milvus/tests/milvus_python_test/collection/test_create_collection.py

349 lines
12 KiB
Python

import pdb
import copy
import logging
import itertools
from time import sleep
import threading
from multiprocessing import Process
import sklearn.preprocessing
import pytest
from milvus import IndexType, MetricType
from utils import *
nb = 1
dim = 128
collection_id = "create_collection"
default_segment_size = 1024
drop_collection_interval_time = 3
segment_size = 10
default_fields = gen_default_fields()
entities = gen_entities(nb)
class TestCreateCollection:
"""
******************************************************************
The following cases are used to test `create_collection` function
******************************************************************
"""
@pytest.fixture(
scope="function",
params=gen_single_filter_fields()
)
def get_filter_field(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_single_vector_fields()
)
def get_vector_field(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_segment_sizes()
)
def get_segment_size(self, request):
yield request.param
def test_create_collection_fields(self, connect, get_filter_field, get_vector_field):
'''
target: test create normal collection with different fields
method: create collection with diff fields: metric/field_type/...
expected: no exception raised
'''
filter_field = get_filter_field
logging.getLogger().info(filter_field)
vector_field = get_vector_field
collection_name = gen_unique_str(collection_id)
fields = {
"fields": [filter_field, vector_field],
"segment_size": segment_size
}
logging.getLogger().info(fields)
connect.create_collection(collection_name, fields)
assert connect.has_collection(collection_name)
# TODO
def test_create_collection_fields_create_index(self, connect, get_filter_field, get_vector_field):
'''
target: test create normal collection with different fields
method: create collection with diff fields: metric/field_type/...
expected: no exception raised
'''
filter_field = get_filter_field
vector_field = get_vector_field
collection_name = gen_unique_str(collection_id)
fields = {
"fields": [filter_field, vector_field],
"segment_size": segment_size
}
connect.create_collection(collection_name, fields)
assert connect.has_collection(collection_name)
def test_create_collection_segment_size(self, connect, get_segment_size):
'''
target: test create normal collection with different fields
method: create collection with diff segment_size
expected: no exception raised
'''
collection_name = gen_unique_str(collection_id)
fields = copy.deepcopy(default_fields)
fields["segment_size"] = get_segment_size
connect.create_collection(collection_name, fields)
assert connect.has_collection(collection_name)
def test_create_collection_auto_flush_disabled(self, connect):
'''
target: test create normal collection, with large auto_flush_interval
method: create collection with corrent params
expected: create status return ok
'''
disable_flush(connect)
collection_name = gen_unique_str(collection_id)
try:
connect.create_collection(collection_name, default_fields)
finally:
enable_flush(connect)
# pdb.set_trace()
def test_create_collection_after_insert(self, connect, collection):
'''
target: test insert vector, then create collection again
method: insert vector and create collection
expected: error raised
'''
# pdb.set_trace()
connect.insert(collection, entities)
with pytest.raises(Exception) as e:
connect.create_collection(collection, default_fields)
def test_create_collection_after_insert_flush(self, connect, collection):
'''
target: test insert vector, then create collection again
method: insert vector and create collection
expected: error raised
'''
connect.insert(collection, entities)
connect.flush([collection])
with pytest.raises(Exception) as e:
connect.create_collection(collection, default_fields)
# TODO: assert exception
@pytest.mark.level(2)
def test_create_collection_without_connection(self, dis_connect):
'''
target: test create collection, without connection
method: create collection with correct params, with a disconnected instance
expected: create raise exception
'''
collection_name = gen_unique_str(collection_id)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
def test_create_collection_existed(self, connect):
'''
target: test create collection but the collection name have already existed
method: create collection with the same collection_name
expected: create status return not ok
'''
collection_name = gen_unique_str(collection_id)
connect.create_collection(collection_name, default_fields)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
@pytest.mark.level(2)
def test_create_collection_multithread(self, connect):
'''
target: test create collection with multithread
method: create collection using multithread,
expected: collections are created
'''
threads_num = 8
threads = []
collection_names = []
def create():
collection_name = gen_unique_str(collection_id)
collection_names.append(collection_name)
connect.create_collection(collection_name, default_fields)
for i in range(threads_num):
t = threading.Thread(target=create, args=())
threads.append(t)
t.start()
time.sleep(0.2)
for t in threads:
t.join()
res = connect.list_collections()
for item in collection_names:
assert item in res
class TestCreateCollectionInvalid(object):
"""
Test creating collections with invalid params
"""
@pytest.fixture(
scope="function",
params=gen_invalid_metric_types()
)
def get_metric_type(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_ints()
)
def get_segment_size(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_ints()
)
def get_dim(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_invalid_string(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_field_types()
)
def get_field_type(self, request):
yield request.param
@pytest.mark.level(2)
def test_create_collection_with_invalid_segment_size(self, connect, get_segment_size):
collection_name = gen_unique_str()
fields = copy.deepcopy(default_fields)
fields["segment_size"] = get_segment_size
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
@pytest.mark.level(2)
def test_create_collection_with_invalid_metric_type(self, connect, get_metric_type):
collection_name = gen_unique_str()
fields = copy.deepcopy(default_fields)
fields["fields"][-1]["params"]["metric_type"] = get_metric_type
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
@pytest.mark.level(2)
def test_create_collection_with_invalid_dimension(self, connect, get_dim):
dimension = get_dim
collection_name = gen_unique_str()
fields = copy.deepcopy(default_fields)
fields["fields"][-1]["params"]["dimension"] = dimension
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
@pytest.mark.level(2)
def test_create_collection_with_invalid_collectionname(self, connect, get_invalid_string):
collection_name = get_invalid_string
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
@pytest.mark.level(2)
def test_create_collection_with_empty_collectionname(self, connect):
collection_name = ''
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
@pytest.mark.level(2)
def test_create_collection_with_none_collectionname(self, connect):
collection_name = None
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
def test_create_collection_None(self, connect):
'''
target: test create collection but the collection name is None
method: create collection, param collection_name is None
expected: create raise error
'''
with pytest.raises(Exception) as e:
connect.create_collection(None, default_fields)
def test_create_collection_no_dimension(self, connect):
'''
target: test create collection with no dimension params
method: create collection with corrent params
expected: create status return ok
'''
collection_name = gen_unique_str(collection_id)
fields = copy.deepcopy(default_fields)
fields["fields"][-1]["params"].pop("dimension")
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
def test_create_collection_no_segment_size(self, connect):
'''
target: test create collection with no segment_size params
method: create collection with corrent params
expected: use default default_segment_size
'''
collection_name = gen_unique_str(collection_id)
fields = copy.deepcopy(default_fields)
fields.pop("segment_size")
connect.create_collection(collection_name, fields)
res = connect.get_collection_info(collection_name)
logging.getLogger().info(res)
assert res["segment_size"] == default_segment_size
# TODO:
def _test_create_collection_no_metric_type(self, connect):
'''
target: test create collection with no metric_type params
method: create collection with corrent params
expected: use default L2
'''
collection_name = gen_unique_str(collection_id)
fields = copy.deepcopy(default_fields)
fields["fields"][-1]["params"].pop("metric_type")
connect.create_collection(collection_name, fields)
res = connect.get_collection_info(collection_name)
logging.getLogger().info(res)
assert res["metric_type"] == "L2"
# TODO: assert exception
def test_create_collection_limit_fields(self, connect):
collection_name = gen_unique_str(collection_id)
limit_num = 64
fields = copy.deepcopy(default_fields)
for i in range(limit_num):
field_name = gen_unique_str("field_name")
field = {"field": field_name, "type": DataType.INT8}
fields["fields"].append(field)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
# TODO: assert exception
def test_create_collection_invalid_field_name(self, connect, get_invalid_string):
collection_name = gen_unique_str(collection_id)
fields = copy.deepcopy(default_fields)
field_name = get_invalid_string
field = {"field": field_name, "type": DataType.INT8}
fields["fields"].append(field)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
# TODO: assert exception
def test_create_collection_invalid_field_type(self, connect, get_field_type):
collection_name = gen_unique_str(collection_id)
fields = copy.deepcopy(default_fields)
field_type = get_field_type
field = {"field": "test_field", "type": field_type}
fields["fields"].append(field)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)