milvus/shards/mishards/test_connections.py

240 lines
7.3 KiB
Python

import logging
import pytest
import mock
import random
import threading
from milvus import Milvus
from mishards.connections import (Connection,
ConnectionPool, ConnectionTopology, ConnectionGroup)
from mishards.topology import StatusType
from mishards import exceptions
logger = logging.getLogger(__name__)
@pytest.mark.usefixtures('app')
class TestConnection:
def test_connection(self):
class Conn:
def __init__(self, state):
self.state = state
def connect(self, uri):
return self.state
def connected(self):
return self.state
FAIL_CONN = Conn(False)
PASS_CONN = Conn(True)
class Retry:
def __init__(self):
self.times = 0
def __call__(self, conn):
self.times += 1
logger.info('Retrying {}'.format(self.times))
class Func():
def __init__(self):
self.executed = False
def __call__(self):
self.executed = True
max_retry = 3
RetryObj = Retry()
c = Connection('client',
uri='xx',
max_retry=max_retry,
on_retry_func=RetryObj)
c.conn = FAIL_CONN
ff = Func()
this_connect = c.connect(func=ff)
with pytest.raises(exceptions.ConnectionConnectError):
this_connect()
assert RetryObj.times == max_retry
assert not ff.executed
RetryObj = Retry()
c.conn = PASS_CONN
this_connect = c.connect(func=ff)
this_connect()
assert ff.executed
assert RetryObj.times == 0
this_connect = c.connect(func=None)
with pytest.raises(TypeError):
this_connect()
errors = []
def error_handler(err):
errors.append(err)
this_connect = c.connect(func=None, exception_handler=error_handler)
this_connect()
assert len(errors) == 1
def test_topology(self):
ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,))
w_topo = ConnectionTopology()
status, wg1 = w_topo.create(name='wg1')
assert w_topo.has_group(wg1)
assert status == StatusType.OK
status, wg1_dup = w_topo.create(name='wg1')
assert wg1_dup is None
assert status == StatusType.DUPLICATED
fetched_group = w_topo.get_group('wg1')
assert id(fetched_group) == id(wg1)
with pytest.raises(RuntimeError):
wg1.create(name='wg1_p1')
status, wg1_p1 = wg1.create(name='wg1_p1', uri='127.0.0.1:19530')
assert status == StatusType.OK
assert wg1_p1 is not None
assert len(wg1) == 1
status, wg1_p1_dup = wg1.create(name='wg1_p1', uri='127.0.0.1:19530')
assert status == StatusType.DUPLICATED
assert wg1_p1_dup is None
assert len(wg1) == 1
status, wg1_p2 = wg1.create('wg1_p2', uri='127.0.0.1:19530')
assert status == StatusType.OK
assert wg1_p2 is not None
assert len(wg1) == 2
poped = wg1.remove('wg1_p3')
assert poped is None
assert len(wg1) == 2
poped = wg1.remove('wg1_p2')
assert poped.name == 'wg1_p2'
assert len(wg1) == 1
fetched_p1 = wg1.get(wg1_p1.name)
assert fetched_p1 == wg1_p1
fetched_p1 = w_topo.get_group('wg1').get('wg1_p1')
conn1 = fetched_p1.fetch()
assert len(fetched_p1) == 1
assert fetched_p1.active_num == 1
conn2 = fetched_p1.fetch()
assert len(fetched_p1) == 2
assert fetched_p1.active_num == 2
conn2.release()
assert len(fetched_p1) == 2
assert fetched_p1.active_num == 1
assert len(w_topo.group_names) == 1
def test_connection_pool(self):
ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,))
def choaz_mp_fetch(capacity, count, tnum):
threads_num = 5
topo = ConnectionTopology()
_, tg = topo.create('tg')
pool_size = 20
pool_names = ['p{}:19530'.format(i) for i in range(pool_size)]
threads = []
def Worker(group, cnt, capacity):
ori_cnt = cnt
assert cnt < 100
while cnt >= 0:
name = pool_names[random.randint(0, pool_size-1)]
cnt -= 1
remove = (random.randint(1,4)%4 == 0)
if remove:
pool = group.get(name=name)
# if name.startswith("p1:"):
# logger.error('{} CNT={} [Remove] Group \"{}\" has pool of SIZE={} ACTIVE={}'.format(threading.get_ident(), ori_cnt-cnt, name, len(pool), pool.active_num))
group.remove(name)
else:
group.create(name=name, uri=name, capacity=capacity)
pool = group.get(name=name)
assert pool is not None
conn = pool.fetch(timeout=0.01)
# if name.startswith("p1:"):
# logger.error('{} CNT={} [Adding] Group \"{}\" has pool of SIZE={} ACTIVE={}'.format(threading.get_ident(), ori_cnt-cnt, name, len(pool), pool.active_num))
for _ in range(threads_num):
t = threading.Thread(target=Worker, args=(tg, count, tnum))
threads.append(t)
t.start()
for t in threads:
t.join()
choaz_mp_fetch(4, 40, 8)
def check_mp_fetch(capacity=-1):
w2 = ConnectionPool(name='w2', uri='127.0.0.1:19530', max_retry=2, capacity=capacity)
connections = []
def GetConnection(pool):
conn = pool.fetch(timeout=0.1)
if conn:
connections.append(conn)
threads = []
threads_num = 10 if capacity < 0 else 2*capacity
for _ in range(threads_num):
t = threading.Thread(target=GetConnection, args=(w2,))
threads.append(t)
t.start()
for t in threads:
t.join()
expected_size = threads_num if capacity < 0 else capacity
assert len(connections) == expected_size
check_mp_fetch(5)
check_mp_fetch()
w1 = ConnectionPool(name='w1', uri='127.0.0.1:19530', max_retry=2, capacity=2)
w1_1 = w1.fetch()
assert len(w1) == 1
assert w1.active_num == 1
w1_2 = w1.fetch()
assert len(w1) == 2
assert w1.active_num == 2
w1_3 = w1.fetch()
assert w1_3 is None
assert len(w1) == 2
assert w1.active_num == 2
w1_1.release()
assert len(w1) == 2
assert w1.active_num == 1
def check(pool, expected_size, expected_active_num):
w = pool.fetch()
assert len(pool) == expected_size
assert pool.active_num == expected_active_num
check(w1, 2, 2)
assert len(w1) == 2
assert w1.active_num == 1
wild_w = w1.create()
with pytest.raises(RuntimeError):
w1.release(wild_w)
ret = w1_2.can_retry
assert ret == w1_2.connection.can_retry