mirror of https://github.com/milvus-io/milvus.git
implement router plugins
parent
2ab0e0eb93
commit
63997d55ec
|
@ -28,8 +28,9 @@ def create_app(testing_config=None):
|
|||
settings.TracingConfig,
|
||||
span_decorator=GrpcSpanDecorator())
|
||||
|
||||
from mishards.routings import RouterFactory
|
||||
router = RouterFactory.new_router(config.ROUTER_CLASS_NAME, connect_mgr)
|
||||
from mishards.router.factory import RouterFactory
|
||||
router = RouterFactory(config.ROUTER_PLUGIN_PATH).create(config.ROUTER_CLASS_NAME,
|
||||
conn_mgr=connect_mgr)
|
||||
|
||||
grpc_server.init_app(conn_mgr=connect_mgr,
|
||||
tracer=tracer,
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
from mishards import exceptions
|
||||
|
||||
|
||||
class RouterMixin:
|
||||
def __init__(self, conn_mgr):
|
||||
self.conn_mgr = conn_mgr
|
||||
|
||||
def routing(self, table_name, metadata=None, **kwargs):
|
||||
raise NotImplemented()
|
||||
|
||||
def connection(self, metadata=None):
|
||||
conn = self.conn_mgr.conn('WOSERVER', metadata=metadata)
|
||||
if conn:
|
||||
conn.on_connect(metadata=metadata)
|
||||
return conn.conn
|
||||
|
||||
def query_conn(self, name, metadata=None):
|
||||
conn = self.conn_mgr.conn(name, metadata=metadata)
|
||||
if not conn:
|
||||
raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
|
||||
conn.on_connect(metadata=metadata)
|
||||
return conn.conn
|
|
@ -0,0 +1,49 @@
|
|||
import os
|
||||
import logging
|
||||
from functools import partial
|
||||
# from pluginbase import PluginBase
|
||||
# import importlib
|
||||
from utils.pluginextension import MiPluginBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
get_path = partial(os.path.join, here)
|
||||
|
||||
PLUGIN_PACKAGE_NAME = 'router.plugins'
|
||||
plugin_base = MiPluginBase(package=PLUGIN_PACKAGE_NAME,
|
||||
searchpath=[get_path('./plugins')])
|
||||
|
||||
|
||||
class RouterFactory(object):
|
||||
PLUGIN_TYPE = 'Router'
|
||||
|
||||
def __init__(self, searchpath=None):
|
||||
self.plugin_package_name = PLUGIN_PACKAGE_NAME
|
||||
self.class_map = {}
|
||||
searchpath = searchpath if searchpath else []
|
||||
searchpath = [searchpath] if isinstance(searchpath, str) else searchpath
|
||||
self.source = plugin_base.make_plugin_source(searchpath=searchpath,
|
||||
identifier=self.__class__.__name__)
|
||||
|
||||
for plugin_name in self.source.list_plugins():
|
||||
plugin = self.source.load_plugin(plugin_name)
|
||||
plugin.setup(self)
|
||||
|
||||
def on_plugin_setup(self, plugin_class):
|
||||
name = getattr(plugin_class, 'name', plugin_class.__name__)
|
||||
self.class_map[name.lower()] = plugin_class
|
||||
|
||||
def plugin(self, name):
|
||||
return self.class_map.get(name, None)
|
||||
|
||||
def create(self, class_name, class_config=None, **kwargs):
|
||||
if not class_name:
|
||||
raise RuntimeError('Please specify router class_name first!')
|
||||
|
||||
this_class = self.plugin(class_name.lower())
|
||||
if not this_class:
|
||||
raise RuntimeError('{} Plugin \'{}\' Not Installed!'.format(self.PLUGIN_TYPE, class_name))
|
||||
|
||||
router = this_class.create(class_config, **kwargs)
|
||||
return router
|
|
@ -1,64 +1,19 @@
|
|||
import logging
|
||||
from sqlalchemy import exc as sqlalchemy_exc
|
||||
from sqlalchemy import and_
|
||||
|
||||
from mishards.models import Tables
|
||||
from mishards.router import RouterMixin
|
||||
from mishards import exceptions, db
|
||||
from mishards.hash_ring import HashRing
|
||||
from mishards.models import Tables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RouteManager:
|
||||
ROUTER_CLASSES = {}
|
||||
|
||||
@classmethod
|
||||
def register_router_class(cls, target):
|
||||
name = target.__dict__.get('NAME', None)
|
||||
name = name if name else target.__class__.__name__
|
||||
cls.ROUTER_CLASSES[name] = target
|
||||
return target
|
||||
|
||||
@classmethod
|
||||
def get_router_class(cls, name):
|
||||
return cls.ROUTER_CLASSES.get(name, None)
|
||||
|
||||
|
||||
class RouterFactory:
|
||||
@classmethod
|
||||
def new_router(cls, name, conn_mgr, **kwargs):
|
||||
router_class = RouteManager.get_router_class(name)
|
||||
assert router_class
|
||||
return router_class(conn_mgr, **kwargs)
|
||||
|
||||
|
||||
class RouterMixin:
|
||||
def __init__(self, conn_mgr):
|
||||
self.conn_mgr = conn_mgr
|
||||
|
||||
def routing(self, table_name, metadata=None, **kwargs):
|
||||
raise NotImplemented()
|
||||
|
||||
def connection(self, metadata=None):
|
||||
conn = self.conn_mgr.conn('WOSERVER', metadata=metadata)
|
||||
if conn:
|
||||
conn.on_connect(metadata=metadata)
|
||||
return conn.conn
|
||||
|
||||
def query_conn(self, name, metadata=None):
|
||||
conn = self.conn_mgr.conn(name, metadata=metadata)
|
||||
if not conn:
|
||||
raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
|
||||
conn.on_connect(metadata=metadata)
|
||||
return conn.conn
|
||||
|
||||
|
||||
@RouteManager.register_router_class
|
||||
class FileBasedHashRingRouter(RouterMixin):
|
||||
NAME = 'FileBasedHashRingRouter'
|
||||
class Factory(RouterMixin):
|
||||
name = 'FileBasedHashRingRouter'
|
||||
|
||||
def __init__(self, conn_mgr, **kwargs):
|
||||
super(FileBasedHashRingRouter, self).__init__(conn_mgr)
|
||||
super(Factory, self).__init__(conn_mgr)
|
||||
|
||||
def routing(self, table_name, metadata=None, **kwargs):
|
||||
range_array = kwargs.pop('range_array', None)
|
||||
|
@ -94,3 +49,16 @@ class FileBasedHashRingRouter(RouterMixin):
|
|||
routing[target_host]['file_ids'].append(str(f.id))
|
||||
|
||||
return routing
|
||||
|
||||
@classmethod
|
||||
def create(cls, config, **kwargs):
|
||||
conn_mgr = kwargs.pop('conn_mgr', None)
|
||||
if not conn_mgr:
|
||||
raise RuntimeError('Cannot find \'conn_mgr\' to initialize \'{}\''.format(self.name))
|
||||
router = cls(conn_mgr, **kwargs)
|
||||
return router
|
||||
|
||||
|
||||
def setup(app):
|
||||
logger.info('Plugin \'{}\' Installed In Package: {}'.format(__file__, app.plugin_package_name))
|
||||
app.on_plugin_setup(Factory)
|
|
@ -76,6 +76,7 @@ class DefaultConfig:
|
|||
SQL_ECHO = env.bool('SQL_ECHO', False)
|
||||
TRACING_PLUGIN_PATH = env.str('TRACING_PLUGIN_PATH', '')
|
||||
TRACING_TYPE = env.str('TRACING_TYPE', '')
|
||||
ROUTER_PLUGIN_PATH = env.str('ROUTER_PLUGIN_PATH', '')
|
||||
ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_NAME', 'FileBasedHashRingRouter')
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
import importlib
|
||||
from pluginbase import PluginBase, PluginSource
|
||||
|
||||
|
||||
class MiPluginSource(PluginSource):
|
||||
def load_plugin(self, name):
|
||||
if '.' in name:
|
||||
raise ImportError('Plugin names cannot contain dots.')
|
||||
with self:
|
||||
return importlib.import_module(self.base.package + '.' + name)
|
||||
|
||||
|
||||
class MiPluginBase(PluginBase):
|
||||
def make_plugin_source(self, *args, **kwargs):
|
||||
return MiPluginSource(self, *args, **kwargs)
|
Loading…
Reference in New Issue