mirror of https://github.com/milvus-io/milvus.git
parent
6ad2252f2d
commit
28f636151c
|
@ -1,255 +0,0 @@
|
|||
import socket
|
||||
import pytest
|
||||
|
||||
from .utils import *
|
||||
|
||||
timeout = 60
|
||||
dimension = 128
|
||||
delete_timeout = 60
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--ip", action="store", default="localhost")
|
||||
parser.addoption("--service", action="store", default="")
|
||||
parser.addoption("--port", action="store", default=19530)
|
||||
parser.addoption("--http-port", action="store", default=19121)
|
||||
parser.addoption("--handler", action="store", default="GRPC")
|
||||
parser.addoption("--tag", action="store", default="all", help="only run tests matching the tag.")
|
||||
parser.addoption('--dry-run', action='store_true', default=False)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
# register an additional marker
|
||||
config.addinivalue_line(
|
||||
"markers", "tag(name): mark test to run only matching the tag"
|
||||
)
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
tags = list()
|
||||
for marker in item.iter_markers(name="tag"):
|
||||
for tag in marker.args:
|
||||
tags.append(tag)
|
||||
if tags:
|
||||
cmd_tag = item.config.getoption("--tag")
|
||||
if cmd_tag != "all" and cmd_tag not in tags:
|
||||
pytest.skip("test requires tag in {!r}".format(tags))
|
||||
|
||||
|
||||
def pytest_runtestloop(session):
|
||||
if session.config.getoption('--dry-run'):
|
||||
total_passed = 0
|
||||
total_skipped = 0
|
||||
test_file_to_items = {}
|
||||
for item in session.items:
|
||||
file_name, test_class, test_func = item.nodeid.split("::")
|
||||
if test_file_to_items.get(file_name) is not None:
|
||||
test_file_to_items[file_name].append(item)
|
||||
else:
|
||||
test_file_to_items[file_name] = [item]
|
||||
for k, items in test_file_to_items.items():
|
||||
skip_case = []
|
||||
should_pass_but_skipped = []
|
||||
skipped_other_reason = []
|
||||
|
||||
level2_case = []
|
||||
for item in items:
|
||||
if "pytestmark" in item.keywords.keys():
|
||||
markers = item.keywords["pytestmark"]
|
||||
skip_case.extend([item.nodeid for marker in markers if marker.name == 'skip'])
|
||||
should_pass_but_skipped.extend([item.nodeid for marker in markers if marker.name == 'skip' and len(marker.args) > 0 and marker.args[0] == "should pass"])
|
||||
skipped_other_reason.extend([item.nodeid for marker in markers if marker.name == 'skip' and (len(marker.args) < 1 or marker.args[0] != "should pass")])
|
||||
level2_case.extend([item.nodeid for marker in markers if marker.name == 'level' and marker.args[0] == 2])
|
||||
|
||||
print("")
|
||||
print(f"[{k}]:")
|
||||
print(f" Total : {len(items):13}")
|
||||
print(f" Passed : {len(items) - len(skip_case):13}")
|
||||
print(f" Skipped : {len(skip_case):13}")
|
||||
print(f" - should pass: {len(should_pass_but_skipped):4}")
|
||||
print(f" - not supported: {len(skipped_other_reason):4}")
|
||||
print(f" Level2 : {len(level2_case):13}")
|
||||
|
||||
print(f" ---------------------------------------")
|
||||
print(f" should pass but skipped: ")
|
||||
print("")
|
||||
for nodeid in should_pass_but_skipped:
|
||||
name, test_class, test_func = nodeid.split("::")
|
||||
print(f" {name:8}: {test_class}.{test_func}")
|
||||
print("")
|
||||
print(f"===============================================")
|
||||
total_passed += len(items) - len(skip_case)
|
||||
total_skipped += len(skip_case)
|
||||
|
||||
print("Total tests : ", len(session.items))
|
||||
print("Total passed: ", total_passed)
|
||||
print("Total skiped: ", total_skipped)
|
||||
return True
|
||||
|
||||
|
||||
def check_server_connection(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
port = request.config.getoption("--port")
|
||||
|
||||
connected = True
|
||||
if ip and (ip not in ['localhost', '127.0.0.1']):
|
||||
try:
|
||||
socket.getaddrinfo(ip, port, 0, 0, socket.IPPROTO_TCP)
|
||||
except Exception as e:
|
||||
print("Socket connnet failed: %s" % str(e))
|
||||
connected = False
|
||||
return connected
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def connect(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
service_name = request.config.getoption("--service")
|
||||
port = request.config.getoption("--port")
|
||||
http_port = request.config.getoption("--http-port")
|
||||
handler = request.config.getoption("--handler")
|
||||
if handler == "HTTP":
|
||||
port = http_port
|
||||
try:
|
||||
milvus = get_milvus(host=ip, port=port, handler=handler)
|
||||
# reset_build_index_threshold(milvus)
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
pytest.exit("Milvus server can not connected, exit pytest ...")
|
||||
def fin():
|
||||
try:
|
||||
milvus.close()
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger().info(str(e))
|
||||
request.addfinalizer(fin)
|
||||
return milvus
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dis_connect(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
service_name = request.config.getoption("--service")
|
||||
port = request.config.getoption("--port")
|
||||
http_port = request.config.getoption("--http-port")
|
||||
handler = request.config.getoption("--handler")
|
||||
if handler == "HTTP":
|
||||
port = http_port
|
||||
milvus = get_milvus(host=ip, port=port, handler=handler)
|
||||
milvus.close()
|
||||
return milvus
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def args(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
service_name = request.config.getoption("--service")
|
||||
port = request.config.getoption("--port")
|
||||
http_port = request.config.getoption("--http-port")
|
||||
handler = request.config.getoption("--handler")
|
||||
if handler == "HTTP":
|
||||
port = http_port
|
||||
args = {"ip": ip, "port": port, "handler": handler, "service_name": service_name}
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def milvus(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
port = request.config.getoption("--port")
|
||||
http_port = request.config.getoption("--http-port")
|
||||
handler = request.config.getoption("--handler")
|
||||
if handler == "HTTP":
|
||||
port = http_port
|
||||
return get_milvus(host=ip, port=port, handler=handler)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def collection(request, connect):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
default_fields = gen_default_fields()
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
connect.load_collection(collection_name)
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
if connect.has_collection(collection_name):
|
||||
connect.drop_collection(collection_name, timeout=delete_timeout)
|
||||
request.addfinalizer(teardown)
|
||||
assert connect.has_collection(collection_name)
|
||||
return collection_name
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def collection_without_loading(request, connect):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
default_fields = gen_default_fields()
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
if connect.has_collection(collection_name):
|
||||
connect.drop_collection(collection_name, timeout=delete_timeout)
|
||||
request.addfinalizer(teardown)
|
||||
assert connect.has_collection(collection_name)
|
||||
return collection_name
|
||||
|
||||
|
||||
# customised id
|
||||
@pytest.fixture(scope="function")
|
||||
def id_collection(request, connect):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
fields = gen_default_fields(auto_id=False)
|
||||
connect.create_collection(collection_name, fields)
|
||||
connect.load_collection(collection_name)
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
if connect.has_collection(collection_name):
|
||||
connect.drop_collection(collection_name, timeout=delete_timeout)
|
||||
request.addfinalizer(teardown)
|
||||
assert connect.has_collection(collection_name)
|
||||
return collection_name
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def binary_collection(request, connect):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
fields = gen_binary_default_fields()
|
||||
connect.create_collection(collection_name, fields)
|
||||
connect.load_collection(collection_name)
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
collection_names = connect.list_collections()
|
||||
if connect.has_collection(collection_name):
|
||||
connect.drop_collection(collection_name, timeout=delete_timeout)
|
||||
request.addfinalizer(teardown)
|
||||
assert connect.has_collection(collection_name)
|
||||
return collection_name
|
||||
|
||||
|
||||
# customised id
|
||||
@pytest.fixture(scope="function")
|
||||
def binary_id_collection(request, connect):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
fields = gen_binary_default_fields(auto_id=False)
|
||||
connect.create_collection(collection_name, fields)
|
||||
connect.load_collection(collection_name)
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
if connect.has_collection(collection_name):
|
||||
connect.drop_collection(collection_name, timeout=delete_timeout)
|
||||
request.addfinalizer(teardown)
|
||||
assert connect.has_collection(collection_name)
|
||||
return collection_name
|
|
@ -1,22 +0,0 @@
|
|||
from . import utils
|
||||
|
||||
default_fields = utils.gen_default_fields()
|
||||
default_binary_fields = utils.gen_binary_default_fields()
|
||||
|
||||
default_entity = utils.gen_entities(1)
|
||||
default_raw_binary_vector, default_binary_entity = utils.gen_binary_entities(1)
|
||||
|
||||
default_entity_row = utils.gen_entities_rows(1)
|
||||
default_raw_binary_vector_row, default_binary_entity_row = utils.gen_binary_entities_rows(1)
|
||||
|
||||
|
||||
default_entities = utils.gen_entities(utils.default_nb)
|
||||
default_raw_binary_vectors, default_binary_entities = utils.gen_binary_entities(utils.default_nb)
|
||||
|
||||
|
||||
default_entities_new = utils.gen_entities_new(utils.default_nb)
|
||||
default_raw_binary_vectors_new, default_binary_entities_new = utils.gen_binary_entities_new(utils.default_nb)
|
||||
|
||||
|
||||
default_entities_rows = utils.gen_entities_rows(utils.default_nb)
|
||||
default_raw_binary_vectors_rows, default_binary_entities_rows = utils.gen_binary_entities_rows(utils.default_nb)
|
|
@ -1,127 +0,0 @@
|
|||
# STL imports
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import datetime
|
||||
import random
|
||||
import struct
|
||||
import sys
|
||||
import uuid
|
||||
from functools import wraps
|
||||
|
||||
sys.path.append('..')
|
||||
# Third party imports
|
||||
import numpy as np
|
||||
import faker
|
||||
from faker.providers import BaseProvider
|
||||
|
||||
# local application imports
|
||||
from milvus.client.types import IndexType, MetricType, DataType
|
||||
|
||||
# grpc
|
||||
from milvus.client.grpc_handler import Prepare as gPrepare
|
||||
from milvus.grpc_gen import milvus_pb2
|
||||
|
||||
|
||||
def gen_vectors(num, dim):
|
||||
return [[random.random() for _ in range(dim)] for _ in range(num)]
|
||||
|
||||
|
||||
def gen_single_vector(dim):
|
||||
return [[random.random() for _ in range(dim)]]
|
||||
|
||||
|
||||
def gen_vector(nb, d, seed=np.random.RandomState(1234)):
|
||||
xb = seed.rand(nb, d).astype("float32")
|
||||
return xb.tolist()
|
||||
|
||||
|
||||
def gen_unique_str(str=None):
|
||||
prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
|
||||
return prefix if str is None else str + "_" + prefix
|
||||
|
||||
|
||||
def get_current_day():
|
||||
return time.strftime('%Y-%m-%d', time.localtime())
|
||||
|
||||
|
||||
def get_last_day(day):
|
||||
tmp = datetime.datetime.now() - datetime.timedelta(days=day)
|
||||
return tmp.strftime('%Y-%m-%d')
|
||||
|
||||
|
||||
def get_next_day(day):
|
||||
tmp = datetime.datetime.now() + datetime.timedelta(days=day)
|
||||
return tmp.strftime('%Y-%m-%d')
|
||||
|
||||
|
||||
def gen_long_str(num):
|
||||
string = ''
|
||||
for _ in range(num):
|
||||
char = random.choice('tomorrow')
|
||||
string += char
|
||||
|
||||
|
||||
def gen_one_binary(topk):
|
||||
ids = [random.randrange(10000000, 99999999) for _ in range(topk)]
|
||||
distances = [random.random() for _ in range(topk)]
|
||||
return milvus_pb2.TopKQueryResult(struct.pack(str(topk) + 'l', *ids), struct.pack(str(topk) + 'd', *distances))
|
||||
|
||||
|
||||
def gen_nq_binaries(nq, topk):
|
||||
return [gen_one_binary(topk) for _ in range(nq)]
|
||||
|
||||
|
||||
def fake_query_bin_result(nq, topk):
|
||||
return gen_nq_binaries(nq, topk)
|
||||
|
||||
|
||||
class FakerProvider(BaseProvider):
|
||||
|
||||
def collection_name(self):
|
||||
return 'collection_names' + str(uuid.uuid4()).replace('-', '_')
|
||||
|
||||
def normal_field_name(self):
|
||||
return 'normal_field_names' + str(uuid.uuid4()).replace('-', '_')
|
||||
|
||||
def vector_field_name(self):
|
||||
return 'vector_field_names' + str(uuid.uuid4()).replace('-', '_')
|
||||
|
||||
def name(self):
|
||||
return 'name' + str(random.randint(1000, 9999))
|
||||
|
||||
def dim(self):
|
||||
return random.randint(0, 999)
|
||||
|
||||
|
||||
fake = faker.Faker()
|
||||
fake.add_provider(FakerProvider)
|
||||
|
||||
def collection_name_factory():
|
||||
return fake.collection_name()
|
||||
|
||||
def collection_schema_factory():
|
||||
param = {
|
||||
"fields": [
|
||||
{"name": fake.normal_field_name(),"type": DataType.INT32},
|
||||
{"name": fake.vector_field_name(),"type": DataType.FLOAT_VECTOR, "params": {"dim": random.randint(1, 999)}},
|
||||
],
|
||||
"auto_id": True,
|
||||
}
|
||||
return param
|
||||
|
||||
|
||||
def records_factory(dimension, nq):
|
||||
return [[random.random() for _ in range(dimension)] for _ in range(nq)]
|
||||
|
||||
|
||||
def time_it(func):
|
||||
@wraps(func)
|
||||
def inner(*args, **kwrgs):
|
||||
pref = time.perf_counter()
|
||||
result = func(*args, **kwrgs)
|
||||
delt = time.perf_counter() - pref
|
||||
print(f"[{func.__name__}][{delt:.4}s]")
|
||||
return result
|
||||
|
||||
return inner
|
|
@ -1,19 +0,0 @@
|
|||
[pytest]
|
||||
log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)
|
||||
log_date_format = %Y-%m-%d %H:%M:%S
|
||||
|
||||
# cli arguments. `-x`-stop test when error occurred;
|
||||
addopts = -x
|
||||
|
||||
testpaths = .
|
||||
|
||||
log_cli = true
|
||||
log_level = 10
|
||||
|
||||
timeout = 360
|
||||
|
||||
markers =
|
||||
level: test level
|
||||
serial
|
||||
|
||||
; level = 1
|
|
@ -1,14 +0,0 @@
|
|||
grpcio==1.26.0
|
||||
grpcio-tools==1.26.0
|
||||
numpy==1.18.1
|
||||
pytest-cov==2.8.1
|
||||
pymilvus-distributed==0.0.35
|
||||
sklearn==0.0
|
||||
pytest==4.5.0
|
||||
pytest-timeout==1.3.3
|
||||
pytest-repeat==0.8.0
|
||||
allure-pytest==2.7.0
|
||||
pytest-print==0.1.2
|
||||
pytest-level==0.1.1
|
||||
pytest-xdist==1.23.2
|
||||
git+https://gitee.com/quicksilver/pytest-tags.git
|
|
@ -1,314 +0,0 @@
|
|||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uid = "create_collection"
|
||||
|
||||
class TestCreateCollection:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_collection` function
|
||||
******************************************************************
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
)
|
||||
def get_filter_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_vector_fields()
|
||||
)
|
||||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_segment_row_limits()
|
||||
)
|
||||
def get_segment_row_limit(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_create_collection_fields(self, connect, get_filter_field, get_vector_field):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff fields: metric/field_type/...
|
||||
expected: no exception raised
|
||||
'''
|
||||
filter_field = get_filter_field
|
||||
# logging.getLogger().info(filter_field)
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
}
|
||||
# logging.getLogger().info(fields)
|
||||
connect.create_collection(collection_name, fields)
|
||||
assert connect.has_collection(collection_name)
|
||||
|
||||
def test_create_collection_fields_create_index(self, connect, get_filter_field, get_vector_field):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff fields: metric/field_type/...
|
||||
expected: no exception raised
|
||||
'''
|
||||
filter_field = get_filter_field
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
assert connect.has_collection(collection_name)
|
||||
|
||||
@pytest.mark.skip("no segment_row_limit")
|
||||
def test_create_collection_segment_row_limit(self, connect):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff segment_row_limit
|
||||
expected: no exception raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
# fields["segment_row_limit"] = get_segment_row_limit
|
||||
connect.create_collection(collection_name, fields)
|
||||
assert connect.has_collection(collection_name)
|
||||
|
||||
@pytest.mark.skip("no flush")
|
||||
def _test_create_collection_auto_flush_disabled(self, connect):
|
||||
'''
|
||||
target: test create normal collection, with large auto_flush_interval
|
||||
method: create collection with corrent params
|
||||
expected: create status return ok
|
||||
'''
|
||||
disable_flush(connect)
|
||||
collection_name = gen_unique_str(uid)
|
||||
try:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
finally:
|
||||
enable_flush(connect)
|
||||
|
||||
def test_create_collection_after_insert(self, connect, collection):
|
||||
'''
|
||||
target: test insert vector, then create collection again
|
||||
method: insert vector and create collection
|
||||
expected: error raised
|
||||
'''
|
||||
# pdb.set_trace()
|
||||
connect.insert(collection, default_entity)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection, default_fields)
|
||||
|
||||
def test_create_collection_after_insert_flush(self, connect, collection):
|
||||
'''
|
||||
target: test insert vector, then create collection again
|
||||
method: insert vector and create collection
|
||||
expected: error raised
|
||||
'''
|
||||
connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection, default_fields)
|
||||
|
||||
def test_create_collection_without_connection(self, dis_connect):
|
||||
'''
|
||||
target: test create collection, without connection
|
||||
method: create collection with correct params, with a disconnected instance
|
||||
expected: error raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
dis_connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def test_create_collection_existed(self, connect):
|
||||
'''
|
||||
target: test create collection but the collection name have already existed
|
||||
method: create collection with the same collection_name
|
||||
expected: error raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def test_create_after_drop_collection(self, connect, collection):
|
||||
'''
|
||||
target: create with the same collection name after collection dropped
|
||||
method: delete, then create
|
||||
expected: create success
|
||||
'''
|
||||
connect.drop_collection(collection)
|
||||
time.sleep(2)
|
||||
connect.create_collection(collection, default_fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_multithread(self, connect):
|
||||
'''
|
||||
target: test create collection with multithread
|
||||
method: create collection using multithread,
|
||||
expected: collections are created
|
||||
'''
|
||||
threads_num = 8
|
||||
threads = []
|
||||
collection_names = []
|
||||
|
||||
def create():
|
||||
collection_name = gen_unique_str(uid)
|
||||
collection_names.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
for i in range(threads_num):
|
||||
t = threading.Thread(target=create, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
for item in collection_names:
|
||||
assert item in connect.list_collections()
|
||||
connect.drop_collection(item)
|
||||
|
||||
|
||||
class TestCreateCollectionInvalid(object):
|
||||
"""
|
||||
Test creating collections with invalid params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_metric_types()
|
||||
)
|
||||
def get_metric_type(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_ints()
|
||||
)
|
||||
def get_segment_row_limit(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_ints()
|
||||
)
|
||||
def get_dim(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_invalid_string(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_field_types()
|
||||
)
|
||||
def get_field_type(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.skip("no segment row limit")
|
||||
def test_create_collection_with_invalid_segment_row_limit(self, connect, get_segment_row_limit):
|
||||
collection_name = gen_unique_str()
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = get_segment_row_limit
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_with_invalid_dimension(self, connect, get_dim):
|
||||
dimension = get_dim
|
||||
collection_name = gen_unique_str()
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["fields"][-1]["params"]["dim"] = dimension
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_with_invalid_collectionname(self, connect, get_invalid_string):
|
||||
collection_name = get_invalid_string
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_with_empty_collectionname(self, connect):
|
||||
collection_name = ''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_with_none_collectionname(self, connect):
|
||||
collection_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def test_create_collection_None(self, connect):
|
||||
'''
|
||||
target: test create collection but the collection name is None
|
||||
method: create collection, param collection_name is None
|
||||
expected: create raise error
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(None, default_fields)
|
||||
|
||||
def test_create_collection_no_dimension(self, connect):
|
||||
'''
|
||||
target: test create collection with no dimension params
|
||||
method: create collection with corrent params
|
||||
expected: create status return ok
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["fields"][-1]["params"].pop("dim")
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
@pytest.mark.skip("no segment row limit")
|
||||
def test_create_collection_no_segment_row_limit(self, connect):
|
||||
'''
|
||||
target: test create collection with no segment_row_limit params
|
||||
method: create collection with correct params
|
||||
expected: use default default_segment_row_limit
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields.pop("segment_row_limit")
|
||||
connect.create_collection(collection_name, fields)
|
||||
res = connect.get_collection_info(collection_name)
|
||||
# logging.getLogger().info(res)
|
||||
assert res["segment_row_limit"] == default_server_segment_row_limit
|
||||
|
||||
def test_create_collection_limit_fields(self, connect):
|
||||
collection_name = gen_unique_str(uid)
|
||||
limit_num = 64
|
||||
fields = copy.deepcopy(default_fields)
|
||||
for i in range(limit_num):
|
||||
field_name = gen_unique_str("field_name")
|
||||
field = {"name": field_name, "type": DataType.INT64}
|
||||
fields["fields"].append(field)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_invalid_field_name(self, connect, get_invalid_string):
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
field_name = get_invalid_string
|
||||
field = {"name": field_name, "type": DataType.INT64}
|
||||
fields["fields"].append(field)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
def test_create_collection_invalid_field_type(self, connect, get_field_type):
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
field_type = get_field_type
|
||||
field = {"name": "test_field", "type": field_type}
|
||||
fields["fields"].append(field)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
|
@ -1,32 +0,0 @@
|
|||
import copy
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uid = "describe_collection"
|
||||
|
||||
|
||||
class TestDescribeCollection:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `describe_collection` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_describe_collection(self, connect):
|
||||
'''
|
||||
target: test describe collection
|
||||
method: create collection then describe the same collection
|
||||
expected: returned value is the same
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
df = copy.deepcopy(default_fields)
|
||||
df["fields"].append({"name": "int32", "type": DataType.INT32})
|
||||
|
||||
connect.create_collection(collection_name, df)
|
||||
info = connect.describe_collection(collection_name)
|
||||
assert info.get("collection_name") == collection_name
|
||||
assert len(info.get("fields")) == 4
|
||||
|
||||
for field in info.get("fields"):
|
||||
assert field.get("name") in ["int32", "int64", "float", "float_vector"]
|
||||
if field.get("name") == "float_vector":
|
||||
assert field.get("params").get("dim") == str(default_dim)
|
|
@ -1,98 +0,0 @@
|
|||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uniq_id = "drop_collection"
|
||||
|
||||
class TestDropCollection:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `drop_collection` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_drop_collection(self, connect, collection):
|
||||
'''
|
||||
target: test delete collection created with correct params
|
||||
method: create collection and then delete,
|
||||
assert the value returned by delete method
|
||||
expected: status ok, and no collection in collections
|
||||
'''
|
||||
connect.drop_collection(collection)
|
||||
time.sleep(2)
|
||||
assert not connect.has_collection(collection)
|
||||
|
||||
def test_drop_collection_without_connection(self, collection, dis_connect):
|
||||
'''
|
||||
target: test describe collection, without connection
|
||||
method: drop collection with correct params, with a disconnected instance
|
||||
expected: drop raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
dis_connect.drop_collection(collection)
|
||||
|
||||
def test_drop_collection_not_existed(self, connect):
|
||||
'''
|
||||
target: test if collection not created
|
||||
method: random a collection name, which not existed in db,
|
||||
assert the exception raised returned by drp_collection method
|
||||
expected: False
|
||||
'''
|
||||
collection_name = gen_unique_str(uniq_id)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_collection(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_drop_collection_multithread(self, connect):
|
||||
'''
|
||||
target: test create and drop collection with multithread
|
||||
method: create and drop collection using multithread,
|
||||
expected: collections are created, and dropped
|
||||
'''
|
||||
threads_num = 8
|
||||
threads = []
|
||||
collection_names = []
|
||||
|
||||
def create():
|
||||
collection_name = gen_unique_str(uniq_id)
|
||||
collection_names.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
connect.drop_collection(collection_name)
|
||||
|
||||
for i in range(threads_num):
|
||||
t = threading.Thread(target=create, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
for item in collection_names:
|
||||
assert not connect.has_collection(item)
|
||||
|
||||
|
||||
class TestDropCollectionInvalid(object):
|
||||
"""
|
||||
Test has collection with invalid params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_collection_with_invalid_collectionname(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
||||
|
||||
def test_drop_collection_with_empty_collectionname(self, connect):
|
||||
collection_name = ''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
||||
|
||||
def test_drop_collection_with_none_collectionname(self, connect):
|
||||
collection_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
|
@ -1,233 +0,0 @@
|
|||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uid = "collection_info"
|
||||
|
||||
class TestInfoBase:
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
)
|
||||
def get_filter_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_vector_fields()
|
||||
)
|
||||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_segment_row_limits()
|
||||
)
|
||||
def get_segment_row_limit(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, connect):
|
||||
logging.getLogger().info(request.param)
|
||||
# if str(connect._cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("sq8h not support in CPU mode")
|
||||
return request.param
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `get_collection_info` function, no data in collection
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.skip("no segment row limit and type")
|
||||
def test_info_collection_fields(self, connect, get_filter_field, get_vector_field):
|
||||
'''
|
||||
target: test create normal collection with different fields, check info returned
|
||||
method: create collection with diff fields: metric/field_type/..., calling `get_collection_info`
|
||||
expected: no exception raised, and value returned correct
|
||||
'''
|
||||
filter_field = get_filter_field
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"segment_row_limit": default_segment_row_limit
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['auto_id'] == True
|
||||
assert res['segment_row_limit'] == default_segment_row_limit
|
||||
assert len(res["fields"]) == 2
|
||||
for field in res["fields"]:
|
||||
if field["type"] == filter_field:
|
||||
assert field["name"] == filter_field["name"]
|
||||
elif field["type"] == vector_field:
|
||||
assert field["name"] == vector_field["name"]
|
||||
assert field["params"] == vector_field["params"]
|
||||
|
||||
@pytest.mark.skip("no segment row limit and type")
|
||||
def test_create_collection_segment_row_limit(self, connect, get_segment_row_limit):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff segment_row_limit
|
||||
expected: no exception raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = get_segment_row_limit
|
||||
connect.create_collection(collection_name, fields)
|
||||
# assert segment row count
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['segment_row_limit'] == get_segment_row_limit
|
||||
|
||||
@pytest.mark.skip("no create Index")
|
||||
def test_get_collection_info_after_index_created(self, connect, collection, get_simple_index):
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
info = connect.describe_index(collection, field_name)
|
||||
assert info == get_simple_index
|
||||
res = connect.get_collection_info(collection, default_float_vec_field_name)
|
||||
assert index["index_type"] == get_simple_index["index_type"]
|
||||
assert index["metric_type"] == get_simple_index["metric_type"]
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_info_without_connection(self, connect, collection, dis_connect):
|
||||
'''
|
||||
target: test get collection info, without connection
|
||||
method: calling get collection info with correct params, with a disconnected instance
|
||||
expected: get collection info raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
assert connect.get_collection_info(dis_connect, collection)
|
||||
|
||||
def test_get_collection_info_not_existed(self, connect):
|
||||
'''
|
||||
target: test if collection not created
|
||||
method: random a collection name, which not existed in db,
|
||||
assert the value returned by get_collection_info method
|
||||
expected: False
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.get_collection_info(connect, collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_info_multithread(self, connect):
|
||||
'''
|
||||
target: test create collection with multithread
|
||||
method: create collection using multithread,
|
||||
expected: collections are created
|
||||
'''
|
||||
threads_num = 4
|
||||
threads = []
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def get_info():
|
||||
res = connect.get_collection_info(connect, collection_name)
|
||||
# assert
|
||||
|
||||
for i in range(threads_num):
|
||||
t = threading.Thread(target=get_info, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `get_collection_info` function, and insert data in collection
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.skip("no segment row limit and type")
|
||||
def test_info_collection_fields_after_insert(self, connect, get_filter_field, get_vector_field):
|
||||
'''
|
||||
target: test create normal collection with different fields, check info returned
|
||||
method: create collection with diff fields: metric/field_type/..., calling `get_collection_info`
|
||||
expected: no exception raised, and value returned correct
|
||||
'''
|
||||
filter_field = get_filter_field
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"segment_row_limit": default_segment_row_limit
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
entities = gen_entities_by_fields(fields["fields"], default_nb, vector_field["params"]["dim"])
|
||||
res_ids = connect.insert(collection_name, entities)
|
||||
connect.flush([collection_name])
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['auto_id'] == True
|
||||
assert res['segment_row_limit'] == default_segment_row_limit
|
||||
assert len(res["fields"]) == 2
|
||||
for field in res["fields"]:
|
||||
if field["type"] == filter_field:
|
||||
assert field["name"] == filter_field["name"]
|
||||
elif field["type"] == vector_field:
|
||||
assert field["name"] == vector_field["name"]
|
||||
assert field["params"] == vector_field["params"]
|
||||
|
||||
@pytest.mark.skip("not segment row limit")
|
||||
def test_create_collection_segment_row_limit_after_insert(self, connect, get_segment_row_limit):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff segment_row_limit
|
||||
expected: no exception raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = get_segment_row_limit
|
||||
connect.create_collection(collection_name, fields)
|
||||
entities = gen_entities_by_fields(fields["fields"], default_nb, fields["fields"][-1]["params"]["dim"])
|
||||
res_ids = connect.insert(collection_name, entities)
|
||||
connect.flush([collection_name])
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['auto_id'] == True
|
||||
assert res['segment_row_limit'] == get_segment_row_limit
|
||||
|
||||
|
||||
class TestInfoInvalid(object):
|
||||
"""
|
||||
Test get collection info with invalid params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_info_with_invalid_collectionname(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.get_collection_info(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_info_with_empty_collectionname(self, connect):
|
||||
collection_name = ''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.get_collection_info(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_info_with_none_collectionname(self, connect):
|
||||
collection_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.get_collection_info(collection_name)
|
||||
|
||||
def test_get_collection_info_None(self, connect):
|
||||
'''
|
||||
target: test create collection but the collection name is None
|
||||
method: create collection, param collection_name is None
|
||||
expected: create raise error
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.get_collection_info(None)
|
|
@ -1,93 +0,0 @@
|
|||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uid = "has_collection"
|
||||
|
||||
class TestHasCollection:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `has_collection` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_has_collection(self, connect, collection):
|
||||
'''
|
||||
target: test if the created collection existed
|
||||
method: create collection, assert the value returned by has_collection method
|
||||
expected: True
|
||||
'''
|
||||
assert connect.has_collection(collection)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_collection_without_connection(self, collection, dis_connect):
|
||||
'''
|
||||
target: test has collection, without connection
|
||||
method: calling has collection with correct params, with a disconnected instance
|
||||
expected: has collection raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
assert dis_connect.has_collection(collection)
|
||||
|
||||
def test_has_collection_not_existed(self, connect):
|
||||
'''
|
||||
target: test if collection not created
|
||||
method: random a collection name, which not existed in db,
|
||||
assert the value returned by has_collection method
|
||||
expected: False
|
||||
'''
|
||||
collection_name = gen_unique_str("test_collection")
|
||||
assert not connect.has_collection(collection_name)
|
||||
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_collection_multithread(self, connect):
|
||||
'''
|
||||
target: test create collection with multithread
|
||||
method: create collection using multithread,
|
||||
expected: collections are created
|
||||
'''
|
||||
threads_num = 4
|
||||
threads = []
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def has():
|
||||
assert connect.has_collection(collection_name)
|
||||
# assert not assert_collection(connect, collection_name)
|
||||
for i in range(threads_num):
|
||||
t = MilvusTestThread(target=has, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
|
||||
class TestHasCollectionInvalid(object):
|
||||
"""
|
||||
Test has collection with invalid params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_collection_with_invalid_collectionname(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_collection_with_empty_collectionname(self, connect):
|
||||
collection_name = ''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_collection_with_none_collectionname(self, connect):
|
||||
collection_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
|
@ -1,841 +0,0 @@
|
|||
import logging
|
||||
import time
|
||||
import pdb
|
||||
import threading
|
||||
from multiprocessing import Pool, Process
|
||||
import numpy
|
||||
import pytest
|
||||
import sklearn.preprocessing
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uid = "test_index"
|
||||
BUILD_TIMEOUT = 300
|
||||
field_name = default_float_vec_field_name
|
||||
binary_field_name = default_binary_vec_field_name
|
||||
query, query_vecs = gen_query_vectors(field_name, default_entities, default_top_k, 1)
|
||||
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
|
||||
|
||||
|
||||
# @pytest.mark.skip("wait for debugging...")
|
||||
class TestIndexBase:
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, connect):
|
||||
import copy
|
||||
logging.getLogger().info(request.param)
|
||||
#if str(connect._cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("sq8h not support in CPU mode")
|
||||
return copy.deepcopy(request.param)
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
10,
|
||||
1111
|
||||
],
|
||||
)
|
||||
def get_nq(self, request):
|
||||
yield request.param
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
def test_create_index_on_field_not_existed(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index on field not existed
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_field_name = gen_unique_str()
|
||||
ids = connect.insert(collection, default_entities)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_index(collection, tmp_field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_on_field(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index on other field
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_field_name = "int64"
|
||||
ids = connect.insert(collection, default_entities)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_index(collection, tmp_field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_no_vectors(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_partition(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection, create partition, and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_partition_flush(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection, create partition, and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush()
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
def test_create_index_without_connect(self, dis_connect, collection):
|
||||
'''
|
||||
target: test create index without connection
|
||||
method: create collection and add entities in it, check if added successfully
|
||||
expected: raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
dis_connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_search_with_query_vectors(self, connect, collection, get_simple_index, get_nq):
|
||||
'''
|
||||
target: test create index interface, search with more query vectors
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
# logging.getLogger().info(connect.get_collection_stats(collection))
|
||||
nq = get_nq
|
||||
index_type = get_simple_index["index_type"]
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, search_params=search_param)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
|
||||
@pytest.mark.skip("can't_pass_ci")
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_multithread(self, connect, collection, args):
|
||||
'''
|
||||
target: test create index interface with multiprocess
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.insert(collection, default_entities)
|
||||
|
||||
def build(connect):
|
||||
connect.create_index(collection, field_name, default_index)
|
||||
|
||||
threads_num = 8
|
||||
threads = []
|
||||
for i in range(threads_num):
|
||||
m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"])
|
||||
t = MilvusTestThread(target=build, args=(m,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
def test_create_index_collection_not_existed(self, connect):
|
||||
'''
|
||||
target: test create index interface when collection name not existed
|
||||
method: create collection and add entities in it, create index
|
||||
, make sure the collection name not in index
|
||||
expected: create index failed
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_index(collection_name, field_name, default_index)
|
||||
|
||||
@pytest.mark.skip("count_entries")
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_insert_flush(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index
|
||||
method: create collection and create index, add entities in it
|
||||
expected: create index ok, and count correct
|
||||
'''
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
count = connect.count_entities(collection)
|
||||
assert count == default_nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_same_index_repeatedly(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: check if index can be created repeatedly, with the same create_index params
|
||||
method: create index after index have been built
|
||||
expected: return code success, and search ok
|
||||
'''
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_different_index_repeatedly(self, connect, collection):
|
||||
'''
|
||||
target: check if index can be created repeatedly, with the different create_index params
|
||||
method: create another index with different index_params after index have been built
|
||||
expected: return code 0, and describe index result equals with the second index params
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
indexs = [default_index, {"metric_type":"L2", "index_type": "FLAT", "params":{"nlist": 1024}}]
|
||||
for index in indexs:
|
||||
connect.create_index(collection, field_name, index)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
# assert stats["partitions"][0]["segments"][0]["index_name"] == index["index_type"]
|
||||
assert stats["row_count"] == str(default_nb)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_ip(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_no_vectors_ip(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_partition_ip(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection, create partition, and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_partition_flush_ip(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection, create partition, and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush()
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_search_with_query_vectors_ip(self, connect, collection, get_simple_index, get_nq):
|
||||
'''
|
||||
target: test create index interface, search with more query vectors
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
metric_type = "IP"
|
||||
ids = connect.insert(collection, default_entities)
|
||||
get_simple_index["metric_type"] = metric_type
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
# logging.getLogger().info(connect.get_collection_stats(collection))
|
||||
nq = get_nq
|
||||
index_type = get_simple_index["index_type"]
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, metric_type=metric_type, search_params=search_param)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
|
||||
@pytest.mark.skip("test_create_index_multithread_ip")
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_multithread_ip(self, connect, collection, args):
|
||||
'''
|
||||
target: test create index interface with multiprocess
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.insert(collection, default_entities)
|
||||
|
||||
def build(connect):
|
||||
default_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, default_index)
|
||||
|
||||
threads_num = 8
|
||||
threads = []
|
||||
for i in range(threads_num):
|
||||
m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"])
|
||||
t = MilvusTestThread(target=build, args=(m,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
def test_create_index_collection_not_existed_ip(self, connect, collection):
|
||||
'''
|
||||
target: test create index interface when collection name not existed
|
||||
method: create collection and add entities in it, create index
|
||||
, make sure the collection name not in index
|
||||
expected: return code not equals to 0, create index failed
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
default_index["metric_type"] = "IP"
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_index(collection_name, field_name, default_index)
|
||||
|
||||
@pytest.mark.skip("count_entries")
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_no_vectors_insert_ip(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface when there is no vectors in collection, and does not affect the subsequent process
|
||||
method: create collection and add no vectors in it, and then create index, add entities in it
|
||||
expected: return code equals to 0
|
||||
'''
|
||||
default_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
count = connect.count_entities(collection)
|
||||
assert count == default_nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_same_index_repeatedly_ip(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: check if index can be created repeatedly, with the same create_index params
|
||||
method: create index after index have been built
|
||||
expected: return code success, and search ok
|
||||
'''
|
||||
default_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
# TODO:
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_different_index_repeatedly_ip(self, connect, collection):
|
||||
'''
|
||||
target: check if index can be created repeatedly, with the different create_index params
|
||||
method: create another index with different index_params after index have been built
|
||||
expected: return code 0, and describe index result equals with the second index params
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
indexs = [default_index, {"index_type": "FLAT", "params": {"nlist": 1024}, "metric_type": "IP"}]
|
||||
for index in indexs:
|
||||
connect.create_index(collection, field_name, index)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
# assert stats["partitions"][0]["segments"][0]["index_name"] == index["index_type"]
|
||||
assert stats["row_count"] == str(default_nb)
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `drop_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.skip("get_collection_stats")
|
||||
def test_drop_index(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test drop index interface
|
||||
method: create collection and add entities in it, create index, call drop index
|
||||
expected: return code 0, and default index param
|
||||
'''
|
||||
# ids = connect.insert(collection, entities)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.drop_index(collection, field_name)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
# assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type
|
||||
assert not stats["partitions"][0]["segments"]
|
||||
|
||||
@pytest.mark.skip("get_collection_stats")
|
||||
@pytest.mark.skip("drop_index raise exception")
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_index_repeatly(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test drop index repeatly
|
||||
method: create index, call drop index, and drop again
|
||||
expected: return code 0
|
||||
'''
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
connect.drop_index(collection, field_name)
|
||||
connect.drop_index(collection, field_name)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
logging.getLogger().info(stats)
|
||||
# assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type
|
||||
assert not stats["partitions"][0]["segments"]
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_index_without_connect(self, dis_connect, collection):
|
||||
'''
|
||||
target: test drop index without connection
|
||||
method: drop index, and check if drop successfully
|
||||
expected: raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
dis_connect.drop_index(collection, field_name)
|
||||
|
||||
def test_drop_index_collection_not_existed(self, connect):
|
||||
'''
|
||||
target: test drop index interface when collection name not existed
|
||||
method: create collection and add entities in it, create index
|
||||
, make sure the collection name not in index, and then drop it
|
||||
expected: return code not equals to 0, drop index failed
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_index(collection_name, field_name)
|
||||
|
||||
def test_drop_index_collection_not_create(self, connect, collection):
|
||||
'''
|
||||
target: test drop index interface when index not created
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return code not equals to 0, drop index failed
|
||||
'''
|
||||
# ids = connect.insert(collection, entities)
|
||||
# no create index
|
||||
connect.drop_index(collection, field_name)
|
||||
|
||||
@pytest.mark.skip("drop_index")
|
||||
@pytest.mark.level(2)
|
||||
def test_create_drop_index_repeatly(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create / drop index repeatly, use the same index params
|
||||
method: create index, drop index, four times
|
||||
expected: return code 0
|
||||
'''
|
||||
for i in range(4):
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.drop_index(collection, field_name)
|
||||
|
||||
@pytest.mark.skip("get_collection_stats")
|
||||
def test_drop_index_ip(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test drop index interface
|
||||
method: create collection and add entities in it, create index, call drop index
|
||||
expected: return code 0, and default index param
|
||||
'''
|
||||
# ids = connect.insert(collection, entities)
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.drop_index(collection, field_name)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
# assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type
|
||||
assert not stats["partitions"][0]["segments"]
|
||||
|
||||
@pytest.mark.skip("get_collection_stats")
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_index_repeatly_ip(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test drop index repeatly
|
||||
method: create index, call drop index, and drop again
|
||||
expected: return code 0
|
||||
'''
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
connect.drop_index(collection, field_name)
|
||||
connect.drop_index(collection, field_name)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
logging.getLogger().info(stats)
|
||||
# assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type
|
||||
assert not stats["partitions"][0]["segments"]
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_index_without_connect_ip(self, dis_connect, collection):
|
||||
'''
|
||||
target: test drop index without connection
|
||||
method: drop index, and check if drop successfully
|
||||
expected: raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
dis_connect.drop_index(collection, field_name)
|
||||
|
||||
def test_drop_index_collection_not_create_ip(self, connect, collection):
|
||||
'''
|
||||
target: test drop index interface when index not created
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return code not equals to 0, drop index failed
|
||||
'''
|
||||
# ids = connect.insert(collection, entities)
|
||||
# no create index
|
||||
connect.drop_index(collection, field_name)
|
||||
|
||||
@pytest.mark.skip("drop_index")
|
||||
@pytest.mark.skip("can't create and drop")
|
||||
@pytest.mark.level(2)
|
||||
def test_create_drop_index_repeatly_ip(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create / drop index repeatly, use the same index params
|
||||
method: create index, drop index, four times
|
||||
expected: return code 0
|
||||
'''
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
for i in range(4):
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.drop_index(collection, field_name)
|
||||
|
||||
|
||||
class TestIndexBinary:
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, connect):
|
||||
# TODO: Determine the service mode
|
||||
# if str(connect._cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("sq8h not support in CPU mode")
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_binary_index()
|
||||
)
|
||||
def get_jaccard_index(self, request, connect):
|
||||
if request.param["index_type"] in binary_support():
|
||||
request.param["metric_type"] = "JACCARD"
|
||||
return request.param
|
||||
else:
|
||||
pytest.skip("Skip index")
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_binary_index()
|
||||
)
|
||||
def get_l2_index(self, request, connect):
|
||||
request.param["metric_type"] = "L2"
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
10,
|
||||
1111
|
||||
],
|
||||
)
|
||||
def get_nq(self, request):
|
||||
yield request.param
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index(self, connect, binary_collection, get_jaccard_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_partition(self, connect, binary_collection, get_jaccard_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection, create partition, and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.create_partition(binary_collection, default_tag)
|
||||
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_search_with_query_vectors(self, connect, binary_collection, get_jaccard_index, get_nq):
|
||||
'''
|
||||
target: test create index interface, search with more query vectors
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
nq = get_nq
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, nq, metric_type="JACCARD")
|
||||
search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD")
|
||||
logging.getLogger().info(search_param)
|
||||
res = connect.search(binary_collection, query, search_params=search_param)
|
||||
assert len(res) == nq
|
||||
|
||||
@pytest.mark.skip("get status for build index failed")
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_invalid_metric_type_binary(self, connect, binary_collection, get_l2_index):
|
||||
'''
|
||||
target: test create index interface with invalid metric type
|
||||
method: add entitys into binary connection, flash, create index with L2 metric type.
|
||||
expected: return create_index failure
|
||||
'''
|
||||
# insert 6000 vectors
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
|
||||
if get_l2_index["index_type"] == "BIN_FLAT":
|
||||
res = connect.create_index(binary_collection, binary_field_name, get_l2_index)
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.create_index(binary_collection, binary_field_name, get_l2_index)
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `get_index_info` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.skip("get_collection_stats does not impl")
|
||||
def test_get_index_info(self, connect, binary_collection, get_jaccard_index):
|
||||
'''
|
||||
target: test describe index interface
|
||||
method: create collection and add entities in it, create index, call describe index
|
||||
expected: return code 0, and index instructure
|
||||
'''
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
for partition in stats["partitions"]:
|
||||
segments = partition["segments"]
|
||||
if segments:
|
||||
for segment in segments:
|
||||
for file in segment["files"]:
|
||||
if "index_type" in file:
|
||||
assert file["index_type"] == get_jaccard_index["index_type"]
|
||||
|
||||
@pytest.mark.skip("get_collection_stats does not impl")
|
||||
def test_get_index_info_partition(self, connect, binary_collection, get_jaccard_index):
|
||||
'''
|
||||
target: test describe index interface
|
||||
method: create collection, create partition and add entities in it, create index, call describe index
|
||||
expected: return code 0, and index instructure
|
||||
'''
|
||||
connect.create_partition(binary_collection, default_tag)
|
||||
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
logging.getLogger().info(stats)
|
||||
assert stats["row_count"] == default_nb
|
||||
assert len(stats["partitions"]) == 2
|
||||
for partition in stats["partitions"]:
|
||||
segments = partition["segments"]
|
||||
if segments:
|
||||
for segment in segments:
|
||||
for file in segment["files"]:
|
||||
if "index_type" in file:
|
||||
assert file["index_type"] == get_jaccard_index["index_type"]
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `drop_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.skip("get_collection_stats")
|
||||
def test_drop_index(self, connect, binary_collection, get_jaccard_index):
|
||||
'''
|
||||
target: test drop index interface
|
||||
method: create collection and add entities in it, create index, call drop index
|
||||
expected: return code 0, and default index param
|
||||
'''
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
logging.getLogger().info(stats)
|
||||
connect.drop_index(binary_collection, binary_field_name)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
# assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type
|
||||
assert not stats["partitions"][0]["segments"]
|
||||
|
||||
@pytest.mark.skip("get_collection_stats does not impl")
|
||||
def test_drop_index_partition(self, connect, binary_collection, get_jaccard_index):
|
||||
'''
|
||||
target: test drop index interface
|
||||
method: create collection, create partition and add entities in it, create index on collection, call drop collection index
|
||||
expected: return code 0, and default index param
|
||||
'''
|
||||
connect.create_partition(binary_collection, default_tag)
|
||||
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
connect.drop_index(binary_collection, binary_field_name)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
for partition in stats["partitions"]:
|
||||
segments = partition["segments"]
|
||||
if segments:
|
||||
for segment in segments:
|
||||
for file in segment["files"]:
|
||||
if "index_type" not in file:
|
||||
continue
|
||||
if file["index_type"] == get_jaccard_index["index_type"]:
|
||||
assert False
|
||||
|
||||
|
||||
class TestIndexInvalid(object):
|
||||
"""
|
||||
Test create / describe / drop index interfaces with invalid collection names
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(1)
|
||||
def test_create_index_with_invalid_collectionname(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_index(collection_name, field_name, default_index)
|
||||
|
||||
@pytest.mark.level(1)
|
||||
def test_drop_index_with_invalid_collectionname(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_index(collection_name)
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_index()
|
||||
)
|
||||
def get_index(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_with_invalid_index_params(self, connect, collection, get_index):
|
||||
logging.getLogger().info(get_index)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
|
||||
class TestIndexAsync:
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def skip_http_check(self, args):
|
||||
if args["handler"] == "HTTP":
|
||||
pytest.skip("skip in http mode")
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, connect):
|
||||
# TODO: Determine the service mode
|
||||
# if str(connect._cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("sq8h not support in CPU mode")
|
||||
return request.param
|
||||
|
||||
def check_result(self, res):
|
||||
logging.getLogger().info("In callback check search result")
|
||||
logging.getLogger().info(res)
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
logging.getLogger().info("start index")
|
||||
future = connect.create_index(collection, field_name, get_simple_index, _async=True)
|
||||
logging.getLogger().info("before result")
|
||||
res = future.result()
|
||||
# TODO:
|
||||
logging.getLogger().info(res)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_drop(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
logging.getLogger().info("start index")
|
||||
future = connect.create_index(collection, field_name, get_simple_index, _async=True)
|
||||
logging.getLogger().info("DROP")
|
||||
connect.drop_collection(collection)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_with_invalid_collectionname(self, connect):
|
||||
collection_name = " "
|
||||
with pytest.raises(Exception) as e:
|
||||
future = connect.create_index(collection_name, field_name, default_index, _async=True)
|
||||
res = future.result()
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_callback(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
logging.getLogger().info("start index")
|
||||
future = connect.create_index(collection, field_name, get_simple_index, _async=True,
|
||||
_callback=self.check_result)
|
||||
logging.getLogger().info("before result")
|
||||
res = future.result()
|
||||
# TODO:
|
||||
logging.getLogger().info(res)
|
File diff suppressed because it is too large
Load Diff
|
@ -1,88 +0,0 @@
|
|||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uid = "list_collections"
|
||||
|
||||
class TestListCollections:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `list_collections` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_list_collections(self, connect, collection):
|
||||
'''
|
||||
target: test list collections
|
||||
method: create collection, assert the value returned by list_collections method
|
||||
expected: True
|
||||
'''
|
||||
assert collection in connect.list_collections()
|
||||
|
||||
def test_list_collections_multi_collections(self, connect):
|
||||
'''
|
||||
target: test list collections
|
||||
method: create collection, assert the value returned by list_collections method
|
||||
expected: True
|
||||
'''
|
||||
collection_num = 50
|
||||
for i in range(collection_num):
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
assert collection_name in connect.list_collections()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_collections_without_connection(self, dis_connect):
|
||||
'''
|
||||
target: test list collections, without connection
|
||||
method: calling list collections with correct params, with a disconnected instance
|
||||
expected: list collections raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
dis_connect.list_collections()
|
||||
|
||||
def test_list_collections_not_existed(self, connect):
|
||||
'''
|
||||
target: test if collection not created
|
||||
method: random a collection name, which not existed in db,
|
||||
assert the value returned by list_collections method
|
||||
expected: False
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
assert collection_name not in connect.list_collections()
|
||||
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.skip("can't run in parallel")
|
||||
def test_list_collections_no_collection(self, connect):
|
||||
'''
|
||||
target: test show collections is correct or not, if no collection in db
|
||||
method: delete all collections,
|
||||
assert the value returned by list_collections method is equal to []
|
||||
expected: the status is ok, and the result is equal to []
|
||||
'''
|
||||
result = connect.list_collections()
|
||||
if result:
|
||||
for collection_name in result:
|
||||
assert connect.has_collection(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_collections_multithread(self, connect):
|
||||
'''
|
||||
target: test create collection with multithread
|
||||
method: create collection using multithread,
|
||||
expected: collections are created
|
||||
'''
|
||||
threads_num = 4
|
||||
threads = []
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def _list():
|
||||
assert collection_name in connect.list_collections()
|
||||
for i in range(threads_num):
|
||||
t = threading.Thread(target=_list, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
|
@ -1,22 +0,0 @@
|
|||
from tests.utils import *
|
||||
from tests.constants import *
|
||||
|
||||
uniq_id = "load_collection"
|
||||
|
||||
class TestLoadCollection:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `load_collection` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_load_collection(self, connect, collection_without_loading):
|
||||
'''
|
||||
target: test load collection and wait for loading collection
|
||||
method: insert then flush, when flushed, try load collection
|
||||
expected: no errors
|
||||
'''
|
||||
collection = collection_without_loading
|
||||
ids = connect.insert(collection, default_entities)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
connect.load_collection(collection)
|
|
@ -1,26 +0,0 @@
|
|||
from tests.utils import *
|
||||
from tests.constants import *
|
||||
|
||||
uniq_id = "load_partitions"
|
||||
|
||||
class TestLoadPartitions:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `load_partitions` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_load_partitions(self, connect, collection):
|
||||
'''
|
||||
target: test load collection and wait for loading collection
|
||||
method: insert then flush, when flushed, try load collection
|
||||
expected: no errors
|
||||
'''
|
||||
partition_tag = "lvn9pq34u8rasjk"
|
||||
connect.create_partition(collection, partition_tag + "1")
|
||||
ids = connect.insert(collection, default_entities, partition_tag=partition_tag + "1")
|
||||
|
||||
connect.create_partition(collection, partition_tag + "2")
|
||||
ids = connect.insert(collection, default_entity, partition_tag=partition_tag + "2")
|
||||
|
||||
connect.flush([collection])
|
||||
connect.load_partitions(collection, [partition_tag + "2"])
|
|
@ -1,396 +0,0 @@
|
|||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
TIMEOUT = 120
|
||||
|
||||
class TestCreateBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_partition` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_create_partition(self, connect, collection):
|
||||
'''
|
||||
target: test create partition, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(600)
|
||||
@pytest.mark.skip
|
||||
def test_create_partition_limit(self, connect, collection, args):
|
||||
'''
|
||||
target: test create partitions, check status returned
|
||||
method: call function: create_partition for 4097 times
|
||||
expected: exception raised
|
||||
'''
|
||||
threads_num = 8
|
||||
threads = []
|
||||
if args["handler"] == "HTTP":
|
||||
pytest.skip("skip in http mode")
|
||||
|
||||
def create(connect, threads_num):
|
||||
for i in range(max_partition_num // threads_num):
|
||||
tag_tmp = gen_unique_str()
|
||||
connect.create_partition(collection, tag_tmp)
|
||||
|
||||
for i in range(threads_num):
|
||||
m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"])
|
||||
t = threading.Thread(target=create, args=(m, threads_num, ))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
tag_tmp = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection, tag_tmp)
|
||||
|
||||
def test_create_partition_repeat(self, connect, collection):
|
||||
'''
|
||||
target: test create partition, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection, default_tag)
|
||||
|
||||
def test_create_partition_collection_not_existed(self, connect):
|
||||
'''
|
||||
target: test create partition, its owner collection name not existed in db, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection_name, default_tag)
|
||||
|
||||
def test_create_partition_tag_name_None(self, connect, collection):
|
||||
'''
|
||||
target: test create partition, tag name set None, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
tag_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection, tag_name)
|
||||
|
||||
def test_create_different_partition_tags(self, connect, collection):
|
||||
'''
|
||||
target: test create partition twice with different names
|
||||
method: call function: create_partition, and again
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
tag_name = gen_unique_str()
|
||||
connect.create_partition(collection, tag_name)
|
||||
tag_list = connect.list_partitions(collection)
|
||||
assert default_tag in tag_list
|
||||
assert tag_name in tag_list
|
||||
assert "_default" in tag_list
|
||||
|
||||
@pytest.mark.skip("not support custom id")
|
||||
def test_create_partition_insert_default(self, connect, id_collection):
|
||||
'''
|
||||
target: test create partition, and insert vectors, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
insert_ids = connect.insert(id_collection, default_entities, ids)
|
||||
assert len(insert_ids) == len(ids)
|
||||
|
||||
@pytest.mark.skip("not support custom id")
|
||||
def test_create_partition_insert_with_tag(self, connect, id_collection):
|
||||
'''
|
||||
target: test create partition, and insert vectors, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
assert len(insert_ids) == len(ids)
|
||||
|
||||
def test_create_partition_insert_with_tag_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test create partition, and insert vectors, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
tag_new = "tag_new"
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
with pytest.raises(Exception) as e:
|
||||
insert_ids = connect.insert(collection, default_entities, ids, partition_tag=tag_new)
|
||||
|
||||
@pytest.mark.skip("not support custom id")
|
||||
def test_create_partition_insert_same_tags(self, connect, id_collection):
|
||||
'''
|
||||
target: test create partition, and insert vectors, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
ids = [(i+default_nb) for i in range(default_nb)]
|
||||
new_insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
connect.flush([id_collection])
|
||||
res = connect.count_entities(id_collection)
|
||||
assert res == default_nb * 2
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.skip("not support count entities")
|
||||
def test_create_partition_insert_same_tags_two_collections(self, connect, collection):
|
||||
'''
|
||||
target: test create two partitions, and insert vectors with the same tag to each collection, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok, collection length is correct
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
collection_new = gen_unique_str()
|
||||
connect.create_collection(collection_new, default_fields)
|
||||
connect.create_partition(collection_new, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
ids = connect.insert(collection_new, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection, collection_new])
|
||||
res = connect.count_entities(collection)
|
||||
assert res == default_nb
|
||||
res = connect.count_entities(collection_new)
|
||||
assert res == default_nb
|
||||
|
||||
|
||||
class TestShowBase:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `list_partitions` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_list_partitions(self, connect, collection):
|
||||
'''
|
||||
target: test show partitions, check status and partitions returned
|
||||
method: create partition first, then call function: list_partitions
|
||||
expected: status ok, partition correct
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
res = connect.list_partitions(collection)
|
||||
assert default_tag in res
|
||||
|
||||
def test_list_partitions_no_partition(self, connect, collection):
|
||||
'''
|
||||
target: test show partitions with collection name, check status and partitions returned
|
||||
method: call function: list_partitions
|
||||
expected: status ok, partitions correct
|
||||
'''
|
||||
res = connect.list_partitions(collection)
|
||||
assert len(res) == 1
|
||||
|
||||
def test_show_multi_partitions(self, connect, collection):
|
||||
'''
|
||||
target: test show partitions, check status and partitions returned
|
||||
method: create partitions first, then call function: list_partitions
|
||||
expected: status ok, partitions correct
|
||||
'''
|
||||
tag_new = gen_unique_str()
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, tag_new)
|
||||
res = connect.list_partitions(collection)
|
||||
assert default_tag in res
|
||||
assert tag_new in res
|
||||
|
||||
|
||||
class TestHasBase:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `has_partition` function
|
||||
******************************************************************
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_tag_name(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_has_partition(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: create partition first, then call function: has_partition
|
||||
expected: status ok, result true
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
res = connect.has_partition(collection, default_tag)
|
||||
logging.getLogger().info(res)
|
||||
assert res
|
||||
|
||||
def test_has_partition_multi_partitions(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: create partition first, then call function: has_partition
|
||||
expected: status ok, result true
|
||||
'''
|
||||
for tag_name in [default_tag, "tag_new", "tag_new_new"]:
|
||||
connect.create_partition(collection, tag_name)
|
||||
for tag_name in [default_tag, "tag_new", "tag_new_new"]:
|
||||
res = connect.has_partition(collection, tag_name)
|
||||
assert res
|
||||
|
||||
def test_has_partition_tag_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: then call function: has_partition, with tag not existed
|
||||
expected: status ok, result empty
|
||||
'''
|
||||
res = connect.has_partition(collection, default_tag)
|
||||
logging.getLogger().info(res)
|
||||
assert not res
|
||||
|
||||
def test_has_partition_collection_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: then call function: has_partition, with collection not existed
|
||||
expected: status not ok
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.has_partition("not_existed_collection", default_tag)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
|
||||
'''
|
||||
target: test has partition, with invalid tag name, check status returned
|
||||
method: call function: has_partition
|
||||
expected: status ok
|
||||
'''
|
||||
tag_name = get_tag_name
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.has_partition(collection, tag_name)
|
||||
|
||||
|
||||
class TestDropBase:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `drop_partition` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_drop_partition(self, connect, collection):
|
||||
'''
|
||||
target: test drop partition, check status and partition if existed
|
||||
method: create partitions first, then call function: drop_partition
|
||||
expected: status ok, no partitions in db
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.drop_partition(collection, default_tag)
|
||||
res = connect.list_partitions(collection)
|
||||
tag_list = []
|
||||
assert default_tag not in tag_list
|
||||
|
||||
def test_drop_partition_tag_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test drop partition, but tag not existed
|
||||
method: create partitions first, then call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
new_tag = "new_tag"
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection, new_tag)
|
||||
|
||||
def test_drop_partition_tag_not_existed_A(self, connect, collection):
|
||||
'''
|
||||
target: test drop partition, but collection not existed
|
||||
method: create partitions first, then call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
new_collection = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(new_collection, default_tag)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_partition_repeatedly(self, connect, collection):
|
||||
'''
|
||||
target: test drop partition twice, check status and partition if existed
|
||||
method: create partitions first, then call function: drop_partition
|
||||
expected: status not ok, no partitions in db
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.drop_partition(collection, default_tag)
|
||||
time.sleep(2)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection, default_tag)
|
||||
tag_list = connect.list_partitions(collection)
|
||||
assert default_tag not in tag_list
|
||||
|
||||
def test_drop_partition_create(self, connect, collection):
|
||||
'''
|
||||
target: test drop partition, and create again, check status
|
||||
method: create partitions first, then call function: drop_partition, create_partition
|
||||
expected: status not ok, partition in db
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.drop_partition(collection, default_tag)
|
||||
time.sleep(2)
|
||||
connect.create_partition(collection, default_tag)
|
||||
tag_list = connect.list_partitions(collection)
|
||||
assert default_tag in tag_list
|
||||
|
||||
|
||||
class TestNameInvalid(object):
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_tag_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_partition_with_invalid_collection_name(self, connect, collection, get_collection_name):
|
||||
'''
|
||||
target: test drop partition, with invalid collection name, check status returned
|
||||
method: call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = get_collection_name
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection_name, default_tag)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
|
||||
'''
|
||||
target: test drop partition, with invalid tag name, check status returned
|
||||
method: call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
tag_name = get_tag_name
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection, tag_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_partitions_with_invalid_collection_name(self, connect, collection, get_collection_name):
|
||||
'''
|
||||
target: test show partitions, with invalid collection name, check status returned
|
||||
method: call function: list_partitions
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = get_collection_name
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.list_partitions(collection_name)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue