[skip ci](shards): fix low concurrency issue (#1672)

Signed-off-by: peng.xu <peng.xu@zilliz.com>
pull/1680/head
XuPeng-SH 2020-03-17 09:57:36 +08:00 committed by GitHub
parent 21c7b8f09c
commit 4173626088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 78 additions and 6 deletions

View File

@ -1,7 +1,10 @@
import time
import enum
import json
import logging import logging
import threading import threading
import enum
from functools import wraps from functools import wraps
from collections import defaultdict
from milvus import Milvus from milvus import Milvus
from milvus.client.hooks import BaseSearchHook from milvus.client.hooks import BaseSearchHook
@ -89,6 +92,26 @@ class Connection:
return self.__str__() return self.__str__()
class Duration:
def __init__(self):
self.start_ts = time.time()
self.end_ts = None
def stop(self):
if self.end_ts:
return False
self.end_ts = time.time()
return True
@property
def value(self):
if not self.end_ts:
return None
return self.end_ts - self.start_ts
class ProxyMixin: class ProxyMixin:
def __getattr__(self, name): def __getattr__(self, name):
target = self.__dict__.get(name, None) target = self.__dict__.get(name, None)
@ -101,6 +124,7 @@ class ScopedConnection(ProxyMixin):
def __init__(self, pool, connection): def __init__(self, pool, connection):
self.pool = pool self.pool = pool
self.connection = connection self.connection = connection
self.duration = Duration()
def __del__(self): def __del__(self):
self.release() self.release()
@ -112,6 +136,8 @@ class ScopedConnection(ProxyMixin):
if not self.pool or not self.connection: if not self.pool or not self.connection:
return return
self.pool.release(self.connection) self.pool.release(self.connection)
self.duration.stop()
self.pool.record_duration(self.connection, self.duration)
self.pool = None self.pool = None
self.connection = None self.connection = None
@ -127,6 +153,30 @@ class ConnectionPool(topology.TopoObject):
self.max_retry = max_retry self.max_retry = max_retry
self.kwargs = kwargs self.kwargs = kwargs
self.cv = threading.Condition() self.cv = threading.Condition()
self.durations = defaultdict(list)
def record_duration(self, conn, duration):
if len(self.durations[conn]) >= 10000:
self.durations[conn].pop(0)
self.durations[conn].append(duration)
def stats(self):
out = {'connections': {}}
connections = out['connections']
take_time = []
for conn, durations in self.durations.items():
total_time = sum(d.value for d in durations)
connections[id(conn)] = {
'total_time': total_time,
'called_times': len(durations)
}
take_time.append(total_time)
out['max-time'] = max(take_time)
out['num'] = len(self.durations)
logger.debug(json.dumps(out, indent=2))
return out
def __len__(self): def __len__(self):
return len(self.pending_pool) + len(self.active_pool) return len(self.pending_pool) + len(self.active_pool)
@ -152,7 +202,7 @@ class ConnectionPool(topology.TopoObject):
if timeout_times >= 1: if timeout_times >= 1:
return connection return connection
# logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num)) # logger.error('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
if len(self.pending_pool) == 0: if len(self.pending_pool) == 0:
connection = self.create() connection = self.create()
else: else:
@ -180,6 +230,13 @@ class ConnectionGroup(topology.TopoGroup):
def __init__(self, name): def __init__(self, name):
super().__init__(name) super().__init__(name)
def stats(self):
out = {}
for name, item in self.items.items():
out[name] = item.stats()
return out
def on_pre_add(self, topo_object): def on_pre_add(self, topo_object):
conn = topo_object.fetch() conn = topo_object.fetch()
conn.on_connect(metadata=None) conn.on_connect(metadata=None)
@ -209,6 +266,13 @@ class ConnectionTopology(topology.Topology):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def stats(self):
out = {}
for name, group in self.topo_groups.items():
out[name] = group.stats()
return out
def create(self, name): def create(self, name):
group = ConnectionGroup(name) group = ConnectionGroup(name)
status = self.add_group(group) status = self.add_group(group)

View File

@ -13,6 +13,7 @@ class RouterMixin:
conn = self.writable_topo.get_group('default').get('WOSERVER').fetch() conn = self.writable_topo.get_group('default').get('WOSERVER').fetch()
if conn: if conn:
conn.on_connect(metadata=metadata) conn.on_connect(metadata=metadata)
# PXU TODO: should return conn
return conn.conn return conn.conn
def query_conn(self, name, metadata=None): def query_conn(self, name, metadata=None):
@ -20,4 +21,4 @@ class RouterMixin:
if not conn: if not conn:
raise exceptions.ConnectionNotFoundError(name, metadata=metadata) raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
conn.on_connect(metadata=metadata) conn.on_connect(metadata=metadata)
return conn.conn return conn

View File

@ -1,6 +1,7 @@
import logging import logging
import time import time
import datetime import datetime
import json
from collections import defaultdict from collections import defaultdict
import multiprocessing import multiprocessing
@ -142,7 +143,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
with self.tracer.start_span('search_{}'.format(addr), with self.tracer.start_span('search_{}'.format(addr),
child_of=span): child_of=span):
ret = conn.search_vectors_in_files(table_name=query_params['table_id'], ret = conn.conn.search_vectors_in_files(table_name=query_params['table_id'],
file_ids=query_params['file_ids'], file_ids=query_params['file_ids'],
query_records=vectors, query_records=vectors,
top_k=topk, top_k=topk,
@ -440,6 +441,12 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
metadata = {'resp_class': milvus_pb2.StringReply} metadata = {'resp_class': milvus_pb2.StringReply}
if _cmd == 'conn_stats':
stats = self.router.readonly_topo.stats()
return milvus_pb2.StringReply(status=status_pb2.Status(
error_code=status_pb2.SUCCESS),
string_reply=json.dumps(stats, indent=2))
if _cmd == 'version': if _cmd == 'version':
_status, _reply = self._get_server_version(metadata=metadata) _status, _reply = self._get_server_version(metadata=metadata)
else: else:

View File

@ -265,7 +265,7 @@ class TestServer:
param['nprobe'] = 2048 param['nprobe'] = 2048
RouterMixin.connection = mock.MagicMock(return_value=Milvus()) RouterMixin.connection = mock.MagicMock(return_value=Milvus())
RouterMixin.query_conn = mock.MagicMock(return_value=Milvus()) RouterMixin.query_conn.conn = mock.MagicMock(return_value=Milvus())
Milvus.describe_table = mock.MagicMock(return_value=(BAD, Milvus.describe_table = mock.MagicMock(return_value=(BAD,
table_schema)) table_schema))
status, ret = self.client.search_vectors(**param) status, ret = self.client.search_vectors(**param)

View File

@ -30,7 +30,7 @@ requests-oauthlib==1.2.0
rsa==4.0 rsa==4.0
six==1.12.0 six==1.12.0
SQLAlchemy==1.3.5 SQLAlchemy==1.3.5
urllib3==1.25.8 urllib3==1.25.3
jaeger-client>=3.4.0 jaeger-client>=3.4.0
grpcio-opentracing>=1.0 grpcio-opentracing>=1.0
mock==2.0.0 mock==2.0.0