mirror of https://github.com/milvus-io/milvus.git
update for code style
parent
bef93edab9
commit
71c67f59a3
|
@ -4,6 +4,7 @@ from mishards import settings, db, create_app
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(request):
|
||||
app = create_app(settings.TestingConfig)
|
||||
|
|
|
@ -2,6 +2,7 @@ import fire
|
|||
from mishards import db
|
||||
from sqlalchemy import and_
|
||||
|
||||
|
||||
class DBHandler:
|
||||
@classmethod
|
||||
def create_all(cls):
|
||||
|
@ -15,9 +16,9 @@ class DBHandler:
|
|||
def fun(cls, tid):
|
||||
from mishards.factories import TablesFactory, TableFilesFactory, Tables
|
||||
f = db.Session.query(Tables).filter(and_(
|
||||
Tables.table_id==tid,
|
||||
Tables.state!=Tables.TO_DELETE)
|
||||
).first()
|
||||
Tables.table_id == tid,
|
||||
Tables.state != Tables.TO_DELETE)
|
||||
).first()
|
||||
print(f)
|
||||
|
||||
# f1 = TableFilesFactory()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import logging
|
||||
import logging
|
||||
from mishards import settings
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
@ -8,6 +8,7 @@ db = DB()
|
|||
from mishards.server import Server
|
||||
grpc_server = Server()
|
||||
|
||||
|
||||
def create_app(testing_config=None):
|
||||
config = testing_config if testing_config else settings.DefaultConfig
|
||||
db.init_db(uri=config.SQLALCHEMY_DATABASE_URI, echo=config.SQL_ECHO)
|
||||
|
@ -24,7 +25,7 @@ def create_app(testing_config=None):
|
|||
from tracing.factory import TracerFactory
|
||||
from mishards.grpc_utils import GrpcSpanDecorator
|
||||
tracer = TracerFactory.new_tracer(settings.TRACING_TYPE, settings.TracingConfig,
|
||||
span_decorator=GrpcSpanDecorator())
|
||||
span_decorator=GrpcSpanDecorator())
|
||||
|
||||
grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, discover=discover)
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from utils import singleton
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
|
||||
self.name = name
|
||||
|
@ -55,7 +56,7 @@ class Connection:
|
|||
|
||||
if not self.can_retry and not self.connected:
|
||||
raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry,
|
||||
metadata=metadata))
|
||||
metadata=metadata))
|
||||
|
||||
self.retried = 0
|
||||
|
||||
|
@ -72,6 +73,7 @@ class Connection:
|
|||
raise e
|
||||
return inner
|
||||
|
||||
|
||||
@singleton
|
||||
class ConnectionMgr:
|
||||
def __init__(self):
|
||||
|
@ -90,10 +92,10 @@ class ConnectionMgr:
|
|||
if not throw:
|
||||
return None
|
||||
raise exceptions.ConnectionNotFoundError(message='Connection {} not found'.format(name),
|
||||
metadata=metadata)
|
||||
metadata=metadata)
|
||||
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
|
||||
threaded = {
|
||||
threading.get_ident() : this_conn
|
||||
threading.get_ident(): this_conn
|
||||
}
|
||||
self.conns[name] = threaded
|
||||
return this_conn
|
||||
|
@ -106,7 +108,7 @@ class ConnectionMgr:
|
|||
if not throw:
|
||||
return None
|
||||
raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name),
|
||||
metadata=metadata)
|
||||
metadata=metadata)
|
||||
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
|
||||
c[tid] = this_conn
|
||||
return this_conn
|
||||
|
|
|
@ -14,8 +14,10 @@ class LocalSession(SessionBase):
|
|||
bind = options.pop('bind', None) or db.engine
|
||||
SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, **options)
|
||||
|
||||
|
||||
class DB:
|
||||
Model = declarative_base()
|
||||
|
||||
def __init__(self, uri=None, echo=False):
|
||||
self.echo = echo
|
||||
uri and self.init_db(uri, echo)
|
||||
|
@ -27,9 +29,9 @@ class DB:
|
|||
self.engine = create_engine(url)
|
||||
else:
|
||||
self.engine = create_engine(uri, pool_size=100, pool_recycle=5, pool_timeout=30,
|
||||
pool_pre_ping=True,
|
||||
echo=echo,
|
||||
max_overflow=0)
|
||||
pool_pre_ping=True,
|
||||
echo=echo,
|
||||
max_overflow=0)
|
||||
self.uri = uri
|
||||
self.url = url
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from mishards import grpc_server as server, exceptions
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resp_handler(err, error_code):
|
||||
if not isinstance(err, exceptions.BaseException):
|
||||
return status_pb2.Status(error_code=error_code, reason=str(err))
|
||||
|
@ -50,21 +51,25 @@ def resp_handler(err, error_code):
|
|||
status.error_code = status_pb2.UNEXPECTED_ERROR
|
||||
return status
|
||||
|
||||
|
||||
@server.errorhandler(exceptions.TableNotFoundError)
|
||||
def TableNotFoundErrorHandler(err):
|
||||
logger.error(err)
|
||||
return resp_handler(err, status_pb2.TABLE_NOT_EXISTS)
|
||||
|
||||
|
||||
@server.errorhandler(exceptions.InvalidArgumentError)
|
||||
def InvalidArgumentErrorHandler(err):
|
||||
logger.error(err)
|
||||
return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT)
|
||||
|
||||
|
||||
@server.errorhandler(exceptions.DBError)
|
||||
def DBErrorHandler(err):
|
||||
logger.error(err)
|
||||
return resp_handler(err, status_pb2.UNEXPECTED_ERROR)
|
||||
|
||||
|
||||
@server.errorhandler(exceptions.InvalidRangeError)
|
||||
def InvalidArgumentErrorHandler(err):
|
||||
logger.error(err)
|
||||
|
|
|
@ -1,26 +1,34 @@
|
|||
import mishards.exception_codes as codes
|
||||
|
||||
|
||||
class BaseException(Exception):
|
||||
code = codes.INVALID_CODE
|
||||
message = 'BaseException'
|
||||
|
||||
def __init__(self, message='', metadata=None):
|
||||
self.message = self.__class__.__name__ if not message else message
|
||||
self.metadata = metadata
|
||||
|
||||
|
||||
class ConnectionConnectError(BaseException):
|
||||
code = codes.CONNECT_ERROR_CODE
|
||||
|
||||
|
||||
class ConnectionNotFoundError(BaseException):
|
||||
code = codes.CONNECTTION_NOT_FOUND_CODE
|
||||
|
||||
|
||||
class DBError(BaseException):
|
||||
code = codes.DB_ERROR_CODE
|
||||
|
||||
|
||||
class TableNotFoundError(BaseException):
|
||||
code = codes.TABLE_NOT_FOUND_CODE
|
||||
|
||||
|
||||
class InvalidArgumentError(BaseException):
|
||||
code = codes.INVALID_ARGUMENT_CODE
|
||||
|
||||
|
||||
class InvalidRangeError(BaseException):
|
||||
code = codes.INVALID_DATE_RANGE_CODE
|
||||
|
|
|
@ -9,13 +9,16 @@ from faker.providers import BaseProvider
|
|||
from mishards import db
|
||||
from mishards.models import Tables, TableFiles
|
||||
|
||||
|
||||
class FakerProvider(BaseProvider):
|
||||
def this_date(self):
|
||||
t = datetime.datetime.today()
|
||||
return (t.year - 1900) * 10000 + (t.month-1)*100 + t.day
|
||||
return (t.year - 1900) * 10000 + (t.month - 1) * 100 + t.day
|
||||
|
||||
|
||||
factory.Faker.add_provider(FakerProvider)
|
||||
|
||||
|
||||
class TablesFactory(SQLAlchemyModelFactory):
|
||||
class Meta:
|
||||
model = Tables
|
||||
|
@ -24,14 +27,15 @@ class TablesFactory(SQLAlchemyModelFactory):
|
|||
|
||||
id = factory.Faker('random_number', digits=16, fix_len=True)
|
||||
table_id = factory.Faker('uuid4')
|
||||
state = factory.Faker('random_element', elements=(0,1,2,3))
|
||||
dimension = factory.Faker('random_element', elements=(256,512))
|
||||
state = factory.Faker('random_element', elements=(0, 1, 2, 3))
|
||||
dimension = factory.Faker('random_element', elements=(256, 512))
|
||||
created_on = int(time.time())
|
||||
index_file_size = 0
|
||||
engine_type = factory.Faker('random_element', elements=(0,1,2,3))
|
||||
metric_type = factory.Faker('random_element', elements=(0,1))
|
||||
engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3))
|
||||
metric_type = factory.Faker('random_element', elements=(0, 1))
|
||||
nlist = 16384
|
||||
|
||||
|
||||
class TableFilesFactory(SQLAlchemyModelFactory):
|
||||
class Meta:
|
||||
model = TableFiles
|
||||
|
@ -40,9 +44,9 @@ class TableFilesFactory(SQLAlchemyModelFactory):
|
|||
|
||||
id = factory.Faker('random_number', digits=16, fix_len=True)
|
||||
table = factory.SubFactory(TablesFactory)
|
||||
engine_type = factory.Faker('random_element', elements=(0,1,2,3))
|
||||
engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3))
|
||||
file_id = factory.Faker('uuid4')
|
||||
file_type = factory.Faker('random_element', elements=(0,1,2,3,4))
|
||||
file_type = factory.Faker('random_element', elements=(0, 1, 2, 3, 4))
|
||||
file_size = factory.Faker('random_number')
|
||||
updated_time = int(time.time())
|
||||
created_on = int(time.time())
|
||||
|
|
|
@ -14,21 +14,23 @@ class GrpcSpanDecorator(SpanDecorator):
|
|||
status = rpc_info.response.status
|
||||
except Exception as e:
|
||||
status = status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR,
|
||||
reason='Should not happen')
|
||||
reason='Should not happen')
|
||||
|
||||
if status.error_code == 0:
|
||||
return
|
||||
error_log = {'event': 'error',
|
||||
'request': rpc_info.request,
|
||||
'response': rpc_info.response
|
||||
}
|
||||
'request': rpc_info.request,
|
||||
'response': rpc_info.response
|
||||
}
|
||||
span.set_tag('error', True)
|
||||
span.log_kv(error_log)
|
||||
|
||||
|
||||
def mark_grpc_method(func):
|
||||
setattr(func, 'grpc_method', True)
|
||||
return func
|
||||
|
||||
|
||||
def is_grpc_method(func):
|
||||
if not func:
|
||||
return False
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# class GrpcArgsWrapper(object):
|
||||
|
||||
# @classmethod
|
||||
# def proto_TableName(cls):
|
||||
# @classmethod
|
||||
# def proto_TableName(cls):
|
||||
|
|
|
@ -9,8 +9,8 @@ else:
|
|||
import md5
|
||||
md5_constructor = md5.new
|
||||
|
||||
class HashRing(object):
|
||||
|
||||
class HashRing(object):
|
||||
def __init__(self, nodes=None, weights=None):
|
||||
"""`nodes` is a list of objects that have a proper __str__ representation.
|
||||
`weights` is dictionary that sets weights to the nodes. The default
|
||||
|
@ -40,13 +40,13 @@ class HashRing(object):
|
|||
if node in self.weights:
|
||||
weight = self.weights.get(node)
|
||||
|
||||
factor = math.floor((40*len(self.nodes)*weight) / total_weight);
|
||||
factor = math.floor((40 * len(self.nodes) * weight) / total_weight)
|
||||
|
||||
for j in range(0, int(factor)):
|
||||
b_key = self._hash_digest( '%s-%s' % (node, j) )
|
||||
b_key = self._hash_digest('%s-%s' % (node, j))
|
||||
|
||||
for i in range(0, 3):
|
||||
key = self._hash_val(b_key, lambda x: x+i*4)
|
||||
key = self._hash_val(b_key, lambda x: x + i * 4)
|
||||
self.ring[key] = node
|
||||
self._sorted_keys.append(key)
|
||||
|
||||
|
@ -60,7 +60,7 @@ class HashRing(object):
|
|||
pos = self.get_node_pos(string_key)
|
||||
if pos is None:
|
||||
return None
|
||||
return self.ring[ self._sorted_keys[pos] ]
|
||||
return self.ring[self._sorted_keys[pos]]
|
||||
|
||||
def get_node_pos(self, string_key):
|
||||
"""Given a string key a corresponding node in the hash ring is returned
|
||||
|
@ -94,6 +94,7 @@ class HashRing(object):
|
|||
yield None, None
|
||||
|
||||
returned_values = set()
|
||||
|
||||
def distinct_filter(value):
|
||||
if str(value) not in returned_values:
|
||||
returned_values.add(str(value))
|
||||
|
@ -121,10 +122,8 @@ class HashRing(object):
|
|||
return self._hash_val(b_key, lambda x: x)
|
||||
|
||||
def _hash_val(self, b_key, entry_fn):
|
||||
return (( b_key[entry_fn(3)] << 24)
|
||||
|(b_key[entry_fn(2)] << 16)
|
||||
|(b_key[entry_fn(1)] << 8)
|
||||
| b_key[entry_fn(0)] )
|
||||
return (b_key[entry_fn(3)] << 24) | (b_key[entry_fn(2)] << 16) | (
|
||||
b_key[entry_fn(1)] << 8) | b_key[entry_fn(0)]
|
||||
|
||||
def _hash_digest(self, key):
|
||||
m = md5_constructor()
|
||||
|
@ -132,12 +131,13 @@ class HashRing(object):
|
|||
m.update(key)
|
||||
return m.digest()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from collections import defaultdict
|
||||
servers = ['192.168.0.246:11212',
|
||||
'192.168.0.247:11212',
|
||||
'192.168.0.248:11212',
|
||||
'192.168.0.249:11212']
|
||||
servers = [
|
||||
'192.168.0.246:11212', '192.168.0.247:11212', '192.168.0.248:11212',
|
||||
'192.168.0.249:11212'
|
||||
]
|
||||
|
||||
ring = HashRing(servers)
|
||||
keys = ['{}'.format(i) for i in range(100)]
|
||||
|
@ -146,5 +146,5 @@ if __name__ == '__main__':
|
|||
server = ring.get_node(k)
|
||||
mapped[server].append(k)
|
||||
|
||||
for k,v in mapped.items():
|
||||
for k, v in mapped.items():
|
||||
print(k, v)
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from mishards import (
|
||||
settings, create_app)
|
||||
from mishards import (settings, create_app)
|
||||
|
||||
|
||||
def main():
|
||||
server = create_app(settings.TestingConfig if settings.TESTING else settings.DefaultConfig)
|
||||
server = create_app(
|
||||
settings.TestingConfig if settings.TESTING else settings.DefaultConfig)
|
||||
server.run(port=settings.SERVER_PORT)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import logging
|
||||
from sqlalchemy import (Integer, Boolean, Text,
|
||||
String, BigInteger, func, and_, or_,
|
||||
Column)
|
||||
String, BigInteger, func, and_, or_,
|
||||
Column)
|
||||
from sqlalchemy.orm import relationship, backref
|
||||
|
||||
from mishards import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TableFiles(db.Model):
|
||||
FILE_TYPE_NEW = 0
|
||||
FILE_TYPE_RAW = 1
|
||||
|
@ -57,16 +58,16 @@ class Tables(db.Model):
|
|||
|
||||
def files_to_search(self, date_range=None):
|
||||
cond = or_(
|
||||
TableFiles.file_type==TableFiles.FILE_TYPE_RAW,
|
||||
TableFiles.file_type==TableFiles.FILE_TYPE_TO_INDEX,
|
||||
TableFiles.file_type==TableFiles.FILE_TYPE_INDEX,
|
||||
TableFiles.file_type == TableFiles.FILE_TYPE_RAW,
|
||||
TableFiles.file_type == TableFiles.FILE_TYPE_TO_INDEX,
|
||||
TableFiles.file_type == TableFiles.FILE_TYPE_INDEX,
|
||||
)
|
||||
if date_range:
|
||||
cond = and_(
|
||||
cond,
|
||||
or_(
|
||||
and_(TableFiles.date>=d[0], TableFiles.date<d[1]) for d in date_range
|
||||
)
|
||||
and_(TableFiles.date >= d[0], TableFiles.date < d[1]) for d in date_range
|
||||
)
|
||||
)
|
||||
|
||||
files = self.files.filter(cond)
|
||||
|
|
|
@ -33,7 +33,7 @@ class Server:
|
|||
self.server_impl = grpc.server(
|
||||
thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
|
||||
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
|
||||
(cygrpc.ChannelArgKey.max_receive_message_length, -1)]
|
||||
(cygrpc.ChannelArgKey.max_receive_message_length, -1)]
|
||||
)
|
||||
|
||||
self.server_impl = self.tracer.decorate(self.server_impl)
|
||||
|
@ -46,7 +46,7 @@ class Server:
|
|||
ip = socket.gethostbyname(url.hostname)
|
||||
socket.inet_pton(socket.AF_INET, ip)
|
||||
self.conn_mgr.register('WOSERVER',
|
||||
'{}://{}:{}'.format(url.scheme, ip, url.port or 80))
|
||||
'{}://{}:{}'.format(url.scheme, ip, url.port or 80))
|
||||
|
||||
def register_pre_run_handler(self, func):
|
||||
logger.info('Regiterring {} into server pre_run_handlers'.format(func))
|
||||
|
|
|
@ -11,7 +11,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
|
||||
from milvus.grpc_gen.milvus_pb2 import TopKQueryResult
|
||||
from milvus.client.Abstract import Range
|
||||
from milvus.client import types
|
||||
from milvus.client import types as Types
|
||||
|
||||
from mishards import (db, settings, exceptions)
|
||||
from mishards.grpc_utils import mark_grpc_method
|
||||
|
@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
MAX_NPROBE = 2048
|
||||
|
||||
def __init__(self, conn_mgr, tracer, *args, **kwargs):
|
||||
self.conn_mgr = conn_mgr
|
||||
self.table_meta = {}
|
||||
|
@ -44,8 +45,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
return conn.conn
|
||||
|
||||
def _format_date(self, start, end):
|
||||
return ((start.year-1900)*10000 + (start.month-1)*100 + start.day
|
||||
, (end.year-1900)*10000 + (end.month-1)*100 + end.day)
|
||||
return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day, (end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day)
|
||||
|
||||
def _range_to_date(self, range_obj, metadata=None):
|
||||
try:
|
||||
|
@ -54,8 +54,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
assert start < end
|
||||
except (ValueError, AssertionError):
|
||||
raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format(
|
||||
range_obj.start_date, range_obj.end_date
|
||||
), metadata=metadata)
|
||||
range_obj.start_date, range_obj.end_date
|
||||
), metadata=metadata)
|
||||
|
||||
return self._format_date(start, end)
|
||||
|
||||
|
@ -63,9 +63,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
# PXU TODO: Implement Thread-local Context
|
||||
try:
|
||||
table = db.Session.query(Tables).filter(and_(
|
||||
Tables.table_id==table_id,
|
||||
Tables.state!=Tables.TO_DELETE
|
||||
)).first()
|
||||
Tables.table_id == table_id,
|
||||
Tables.state != Tables.TO_DELETE
|
||||
)).first()
|
||||
except sqlalchemy_exc.SQLAlchemyError as e:
|
||||
raise exceptions.DBError(message=str(e), metadata=metadata)
|
||||
|
||||
|
@ -93,7 +93,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
return routing
|
||||
|
||||
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
|
||||
status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success")
|
||||
status = status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success")
|
||||
if not files_n_topk_results:
|
||||
return status, []
|
||||
|
||||
|
@ -107,7 +107,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
for request_pos, each_request_results in enumerate(files_collection.topk_query_result):
|
||||
request_results[request_pos].extend(each_request_results.query_result_arrays)
|
||||
request_results[request_pos] = sorted(request_results[request_pos], key=lambda x: x.distance,
|
||||
reverse=reverse)[:topk]
|
||||
reverse=reverse)[:topk]
|
||||
|
||||
calc_time = time.time() - calc_time
|
||||
logger.info('Merge takes {}'.format(calc_time))
|
||||
|
@ -127,7 +127,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
|
||||
routing = {}
|
||||
with self.tracer.start_span('get_routing',
|
||||
child_of=context.get_active_span().context):
|
||||
child_of=context.get_active_span().context):
|
||||
routing = self._get_routing_file_ids(table_id, range_array, metadata=metadata)
|
||||
logger.info('Routing: {}'.format(routing))
|
||||
|
||||
|
@ -140,28 +140,28 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
|
||||
def search(addr, query_params, vectors, topk, nprobe, **kwargs):
|
||||
logger.info('Send Search Request: addr={};params={};nq={};topk={};nprobe={}'.format(
|
||||
addr, query_params, len(vectors), topk, nprobe
|
||||
))
|
||||
addr, query_params, len(vectors), topk, nprobe
|
||||
))
|
||||
|
||||
conn = self.query_conn(addr, metadata=metadata)
|
||||
start = time.time()
|
||||
span = kwargs.get('span', None)
|
||||
span = span if span else context.get_active_span().context
|
||||
with self.tracer.start_span('search_{}'.format(addr),
|
||||
child_of=context.get_active_span().context):
|
||||
child_of=context.get_active_span().context):
|
||||
ret = conn.search_vectors_in_files(table_name=query_params['table_id'],
|
||||
file_ids=query_params['file_ids'],
|
||||
query_records=vectors,
|
||||
top_k=topk,
|
||||
nprobe=nprobe,
|
||||
lazy=True)
|
||||
file_ids=query_params['file_ids'],
|
||||
query_records=vectors,
|
||||
top_k=topk,
|
||||
nprobe=nprobe,
|
||||
lazy=True)
|
||||
end = time.time()
|
||||
logger.info('search_vectors_in_files takes: {}'.format(end - start))
|
||||
|
||||
all_topk_results.append(ret)
|
||||
|
||||
with self.tracer.start_span('do_search',
|
||||
child_of=context.get_active_span().context) as span:
|
||||
child_of=context.get_active_span().context) as span:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
for addr, params in routing.items():
|
||||
res = pool.submit(search, addr, params, vectors, topk, nprobe, span=span)
|
||||
|
@ -170,9 +170,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
for res in rs:
|
||||
res.result()
|
||||
|
||||
reverse = table_meta.metric_type == types.MetricType.IP
|
||||
reverse = table_meta.metric_type == Types.MetricType.IP
|
||||
with self.tracer.start_span('do_merge',
|
||||
child_of=context.get_active_span().context):
|
||||
child_of=context.get_active_span().context):
|
||||
return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata)
|
||||
|
||||
@mark_grpc_method
|
||||
|
@ -201,8 +201,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
logger.info('HasTable {}'.format(_table_name))
|
||||
|
||||
_bool = self.connection(metadata={
|
||||
'resp_class': milvus_pb2.BoolReply
|
||||
}).has_table(_table_name)
|
||||
'resp_class': milvus_pb2.BoolReply
|
||||
}).has_table(_table_name)
|
||||
|
||||
return milvus_pb2.BoolReply(
|
||||
status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"),
|
||||
|
@ -244,7 +244,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
|
||||
_status, _ids = self.connection(metadata={
|
||||
'resp_class': milvus_pb2.VectorIds
|
||||
}).add_vectors(None, None, insert_param=request)
|
||||
}).add_vectors(None, None, insert_param=request)
|
||||
return milvus_pb2.VectorIds(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
vector_id_array=_ids
|
||||
|
@ -266,7 +266,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
|
||||
if nprobe > self.MAX_NPROBE or nprobe <= 0:
|
||||
raise exceptions.InvalidArgumentError(message='Invalid nprobe: {}'.format(nprobe),
|
||||
metadata=metadata)
|
||||
metadata=metadata)
|
||||
|
||||
table_meta = self.table_meta.get(table_name, None)
|
||||
|
||||
|
@ -332,8 +332,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
)
|
||||
|
||||
return milvus_pb2.TableSchema(
|
||||
table_name=_table_name,
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_name=_table_name,
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
)
|
||||
|
||||
@mark_grpc_method
|
||||
|
@ -391,8 +391,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
_status, _results = self.connection(metadata=metadata).show_tables()
|
||||
|
||||
return milvus_pb2.TableNameList(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_names=_results
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_names=_results
|
||||
)
|
||||
|
||||
@mark_grpc_method
|
||||
|
@ -426,7 +426,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.IndexParam(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
)
|
||||
|
||||
metadata = {
|
||||
|
@ -439,7 +439,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
|||
_index = milvus_pb2.Index(index_type=_index_param._index_type, nlist=_index_param._nlist)
|
||||
|
||||
return milvus_pb2.IndexParam(status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_name=_table_name, index=_index)
|
||||
table_name=_table_name, index=_index)
|
||||
|
||||
@mark_grpc_method
|
||||
def DropIndex(self, request, context):
|
||||
|
|
|
@ -39,13 +39,15 @@ if SD_PROVIDER == 'Kubernetes':
|
|||
elif SD_PROVIDER == 'Static':
|
||||
from sd.static_provider import StaticProviderSettings
|
||||
SD_PROVIDER_SETTINGS = StaticProviderSettings(
|
||||
hosts=env.list('SD_STATIC_HOSTS', [])
|
||||
)
|
||||
hosts=env.list('SD_STATIC_HOSTS', [])
|
||||
)
|
||||
|
||||
TESTING = env.bool('TESTING', False)
|
||||
TESTING_WOSERVER = env.str('TESTING_WOSERVER', 'tcp://127.0.0.1:19530')
|
||||
|
||||
TRACING_TYPE = env.str('TRACING_TYPE', '')
|
||||
|
||||
|
||||
class TracingConfig:
|
||||
TRACING_SERVICE_NAME = env.str('TRACING_SERVICE_NAME', 'mishards')
|
||||
TRACING_VALIDATE = env.bool('TRACING_VALIDATE', True)
|
||||
|
@ -54,7 +56,7 @@ class TracingConfig:
|
|||
'sampler': {
|
||||
'type': env.str('TRACING_SAMPLER_TYPE', 'const'),
|
||||
'param': env.str('TRACING_SAMPLER_PARAM', "1"),
|
||||
},
|
||||
},
|
||||
'local_agent': {
|
||||
'reporting_host': env.str('TRACING_REPORTING_HOST', '127.0.0.1'),
|
||||
'reporting_port': env.str('TRACING_REPORTING_PORT', '5775')
|
||||
|
@ -62,10 +64,12 @@ class TracingConfig:
|
|||
'logging': env.bool('TRACING_LOGGING', True)
|
||||
}
|
||||
|
||||
|
||||
class DefaultConfig:
|
||||
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI')
|
||||
SQL_ECHO = env.bool('SQL_ECHO', False)
|
||||
|
||||
|
||||
TESTING = env.bool('TESTING', False)
|
||||
if TESTING:
|
||||
class TestingConfig(DefaultConfig):
|
||||
|
|
|
@ -6,6 +6,7 @@ from mishards import exceptions
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('app')
|
||||
class TestConnection:
|
||||
def test_manager(self):
|
||||
|
@ -30,8 +31,10 @@ class TestConnection:
|
|||
class Conn:
|
||||
def __init__(self, state):
|
||||
self.state = state
|
||||
|
||||
def connect(self, uri):
|
||||
return self.state
|
||||
|
||||
def connected(self):
|
||||
return self.state
|
||||
FAIL_CONN = Conn(False)
|
||||
|
@ -48,6 +51,7 @@ class TestConnection:
|
|||
class Func():
|
||||
def __init__(self):
|
||||
self.executed = False
|
||||
|
||||
def __call__(self):
|
||||
self.executed = True
|
||||
|
||||
|
@ -55,8 +59,8 @@ class TestConnection:
|
|||
|
||||
RetryObj = Retry()
|
||||
c = Connection('client', uri='',
|
||||
max_retry=max_retry,
|
||||
on_retry_func=RetryObj)
|
||||
max_retry=max_retry,
|
||||
on_retry_func=RetryObj)
|
||||
c.conn = FAIL_CONN
|
||||
ff = Func()
|
||||
this_connect = c.connect(func=ff)
|
||||
|
|
|
@ -3,12 +3,13 @@ import pytest
|
|||
from mishards.factories import TableFiles, Tables, TableFilesFactory, TablesFactory
|
||||
from mishards import db, create_app, settings
|
||||
from mishards.factories import (
|
||||
Tables, TableFiles,
|
||||
TablesFactory, TableFilesFactory
|
||||
)
|
||||
Tables, TableFiles,
|
||||
TablesFactory, TableFilesFactory
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('app')
|
||||
class TestModels:
|
||||
def test_files_to_search(self):
|
||||
|
|
|
@ -24,4 +24,5 @@ class ProviderManager:
|
|||
def get_provider(cls, name):
|
||||
return cls.PROVIDERS.get(name, None)
|
||||
|
||||
|
||||
from sd import kubernetes_provider, static_provider
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
if __name__ == '__main__':
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
@ -71,7 +72,6 @@ class K8SHeartbeatHandler(threading.Thread, K8SMixin):
|
|||
|
||||
self.queue.put(event_message)
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(exc)
|
||||
|
||||
|
@ -98,18 +98,18 @@ class K8SEventListener(threading.Thread, K8SMixin):
|
|||
resource_version = ''
|
||||
w = watch.Watch()
|
||||
for event in w.stream(self.v1.list_namespaced_event, namespace=self.namespace,
|
||||
field_selector='involvedObject.kind=Pod'):
|
||||
field_selector='involvedObject.kind=Pod'):
|
||||
if self.terminate:
|
||||
break
|
||||
|
||||
resource_version = int(event['object'].metadata.resource_version)
|
||||
|
||||
info = dict(
|
||||
eType='WatchEvent',
|
||||
pod=event['object'].involved_object.name,
|
||||
reason=event['object'].reason,
|
||||
message=event['object'].message,
|
||||
start_up=self.at_start_up,
|
||||
eType='WatchEvent',
|
||||
pod=event['object'].involved_object.name,
|
||||
reason=event['object'].reason,
|
||||
message=event['object'].message,
|
||||
start_up=self.at_start_up,
|
||||
)
|
||||
self.at_start_up = False
|
||||
# logger.info('Received event: {}'.format(info))
|
||||
|
@ -135,7 +135,7 @@ class EventHandler(threading.Thread):
|
|||
def on_pod_started(self, event, **kwargs):
|
||||
try_cnt = 3
|
||||
pod = None
|
||||
while try_cnt > 0:
|
||||
while try_cnt > 0:
|
||||
try_cnt -= 1
|
||||
try:
|
||||
pod = self.mgr.v1.read_namespaced_pod(name=event['pod'], namespace=self.namespace)
|
||||
|
@ -203,6 +203,7 @@ class EventHandler(threading.Thread):
|
|||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
||||
class KubernetesProviderSettings:
|
||||
def __init__(self, namespace, pod_patt, label_selector, in_cluster, poll_interval, **kwargs):
|
||||
self.namespace = namespace
|
||||
|
@ -211,10 +212,12 @@ class KubernetesProviderSettings:
|
|||
self.in_cluster = in_cluster
|
||||
self.poll_interval = poll_interval
|
||||
|
||||
|
||||
@singleton
|
||||
@ProviderManager.register_service_provider
|
||||
class KubernetesProvider(object):
|
||||
NAME = 'Kubernetes'
|
||||
|
||||
def __init__(self, settings, conn_mgr, **kwargs):
|
||||
self.namespace = settings.namespace
|
||||
self.pod_patt = settings.pod_patt
|
||||
|
@ -233,27 +236,27 @@ class KubernetesProvider(object):
|
|||
self.v1 = client.CoreV1Api()
|
||||
|
||||
self.listener = K8SEventListener(
|
||||
message_queue=self.queue,
|
||||
namespace=self.namespace,
|
||||
in_cluster=self.in_cluster,
|
||||
v1=self.v1,
|
||||
**kwargs
|
||||
)
|
||||
message_queue=self.queue,
|
||||
namespace=self.namespace,
|
||||
in_cluster=self.in_cluster,
|
||||
v1=self.v1,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.pod_heartbeater = K8SHeartbeatHandler(
|
||||
message_queue=self.queue,
|
||||
namespace=self.namespace,
|
||||
label_selector=self.label_selector,
|
||||
in_cluster=self.in_cluster,
|
||||
v1=self.v1,
|
||||
poll_interval=self.poll_interval,
|
||||
**kwargs
|
||||
)
|
||||
message_queue=self.queue,
|
||||
namespace=self.namespace,
|
||||
label_selector=self.label_selector,
|
||||
in_cluster=self.in_cluster,
|
||||
v1=self.v1,
|
||||
poll_interval=self.poll_interval,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.event_handler = EventHandler(mgr=self,
|
||||
message_queue=self.queue,
|
||||
namespace=self.namespace,
|
||||
pod_patt=self.pod_patt, **kwargs)
|
||||
message_queue=self.queue,
|
||||
namespace=self.namespace,
|
||||
pod_patt=self.pod_patt, **kwargs)
|
||||
|
||||
def add_pod(self, name, ip):
|
||||
self.conn_mgr.register(name, 'tcp://{}:19530'.format(ip))
|
||||
|
@ -276,9 +279,11 @@ class KubernetesProvider(object):
|
|||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
class Connect:
|
||||
def register(self, name, value):
|
||||
logger.error('Register: {} - {}'.format(name, value))
|
||||
|
||||
def unregister(self, name):
|
||||
logger.error('Unregister: {}'.format(name))
|
||||
|
||||
|
@ -289,16 +294,16 @@ if __name__ == '__main__':
|
|||
connect_mgr = Connect()
|
||||
|
||||
settings = KubernetesProviderSettings(
|
||||
namespace='xp',
|
||||
pod_patt=".*-ro-servers-.*",
|
||||
label_selector='tier=ro-servers',
|
||||
poll_interval=5,
|
||||
in_cluster=False)
|
||||
namespace='xp',
|
||||
pod_patt=".*-ro-servers-.*",
|
||||
label_selector='tier=ro-servers',
|
||||
poll_interval=5,
|
||||
in_cluster=False)
|
||||
|
||||
provider_class = ProviderManager.get_provider('Kubernetes')
|
||||
t = provider_class(conn_mgr=connect_mgr,
|
||||
settings=settings
|
||||
)
|
||||
settings=settings
|
||||
)
|
||||
t.start()
|
||||
cnt = 100
|
||||
while cnt > 0:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
if __name__ == '__main__':
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
@ -6,14 +7,17 @@ import socket
|
|||
from utils import singleton
|
||||
from sd import ProviderManager
|
||||
|
||||
|
||||
class StaticProviderSettings:
|
||||
def __init__(self, hosts):
|
||||
self.hosts = hosts
|
||||
|
||||
|
||||
@singleton
|
||||
@ProviderManager.register_service_provider
|
||||
class KubernetesProvider(object):
|
||||
NAME = 'Static'
|
||||
|
||||
def __init__(self, settings, conn_mgr, **kwargs):
|
||||
self.conn_mgr = conn_mgr
|
||||
self.hosts = [socket.gethostbyname(host) for host in settings.hosts]
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
def empty_server_interceptor_decorator(target_server, interceptor):
|
||||
return target_server
|
||||
|
||||
|
||||
class Tracer:
|
||||
def __init__(self, tracer=None,
|
||||
interceptor=None,
|
||||
server_decorator=empty_server_interceptor_decorator):
|
||||
interceptor=None,
|
||||
server_decorator=empty_server_interceptor_decorator):
|
||||
self.tracer = tracer
|
||||
self.interceptor = interceptor
|
||||
self.server_decorator=server_decorator
|
||||
self.server_decorator = server_decorator
|
||||
|
||||
def decorate(self, server):
|
||||
return self.server_decorator(server, self.interceptor)
|
||||
|
@ -16,7 +17,7 @@ class Tracer:
|
|||
self.tracer and self.tracer.close()
|
||||
|
||||
def start_span(self, operation_name=None,
|
||||
child_of=None, references=None, tags=None,
|
||||
start_time=None, ignore_active_span=False):
|
||||
child_of=None, references=None, tags=None,
|
||||
start_time=None, ignore_active_span=False):
|
||||
return self.tracer.start_span(operation_name, child_of,
|
||||
references, tags, start_time, ignore_active_span)
|
||||
references, tags, start_time, ignore_active_span)
|
||||
|
|
|
@ -4,7 +4,7 @@ from grpc_opentracing.grpcext import intercept_server
|
|||
from grpc_opentracing import open_tracing_server_interceptor
|
||||
|
||||
from tracing import (Tracer,
|
||||
empty_server_interceptor_decorator)
|
||||
empty_server_interceptor_decorator)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -17,14 +17,14 @@ class TracerFactory:
|
|||
|
||||
if tracer_type.lower() == 'jaeger':
|
||||
config = Config(config=tracer_config.TRACING_CONFIG,
|
||||
service_name=tracer_config.TRACING_SERVICE_NAME,
|
||||
validate=tracer_config.TRACING_VALIDATE
|
||||
)
|
||||
service_name=tracer_config.TRACING_SERVICE_NAME,
|
||||
validate=tracer_config.TRACING_VALIDATE
|
||||
)
|
||||
|
||||
tracer = config.initialize_tracer()
|
||||
tracer_interceptor = open_tracing_server_interceptor(tracer,
|
||||
log_payloads=tracer_config.TRACING_LOG_PAYLOAD,
|
||||
span_decorator=span_decorator)
|
||||
log_payloads=tracer_config.TRACING_LOG_PAYLOAD,
|
||||
span_decorator=span_decorator)
|
||||
|
||||
return Tracer(tracer, tracer_interceptor, intercept_server)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from functools import wraps
|
||||
|
||||
|
||||
def singleton(cls):
|
||||
instances = {}
|
||||
@wraps(cls)
|
||||
|
|
|
@ -9,18 +9,22 @@ class InfoFilter(logging.Filter):
|
|||
def filter(self, rec):
|
||||
return rec.levelno == logging.INFO
|
||||
|
||||
|
||||
class DebugFilter(logging.Filter):
|
||||
def filter(self, rec):
|
||||
return rec.levelno == logging.DEBUG
|
||||
|
||||
|
||||
class WarnFilter(logging.Filter):
|
||||
def filter(self, rec):
|
||||
return rec.levelno == logging.WARN
|
||||
|
||||
|
||||
class ErrorFilter(logging.Filter):
|
||||
def filter(self, rec):
|
||||
return rec.levelno == logging.ERROR
|
||||
|
||||
|
||||
class CriticalFilter(logging.Filter):
|
||||
def filter(self, rec):
|
||||
return rec.levelno == logging.CRITICAL
|
||||
|
@ -36,6 +40,7 @@ COLORS = {
|
|||
'ENDC': '\033[0m',
|
||||
}
|
||||
|
||||
|
||||
class ColorFulFormatColMixin:
|
||||
def format_col(self, message_str, level_name):
|
||||
if level_name in COLORS.keys():
|
||||
|
@ -43,12 +48,14 @@ class ColorFulFormatColMixin:
|
|||
'ENDC')
|
||||
return message_str
|
||||
|
||||
|
||||
class ColorfulFormatter(logging.Formatter, ColorFulFormatColMixin):
|
||||
def format(self, record):
|
||||
message_str = super(ColorfulFormatter, self).format(record)
|
||||
|
||||
return self.format_col(message_str, level_name=record.levelname)
|
||||
|
||||
|
||||
def config(log_level, log_path, name, tz='UTC'):
|
||||
def build_log_file(level, log_path, name, tz):
|
||||
utc_now = datetime.datetime.utcnow()
|
||||
|
@ -56,7 +63,7 @@ def config(log_level, log_path, name, tz='UTC'):
|
|||
local_tz = timezone(tz)
|
||||
tznow = utc_now.replace(tzinfo=utc_tz).astimezone(local_tz)
|
||||
return '{}-{}-{}.log'.format(os.path.join(log_path, name), tznow.strftime("%m-%d-%Y-%H:%M:%S"),
|
||||
level)
|
||||
level)
|
||||
|
||||
if not os.path.exists(log_path):
|
||||
os.makedirs(log_path)
|
||||
|
@ -66,10 +73,10 @@ def config(log_level, log_path, name, tz='UTC'):
|
|||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
'default': {
|
||||
'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)'
|
||||
'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)'
|
||||
},
|
||||
'colorful_console': {
|
||||
'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)',
|
||||
'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)',
|
||||
'()': ColorfulFormatter,
|
||||
},
|
||||
},
|
||||
|
@ -133,8 +140,8 @@ def config(log_level, log_path, name, tz='UTC'):
|
|||
},
|
||||
'loggers': {
|
||||
'': {
|
||||
'handlers': ['milvus_celery_console', 'milvus_info_file', 'milvus_debug_file', 'milvus_warn_file', \
|
||||
'milvus_error_file', 'milvus_critical_file'],
|
||||
'handlers': ['milvus_celery_console', 'milvus_info_file', 'milvus_debug_file', 'milvus_warn_file',
|
||||
'milvus_error_file', 'milvus_critical_file'],
|
||||
'level': log_level,
|
||||
'propagate': False
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue