[skip ci](shards): fix low concurrency issue (#1672)

Signed-off-by: peng.xu <peng.xu@zilliz.com>
pull/1680/head
XuPeng-SH 2020-03-17 09:57:36 +08:00 committed by GitHub
parent 21c7b8f09c
commit 4173626088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 78 additions and 6 deletions

View File

@ -1,7 +1,10 @@
import time
import enum
import json
import logging
import threading
import enum
from functools import wraps
from collections import defaultdict
from milvus import Milvus
from milvus.client.hooks import BaseSearchHook
@ -89,6 +92,26 @@ class Connection:
return self.__str__()
class Duration:
def __init__(self):
self.start_ts = time.time()
self.end_ts = None
def stop(self):
if self.end_ts:
return False
self.end_ts = time.time()
return True
@property
def value(self):
if not self.end_ts:
return None
return self.end_ts - self.start_ts
class ProxyMixin:
def __getattr__(self, name):
target = self.__dict__.get(name, None)
@ -101,6 +124,7 @@ class ScopedConnection(ProxyMixin):
def __init__(self, pool, connection):
self.pool = pool
self.connection = connection
self.duration = Duration()
def __del__(self):
self.release()
@ -112,6 +136,8 @@ class ScopedConnection(ProxyMixin):
if not self.pool or not self.connection:
return
self.pool.release(self.connection)
self.duration.stop()
self.pool.record_duration(self.connection, self.duration)
self.pool = None
self.connection = None
@ -127,6 +153,30 @@ class ConnectionPool(topology.TopoObject):
self.max_retry = max_retry
self.kwargs = kwargs
self.cv = threading.Condition()
self.durations = defaultdict(list)
def record_duration(self, conn, duration):
if len(self.durations[conn]) >= 10000:
self.durations[conn].pop(0)
self.durations[conn].append(duration)
def stats(self):
out = {'connections': {}}
connections = out['connections']
take_time = []
for conn, durations in self.durations.items():
total_time = sum(d.value for d in durations)
connections[id(conn)] = {
'total_time': total_time,
'called_times': len(durations)
}
take_time.append(total_time)
out['max-time'] = max(take_time)
out['num'] = len(self.durations)
logger.debug(json.dumps(out, indent=2))
return out
def __len__(self):
return len(self.pending_pool) + len(self.active_pool)
@ -152,7 +202,7 @@ class ConnectionPool(topology.TopoObject):
if timeout_times >= 1:
return connection
# logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
# logger.error('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
if len(self.pending_pool) == 0:
connection = self.create()
else:
@ -180,6 +230,13 @@ class ConnectionGroup(topology.TopoGroup):
def __init__(self, name):
super().__init__(name)
def stats(self):
out = {}
for name, item in self.items.items():
out[name] = item.stats()
return out
def on_pre_add(self, topo_object):
conn = topo_object.fetch()
conn.on_connect(metadata=None)
@ -209,6 +266,13 @@ class ConnectionTopology(topology.Topology):
def __init__(self):
super().__init__()
def stats(self):
out = {}
for name, group in self.topo_groups.items():
out[name] = group.stats()
return out
def create(self, name):
group = ConnectionGroup(name)
status = self.add_group(group)

View File

@ -13,6 +13,7 @@ class RouterMixin:
conn = self.writable_topo.get_group('default').get('WOSERVER').fetch()
if conn:
conn.on_connect(metadata=metadata)
# PXU TODO: should return conn
return conn.conn
def query_conn(self, name, metadata=None):
@ -20,4 +21,4 @@ class RouterMixin:
if not conn:
raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
conn.on_connect(metadata=metadata)
return conn.conn
return conn

View File

@ -1,6 +1,7 @@
import logging
import time
import datetime
import json
from collections import defaultdict
import multiprocessing
@ -142,7 +143,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
with self.tracer.start_span('search_{}'.format(addr),
child_of=span):
ret = conn.search_vectors_in_files(table_name=query_params['table_id'],
ret = conn.conn.search_vectors_in_files(table_name=query_params['table_id'],
file_ids=query_params['file_ids'],
query_records=vectors,
top_k=topk,
@ -440,6 +441,12 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
metadata = {'resp_class': milvus_pb2.StringReply}
if _cmd == 'conn_stats':
stats = self.router.readonly_topo.stats()
return milvus_pb2.StringReply(status=status_pb2.Status(
error_code=status_pb2.SUCCESS),
string_reply=json.dumps(stats, indent=2))
if _cmd == 'version':
_status, _reply = self._get_server_version(metadata=metadata)
else:

View File

@ -265,7 +265,7 @@ class TestServer:
param['nprobe'] = 2048
RouterMixin.connection = mock.MagicMock(return_value=Milvus())
RouterMixin.query_conn = mock.MagicMock(return_value=Milvus())
RouterMixin.query_conn.conn = mock.MagicMock(return_value=Milvus())
Milvus.describe_table = mock.MagicMock(return_value=(BAD,
table_schema))
status, ret = self.client.search_vectors(**param)

View File

@ -30,7 +30,7 @@ requests-oauthlib==1.2.0
rsa==4.0
six==1.12.0
SQLAlchemy==1.3.5
urllib3==1.25.8
urllib3==1.25.3
jaeger-client>=3.4.0
grpcio-opentracing>=1.0
mock==2.0.0