update for code style

pull/232/head
peng.xu 2019-10-14 13:42:12 +08:00
parent bef93edab9
commit 71c67f59a3
25 changed files with 201 additions and 143 deletions

View File

@ -4,6 +4,7 @@ from mishards import settings, db, create_app
logger = logging.getLogger(__name__)
@pytest.fixture
def app(request):
app = create_app(settings.TestingConfig)

View File

@ -2,6 +2,7 @@ import fire
from mishards import db
from sqlalchemy import and_
class DBHandler:
@classmethod
def create_all(cls):
@ -15,9 +16,9 @@ class DBHandler:
def fun(cls, tid):
from mishards.factories import TablesFactory, TableFilesFactory, Tables
f = db.Session.query(Tables).filter(and_(
Tables.table_id==tid,
Tables.state!=Tables.TO_DELETE)
).first()
Tables.table_id == tid,
Tables.state != Tables.TO_DELETE)
).first()
print(f)
# f1 = TableFilesFactory()

View File

@ -1,4 +1,4 @@
import logging
import logging
from mishards import settings
logger = logging.getLogger()
@ -8,6 +8,7 @@ db = DB()
from mishards.server import Server
grpc_server = Server()
def create_app(testing_config=None):
config = testing_config if testing_config else settings.DefaultConfig
db.init_db(uri=config.SQLALCHEMY_DATABASE_URI, echo=config.SQL_ECHO)
@ -24,7 +25,7 @@ def create_app(testing_config=None):
from tracing.factory import TracerFactory
from mishards.grpc_utils import GrpcSpanDecorator
tracer = TracerFactory.new_tracer(settings.TRACING_TYPE, settings.TracingConfig,
span_decorator=GrpcSpanDecorator())
span_decorator=GrpcSpanDecorator())
grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, discover=discover)

View File

@ -10,6 +10,7 @@ from utils import singleton
logger = logging.getLogger(__name__)
class Connection:
def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
self.name = name
@ -55,7 +56,7 @@ class Connection:
if not self.can_retry and not self.connected:
raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry,
metadata=metadata))
metadata=metadata))
self.retried = 0
@ -72,6 +73,7 @@ class Connection:
raise e
return inner
@singleton
class ConnectionMgr:
def __init__(self):
@ -90,10 +92,10 @@ class ConnectionMgr:
if not throw:
return None
raise exceptions.ConnectionNotFoundError(message='Connection {} not found'.format(name),
metadata=metadata)
metadata=metadata)
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
threaded = {
threading.get_ident() : this_conn
threading.get_ident(): this_conn
}
self.conns[name] = threaded
return this_conn
@ -106,7 +108,7 @@ class ConnectionMgr:
if not throw:
return None
raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name),
metadata=metadata)
metadata=metadata)
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
c[tid] = this_conn
return this_conn

View File

@ -14,8 +14,10 @@ class LocalSession(SessionBase):
bind = options.pop('bind', None) or db.engine
SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, **options)
class DB:
Model = declarative_base()
def __init__(self, uri=None, echo=False):
self.echo = echo
uri and self.init_db(uri, echo)
@ -27,9 +29,9 @@ class DB:
self.engine = create_engine(url)
else:
self.engine = create_engine(uri, pool_size=100, pool_recycle=5, pool_timeout=30,
pool_pre_ping=True,
echo=echo,
max_overflow=0)
pool_pre_ping=True,
echo=echo,
max_overflow=0)
self.uri = uri
self.url = url

View File

@ -4,6 +4,7 @@ from mishards import grpc_server as server, exceptions
logger = logging.getLogger(__name__)
def resp_handler(err, error_code):
if not isinstance(err, exceptions.BaseException):
return status_pb2.Status(error_code=error_code, reason=str(err))
@ -50,21 +51,25 @@ def resp_handler(err, error_code):
status.error_code = status_pb2.UNEXPECTED_ERROR
return status
@server.errorhandler(exceptions.TableNotFoundError)
def TableNotFoundErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.TABLE_NOT_EXISTS)
@server.errorhandler(exceptions.InvalidArgumentError)
def InvalidArgumentErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT)
@server.errorhandler(exceptions.DBError)
def DBErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.UNEXPECTED_ERROR)
@server.errorhandler(exceptions.InvalidRangeError)
def InvalidArgumentErrorHandler(err):
logger.error(err)

