mirror of https://github.com/milvus-io/milvus.git
240 lines
7.3 KiB
Python
240 lines
7.3 KiB
Python
import logging
|
|
import pytest
|
|
import mock
|
|
import random
|
|
import threading
|
|
|
|
from milvus import Milvus
|
|
from mishards.connections import (Connection,
|
|
ConnectionPool, ConnectionTopology, ConnectionGroup)
|
|
from mishards.topology import StatusType
|
|
from mishards import exceptions
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@pytest.mark.usefixtures('app')
|
|
class TestConnection:
|
|
def test_connection(self):
|
|
class Conn:
|
|
def __init__(self, state):
|
|
self.state = state
|
|
|
|
def connect(self, uri):
|
|
return self.state
|
|
|
|
def connected(self):
|
|
return self.state
|
|
|
|
FAIL_CONN = Conn(False)
|
|
PASS_CONN = Conn(True)
|
|
|
|
class Retry:
|
|
def __init__(self):
|
|
self.times = 0
|
|
|
|
def __call__(self, conn):
|
|
self.times += 1
|
|
logger.info('Retrying {}'.format(self.times))
|
|
|
|
class Func():
|
|
def __init__(self):
|
|
self.executed = False
|
|
|
|
def __call__(self):
|
|
self.executed = True
|
|
|
|
max_retry = 3
|
|
|
|
RetryObj = Retry()
|
|
|
|
c = Connection('client',
|
|
uri='xx',
|
|
max_retry=max_retry,
|
|
on_retry_func=RetryObj)
|
|
c.conn = FAIL_CONN
|
|
ff = Func()
|
|
this_connect = c.connect(func=ff)
|
|
with pytest.raises(exceptions.ConnectionConnectError):
|
|
this_connect()
|
|
assert RetryObj.times == max_retry
|
|
assert not ff.executed
|
|
RetryObj = Retry()
|
|
|
|
c.conn = PASS_CONN
|
|
this_connect = c.connect(func=ff)
|
|
this_connect()
|
|
assert ff.executed
|
|
assert RetryObj.times == 0
|
|
|
|
this_connect = c.connect(func=None)
|
|
with pytest.raises(TypeError):
|
|
this_connect()
|
|
|
|
errors = []
|
|
|
|
def error_handler(err):
|
|
errors.append(err)
|
|
|
|
this_connect = c.connect(func=None, exception_handler=error_handler)
|
|
this_connect()
|
|
assert len(errors) == 1
|
|
|
|
def test_topology(self):
|
|
ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,))
|
|
w_topo = ConnectionTopology()
|
|
status, wg1 = w_topo.create(name='wg1')
|
|
assert w_topo.has_group(wg1)
|
|
assert status == StatusType.OK
|
|
|
|
status, wg1_dup = w_topo.create(name='wg1')
|
|
assert wg1_dup is None
|
|
assert status == StatusType.DUPLICATED
|
|
|
|
fetched_group = w_topo.get_group('wg1')
|
|
assert id(fetched_group) == id(wg1)
|
|
|
|
with pytest.raises(RuntimeError):
|
|
wg1.create(name='wg1_p1')
|
|
|
|
status, wg1_p1 = wg1.create(name='wg1_p1', uri='127.0.0.1:19530')
|
|
assert status == StatusType.OK
|
|
assert wg1_p1 is not None
|
|
assert len(wg1) == 1
|
|
|
|
status, wg1_p1_dup = wg1.create(name='wg1_p1', uri='127.0.0.1:19530')
|
|
assert status == StatusType.DUPLICATED
|
|
assert wg1_p1_dup is None
|
|
assert len(wg1) == 1
|
|
|
|
status, wg1_p2 = wg1.create('wg1_p2', uri='127.0.0.1:19530')
|
|
assert status == StatusType.OK
|
|
assert wg1_p2 is not None
|
|
assert len(wg1) == 2
|
|
|
|
poped = wg1.remove('wg1_p3')
|
|
assert poped is None
|
|
assert len(wg1) == 2
|
|
|
|
poped = wg1.remove('wg1_p2')
|
|
assert poped.name == 'wg1_p2'
|
|
assert len(wg1) == 1
|
|
|
|
fetched_p1 = wg1.get(wg1_p1.name)
|
|
assert fetched_p1 == wg1_p1
|
|
|
|
fetched_p1 = w_topo.get_group('wg1').get('wg1_p1')
|
|
|
|
conn1 = fetched_p1.fetch()
|
|
assert len(fetched_p1) == 1
|
|
assert fetched_p1.active_num == 1
|
|
|
|
conn2 = fetched_p1.fetch()
|
|
assert len(fetched_p1) == 2
|
|
assert fetched_p1.active_num == 2
|
|
|
|
conn2.release()
|
|
assert len(fetched_p1) == 2
|
|
assert fetched_p1.active_num == 1
|
|
|
|
assert len(w_topo.group_names) == 1
|
|
|
|
def test_connection_pool(self):
|
|
ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,))
|
|
|
|
def choaz_mp_fetch(capacity, count, tnum):
|
|
threads_num = 5
|
|
topo = ConnectionTopology()
|
|
_, tg = topo.create('tg')
|
|
pool_size = 20
|
|
pool_names = ['p{}:19530'.format(i) for i in range(pool_size)]
|
|
|
|
threads = []
|
|
def Worker(group, cnt, capacity):
|
|
ori_cnt = cnt
|
|
assert cnt < 100
|
|
while cnt >= 0:
|
|
name = pool_names[random.randint(0, pool_size-1)]
|
|
cnt -= 1
|
|
remove = (random.randint(1,4)%4 == 0)
|
|
if remove:
|
|
pool = group.get(name=name)
|
|
# if name.startswith("p1:"):
|
|
# logger.error('{} CNT={} [Remove] Group \"{}\" has pool of SIZE={} ACTIVE={}'.format(threading.get_ident(), ori_cnt-cnt, name, len(pool), pool.active_num))
|
|
group.remove(name)
|
|
|
|
else:
|
|
group.create(name=name, uri=name, capacity=capacity)
|
|
pool = group.get(name=name)
|
|
assert pool is not None
|
|
conn = pool.fetch(timeout=0.01)
|
|
# if name.startswith("p1:"):
|
|
# logger.error('{} CNT={} [Adding] Group \"{}\" has pool of SIZE={} ACTIVE={}'.format(threading.get_ident(), ori_cnt-cnt, name, len(pool), pool.active_num))
|
|
|
|
for _ in range(threads_num):
|
|
t = threading.Thread(target=Worker, args=(tg, count, tnum))
|
|
threads.append(t)
|
|
t.start()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
choaz_mp_fetch(4, 40, 8)
|
|
|
|
def check_mp_fetch(capacity=-1):
|
|
w2 = ConnectionPool(name='w2', uri='127.0.0.1:19530', max_retry=2, capacity=capacity)
|
|
connections = []
|
|
def GetConnection(pool):
|
|
conn = pool.fetch(timeout=0.1)
|
|
if conn:
|
|
connections.append(conn)
|
|
|
|
threads = []
|
|
threads_num = 10 if capacity < 0 else 2*capacity
|
|
for _ in range(threads_num):
|
|
t = threading.Thread(target=GetConnection, args=(w2,))
|
|
threads.append(t)
|
|
t.start()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
|
|
expected_size = threads_num if capacity < 0 else capacity
|
|
|
|
assert len(connections) == expected_size
|
|
|
|
check_mp_fetch(5)
|
|
check_mp_fetch()
|
|
|
|
w1 = ConnectionPool(name='w1', uri='127.0.0.1:19530', max_retry=2, capacity=2)
|
|
w1_1 = w1.fetch()
|
|
assert len(w1) == 1
|
|
assert w1.active_num == 1
|
|
w1_2 = w1.fetch()
|
|
assert len(w1) == 2
|
|
assert w1.active_num == 2
|
|
w1_3 = w1.fetch()
|
|
assert w1_3 is None
|
|
assert len(w1) == 2
|
|
assert w1.active_num == 2
|
|
|
|
w1_1.release()
|
|
assert len(w1) == 2
|
|
assert w1.active_num == 1
|
|
|
|
def check(pool, expected_size, expected_active_num):
|
|
w = pool.fetch()
|
|
assert len(pool) == expected_size
|
|
assert pool.active_num == expected_active_num
|
|
|
|
check(w1, 2, 2)
|
|
|
|
assert len(w1) == 2
|
|
assert w1.active_num == 1
|
|
|
|
wild_w = w1.create()
|
|
with pytest.raises(RuntimeError):
|
|
w1.release(wild_w)
|
|
|
|
ret = w1_2.can_retry
|
|
assert ret == w1_2.connection.can_retry
|