mirror of https://github.com/milvus-io/milvus.git
fix bug in test_server
parent
7b0a731e04
commit
9a4c732563
10
Dockerfile
10
Dockerfile
|
@ -1,10 +0,0 @@
|
|||
FROM python:3.6
|
||||
RUN apt update && apt install -y \
|
||||
less \
|
||||
telnet
|
||||
RUN mkdir /source
|
||||
WORKDIR /source
|
||||
ADD ./requirements.txt ./
|
||||
RUN pip install -r requirements.txt
|
||||
COPY . .
|
||||
CMD python mishards/main.py
|
39
build.sh
39
build.sh
|
@ -1,39 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
BOLD=`tput bold`
|
||||
NORMAL=`tput sgr0`
|
||||
YELLOW='\033[1;33m'
|
||||
ENDC='\033[0m'
|
||||
|
||||
echo -e "${BOLD}MISHARDS_REGISTRY=${MISHARDS_REGISTRY}${ENDC}"
|
||||
|
||||
function build_image() {
|
||||
dockerfile=$1
|
||||
remote_registry=$2
|
||||
tagged=$2
|
||||
buildcmd="docker build -t ${tagged} -f ${dockerfile} ."
|
||||
echo -e "${BOLD}$buildcmd${NORMAL}"
|
||||
$buildcmd
|
||||
pushcmd="docker push ${remote_registry}"
|
||||
echo -e "${BOLD}$pushcmd${NORMAL}"
|
||||
$pushcmd
|
||||
echo -e "${YELLOW}${BOLD}Image: ${remote_registry}${NORMAL}${ENDC}"
|
||||
}
|
||||
|
||||
case "$1" in
|
||||
|
||||
all)
|
||||
[[ -z $MISHARDS_REGISTRY ]] && {
|
||||
echo -e "${YELLOW}Error: Please set docker registry first:${ENDC}\n\t${BOLD}export MISHARDS_REGISTRY=xxxx\n${ENDC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
version=""
|
||||
[[ ! -z $2 ]] && version=":${2}"
|
||||
build_image "Dockerfile" "${MISHARDS_REGISTRY}${version}" "${MISHARDS_REGISTRY}"
|
||||
;;
|
||||
*)
|
||||
echo "Usage: [option...] {base | apps}"
|
||||
echo "all, Usage: build.sh all [tagname|] => {docker_registry}:\${tagname}"
|
||||
;;
|
||||
esac
|
27
conftest.py
27
conftest.py
|
@ -1,27 +0,0 @@
|
|||
import logging
|
||||
import pytest
|
||||
import grpc
|
||||
from mishards import settings, db, create_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(request):
|
||||
app = create_app(settings.TestingConfig)
|
||||
db.drop_all()
|
||||
db.create_all()
|
||||
|
||||
yield app
|
||||
|
||||
db.drop_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def started_app(app):
|
||||
app.on_pre_run()
|
||||
app.start(settings.SERVER_TEST_PORT)
|
||||
|
||||
yield app
|
||||
|
||||
app.stop()
|
28
manager.py
28
manager.py
|
@ -1,28 +0,0 @@
|
|||
import fire
|
||||
from mishards import db
|
||||
from sqlalchemy import and_
|
||||
|
||||
|
||||
class DBHandler:
|
||||
@classmethod
|
||||
def create_all(cls):
|
||||
db.create_all()
|
||||
|
||||
@classmethod
|
||||
def drop_all(cls):
|
||||
db.drop_all()
|
||||
|
||||
@classmethod
|
||||
def fun(cls, tid):
|
||||
from mishards.factories import TablesFactory, TableFilesFactory, Tables
|
||||
f = db.Session.query(Tables).filter(and_(
|
||||
Tables.table_id == tid,
|
||||
Tables.state != Tables.TO_DELETE)
|
||||
).first()
|
||||
print(f)
|
||||
|
||||
# f1 = TableFilesFactory()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(DBHandler)
|
|
@ -1,33 +0,0 @@
|
|||
DEBUG=True
|
||||
|
||||
WOSERVER=tcp://127.0.0.1:19530
|
||||
SERVER_PORT=19532
|
||||
SERVER_TEST_PORT=19888
|
||||
|
||||
SD_PROVIDER=Static
|
||||
|
||||
SD_NAMESPACE=xp
|
||||
SD_IN_CLUSTER=False
|
||||
SD_POLL_INTERVAL=5
|
||||
SD_ROSERVER_POD_PATT=.*-ro-servers-.*
|
||||
SD_LABEL_SELECTOR=tier=ro-servers
|
||||
|
||||
SD_STATIC_HOSTS=127.0.0.1
|
||||
SD_STATIC_PORT=19530
|
||||
|
||||
#SQLALCHEMY_DATABASE_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4
|
||||
SQLALCHEMY_DATABASE_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False
|
||||
SQL_ECHO=True
|
||||
|
||||
#SQLALCHEMY_DATABASE_TEST_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4
|
||||
SQLALCHEMY_DATABASE_TEST_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False
|
||||
SQL_TEST_ECHO=False
|
||||
|
||||
# TRACING_TEST_TYPE=jaeger
|
||||
TRACING_TYPE=jaeger
|
||||
TRACING_SERVICE_NAME=fortest
|
||||
TRACING_SAMPLER_TYPE=const
|
||||
TRACING_SAMPLER_PARAM=1
|
||||
TRACING_LOG_PAYLOAD=True
|
||||
#TRACING_SAMPLER_TYPE=probabilistic
|
||||
#TRACING_SAMPLER_PARAM=0.5
|
|
@ -1,36 +0,0 @@
|
|||
import logging
|
||||
from mishards import settings
|
||||
logger = logging.getLogger()
|
||||
|
||||
from mishards.db_base import DB
|
||||
db = DB()
|
||||
|
||||
from mishards.server import Server
|
||||
grpc_server = Server()
|
||||
|
||||
|
||||
def create_app(testing_config=None):
|
||||
config = testing_config if testing_config else settings.DefaultConfig
|
||||
db.init_db(uri=config.SQLALCHEMY_DATABASE_URI, echo=config.SQL_ECHO)
|
||||
|
||||
from mishards.connections import ConnectionMgr
|
||||
connect_mgr = ConnectionMgr()
|
||||
|
||||
from sd import ProviderManager
|
||||
|
||||
sd_proiver_class = ProviderManager.get_provider(settings.SD_PROVIDER)
|
||||
discover = sd_proiver_class(settings=settings.SD_PROVIDER_SETTINGS, conn_mgr=connect_mgr)
|
||||
|
||||
from tracing.factory import TracerFactory
|
||||
from mishards.grpc_utils import GrpcSpanDecorator
|
||||
tracer = TracerFactory.new_tracer(config.TRACING_TYPE, settings.TracingConfig,
|
||||
span_decorator=GrpcSpanDecorator())
|
||||
|
||||
from mishards.routings import RouterFactory
|
||||
router = RouterFactory.new_router(config.ROUTER_CLASS_NAME, connect_mgr)
|
||||
|
||||
grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, router=router, discover=discover)
|
||||
|
||||
from mishards import exception_handlers
|
||||
|
||||
return grpc_server
|
|
@ -1,154 +0,0 @@
|
|||
import logging
|
||||
import threading
|
||||
from functools import wraps
|
||||
from milvus import Milvus
|
||||
|
||||
from mishards import (settings, exceptions)
|
||||
from utils import singleton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
|
||||
self.name = name
|
||||
self.uri = uri
|
||||
self.max_retry = max_retry
|
||||
self.retried = 0
|
||||
self.conn = Milvus()
|
||||
self.error_handlers = [] if not error_handlers else error_handlers
|
||||
self.on_retry_func = kwargs.get('on_retry_func', None)
|
||||
# self._connect()
|
||||
|
||||
def __str__(self):
|
||||
return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri)
|
||||
|
||||
def _connect(self, metadata=None):
|
||||
try:
|
||||
self.conn.connect(uri=self.uri)
|
||||
except Exception as e:
|
||||
if not self.error_handlers:
|
||||
raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata)
|
||||
for handler in self.error_handlers:
|
||||
handler(e, metadata=metadata)
|
||||
|
||||
@property
|
||||
def can_retry(self):
|
||||
return self.retried < self.max_retry
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
return self.conn.connected()
|
||||
|
||||
def on_retry(self):
|
||||
if self.on_retry_func:
|
||||
self.on_retry_func(self)
|
||||
else:
|
||||
self.retried > 1 and logger.warning('{} is retrying {}'.format(self, self.retried))
|
||||
|
||||
def on_connect(self, metadata=None):
|
||||
while not self.connected and self.can_retry:
|
||||
self.retried += 1
|
||||
self.on_retry()
|
||||
self._connect(metadata=metadata)
|
||||
|
||||
if not self.can_retry and not self.connected:
|
||||
raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry,
|
||||
metadata=metadata))
|
||||
|
||||
self.retried = 0
|
||||
|
||||
def connect(self, func, exception_handler=None):
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
self.on_connect()
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if exception_handler:
|
||||
exception_handler(e)
|
||||
else:
|
||||
raise e
|
||||
return inner
|
||||
|
||||
|
||||
@singleton
|
||||
class ConnectionMgr:
|
||||
def __init__(self):
|
||||
self.metas = {}
|
||||
self.conns = {}
|
||||
|
||||
@property
|
||||
def conn_names(self):
|
||||
return set(self.metas.keys()) - set(['WOSERVER'])
|
||||
|
||||
def conn(self, name, metadata, throw=False):
|
||||
c = self.conns.get(name, None)
|
||||
if not c:
|
||||
url = self.metas.get(name, None)
|
||||
if not url:
|
||||
if not throw:
|
||||
return None
|
||||
raise exceptions.ConnectionNotFoundError(message='Connection {} not found'.format(name),
|
||||
metadata=metadata)
|
||||
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
|
||||
threaded = {
|
||||
threading.get_ident(): this_conn
|
||||
}
|
||||
self.conns[name] = threaded
|
||||
return this_conn
|
||||
|
||||
tid = threading.get_ident()
|
||||
rconn = c.get(tid, None)
|
||||
if not rconn:
|
||||
url = self.metas.get(name, None)
|
||||
if not url:
|
||||
if not throw:
|
||||
return None
|
||||
raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name),
|
||||
metadata=metadata)
|
||||
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
|
||||
c[tid] = this_conn
|
||||
return this_conn
|
||||
|
||||
return rconn
|
||||
|
||||
def on_new_meta(self, name, url):
|
||||
logger.info('Register Connection: name={};url={}'.format(name, url))
|
||||
self.metas[name] = url
|
||||
|
||||
def on_duplicate_meta(self, name, url):
|
||||
if self.metas[name] == url:
|
||||
return self.on_same_meta(name, url)
|
||||
|
||||
return self.on_diff_meta(name, url)
|
||||
|
||||
def on_same_meta(self, name, url):
|
||||
# logger.warning('Register same meta: {}:{}'.format(name, url))
|
||||
pass
|
||||
|
||||
def on_diff_meta(self, name, url):
|
||||
logger.warning('Received {} with diff url={}'.format(name, url))
|
||||
self.metas[name] = url
|
||||
self.conns[name] = {}
|
||||
|
||||
def on_unregister_meta(self, name, url):
|
||||
logger.info('Unregister name={};url={}'.format(name, url))
|
||||
self.conns.pop(name, None)
|
||||
|
||||
def on_nonexisted_meta(self, name):
|
||||
logger.warning('Non-existed meta: {}'.format(name))
|
||||
|
||||
def register(self, name, url):
|
||||
meta = self.metas.get(name)
|
||||
if not meta:
|
||||
return self.on_new_meta(name, url)
|
||||
else:
|
||||
return self.on_duplicate_meta(name, url)
|
||||
|
||||
def unregister(self, name):
|
||||
logger.info('Unregister Connection: name={}'.format(name))
|
||||
url = self.metas.pop(name, None)
|
||||
if url is None:
|
||||
return self.on_nonexisted_meta(name)
|
||||
return self.on_unregister_meta(name, url)
|
|
@ -1,52 +0,0 @@
|
|||
import logging
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.orm.session import Session as SessionBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalSession(SessionBase):
|
||||
def __init__(self, db, autocommit=False, autoflush=True, **options):
|
||||
self.db = db
|
||||
bind = options.pop('bind', None) or db.engine
|
||||
SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, **options)
|
||||
|
||||
|
||||
class DB:
|
||||
Model = declarative_base()
|
||||
|
||||
def __init__(self, uri=None, echo=False):
|
||||
self.echo = echo
|
||||
uri and self.init_db(uri, echo)
|
||||
self.session_factory = scoped_session(sessionmaker(class_=LocalSession, db=self))
|
||||
|
||||
def init_db(self, uri, echo=False):
|
||||
url = make_url(uri)
|
||||
if url.get_backend_name() == 'sqlite':
|
||||
self.engine = create_engine(url)
|
||||
else:
|
||||
self.engine = create_engine(uri, pool_size=100, pool_recycle=5, pool_timeout=30,
|
||||
pool_pre_ping=True,
|
||||
echo=echo,
|
||||
max_overflow=0)
|
||||
self.uri = uri
|
||||
self.url = url
|
||||
|
||||
def __str__(self):
|
||||
return '<DB: backend={};database={}>'.format(self.url.get_backend_name(), self.url.database)
|
||||
|
||||
@property
|
||||
def Session(self):
|
||||
return self.session_factory()
|
||||
|
||||
def remove_session(self):
|
||||
self.session_factory.remove()
|
||||
|
||||
def drop_all(self):
|
||||
self.Model.metadata.drop_all(self.engine)
|
||||
|
||||
def create_all(self):
|
||||
self.Model.metadata.create_all(self.engine)
|
|
@ -1,10 +0,0 @@
|
|||
INVALID_CODE = -1
|
||||
|
||||
CONNECT_ERROR_CODE = 10001
|
||||
CONNECTTION_NOT_FOUND_CODE = 10002
|
||||
DB_ERROR_CODE = 10003
|
||||
|
||||
TABLE_NOT_FOUND_CODE = 20001
|
||||
INVALID_ARGUMENT_CODE = 20002
|
||||
INVALID_DATE_RANGE_CODE = 20003
|
||||
INVALID_TOPK_CODE = 20004
|
|
@ -1,82 +0,0 @@
|
|||
import logging
|
||||
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
|
||||
from mishards import grpc_server as server, exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resp_handler(err, error_code):
|
||||
if not isinstance(err, exceptions.BaseException):
|
||||
return status_pb2.Status(error_code=error_code, reason=str(err))
|
||||
|
||||
status = status_pb2.Status(error_code=error_code, reason=err.message)
|
||||
|
||||
if err.metadata is None:
|
||||
return status
|
||||
|
||||
resp_class = err.metadata.get('resp_class', None)
|
||||
if not resp_class:
|
||||
return status
|
||||
|
||||
if resp_class == milvus_pb2.BoolReply:
|
||||
return resp_class(status=status, bool_reply=False)
|
||||
|
||||
if resp_class == milvus_pb2.VectorIds:
|
||||
return resp_class(status=status, vector_id_array=[])
|
||||
|
||||
if resp_class == milvus_pb2.TopKQueryResultList:
|
||||
return resp_class(status=status, topk_query_result=[])
|
||||
|
||||
if resp_class == milvus_pb2.TableRowCount:
|
||||
return resp_class(status=status, table_row_count=-1)
|
||||
|
||||
if resp_class == milvus_pb2.TableName:
|
||||
return resp_class(status=status, table_name=[])
|
||||
|
||||
if resp_class == milvus_pb2.StringReply:
|
||||
return resp_class(status=status, string_reply='')
|
||||
|
||||
if resp_class == milvus_pb2.TableSchema:
|
||||
return milvus_pb2.TableSchema(
|
||||
status=status
|
||||
)
|
||||
|
||||
if resp_class == milvus_pb2.IndexParam:
|
||||
return milvus_pb2.IndexParam(
|
||||
table_name=milvus_pb2.TableName(
|
||||
status=status
|
||||
)
|
||||
)
|
||||
|
||||
status.error_code = status_pb2.UNEXPECTED_ERROR
|
||||
return status
|
||||
|
||||
|
||||
@server.errorhandler(exceptions.TableNotFoundError)
|
||||
def TableNotFoundErrorHandler(err):
|
||||
logger.error(err)
|
||||
return resp_handler(err, status_pb2.TABLE_NOT_EXISTS)
|
||||
|
||||
|
||||
@server.errorhandler(exceptions.InvalidTopKError)
|
||||
def InvalidTopKErrorHandler(err):
|
||||
logger.error(err)
|
||||
return resp_handler(err, status_pb2.ILLEGAL_TOPK)
|
||||
|
||||
|
||||
@server.errorhandler(exceptions.InvalidArgumentError)
|
||||
def InvalidArgumentErrorHandler(err):
|
||||
logger.error(err)
|
||||
return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT)
|
||||
|
||||
|
||||
@server.errorhandler(exceptions.DBError)
|
||||
def DBErrorHandler(err):
|
||||
logger.error(err)
|
||||
return resp_handler(err, status_pb2.UNEXPECTED_ERROR)
|
||||
|
||||
|
||||
@server.errorhandler(exceptions.InvalidRangeError)
|
||||
def InvalidArgumentErrorHandler(err):
|
||||
logger.error(err)
|
||||
return resp_handler(err, status_pb2.ILLEGAL_RANGE)
|
|
@ -1,38 +0,0 @@
|
|||
import mishards.exception_codes as codes
|
||||
|
||||
|
||||
class BaseException(Exception):
|
||||
code = codes.INVALID_CODE
|
||||
message = 'BaseException'
|
||||
|
||||
def __init__(self, message='', metadata=None):
|
||||
self.message = self.__class__.__name__ if not message else message
|
||||
self.metadata = metadata
|
||||
|
||||
|
||||
class ConnectionConnectError(BaseException):
|
||||
code = codes.CONNECT_ERROR_CODE
|
||||
|
||||
|
||||
class ConnectionNotFoundError(BaseException):
|
||||
code = codes.CONNECTTION_NOT_FOUND_CODE
|
||||
|
||||
|
||||
class DBError(BaseException):
|
||||
code = codes.DB_ERROR_CODE
|
||||
|
||||
|
||||
class TableNotFoundError(BaseException):
|
||||
code = codes.TABLE_NOT_FOUND_CODE
|
||||
|
||||
|
||||
class InvalidTopKError(BaseException):
|
||||
code = codes.INVALID_TOPK_CODE
|
||||
|
||||
|
||||
class InvalidArgumentError(BaseException):
|
||||
code = codes.INVALID_ARGUMENT_CODE
|
||||
|
||||
|
||||
class InvalidRangeError(BaseException):
|
||||
code = codes.INVALID_DATE_RANGE_CODE
|
|
@ -1,54 +0,0 @@
|
|||
import time
|
||||
import datetime
|
||||
import random
|
||||
import factory
|
||||
from factory.alchemy import SQLAlchemyModelFactory
|
||||
from faker import Faker
|
||||
from faker.providers import BaseProvider
|
||||
|
||||
from milvus.client.types import MetricType
|
||||
from mishards import db
|
||||
from mishards.models import Tables, TableFiles
|
||||
|
||||
|
||||
class FakerProvider(BaseProvider):
|
||||
def this_date(self):
|
||||
t = datetime.datetime.today()
|
||||
return (t.year - 1900) * 10000 + (t.month - 1) * 100 + t.day
|
||||
|
||||
|
||||
factory.Faker.add_provider(FakerProvider)
|
||||
|
||||
|
||||
class TablesFactory(SQLAlchemyModelFactory):
|
||||
class Meta:
|
||||
model = Tables
|
||||
sqlalchemy_session = db.session_factory
|
||||
sqlalchemy_session_persistence = 'commit'
|
||||
|
||||
id = factory.Faker('random_number', digits=16, fix_len=True)
|
||||
table_id = factory.Faker('uuid4')
|
||||
state = factory.Faker('random_element', elements=(0, 1))
|
||||
dimension = factory.Faker('random_element', elements=(256, 512))
|
||||
created_on = int(time.time())
|
||||
index_file_size = 0
|
||||
engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3))
|
||||
metric_type = factory.Faker('random_element', elements=(MetricType.L2, MetricType.IP))
|
||||
nlist = 16384
|
||||
|
||||
|
||||
class TableFilesFactory(SQLAlchemyModelFactory):
|
||||
class Meta:
|
||||
model = TableFiles
|
||||
sqlalchemy_session = db.session_factory
|
||||
sqlalchemy_session_persistence = 'commit'
|
||||
|
||||
id = factory.Faker('random_number', digits=16, fix_len=True)
|
||||
table = factory.SubFactory(TablesFactory)
|
||||
engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3))
|
||||
file_id = factory.Faker('uuid4')
|
||||
file_type = factory.Faker('random_element', elements=(0, 1, 2, 3, 4))
|
||||
file_size = factory.Faker('random_number')
|
||||
updated_time = int(time.time())
|
||||
created_on = int(time.time())
|
||||
date = factory.Faker('this_date')
|
|
@ -1,37 +0,0 @@
|
|||
from grpc_opentracing import SpanDecorator
|
||||
from milvus.grpc_gen import status_pb2
|
||||
|
||||
|
||||
class GrpcSpanDecorator(SpanDecorator):
|
||||
def __call__(self, span, rpc_info):
|
||||
status = None
|
||||
if not rpc_info.response:
|
||||
return
|
||||
if isinstance(rpc_info.response, status_pb2.Status):
|
||||
status = rpc_info.response
|
||||
else:
|
||||
try:
|
||||
status = rpc_info.response.status
|
||||
except Exception as e:
|
||||
status = status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR,
|
||||
reason='Should not happen')
|
||||
|
||||
if status.error_code == 0:
|
||||
return
|
||||
error_log = {'event': 'error',
|
||||
'request': rpc_info.request,
|
||||
'response': rpc_info.response
|
||||
}
|
||||
span.set_tag('error', True)
|
||||
span.log_kv(error_log)
|
||||
|
||||
|
||||
def mark_grpc_method(func):
|
||||
setattr(func, 'grpc_method', True)
|
||||
return func
|
||||
|
||||
|
||||
def is_grpc_method(func):
|
||||
if not func:
|
||||
return False
|
||||
return getattr(func, 'grpc_method', False)
|
|
@ -1,102 +0,0 @@
|
|||
from milvus import Status
|
||||
from functools import wraps
|
||||
|
||||
|
||||
def error_status(func):
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
try:
|
||||
results = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return Status(code=Status.UNEXPECTED_ERROR, message=str(e)), None
|
||||
|
||||
return Status(code=0, message="Success"), results
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class GrpcArgsParser(object):
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_TableSchema(cls, param):
|
||||
_table_schema = {
|
||||
'status': param.status,
|
||||
'table_name': param.table_name,
|
||||
'dimension': param.dimension,
|
||||
'index_file_size': param.index_file_size,
|
||||
'metric_type': param.metric_type
|
||||
}
|
||||
|
||||
return _table_schema
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_TableName(cls, param):
|
||||
return param.table_name
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_Index(cls, param):
|
||||
_index = {
|
||||
'index_type': param.index_type,
|
||||
'nlist': param.nlist
|
||||
}
|
||||
|
||||
return _index
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_IndexParam(cls, param):
|
||||
_table_name = param.table_name
|
||||
_status, _index = cls.parse_proto_Index(param.index)
|
||||
|
||||
if not _status.OK():
|
||||
raise Exception("Argument parse error")
|
||||
|
||||
return _table_name, _index
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_Command(cls, param):
|
||||
_cmd = param.cmd
|
||||
|
||||
return _cmd
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_Range(cls, param):
|
||||
_start_value = param.start_value
|
||||
_end_value = param.end_value
|
||||
|
||||
return _start_value, _end_value
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_RowRecord(cls, param):
|
||||
return list(param.vector_data)
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_SearchParam(cls, param):
|
||||
_table_name = param.table_name
|
||||
_topk = param.topk
|
||||
_nprobe = param.nprobe
|
||||
_status, _range = cls.parse_proto_Range(param.query_range_array)
|
||||
|
||||
if not _status.OK():
|
||||
raise Exception("Argument parse error")
|
||||
|
||||
_row_record = param.query_record_array
|
||||
|
||||
return _table_name, _row_record, _range, _topk
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_DeleteByRangeParam(cls, param):
|
||||
_table_name = param.table_name
|
||||
_range = param.range
|
||||
_start_value = _range.start_value
|
||||
_end_value = _range.end_value
|
||||
|
||||
return _table_name, _start_value, _end_value
|
|
@ -1,4 +0,0 @@
|
|||
# class GrpcArgsWrapper(object):
|
||||
|
||||
# @classmethod
|
||||
# def proto_TableName(cls):
|
|
@ -1,75 +0,0 @@
|
|||
import logging
|
||||
import opentracing
|
||||
from mishards.grpc_utils import GrpcSpanDecorator, is_grpc_method
|
||||
from milvus.grpc_gen import status_pb2, milvus_pb2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FakeTracer(opentracing.Tracer):
|
||||
pass
|
||||
|
||||
|
||||
class FakeSpan(opentracing.Span):
|
||||
def __init__(self, context, tracer, **kwargs):
|
||||
super(FakeSpan, self).__init__(tracer, context)
|
||||
self.reset()
|
||||
|
||||
def set_tag(self, key, value):
|
||||
self.tags.append({key: value})
|
||||
|
||||
def log_kv(self, key_values, timestamp=None):
|
||||
self.logs.append(key_values)
|
||||
|
||||
def reset(self):
|
||||
self.tags = []
|
||||
self.logs = []
|
||||
|
||||
|
||||
class FakeRpcInfo:
|
||||
def __init__(self, request, response):
|
||||
self.request = request
|
||||
self.response = response
|
||||
|
||||
|
||||
class TestGrpcUtils:
|
||||
def test_span_deco(self):
|
||||
request = 'request'
|
||||
OK = status_pb2.Status(error_code=status_pb2.SUCCESS, reason='Success')
|
||||
response = OK
|
||||
rpc_info = FakeRpcInfo(request=request, response=response)
|
||||
span = FakeSpan(context=None, tracer=FakeTracer())
|
||||
span_deco = GrpcSpanDecorator()
|
||||
span_deco(span, rpc_info)
|
||||
assert len(span.logs) == 0
|
||||
assert len(span.tags) == 0
|
||||
|
||||
response = milvus_pb2.BoolReply(status=OK, bool_reply=False)
|
||||
rpc_info = FakeRpcInfo(request=request, response=response)
|
||||
span = FakeSpan(context=None, tracer=FakeTracer())
|
||||
span_deco = GrpcSpanDecorator()
|
||||
span_deco(span, rpc_info)
|
||||
assert len(span.logs) == 0
|
||||
assert len(span.tags) == 0
|
||||
|
||||
response = 1
|
||||
rpc_info = FakeRpcInfo(request=request, response=response)
|
||||
span = FakeSpan(context=None, tracer=FakeTracer())
|
||||
span_deco = GrpcSpanDecorator()
|
||||
span_deco(span, rpc_info)
|
||||
assert len(span.logs) == 1
|
||||
assert len(span.tags) == 1
|
||||
|
||||
response = 0
|
||||
rpc_info = FakeRpcInfo(request=request, response=response)
|
||||
span = FakeSpan(context=None, tracer=FakeTracer())
|
||||
span_deco = GrpcSpanDecorator()
|
||||
span_deco(span, rpc_info)
|
||||
assert len(span.logs) == 0
|
||||
assert len(span.tags) == 0
|
||||
|
||||
def test_is_grpc_method(self):
|
||||
target = 1
|
||||
assert not is_grpc_method(target)
|
||||
target = None
|
||||
assert not is_grpc_method(target)
|
|
@ -1,150 +0,0 @@
|
|||
import math
|
||||
import sys
|
||||
from bisect import bisect
|
||||
|
||||
if sys.version_info >= (2, 5):
|
||||
import hashlib
|
||||
md5_constructor = hashlib.md5
|
||||
else:
|
||||
import md5
|
||||
md5_constructor = md5.new
|
||||
|
||||
|
||||
class HashRing(object):
|
||||
def __init__(self, nodes=None, weights=None):
|
||||
"""`nodes` is a list of objects that have a proper __str__ representation.
|
||||
`weights` is dictionary that sets weights to the nodes. The default
|
||||
weight is that all nodes are equal.
|
||||
"""
|
||||
self.ring = dict()
|
||||
self._sorted_keys = []
|
||||
|
||||
self.nodes = nodes
|
||||
|
||||
if not weights:
|
||||
weights = {}
|
||||
self.weights = weights
|
||||
|
||||
self._generate_circle()
|
||||
|
||||
def _generate_circle(self):
|
||||
"""Generates the circle.
|
||||
"""
|
||||
total_weight = 0
|
||||
for node in self.nodes:
|
||||
total_weight += self.weights.get(node, 1)
|
||||
|
||||
for node in self.nodes:
|
||||
weight = 1
|
||||
|
||||
if node in self.weights:
|
||||
weight = self.weights.get(node)
|
||||
|
||||
factor = math.floor((40 * len(self.nodes) * weight) / total_weight)
|
||||
|
||||
for j in range(0, int(factor)):
|
||||
b_key = self._hash_digest('%s-%s' % (node, j))
|
||||
|
||||
for i in range(0, 3):
|
||||
key = self._hash_val(b_key, lambda x: x + i * 4)
|
||||
self.ring[key] = node
|
||||
self._sorted_keys.append(key)
|
||||
|
||||
self._sorted_keys.sort()
|
||||
|
||||
def get_node(self, string_key):
|
||||
"""Given a string key a corresponding node in the hash ring is returned.
|
||||
|
||||
If the hash ring is empty, `None` is returned.
|
||||
"""
|
||||
pos = self.get_node_pos(string_key)
|
||||
if pos is None:
|
||||
return None
|
||||
return self.ring[self._sorted_keys[pos]]
|
||||
|
||||
def get_node_pos(self, string_key):
|
||||
"""Given a string key a corresponding node in the hash ring is returned
|
||||
along with it's position in the ring.
|
||||
|
||||
If the hash ring is empty, (`None`, `None`) is returned.
|
||||
"""
|
||||
if not self.ring:
|
||||
return None
|
||||
|
||||
key = self.gen_key(string_key)
|
||||
|
||||
nodes = self._sorted_keys
|
||||
pos = bisect(nodes, key)
|
||||
|
||||
if pos == len(nodes):
|
||||
return 0
|
||||
else:
|
||||
return pos
|
||||
|
||||
def iterate_nodes(self, string_key, distinct=True):
|
||||
"""Given a string key it returns the nodes as a generator that can hold the key.
|
||||
|
||||
The generator iterates one time through the ring
|
||||
starting at the correct position.
|
||||
|
||||
if `distinct` is set, then the nodes returned will be unique,
|
||||
i.e. no virtual copies will be returned.
|
||||
"""
|
||||
if not self.ring:
|
||||
yield None, None
|
||||
|
||||
returned_values = set()
|
||||
|
||||
def distinct_filter(value):
|
||||
if str(value) not in returned_values:
|
||||
returned_values.add(str(value))
|
||||
return value
|
||||
|
||||
pos = self.get_node_pos(string_key)
|
||||
for key in self._sorted_keys[pos:]:
|
||||
val = distinct_filter(self.ring[key])
|
||||
if val:
|
||||
yield val
|
||||
|
||||
for i, key in enumerate(self._sorted_keys):
|
||||
if i < pos:
|
||||
val = distinct_filter(self.ring[key])
|
||||
if val:
|
||||
yield val
|
||||
|
||||
def gen_key(self, key):
|
||||
"""Given a string key it returns a long value,
|
||||
this long value represents a place on the hash ring.
|
||||
|
||||
md5 is currently used because it mixes well.
|
||||
"""
|
||||
b_key = self._hash_digest(key)
|
||||
return self._hash_val(b_key, lambda x: x)
|
||||
|
||||
def _hash_val(self, b_key, entry_fn):
|
||||
return (b_key[entry_fn(3)] << 24) | (b_key[entry_fn(2)] << 16) | (
|
||||
b_key[entry_fn(1)] << 8) | b_key[entry_fn(0)]
|
||||
|
||||
def _hash_digest(self, key):
|
||||
m = md5_constructor()
|
||||
key = key.encode()
|
||||
m.update(key)
|
||||
return m.digest()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from collections import defaultdict
|
||||
servers = [
|
||||
'192.168.0.246:11212', '192.168.0.247:11212', '192.168.0.248:11212',
|
||||
'192.168.0.249:11212'
|
||||
]
|
||||
|
||||
ring = HashRing(servers)
|
||||
keys = ['{}'.format(i) for i in range(100)]
|
||||
mapped = defaultdict(list)
|
||||
for k in keys:
|
||||
server = ring.get_node(k)
|
||||
mapped[server].append(k)
|
||||
|
||||
for k, v in mapped.items():
|
||||
print(k, v)
|
|
@ -1,15 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from mishards import (settings, create_app)
|
||||
|
||||
|
||||
def main():
|
||||
server = create_app(settings.DefaultConfig)
|
||||
server.run(port=settings.SERVER_PORT)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
|
@ -1,76 +0,0 @@
|
|||
import logging
|
||||
from sqlalchemy import (Integer, Boolean, Text,
|
||||
String, BigInteger, and_, or_,
|
||||
Column)
|
||||
from sqlalchemy.orm import relationship, backref
|
||||
|
||||
from mishards import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TableFiles(db.Model):
|
||||
FILE_TYPE_NEW = 0
|
||||
FILE_TYPE_RAW = 1
|
||||
FILE_TYPE_TO_INDEX = 2
|
||||
FILE_TYPE_INDEX = 3
|
||||
FILE_TYPE_TO_DELETE = 4
|
||||
FILE_TYPE_NEW_MERGE = 5
|
||||
FILE_TYPE_NEW_INDEX = 6
|
||||
FILE_TYPE_BACKUP = 7
|
||||
|
||||
__tablename__ = 'TableFiles'
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
table_id = Column(String(50))
|
||||
engine_type = Column(Integer)
|
||||
file_id = Column(String(50))
|
||||
file_type = Column(Integer)
|
||||
file_size = Column(Integer, default=0)
|
||||
row_count = Column(Integer, default=0)
|
||||
updated_time = Column(BigInteger)
|
||||
created_on = Column(BigInteger)
|
||||
date = Column(Integer)
|
||||
|
||||
table = relationship(
|
||||
'Tables',
|
||||
primaryjoin='and_(foreign(TableFiles.table_id) == Tables.table_id)',
|
||||
backref=backref('files', uselist=True, lazy='dynamic')
|
||||
)
|
||||
|
||||
|
||||
class Tables(db.Model):
|
||||
TO_DELETE = 1
|
||||
NORMAL = 0
|
||||
|
||||
__tablename__ = 'Tables'
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
table_id = Column(String(50), unique=True)
|
||||
state = Column(Integer)
|
||||
dimension = Column(Integer)
|
||||
created_on = Column(Integer)
|
||||
flag = Column(Integer, default=0)
|
||||
index_file_size = Column(Integer)
|
||||
engine_type = Column(Integer)
|
||||
nlist = Column(Integer)
|
||||
metric_type = Column(Integer)
|
||||
|
||||
def files_to_search(self, date_range=None):
|
||||
cond = or_(
|
||||
TableFiles.file_type == TableFiles.FILE_TYPE_RAW,
|
||||
TableFiles.file_type == TableFiles.FILE_TYPE_TO_INDEX,
|
||||
TableFiles.file_type == TableFiles.FILE_TYPE_INDEX,
|
||||
)
|
||||
if date_range:
|
||||
cond = and_(
|
||||
cond,
|
||||
or_(
|
||||
and_(TableFiles.date >= d[0], TableFiles.date < d[1]) for d in date_range
|
||||
)
|
||||
)
|
||||
|
||||
files = self.files.filter(cond)
|
||||
|
||||
logger.debug('DATE_RANGE: {}'.format(date_range))
|
||||
return files
|
|
@ -1,96 +0,0 @@
|
|||
import logging
|
||||
from sqlalchemy import exc as sqlalchemy_exc
|
||||
from sqlalchemy import and_
|
||||
|
||||
from mishards import exceptions, db
|
||||
from mishards.hash_ring import HashRing
|
||||
from mishards.models import Tables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RouteManager:
|
||||
ROUTER_CLASSES = {}
|
||||
|
||||
@classmethod
|
||||
def register_router_class(cls, target):
|
||||
name = target.__dict__.get('NAME', None)
|
||||
name = name if name else target.__class__.__name__
|
||||
cls.ROUTER_CLASSES[name] = target
|
||||
return target
|
||||
|
||||
@classmethod
|
||||
def get_router_class(cls, name):
|
||||
return cls.ROUTER_CLASSES.get(name, None)
|
||||
|
||||
|
||||
class RouterFactory:
|
||||
@classmethod
|
||||
def new_router(cls, name, conn_mgr, **kwargs):
|
||||
router_class = RouteManager.get_router_class(name)
|
||||
assert router_class
|
||||
return router_class(conn_mgr, **kwargs)
|
||||
|
||||
|
||||
class RouterMixin:
|
||||
def __init__(self, conn_mgr):
|
||||
self.conn_mgr = conn_mgr
|
||||
|
||||
def routing(self, table_name, metadata=None, **kwargs):
|
||||
raise NotImplemented()
|
||||
|
||||
def connection(self, metadata=None):
|
||||
conn = self.conn_mgr.conn('WOSERVER', metadata=metadata)
|
||||
if conn:
|
||||
conn.on_connect(metadata=metadata)
|
||||
return conn.conn
|
||||
|
||||
def query_conn(self, name, metadata=None):
|
||||
conn = self.conn_mgr.conn(name, metadata=metadata)
|
||||
if not conn:
|
||||
raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
|
||||
conn.on_connect(metadata=metadata)
|
||||
return conn.conn
|
||||
|
||||
|
||||
@RouteManager.register_router_class
|
||||
class FileBasedHashRingRouter(RouterMixin):
|
||||
NAME = 'FileBasedHashRingRouter'
|
||||
|
||||
def __init__(self, conn_mgr, **kwargs):
|
||||
super(FileBasedHashRingRouter, self).__init__(conn_mgr)
|
||||
|
||||
def routing(self, table_name, metadata=None, **kwargs):
|
||||
range_array = kwargs.pop('range_array', None)
|
||||
return self._route(table_name, range_array, metadata, **kwargs)
|
||||
|
||||
def _route(self, table_name, range_array, metadata=None, **kwargs):
|
||||
# PXU TODO: Implement Thread-local Context
|
||||
# PXU TODO: Session life mgt
|
||||
try:
|
||||
table = db.Session.query(Tables).filter(
|
||||
and_(Tables.table_id == table_name,
|
||||
Tables.state != Tables.TO_DELETE)).first()
|
||||
except sqlalchemy_exc.SQLAlchemyError as e:
|
||||
raise exceptions.DBError(message=str(e), metadata=metadata)
|
||||
|
||||
if not table:
|
||||
raise exceptions.TableNotFoundError(table_name, metadata=metadata)
|
||||
files = table.files_to_search(range_array)
|
||||
db.remove_session()
|
||||
|
||||
servers = self.conn_mgr.conn_names
|
||||
logger.info('Available servers: {}'.format(servers))
|
||||
|
||||
ring = HashRing(servers)
|
||||
|
||||
routing = {}
|
||||
|
||||
for f in files:
|
||||
target_host = ring.get_node(str(f.id))
|
||||
sub = routing.get(target_host, None)
|
||||
if not sub:
|
||||
routing[target_host] = {'table_id': table_name, 'file_ids': []}
|
||||
routing[target_host]['file_ids'].append(str(f.id))
|
||||
|
||||
return routing
|
|
@ -1,122 +0,0 @@
|
|||
import logging
|
||||
import grpc
|
||||
import time
|
||||
import socket
|
||||
import inspect
|
||||
from urllib.parse import urlparse
|
||||
from functools import wraps
|
||||
from concurrent import futures
|
||||
from grpc._cython import cygrpc
|
||||
from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server
|
||||
from mishards.grpc_utils import is_grpc_method
|
||||
from mishards.service_handler import ServiceHandler
|
||||
from mishards import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self):
|
||||
self.pre_run_handlers = set()
|
||||
self.grpc_methods = set()
|
||||
self.error_handlers = {}
|
||||
self.exit_flag = False
|
||||
|
||||
def init_app(self,
|
||||
conn_mgr,
|
||||
tracer,
|
||||
router,
|
||||
discover,
|
||||
port=19530,
|
||||
max_workers=10,
|
||||
**kwargs):
|
||||
self.port = int(port)
|
||||
self.conn_mgr = conn_mgr
|
||||
self.tracer = tracer
|
||||
self.router = router
|
||||
self.discover = discover
|
||||
|
||||
self.server_impl = grpc.server(
|
||||
thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
|
||||
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
|
||||
(cygrpc.ChannelArgKey.max_receive_message_length, -1)])
|
||||
|
||||
self.server_impl = self.tracer.decorate(self.server_impl)
|
||||
|
||||
self.register_pre_run_handler(self.pre_run_handler)
|
||||
|
||||
def pre_run_handler(self):
|
||||
woserver = settings.WOSERVER
|
||||
url = urlparse(woserver)
|
||||
ip = socket.gethostbyname(url.hostname)
|
||||
socket.inet_pton(socket.AF_INET, ip)
|
||||
self.conn_mgr.register(
|
||||
'WOSERVER', '{}://{}:{}'.format(url.scheme, ip, url.port or 80))
|
||||
|
||||
def register_pre_run_handler(self, func):
|
||||
logger.info('Regiterring {} into server pre_run_handlers'.format(func))
|
||||
self.pre_run_handlers.add(func)
|
||||
return func
|
||||
|
||||
def wrap_method_with_errorhandler(self, func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if e.__class__ in self.error_handlers:
|
||||
return self.error_handlers[e.__class__](e)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
def errorhandler(self, exception):
|
||||
if inspect.isclass(exception) and issubclass(exception, Exception):
|
||||
|
||||
def wrapper(func):
|
||||
self.error_handlers[exception] = func
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
return exception
|
||||
|
||||
def on_pre_run(self):
|
||||
for handler in self.pre_run_handlers:
|
||||
handler()
|
||||
self.discover.start()
|
||||
|
||||
def start(self, port=None):
|
||||
handler_class = self.decorate_handler(ServiceHandler)
|
||||
add_MilvusServiceServicer_to_server(
|
||||
handler_class(tracer=self.tracer,
|
||||
router=self.router), self.server_impl)
|
||||
self.server_impl.add_insecure_port("[::]:{}".format(
|
||||
str(port or self.port)))
|
||||
self.server_impl.start()
|
||||
|
||||
def run(self, port):
|
||||
logger.info('Milvus server start ......')
|
||||
port = port or self.port
|
||||
self.on_pre_run()
|
||||
|
||||
self.start(port)
|
||||
logger.info('Listening on port {}'.format(port))
|
||||
|
||||
try:
|
||||
while not self.exit_flag:
|
||||
time.sleep(5)
|
||||
except KeyboardInterrupt:
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
logger.info('Server is shuting down ......')
|
||||
self.exit_flag = True
|
||||
self.server_impl.stop(0)
|
||||
self.tracer.close()
|
||||
logger.info('Server is closed')
|
||||
|
||||
def decorate_handler(self, handler):
|
||||
for key, attr in handler.__dict__.items():
|
||||
if is_grpc_method(attr):
|
||||
setattr(handler, key, self.wrap_method_with_errorhandler(attr))
|
||||
return handler
|
|
@ -1,475 +0,0 @@
|
|||
import logging
|
||||
import time
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
|
||||
from milvus.grpc_gen.milvus_pb2 import TopKQueryResult
|
||||
from milvus.client.abstract import Range
|
||||
from milvus.client import types as Types
|
||||
|
||||
from mishards import (db, settings, exceptions)
|
||||
from mishards.grpc_utils import mark_grpc_method
|
||||
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
|
||||
from mishards import utilities
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
MAX_NPROBE = 2048
|
||||
MAX_TOPK = 2048
|
||||
|
||||
def __init__(self, tracer, router, max_workers=multiprocessing.cpu_count(), **kwargs):
|
||||
self.table_meta = {}
|
||||
self.error_handlers = {}
|
||||
self.tracer = tracer
|
||||
self.router = router
|
||||
self.max_workers = max_workers
|
||||
|
||||
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
|
||||
status = status_pb2.Status(error_code=status_pb2.SUCCESS,
|
||||
reason="Success")
|
||||
if not files_n_topk_results:
|
||||
return status, []
|
||||
|
||||
request_results = defaultdict(list)
|
||||
|
||||
calc_time = time.time()
|
||||
for files_collection in files_n_topk_results:
|
||||
if isinstance(files_collection, tuple):
|
||||
status, _ = files_collection
|
||||
return status, []
|
||||
for request_pos, each_request_results in enumerate(
|
||||
files_collection.topk_query_result):
|
||||
request_results[request_pos].extend(
|
||||
each_request_results.query_result_arrays)
|
||||
request_results[request_pos] = sorted(
|
||||
request_results[request_pos],
|
||||
key=lambda x: x.distance,
|
||||
reverse=reverse)[:topk]
|
||||
|
||||
calc_time = time.time() - calc_time
|
||||
logger.info('Merge takes {}'.format(calc_time))
|
||||
|
||||
results = sorted(request_results.items())
|
||||
topk_query_result = []
|
||||
|
||||
for result in results:
|
||||
query_result = TopKQueryResult(query_result_arrays=result[1])
|
||||
topk_query_result.append(query_result)
|
||||
|
||||
return status, topk_query_result
|
||||
|
||||
def _do_query(self,
|
||||
context,
|
||||
table_id,
|
||||
table_meta,
|
||||
vectors,
|
||||
topk,
|
||||
nprobe,
|
||||
range_array=None,
|
||||
**kwargs):
|
||||
metadata = kwargs.get('metadata', None)
|
||||
range_array = [
|
||||
utilities.range_to_date(r, metadata=metadata) for r in range_array
|
||||
] if range_array else None
|
||||
|
||||
routing = {}
|
||||
p_span = None if self.tracer.empty else context.get_active_span(
|
||||
).context
|
||||
with self.tracer.start_span('get_routing', child_of=p_span):
|
||||
routing = self.router.routing(table_id,
|
||||
range_array=range_array,
|
||||
metadata=metadata)
|
||||
logger.info('Routing: {}'.format(routing))
|
||||
|
||||
metadata = kwargs.get('metadata', None)
|
||||
|
||||
rs = []
|
||||
all_topk_results = []
|
||||
|
||||
def search(addr, query_params, vectors, topk, nprobe, **kwargs):
|
||||
logger.info(
|
||||
'Send Search Request: addr={};params={};nq={};topk={};nprobe={}'
|
||||
.format(addr, query_params, len(vectors), topk, nprobe))
|
||||
|
||||
conn = self.router.query_conn(addr, metadata=metadata)
|
||||
start = time.time()
|
||||
span = kwargs.get('span', None)
|
||||
span = span if span else (None if self.tracer.empty else
|
||||
context.get_active_span().context)
|
||||
|
||||
with self.tracer.start_span('search_{}'.format(addr),
|
||||
child_of=span):
|
||||
ret = conn.search_vectors_in_files(
|
||||
table_name=query_params['table_id'],
|
||||
file_ids=query_params['file_ids'],
|
||||
query_records=vectors,
|
||||
top_k=topk,
|
||||
nprobe=nprobe,
|
||||
lazy_=True)
|
||||
end = time.time()
|
||||
logger.info('search_vectors_in_files takes: {}'.format(end - start))
|
||||
|
||||
all_topk_results.append(ret)
|
||||
|
||||
with self.tracer.start_span('do_search', child_of=p_span) as span:
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as pool:
|
||||
for addr, params in routing.items():
|
||||
res = pool.submit(search,
|
||||
addr,
|
||||
params,
|
||||
vectors,
|
||||
topk,
|
||||
nprobe,
|
||||
span=span)
|
||||
rs.append(res)
|
||||
|
||||
for res in rs:
|
||||
res.result()
|
||||
|
||||
reverse = table_meta.metric_type == Types.MetricType.IP
|
||||
with self.tracer.start_span('do_merge', child_of=p_span):
|
||||
return self._do_merge(all_topk_results,
|
||||
topk,
|
||||
reverse=reverse,
|
||||
metadata=metadata)
|
||||
|
||||
def _create_table(self, table_schema):
|
||||
return self.router.connection().create_table(table_schema)
|
||||
|
||||
@mark_grpc_method
|
||||
def CreateTable(self, request, context):
|
||||
_status, _table_schema = Parser.parse_proto_TableSchema(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
logger.info('CreateTable {}'.format(_table_schema['table_name']))
|
||||
|
||||
_status = self._create_table(_table_schema)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
def _has_table(self, table_name, metadata=None):
|
||||
return self.router.connection(metadata=metadata).has_table(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def HasTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.BoolReply(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message),
|
||||
bool_reply=False)
|
||||
|
||||
logger.info('HasTable {}'.format(_table_name))
|
||||
|
||||
_status, _bool = self._has_table(_table_name,
|
||||
metadata={'resp_class': milvus_pb2.BoolReply})
|
||||
|
||||
return milvus_pb2.BoolReply(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message),
|
||||
bool_reply=_bool)
|
||||
|
||||
def _delete_table(self, table_name):
|
||||
return self.router.connection().delete_table(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def DropTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
logger.info('DropTable {}'.format(_table_name))
|
||||
|
||||
_status = self._delete_table(_table_name)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
def _create_index(self, table_name, index):
|
||||
return self.router.connection().create_index(table_name, index)
|
||||
|
||||
@mark_grpc_method
|
||||
def CreateIndex(self, request, context):
|
||||
_status, unpacks = Parser.parse_proto_IndexParam(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
_table_name, _index = unpacks
|
||||
|
||||
logger.info('CreateIndex {}'.format(_table_name))
|
||||
|
||||
# TODO: interface create_table incompleted
|
||||
_status = self._create_index(_table_name, _index)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
def _add_vectors(self, param, metadata=None):
|
||||
return self.router.connection(metadata=metadata).add_vectors(
|
||||
None, None, insert_param=param)
|
||||
|
||||
@mark_grpc_method
|
||||
def Insert(self, request, context):
|
||||
logger.info('Insert')
|
||||
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
|
||||
_status, _ids = self._add_vectors(
|
||||
metadata={'resp_class': milvus_pb2.VectorIds}, param=request)
|
||||
return milvus_pb2.VectorIds(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message),
|
||||
vector_id_array=_ids)
|
||||
|
||||
@mark_grpc_method
|
||||
def Search(self, request, context):
|
||||
|
||||
table_name = request.table_name
|
||||
|
||||
topk = request.topk
|
||||
nprobe = request.nprobe
|
||||
|
||||
logger.info('Search {}: topk={} nprobe={}'.format(
|
||||
table_name, topk, nprobe))
|
||||
|
||||
metadata = {'resp_class': milvus_pb2.TopKQueryResultList}
|
||||
|
||||
if nprobe > self.MAX_NPROBE or nprobe <= 0:
|
||||
raise exceptions.InvalidArgumentError(
|
||||
message='Invalid nprobe: {}'.format(nprobe), metadata=metadata)
|
||||
|
||||
if topk > self.MAX_TOPK or topk <= 0:
|
||||
raise exceptions.InvalidTopKError(
|
||||
message='Invalid topk: {}'.format(topk), metadata=metadata)
|
||||
|
||||
table_meta = self.table_meta.get(table_name, None)
|
||||
|
||||
if not table_meta:
|
||||
status, info = self.router.connection(
|
||||
metadata=metadata).describe_table(table_name)
|
||||
if not status.OK():
|
||||
raise exceptions.TableNotFoundError(table_name,
|
||||
metadata=metadata)
|
||||
|
||||
self.table_meta[table_name] = info
|
||||
table_meta = info
|
||||
|
||||
start = time.time()
|
||||
|
||||
query_record_array = []
|
||||
|
||||
for query_record in request.query_record_array:
|
||||
query_record_array.append(list(query_record.vector_data))
|
||||
|
||||
query_range_array = []
|
||||
for query_range in request.query_range_array:
|
||||
query_range_array.append(
|
||||
Range(query_range.start_value, query_range.end_value))
|
||||
|
||||
status, results = self._do_query(context,
|
||||
table_name,
|
||||
table_meta,
|
||||
query_record_array,
|
||||
topk,
|
||||
nprobe,
|
||||
query_range_array,
|
||||
metadata=metadata)
|
||||
|
||||
now = time.time()
|
||||
logger.info('SearchVector takes: {}'.format(now - start))
|
||||
|
||||
topk_result_list = milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=status.error_code,
|
||||
reason=status.reason),
|
||||
topk_query_result=results)
|
||||
return topk_result_list
|
||||
|
||||
@mark_grpc_method
|
||||
def SearchInFiles(self, request, context):
|
||||
raise NotImplemented()
|
||||
|
||||
def _describe_table(self, table_name, metadata=None):
|
||||
return self.router.connection(metadata=metadata).describe_table(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def DescribeTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.TableSchema(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message), )
|
||||
|
||||
metadata = {'resp_class': milvus_pb2.TableSchema}
|
||||
|
||||
logger.info('DescribeTable {}'.format(_table_name))
|
||||
_status, _table = self._describe_table(metadata=metadata,
|
||||
table_name=_table_name)
|
||||
|
||||
if _status.OK():
|
||||
return milvus_pb2.TableSchema(
|
||||
table_name=_table_name,
|
||||
index_file_size=_table.index_file_size,
|
||||
dimension=_table.dimension,
|
||||
metric_type=_table.metric_type,
|
||||
status=status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message),
|
||||
)
|
||||
|
||||
return milvus_pb2.TableSchema(
|
||||
table_name=_table_name,
|
||||
status=status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message),
|
||||
)
|
||||
|
||||
def _count_table(self, table_name, metadata=None):
|
||||
return self.router.connection(
|
||||
metadata=metadata).get_table_row_count(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def CountTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
status = status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
return milvus_pb2.TableRowCount(status=status)
|
||||
|
||||
logger.info('CountTable {}'.format(_table_name))
|
||||
|
||||
metadata = {'resp_class': milvus_pb2.TableRowCount}
|
||||
_status, _count = self._count_table(_table_name, metadata=metadata)
|
||||
|
||||
return milvus_pb2.TableRowCount(
|
||||
status=status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message),
|
||||
table_row_count=_count if isinstance(_count, int) else -1)
|
||||
|
||||
def _get_server_version(self, metadata=None):
|
||||
return self.router.connection(metadata=metadata).server_version()
|
||||
|
||||
@mark_grpc_method
|
||||
def Cmd(self, request, context):
|
||||
_status, _cmd = Parser.parse_proto_Command(request)
|
||||
logger.info('Cmd: {}'.format(_cmd))
|
||||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.StringReply(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message))
|
||||
|
||||
metadata = {'resp_class': milvus_pb2.StringReply}
|
||||
|
||||
if _cmd == 'version':
|
||||
_status, _reply = self._get_server_version(metadata=metadata)
|
||||
else:
|
||||
_status, _reply = self.router.connection(
|
||||
metadata=metadata).server_status()
|
||||
|
||||
return milvus_pb2.StringReply(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message),
|
||||
string_reply=_reply)
|
||||
|
||||
def _show_tables(self, metadata=None):
|
||||
return self.router.connection(metadata=metadata).show_tables()
|
||||
|
||||
@mark_grpc_method
|
||||
def ShowTables(self, request, context):
|
||||
logger.info('ShowTables')
|
||||
metadata = {'resp_class': milvus_pb2.TableName}
|
||||
_status, _results = self._show_tables(metadata=metadata)
|
||||
|
||||
return milvus_pb2.TableNameList(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message),
|
||||
table_names=_results)
|
||||
|
||||
def _delete_by_range(self, table_name, start_date, end_date):
|
||||
return self.router.connection().delete_vectors_by_range(table_name,
|
||||
start_date,
|
||||
end_date)
|
||||
|
||||
@mark_grpc_method
|
||||
def DeleteByRange(self, request, context):
|
||||
_status, unpacks = \
|
||||
Parser.parse_proto_DeleteByRangeParam(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
_table_name, _start_date, _end_date = unpacks
|
||||
|
||||
logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date,
|
||||
_end_date))
|
||||
_status = self._delete_by_range(_table_name, _start_date, _end_date)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
def _preload_table(self, table_name):
|
||||
return self.router.connection().preload_table(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def PreloadTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
logger.info('PreloadTable {}'.format(_table_name))
|
||||
_status = self._preload_table(_table_name)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
def _describe_index(self, table_name, metadata=None):
|
||||
return self.router.connection(metadata=metadata).describe_index(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def DescribeIndex(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.IndexParam(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message))
|
||||
|
||||
metadata = {'resp_class': milvus_pb2.IndexParam}
|
||||
|
||||
logger.info('DescribeIndex {}'.format(_table_name))
|
||||
_status, _index_param = self._describe_index(table_name=_table_name,
|
||||
metadata=metadata)
|
||||
|
||||
if not _index_param:
|
||||
return milvus_pb2.IndexParam(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message))
|
||||
|
||||
_index = milvus_pb2.Index(index_type=_index_param._index_type,
|
||||
nlist=_index_param._nlist)
|
||||
|
||||
return milvus_pb2.IndexParam(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message),
|
||||
table_name=_table_name,
|
||||
index=_index)
|
||||
|
||||
def _drop_index(self, table_name):
|
||||
return self.router.connection().drop_index(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def DropIndex(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
logger.info('DropIndex {}'.format(_table_name))
|
||||
_status = self._drop_index(_table_name)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
|
@ -1,94 +0,0 @@
|
|||
import sys
|
||||
import os
|
||||
|
||||
from environs import Env
|
||||
env = Env()
|
||||
|
||||
FROM_EXAMPLE = env.bool('FROM_EXAMPLE', False)
|
||||
if FROM_EXAMPLE:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv('./mishards/.env.example')
|
||||
else:
|
||||
env.read_env()
|
||||
|
||||
DEBUG = env.bool('DEBUG', False)
|
||||
|
||||
LOG_LEVEL = env.str('LOG_LEVEL', 'DEBUG' if DEBUG else 'INFO')
|
||||
LOG_PATH = env.str('LOG_PATH', '/tmp/mishards')
|
||||
LOG_NAME = env.str('LOG_NAME', 'logfile')
|
||||
TIMEZONE = env.str('TIMEZONE', 'UTC')
|
||||
|
||||
from utils.logger_helper import config
|
||||
config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE)
|
||||
|
||||
TIMEOUT = env.int('TIMEOUT', 60)
|
||||
MAX_RETRY = env.int('MAX_RETRY', 3)
|
||||
|
||||
SERVER_PORT = env.int('SERVER_PORT', 19530)
|
||||
SERVER_TEST_PORT = env.int('SERVER_TEST_PORT', 19530)
|
||||
WOSERVER = env.str('WOSERVER')
|
||||
|
||||
SD_PROVIDER_SETTINGS = None
|
||||
SD_PROVIDER = env.str('SD_PROVIDER', 'Kubernetes')
|
||||
if SD_PROVIDER == 'Kubernetes':
|
||||
from sd.kubernetes_provider import KubernetesProviderSettings
|
||||
SD_PROVIDER_SETTINGS = KubernetesProviderSettings(
|
||||
namespace=env.str('SD_NAMESPACE', ''),
|
||||
in_cluster=env.bool('SD_IN_CLUSTER', False),
|
||||
poll_interval=env.int('SD_POLL_INTERVAL', 5),
|
||||
pod_patt=env.str('SD_ROSERVER_POD_PATT', ''),
|
||||
label_selector=env.str('SD_LABEL_SELECTOR', ''),
|
||||
port=env.int('SD_PORT', 19530))
|
||||
elif SD_PROVIDER == 'Static':
|
||||
from sd.static_provider import StaticProviderSettings
|
||||
SD_PROVIDER_SETTINGS = StaticProviderSettings(
|
||||
hosts=env.list('SD_STATIC_HOSTS', []),
|
||||
port=env.int('SD_STATIC_PORT', 19530))
|
||||
|
||||
# TESTING_WOSERVER = env.str('TESTING_WOSERVER', 'tcp://127.0.0.1:19530')
|
||||
|
||||
|
||||
class TracingConfig:
|
||||
TRACING_SERVICE_NAME = env.str('TRACING_SERVICE_NAME', 'mishards')
|
||||
TRACING_VALIDATE = env.bool('TRACING_VALIDATE', True)
|
||||
TRACING_LOG_PAYLOAD = env.bool('TRACING_LOG_PAYLOAD', False)
|
||||
TRACING_CONFIG = {
|
||||
'sampler': {
|
||||
'type': env.str('TRACING_SAMPLER_TYPE', 'const'),
|
||||
'param': env.str('TRACING_SAMPLER_PARAM', "1"),
|
||||
},
|
||||
'local_agent': {
|
||||
'reporting_host': env.str('TRACING_REPORTING_HOST', '127.0.0.1'),
|
||||
'reporting_port': env.str('TRACING_REPORTING_PORT', '5775')
|
||||
},
|
||||
'logging': env.bool('TRACING_LOGGING', True)
|
||||
}
|
||||
DEFAULT_TRACING_CONFIG = {
|
||||
'sampler': {
|
||||
'type': env.str('TRACING_SAMPLER_TYPE', 'const'),
|
||||
'param': env.str('TRACING_SAMPLER_PARAM', "0"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class DefaultConfig:
|
||||
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI')
|
||||
SQL_ECHO = env.bool('SQL_ECHO', False)
|
||||
TRACING_TYPE = env.str('TRACING_TYPE', '')
|
||||
ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_NAME', 'FileBasedHashRingRouter')
|
||||
|
||||
|
||||
class TestingConfig(DefaultConfig):
|
||||
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_TEST_URI', '')
|
||||
SQL_ECHO = env.bool('SQL_TEST_ECHO', False)
|
||||
TRACING_TYPE = env.str('TRACING_TEST_TYPE', '')
|
||||
ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_TEST_NAME', 'FileBasedHashRingRouter')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.debug('DEBUG')
|
||||
logger.info('INFO')
|
||||
logger.warn('WARN')
|
||||
logger.error('ERROR')
|
|
@ -1,101 +0,0 @@
|
|||
import logging
|
||||
import pytest
|
||||
import mock
|
||||
|
||||
from milvus import Milvus
|
||||
from mishards.connections import (ConnectionMgr, Connection)
|
||||
from mishards import exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('app')
|
||||
class TestConnection:
|
||||
def test_manager(self):
|
||||
mgr = ConnectionMgr()
|
||||
|
||||
mgr.register('pod1', '111')
|
||||
mgr.register('pod2', '222')
|
||||
mgr.register('pod2', '222')
|
||||
mgr.register('pod2', '2222')
|
||||
assert len(mgr.conn_names) == 2
|
||||
|
||||
mgr.unregister('pod1')
|
||||
assert len(mgr.conn_names) == 1
|
||||
|
||||
mgr.unregister('pod2')
|
||||
assert len(mgr.conn_names) == 0
|
||||
|
||||
mgr.register('WOSERVER', 'xxxx')
|
||||
assert len(mgr.conn_names) == 0
|
||||
|
||||
assert not mgr.conn('XXXX', None)
|
||||
with pytest.raises(exceptions.ConnectionNotFoundError):
|
||||
mgr.conn('XXXX', None, True)
|
||||
|
||||
mgr.conn('WOSERVER', None)
|
||||
|
||||
def test_connection(self):
|
||||
class Conn:
|
||||
def __init__(self, state):
|
||||
self.state = state
|
||||
|
||||
def connect(self, uri):
|
||||
return self.state
|
||||
|
||||
def connected(self):
|
||||
return self.state
|
||||
|
||||
FAIL_CONN = Conn(False)
|
||||
PASS_CONN = Conn(True)
|
||||
|
||||
class Retry:
|
||||
def __init__(self):
|
||||
self.times = 0
|
||||
|
||||
def __call__(self, conn):
|
||||
self.times += 1
|
||||
logger.info('Retrying {}'.format(self.times))
|
||||
|
||||
class Func():
|
||||
def __init__(self):
|
||||
self.executed = False
|
||||
|
||||
def __call__(self):
|
||||
self.executed = True
|
||||
|
||||
max_retry = 3
|
||||
|
||||
RetryObj = Retry()
|
||||
|
||||
c = Connection('client',
|
||||
uri='xx',
|
||||
max_retry=max_retry,
|
||||
on_retry_func=RetryObj)
|
||||
c.conn = FAIL_CONN
|
||||
ff = Func()
|
||||
this_connect = c.connect(func=ff)
|
||||
with pytest.raises(exceptions.ConnectionConnectError):
|
||||
this_connect()
|
||||
assert RetryObj.times == max_retry
|
||||
assert not ff.executed
|
||||
RetryObj = Retry()
|
||||
|
||||
c.conn = PASS_CONN
|
||||
this_connect = c.connect(func=ff)
|
||||
this_connect()
|
||||
assert ff.executed
|
||||
assert RetryObj.times == 0
|
||||
|
||||
this_connect = c.connect(func=None)
|
||||
with pytest.raises(TypeError):
|
||||
this_connect()
|
||||
|
||||
errors = []
|
||||
|
||||
def error_handler(err):
|
||||
errors.append(err)
|
||||
|
||||
this_connect = c.connect(func=None, exception_handler=error_handler)
|
||||
this_connect()
|
||||
assert len(errors) == 1
|
|
@ -1,39 +0,0 @@
|
|||
import logging
|
||||
import pytest
|
||||
from mishards.factories import TableFiles, Tables, TableFilesFactory, TablesFactory
|
||||
from mishards import db, create_app, settings
|
||||
from mishards.factories import (
|
||||
Tables, TableFiles,
|
||||
TablesFactory, TableFilesFactory
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('app')
|
||||
class TestModels:
|
||||
def test_files_to_search(self):
|
||||
table = TablesFactory()
|
||||
new_files_cnt = 5
|
||||
to_index_cnt = 10
|
||||
raw_cnt = 20
|
||||
backup_cnt = 12
|
||||
to_delete_cnt = 9
|
||||
index_cnt = 8
|
||||
new_index_cnt = 6
|
||||
new_merge_cnt = 11
|
||||
|
||||
new_files = TableFilesFactory.create_batch(new_files_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW, date=110)
|
||||
to_index_files = TableFilesFactory.create_batch(to_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_INDEX, date=110)
|
||||
raw_files = TableFilesFactory.create_batch(raw_cnt, table=table, file_type=TableFiles.FILE_TYPE_RAW, date=120)
|
||||
backup_files = TableFilesFactory.create_batch(backup_cnt, table=table, file_type=TableFiles.FILE_TYPE_BACKUP, date=110)
|
||||
index_files = TableFilesFactory.create_batch(index_cnt, table=table, file_type=TableFiles.FILE_TYPE_INDEX, date=110)
|
||||
new_index_files = TableFilesFactory.create_batch(new_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW_INDEX, date=110)
|
||||
new_merge_files = TableFilesFactory.create_batch(new_merge_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW_MERGE, date=110)
|
||||
to_delete_files = TableFilesFactory.create_batch(to_delete_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_DELETE, date=110)
|
||||
assert table.files_to_search().count() == raw_cnt + index_cnt + to_index_cnt
|
||||
|
||||
assert table.files_to_search([(100, 115)]).count() == index_cnt + to_index_cnt
|
||||
assert table.files_to_search([(111, 120)]).count() == 0
|
||||
assert table.files_to_search([(111, 121)]).count() == raw_cnt
|
||||
assert table.files_to_search([(110, 121)]).count() == raw_cnt + index_cnt + to_index_cnt
|
|
@ -1,279 +0,0 @@
|
|||
import logging
|
||||
import pytest
|
||||
import mock
|
||||
import datetime
|
||||
import random
|
||||
import faker
|
||||
import inspect
|
||||
from milvus import Milvus
|
||||
from milvus.client.types import Status, IndexType, MetricType
|
||||
from milvus.client.abstract import IndexParam, TableSchema
|
||||
from milvus.grpc_gen import status_pb2, milvus_pb2
|
||||
from mishards import db, create_app, settings
|
||||
from mishards.service_handler import ServiceHandler
|
||||
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
|
||||
from mishards.factories import TableFilesFactory, TablesFactory, TableFiles, Tables
|
||||
from mishards.routings import RouterMixin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OK = Status(code=Status.SUCCESS, message='Success')
|
||||
BAD = Status(code=Status.PERMISSION_DENIED, message='Fail')
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('started_app')
|
||||
class TestServer:
|
||||
@property
|
||||
def client(self):
|
||||
m = Milvus()
|
||||
m.connect(host='localhost', port=settings.SERVER_TEST_PORT)
|
||||
return m
|
||||
|
||||
def test_server_start(self, started_app):
|
||||
assert started_app.conn_mgr.metas.get('WOSERVER') == settings.WOSERVER
|
||||
|
||||
def test_cmd(self, started_app):
|
||||
ServiceHandler._get_server_version = mock.MagicMock(return_value=(OK,
|
||||
''))
|
||||
status, _ = self.client.server_version()
|
||||
assert status.OK()
|
||||
|
||||
Parser.parse_proto_Command = mock.MagicMock(return_value=(BAD, 'cmd'))
|
||||
status, _ = self.client.server_version()
|
||||
assert not status.OK()
|
||||
|
||||
def test_drop_index(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
ServiceHandler._drop_index = mock.MagicMock(return_value=OK)
|
||||
status = self.client.drop_index(table_name)
|
||||
assert status.OK()
|
||||
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(BAD, table_name))
|
||||
status = self.client.drop_index(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
def test_describe_index(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
index_type = IndexType.FLAT
|
||||
nlist = 1
|
||||
index_param = IndexParam(table_name=table_name,
|
||||
index_type=index_type,
|
||||
nlist=nlist)
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(OK, table_name))
|
||||
ServiceHandler._describe_index = mock.MagicMock(
|
||||
return_value=(OK, index_param))
|
||||
status, ret = self.client.describe_index(table_name)
|
||||
assert status.OK()
|
||||
assert ret._table_name == index_param._table_name
|
||||
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(BAD, table_name))
|
||||
status, _ = self.client.describe_index(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
def test_preload(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(OK, table_name))
|
||||
ServiceHandler._preload_table = mock.MagicMock(return_value=OK)
|
||||
status = self.client.preload_table(table_name)
|
||||
assert status.OK()
|
||||
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(BAD, table_name))
|
||||
status = self.client.preload_table(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_delete_by_range(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
|
||||
unpacked = table_name, datetime.datetime.today(
|
||||
), datetime.datetime.today()
|
||||
|
||||
Parser.parse_proto_DeleteByRangeParam = mock.MagicMock(
|
||||
return_value=(OK, unpacked))
|
||||
ServiceHandler._delete_by_range = mock.MagicMock(return_value=OK)
|
||||
status = self.client.delete_vectors_by_range(
|
||||
*unpacked)
|
||||
assert status.OK()
|
||||
|
||||
Parser.parse_proto_DeleteByRangeParam = mock.MagicMock(
|
||||
return_value=(BAD, unpacked))
|
||||
status = self.client.delete_vectors_by_range(
|
||||
*unpacked)
|
||||
assert not status.OK()
|
||||
|
||||
def test_count_table(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
count = random.randint(100, 200)
|
||||
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(OK, table_name))
|
||||
ServiceHandler._count_table = mock.MagicMock(return_value=(OK, count))
|
||||
status, ret = self.client.get_table_row_count(table_name)
|
||||
assert status.OK()
|
||||
assert ret == count
|
||||
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(BAD, table_name))
|
||||
status, _ = self.client.get_table_row_count(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
def test_show_tables(self, started_app):
|
||||
tables = ['t1', 't2']
|
||||
ServiceHandler._show_tables = mock.MagicMock(return_value=(OK, tables))
|
||||
status, ret = self.client.show_tables()
|
||||
assert status.OK()
|
||||
assert ret == tables
|
||||
|
||||
def test_describe_table(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
dimension = 128
|
||||
nlist = 1
|
||||
table_schema = TableSchema(table_name=table_name,
|
||||
index_file_size=100,
|
||||
metric_type=MetricType.L2,
|
||||
dimension=dimension)
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(OK, table_schema.table_name))
|
||||
ServiceHandler._describe_table = mock.MagicMock(
|
||||
return_value=(OK, table_schema))
|
||||
status, _ = self.client.describe_table(table_name)
|
||||
assert status.OK()
|
||||
|
||||
ServiceHandler._describe_table = mock.MagicMock(
|
||||
return_value=(BAD, table_schema))
|
||||
status, _ = self.client.describe_table(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
Parser.parse_proto_TableName = mock.MagicMock(return_value=(BAD,
|
||||
'cmd'))
|
||||
status, ret = self.client.describe_table(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
def test_insert(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
vectors = [[random.random() for _ in range(16)] for _ in range(10)]
|
||||
ids = [random.randint(1000000, 20000000) for _ in range(10)]
|
||||
ServiceHandler._add_vectors = mock.MagicMock(return_value=(OK, ids))
|
||||
status, ret = self.client.add_vectors(
|
||||
table_name=table_name, records=vectors)
|
||||
assert status.OK()
|
||||
assert ids == ret
|
||||
|
||||
def test_create_index(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
unpacks = table_name, None
|
||||
Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(OK,
|
||||
unpacks))
|
||||
ServiceHandler._create_index = mock.MagicMock(return_value=OK)
|
||||
status = self.client.create_index(table_name=table_name)
|
||||
assert status.OK()
|
||||
|
||||
Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(BAD,
|
||||
None))
|
||||
status = self.client.create_index(table_name=table_name)
|
||||
assert not status.OK()
|
||||
|
||||
def test_drop_table(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(OK, table_name))
|
||||
ServiceHandler._delete_table = mock.MagicMock(return_value=OK)
|
||||
status = self.client.delete_table(table_name=table_name)
|
||||
assert status.OK()
|
||||
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(BAD, table_name))
|
||||
status = self.client.delete_table(table_name=table_name)
|
||||
assert not status.OK()
|
||||
|
||||
def test_has_table(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(OK, table_name))
|
||||
ServiceHandler._has_table = mock.MagicMock(return_value=(OK, True))
|
||||
has = self.client.has_table(table_name=table_name)
|
||||
assert has
|
||||
|
||||
Parser.parse_proto_TableName = mock.MagicMock(
|
||||
return_value=(BAD, table_name))
|
||||
status, has = self.client.has_table(table_name=table_name)
|
||||
assert not status.OK()
|
||||
assert not has
|
||||
|
||||
def test_create_table(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
dimension = 128
|
||||
table_schema = dict(table_name=table_name,
|
||||
index_file_size=100,
|
||||
metric_type=MetricType.L2,
|
||||
dimension=dimension)
|
||||
|
||||
ServiceHandler._create_table = mock.MagicMock(return_value=OK)
|
||||
status = self.client.create_table(table_schema)
|
||||
assert status.OK()
|
||||
|
||||
Parser.parse_proto_TableSchema = mock.MagicMock(return_value=(BAD,
|
||||
None))
|
||||
status = self.client.create_table(table_schema)
|
||||
assert not status.OK()
|
||||
|
||||
def random_data(self, n, dimension):
|
||||
return [[random.random() for _ in range(dimension)] for _ in range(n)]
|
||||
|
||||
def test_search(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
to_index_cnt = random.randint(10, 20)
|
||||
table = TablesFactory(table_id=table_name, state=Tables.NORMAL)
|
||||
to_index_files = TableFilesFactory.create_batch(
|
||||
to_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_INDEX)
|
||||
topk = random.randint(5, 10)
|
||||
nq = random.randint(5, 10)
|
||||
param = {
|
||||
'table_name': table_name,
|
||||
'query_records': self.random_data(nq, table.dimension),
|
||||
'top_k': topk,
|
||||
'nprobe': 2049
|
||||
}
|
||||
|
||||
result = [
|
||||
milvus_pb2.TopKQueryResult(query_result_arrays=[
|
||||
milvus_pb2.QueryResult(id=i, distance=random.random())
|
||||
for i in range(topk)
|
||||
]) for i in range(nq)
|
||||
]
|
||||
|
||||
mock_results = milvus_pb2.TopKQueryResultList(status=status_pb2.Status(
|
||||
error_code=status_pb2.SUCCESS, reason="Success"),
|
||||
topk_query_result=result)
|
||||
|
||||
table_schema = TableSchema(table_name=table_name,
|
||||
index_file_size=table.index_file_size,
|
||||
metric_type=table.metric_type,
|
||||
dimension=table.dimension)
|
||||
|
||||
status, _ = self.client.search_vectors(**param)
|
||||
assert status.code == Status.ILLEGAL_ARGUMENT
|
||||
|
||||
param['nprobe'] = 2048
|
||||
RouterMixin.connection = mock.MagicMock(return_value=Milvus())
|
||||
RouterMixin.query_conn = mock.MagicMock(return_value=Milvus())
|
||||
Milvus.describe_table = mock.MagicMock(return_value=(BAD,
|
||||
table_schema))
|
||||
status, ret = self.client.search_vectors(**param)
|
||||
assert status.code == Status.TABLE_NOT_EXISTS
|
||||
|
||||
Milvus.describe_table = mock.MagicMock(return_value=(OK, table_schema))
|
||||
Milvus.search_vectors_in_files = mock.MagicMock(
|
||||
return_value=mock_results)
|
||||
|
||||
status, ret = self.client.search_vectors(**param)
|
||||
assert status.OK()
|
||||
assert len(ret) == nq
|
|
@ -1,20 +0,0 @@
|
|||
import datetime
|
||||
from mishards import exceptions
|
||||
|
||||
|
||||
def format_date(start, end):
|
||||
return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day,
|
||||
(end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day)
|
||||
|
||||
|
||||
def range_to_date(range_obj, metadata=None):
|
||||
try:
|
||||
start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d')
|
||||
end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d')
|
||||
assert start < end
|
||||
except (ValueError, AssertionError):
|
||||
raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format(
|
||||
range_obj.start_date, range_obj.end_date),
|
||||
metadata=metadata)
|
||||
|
||||
return format_date(start, end)
|
|
@ -1,36 +0,0 @@
|
|||
environs==4.2.0
|
||||
factory-boy==2.12.0
|
||||
Faker==1.0.7
|
||||
fire==0.1.3
|
||||
google-auth==1.6.3
|
||||
grpcio==1.22.0
|
||||
grpcio-tools==1.22.0
|
||||
kubernetes==10.0.1
|
||||
MarkupSafe==1.1.1
|
||||
marshmallow==2.19.5
|
||||
pymysql==0.9.3
|
||||
protobuf==3.9.1
|
||||
py==1.8.0
|
||||
pyasn1==0.4.7
|
||||
pyasn1-modules==0.2.6
|
||||
pylint==2.3.1
|
||||
pymilvus-test==0.2.28
|
||||
#pymilvus==0.2.0
|
||||
pyparsing==2.4.0
|
||||
pytest==4.6.3
|
||||
pytest-level==0.1.1
|
||||
pytest-print==0.1.2
|
||||
pytest-repeat==0.8.0
|
||||
pytest-timeout==1.3.3
|
||||
python-dateutil==2.8.0
|
||||
python-dotenv==0.10.3
|
||||
pytz==2019.1
|
||||
requests==2.22.0
|
||||
requests-oauthlib==1.2.0
|
||||
rsa==4.0
|
||||
six==1.12.0
|
||||
SQLAlchemy==1.3.5
|
||||
urllib3==1.25.3
|
||||
jaeger-client>=3.4.0
|
||||
grpcio-opentracing>=1.0
|
||||
mock==2.0.0
|
|
@ -1,28 +0,0 @@
|
|||
import logging
|
||||
import inspect
|
||||
# from utils import singleton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
PROVIDERS = {}
|
||||
|
||||
@classmethod
|
||||
def register_service_provider(cls, target):
|
||||
if inspect.isfunction(target):
|
||||
cls.PROVIDERS[target.__name__] = target
|
||||
elif inspect.isclass(target):
|
||||
name = target.__dict__.get('NAME', None)
|
||||
name = name if name else target.__class__.__name__
|
||||
cls.PROVIDERS[name] = target
|
||||
else:
|
||||
assert False, 'Cannot register_service_provider for: {}'.format(target)
|
||||
return target
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, name):
|
||||
return cls.PROVIDERS.get(name, None)
|
||||
|
||||
|
||||
from sd import kubernetes_provider, static_provider
|
|
@ -1,331 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
if __name__ == '__main__':
|
||||
sys.path.append(os.path.dirname(os.path.dirname(
|
||||
os.path.abspath(__file__))))
|
||||
|
||||
import re
|
||||
import logging
|
||||
import time
|
||||
import copy
|
||||
import threading
|
||||
import queue
|
||||
import enum
|
||||
from kubernetes import client, config, watch
|
||||
|
||||
from utils import singleton
|
||||
from sd import ProviderManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
INCLUSTER_NAMESPACE_PATH = '/var/run/secrets/kubernetes.io/serviceaccount/namespace'
|
||||
|
||||
|
||||
class EventType(enum.Enum):
|
||||
PodHeartBeat = 1
|
||||
Watch = 2
|
||||
|
||||
|
||||
class K8SMixin:
|
||||
def __init__(self, namespace, in_cluster=False, **kwargs):
|
||||
self.namespace = namespace
|
||||
self.in_cluster = in_cluster
|
||||
self.kwargs = kwargs
|
||||
self.v1 = kwargs.get('v1', None)
|
||||
if not self.namespace:
|
||||
self.namespace = open(INCLUSTER_NAMESPACE_PATH).read()
|
||||
|
||||
if not self.v1:
|
||||
config.load_incluster_config(
|
||||
) if self.in_cluster else config.load_kube_config()
|
||||
self.v1 = client.CoreV1Api()
|
||||
|
||||
|
||||
class K8SHeartbeatHandler(threading.Thread, K8SMixin):
|
||||
def __init__(self,
|
||||
message_queue,
|
||||
namespace,
|
||||
label_selector,
|
||||
in_cluster=False,
|
||||
**kwargs):
|
||||
K8SMixin.__init__(self,
|
||||
namespace=namespace,
|
||||
in_cluster=in_cluster,
|
||||
**kwargs)
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = message_queue
|
||||
self.terminate = False
|
||||
self.label_selector = label_selector
|
||||
self.poll_interval = kwargs.get('poll_interval', 5)
|
||||
|
||||
def run(self):
|
||||
while not self.terminate:
|
||||
try:
|
||||
pods = self.v1.list_namespaced_pod(
|
||||
namespace=self.namespace,
|
||||
label_selector=self.label_selector)
|
||||
event_message = {'eType': EventType.PodHeartBeat, 'events': []}
|
||||
for item in pods.items:
|
||||
pod = self.v1.read_namespaced_pod(name=item.metadata.name,
|
||||
namespace=self.namespace)
|
||||
name = pod.metadata.name
|
||||
ip = pod.status.pod_ip
|
||||
phase = pod.status.phase
|
||||
reason = pod.status.reason
|
||||
message = pod.status.message
|
||||
ready = True if phase == 'Running' else False
|
||||
|
||||
pod_event = dict(pod=name,
|
||||
ip=ip,
|
||||
ready=ready,
|
||||
reason=reason,
|
||||
message=message)
|
||||
|
||||
event_message['events'].append(pod_event)
|
||||
|
||||
self.queue.put(event_message)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(exc)
|
||||
|
||||
time.sleep(self.poll_interval)
|
||||
|
||||
def stop(self):
|
||||
self.terminate = True
|
||||
|
||||
|
||||
class K8SEventListener(threading.Thread, K8SMixin):
|
||||
def __init__(self, message_queue, namespace, in_cluster=False, **kwargs):
|
||||
K8SMixin.__init__(self,
|
||||
namespace=namespace,
|
||||
in_cluster=in_cluster,
|
||||
**kwargs)
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = message_queue
|
||||
self.terminate = False
|
||||
self.at_start_up = True
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
def stop(self):
|
||||
self.terminate = True
|
||||
self._stop_event.set()
|
||||
|
||||
def run(self):
|
||||
resource_version = ''
|
||||
w = watch.Watch()
|
||||
for event in w.stream(self.v1.list_namespaced_event,
|
||||
namespace=self.namespace,
|
||||
field_selector='involvedObject.kind=Pod'):
|
||||
if self.terminate:
|
||||
break
|
||||
|
||||
resource_version = int(event['object'].metadata.resource_version)
|
||||
|
||||
info = dict(
|
||||
eType=EventType.Watch,
|
||||
pod=event['object'].involved_object.name,
|
||||
reason=event['object'].reason,
|
||||
message=event['object'].message,
|
||||
start_up=self.at_start_up,
|
||||
)
|
||||
self.at_start_up = False
|
||||
# logger.info('Received event: {}'.format(info))
|
||||
self.queue.put(info)
|
||||
|
||||
|
||||
class EventHandler(threading.Thread):
|
||||
def __init__(self, mgr, message_queue, namespace, pod_patt, **kwargs):
|
||||
threading.Thread.__init__(self)
|
||||
self.mgr = mgr
|
||||
self.queue = message_queue
|
||||
self.kwargs = kwargs
|
||||
self.terminate = False
|
||||
self.pod_patt = re.compile(pod_patt)
|
||||
self.namespace = namespace
|
||||
|
||||
def stop(self):
|
||||
self.terminate = True
|
||||
|
||||
def on_drop(self, event, **kwargs):
|
||||
pass
|
||||
|
||||
def on_pod_started(self, event, **kwargs):
|
||||
try_cnt = 3
|
||||
pod = None
|
||||
while try_cnt > 0:
|
||||
try_cnt -= 1
|
||||
try:
|
||||
pod = self.mgr.v1.read_namespaced_pod(name=event['pod'],
|
||||
namespace=self.namespace)
|
||||
if not pod.status.pod_ip:
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
break
|
||||
except client.rest.ApiException as exc:
|
||||
time.sleep(0.5)
|
||||
|
||||
if try_cnt <= 0 and not pod:
|
||||
if not event['start_up']:
|
||||
logger.error('Pod {} is started but cannot read pod'.format(
|
||||
event['pod']))
|
||||
return
|
||||
elif try_cnt <= 0 and not pod.status.pod_ip:
|
||||
logger.warning('NoPodIPFoundError')
|
||||
return
|
||||
|
||||
logger.info('Register POD {} with IP {}'.format(
|
||||
pod.metadata.name, pod.status.pod_ip))
|
||||
self.mgr.add_pod(name=pod.metadata.name, ip=pod.status.pod_ip)
|
||||
|
||||
def on_pod_killing(self, event, **kwargs):
|
||||
logger.info('Unregister POD {}'.format(event['pod']))
|
||||
self.mgr.delete_pod(name=event['pod'])
|
||||
|
||||
def on_pod_heartbeat(self, event, **kwargs):
|
||||
names = self.mgr.conn_mgr.conn_names
|
||||
|
||||
running_names = set()
|
||||
for each_event in event['events']:
|
||||
if each_event['ready']:
|
||||
self.mgr.add_pod(name=each_event['pod'], ip=each_event['ip'])
|
||||
running_names.add(each_event['pod'])
|
||||
else:
|
||||
self.mgr.delete_pod(name=each_event['pod'])
|
||||
|
||||
to_delete = names - running_names
|
||||
for name in to_delete:
|
||||
self.mgr.delete_pod(name)
|
||||
|
||||
logger.info(self.mgr.conn_mgr.conn_names)
|
||||
|
||||
def handle_event(self, event):
|
||||
if event['eType'] == EventType.PodHeartBeat:
|
||||
return self.on_pod_heartbeat(event)
|
||||
|
||||
if not event or (event['reason'] not in ('Started', 'Killing')):
|
||||
return self.on_drop(event)
|
||||
|
||||
if not re.match(self.pod_patt, event['pod']):
|
||||
return self.on_drop(event)
|
||||
|
||||
logger.info('Handling event: {}'.format(event))
|
||||
|
||||
if event['reason'] == 'Started':
|
||||
return self.on_pod_started(event)
|
||||
|
||||
return self.on_pod_killing(event)
|
||||
|
||||
def run(self):
|
||||
while not self.terminate:
|
||||
try:
|
||||
event = self.queue.get(timeout=1)
|
||||
self.handle_event(event)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
||||
class KubernetesProviderSettings:
|
||||
def __init__(self, namespace, pod_patt, label_selector, in_cluster,
|
||||
poll_interval, port=None, **kwargs):
|
||||
self.namespace = namespace
|
||||
self.pod_patt = pod_patt
|
||||
self.label_selector = label_selector
|
||||
self.in_cluster = in_cluster
|
||||
self.poll_interval = poll_interval
|
||||
self.port = int(port) if port else 19530
|
||||
|
||||
|
||||
@singleton
|
||||
@ProviderManager.register_service_provider
|
||||
class KubernetesProvider(object):
|
||||
NAME = 'Kubernetes'
|
||||
|
||||
def __init__(self, settings, conn_mgr, **kwargs):
|
||||
self.namespace = settings.namespace
|
||||
self.pod_patt = settings.pod_patt
|
||||
self.label_selector = settings.label_selector
|
||||
self.in_cluster = settings.in_cluster
|
||||
self.poll_interval = settings.poll_interval
|
||||
self.port = settings.port
|
||||
self.kwargs = kwargs
|
||||
self.queue = queue.Queue()
|
||||
|
||||
self.conn_mgr = conn_mgr
|
||||
|
||||
if not self.namespace:
|
||||
self.namespace = open(incluster_namespace_path).read()
|
||||
|
||||
config.load_incluster_config(
|
||||
) if self.in_cluster else config.load_kube_config()
|
||||
self.v1 = client.CoreV1Api()
|
||||
|
||||
self.listener = K8SEventListener(message_queue=self.queue,
|
||||
namespace=self.namespace,
|
||||
in_cluster=self.in_cluster,
|
||||
v1=self.v1,
|
||||
**kwargs)
|
||||
|
||||
self.pod_heartbeater = K8SHeartbeatHandler(
|
||||
message_queue=self.queue,
|
||||
namespace=self.namespace,
|
||||
label_selector=self.label_selector,
|
||||
in_cluster=self.in_cluster,
|
||||
v1=self.v1,
|
||||
poll_interval=self.poll_interval,
|
||||
**kwargs)
|
||||
|
||||
self.event_handler = EventHandler(mgr=self,
|
||||
message_queue=self.queue,
|
||||
namespace=self.namespace,
|
||||
pod_patt=self.pod_patt,
|
||||
**kwargs)
|
||||
|
||||
def add_pod(self, name, ip):
|
||||
self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port))
|
||||
|
||||
def delete_pod(self, name):
|
||||
self.conn_mgr.unregister(name)
|
||||
|
||||
def start(self):
|
||||
self.listener.daemon = True
|
||||
self.listener.start()
|
||||
self.event_handler.start()
|
||||
|
||||
self.pod_heartbeater.start()
|
||||
|
||||
def stop(self):
|
||||
self.listener.stop()
|
||||
self.pod_heartbeater.stop()
|
||||
self.event_handler.stop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
class Connect:
|
||||
def register(self, name, value):
|
||||
logger.error('Register: {} - {}'.format(name, value))
|
||||
|
||||
def unregister(self, name):
|
||||
logger.error('Unregister: {}'.format(name))
|
||||
|
||||
@property
|
||||
def conn_names(self):
|
||||
return set()
|
||||
|
||||
connect_mgr = Connect()
|
||||
|
||||
settings = KubernetesProviderSettings(namespace='xp',
|
||||
pod_patt=".*-ro-servers-.*",
|
||||
label_selector='tier=ro-servers',
|
||||
poll_interval=5,
|
||||
in_cluster=False)
|
||||
|
||||
provider_class = ProviderManager.get_provider('Kubernetes')
|
||||
t = provider_class(conn_mgr=connect_mgr, settings=settings)
|
||||
t.start()
|
||||
cnt = 100
|
||||
while cnt > 0:
|
||||
time.sleep(2)
|
||||
cnt -= 1
|
||||
t.stop()
|
|
@ -1,39 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
if __name__ == '__main__':
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import socket
|
||||
from utils import singleton
|
||||
from sd import ProviderManager
|
||||
|
||||
|
||||
class StaticProviderSettings:
|
||||
def __init__(self, hosts, port=None):
|
||||
self.hosts = hosts
|
||||
self.port = int(port) if port else 19530
|
||||
|
||||
|
||||
@singleton
|
||||
@ProviderManager.register_service_provider
|
||||
class KubernetesProvider(object):
|
||||
NAME = 'Static'
|
||||
|
||||
def __init__(self, settings, conn_mgr, **kwargs):
|
||||
self.conn_mgr = conn_mgr
|
||||
self.hosts = [socket.gethostbyname(host) for host in settings.hosts]
|
||||
self.port = settings.port
|
||||
|
||||
def start(self):
|
||||
for host in self.hosts:
|
||||
self.add_pod(host, host)
|
||||
|
||||
def stop(self):
|
||||
for host in self.hosts:
|
||||
self.delete_pod(host)
|
||||
|
||||
def add_pod(self, name, ip):
|
||||
self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port))
|
||||
|
||||
def delete_pod(self, name):
|
||||
self.conn_mgr.unregister(name)
|
|
@ -1,4 +0,0 @@
|
|||
[tool:pytest]
|
||||
testpaths = mishards
|
||||
log_cli=true
|
||||
log_cli_level=info
|
|
@ -1,45 +0,0 @@
|
|||
version: "2.3"
|
||||
services:
|
||||
milvus:
|
||||
runtime: nvidia
|
||||
restart: always
|
||||
image: registry.zilliz.com/milvus/engine:branch-0.5.0-release-4316de
|
||||
# ports:
|
||||
# - "0.0.0.0:19530:19530"
|
||||
volumes:
|
||||
- /tmp/milvus/db:/opt/milvus/db
|
||||
|
||||
jaeger:
|
||||
restart: always
|
||||
image: jaegertracing/all-in-one:1.14
|
||||
ports:
|
||||
- "0.0.0.0:5775:5775/udp"
|
||||
- "0.0.0.0:16686:16686"
|
||||
- "0.0.0.0:9441:9441"
|
||||
environment:
|
||||
COLLECTOR_ZIPKIN_HTTP_PORT: 9411
|
||||
|
||||
mishards:
|
||||
restart: always
|
||||
image: registry.zilliz.com/milvus/mishards:v0.0.4
|
||||
ports:
|
||||
- "0.0.0.0:19530:19531"
|
||||
- "0.0.0.0:19532:19532"
|
||||
volumes:
|
||||
- /tmp/milvus/db:/tmp/milvus/db
|
||||
# - /tmp/mishards_env:/source/mishards/.env
|
||||
command: ["python", "mishards/main.py"]
|
||||
environment:
|
||||
FROM_EXAMPLE: 'true'
|
||||
DEBUG: 'true'
|
||||
SERVER_PORT: 19531
|
||||
WOSERVER: tcp://milvus:19530
|
||||
SD_STATIC_HOSTS: milvus
|
||||
TRACING_TYPE: jaeger
|
||||
TRACING_SERVICE_NAME: mishards-demo
|
||||
TRACING_REPORTING_HOST: jaeger
|
||||
TRACING_REPORTING_PORT: 5775
|
||||
|
||||
depends_on:
|
||||
- milvus
|
||||
- jaeger
|
|
@ -1,43 +0,0 @@
|
|||
from contextlib import contextmanager
|
||||
|
||||
|
||||
def empty_server_interceptor_decorator(target_server, interceptor):
|
||||
return target_server
|
||||
|
||||
|
||||
@contextmanager
|
||||
def EmptySpan(*args, **kwargs):
|
||||
yield None
|
||||
return
|
||||
|
||||
|
||||
class Tracer:
|
||||
def __init__(self,
|
||||
tracer=None,
|
||||
interceptor=None,
|
||||
server_decorator=empty_server_interceptor_decorator):
|
||||
self.tracer = tracer
|
||||
self.interceptor = interceptor
|
||||
self.server_decorator = server_decorator
|
||||
|
||||
def decorate(self, server):
|
||||
return self.server_decorator(server, self.interceptor)
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
return self.tracer is None
|
||||
|
||||
def close(self):
|
||||
self.tracer and self.tracer.close()
|
||||
|
||||
def start_span(self,
|
||||
operation_name=None,
|
||||
child_of=None,
|
||||
references=None,
|
||||
tags=None,
|
||||
start_time=None,
|
||||
ignore_active_span=False):
|
||||
if self.empty:
|
||||
return EmptySpan()
|
||||
return self.tracer.start_span(operation_name, child_of, references,
|
||||
tags, start_time, ignore_active_span)
|
|
@ -1,40 +0,0 @@
|
|||
import logging
|
||||
from jaeger_client import Config
|
||||
from grpc_opentracing.grpcext import intercept_server
|
||||
from grpc_opentracing import open_tracing_server_interceptor
|
||||
|
||||
from tracing import (Tracer, empty_server_interceptor_decorator)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TracerFactory:
|
||||
@classmethod
|
||||
def new_tracer(cls,
|
||||
tracer_type,
|
||||
tracer_config,
|
||||
span_decorator=None,
|
||||
**kwargs):
|
||||
if not tracer_type:
|
||||
return Tracer()
|
||||
config = tracer_config.TRACING_CONFIG
|
||||
service_name = tracer_config.TRACING_SERVICE_NAME
|
||||
validate = tracer_config.TRACING_VALIDATE
|
||||
# if not tracer_type:
|
||||
# tracer_type = 'jaeger'
|
||||
# config = tracer_config.DEFAULT_TRACING_CONFIG
|
||||
|
||||
if tracer_type.lower() == 'jaeger':
|
||||
config = Config(config=config,
|
||||
service_name=service_name,
|
||||
validate=validate)
|
||||
|
||||
tracer = config.initialize_tracer()
|
||||
tracer_interceptor = open_tracing_server_interceptor(
|
||||
tracer,
|
||||
log_payloads=tracer_config.TRACING_LOG_PAYLOAD,
|
||||
span_decorator=span_decorator)
|
||||
|
||||
return Tracer(tracer, tracer_interceptor, intercept_server)
|
||||
|
||||
assert False, 'Unsupported tracer type: {}'.format(tracer_type)
|
|
@ -1,11 +0,0 @@
|
|||
from functools import wraps
|
||||
|
||||
|
||||
def singleton(cls):
|
||||
instances = {}
|
||||
@wraps(cls)
|
||||
def getinstance(*args, **kw):
|
||||
if cls not in instances:
|
||||
instances[cls] = cls(*args, **kw)
|
||||
return instances[cls]
|
||||
return getinstance
|
|
@ -1,152 +0,0 @@
|
|||
import os
|
||||
import datetime
|
||||
from pytz import timezone
|
||||
from logging import Filter
|
||||
import logging.config
|
||||
|
||||
|
||||
class InfoFilter(logging.Filter):
|
||||
def filter(self, rec):
|
||||
return rec.levelno == logging.INFO
|
||||
|
||||
|
||||
class DebugFilter(logging.Filter):
|
||||
def filter(self, rec):
|
||||
return rec.levelno == logging.DEBUG
|
||||
|
||||
|
||||
class WarnFilter(logging.Filter):
|
||||
def filter(self, rec):
|
||||
return rec.levelno == logging.WARN
|
||||
|
||||
|
||||
class ErrorFilter(logging.Filter):
|
||||
def filter(self, rec):
|
||||
return rec.levelno == logging.ERROR
|
||||
|
||||
|
||||
class CriticalFilter(logging.Filter):
|
||||
def filter(self, rec):
|
||||
return rec.levelno == logging.CRITICAL
|
||||
|
||||
|
||||
COLORS = {
|
||||
'HEADER': '\033[95m',
|
||||
'INFO': '\033[92m',
|
||||
'DEBUG': '\033[94m',
|
||||
'WARNING': '\033[93m',
|
||||
'ERROR': '\033[95m',
|
||||
'CRITICAL': '\033[91m',
|
||||
'ENDC': '\033[0m',
|
||||
}
|
||||
|
||||
|
||||
class ColorFulFormatColMixin:
|
||||
def format_col(self, message_str, level_name):
|
||||
if level_name in COLORS.keys():
|
||||
message_str = COLORS.get(level_name) + message_str + COLORS.get(
|
||||
'ENDC')
|
||||
return message_str
|
||||
|
||||
|
||||
class ColorfulFormatter(logging.Formatter, ColorFulFormatColMixin):
|
||||
def format(self, record):
|
||||
message_str = super(ColorfulFormatter, self).format(record)
|
||||
|
||||
return self.format_col(message_str, level_name=record.levelname)
|
||||
|
||||
|
||||
def config(log_level, log_path, name, tz='UTC'):
|
||||
def build_log_file(level, log_path, name, tz):
|
||||
utc_now = datetime.datetime.utcnow()
|
||||
utc_tz = timezone('UTC')
|
||||
local_tz = timezone(tz)
|
||||
tznow = utc_now.replace(tzinfo=utc_tz).astimezone(local_tz)
|
||||
return '{}-{}-{}.log'.format(os.path.join(log_path, name), tznow.strftime("%m-%d-%Y-%H:%M:%S"),
|
||||
level)
|
||||
|
||||
if not os.path.exists(log_path):
|
||||
os.makedirs(log_path)
|
||||
|
||||
LOGGING = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
'default': {
|
||||
'format': '%(asctime)s | %(levelname)s | %(name)s | %(threadName)s: %(message)s (%(filename)s:%(lineno)s)',
|
||||
},
|
||||
'colorful_console': {
|
||||
'format': '%(asctime)s | %(levelname)s | %(name)s | %(threadName)s: %(message)s (%(filename)s:%(lineno)s)',
|
||||
'()': ColorfulFormatter,
|
||||
},
|
||||
},
|
||||
'filters': {
|
||||
'InfoFilter': {
|
||||
'()': InfoFilter,
|
||||
},
|
||||
'DebugFilter': {
|
||||
'()': DebugFilter,
|
||||
},
|
||||
'WarnFilter': {
|
||||
'()': WarnFilter,
|
||||
},
|
||||
'ErrorFilter': {
|
||||
'()': ErrorFilter,
|
||||
},
|
||||
'CriticalFilter': {
|
||||
'()': CriticalFilter,
|
||||
},
|
||||
},
|
||||
'handlers': {
|
||||
'milvus_celery_console': {
|
||||
'class': 'logging.StreamHandler',
|
||||
'formatter': 'colorful_console',
|
||||
},
|
||||
'milvus_debug_file': {
|
||||
'level': 'DEBUG',
|
||||
'filters': ['DebugFilter'],
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'formatter': 'default',
|
||||
'filename': build_log_file('debug', log_path, name, tz)
|
||||
},
|
||||
'milvus_info_file': {
|
||||
'level': 'INFO',
|
||||
'filters': ['InfoFilter'],
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'formatter': 'default',
|
||||
'filename': build_log_file('info', log_path, name, tz)
|
||||
},
|
||||
'milvus_warn_file': {
|
||||
'level': 'WARN',
|
||||
'filters': ['WarnFilter'],
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'formatter': 'default',
|
||||
'filename': build_log_file('warn', log_path, name, tz)
|
||||
},
|
||||
'milvus_error_file': {
|
||||
'level': 'ERROR',
|
||||
'filters': ['ErrorFilter'],
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'formatter': 'default',
|
||||
'filename': build_log_file('error', log_path, name, tz)
|
||||
},
|
||||
'milvus_critical_file': {
|
||||
'level': 'CRITICAL',
|
||||
'filters': ['CriticalFilter'],
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'formatter': 'default',
|
||||
'filename': build_log_file('critical', log_path, name, tz)
|
||||
},
|
||||
},
|
||||
'loggers': {
|
||||
'': {
|
||||
'handlers': ['milvus_celery_console', 'milvus_info_file', 'milvus_debug_file', 'milvus_warn_file',
|
||||
'milvus_error_file', 'milvus_critical_file'],
|
||||
'level': log_level,
|
||||
'propagate': False
|
||||
},
|
||||
},
|
||||
'propagate': False,
|
||||
}
|
||||
|
||||
logging.config.dictConfig(LOGGING)
|
Loading…
Reference in New Issue