View File

@ -1,26 +1,34 @@
import mishards.exception_codes as codes
class BaseException(Exception):
code = codes.INVALID_CODE
message = 'BaseException'
def __init__(self, message='', metadata=None):
self.message = self.__class__.__name__ if not message else message
self.metadata = metadata
class ConnectionConnectError(BaseException):
code = codes.CONNECT_ERROR_CODE
class ConnectionNotFoundError(BaseException):
code = codes.CONNECTTION_NOT_FOUND_CODE
class DBError(BaseException):
code = codes.DB_ERROR_CODE
class TableNotFoundError(BaseException):
code = codes.TABLE_NOT_FOUND_CODE
class InvalidArgumentError(BaseException):
code = codes.INVALID_ARGUMENT_CODE
class InvalidRangeError(BaseException):
code = codes.INVALID_DATE_RANGE_CODE

View File

@ -9,13 +9,16 @@ from faker.providers import BaseProvider
from mishards import db
from mishards.models import Tables, TableFiles
class FakerProvider(BaseProvider):
def this_date(self):
t = datetime.datetime.today()
return (t.year - 1900) * 10000 + (t.month-1)*100 + t.day
return (t.year - 1900) * 10000 + (t.month - 1) * 100 + t.day
factory.Faker.add_provider(FakerProvider)
class TablesFactory(SQLAlchemyModelFactory):
class Meta:
model = Tables
@ -24,14 +27,15 @@ class TablesFactory(SQLAlchemyModelFactory):
id = factory.Faker('random_number', digits=16, fix_len=True)
table_id = factory.Faker('uuid4')
state = factory.Faker('random_element', elements=(0,1,2,3))
dimension = factory.Faker('random_element', elements=(256,512))
state = factory.Faker('random_element', elements=(0, 1, 2, 3))
dimension = factory.Faker('random_element', elements=(256, 512))
created_on = int(time.time())
index_file_size = 0
engine_type = factory.Faker('random_element', elements=(0,1,2,3))
metric_type = factory.Faker('random_element', elements=(0,1))
engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3))
metric_type = factory.Faker('random_element', elements=(0, 1))
nlist = 16384
class TableFilesFactory(SQLAlchemyModelFactory):
class Meta:
model = TableFiles
@ -40,9 +44,9 @@ class TableFilesFactory(SQLAlchemyModelFactory):
id = factory.Faker('random_number', digits=16, fix_len=True)
table = factory.SubFactory(TablesFactory)
engine_type = factory.Faker('random_element', elements=(0,1,2,3))
engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3))
file_id = factory.Faker('uuid4')
file_type = factory.Faker('random_element', elements=(0,1,2,3,4))
file_type = factory.Faker('random_element', elements=(0, 1, 2, 3, 4))
file_size = factory.Faker('random_number')
updated_time = int(time.time())
created_on = int(time.time())

View File

@ -14,21 +14,23 @@ class GrpcSpanDecorator(SpanDecorator):
status = rpc_info.response.status
except Exception as e:
status = status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR,
reason='Should not happen')
reason='Should not happen')
if status.error_code == 0:
return
error_log = {'event': 'error',
'request': rpc_info.request,
'response': rpc_info.response
}
'request': rpc_info.request,
'response': rpc_info.response
}
span.set_tag('error', True)
span.log_kv(error_log)
def mark_grpc_method(func):
setattr(func, 'grpc_method', True)
return func
def is_grpc_method(func):
if not func:
return False

View File

