diff --git a/shards/mishards/connections.py b/shards/mishards/connections.py index d13987c70b..fa9c52eddc 100644 --- a/shards/mishards/connections.py +++ b/shards/mishards/connections.py @@ -1,7 +1,10 @@ +import time +import enum +import json import logging import threading -import enum from functools import wraps +from collections import defaultdict from milvus import Milvus from milvus.client.hooks import BaseSearchHook @@ -89,6 +92,26 @@ class Connection: return self.__str__() +class Duration: + def __init__(self): + self.start_ts = time.time() + self.end_ts = None + + def stop(self): + if self.end_ts: + return False + + self.end_ts = time.time() + return True + + @property + def value(self): + if not self.end_ts: + return None + + return self.end_ts - self.start_ts + + class ProxyMixin: def __getattr__(self, name): target = self.__dict__.get(name, None) @@ -101,6 +124,7 @@ class ScopedConnection(ProxyMixin): def __init__(self, pool, connection): self.pool = pool self.connection = connection + self.duration = Duration() def __del__(self): self.release() @@ -112,6 +136,8 @@ class ScopedConnection(ProxyMixin): if not self.pool or not self.connection: return self.pool.release(self.connection) + self.duration.stop() + self.pool.record_duration(self.connection, self.duration) self.pool = None self.connection = None @@ -127,6 +153,30 @@ class ConnectionPool(topology.TopoObject): self.max_retry = max_retry self.kwargs = kwargs self.cv = threading.Condition() + self.durations = defaultdict(list) + + def record_duration(self, conn, duration): + if len(self.durations[conn]) >= 10000: + self.durations[conn].pop(0) + + self.durations[conn].append(duration) + + def stats(self): + out = {'connections': {}} + connections = out['connections'] + take_time = [] + for conn, durations in self.durations.items(): + total_time = sum(d.value for d in durations) + connections[id(conn)] = { + 'total_time': total_time, + 'called_times': len(durations) + } + take_time.append(total_time) + + out['max-time'] = max(take_time) + out['num'] = len(self.durations) + logger.debug(json.dumps(out, indent=2)) + return out def __len__(self): return len(self.pending_pool) + len(self.active_pool) @@ -152,7 +202,7 @@ class ConnectionPool(topology.TopoObject): if timeout_times >= 1: return connection - # logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num)) + # logger.error('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num)) if len(self.pending_pool) == 0: connection = self.create() else: @@ -180,6 +230,13 @@ class ConnectionGroup(topology.TopoGroup): def __init__(self, name): super().__init__(name) + def stats(self): + out = {} + for name, item in self.items.items(): + out[name] = item.stats() + + return out + def on_pre_add(self, topo_object): conn = topo_object.fetch() conn.on_connect(metadata=None) @@ -209,6 +266,13 @@ class ConnectionTopology(topology.Topology): def __init__(self): super().__init__() + def stats(self): + out = {} + for name, group in self.topo_groups.items(): + out[name] = group.stats() + + return out + def create(self, name): group = ConnectionGroup(name) status = self.add_group(group) diff --git a/shards/mishards/router/__init__.py b/shards/mishards/router/__init__.py index e435ea3cc0..2567682fda 100644 --- a/shards/mishards/router/__init__.py +++ b/shards/mishards/router/__init__.py @@ -13,6 +13,7 @@ class RouterMixin: conn = self.writable_topo.get_group('default').get('WOSERVER').fetch() if conn: conn.on_connect(metadata=metadata) + # PXU TODO: should return conn return conn.conn def query_conn(self, name, metadata=None): @@ -20,4 +21,4 @@ class RouterMixin: if not conn: raise exceptions.ConnectionNotFoundError(name, metadata=metadata) conn.on_connect(metadata=metadata) - return conn.conn + return conn diff --git a/shards/mishards/service_handler.py b/shards/mishards/service_handler.py index 66942269fd..5612af5333 100644 --- a/shards/mishards/service_handler.py +++ b/shards/mishards/service_handler.py @@ -1,6 +1,7 @@ import logging import time import datetime +import json from collections import defaultdict import multiprocessing @@ -142,7 +143,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): with self.tracer.start_span('search_{}'.format(addr), child_of=span): - ret = conn.search_vectors_in_files(table_name=query_params['table_id'], + ret = conn.conn.search_vectors_in_files(table_name=query_params['table_id'], file_ids=query_params['file_ids'], query_records=vectors, top_k=topk, @@ -440,6 +441,12 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): metadata = {'resp_class': milvus_pb2.StringReply} + if _cmd == 'conn_stats': + stats = self.router.readonly_topo.stats() + return milvus_pb2.StringReply(status=status_pb2.Status( + error_code=status_pb2.SUCCESS), + string_reply=json.dumps(stats, indent=2)) + if _cmd == 'version': _status, _reply = self._get_server_version(metadata=metadata) else: diff --git a/shards/mishards/test_server.py b/shards/mishards/test_server.py index b90cdf7875..b7a3ad370d 100644 --- a/shards/mishards/test_server.py +++ b/shards/mishards/test_server.py @@ -265,7 +265,7 @@ class TestServer: param['nprobe'] = 2048 RouterMixin.connection = mock.MagicMock(return_value=Milvus()) - RouterMixin.query_conn = mock.MagicMock(return_value=Milvus()) + RouterMixin.query_conn.conn = mock.MagicMock(return_value=Milvus()) Milvus.describe_table = mock.MagicMock(return_value=(BAD, table_schema)) status, ret = self.client.search_vectors(**param) diff --git a/shards/requirements.txt b/shards/requirements.txt index 8f4667f34d..47e1e521c2 100644 --- a/shards/requirements.txt +++ b/shards/requirements.txt @@ -30,7 +30,7 @@ requests-oauthlib==1.2.0 rsa==4.0 six==1.12.0 SQLAlchemy==1.3.5 -urllib3==1.25.8 +urllib3==1.25.3 jaeger-client>=3.4.0 grpcio-opentracing>=1.0 mock==2.0.0