mirror of https://github.com/milvus-io/milvus.git
changes for unit test
parent
4aa29968a6
commit
9012f47a10
|
@ -7,12 +7,12 @@ from milvus.grpc_gen import status_pb2, milvus_pb2
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestTracer(opentracing.Tracer):
|
||||
class FakeTracer(opentracing.Tracer):
|
||||
pass
|
||||
|
||||
class TestSpan(opentracing.Span):
|
||||
class FakeSpan(opentracing.Span):
|
||||
def __init__(self, context, tracer, **kwargs):
|
||||
super(TestSpan, self).__init__(tracer, context)
|
||||
super(FakeSpan, self).__init__(tracer, context)
|
||||
self.reset()
|
||||
|
||||
def set_tag(self, key, value):
|
||||
|
@ -26,7 +26,7 @@ class TestSpan(opentracing.Span):
|
|||
self.logs = []
|
||||
|
||||
|
||||
class TestRpcInfo:
|
||||
class FakeRpcInfo:
|
||||
def __init__(self, request, response):
|
||||
self.request = request
|
||||
self.response = response
|
||||
|
@ -37,32 +37,32 @@ class TestGrpcUtils:
|
|||
request = 'request'
|
||||
OK = status_pb2.Status(error_code=status_pb2.SUCCESS, reason='Success')
|
||||
response = OK
|
||||
rpc_info = TestRpcInfo(request=request, response=response)
|
||||
span = TestSpan(context=None, tracer=TestTracer())
|
||||
rpc_info = FakeRpcInfo(request=request, response=response)
|
||||
span = FakeSpan(context=None, tracer=FakeTracer())
|
||||
span_deco = GrpcSpanDecorator()
|
||||
span_deco(span, rpc_info)
|
||||
assert len(span.logs) == 0
|
||||
assert len(span.tags) == 0
|
||||
|
||||
response = milvus_pb2.BoolReply(status=OK, bool_reply=False)
|
||||
rpc_info = TestRpcInfo(request=request, response=response)
|
||||
span = TestSpan(context=None, tracer=TestTracer())
|
||||
rpc_info = FakeRpcInfo(request=request, response=response)
|
||||
span = FakeSpan(context=None, tracer=FakeTracer())
|
||||
span_deco = GrpcSpanDecorator()
|
||||
span_deco(span, rpc_info)
|
||||
assert len(span.logs) == 0
|
||||
assert len(span.tags) == 0
|
||||
|
||||
response = 1
|
||||
rpc_info = TestRpcInfo(request=request, response=response)
|
||||
span = TestSpan(context=None, tracer=TestTracer())
|
||||
rpc_info = FakeRpcInfo(request=request, response=response)
|
||||
span = FakeSpan(context=None, tracer=FakeTracer())
|
||||
span_deco = GrpcSpanDecorator()
|
||||
span_deco(span, rpc_info)
|
||||
assert len(span.logs) == 1
|
||||
assert len(span.tags) == 1
|
||||
|
||||
response = 0
|
||||
rpc_info = TestRpcInfo(request=request, response=response)
|
||||
span = TestSpan(context=None, tracer=TestTracer())
|
||||
rpc_info = FakeRpcInfo(request=request, response=response)
|
||||
span = FakeSpan(context=None, tracer=FakeTracer())
|
||||
span_deco = GrpcSpanDecorator()
|
||||
span_deco(span, rpc_info)
|
||||
assert len(span.logs) == 0
|
||||
|
|
|
@ -237,13 +237,15 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
def _add_vectors(self, param, metadata=None):
|
||||
return self.connection(metadata=metadata).add_vectors(None, None, insert_param=param)
|
||||
|
||||
@mark_grpc_method
|
||||
def Insert(self, request, context):
|
||||
logger.info('Insert')
|
||||
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
|
||||
_status, _ids = self.connection(metadata={
|
||||
'resp_class': milvus_pb2.VectorIds
|
||||
}).add_vectors(None, None, insert_param=request)
|
||||
_status, _ids = self._add_vectors(metadata={
|
||||
'resp_class': milvus_pb2.VectorIds}, param=request)
|
||||
return milvus_pb2.VectorIds(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
vector_id_array=_ids
|
||||
|
@ -305,6 +307,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
def SearchInFiles(self, request, context):
|
||||
raise NotImplemented()
|
||||
|
||||
def _describe_table(self, table_name, metadata=None):
|
||||
return self.connection(metadata=metadata).describe_table(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def DescribeTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
@ -319,7 +324,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
}
|
||||
|
||||
logger.info('DescribeTable {}'.format(_table_name))
|
||||
_status, _table = self.connection(metadata=metadata).describe_table(_table_name)
|
||||
_status, _table = self._describe_table(metadata=metadata, table_name=_table_name)
|
||||
|
||||
if _status.OK():
|
||||
return milvus_pb2.TableSchema(
|
||||
|
@ -335,6 +340,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
)
|
||||
|
||||
def _count_table(self, table_name, metadata=None):
|
||||
return self.connection(metadata=metadata).get_table_row_count(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def CountTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
@ -351,12 +359,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
metadata = {
|
||||
'resp_class': milvus_pb2.TableRowCount
|
||||
}
|
||||
_status, _count = self.connection(metadata=metadata).get_table_row_count(_table_name)
|
||||
_status, _count = self._count_table(_table_name, metadata=metadata)
|
||||
|
||||
return milvus_pb2.TableRowCount(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_row_count=_count if isinstance(_count, int) else -1)
|
||||
|
||||
|
||||
def _get_server_version(self, metadata=None):
|
||||
return self.connection(metadata=metadata).server_version()
|
||||
|
||||
@mark_grpc_method
|
||||
def Cmd(self, request, context):
|
||||
_status, _cmd = Parser.parse_proto_Command(request)
|
||||
|
@ -364,7 +376,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.StringReply(
|
||||
status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
)
|
||||
|
||||
metadata = {
|
||||
|
@ -372,7 +384,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
}
|
||||
|
||||
if _cmd == 'version':
|
||||
_status, _reply = self.connection(metadata=metadata).server_version()
|
||||
_status, _reply = self._get_server_version(metadata=metadata)
|
||||
else:
|
||||
_status, _reply = self.connection(metadata=metadata).server_status()
|
||||
|
||||
|
@ -381,19 +393,25 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
string_reply=_reply
|
||||
)
|
||||
|
||||
def _show_tables(self):
|
||||
return self.connection(metadata=metadata).show_tables()
|
||||
|
||||
@mark_grpc_method
|
||||
def ShowTables(self, request, context):
|
||||
logger.info('ShowTables')
|
||||
metadata = {
|
||||
'resp_class': milvus_pb2.TableName
|
||||
}
|
||||
_status, _results = self.connection(metadata=metadata).show_tables()
|
||||
_status, _results = self._show_tables()
|
||||
|
||||
return milvus_pb2.TableNameList(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_names=_results
|
||||
)
|
||||
|
||||
def _delete_by_range(self, table_name, start_date, end_date):
|
||||
return self.connection().delete_vectors_by_range(table_name, start_date, end_date)
|
||||
|
||||
@mark_grpc_method
|
||||
def DeleteByRange(self, request, context):
|
||||
_status, unpacks = \
|
||||
|
@ -405,9 +423,12 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
_table_name, _start_date, _end_date = unpacks
|
||||
|
||||
logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date, _end_date))
|
||||
_status = self.connection().delete_vectors_by_range(_table_name, _start_date, _end_date)
|
||||
_status = self._delete_by_range(_table_name, _start_date, _end_date)
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
def _preload_table(self, table_name):
|
||||
return self.connection().preload_table(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def PreloadTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
@ -416,9 +437,12 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
logger.info('PreloadTable {}'.format(_table_name))
|
||||
_status = self.connection().preload_table(_table_name)
|
||||
_status = self._preload_table(_table_name)
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
def _describe_index(self, table_name, metadata=None):
|
||||
return self.connection(metadata=metadata).describe_index(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def DescribeIndex(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
@ -433,13 +457,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
}
|
||||
|
||||
logger.info('DescribeIndex {}'.format(_table_name))
|
||||
_status, _index_param = self.connection(metadata=metadata).describe_index(_table_name)
|
||||
_status, _index_param = self._describe_index(table_name=_table_name, metadata=metadata)
|
||||
|
||||
_index = milvus_pb2.Index(index_type=_index_param._index_type, nlist=_index_param._nlist)
|
||||
|
||||
return milvus_pb2.IndexParam(status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_name=_table_name, index=_index)
|
||||
|
||||
def _drop_index(self, table_name):
|
||||
return self.connection().drop_index(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def DropIndex(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
@ -448,5 +475,5 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
logger.info('DropIndex {}'.format(_table_name))
|
||||
_status = self.connection().drop_index(_table_name)
|
||||
_status = self._drop_index(_table_name)
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
|
|
@ -12,12 +12,14 @@ logger = logging.getLogger(__name__)
|
|||
class TracerFactory:
|
||||
@classmethod
|
||||
def new_tracer(cls, tracer_type, tracer_config, span_decorator=None, **kwargs):
|
||||
if not tracer_type:
|
||||
return Tracer()
|
||||
config = tracer_config.TRACING_CONFIG
|
||||
service_name = tracer_config.TRACING_SERVICE_NAME
|
||||
validate=tracer_config.TRACING_VALIDATE
|
||||
if not tracer_type:
|
||||
tracer_type = 'jaeger'
|
||||
config = tracer_config.DEFAULT_TRACING_CONFIG
|
||||
# if not tracer_type:
|
||||
# tracer_type = 'jaeger'
|
||||
# config = tracer_config.DEFAULT_TRACING_CONFIG
|
||||
|
||||
if tracer_type.lower() == 'jaeger':
|
||||
config = Config(config=config,
|
||||
|
|
Loading…
Reference in New Issue