changes for unit test

pull/232/head
peng.xu 2019-10-16 17:38:34 +08:00
parent 4aa29968a6
commit 9012f47a10
3 changed files with 56 additions and 27 deletions

View File

@ -7,12 +7,12 @@ from milvus.grpc_gen import status_pb2, milvus_pb2
logger = logging.getLogger(__name__)
class TestTracer(opentracing.Tracer):
class FakeTracer(opentracing.Tracer):
pass
class TestSpan(opentracing.Span):
class FakeSpan(opentracing.Span):
def __init__(self, context, tracer, **kwargs):
super(TestSpan, self).__init__(tracer, context)
super(FakeSpan, self).__init__(tracer, context)
self.reset()
def set_tag(self, key, value):
@ -26,7 +26,7 @@ class TestSpan(opentracing.Span):
self.logs = []
class TestRpcInfo:
class FakeRpcInfo:
def __init__(self, request, response):
self.request = request
self.response = response
@ -37,32 +37,32 @@ class TestGrpcUtils:
request = 'request'
OK = status_pb2.Status(error_code=status_pb2.SUCCESS, reason='Success')
response = OK
rpc_info = TestRpcInfo(request=request, response=response)
span = TestSpan(context=None, tracer=TestTracer())
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 0
assert len(span.tags) == 0
response = milvus_pb2.BoolReply(status=OK, bool_reply=False)
rpc_info = TestRpcInfo(request=request, response=response)
span = TestSpan(context=None, tracer=TestTracer())
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 0
assert len(span.tags) == 0
response = 1
rpc_info = TestRpcInfo(request=request, response=response)
span = TestSpan(context=None, tracer=TestTracer())
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 1
assert len(span.tags) == 1
response = 0
rpc_info = TestRpcInfo(request=request, response=response)
span = TestSpan(context=None, tracer=TestTracer())
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 0

View File

@ -237,13 +237,15 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return status_pb2.Status(error_code=_status.code, reason=_status.message)
def _add_vectors(self, param, metadata=None):
return self.connection(metadata=metadata).add_vectors(None, None, insert_param=param)
@mark_grpc_method
def Insert(self, request, context):
logger.info('Insert')
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
_status, _ids = self.connection(metadata={
'resp_class': milvus_pb2.VectorIds
}).add_vectors(None, None, insert_param=request)
_status, _ids = self._add_vectors(metadata={
'resp_class': milvus_pb2.VectorIds}, param=request)
return milvus_pb2.VectorIds(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
vector_id_array=_ids
@ -305,6 +307,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
def SearchInFiles(self, request, context):
raise NotImplemented()
def _describe_table(self, table_name, metadata=None):
return self.connection(metadata=metadata).describe_table(table_name)
@mark_grpc_method
def DescribeTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
@ -319,7 +324,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
}
logger.info('DescribeTable {}'.format(_table_name))
_status, _table = self.connection(metadata=metadata).describe_table(_table_name)
_status, _table = self._describe_table(metadata=metadata, table_name=_table_name)
if _status.OK():
return milvus_pb2.TableSchema(
@ -335,6 +340,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
)
def _count_table(self, table_name, metadata=None):
return self.connection(metadata=metadata).get_table_row_count(table_name)
@mark_grpc_method
def CountTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
@ -351,12 +359,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
metadata = {
'resp_class': milvus_pb2.TableRowCount
}
_status, _count = self.connection(metadata=metadata).get_table_row_count(_table_name)
_status, _count = self._count_table(_table_name, metadata=metadata)
return milvus_pb2.TableRowCount(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_row_count=_count if isinstance(_count, int) else -1)
def _get_server_version(self, metadata=None):
return self.connection(metadata=metadata).server_version()
@mark_grpc_method
def Cmd(self, request, context):
_status, _cmd = Parser.parse_proto_Command(request)
@ -364,7 +376,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
if not _status.OK():
return milvus_pb2.StringReply(
status_pb2.Status(error_code=_status.code, reason=_status.message)
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
)
metadata = {
@ -372,7 +384,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
}
if _cmd == 'version':
_status, _reply = self.connection(metadata=metadata).server_version()
_status, _reply = self._get_server_version(metadata=metadata)
else:
_status, _reply = self.connection(metadata=metadata).server_status()
@ -381,19 +393,25 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
string_reply=_reply
)
def _show_tables(self):
return self.connection(metadata=metadata).show_tables()
@mark_grpc_method
def ShowTables(self, request, context):
logger.info('ShowTables')
metadata = {
'resp_class': milvus_pb2.TableName
}
_status, _results = self.connection(metadata=metadata).show_tables()
_status, _results = self._show_tables()
return milvus_pb2.TableNameList(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_names=_results
)
def _delete_by_range(self, table_name, start_date, end_date):
return self.connection().delete_vectors_by_range(table_name, start_date, end_date)
@mark_grpc_method
def DeleteByRange(self, request, context):
_status, unpacks = \
@ -405,9 +423,12 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
_table_name, _start_date, _end_date = unpacks
logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date, _end_date))
_status = self.connection().delete_vectors_by_range(_table_name, _start_date, _end_date)
_status = self._delete_by_range(_table_name, _start_date, _end_date)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
def _preload_table(self, table_name):
return self.connection().preload_table(table_name)
@mark_grpc_method
def PreloadTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
@ -416,9 +437,12 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return status_pb2.Status(error_code=_status.code, reason=_status.message)
logger.info('PreloadTable {}'.format(_table_name))
_status = self.connection().preload_table(_table_name)
_status = self._preload_table(_table_name)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
def _describe_index(self, table_name, metadata=None):
return self.connection(metadata=metadata).describe_index(table_name)
@mark_grpc_method
def DescribeIndex(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
@ -433,13 +457,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
}
logger.info('DescribeIndex {}'.format(_table_name))
_status, _index_param = self.connection(metadata=metadata).describe_index(_table_name)
_status, _index_param = self._describe_index(table_name=_table_name, metadata=metadata)
_index = milvus_pb2.Index(index_type=_index_param._index_type, nlist=_index_param._nlist)
return milvus_pb2.IndexParam(status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_name=_table_name, index=_index)
def _drop_index(self, table_name):
return self.connection().drop_index(table_name)
@mark_grpc_method
def DropIndex(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
@ -448,5 +475,5 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return status_pb2.Status(error_code=_status.code, reason=_status.message)
logger.info('DropIndex {}'.format(_table_name))
_status = self.connection().drop_index(_table_name)
_status = self._drop_index(_table_name)
return status_pb2.Status(error_code=_status.code, reason=_status.message)

View File

@ -12,12 +12,14 @@ logger = logging.getLogger(__name__)
class TracerFactory:
@classmethod
def new_tracer(cls, tracer_type, tracer_config, span_decorator=None, **kwargs):
if not tracer_type:
return Tracer()
config = tracer_config.TRACING_CONFIG
service_name = tracer_config.TRACING_SERVICE_NAME
validate=tracer_config.TRACING_VALIDATE
if not tracer_type:
tracer_type = 'jaeger'
config = tracer_config.DEFAULT_TRACING_CONFIG
# if not tracer_type:
# tracer_type = 'jaeger'
# config = tracer_config.DEFAULT_TRACING_CONFIG
if tracer_type.lower() == 'jaeger':
config = Config(config=config,