@ -1,4 +1,4 @@
# class GrpcArgsWrapper(object):
# @classmethod
# def proto_TableName(cls):
# @classmethod
# def proto_TableName(cls):

View File

@ -9,8 +9,8 @@ else:
import md5
md5_constructor = md5.new
class HashRing(object):
class HashRing(object):
def __init__(self, nodes=None, weights=None):
"""`nodes` is a list of objects that have a proper __str__ representation.
`weights` is dictionary that sets weights to the nodes. The default
@ -40,13 +40,13 @@ class HashRing(object):
if node in self.weights:
weight = self.weights.get(node)
factor = math.floor((40*len(self.nodes)*weight) / total_weight);
factor = math.floor((40 * len(self.nodes) * weight) / total_weight)
for j in range(0, int(factor)):
b_key = self._hash_digest( '%s-%s' % (node, j) )
b_key = self._hash_digest('%s-%s' % (node, j))
for i in range(0, 3):
key = self._hash_val(b_key, lambda x: x+i*4)
key = self._hash_val(b_key, lambda x: x + i * 4)
self.ring[key] = node
self._sorted_keys.append(key)
@ -60,7 +60,7 @@ class HashRing(object):
pos = self.get_node_pos(string_key)
if pos is None:
return None
return self.ring[ self._sorted_keys[pos] ]
return self.ring[self._sorted_keys[pos]]
def get_node_pos(self, string_key):
"""Given a string key a corresponding node in the hash ring is returned
@ -94,6 +94,7 @@ class HashRing(object):
yield None, None
returned_values = set()
def distinct_filter(value):
if str(value) not in returned_values:
returned_values.add(str(value))
@ -121,10 +122,8 @@ class HashRing(object):
return self._hash_val(b_key, lambda x: x)
def _hash_val(self, b_key, entry_fn):
return (( b_key[entry_fn(3)] << 24)
|(b_key[entry_fn(2)] << 16)
|(b_key[entry_fn(1)] << 8)
| b_key[entry_fn(0)] )
return (b_key[entry_fn(3)] << 24) | (b_key[entry_fn(2)] << 16) | (
b_key[entry_fn(1)] << 8) | b_key[entry_fn(0)]
def _hash_digest(self, key):
m = md5_constructor()
@ -132,12 +131,13 @@ class HashRing(object):
m.update(key)
return m.digest()
if __name__ == '__main__':
from collections import defaultdict
servers = ['192.168.0.246:11212',
'192.168.0.247:11212',
'192.168.0.248:11212',
'192.168.0.249:11212']
servers = [
'192.168.0.246:11212', '192.168.0.247:11212', '192.168.0.248:11212',
'192.168.0.249:11212'
]
ring = HashRing(servers)
keys = ['{}'.format(i) for i in range(100)]
@ -146,5 +146,5 @@ if __name__ == '__main__':
server = ring.get_node(k)
mapped[server].append(k)
for k,v in mapped.items():
for k, v in mapped.items():
print(k, v)

View File

@ -1,13 +1,16 @@
import os, sys
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from mishards import (
settings, create_app)
from mishards import (settings, create_app)
def main():
server = create_app(settings.TestingConfig if settings.TESTING else settings.DefaultConfig)
server = create_app(
settings.TestingConfig if settings.TESTING else settings.DefaultConfig)
server.run(port=settings.SERVER_PORT)
return 0
if __name__ == '__main__':
sys.exit(main())

View File

