mirror of https://github.com/milvus-io/milvus.git
* [skip ci](shards): export MAX_WORKERS as configurable parameter Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): skip mishards .env git info Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): support more robust static discovery host configuration Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): update static provider that terminate server if connection to downstream server error during startup Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): add topology.py Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): add connection pool Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): add topology test Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): refactory using topo Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): refactory static discovery using topo Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): refactory kubernetes discovery using topo Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): add more test for connection pool Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): export 19541 and 19542 for all_in_one demo Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): check version on new connection Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): mock connections Signed-off-by: peng.xu <peng.xu@zilliz.com> * [skip ci](shards): update tests Signed-off-by: peng.xu <peng.xu@zilliz.com>pull/1648/head
parent
4088f5e9a2
commit
5f2f8bdc8b
|
@ -31,3 +31,4 @@ cov_html/
|
|||
|
||||
# temp
|
||||
shards/all_in_one_with_mysql/metadata/
|
||||
shards/mishards/.env
|
||||
|
|
|
@ -4,6 +4,8 @@ services:
|
|||
runtime: nvidia
|
||||
restart: always
|
||||
image: milvusdb/milvus:0.6.0-gpu-d120719-2b40dd
|
||||
ports:
|
||||
- "0.0.0.0:19540:19530"
|
||||
volumes:
|
||||
- /tmp/milvus/db:/var/lib/milvus/db
|
||||
- ./wr_server.yml:/opt/milvus/conf/server_config.yaml
|
||||
|
@ -12,6 +14,8 @@ services:
|
|||
runtime: nvidia
|
||||
restart: always
|
||||
image: milvusdb/milvus:0.6.0-gpu-d120719-2b40dd
|
||||
ports:
|
||||
- "0.0.0.0:19541:19530"
|
||||
volumes:
|
||||
- /tmp/milvus/db:/var/lib/milvus/db
|
||||
- ./ro_server.yml:/opt/milvus/conf/server_config.yaml
|
||||
|
|
|
@ -2,8 +2,10 @@ import os
|
|||
import logging
|
||||
import pytest
|
||||
import grpc
|
||||
import mock
|
||||
import tempfile
|
||||
import shutil
|
||||
import time
|
||||
from mishards import settings, db, create_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -18,6 +20,9 @@ settings.TestingConfig.SQLALCHEMY_DATABASE_URI = 'sqlite:///{}?check_same_thread
|
|||
|
||||
@pytest.fixture
|
||||
def app(request):
|
||||
from mishards.connections import ConnectionGroup
|
||||
ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,))
|
||||
time.sleep(0.1)
|
||||
app = create_app(settings.TestingConfig)
|
||||
db.drop_all()
|
||||
db.create_all()
|
||||
|
|
|
@ -13,10 +13,10 @@ class DiscoveryFactory(BaseMixin):
|
|||
super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME)
|
||||
|
||||
def _create(self, plugin_class, **kwargs):
|
||||
conn_mgr = kwargs.pop('conn_mgr', None)
|
||||
if not conn_mgr:
|
||||
raise RuntimeError('Please pass conn_mgr to create discovery!')
|
||||
readonly_topo = kwargs.pop('readonly_topo', None)
|
||||
if not readonly_topo:
|
||||
raise RuntimeError('Please pass readonly_topo to create discovery!')
|
||||
|
||||
plugin_config = DiscoveryConfig.Create()
|
||||
plugin = plugin_class.Create(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs)
|
||||
plugin = plugin_class.Create(plugin_config=plugin_config, readonly_topo=readonly_topo, **kwargs)
|
||||
return plugin
|
||||
|
|
|
@ -181,7 +181,7 @@ class EventHandler(threading.Thread):
|
|||
self.mgr.delete_pod(name=event['pod'])
|
||||
|
||||
def on_pod_heartbeat(self, event, **kwargs):
|
||||
names = self.mgr.conn_mgr.conn_names
|
||||
names = self.mgr.readonly_topo.group_names
|
||||
|
||||
running_names = set()
|
||||
for each_event in event['events']:
|
||||
|
@ -195,7 +195,7 @@ class EventHandler(threading.Thread):
|
|||
for name in to_delete:
|
||||
self.mgr.delete_pod(name)
|
||||
|
||||
logger.info(self.mgr.conn_mgr.conn_names)
|
||||
logger.info(self.mgr.readonly_topo.group_names)
|
||||
|
||||
def handle_event(self, event):
|
||||
if event['eType'] == EventType.PodHeartBeat:
|
||||
|
@ -237,7 +237,7 @@ class KubernetesProviderSettings:
|
|||
class KubernetesProvider(object):
|
||||
name = 'kubernetes'
|
||||
|
||||
def __init__(self, plugin_config, conn_mgr, **kwargs):
|
||||
def __init__(self, plugin_config, readonly_topo, **kwargs):
|
||||
self.namespace = plugin_config.DISCOVERY_KUBERNETES_NAMESPACE
|
||||
self.pod_patt = plugin_config.DISCOVERY_KUBERNETES_POD_PATT
|
||||
self.label_selector = plugin_config.DISCOVERY_KUBERNETES_LABEL_SELECTOR
|
||||
|
@ -250,7 +250,7 @@ class KubernetesProvider(object):
|
|||
self.kwargs = kwargs
|
||||
self.queue = queue.Queue()
|
||||
|
||||
self.conn_mgr = conn_mgr
|
||||
self.readonly_topo = readonly_topo
|
||||
|
||||
if not self.namespace:
|
||||
self.namespace = open(incluster_namespace_path).read()
|
||||
|
@ -281,10 +281,24 @@ class KubernetesProvider(object):
|
|||
**kwargs)
|
||||
|
||||
def add_pod(self, name, ip):
|
||||
self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port))
|
||||
ok = True
|
||||
status = StatusType.OK
|
||||
try:
|
||||
uri = 'tcp://{}:{}'.format(ip, self.port)
|
||||
status, group = self.readonly_topo.create(name=name)
|
||||
if status == StatusType.OK:
|
||||
status, pool = group.create(name=name, uri=uri)
|
||||
except ConnectionConnectError as exc:
|
||||
ok = False
|
||||
logger.error('Connection error to: {}'.format(addr))
|
||||
|
||||
if ok and status == StatusType.OK:
|
||||
logger.info('KubernetesProvider Add Group \"{}\" Of 1 Address: {}'.format(name, uri))
|
||||
return ok
|
||||
|
||||
def delete_pod(self, name):
|
||||
self.conn_mgr.unregister(name)
|
||||
pool = self.readonly_topo.delete_group(name)
|
||||
return True
|
||||
|
||||
def start(self):
|
||||
self.listener.daemon = True
|
||||
|
@ -299,8 +313,8 @@ class KubernetesProvider(object):
|
|||
self.event_handler.stop()
|
||||
|
||||
@classmethod
|
||||
def Create(cls, conn_mgr, plugin_config, **kwargs):
|
||||
discovery = cls(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs)
|
||||
def Create(cls, readonly_topo, plugin_config, **kwargs):
|
||||
discovery = cls(config=plugin_config, readonly_topo=readonly_topo, **kwargs)
|
||||
return discovery
|
||||
|
||||
|
||||
|
|
|
@ -6,37 +6,72 @@ if __name__ == '__main__':
|
|||
import logging
|
||||
import socket
|
||||
from environs import Env
|
||||
from mishards.exceptions import ConnectionConnectError
|
||||
from mishards.topology import StatusType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
env = Env()
|
||||
|
||||
DELIMITER = ':'
|
||||
|
||||
def parse_host(addr):
|
||||
splited_arr = addr.split(DELIMITER)
|
||||
return splited_arr
|
||||
|
||||
def resolve_address(addr, default_port):
|
||||
addr_arr = parse_host(addr)
|
||||
assert len(addr_arr) >= 1 and len(addr_arr) <= 2, 'Invalid Addr: {}'.format(addr)
|
||||
port = addr_arr[1] if len(addr_arr) == 2 else default_port
|
||||
return '{}:{}'.format(socket.gethostbyname(addr_arr[0]), port)
|
||||
|
||||
class StaticDiscovery(object):
|
||||
name = 'static'
|
||||
|
||||
def __init__(self, config, conn_mgr, **kwargs):
|
||||
self.conn_mgr = conn_mgr
|
||||
def __init__(self, config, readonly_topo, **kwargs):
|
||||
self.readonly_topo = readonly_topo
|
||||
hosts = env.list('DISCOVERY_STATIC_HOSTS', [])
|
||||
self.port = env.int('DISCOVERY_STATIC_PORT', 19530)
|
||||
self.hosts = [socket.gethostbyname(host) for host in hosts]
|
||||
self.hosts = [resolve_address(host, self.port) for host in hosts]
|
||||
|
||||
def start(self):
|
||||
ok = True
|
||||
for host in self.hosts:
|
||||
self.add_pod(host, host)
|
||||
ok &= self.add_pod(host, host)
|
||||
if not ok: break
|
||||
if ok and len(self.hosts) == 0:
|
||||
logger.error('No address is specified')
|
||||
ok = False
|
||||
return ok
|
||||
|
||||
def stop(self):
|
||||
for host in self.hosts:
|
||||
self.delete_pod(host)
|
||||
|
||||
def add_pod(self, name, ip):
|
||||
self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port))
|
||||
def add_pod(self, name, addr):
|
||||
ok = True
|
||||
status = StatusType.OK
|
||||
try:
|
||||
uri = 'tcp://{}'.format(addr)
|
||||
status, group = self.readonly_topo.create(name=name)
|
||||
if status == StatusType.OK:
|
||||
status, pool = group.create(name=name, uri=uri)
|
||||
if status not in (StatusType.OK, StatusType.DUPLICATED):
|
||||
ok = False
|
||||
except ConnectionConnectError as exc:
|
||||
ok = False
|
||||
logger.error('Connection error to: {}'.format(addr))
|
||||
|
||||
if ok and status == StatusType.OK:
|
||||
logger.info('StaticDiscovery Add Static Group \"{}\" Of 1 Address: {}'.format(name, addr))
|
||||
return ok
|
||||
|
||||
def delete_pod(self, name):
|
||||
self.conn_mgr.unregister(name)
|
||||
pool = self.readonly_topo.delete_group(name)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def Create(cls, conn_mgr, plugin_config, **kwargs):
|
||||
discovery = cls(config=plugin_config, conn_mgr=conn_mgr, **kwargs)
|
||||
def Create(cls, readonly_topo, plugin_config, **kwargs):
|
||||
discovery = cls(config=plugin_config, readonly_topo=readonly_topo, **kwargs)
|
||||
return discovery
|
||||
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ DEBUG=True
|
|||
WOSERVER=tcp://127.0.0.1:19530
|
||||
SERVER_PORT=19535
|
||||
SERVER_TEST_PORT=19888
|
||||
MAX_WORKERS=50
|
||||
|
||||
#SQLALCHEMY_DATABASE_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4
|
||||
SQLALCHEMY_DATABASE_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False
|
||||
|
|
|
@ -15,12 +15,14 @@ def create_app(testing_config=None):
|
|||
pool_recycle=config.SQL_POOL_RECYCLE, pool_timeout=config.SQL_POOL_TIMEOUT,
|
||||
pool_pre_ping=config.SQL_POOL_PRE_PING, max_overflow=config.SQL_MAX_OVERFLOW)
|
||||
|
||||
from mishards.connections import ConnectionMgr
|
||||
connect_mgr = ConnectionMgr()
|
||||
from mishards.connections import ConnectionMgr, ConnectionTopology
|
||||
|
||||
readonly_topo = ConnectionTopology()
|
||||
writable_topo = ConnectionTopology()
|
||||
|
||||
from discovery.factory import DiscoveryFactory
|
||||
discover = DiscoveryFactory(config.DISCOVERY_PLUGIN_PATH).create(config.DISCOVERY_CLASS_NAME,
|
||||
conn_mgr=connect_mgr)
|
||||
readonly_topo=readonly_topo)
|
||||
|
||||
from mishards.grpc_utils import GrpcSpanDecorator
|
||||
from tracer.factory import TracerFactory
|
||||
|
@ -30,12 +32,15 @@ def create_app(testing_config=None):
|
|||
|
||||
from mishards.router.factory import RouterFactory
|
||||
router = RouterFactory(config.ROUTER_PLUGIN_PATH).create(config.ROUTER_CLASS_NAME,
|
||||
conn_mgr=connect_mgr)
|
||||
readonly_topo=readonly_topo,
|
||||
writable_topo=writable_topo)
|
||||
|
||||
grpc_server.init_app(conn_mgr=connect_mgr,
|
||||
grpc_server.init_app(writable_topo=writable_topo,
|
||||
readonly_topo=readonly_topo,
|
||||
tracer=tracer,
|
||||
router=router,
|
||||
discover=discover)
|
||||
discover=discover,
|
||||
max_workers=settings.MAX_WORKERS)
|
||||
|
||||
from mishards import exception_handlers
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import logging
|
||||
import threading
|
||||
import enum
|
||||
from functools import wraps
|
||||
from milvus import Milvus
|
||||
from milvus.client.hooks import BaseSearchHook
|
||||
|
||||
from mishards import (settings, exceptions)
|
||||
from mishards import (settings, exceptions, topology)
|
||||
from utils import singleton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -81,6 +82,140 @@ class Connection:
|
|||
raise e
|
||||
return inner
|
||||
|
||||
def __str__(self):
|
||||
return '<Connection: {}:{}>'.format(self.name, id(self))
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class ProxyMixin:
|
||||
def __getattr__(self, name):
|
||||
target = self.__dict__.get(name, None)
|
||||
if target or not self.connection:
|
||||
return target
|
||||
return getattr(self.connection, name)
|
||||
|
||||
|
||||
class ScopedConnection(ProxyMixin):
|
||||
def __init__(self, pool, connection):
|
||||
self.pool = pool
|
||||
self.connection = connection
|
||||
|
||||
def __del__(self):
|
||||
self.release()
|
||||
|
||||
def __str__(self):
|
||||
return self.connection.__str__()
|
||||
|
||||
def release(self):
|
||||
if not self.pool or not self.connection:
|
||||
return
|
||||
self.pool.release(self.connection)
|
||||
self.pool = None
|
||||
self.connection = None
|
||||
|
||||
|
||||
class ConnectionPool(topology.TopoObject):
|
||||
def __init__(self, name, uri, max_retry=1, capacity=-1, **kwargs):
|
||||
super().__init__(name)
|
||||
self.capacity = capacity
|
||||
self.pending_pool = set()
|
||||
self.active_pool = set()
|
||||
self.connection_ownership = {}
|
||||
self.uri = uri
|
||||
self.max_retry = max_retry
|
||||
self.kwargs = kwargs
|
||||
self.cv = threading.Condition()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pending_pool) + len(self.active_pool)
|
||||
|
||||
@property
|
||||
def active_num(self):
|
||||
return len(self.active_pool)
|
||||
|
||||
def _is_full(self):
|
||||
if self.capacity < 0:
|
||||
return False
|
||||
return len(self) >= self.capacity
|
||||
|
||||
def fetch(self, timeout=1):
|
||||
with self.cv:
|
||||
timeout_times = 0
|
||||
while (len(self.pending_pool) == 0 and self._is_full() and timeout_times < 1):
|
||||
self.cv.notifyAll()
|
||||
self.cv.wait(timeout)
|
||||
timeout_times += 1
|
||||
|
||||
connection = None
|
||||
if timeout_times >= 1:
|
||||
return connection
|
||||
|
||||
# logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
|
||||
if len(self.pending_pool) == 0:
|
||||
connection = self.create()
|
||||
else:
|
||||
connection = self.pending_pool.pop()
|
||||
# logger.debug('[Connection] Registerring \"{}\" into pool \"{}\"'.format(connection, self.name))
|
||||
self.active_pool.add(connection)
|
||||
scoped_connection = ScopedConnection(self, connection)
|
||||
return scoped_connection
|
||||
|
||||
def release(self, connection):
|
||||
with self.cv:
|
||||
if connection not in self.active_pool:
|
||||
raise RuntimeError('\"{}\" not found in pool \"{}\"'.format(connection, self.name))
|
||||
# logger.debug('[Connection] Releasing \"{}\" from pool \"{}\"'.format(connection, self.name))
|
||||
# logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
|
||||
self.active_pool.remove(connection)
|
||||
self.pending_pool.add(connection)
|
||||
|
||||
def create(self):
|
||||
connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs)
|
||||
return connection
|
||||
|
||||
|
||||
class ConnectionGroup(topology.TopoGroup):
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
def on_pre_add(self, topo_object):
|
||||
conn = topo_object.fetch()
|
||||
conn.on_connect(metadata=None)
|
||||
status, version = conn.conn.server_version()
|
||||
if not status.OK():
|
||||
logger.error('Cannot connect to newly added address: {}. Remove it now'.format(topo_object.name))
|
||||
return False
|
||||
if version not in settings.SERVER_VERSIONS:
|
||||
logger.error('Cannot connect to server of version: {}. Only {} supported'.format(version,
|
||||
settings.SERVER_VERSIONS))
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def create(self, name, **kwargs):
|
||||
uri = kwargs.get('uri', None)
|
||||
if not uri:
|
||||
raise RuntimeError('\"uri\" is required to create connection pool')
|
||||
pool = ConnectionPool(name=name, **kwargs)
|
||||
status = self.add(pool)
|
||||
if status != topology.StatusType.OK:
|
||||
pool = None
|
||||
return status, pool
|
||||
|
||||
|
||||
class ConnectionTopology(topology.Topology):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def create(self, name):
|
||||
group = ConnectionGroup(name)
|
||||
status = self.add_group(group)
|
||||
if status == topology.StatusType.DUPLICATED:
|
||||
group = None
|
||||
return status, group
|
||||
|
||||
|
||||
@singleton
|
||||
class ConnectionMgr:
|
||||
|
@ -126,6 +261,14 @@ class ConnectionMgr:
|
|||
def on_new_meta(self, name, url):
|
||||
logger.info('Register Connection: name={};url={}'.format(name, url))
|
||||
self.metas[name] = url
|
||||
conn = self.conn(name, metadata=None)
|
||||
conn.on_connect(metadata=None)
|
||||
status, _ = conn.conn.server_version()
|
||||
if not status.OK():
|
||||
logger.error('Cannot connect to newly added address: {}. Remove it now'.format(name))
|
||||
self.unregister(name)
|
||||
return False
|
||||
return True
|
||||
|
||||
def on_duplicate_meta(self, name, url):
|
||||
if self.metas[name] == url:
|
||||
|
@ -135,19 +278,22 @@ class ConnectionMgr:
|
|||
|
||||
def on_same_meta(self, name, url):
|
||||
# logger.warning('Register same meta: {}:{}'.format(name, url))
|
||||
pass
|
||||
return True
|
||||
|
||||
def on_diff_meta(self, name, url):
|
||||
logger.warning('Received {} with diff url={}'.format(name, url))
|
||||
self.metas[name] = url
|
||||
self.conns[name] = {}
|
||||
return True
|
||||
|
||||
def on_unregister_meta(self, name, url):
|
||||
logger.info('Unregister name={};url={}'.format(name, url))
|
||||
self.conns.pop(name, None)
|
||||
return True
|
||||
|
||||
def on_nonexisted_meta(self, name):
|
||||
logger.warning('Non-existed meta: {}'.format(name))
|
||||
return False
|
||||
|
||||
def register(self, name, url):
|
||||
meta = self.metas.get(name)
|
||||
|
|
|
@ -2,20 +2,21 @@ from mishards import exceptions
|
|||
|
||||
|
||||
class RouterMixin:
|
||||
def __init__(self, conn_mgr):
|
||||
self.conn_mgr = conn_mgr
|
||||
def __init__(self, writable_topo, readonly_topo):
|
||||
self.writable_topo = writable_topo
|
||||
self.readonly_topo = readonly_topo
|
||||
|
||||
def routing(self, table_name, metadata=None, **kwargs):
|
||||
raise NotImplemented()
|
||||
|
||||
def connection(self, metadata=None):
|
||||
conn = self.conn_mgr.conn('WOSERVER', metadata=metadata)
|
||||
conn = self.writable_topo.get_group('default').get('WOSERVER').fetch()
|
||||
if conn:
|
||||
conn.on_connect(metadata=metadata)
|
||||
return conn.conn
|
||||
|
||||
def query_conn(self, name, metadata=None):
|
||||
conn = self.conn_mgr.conn(name, metadata=metadata)
|
||||
conn = self.readonly_topo.get_group(name).get(name).fetch()
|
||||
if not conn:
|
||||
raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
|
||||
conn.on_connect(metadata=metadata)
|
||||
|
|
|
@ -12,8 +12,9 @@ logger = logging.getLogger(__name__)
|
|||
class Factory(RouterMixin):
|
||||
name = 'FileBasedHashRingRouter'
|
||||
|
||||
def __init__(self, conn_mgr, **kwargs):
|
||||
super(Factory, self).__init__(conn_mgr)
|
||||
def __init__(self, writable_topo, readonly_topo, **kwargs):
|
||||
super(Factory, self).__init__(writable_topo=writable_topo,
|
||||
readonly_topo=readonly_topo)
|
||||
|
||||
def routing(self, table_name, partition_tags=None, metadata=None, **kwargs):
|
||||
range_array = kwargs.pop('range_array', None)
|
||||
|
@ -46,7 +47,7 @@ class Factory(RouterMixin):
|
|||
|
||||
db.remove_session()
|
||||
|
||||
servers = self.conn_mgr.conn_names
|
||||
servers = self.readonly_topo.group_names
|
||||
logger.info('Available servers: {}'.format(servers))
|
||||
|
||||
ring = HashRing(servers)
|
||||
|
@ -65,10 +66,13 @@ class Factory(RouterMixin):
|
|||
|
||||
@classmethod
|
||||
def Create(cls, **kwargs):
|
||||
conn_mgr = kwargs.pop('conn_mgr', None)
|
||||
if not conn_mgr:
|
||||
raise RuntimeError('Cannot find \'conn_mgr\' to initialize \'{}\''.format(self.name))
|
||||
router = cls(conn_mgr, **kwargs)
|
||||
writable_topo = kwargs.pop('writable_topo', None)
|
||||
if not writable_topo:
|
||||
raise RuntimeError('Cannot find \'writable_topo\' to initialize \'{}\''.format(self.name))
|
||||
readonly_topo = kwargs.pop('readonly_topo', None)
|
||||
if not readonly_topo:
|
||||
raise RuntimeError('Cannot find \'readonly_topo\' to initialize \'{}\''.format(self.name))
|
||||
router = cls(writable_topo=writable_topo, readonly_topo=readonly_topo, **kwargs)
|
||||
return router
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import sys
|
||||
import grpc
|
||||
import time
|
||||
import socket
|
||||
|
@ -23,7 +24,8 @@ class Server:
|
|||
self.exit_flag = False
|
||||
|
||||
def init_app(self,
|
||||
conn_mgr,
|
||||
writable_topo,
|
||||
readonly_topo,
|
||||
tracer,
|
||||
router,
|
||||
discover,
|
||||
|
@ -31,11 +33,14 @@ class Server:
|
|||
max_workers=10,
|
||||
**kwargs):
|
||||
self.port = int(port)
|
||||
self.conn_mgr = conn_mgr
|
||||
self.writable_topo = writable_topo
|
||||
self.readonly_topo = readonly_topo
|
||||
self.tracer = tracer
|
||||
self.router = router
|
||||
self.discover = discover
|
||||
|
||||
logger.debug('Init grpc server with max_workers: {}'.format(max_workers))
|
||||
|
||||
self.server_impl = grpc.server(
|
||||
thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
|
||||
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
|
||||
|
@ -50,8 +55,8 @@ class Server:
|
|||
url = urlparse(woserver)
|
||||
ip = socket.gethostbyname(url.hostname)
|
||||
socket.inet_pton(socket.AF_INET, ip)
|
||||
self.conn_mgr.register(
|
||||
'WOSERVER', '{}://{}:{}'.format(url.scheme, ip, url.port or 80))
|
||||
_, group = self.writable_topo.create('default')
|
||||
group.create(name='WOSERVER', uri='{}://{}:{}'.format(url.scheme, ip, url.port or 80))
|
||||
|
||||
def register_pre_run_handler(self, func):
|
||||
logger.info('Regiterring {} into server pre_run_handlers'.format(func))
|
||||
|
@ -83,7 +88,7 @@ class Server:
|
|||
def on_pre_run(self):
|
||||
for handler in self.pre_run_handlers:
|
||||
handler()
|
||||
self.discover.start()
|
||||
return self.discover.start()
|
||||
|
||||
def start(self, port=None):
|
||||
handler_class = self.decorate_handler(ServiceHandler)
|
||||
|
@ -97,7 +102,11 @@ class Server:
|
|||
def run(self, port):
|
||||
logger.info('Milvus server start ......')
|
||||
port = port or self.port
|
||||
self.on_pre_run()
|
||||
ok = self.on_pre_run()
|
||||
|
||||
if not ok:
|
||||
logger.error('Terminate server due to error found in on_pre_run')
|
||||
sys.exit(1)
|
||||
|
||||
self.start(port)
|
||||
logger.info('Listening on port {}'.format(port))
|
||||
|
|
|
@ -12,6 +12,7 @@ else:
|
|||
env.read_env()
|
||||
|
||||
|
||||
SERVER_VERSIONS = ['0.6.0']
|
||||
DEBUG = env.bool('DEBUG', False)
|
||||
MAX_RETRY = env.int('MAX_RETRY', 3)
|
||||
|
||||
|
@ -26,6 +27,7 @@ config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE)
|
|||
SERVER_PORT = env.int('SERVER_PORT', 19530)
|
||||
SERVER_TEST_PORT = env.int('SERVER_TEST_PORT', 19530)
|
||||
WOSERVER = env.str('WOSERVER')
|
||||
MAX_WORKERS = env.int('MAX_WORKERS', 50)
|
||||
|
||||
|
||||
class TracingConfig:
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import logging
|
||||
import pytest
|
||||
import mock
|
||||
import random
|
||||
import threading
|
||||
|
||||
from milvus import Milvus
|
||||
from mishards.connections import (ConnectionMgr, Connection)
|
||||
from mishards.connections import (ConnectionMgr, Connection,
|
||||
ConnectionPool, ConnectionTopology, ConnectionGroup)
|
||||
from mishards.topology import StatusType
|
||||
from mishards import exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -11,6 +15,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
@pytest.mark.usefixtures('app')
|
||||
class TestConnection:
|
||||
@pytest.mark.skip
|
||||
def test_manager(self):
|
||||
mgr = ConnectionMgr()
|
||||
|
||||
|
@ -99,3 +104,161 @@ class TestConnection:
|
|||
this_connect = c.connect(func=None, exception_handler=error_handler)
|
||||
this_connect()
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_topology(self):
|
||||
ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,))
|
||||
w_topo = ConnectionTopology()
|
||||
status, wg1 = w_topo.create(name='wg1')
|
||||
assert w_topo.has_group(wg1)
|
||||
assert status == StatusType.OK
|
||||
|
||||
status, wg1_dup = w_topo.create(name='wg1')
|
||||
assert wg1_dup is None
|
||||
assert status == StatusType.DUPLICATED
|
||||
|
||||
fetched_group = w_topo.get_group('wg1')
|
||||
assert id(fetched_group) == id(wg1)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
wg1.create(name='wg1_p1')
|
||||
|
||||
status, wg1_p1 = wg1.create(name='wg1_p1', uri='127.0.0.1:19530')
|
||||
assert status == StatusType.OK
|
||||
assert wg1_p1 is not None
|
||||
assert len(wg1) == 1
|
||||
|
||||
status, wg1_p1_dup = wg1.create(name='wg1_p1', uri='127.0.0.1:19530')
|
||||
assert status == StatusType.DUPLICATED
|
||||
assert wg1_p1_dup is None
|
||||
assert len(wg1) == 1
|
||||
|
||||
status, wg1_p2 = wg1.create('wg1_p2', uri='127.0.0.1:19530')
|
||||
assert status == StatusType.OK
|
||||
assert wg1_p2 is not None
|
||||
assert len(wg1) == 2
|
||||
|
||||
poped = wg1.remove('wg1_p3')
|
||||
assert poped is None
|
||||
assert len(wg1) == 2
|
||||
|
||||
poped = wg1.remove('wg1_p2')
|
||||
assert poped.name == 'wg1_p2'
|
||||
assert len(wg1) == 1
|
||||
|
||||
fetched_p1 = wg1.get(wg1_p1.name)
|
||||
assert fetched_p1 == wg1_p1
|
||||
|
||||
fetched_p1 = w_topo.get_group('wg1').get('wg1_p1')
|
||||
|
||||
conn1 = fetched_p1.fetch()
|
||||
assert len(fetched_p1) == 1
|
||||
assert fetched_p1.active_num == 1
|
||||
|
||||
conn2 = fetched_p1.fetch()
|
||||
assert len(fetched_p1) == 2
|
||||
assert fetched_p1.active_num == 2
|
||||
|
||||
conn2.release()
|
||||
assert len(fetched_p1) == 2
|
||||
assert fetched_p1.active_num == 1
|
||||
|
||||
assert len(w_topo.group_names) == 1
|
||||
|
||||
def test_connection_pool(self):
|
||||
ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,))
|
||||
|
||||
def choaz_mp_fetch(capacity, count, tnum):
|
||||
threads_num = 5
|
||||
topo = ConnectionTopology()
|
||||
_, tg = topo.create('tg')
|
||||
pool_size = 20
|
||||
pool_names = ['p{}:19530'.format(i) for i in range(pool_size)]
|
||||
|
||||
threads = []
|
||||
def Worker(group, cnt, capacity):
|
||||
ori_cnt = cnt
|
||||
assert cnt < 100
|
||||
while cnt >= 0:
|
||||
name = pool_names[random.randint(0, pool_size-1)]
|
||||
cnt -= 1
|
||||
remove = (random.randint(1,4)%4 == 0)
|
||||
if remove:
|
||||
pool = group.get(name=name)
|
||||
# if name.startswith("p1:"):
|
||||
# logger.error('{} CNT={} [Remove] Group \"{}\" has pool of SIZE={} ACTIVE={}'.format(threading.get_ident(), ori_cnt-cnt, name, len(pool), pool.active_num))
|
||||
group.remove(name)
|
||||
|
||||
else:
|
||||
group.create(name=name, uri=name, capacity=capacity)
|
||||
pool = group.get(name=name)
|
||||
assert pool is not None
|
||||
conn = pool.fetch(timeout=0.01)
|
||||
# if name.startswith("p1:"):
|
||||
# logger.error('{} CNT={} [Adding] Group \"{}\" has pool of SIZE={} ACTIVE={}'.format(threading.get_ident(), ori_cnt-cnt, name, len(pool), pool.active_num))
|
||||
|
||||
for _ in range(threads_num):
|
||||
t = threading.Thread(target=Worker, args=(tg, count, tnum))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
choaz_mp_fetch(4, 40, 8)
|
||||
|
||||
def check_mp_fetch(capacity=-1):
|
||||
w2 = ConnectionPool(name='w2', uri='127.0.0.1:19530', max_retry=2, capacity=capacity)
|
||||
connections = []
|
||||
def GetConnection(pool):
|
||||
conn = pool.fetch(timeout=0.1)
|
||||
if conn:
|
||||
connections.append(conn)
|
||||
|
||||
threads = []
|
||||
threads_num = 10 if capacity < 0 else 2*capacity
|
||||
for _ in range(threads_num):
|
||||
t = threading.Thread(target=GetConnection, args=(w2,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
expected_size = threads_num if capacity < 0 else capacity
|
||||
|
||||
assert len(connections) == expected_size
|
||||
|
||||
check_mp_fetch(5)
|
||||
check_mp_fetch()
|
||||
|
||||
w1 = ConnectionPool(name='w1', uri='127.0.0.1:19530', max_retry=2, capacity=2)
|
||||
w1_1 = w1.fetch()
|
||||
assert len(w1) == 1
|
||||
assert w1.active_num == 1
|
||||
w1_2 = w1.fetch()
|
||||
assert len(w1) == 2
|
||||
assert w1.active_num == 2
|
||||
w1_3 = w1.fetch()
|
||||
assert w1_3 is None
|
||||
assert len(w1) == 2
|
||||
assert w1.active_num == 2
|
||||
|
||||
w1_1.release()
|
||||
assert len(w1) == 2
|
||||
assert w1.active_num == 1
|
||||
|
||||
def check(pool, expected_size, expected_active_num):
|
||||
w = pool.fetch()
|
||||
assert len(pool) == expected_size
|
||||
assert pool.active_num == expected_active_num
|
||||
|
||||
check(w1, 2, 2)
|
||||
|
||||
assert len(w1) == 2
|
||||
assert w1.active_num == 1
|
||||
|
||||
wild_w = w1.create()
|
||||
with pytest.raises(RuntimeError):
|
||||
w1.release(wild_w)
|
||||
|
||||
ret = w1_2.can_retry
|
||||
assert ret == w1_2.connection.can_retry
|
||||
|
|
|
@ -14,6 +14,8 @@ from mishards.service_handler import ServiceHandler
|
|||
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
|
||||
from mishards.factories import TableFilesFactory, TablesFactory, TableFiles, Tables
|
||||
from mishards.router import RouterMixin
|
||||
from mishards.connections import (ConnectionMgr, Connection,
|
||||
ConnectionPool, ConnectionTopology, ConnectionGroup)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -23,15 +25,13 @@ BAD = Status(code=Status.PERMISSION_DENIED, message='Fail')
|
|||
|
||||
@pytest.mark.usefixtures('started_app')
|
||||
class TestServer:
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
m = Milvus()
|
||||
m.connect(host='localhost', port=settings.SERVER_TEST_PORT)
|
||||
return m
|
||||
|
||||
def test_server_start(self, started_app):
|
||||
assert started_app.conn_mgr.metas.get('WOSERVER') == settings.WOSERVER
|
||||
|
||||
def test_cmd(self, started_app):
|
||||
ServiceHandler._get_server_version = mock.MagicMock(return_value=(OK,
|
||||
''))
|
||||
|
@ -228,6 +228,7 @@ class TestServer:
|
|||
def random_data(self, n, dimension):
|
||||
return [[random.random() for _ in range(dimension)] for _ in range(n)]
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_search(self, started_app):
|
||||
table_name = inspect.currentframe().f_code.co_name
|
||||
to_index_cnt = random.randint(10, 20)
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
import logging
|
||||
import threading
|
||||
import enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TopoObject:
|
||||
def __init__(self, name, **kwargs):
|
||||
self.name = name
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.name == other
|
||||
return self.name == other.name
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __str__(self):
|
||||
return '<TopoObject: {}>'.format(self.name)
|
||||
|
||||
class StatusType(enum.Enum):
|
||||
OK = 1
|
||||
DUPLICATED = 2
|
||||
ADD_ERROR = 3
|
||||
VERSION_ERROR = 4
|
||||
|
||||
|
||||
class TopoGroup:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.items = {}
|
||||
self.cv = threading.Condition()
|
||||
|
||||
def on_duplicate(self, topo_object):
|
||||
logger.warning('Duplicated topo_object \"{}\" into group \"{}\"'.format(topo_object, self.name))
|
||||
|
||||
def on_added(self, topo_object):
|
||||
return True
|
||||
|
||||
def on_pre_add(self, topo_object):
|
||||
return True
|
||||
|
||||
def _add_no_lock(self, topo_object):
|
||||
if topo_object.name in self.items:
|
||||
return StatusType.DUPLICATED
|
||||
logger.info('Adding topo_object \"{}\" into group \"{}\"'.format(topo_object, self.name))
|
||||
ok = self.on_pre_add(topo_object)
|
||||
if not ok:
|
||||
return StatusType.VERSION_ERROR
|
||||
self.items[topo_object.name] = topo_object
|
||||
ok = self.on_added(topo_object)
|
||||
if not ok:
|
||||
self._remove_no_lock(topo_object.name)
|
||||
|
||||
return StatusType.OK if ok else StatusType.ADD_ERROR
|
||||
|
||||
def add(self, topo_object):
|
||||
with self.cv:
|
||||
return self._add_no_lock(topo_object)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __str__(self):
|
||||
return '<TopoGroup: {}>'.format(self.name)
|
||||
|
||||
def get(self, name):
|
||||
return self.items.get(name, None)
|
||||
|
||||
def _remove_no_lock(self, name):
|
||||
logger.info('Removing topo_object \"{}\" from group \"{}\"'.format(name, self.name))
|
||||
return self.items.pop(name, None)
|
||||
|
||||
def remove(self, name):
|
||||
with self.cv:
|
||||
return self._remove_no_lock(name)
|
||||
|
||||
|
||||
class Topology:
|
||||
def __init__(self):
|
||||
self.topo_groups = {}
|
||||
self.cv = threading.Condition()
|
||||
|
||||
def on_duplicated_group(self, group):
|
||||
logger.warning('Duplicated group \"{}\" found!'.format(group))
|
||||
return StatusType.DUPLICATED
|
||||
|
||||
def on_pre_add_group(self, group):
|
||||
logger.debug('Pre add group \"{}\"'.format(group))
|
||||
return StatusType.OK
|
||||
|
||||
def on_post_add_group(self, group):
|
||||
logger.debug('Post add group \"{}\"'.format(group))
|
||||
return StatusType.OK
|
||||
|
||||
def get_group(self, name):
|
||||
return self.topo_groups.get(name, None)
|
||||
|
||||
def has_group(self, group):
|
||||
key = group if isinstance(group, str) else group.name
|
||||
return key in self.topo_groups
|
||||
|
||||
def _add_group_no_lock(self, group):
|
||||
logger.info('Adding group \"{}\"'.format(group))
|
||||
self.topo_groups[group.name] = group
|
||||
|
||||
def add_group(self, group):
|
||||
self.on_pre_add_group(group)
|
||||
if self.has_group(group):
|
||||
return self.on_duplicated_group(group)
|
||||
with self.cv:
|
||||
self._add_group_no_lock(group)
|
||||
return self.on_post_add_group(group)
|
||||
|
||||
def on_delete_not_existed_group(self, group):
|
||||
logger.warning('Deleting non-existed group \"{}\"'.format(group))
|
||||
|
||||
def on_pre_delete_group(self, group):
|
||||
logger.debug('Pre delete group \"{}\"'.format(group))
|
||||
|
||||
def on_post_delete_group(self, group):
|
||||
logger.debug('Post delete group \"{}\"'.format(group))
|
||||
|
||||
def _delete_group_no_lock(self, group):
|
||||
logger.info('Deleting group \"{}\"'.format(group))
|
||||
delete_key = group if isinstance(group, str) else group.name
|
||||
return self.topo_groups.pop(delete_key, None)
|
||||
|
||||
def delete_group(self, group):
|
||||
self.on_pre_delete_group(group)
|
||||
with self.cv:
|
||||
deleted_group = self._delete_group_lock(group)
|
||||
if not deleted_group:
|
||||
return self.on_delete_not_existed_group(group)
|
||||
return self.on_post_delete_group(group)
|
||||
|
||||
@property
|
||||
def group_names(self):
|
||||
return self.topo_groups.keys()
|
|
@ -26,6 +26,8 @@ class JaegerFactory:
|
|||
tracer,
|
||||
log_payloads=plugin_config.TRACING_LOG_PAYLOAD,
|
||||
span_decorator=span_decorator)
|
||||
jaeger_logger = logging.getLogger('jaeger_tracing')
|
||||
jaeger_logger.setLevel(logging.ERROR)
|
||||
|
||||
return Tracer(tracer, tracer_interceptor, intercept_server)
|
||||
|
||||
|
|
Loading…
Reference in New Issue