milvus/shards/mishards/connections.py

283 lines
8.9 KiB
Python

import copy
import time
import json
import logging
import threading
from functools import wraps
from collections import defaultdict
from milvus import Milvus
from mishards import (settings, exceptions, topology)
from utils import singleton
logger = logging.getLogger(__name__)
# class Searchook(BaseSearchHook):
#
# def on_response(self, *args, **kwargs):
# return True
#
#
# class Connection:
# def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
# self.name = name
# self.uri = uri
# self.max_retry = max_retry
# self.retried = 0
# self.conn = Milvus()
# self.error_handlers = [] if not error_handlers else error_handlers
# self.on_retry_func = kwargs.get('on_retry_func', None)
#
# # define search hook
# self.conn.set_hook(search_in_file=Searchook())
# # self._connect()
#
# def __str__(self):
# return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri)
#
# def _connect(self, metadata=None):
# try:
# self.conn.connect(uri=self.uri)
# except Exception as e:
# if not self.error_handlers:
# raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata)
# for handler in self.error_handlers:
# handler(e, metadata=metadata)
#
# @property
# def can_retry(self):
# return self.retried < self.max_retry
#
# @property
# def connected(self):
# return self.conn.connected()
#
# def on_retry(self):
# if self.on_retry_func:
# self.on_retry_func(self)
# else:
# self.retried > 1 and logger.warning('{} is retrying {}'.format(self, self.retried))
#
# def on_connect(self, metadata=None):
# while not self.connected and self.can_retry:
# self.retried += 1
# self.on_retry()
# self._connect(metadata=metadata)
#
# if not self.can_retry and not self.connected:
# raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry,
# metadata=metadata))
#
# self.retried = 0
#
# def connect(self, func, exception_handler=None):
# @wraps(func)
# def inner(*args, **kwargs):
# self.on_connect()
# try:
# return func(*args, **kwargs)
# except Exception as e:
# if exception_handler:
# exception_handler(e)
# else:
# raise e
# return inner
#
# def __str__(self):
# return '<Connection: {}:{}>'.format(self.name, id(self))
#
# def __repr__(self):
# return self.__str__()
#
#
# class Duration:
# def __init__(self):
# self.start_ts = time.time()
# self.end_ts = None
#
# def stop(self):
# if self.end_ts:
# return False
#
# self.end_ts = time.time()
# return True
#
# @property
# def value(self):
# if not self.end_ts:
# return None
#
# return self.end_ts - self.start_ts
#
#
# class ProxyMixin:
# def __getattr__(self, name):
# target = self.__dict__.get(name, None)
# if target or not self.connection:
# return target
# return getattr(self.connection, name)
#
#
# class ScopedConnection(ProxyMixin):
# def __init__(self, pool, connection):
# self.pool = pool
# self.connection = connection
# self.duration = Duration()
#
# def __del__(self):
# self.release()
#
# def __str__(self):
# return self.connection.__str__()
#
# def release(self):
# if not self.pool or not self.connection:
# return
# self.pool.release(self.connection)
# self.duration.stop()
# self.pool.record_duration(self.connection, self.duration)
# self.pool = None
# self.connection = None
#
#
# class ConnectionPool(topology.TopoObject):
# def __init__(self, name, uri, max_retry=1, capacity=-1, **kwargs):
# super().__init__(name)
# self.capacity = capacity
# self.pending_pool = set()
# self.active_pool = set()
# self.connection_ownership = {}
# self.uri = uri
# self.max_retry = max_retry
# self.kwargs = kwargs
# self.cv = threading.Condition()
# self.durations = defaultdict(list)
#
# def record_duration(self, conn, duration):
# if len(self.durations[conn]) >= 10000:
# self.durations[conn].pop(0)
#
# self.durations[conn].append(duration)
#
# def stats(self):
# out = {'connections': {}}
# connections = out['connections']
# take_time = []
# for conn, durations in self.durations.items():
# total_time = sum(d.value for d in durations)
# connections[id(conn)] = {
# 'total_time': total_time,
# 'called_times': len(durations)
# }
# take_time.append(total_time)
#
# out['max-time'] = max(take_time)
# out['num'] = len(self.durations)
# logger.debug(json.dumps(out, indent=2))
# return out
#
# def __len__(self):
# return len(self.pending_pool) + len(self.active_pool)
#
# @property
# def active_num(self):
# return len(self.active_pool)
#
# def _is_full(self):
# if self.capacity < 0:
# return False
# return len(self) >= self.capacity
#
# def fetch(self, timeout=1):
# with self.cv:
# timeout_times = 0
# while (len(self.pending_pool) == 0 and self._is_full() and timeout_times < 1):
# self.cv.notifyAll()
# self.cv.wait(timeout)
# timeout_times += 1
#
# connection = None
# if timeout_times >= 1:
# return connection
#
# # logger.error('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
# if len(self.pending_pool) == 0:
# connection = self.create()
# else:
# connection = self.pending_pool.pop()
# # logger.debug('[Connection] Registerring \"{}\" into pool \"{}\"'.format(connection, self.name))
# self.active_pool.add(connection)
# scoped_connection = ScopedConnection(self, connection)
# return scoped_connection
#
# def release(self, connection):
# with self.cv:
# if connection not in self.active_pool:
# raise RuntimeError('\"{}\" not found in pool \"{}\"'.format(connection, self.name))
# # logger.debug('[Connection] Releasing \"{}\" from pool \"{}\"'.format(connection, self.name))
# # logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
# self.active_pool.remove(connection)
# self.pending_pool.add(connection)
#
# def create(self):
# connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs)
# return connection
class ConnectionGroup(topology.TopoGroup):
def __init__(self, name):
super().__init__(name)
def stats(self):
out = {}
for name, item in self.items.items():
out[name] = item.stats()
return out
def on_pre_add(self, topo_object):
# conn = topo_object.fetch()
# conn.on_connect(metadata=None)
status, version = topo_object.server_version()
if not status.OK():
logger.error('Cannot connect to newly added address: {}. Remove it now'.format(topo_object.name))
return False
if version not in settings.SERVER_VERSIONS:
logger.error('Cannot connect to server of version: {}. Only {} supported'.format(version,
settings.SERVER_VERSIONS))
return False
return True
def create(self, name, **kwargs):
uri = kwargs.get('uri', None)
if not uri:
raise RuntimeError('\"uri\" is required to create connection pool')
milvus_args = copy.deepcopy(kwargs)
milvus_args["max_retry"] = settings.MAX_RETRY
pool = Milvus(name=name, **milvus_args)
status = self.add(pool)
if status != topology.StatusType.OK:
pool = None
return status, pool
class ConnectionTopology(topology.Topology):
def __init__(self):
super().__init__()
def stats(self):
out = {}
for name, group in self.topo_groups.items():
out[name] = group.stats()
return out
def create(self, name):
group = ConnectionGroup(name)
status = self.add_group(group)
if status == topology.StatusType.DUPLICATED:
group = None
return status, group