@ -1,13 +1,14 @@
import logging
from sqlalchemy import (Integer, Boolean, Text,
String, BigInteger, func, and_, or_,
Column)
String, BigInteger, func, and_, or_,
Column)
from sqlalchemy.orm import relationship, backref
from mishards import db
logger = logging.getLogger(__name__)
class TableFiles(db.Model):
FILE_TYPE_NEW = 0
FILE_TYPE_RAW = 1
@ -57,16 +58,16 @@ class Tables(db.Model):
def files_to_search(self, date_range=None):
cond = or_(
TableFiles.file_type==TableFiles.FILE_TYPE_RAW,
TableFiles.file_type==TableFiles.FILE_TYPE_TO_INDEX,
TableFiles.file_type==TableFiles.FILE_TYPE_INDEX,
TableFiles.file_type == TableFiles.FILE_TYPE_RAW,
TableFiles.file_type == TableFiles.FILE_TYPE_TO_INDEX,
TableFiles.file_type == TableFiles.FILE_TYPE_INDEX,
)
if date_range:
cond = and_(
cond,
or_(
and_(TableFiles.date>=d[0], TableFiles.date<d[1]) for d in date_range
)
and_(TableFiles.date >= d[0], TableFiles.date < d[1]) for d in date_range
)
)
files = self.files.filter(cond)

View File

@ -33,7 +33,7 @@ class Server:
self.server_impl = grpc.server(
thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
(cygrpc.ChannelArgKey.max_receive_message_length, -1)]
(cygrpc.ChannelArgKey.max_receive_message_length, -1)]
)
self.server_impl = self.tracer.decorate(self.server_impl)
@ -46,7 +46,7 @@ class Server:
ip = socket.gethostbyname(url.hostname)
socket.inet_pton(socket.AF_INET, ip)
self.conn_mgr.register('WOSERVER',
'{}://{}:{}'.format(url.scheme, ip, url.port or 80))
'{}://{}:{}'.format(url.scheme, ip, url.port or 80))
def register_pre_run_handler(self, func):
logger.info('Regiterring {} into server pre_run_handlers'.format(func))

View File

