diff --git a/shards/discovery/factory.py b/shards/discovery/factory.py index 80334daf68..5f5c7fcf95 100644 --- a/shards/discovery/factory.py +++ b/shards/discovery/factory.py @@ -8,6 +8,7 @@ PLUGIN_PACKAGE_NAME = 'discovery.plugins' class DiscoveryFactory(BaseMixin): PLUGIN_TYPE = 'Discovery' + def __init__(self, searchpath=None): super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME) @@ -17,5 +18,5 @@ class DiscoveryFactory(BaseMixin): raise RuntimeError('Please pass conn_mgr to create discovery!') plugin_config = DiscoveryConfig.Create() - plugin = plugin_class.create(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs) + plugin = plugin_class.Create(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs) return plugin diff --git a/shards/discovery/plugins/kubernetes_provider.py b/shards/discovery/plugins/kubernetes_provider.py index c9d9a3ad5a..aaf6091f83 100644 --- a/shards/discovery/plugins/kubernetes_provider.py +++ b/shards/discovery/plugins/kubernetes_provider.py @@ -299,7 +299,7 @@ class KubernetesProvider(object): self.event_handler.stop() @classmethod - def create(cls, conn_mgr, plugin_config, **kwargs): + def Create(cls, conn_mgr, plugin_config, **kwargs): discovery = cls(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs) return discovery diff --git a/shards/discovery/plugins/static_provider.py b/shards/discovery/plugins/static_provider.py index 0f8bdb3d25..9bea62f2da 100644 --- a/shards/discovery/plugins/static_provider.py +++ b/shards/discovery/plugins/static_provider.py @@ -33,7 +33,7 @@ class StaticDiscovery(object): self.conn_mgr.unregister(name) @classmethod - def create(cls, conn_mgr, plugin_config, **kwargs): + def Create(cls, conn_mgr, plugin_config, **kwargs): discovery = cls(config=plugin_config, conn_mgr=conn_mgr, **kwargs) return discovery diff --git a/shards/mishards/__init__.py b/shards/mishards/__init__.py index e0792348a9..96463caa93 100644 --- a/shards/mishards/__init__.py +++ b/shards/mishards/__init__.py @@ -23,7 +23,7 @@ def create_app(testing_config=None): from mishards.grpc_utils import GrpcSpanDecorator from tracer.factory import TracerFactory tracer = TracerFactory(config.TRACING_PLUGIN_PATH).create(config.TRACING_TYPE, - settings.TracingConfig, + plugin_config=settings.TracingConfig, span_decorator=GrpcSpanDecorator()) from mishards.router.factory import RouterFactory diff --git a/shards/mishards/router/factory.py b/shards/mishards/router/factory.py index ea29a26a1d..a8f85c0df8 100644 --- a/shards/mishards/router/factory.py +++ b/shards/mishards/router/factory.py @@ -13,5 +13,5 @@ class RouterFactory(BaseMixin): super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME) def _create(self, plugin_class, **kwargs): - router = plugin_class.create(**kwargs) + router = plugin_class.Create(**kwargs) return router diff --git a/shards/mishards/router/plugins/file_based_hash_ring_router.py b/shards/mishards/router/plugins/file_based_hash_ring_router.py index 4697189f35..b90935129e 100644 --- a/shards/mishards/router/plugins/file_based_hash_ring_router.py +++ b/shards/mishards/router/plugins/file_based_hash_ring_router.py @@ -51,7 +51,7 @@ class Factory(RouterMixin): return routing @classmethod - def create(cls, **kwargs): + def Create(cls, **kwargs): conn_mgr = kwargs.pop('conn_mgr', None) if not conn_mgr: raise RuntimeError('Cannot find \'conn_mgr\' to initialize \'{}\''.format(self.name)) diff --git a/shards/tracer/factory.py b/shards/tracer/factory.py index fff7a885e4..0e54a5aeb6 100644 --- a/shards/tracer/factory.py +++ b/shards/tracer/factory.py @@ -1,50 +1,27 @@ import os import logging -from functools import partial -from utils.pluginextension import MiPluginBase as PluginBase from tracer import Tracer - +from utils.plugins import BaseMixin logger = logging.getLogger(__name__) - -here = os.path.abspath(os.path.dirname(__file__)) -get_path = partial(os.path.join, here) - PLUGIN_PACKAGE_NAME = 'tracer.plugins' -plugin_base = PluginBase(package=PLUGIN_PACKAGE_NAME, - searchpath=[get_path('./plugins')]) -class TracerFactory(object): +class TracerFactory(BaseMixin): + PLUGIN_TYPE = 'Tracer' + def __init__(self, searchpath=None): - self.plugin_package_name = PLUGIN_PACKAGE_NAME - self.tracer_map = {} - searchpath = searchpath if searchpath else [] - searchpath = [searchpath] if isinstance(searchpath, str) else searchpath - self.source = plugin_base.make_plugin_source(searchpath=searchpath, - identifier=self.__class__.__name__) + super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME) - for plugin_name in self.source.list_plugins(): - plugin = self.source.load_plugin(plugin_name) - plugin.setup(self) - - def on_plugin_setup(self, plugin_class): - name = getattr(plugin_class, 'name', plugin_class.__name__) - self.tracer_map[name.lower()] = plugin_class - - def plugin(self, name): - return self.tracer_map.get(name, None) - - def create(self, - tracer_type, - tracer_config, - span_decorator=None, - **kwargs): - if not tracer_type: + def create(self, class_name, **kwargs): + if not class_name: return Tracer() - plugin_class = self.plugin(tracer_type.lower()) - if not plugin_class: - raise RuntimeError('Tracer Plugin \'{}\' not installed!'.format(tracer_type)) + return super().create(class_name, **kwargs) - tracer = plugin_class.create(tracer_config, span_decorator=span_decorator, **kwargs) - return tracer + def _create(self, plugin_class, **kwargs): + plugin_config = kwargs.pop('plugin_config', None) + if not plugin_config: + raise RuntimeError('\'{}\' Plugin Config is Required!'.format(self.PLUGIN_TYPE)) + + plugin = plugin_class.Create(plugin_config=plugin_config, **kwargs) + return plugin diff --git a/shards/tracer/plugins/jaeger_factory.py b/shards/tracer/plugins/jaeger_factory.py index 7b18a86130..923f2f805d 100644 --- a/shards/tracer/plugins/jaeger_factory.py +++ b/shards/tracer/plugins/jaeger_factory.py @@ -12,10 +12,11 @@ PLUGIN_NAME = __file__ class JaegerFactory: name = 'jaeger' @classmethod - def create(cls, tracer_config, span_decorator=None, **kwargs): - tracing_config = tracer_config.TRACING_CONFIG - service_name = tracer_config.TRACING_SERVICE_NAME - validate = tracer_config.TRACING_VALIDATE + def Create(cls, plugin_config, **kwargs): + tracing_config = plugin_config.TRACING_CONFIG + span_decorator = kwargs.pop('span_decorator', None) + service_name = plugin_config.TRACING_SERVICE_NAME + validate = plugin_config.TRACING_VALIDATE config = Config(config=tracing_config, service_name=service_name, validate=validate) @@ -23,7 +24,7 @@ class JaegerFactory: tracer = config.initialize_tracer() tracer_interceptor = open_tracing_server_interceptor( tracer, - log_payloads=tracer_config.TRACING_LOG_PAYLOAD, + log_payloads=plugin_config.TRACING_LOG_PAYLOAD, span_decorator=span_decorator) return Tracer(tracer, tracer_interceptor, intercept_server) diff --git a/shards/utils/plugins/__init__.py b/shards/utils/plugins/__init__.py index 361dda66f9..633f1164a7 100644 --- a/shards/utils/plugins/__init__.py +++ b/shards/utils/plugins/__init__.py @@ -5,7 +5,8 @@ from utils.pluginextension import MiPluginBase as PluginBase class BaseMixin(object): - def __init__(self, package_name, searchpath=None): + + def __init__(self, package_name, searchpath=None): self.plugin_package_name = package_name caller_path = os.path.dirname(inspect.stack()[1][1]) get_path = partial(os.path.join, caller_path)