From 71c67f59a3b1d348c0e27c49a642bf64b0227a5a Mon Sep 17 00:00:00 2001 From: "peng.xu" Date: Mon, 14 Oct 2019 13:42:12 +0800 Subject: [PATCH] update for code style --- conftest.py | 1 + manager.py | 7 ++- mishards/__init__.py | 5 +- mishards/connections.py | 10 ++-- mishards/db_base.py | 8 ++- mishards/exception_handlers.py | 5 ++ mishards/exceptions.py | 8 +++ mishards/factories.py | 18 +++--- mishards/grpc_utils/__init__.py | 10 ++-- mishards/grpc_utils/grpc_args_wrapper.py | 4 +- mishards/hash_ring.py | 28 +++++----- mishards/main.py | 11 ++-- mishards/models.py | 15 ++--- mishards/server.py | 4 +- mishards/service_handler.py | 64 ++++++++++----------- mishards/settings.py | 10 +++- mishards/test_connections.py | 8 ++- mishards/test_models.py | 7 ++- sd/__init__.py | 1 + sd/kubernetes_provider.py | 71 +++++++++++++----------- sd/static_provider.py | 6 +- tracing/__init__.py | 13 +++-- tracing/factory.py | 12 ++-- utils/__init__.py | 1 + utils/logger_helper.py | 17 ++++-- 25 files changed, 201 insertions(+), 143 deletions(-) diff --git a/conftest.py b/conftest.py index c4fed5cc7e..d6c9f3acc7 100644 --- a/conftest.py +++ b/conftest.py @@ -4,6 +4,7 @@ from mishards import settings, db, create_app logger = logging.getLogger(__name__) + @pytest.fixture def app(request): app = create_app(settings.TestingConfig) diff --git a/manager.py b/manager.py index 31f5894d2d..931c90ebc8 100644 --- a/manager.py +++ b/manager.py @@ -2,6 +2,7 @@ import fire from mishards import db from sqlalchemy import and_ + class DBHandler: @classmethod def create_all(cls): @@ -15,9 +16,9 @@ class DBHandler: def fun(cls, tid): from mishards.factories import TablesFactory, TableFilesFactory, Tables f = db.Session.query(Tables).filter(and_( - Tables.table_id==tid, - Tables.state!=Tables.TO_DELETE) - ).first() + Tables.table_id == tid, + Tables.state != Tables.TO_DELETE) + ).first() print(f) # f1 = TableFilesFactory() diff --git a/mishards/__init__.py b/mishards/__init__.py index b351986cba..47d8adb6e3 100644 --- a/mishards/__init__.py +++ b/mishards/__init__.py @@ -1,4 +1,4 @@ -import logging +import logging from mishards import settings logger = logging.getLogger() @@ -8,6 +8,7 @@ db = DB() from mishards.server import Server grpc_server = Server() + def create_app(testing_config=None): config = testing_config if testing_config else settings.DefaultConfig db.init_db(uri=config.SQLALCHEMY_DATABASE_URI, echo=config.SQL_ECHO) @@ -24,7 +25,7 @@ def create_app(testing_config=None): from tracing.factory import TracerFactory from mishards.grpc_utils import GrpcSpanDecorator tracer = TracerFactory.new_tracer(settings.TRACING_TYPE, settings.TracingConfig, - span_decorator=GrpcSpanDecorator()) + span_decorator=GrpcSpanDecorator()) grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, discover=discover) diff --git a/mishards/connections.py b/mishards/connections.py index 22524c3a20..ccd8e7e81b 100644 --- a/mishards/connections.py +++ b/mishards/connections.py @@ -10,6 +10,7 @@ from utils import singleton logger = logging.getLogger(__name__) + class Connection: def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs): self.name = name @@ -55,7 +56,7 @@ class Connection: if not self.can_retry and not self.connected: raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry, - metadata=metadata)) + metadata=metadata)) self.retried = 0 @@ -72,6 +73,7 @@ class Connection: raise e return inner + @singleton class ConnectionMgr: def __init__(self): @@ -90,10 +92,10 @@ class ConnectionMgr: if not throw: return None raise exceptions.ConnectionNotFoundError(message='Connection {} not found'.format(name), - metadata=metadata) + metadata=metadata) this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY) threaded = { - threading.get_ident() : this_conn + threading.get_ident(): this_conn } self.conns[name] = threaded return this_conn @@ -106,7 +108,7 @@ class ConnectionMgr: if not throw: return None raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name), - metadata=metadata) + metadata=metadata) this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY) c[tid] = this_conn return this_conn diff --git a/mishards/db_base.py b/mishards/db_base.py index b1492aa8f5..6fb3aef4e1 100644 --- a/mishards/db_base.py +++ b/mishards/db_base.py @@ -14,8 +14,10 @@ class LocalSession(SessionBase): bind = options.pop('bind', None) or db.engine SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, **options) + class DB: Model = declarative_base() + def __init__(self, uri=None, echo=False): self.echo = echo uri and self.init_db(uri, echo) @@ -27,9 +29,9 @@ class DB: self.engine = create_engine(url) else: self.engine = create_engine(uri, pool_size=100, pool_recycle=5, pool_timeout=30, - pool_pre_ping=True, - echo=echo, - max_overflow=0) + pool_pre_ping=True, + echo=echo, + max_overflow=0) self.uri = uri self.url = url diff --git a/mishards/exception_handlers.py b/mishards/exception_handlers.py index 16ba34a3b1..1e5ffb3529 100644 --- a/mishards/exception_handlers.py +++ b/mishards/exception_handlers.py @@ -4,6 +4,7 @@ from mishards import grpc_server as server, exceptions logger = logging.getLogger(__name__) + def resp_handler(err, error_code): if not isinstance(err, exceptions.BaseException): return status_pb2.Status(error_code=error_code, reason=str(err)) @@ -50,21 +51,25 @@ def resp_handler(err, error_code): status.error_code = status_pb2.UNEXPECTED_ERROR return status + @server.errorhandler(exceptions.TableNotFoundError) def TableNotFoundErrorHandler(err): logger.error(err) return resp_handler(err, status_pb2.TABLE_NOT_EXISTS) + @server.errorhandler(exceptions.InvalidArgumentError) def InvalidArgumentErrorHandler(err): logger.error(err) return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT) + @server.errorhandler(exceptions.DBError) def DBErrorHandler(err): logger.error(err) return resp_handler(err, status_pb2.UNEXPECTED_ERROR) + @server.errorhandler(exceptions.InvalidRangeError) def InvalidArgumentErrorHandler(err): logger.error(err) diff --git a/mishards/exceptions.py b/mishards/exceptions.py index 2aa2b39eb9..acd9372d6a 100644 --- a/mishards/exceptions.py +++ b/mishards/exceptions.py @@ -1,26 +1,34 @@ import mishards.exception_codes as codes + class BaseException(Exception): code = codes.INVALID_CODE message = 'BaseException' + def __init__(self, message='', metadata=None): self.message = self.__class__.__name__ if not message else message self.metadata = metadata + class ConnectionConnectError(BaseException): code = codes.CONNECT_ERROR_CODE + class ConnectionNotFoundError(BaseException): code = codes.CONNECTTION_NOT_FOUND_CODE + class DBError(BaseException): code = codes.DB_ERROR_CODE + class TableNotFoundError(BaseException): code = codes.TABLE_NOT_FOUND_CODE + class InvalidArgumentError(BaseException): code = codes.INVALID_ARGUMENT_CODE + class InvalidRangeError(BaseException): code = codes.INVALID_DATE_RANGE_CODE diff --git a/mishards/factories.py b/mishards/factories.py index 26e9ab2619..c4037fe2d7 100644 --- a/mishards/factories.py +++ b/mishards/factories.py @@ -9,13 +9,16 @@ from faker.providers import BaseProvider from mishards import db from mishards.models import Tables, TableFiles + class FakerProvider(BaseProvider): def this_date(self): t = datetime.datetime.today() - return (t.year - 1900) * 10000 + (t.month-1)*100 + t.day + return (t.year - 1900) * 10000 + (t.month - 1) * 100 + t.day + factory.Faker.add_provider(FakerProvider) + class TablesFactory(SQLAlchemyModelFactory): class Meta: model = Tables @@ -24,14 +27,15 @@ class TablesFactory(SQLAlchemyModelFactory): id = factory.Faker('random_number', digits=16, fix_len=True) table_id = factory.Faker('uuid4') - state = factory.Faker('random_element', elements=(0,1,2,3)) - dimension = factory.Faker('random_element', elements=(256,512)) + state = factory.Faker('random_element', elements=(0, 1, 2, 3)) + dimension = factory.Faker('random_element', elements=(256, 512)) created_on = int(time.time()) index_file_size = 0 - engine_type = factory.Faker('random_element', elements=(0,1,2,3)) - metric_type = factory.Faker('random_element', elements=(0,1)) + engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3)) + metric_type = factory.Faker('random_element', elements=(0, 1)) nlist = 16384 + class TableFilesFactory(SQLAlchemyModelFactory): class Meta: model = TableFiles @@ -40,9 +44,9 @@ class TableFilesFactory(SQLAlchemyModelFactory): id = factory.Faker('random_number', digits=16, fix_len=True) table = factory.SubFactory(TablesFactory) - engine_type = factory.Faker('random_element', elements=(0,1,2,3)) + engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3)) file_id = factory.Faker('uuid4') - file_type = factory.Faker('random_element', elements=(0,1,2,3,4)) + file_type = factory.Faker('random_element', elements=(0, 1, 2, 3, 4)) file_size = factory.Faker('random_number') updated_time = int(time.time()) created_on = int(time.time()) diff --git a/mishards/grpc_utils/__init__.py b/mishards/grpc_utils/__init__.py index 550913ed60..f5225b2a66 100644 --- a/mishards/grpc_utils/__init__.py +++ b/mishards/grpc_utils/__init__.py @@ -14,21 +14,23 @@ class GrpcSpanDecorator(SpanDecorator): status = rpc_info.response.status except Exception as e: status = status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, - reason='Should not happen') + reason='Should not happen') if status.error_code == 0: return error_log = {'event': 'error', - 'request': rpc_info.request, - 'response': rpc_info.response - } + 'request': rpc_info.request, + 'response': rpc_info.response + } span.set_tag('error', True) span.log_kv(error_log) + def mark_grpc_method(func): setattr(func, 'grpc_method', True) return func + def is_grpc_method(func): if not func: return False diff --git a/mishards/grpc_utils/grpc_args_wrapper.py b/mishards/grpc_utils/grpc_args_wrapper.py index a864b1e400..7447dbd995 100644 --- a/mishards/grpc_utils/grpc_args_wrapper.py +++ b/mishards/grpc_utils/grpc_args_wrapper.py @@ -1,4 +1,4 @@ # class GrpcArgsWrapper(object): - # @classmethod - # def proto_TableName(cls): \ No newline at end of file +# @classmethod +# def proto_TableName(cls): diff --git a/mishards/hash_ring.py b/mishards/hash_ring.py index bfec108c5c..a97f3f580e 100644 --- a/mishards/hash_ring.py +++ b/mishards/hash_ring.py @@ -9,8 +9,8 @@ else: import md5 md5_constructor = md5.new -class HashRing(object): +class HashRing(object): def __init__(self, nodes=None, weights=None): """`nodes` is a list of objects that have a proper __str__ representation. `weights` is dictionary that sets weights to the nodes. The default @@ -40,13 +40,13 @@ class HashRing(object): if node in self.weights: weight = self.weights.get(node) - factor = math.floor((40*len(self.nodes)*weight) / total_weight); + factor = math.floor((40 * len(self.nodes) * weight) / total_weight) for j in range(0, int(factor)): - b_key = self._hash_digest( '%s-%s' % (node, j) ) + b_key = self._hash_digest('%s-%s' % (node, j)) for i in range(0, 3): - key = self._hash_val(b_key, lambda x: x+i*4) + key = self._hash_val(b_key, lambda x: x + i * 4) self.ring[key] = node self._sorted_keys.append(key) @@ -60,7 +60,7 @@ class HashRing(object): pos = self.get_node_pos(string_key) if pos is None: return None - return self.ring[ self._sorted_keys[pos] ] + return self.ring[self._sorted_keys[pos]] def get_node_pos(self, string_key): """Given a string key a corresponding node in the hash ring is returned @@ -94,6 +94,7 @@ class HashRing(object): yield None, None returned_values = set() + def distinct_filter(value): if str(value) not in returned_values: returned_values.add(str(value)) @@ -121,10 +122,8 @@ class HashRing(object): return self._hash_val(b_key, lambda x: x) def _hash_val(self, b_key, entry_fn): - return (( b_key[entry_fn(3)] << 24) - |(b_key[entry_fn(2)] << 16) - |(b_key[entry_fn(1)] << 8) - | b_key[entry_fn(0)] ) + return (b_key[entry_fn(3)] << 24) | (b_key[entry_fn(2)] << 16) | ( + b_key[entry_fn(1)] << 8) | b_key[entry_fn(0)] def _hash_digest(self, key): m = md5_constructor() @@ -132,12 +131,13 @@ class HashRing(object): m.update(key) return m.digest() + if __name__ == '__main__': from collections import defaultdict - servers = ['192.168.0.246:11212', - '192.168.0.247:11212', - '192.168.0.248:11212', - '192.168.0.249:11212'] + servers = [ + '192.168.0.246:11212', '192.168.0.247:11212', '192.168.0.248:11212', + '192.168.0.249:11212' + ] ring = HashRing(servers) keys = ['{}'.format(i) for i in range(100)] @@ -146,5 +146,5 @@ if __name__ == '__main__': server = ring.get_node(k) mapped[server].append(k) - for k,v in mapped.items(): + for k, v in mapped.items(): print(k, v) diff --git a/mishards/main.py b/mishards/main.py index 5d8db0a179..3f69484ee4 100644 --- a/mishards/main.py +++ b/mishards/main.py @@ -1,13 +1,16 @@ -import os, sys +import os +import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from mishards import ( - settings, create_app) +from mishards import (settings, create_app) + def main(): - server = create_app(settings.TestingConfig if settings.TESTING else settings.DefaultConfig) + server = create_app( + settings.TestingConfig if settings.TESTING else settings.DefaultConfig) server.run(port=settings.SERVER_PORT) return 0 + if __name__ == '__main__': sys.exit(main()) diff --git a/mishards/models.py b/mishards/models.py index 0f7bb603ae..54cf5f8ed9 100644 --- a/mishards/models.py +++ b/mishards/models.py @@ -1,13 +1,14 @@ import logging from sqlalchemy import (Integer, Boolean, Text, - String, BigInteger, func, and_, or_, - Column) + String, BigInteger, func, and_, or_, + Column) from sqlalchemy.orm import relationship, backref from mishards import db logger = logging.getLogger(__name__) + class TableFiles(db.Model): FILE_TYPE_NEW = 0 FILE_TYPE_RAW = 1 @@ -57,16 +58,16 @@ class Tables(db.Model): def files_to_search(self, date_range=None): cond = or_( - TableFiles.file_type==TableFiles.FILE_TYPE_RAW, - TableFiles.file_type==TableFiles.FILE_TYPE_TO_INDEX, - TableFiles.file_type==TableFiles.FILE_TYPE_INDEX, + TableFiles.file_type == TableFiles.FILE_TYPE_RAW, + TableFiles.file_type == TableFiles.FILE_TYPE_TO_INDEX, + TableFiles.file_type == TableFiles.FILE_TYPE_INDEX, ) if date_range: cond = and_( cond, or_( - and_(TableFiles.date>=d[0], TableFiles.date= d[0], TableFiles.date < d[1]) for d in date_range + ) ) files = self.files.filter(cond) diff --git a/mishards/server.py b/mishards/server.py index c044bbb7ad..032d101cba 100644 --- a/mishards/server.py +++ b/mishards/server.py @@ -33,7 +33,7 @@ class Server: self.server_impl = grpc.server( thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers), options=[(cygrpc.ChannelArgKey.max_send_message_length, -1), - (cygrpc.ChannelArgKey.max_receive_message_length, -1)] + (cygrpc.ChannelArgKey.max_receive_message_length, -1)] ) self.server_impl = self.tracer.decorate(self.server_impl) @@ -46,7 +46,7 @@ class Server: ip = socket.gethostbyname(url.hostname) socket.inet_pton(socket.AF_INET, ip) self.conn_mgr.register('WOSERVER', - '{}://{}:{}'.format(url.scheme, ip, url.port or 80)) + '{}://{}:{}'.format(url.scheme, ip, url.port or 80)) def register_pre_run_handler(self, func): logger.info('Regiterring {} into server pre_run_handlers'.format(func)) diff --git a/mishards/service_handler.py b/mishards/service_handler.py index 60d64cef37..2a1e0eef02 100644 --- a/mishards/service_handler.py +++ b/mishards/service_handler.py @@ -11,7 +11,7 @@ from concurrent.futures import ThreadPoolExecutor from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 from milvus.grpc_gen.milvus_pb2 import TopKQueryResult from milvus.client.Abstract import Range -from milvus.client import types +from milvus.client import types as Types from mishards import (db, settings, exceptions) from mishards.grpc_utils import mark_grpc_method @@ -24,6 +24,7 @@ logger = logging.getLogger(__name__) class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): MAX_NPROBE = 2048 + def __init__(self, conn_mgr, tracer, *args, **kwargs): self.conn_mgr = conn_mgr self.table_meta = {} @@ -44,8 +45,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return conn.conn def _format_date(self, start, end): - return ((start.year-1900)*10000 + (start.month-1)*100 + start.day - , (end.year-1900)*10000 + (end.month-1)*100 + end.day) + return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day, (end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day) def _range_to_date(self, range_obj, metadata=None): try: @@ -54,8 +54,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): assert start < end except (ValueError, AssertionError): raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format( - range_obj.start_date, range_obj.end_date - ), metadata=metadata) + range_obj.start_date, range_obj.end_date + ), metadata=metadata) return self._format_date(start, end) @@ -63,9 +63,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): # PXU TODO: Implement Thread-local Context try: table = db.Session.query(Tables).filter(and_( - Tables.table_id==table_id, - Tables.state!=Tables.TO_DELETE - )).first() + Tables.table_id == table_id, + Tables.state != Tables.TO_DELETE + )).first() except sqlalchemy_exc.SQLAlchemyError as e: raise exceptions.DBError(message=str(e), metadata=metadata) @@ -93,7 +93,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return routing def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs): - status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success") + status = status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success") if not files_n_topk_results: return status, [] @@ -107,7 +107,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): for request_pos, each_request_results in enumerate(files_collection.topk_query_result): request_results[request_pos].extend(each_request_results.query_result_arrays) request_results[request_pos] = sorted(request_results[request_pos], key=lambda x: x.distance, - reverse=reverse)[:topk] + reverse=reverse)[:topk] calc_time = time.time() - calc_time logger.info('Merge takes {}'.format(calc_time)) @@ -127,7 +127,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): routing = {} with self.tracer.start_span('get_routing', - child_of=context.get_active_span().context): + child_of=context.get_active_span().context): routing = self._get_routing_file_ids(table_id, range_array, metadata=metadata) logger.info('Routing: {}'.format(routing)) @@ -140,28 +140,28 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): def search(addr, query_params, vectors, topk, nprobe, **kwargs): logger.info('Send Search Request: addr={};params={};nq={};topk={};nprobe={}'.format( - addr, query_params, len(vectors), topk, nprobe - )) + addr, query_params, len(vectors), topk, nprobe + )) conn = self.query_conn(addr, metadata=metadata) start = time.time() span = kwargs.get('span', None) span = span if span else context.get_active_span().context with self.tracer.start_span('search_{}'.format(addr), - child_of=context.get_active_span().context): + child_of=context.get_active_span().context): ret = conn.search_vectors_in_files(table_name=query_params['table_id'], - file_ids=query_params['file_ids'], - query_records=vectors, - top_k=topk, - nprobe=nprobe, - lazy=True) + file_ids=query_params['file_ids'], + query_records=vectors, + top_k=topk, + nprobe=nprobe, + lazy=True) end = time.time() logger.info('search_vectors_in_files takes: {}'.format(end - start)) all_topk_results.append(ret) with self.tracer.start_span('do_search', - child_of=context.get_active_span().context) as span: + child_of=context.get_active_span().context) as span: with ThreadPoolExecutor(max_workers=workers) as pool: for addr, params in routing.items(): res = pool.submit(search, addr, params, vectors, topk, nprobe, span=span) @@ -170,9 +170,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): for res in rs: res.result() - reverse = table_meta.metric_type == types.MetricType.IP + reverse = table_meta.metric_type == Types.MetricType.IP with self.tracer.start_span('do_merge', - child_of=context.get_active_span().context): + child_of=context.get_active_span().context): return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata) @mark_grpc_method @@ -201,8 +201,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('HasTable {}'.format(_table_name)) _bool = self.connection(metadata={ - 'resp_class': milvus_pb2.BoolReply - }).has_table(_table_name) + 'resp_class': milvus_pb2.BoolReply + }).has_table(_table_name) return milvus_pb2.BoolReply( status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"), @@ -244,7 +244,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): # TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array' _status, _ids = self.connection(metadata={ 'resp_class': milvus_pb2.VectorIds - }).add_vectors(None, None, insert_param=request) + }).add_vectors(None, None, insert_param=request) return milvus_pb2.VectorIds( status=status_pb2.Status(error_code=_status.code, reason=_status.message), vector_id_array=_ids @@ -266,7 +266,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): if nprobe > self.MAX_NPROBE or nprobe <= 0: raise exceptions.InvalidArgumentError(message='Invalid nprobe: {}'.format(nprobe), - metadata=metadata) + metadata=metadata) table_meta = self.table_meta.get(table_name, None) @@ -332,8 +332,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): ) return milvus_pb2.TableSchema( - table_name=_table_name, - status=status_pb2.Status(error_code=_status.code, reason=_status.message), + table_name=_table_name, + status=status_pb2.Status(error_code=_status.code, reason=_status.message), ) @mark_grpc_method @@ -391,8 +391,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _status, _results = self.connection(metadata=metadata).show_tables() return milvus_pb2.TableNameList( - status=status_pb2.Status(error_code=_status.code, reason=_status.message), - table_names=_results + status=status_pb2.Status(error_code=_status.code, reason=_status.message), + table_names=_results ) @mark_grpc_method @@ -426,7 +426,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): if not _status.OK(): return milvus_pb2.IndexParam( - status=status_pb2.Status(error_code=_status.code, reason=_status.message) + status=status_pb2.Status(error_code=_status.code, reason=_status.message) ) metadata = { @@ -439,7 +439,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _index = milvus_pb2.Index(index_type=_index_param._index_type, nlist=_index_param._nlist) return milvus_pb2.IndexParam(status=status_pb2.Status(error_code=_status.code, reason=_status.message), - table_name=_table_name, index=_index) + table_name=_table_name, index=_index) @mark_grpc_method def DropIndex(self, request, context): diff --git a/mishards/settings.py b/mishards/settings.py index f5028cbbc7..4563538a08 100644 --- a/mishards/settings.py +++ b/mishards/settings.py @@ -39,13 +39,15 @@ if SD_PROVIDER == 'Kubernetes': elif SD_PROVIDER == 'Static': from sd.static_provider import StaticProviderSettings SD_PROVIDER_SETTINGS = StaticProviderSettings( - hosts=env.list('SD_STATIC_HOSTS', []) - ) + hosts=env.list('SD_STATIC_HOSTS', []) + ) TESTING = env.bool('TESTING', False) TESTING_WOSERVER = env.str('TESTING_WOSERVER', 'tcp://127.0.0.1:19530') TRACING_TYPE = env.str('TRACING_TYPE', '') + + class TracingConfig: TRACING_SERVICE_NAME = env.str('TRACING_SERVICE_NAME', 'mishards') TRACING_VALIDATE = env.bool('TRACING_VALIDATE', True) @@ -54,7 +56,7 @@ class TracingConfig: 'sampler': { 'type': env.str('TRACING_SAMPLER_TYPE', 'const'), 'param': env.str('TRACING_SAMPLER_PARAM', "1"), - }, + }, 'local_agent': { 'reporting_host': env.str('TRACING_REPORTING_HOST', '127.0.0.1'), 'reporting_port': env.str('TRACING_REPORTING_PORT', '5775') @@ -62,10 +64,12 @@ class TracingConfig: 'logging': env.bool('TRACING_LOGGING', True) } + class DefaultConfig: SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI') SQL_ECHO = env.bool('SQL_ECHO', False) + TESTING = env.bool('TESTING', False) if TESTING: class TestingConfig(DefaultConfig): diff --git a/mishards/test_connections.py b/mishards/test_connections.py index 1f46b60f8b..f1c54f0c61 100644 --- a/mishards/test_connections.py +++ b/mishards/test_connections.py @@ -6,6 +6,7 @@ from mishards import exceptions logger = logging.getLogger(__name__) + @pytest.mark.usefixtures('app') class TestConnection: def test_manager(self): @@ -30,8 +31,10 @@ class TestConnection: class Conn: def __init__(self, state): self.state = state + def connect(self, uri): return self.state + def connected(self): return self.state FAIL_CONN = Conn(False) @@ -48,6 +51,7 @@ class TestConnection: class Func(): def __init__(self): self.executed = False + def __call__(self): self.executed = True @@ -55,8 +59,8 @@ class TestConnection: RetryObj = Retry() c = Connection('client', uri='', - max_retry=max_retry, - on_retry_func=RetryObj) + max_retry=max_retry, + on_retry_func=RetryObj) c.conn = FAIL_CONN ff = Func() this_connect = c.connect(func=ff) diff --git a/mishards/test_models.py b/mishards/test_models.py index 85dcc246aa..d60b62713e 100644 --- a/mishards/test_models.py +++ b/mishards/test_models.py @@ -3,12 +3,13 @@ import pytest from mishards.factories import TableFiles, Tables, TableFilesFactory, TablesFactory from mishards import db, create_app, settings from mishards.factories import ( - Tables, TableFiles, - TablesFactory, TableFilesFactory - ) + Tables, TableFiles, + TablesFactory, TableFilesFactory +) logger = logging.getLogger(__name__) + @pytest.mark.usefixtures('app') class TestModels: def test_files_to_search(self): diff --git a/sd/__init__.py b/sd/__init__.py index 6dfba5ddc1..7943887d0f 100644 --- a/sd/__init__.py +++ b/sd/__init__.py @@ -24,4 +24,5 @@ class ProviderManager: def get_provider(cls, name): return cls.PROVIDERS.get(name, None) + from sd import kubernetes_provider, static_provider diff --git a/sd/kubernetes_provider.py b/sd/kubernetes_provider.py index 51665a0cb5..924f1fc8a4 100644 --- a/sd/kubernetes_provider.py +++ b/sd/kubernetes_provider.py @@ -1,4 +1,5 @@ -import os, sys +import os +import sys if __name__ == '__main__': sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -71,7 +72,6 @@ class K8SHeartbeatHandler(threading.Thread, K8SMixin): self.queue.put(event_message) - except Exception as exc: logger.error(exc) @@ -98,18 +98,18 @@ class K8SEventListener(threading.Thread, K8SMixin): resource_version = '' w = watch.Watch() for event in w.stream(self.v1.list_namespaced_event, namespace=self.namespace, - field_selector='involvedObject.kind=Pod'): + field_selector='involvedObject.kind=Pod'): if self.terminate: break resource_version = int(event['object'].metadata.resource_version) info = dict( - eType='WatchEvent', - pod=event['object'].involved_object.name, - reason=event['object'].reason, - message=event['object'].message, - start_up=self.at_start_up, + eType='WatchEvent', + pod=event['object'].involved_object.name, + reason=event['object'].reason, + message=event['object'].message, + start_up=self.at_start_up, ) self.at_start_up = False # logger.info('Received event: {}'.format(info)) @@ -135,7 +135,7 @@ class EventHandler(threading.Thread): def on_pod_started(self, event, **kwargs): try_cnt = 3 pod = None - while try_cnt > 0: + while try_cnt > 0: try_cnt -= 1 try: pod = self.mgr.v1.read_namespaced_pod(name=event['pod'], namespace=self.namespace) @@ -203,6 +203,7 @@ class EventHandler(threading.Thread): except queue.Empty: continue + class KubernetesProviderSettings: def __init__(self, namespace, pod_patt, label_selector, in_cluster, poll_interval, **kwargs): self.namespace = namespace @@ -211,10 +212,12 @@ class KubernetesProviderSettings: self.in_cluster = in_cluster self.poll_interval = poll_interval + @singleton @ProviderManager.register_service_provider class KubernetesProvider(object): NAME = 'Kubernetes' + def __init__(self, settings, conn_mgr, **kwargs): self.namespace = settings.namespace self.pod_patt = settings.pod_patt @@ -233,27 +236,27 @@ class KubernetesProvider(object): self.v1 = client.CoreV1Api() self.listener = K8SEventListener( - message_queue=self.queue, - namespace=self.namespace, - in_cluster=self.in_cluster, - v1=self.v1, - **kwargs - ) + message_queue=self.queue, + namespace=self.namespace, + in_cluster=self.in_cluster, + v1=self.v1, + **kwargs + ) self.pod_heartbeater = K8SHeartbeatHandler( - message_queue=self.queue, - namespace=self.namespace, - label_selector=self.label_selector, - in_cluster=self.in_cluster, - v1=self.v1, - poll_interval=self.poll_interval, - **kwargs - ) + message_queue=self.queue, + namespace=self.namespace, + label_selector=self.label_selector, + in_cluster=self.in_cluster, + v1=self.v1, + poll_interval=self.poll_interval, + **kwargs + ) self.event_handler = EventHandler(mgr=self, - message_queue=self.queue, - namespace=self.namespace, - pod_patt=self.pod_patt, **kwargs) + message_queue=self.queue, + namespace=self.namespace, + pod_patt=self.pod_patt, **kwargs) def add_pod(self, name, ip): self.conn_mgr.register(name, 'tcp://{}:19530'.format(ip)) @@ -276,9 +279,11 @@ class KubernetesProvider(object): if __name__ == '__main__': logging.basicConfig(level=logging.INFO) + class Connect: def register(self, name, value): logger.error('Register: {} - {}'.format(name, value)) + def unregister(self, name): logger.error('Unregister: {}'.format(name)) @@ -289,16 +294,16 @@ if __name__ == '__main__': connect_mgr = Connect() settings = KubernetesProviderSettings( - namespace='xp', - pod_patt=".*-ro-servers-.*", - label_selector='tier=ro-servers', - poll_interval=5, - in_cluster=False) + namespace='xp', + pod_patt=".*-ro-servers-.*", + label_selector='tier=ro-servers', + poll_interval=5, + in_cluster=False) provider_class = ProviderManager.get_provider('Kubernetes') t = provider_class(conn_mgr=connect_mgr, - settings=settings - ) + settings=settings + ) t.start() cnt = 100 while cnt > 0: diff --git a/sd/static_provider.py b/sd/static_provider.py index 423d6c4d60..5c97c4efd0 100644 --- a/sd/static_provider.py +++ b/sd/static_provider.py @@ -1,4 +1,5 @@ -import os, sys +import os +import sys if __name__ == '__main__': sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -6,14 +7,17 @@ import socket from utils import singleton from sd import ProviderManager + class StaticProviderSettings: def __init__(self, hosts): self.hosts = hosts + @singleton @ProviderManager.register_service_provider class KubernetesProvider(object): NAME = 'Static' + def __init__(self, settings, conn_mgr, **kwargs): self.conn_mgr = conn_mgr self.hosts = [socket.gethostbyname(host) for host in settings.hosts] diff --git a/tracing/__init__.py b/tracing/__init__.py index 27c57473db..5014309a52 100644 --- a/tracing/__init__.py +++ b/tracing/__init__.py @@ -1,13 +1,14 @@ def empty_server_interceptor_decorator(target_server, interceptor): return target_server + class Tracer: def __init__(self, tracer=None, - interceptor=None, - server_decorator=empty_server_interceptor_decorator): + interceptor=None, + server_decorator=empty_server_interceptor_decorator): self.tracer = tracer self.interceptor = interceptor - self.server_decorator=server_decorator + self.server_decorator = server_decorator def decorate(self, server): return self.server_decorator(server, self.interceptor) @@ -16,7 +17,7 @@ class Tracer: self.tracer and self.tracer.close() def start_span(self, operation_name=None, - child_of=None, references=None, tags=None, - start_time=None, ignore_active_span=False): + child_of=None, references=None, tags=None, + start_time=None, ignore_active_span=False): return self.tracer.start_span(operation_name, child_of, - references, tags, start_time, ignore_active_span) + references, tags, start_time, ignore_active_span) diff --git a/tracing/factory.py b/tracing/factory.py index fd06fe3cac..648dfa291e 100644 --- a/tracing/factory.py +++ b/tracing/factory.py @@ -4,7 +4,7 @@ from grpc_opentracing.grpcext import intercept_server from grpc_opentracing import open_tracing_server_interceptor from tracing import (Tracer, - empty_server_interceptor_decorator) + empty_server_interceptor_decorator) logger = logging.getLogger(__name__) @@ -17,14 +17,14 @@ class TracerFactory: if tracer_type.lower() == 'jaeger': config = Config(config=tracer_config.TRACING_CONFIG, - service_name=tracer_config.TRACING_SERVICE_NAME, - validate=tracer_config.TRACING_VALIDATE - ) + service_name=tracer_config.TRACING_SERVICE_NAME, + validate=tracer_config.TRACING_VALIDATE + ) tracer = config.initialize_tracer() tracer_interceptor = open_tracing_server_interceptor(tracer, - log_payloads=tracer_config.TRACING_LOG_PAYLOAD, - span_decorator=span_decorator) + log_payloads=tracer_config.TRACING_LOG_PAYLOAD, + span_decorator=span_decorator) return Tracer(tracer, tracer_interceptor, intercept_server) diff --git a/utils/__init__.py b/utils/__init__.py index ec7f32bcbc..c1d55e76c0 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,5 +1,6 @@ from functools import wraps + def singleton(cls): instances = {} @wraps(cls) diff --git a/utils/logger_helper.py b/utils/logger_helper.py index 1b59aa40ec..55ce3206ab 100644 --- a/utils/logger_helper.py +++ b/utils/logger_helper.py @@ -9,18 +9,22 @@ class InfoFilter(logging.Filter): def filter(self, rec): return rec.levelno == logging.INFO + class DebugFilter(logging.Filter): def filter(self, rec): return rec.levelno == logging.DEBUG + class WarnFilter(logging.Filter): def filter(self, rec): return rec.levelno == logging.WARN + class ErrorFilter(logging.Filter): def filter(self, rec): return rec.levelno == logging.ERROR + class CriticalFilter(logging.Filter): def filter(self, rec): return rec.levelno == logging.CRITICAL @@ -36,6 +40,7 @@ COLORS = { 'ENDC': '\033[0m', } + class ColorFulFormatColMixin: def format_col(self, message_str, level_name): if level_name in COLORS.keys(): @@ -43,12 +48,14 @@ class ColorFulFormatColMixin: 'ENDC') return message_str + class ColorfulFormatter(logging.Formatter, ColorFulFormatColMixin): def format(self, record): message_str = super(ColorfulFormatter, self).format(record) return self.format_col(message_str, level_name=record.levelname) + def config(log_level, log_path, name, tz='UTC'): def build_log_file(level, log_path, name, tz): utc_now = datetime.datetime.utcnow() @@ -56,7 +63,7 @@ def config(log_level, log_path, name, tz='UTC'): local_tz = timezone(tz) tznow = utc_now.replace(tzinfo=utc_tz).astimezone(local_tz) return '{}-{}-{}.log'.format(os.path.join(log_path, name), tznow.strftime("%m-%d-%Y-%H:%M:%S"), - level) + level) if not os.path.exists(log_path): os.makedirs(log_path) @@ -66,10 +73,10 @@ def config(log_level, log_path, name, tz='UTC'): 'disable_existing_loggers': False, 'formatters': { 'default': { - 'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)' + 'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)' }, 'colorful_console': { - 'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)', + 'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)', '()': ColorfulFormatter, }, }, @@ -133,8 +140,8 @@ def config(log_level, log_path, name, tz='UTC'): }, 'loggers': { '': { - 'handlers': ['milvus_celery_console', 'milvus_info_file', 'milvus_debug_file', 'milvus_warn_file', \ - 'milvus_error_file', 'milvus_critical_file'], + 'handlers': ['milvus_celery_console', 'milvus_info_file', 'milvus_debug_file', 'milvus_warn_file', + 'milvus_error_file', 'milvus_critical_file'], 'level': log_level, 'propagate': False },