@ -11,7 +11,7 @@ from concurrent.futures import ThreadPoolExecutor
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
from milvus.grpc_gen.milvus_pb2 import TopKQueryResult
from milvus.client.Abstract import Range
from milvus.client import types
from milvus.client import types as Types
from mishards import (db, settings, exceptions)
from mishards.grpc_utils import mark_grpc_method
@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
MAX_NPROBE = 2048
def __init__(self, conn_mgr, tracer, *args, **kwargs):
self.conn_mgr = conn_mgr
self.table_meta = {}
@ -44,8 +45,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return conn.conn
def _format_date(self, start, end):
return ((start.year-1900)*10000 + (start.month-1)*100 + start.day
, (end.year-1900)*10000 + (end.month-1)*100 + end.day)
return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day, (end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day)
def _range_to_date(self, range_obj, metadata=None):
try:
@ -54,8 +54,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
assert start < end
except (ValueError, AssertionError):
raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format(
range_obj.start_date, range_obj.end_date
), metadata=metadata)
range_obj.start_date, range_obj.end_date
), metadata=metadata)
return self._format_date(start, end)
@ -63,9 +63,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
# PXU TODO: Implement Thread-local Context
try:
table = db.Session.query(Tables).filter(and_(
Tables.table_id==table_id,
Tables.state!=Tables.TO_DELETE
)).first()
Tables.table_id == table_id,
Tables.state != Tables.TO_DELETE
)).first()
except sqlalchemy_exc.SQLAlchemyError as e:
raise exceptions.DBError(message=str(e), metadata=metadata)
@ -93,7 +93,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return routing
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success")
status = status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success")
if not files_n_topk_results:
return status, []
@ -107,7 +107,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
for request_pos, each_request_results in enumerate(files_collection.topk_query_result):
request_results[request_pos].extend(each_request_results.query_result_arrays)
request_results[request_pos] = sorted(request_results[request_pos], key=lambda x: x.distance,
reverse=reverse)[:topk]
reverse=reverse)[:topk]
calc_time = time.time() - calc_time
logger.info('Merge takes {}'.format(calc_time))
@ -127,7 +127,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
routing = {}
with self.tracer.start_span('get_routing',
child_of=context.get_active_span().context):
child_of=context.get_active_span().context):
routing = self._get_routing_file_ids(table_id, range_array, metadata=metadata)
logger.info('Routing: {}'.format(routing))
@ -140,28 +140,28 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
def search(addr, query_params, vectors, topk, nprobe, **kwargs):
logger.info('Send Search Request: addr={};params={};nq={};topk={};nprobe={}'.format(
addr, query_params, len(vectors), topk, nprobe
))
addr, query_params, len(vectors), topk, nprobe
))
conn = self.query_conn(addr, metadata=metadata)
start = time.time()
span = kwargs.get('span', None)
span = span if span else context.get_active_span().context
with self.tracer.start_span('search_{}'.format(addr),
child_of=context.get_active_span().context):
child_of=context.get_active_span().context):
ret = conn.search_vectors_in_files(table_name=query_params['table_id'],
file_ids=query_params['file_ids'],
query_records=vectors,
top_k=topk,
nprobe=nprobe,
lazy=True)
file_ids=query_params['file_ids'],
query_records=vectors,
top_k=topk,
nprobe=nprobe,
lazy=True)
end = time.time()
logger.info('search_vectors_in_files takes: {}'.format(end - start))
all_topk_results.append(ret)
with self.tracer.start_span('do_search',
child_of=context.get_active_span().context) as span:
child_of=context.get_active_span().context) as span:
with ThreadPoolExecutor(max_workers=workers) as pool:
for addr, params in routing.items():
res = pool.submit(search, addr, params, vectors, topk, nprobe, span=span)
@ -170,9 +170,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
for res in rs:
res.result()
reverse = table_meta.metric_type == types.MetricType.IP
reverse = table_meta.metric_type == Types.MetricType.IP
with self.tracer.start_span('do_merge',
child_of=context.get_active_span().context):
child_of=context.get_active_span().context):
return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata)
@mark_grpc_method
@ -201,8 +201,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
logger.info('HasTable {}'.format(_table_name))
_bool = self.connection(metadata={
'resp_class': milvus_pb2.BoolReply
}).has_table(_table_name)
'resp_class': milvus_pb2.BoolReply
}).has_table(_table_name)
return milvus_pb2.BoolReply(
status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"),
@ -244,7 +244,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
_status, _ids = self.connection(metadata={
'resp_class': milvus_pb2.VectorIds
}).add_vectors(None, None, insert_param=request)
}).add_vectors(None, None, insert_param=request)
return milvus_pb2.VectorIds(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
vector_id_array=_ids
@ -266,7 +266,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
if nprobe > self.MAX_NPROBE or nprobe <= 0:
raise exceptions.InvalidArgumentError(message='Invalid nprobe: {}'.format(nprobe),
metadata=metadata)
metadata=metadata)
table_meta = self.table_meta.get(table_name, None)
@ -332,8 +332,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
)
return milvus_pb2.TableSchema(
table_name=_table_name,
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_name=_table_name,
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
)
@mark_grpc_method
@ -391,8 +391,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
_status, _results = self.connection(metadata=metadata).show_tables()
return milvus_pb2.TableNameList(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_names=_results
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_names=_results
)
@mark_grpc_method
@ -426,7 +426,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
if not _status.OK():
return milvus_pb2.IndexParam(
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
)
metadata = {
@ -439,7 +439,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
_index = milvus_pb2.Index(index_type=_index_param._index_type, nlist=_index_param._nlist)
return milvus_pb2.IndexParam(status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_name=_table_name, index=_index)
table_name=_table_name, index=_index)
@mark_grpc_method
def DropIndex(self, request, context):

View File

@ -39,13 +39,15 @@ if SD_PROVIDER == 'Kubernetes':
elif SD_PROVIDER == 'Static':
from sd.static_provider import StaticProviderSettings
SD_PROVIDER_SETTINGS = StaticProviderSettings(
hosts=env.list('SD_STATIC_HOSTS', [])
)
hosts=env.list('SD_STATIC_HOSTS', [])
)
TESTING = env.bool('TESTING', False)
TESTING_WOSERVER = env.str('TESTING_WOSERVER', 'tcp://127.0.0.1:19530')
TRACING_TYPE = env.str('TRACING_TYPE', '')
class TracingConfig:
TRACING_SERVICE_NAME = env.str('TRACING_SERVICE_NAME', 'mishards')
TRACING_VALIDATE = env.bool('TRACING_VALIDATE', True)
@ -54,7 +56,7 @@ class TracingConfig:
'sampler': {
'type': env.str('TRACING_SAMPLER_TYPE', 'const'),
'param': env.str('TRACING_SAMPLER_PARAM', "1"),
},
},
'local_agent': {
'reporting_host': env.str('TRACING_REPORTING_HOST', '127.0.0.1'),
'reporting_port': env.str('TRACING_REPORTING_PORT', '5775')
@ -62,10 +64,12 @@ class TracingConfig:
'logging': env.bool('TRACING_LOGGING', True)
}
class DefaultConfig:
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI')
SQL_ECHO = env.bool('SQL_ECHO', False)
TESTING = env.bool('TESTING', False)
if TESTING:
class TestingConfig(DefaultConfig):

View File

@ -6,6 +6,7 @@ from mishards import exceptions
logger = logging.getLogger(__name__)
@pytest.mark.usefixtures('app')
class TestConnection:
def test_manager(self):
@ -30,8 +31,10 @@ class TestConnection:
class Conn:
def __init__(self, state):
self.state = state
def connect(self, uri):
return self.state
def connected(self):
return self.state
FAIL_CONN = Conn(False)
@ -48,6 +51,7 @@ class TestConnection:
class Func():
def __init__(self):
self.executed = False
def __call__(self):
self.executed = True
@ -55,8 +59,8 @@ class TestConnection:
RetryObj = Retry()
c = Connection('client', uri='',
max_retry=max_retry,
on_retry_func=RetryObj)
max_retry=max_retry,
on_retry_func=RetryObj)
c.conn = FAIL_CONN
ff = Func()
this_connect = c.connect(func=ff)

View File

@ -3,12 +3,13 @@ import pytest
from mishards.factories import TableFiles, Tables, TableFilesFactory, TablesFactory
from mishards import db, create_app, settings
from mishards.factories import (
Tables, TableFiles,
TablesFactory, TableFilesFactory
)
Tables, TableFiles,
TablesFactory, TableFilesFactory
)
logger = logging.getLogger(__name__)
@pytest.mark.usefixtures('app')
class TestModels:
def test_files_to_search(self):

View File

@ -24,4 +24,5 @@ class ProviderManager:
def get_provider(cls, name):
return cls.PROVIDERS.get(name, None)
from sd import kubernetes_provider, static_provider

View File

@ -1,4 +1,5 @@
import os, sys
import os
import sys
if __name__ == '__main__':
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@ -71,7 +72,6 @@ class K8SHeartbeatHandler(threading.Thread, K8SMixin):
self.queue.put(event_message)
except Exception as exc:
logger.error(exc)
@ -98,18 +98,18 @@ class K8SEventListener(threading.Thread, K8SMixin):
resource_version = ''
w = watch.Watch()
for event in w.stream(self.v1.list_namespaced_event, namespace=self.namespace,
field_selector='involvedObject.kind=Pod'):
field_selector='involvedObject.kind=Pod'):
if self.terminate:
break
resource_version = int(event['object'].metadata.resource_version)
info = dict(
eType='WatchEvent',
pod=event['object'].involved_object.name,
reason=event['object'].reason,
message=event['object'].message,
start_up=self.at_start_up,
eType='WatchEvent',
pod=event['object'].involved_object.name,
reason=event['object'].reason,
message=event['object'].message,
start_up=self.at_start_up,
)
self.at_start_up = False
# logger.info('Received event: {}'.format(info))
@ -135,7 +135,7 @@ class EventHandler(threading.Thread):
def on_pod_started(self, event, **kwargs):
try_cnt = 3
pod = None
while try_cnt > 0:
while try_cnt > 0:
try_cnt -= 1
try:
pod = self.mgr.v1.read_namespaced_pod(name=event['pod'], namespace=self.namespace)
@ -203,6 +203,7 @@ class EventHandler(threading.Thread):
except queue.Empty:
continue
class KubernetesProviderSettings:
def __init__(self, namespace, pod_patt, label_selector, in_cluster, poll_interval, **kwargs):
self.namespace = namespace
@ -211,10 +212,12 @@ class KubernetesProviderSettings:
self.in_cluster = in_cluster
self.poll_interval = poll_interval
@singleton
@ProviderManager.register_service_provider
class KubernetesProvider(object):
NAME = 'Kubernetes'
def __init__(self, settings, conn_mgr, **kwargs):
self.namespace = settings.namespace
self.pod_patt = settings.pod_patt
@ -233,27 +236,27 @@ class KubernetesProvider(object):
self.v1 = client.CoreV1Api()
self.listener = K8SEventListener(
message_queue=self.queue,
namespace=self.namespace,
in_cluster=self.in_cluster,
v1=self.v1,
**kwargs
)
message_queue=self.queue,
namespace=self.namespace,
in_cluster=self.in_cluster,
v1=self.v1,
**kwargs
)
self.pod_heartbeater = K8SHeartbeatHandler(
message_queue=self.queue,
namespace=self.namespace,
label_selector=self.label_selector,
in_cluster=self.in_cluster,
v1=self.v1,
poll_interval=self.poll_interval,
**kwargs
)
message_queue=self.queue,
namespace=self.namespace,
label_selector=self.label_selector,
in_cluster=self.in_cluster,
v1=self.v1,
poll_interval=self.poll_interval,
**kwargs
)
self.event_handler = EventHandler(mgr=self,
message_queue=self.queue,
namespace=self.namespace,
pod_patt=self.pod_patt, **kwargs)
message_queue=self.queue,
namespace=self.namespace,
pod_patt=self.pod_patt, **kwargs)
def add_pod(self, name, ip):
self.conn_mgr.register(name, 'tcp://{}:19530'.format(ip))
@ -276,9 +279,11 @@ class KubernetesProvider(object):
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
class Connect:
def register(self, name, value):
logger.error('Register: {} - {}'.format(name, value))
def unregister(self, name):
logger.error('Unregister: {}'.format(name))
@ -289,16 +294,16 @@ if __name__ == '__main__':
connect_mgr = Connect()
settings = KubernetesProviderSettings(
namespace='xp',
pod_patt=".*-ro-servers-.*",
label_selector='tier=ro-servers',
poll_interval=5,
in_cluster=False)
namespace='xp',
pod_patt=".*-ro-servers-.*",
label_selector='tier=ro-servers',
poll_interval=5,
in_cluster=False)
provider_class = ProviderManager.get_provider('Kubernetes')
t = provider_class(conn_mgr=connect_mgr,
settings=settings
)
settings=settings
)
t.start()
cnt = 100
while cnt > 0:

View File

@ -1,4 +1,5 @@
import os, sys
import os
import sys
if __name__ == '__main__':
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@ -6,14 +7,17 @@ import socket
from utils import singleton
from sd import ProviderManager
class StaticProviderSettings:
def __init__(self, hosts):
self.hosts = hosts
@singleton
@ProviderManager.register_service_provider
class KubernetesProvider(object):
NAME = 'Static'
def __init__(self, settings, conn_mgr, **kwargs):
self.conn_mgr = conn_mgr
self.hosts = [socket.gethostbyname(host) for host in settings.hosts]

View File

@ -1,13 +1,14 @@
def empty_server_interceptor_decorator(target_server, interceptor):
return target_server
class Tracer:
def __init__(self, tracer=None,
interceptor=None,
server_decorator=empty_server_interceptor_decorator):
interceptor=None,
server_decorator=empty_server_interceptor_decorator):
self.tracer = tracer
self.interceptor = interceptor
self.server_decorator=server_decorator
self.server_decorator = server_decorator
def decorate(self, server):
return self.server_decorator(server, self.interceptor)
@ -16,7 +17,7 @@ class Tracer:
self.tracer and self.tracer.close()
def start_span(self, operation_name=None,
child_of=None, references=None, tags=None,
start_time=None, ignore_active_span=False):
child_of=None, references=None, tags=None,
start_time=None, ignore_active_span=False):
return self.tracer.start_span(operation_name, child_of,
references, tags, start_time, ignore_active_span)
references, tags, start_time, ignore_active_span)

View File

@ -4,7 +4,7 @@ from grpc_opentracing.grpcext import intercept_server
from grpc_opentracing import open_tracing_server_interceptor
from tracing import (Tracer,
empty_server_interceptor_decorator)
empty_server_interceptor_decorator)
logger = logging.getLogger(__name__)
@ -17,14 +17,14 @@ class TracerFactory:
if tracer_type.lower() == 'jaeger':
config = Config(config=tracer_config.TRACING_CONFIG,
service_name=tracer_config.TRACING_SERVICE_NAME,
validate=tracer_config.TRACING_VALIDATE
)
service_name=tracer_config.TRACING_SERVICE_NAME,
validate=tracer_config.TRACING_VALIDATE
)
tracer = config.initialize_tracer()
tracer_interceptor = open_tracing_server_interceptor(tracer,
log_payloads=tracer_config.TRACING_LOG_PAYLOAD,
span_decorator=span_decorator)
log_payloads=tracer_config.TRACING_LOG_PAYLOAD,
span_decorator=span_decorator)
return Tracer(tracer, tracer_interceptor, intercept_server)

View File

@ -1,5 +1,6 @@
from functools import wraps
def singleton(cls):
instances = {}
@wraps(cls)

View File

@ -9,18 +9,22 @@ class InfoFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.INFO
class DebugFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.DEBUG
class WarnFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.WARN
class ErrorFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.ERROR
class CriticalFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.CRITICAL
@ -36,6 +40,7 @@ COLORS = {
'ENDC': '\033[0m',
}
class ColorFulFormatColMixin:
def format_col(self, message_str, level_name):
if level_name in COLORS.keys():
@ -43,12 +48,14 @@ class ColorFulFormatColMixin:
'ENDC')
return message_str
class ColorfulFormatter(logging.Formatter, ColorFulFormatColMixin):
def format(self, record):
message_str = super(ColorfulFormatter, self).format(record)
return self.format_col(message_str, level_name=record.levelname)
def config(log_level, log_path, name, tz='UTC'):
def build_log_file(level, log_path, name, tz):
utc_now = datetime.datetime.utcnow()
@ -56,7 +63,7 @@ def config(log_level, log_path, name, tz='UTC'):
local_tz = timezone(tz)
tznow = utc_now.replace(tzinfo=utc_tz).astimezone(local_tz)
return '{}-{}-{}.log'.format(os.path.join(log_path, name), tznow.strftime("%m-%d-%Y-%H:%M:%S"),
level)
level)
if not os.path.exists(log_path):
os.makedirs(log_path)
@ -66,10 +73,10 @@ def config(log_level, log_path, name, tz='UTC'):
'disable_existing_loggers': False,
'formatters': {
'default': {
'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)'
'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)'
},
'colorful_console': {
'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)',
'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)',
'()': ColorfulFormatter,
},
},
@ -133,8 +140,8 @@ def config(log_level, log_path, name, tz='UTC'):
},
'loggers': {
'': {
'handlers': ['milvus_celery_console', 'milvus_info_file', 'milvus_debug_file', 'milvus_warn_file', \
'milvus_error_file', 'milvus_critical_file'],
'handlers': ['milvus_celery_console', 'milvus_info_file', 'milvus_debug_file', 'milvus_warn_file',
'milvus_error_file', 'milvus_critical_file'],
'level': log_level,
'propagate': False
},