Remove import star of test file (#9425)

Signed-off-by: ThreadDao <yufen.zong@zilliz.com>
pull/9428/head
ThreadDao 2021-10-07 22:38:43 +08:00 committed by GitHub
parent 3b5c17aa0e
commit c8fa7026aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 83 additions and 76 deletions

View File

@ -1,18 +1,25 @@
import copy
import logging
import threading
import pytest
from pymilvus import DataType, ParamError, BaseException
from utils.utils import *
from utils import utils as ut
from common.constants import default_entity, default_entities, default_binary_entity, default_binary_entities, \
default_fields
from common.common_type import CaseLabel
ADD_TIMEOUT = 60
uid = "test_insert"
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
field_name = ut.default_float_vec_field_name
binary_field_name = ut.default_binary_vec_field_name
default_nb = ut.default_nb
row_count = ut.row_count
default_tag = ut.default_tag
default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, default_dim), "metric_type": "L2",
{"vector": {field_name: {"topk": 10, "query": ut.gen_vectors(1, ut.default_dim), "metric_type": "L2",
"params": {"nprobe": 10}}}}
]
}
@ -28,25 +35,25 @@ class TestInsertBase:
@pytest.fixture(
scope="function",
params=gen_simple_index()
params=ut.gen_simple_index()
)
def get_simple_index(self, request, connect):
# if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
if request.param["index_type"] in ut.index_cpu_not_support():
pytest.skip("CPU not support index_type: ivf_sq8h")
logging.getLogger().info(request.param)
return request.param
@pytest.fixture(
scope="function",
params=gen_single_filter_fields()
params=ut.gen_single_filter_fields()
)
def get_filter_field(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_single_vector_fields()
params=ut.gen_single_vector_fields()
)
def get_vector_field(self, request):
yield request.param
@ -81,7 +88,7 @@ class TestInsertBase:
method: insert entity into a random named collection
expected: raise a BaseException
"""
collection_name = gen_unique_str(uid)
collection_name = ut.gen_unique_str(uid)
with pytest.raises(BaseException) as e:
connect.insert(collection_name, default_entities)
@ -136,7 +143,7 @@ class TestInsertBase:
connect.create_index(collection, field_name, get_simple_index)
if get_simple_index["index_type"] != "FLAT":
index = connect.describe_index(collection, "")
create_target_index(get_simple_index, field_name)
ut.create_target_index(get_simple_index, field_name)
assert index == get_simple_index
@pytest.mark.timeout(ADD_TIMEOUT)
@ -152,7 +159,7 @@ class TestInsertBase:
assert len(result.primary_keys) == default_nb
if get_simple_index["index_type"] != "FLAT":
index = connect.describe_index(collection, "")
create_target_index(get_simple_index, field_name)
ut.create_target_index(get_simple_index, field_name)
assert index == get_simple_index
@pytest.mark.timeout(ADD_TIMEOUT)
@ -167,18 +174,18 @@ class TestInsertBase:
connect.flush([collection])
connect.load_collection(collection)
res = connect.search(collection, default_single_query)
assert len(res[0]) == default_top_k
assert len(res[0]) == ut.default_top_k
@pytest.mark.tags(CaseLabel.L2)
def _test_insert_segment_row_count(self, connect, collection):
nb = default_segment_row_limit + 1
result = connect.insert(collection, gen_entities(nb))
nb = ut.default_segment_row_limit + 1
result = connect.insert(collection, ut.gen_entities(nb))
connect.flush([collection])
assert len(result.primary_keys) == nb
stats = connect.get_collection_stats(collection)
assert len(stats['partitions'][0]['segments']) == 2
for segment in stats['partitions'][0]['segments']:
assert segment['row_count'] in [default_segment_row_limit, 1]
assert segment['row_count'] in [ut.default_segment_row_limit, 1]
@pytest.fixture(
scope="function",
@ -200,7 +207,7 @@ class TestInsertBase:
"""
nb = insert_count
ids = [i for i in range(nb)]
entities = gen_entities(nb)
entities = ut.gen_entities(nb)
entities[0]["values"] = ids
result = connect.insert(id_collection, entities)
connect.flush([id_collection])
@ -219,7 +226,7 @@ class TestInsertBase:
"""
nb = insert_count
ids = [1 for i in range(nb)]
entities = gen_entities(nb)
entities = ut.gen_entities(nb)
entities[0]["values"] = ids
result = connect.insert(id_collection, entities)
connect.flush([id_collection])
@ -239,14 +246,14 @@ class TestInsertBase:
nb = 5
filter_field = get_filter_field
vector_field = get_vector_field
collection_name = gen_unique_str("test_collection")
collection_name = ut.gen_unique_str("test_collection")
fields = {
"fields": [gen_primary_field(), filter_field, vector_field],
"fields": [ut.gen_primary_field(), filter_field, vector_field],
"auto_id": False
}
connect.create_collection(collection_name, fields)
ids = [i for i in range(nb)]
entities = gen_entities_by_fields(fields["fields"], nb, default_dim, ids)
entities = ut.gen_entities_by_fields(fields["fields"], nb, ut.default_dim, ids)
logging.getLogger().info(entities)
result = connect.insert(collection_name, entities)
assert result.primary_keys == ids
@ -264,7 +271,7 @@ class TestInsertBase:
"""
nb = insert_count
with pytest.raises(Exception) as e:
entities = gen_entities(nb)
entities = ut.gen_entities(nb)
del entities[0]
connect.insert(id_collection, entities)
@ -355,7 +362,7 @@ class TestInsertBase:
"""
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
entities = gen_entities(default_nb)
entities = ut.gen_entities(default_nb)
entities[0]["values"] = ids
result = connect.insert(id_collection, entities, partition_name=default_tag)
assert result.primary_keys == ids
@ -369,7 +376,7 @@ class TestInsertBase:
method: create partition and insert info collection without tag params
expected: the collection row count equals to nb
"""
result = connect.insert(collection, default_entities, partition_name=default_partition_name)
result = connect.insert(collection, default_entities, partition_name=ut.default_partition_name)
assert len(result.primary_keys) == default_nb
connect.flush([collection])
stats = connect.get_collection_stats(collection)
@ -383,7 +390,7 @@ class TestInsertBase:
method: create collection and insert entities in it, with the not existed partition_name param
expected: error raised
"""
tag = gen_unique_str()
tag = ut.gen_unique_str()
with pytest.raises(Exception) as e:
connect.insert(collection, default_entities, partition_name=tag)
@ -409,7 +416,7 @@ class TestInsertBase:
method: the entities dimension is half of the collection dimension, check the status
expected: error raised
"""
vectors = gen_vectors(default_nb, int(default_dim) // 2)
vectors = ut.gen_vectors(default_nb, int(ut.default_dim) // 2)
insert_entities = copy.deepcopy(default_entities)
insert_entities[-1]["values"] = vectors
with pytest.raises(Exception) as e:
@ -422,7 +429,7 @@ class TestInsertBase:
method: update entity field name
expected: error raised
"""
tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", "int64new")
tmp_entity = ut.update_field_name(copy.deepcopy(default_entity), "int64", "int64new")
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -433,7 +440,7 @@ class TestInsertBase:
method: update entity field type
expected: error raised
"""
tmp_entity = update_field_type(copy.deepcopy(default_entity), "int64", DataType.FLOAT)
tmp_entity = ut.update_field_type(copy.deepcopy(default_entity), "int64", DataType.FLOAT)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -444,7 +451,7 @@ class TestInsertBase:
method: update entity field value
expected: error raised
"""
tmp_entity = update_field_value(copy.deepcopy(default_entity), DataType.FLOAT, 's')
tmp_entity = ut.update_field_value(copy.deepcopy(default_entity), DataType.FLOAT, 's')
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -455,7 +462,7 @@ class TestInsertBase:
method: add entity field
expected: error raised
"""
tmp_entity = add_field(copy.deepcopy(default_entity))
tmp_entity = ut.add_field(copy.deepcopy(default_entity))
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -466,7 +473,7 @@ class TestInsertBase:
method: add entity vector field
expected: error raised
"""
tmp_entity = add_vector_field(default_nb, default_dim)
tmp_entity = ut.add_vector_field(default_nb, ut.default_dim)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -477,7 +484,7 @@ class TestInsertBase:
method: remove entity field
expected: error raised
"""
tmp_entity = remove_field(copy.deepcopy(default_entity))
tmp_entity = ut.remove_field(copy.deepcopy(default_entity))
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -488,7 +495,7 @@ class TestInsertBase:
method: remove entity vector field
expected: error raised
"""
tmp_entity = remove_vector_field(copy.deepcopy(default_entity))
tmp_entity = ut.remove_vector_field(copy.deepcopy(default_entity))
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -542,7 +549,7 @@ class TestInsertBase:
pytest.skip("Skip test in http mode")
thread_num = 8
threads = []
milvus = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"], try_connect=False)
milvus = ut.get_milvus(host=args["ip"], port=args["port"], handler=args["handler"], try_connect=False)
def insert(thread_i):
logging.getLogger().info("In thread-%d" % thread_i)
@ -567,7 +574,7 @@ class TestInsertBase:
expected: the count is equal to 0
"""
delete_nums = 500
disable_flush(connect)
ut.disable_flush(connect)
result = connect.insert(collection, default_entities)
ids = result.primary_keys
res = connect.get_entity_by_id(collection, ids[:delete_nums])
@ -578,7 +585,7 @@ class TestInsertBase:
class TestInsertBinary:
@pytest.fixture(
scope="function",
params=gen_binary_index()
params=ut.gen_binary_index()
)
def get_binary_index(self, request):
request.param["metric_type"] = "JACCARD"
@ -638,7 +645,7 @@ class TestInsertBinary:
assert len(result.primary_keys) == default_nb
connect.flush([binary_collection])
index = connect.describe_index(binary_collection, "")
create_target_index(get_binary_index, binary_field_name)
ut.create_target_index(get_binary_index, binary_field_name)
assert index == get_binary_index
@pytest.mark.timeout(ADD_TIMEOUT)
@ -654,7 +661,7 @@ class TestInsertBinary:
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_binary_index)
index = connect.describe_index(binary_collection, "")
create_target_index(get_binary_index, binary_field_name)
ut.create_target_index(get_binary_index, binary_field_name)
assert index == get_binary_index
@pytest.mark.tags(CaseLabel.L0)
@ -666,12 +673,12 @@ class TestInsertBinary:
"""
result = connect.insert(binary_collection, default_binary_entities)
connect.flush([binary_collection])
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1,
query, vecs = ut.gen_query_vectors(binary_field_name, default_binary_entities, ut.default_top_k, 1,
metric_type="JACCARD")
connect.load_collection(binary_collection)
res = connect.search(binary_collection, query)
logging.getLogger().debug(res)
assert len(res[0]) == default_top_k
assert len(res[0]) == ut.default_top_k
class TestInsertAsync:
@ -706,7 +713,7 @@ class TestInsertAsync:
expected: length of ids is equal to the length of vectors
"""
nb = insert_count
future = connect.insert(collection, gen_entities(nb), _async=True)
future = connect.insert(collection, ut.gen_entities(nb), _async=True)
ids = future.result().primary_keys
connect.flush([collection])
assert len(ids) == nb
@ -719,7 +726,7 @@ class TestInsertAsync:
expected: length of ids is equal to the length of vectors
"""
nb = insert_count
result = connect.insert(collection, gen_entities(nb), _async=False)
result = connect.insert(collection, ut.gen_entities(nb), _async=False)
# ids = future.result()
connect.flush([collection])
assert len(result.primary_keys) == nb
@ -732,7 +739,7 @@ class TestInsertAsync:
expected: length of ids is equal to the length of vectors
"""
nb = insert_count
future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_result)
future = connect.insert(collection, ut.gen_entities(nb), _async=True, _callback=self.check_result)
future.done()
ids = future.result().primary_keys
assert len(ids) == nb
@ -745,7 +752,7 @@ class TestInsertAsync:
expected: length of ids is equal to the length of vectors
"""
nb = 50000
future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_result)
future = connect.insert(collection, ut.gen_entities(nb), _async=True, _callback=self.check_result)
result = future.result()
assert len(result.primary_keys) == nb
connect.flush([collection])
@ -761,7 +768,7 @@ class TestInsertAsync:
expected: length of ids is equal to the length of vectors
"""
nb = 100000
future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_status, timeout=1)
future = connect.insert(collection, ut.gen_entities(nb), _async=True, _callback=self.check_status, timeout=1)
with pytest.raises(Exception) as e:
result = future.result()
@ -772,7 +779,7 @@ class TestInsertAsync:
method: set different vectors as insert method params
expected: length of ids is equal to the length of vectors
"""
collection_new = gen_unique_str()
collection_new = ut.gen_unique_str()
future = connect.insert(collection_new, default_entities, _async=True)
future.done()
with pytest.raises(Exception) as e:
@ -802,7 +809,7 @@ class TestInsertMultiCollections:
@pytest.fixture(
scope="function",
params=gen_simple_index()
params=ut.gen_simple_index()
)
def get_simple_index(self, request, connect):
logging.getLogger().info(request.param)
@ -821,7 +828,7 @@ class TestInsertMultiCollections:
collection_num = 10
collection_list = []
for i in range(collection_num):
collection_name = gen_unique_str(uid)
collection_name = ut.gen_unique_str(uid)
collection_list.append(collection_name)
connect.create_collection(collection_name, default_fields)
result = connect.insert(collection_name, default_entities)
@ -840,7 +847,7 @@ class TestInsertMultiCollections:
method: delete collection_2 and insert vector to collection_1
expected: row count equals the length of entities inserted
"""
collection_name = gen_unique_str(uid)
collection_name = ut.gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
connect.drop_collection(collection)
result = connect.insert(collection_name, default_entity)
@ -855,14 +862,14 @@ class TestInsertMultiCollections:
method: build index and insert vector
expected: status ok
"""
collection_name = gen_unique_str(uid)
collection_name = ut.gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
connect.create_index(collection, field_name, get_simple_index)
result = connect.insert(collection_name, default_entity)
assert len(result.primary_keys) == 1
if get_simple_index["index_type"] != "FLAT":
index = connect.describe_index(collection, "")
create_target_index(get_simple_index, field_name)
ut.create_target_index(get_simple_index, field_name)
assert index == get_simple_index
connect.drop_collection(collection_name)
@ -874,14 +881,14 @@ class TestInsertMultiCollections:
method: build index and insert vector
expected: status ok
"""
collection_name = gen_unique_str(uid)
collection_name = ut.gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
result = connect.insert(collection, default_entity)
connect.flush([collection])
connect.create_index(collection_name, field_name, get_simple_index)
if get_simple_index["index_type"] != "FLAT":
index = connect.describe_index(collection_name, "")
create_target_index(get_simple_index, field_name)
ut.create_target_index(get_simple_index, field_name)
assert index == get_simple_index
stats = connect.get_collection_stats(collection)
assert stats[row_count] == 1
@ -894,7 +901,7 @@ class TestInsertMultiCollections:
method: build index and insert vector
expected: status ok
"""
collection_name = gen_unique_str(uid)
collection_name = ut.gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
result = connect.insert(collection, default_entity)
connect.flush([collection])
@ -910,7 +917,7 @@ class TestInsertMultiCollections:
method: search collection and insert entity
expected: status ok
"""
collection_name = gen_unique_str(uid)
collection_name = ut.gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
connect.load_collection(collection)
res = connect.search(collection, default_single_query)
@ -928,7 +935,7 @@ class TestInsertMultiCollections:
method: search collection and insert entity
expected: status ok
"""
collection_name = gen_unique_str(uid)
collection_name = ut.gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
result = connect.insert(collection, default_entity)
connect.flush([collection])
@ -945,7 +952,7 @@ class TestInsertMultiCollections:
method: search collection, sleep, and insert entity
expected: status ok
"""
collection_name = gen_unique_str(uid)
collection_name = ut.gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
result = connect.insert(collection, default_entity)
connect.flush([collection])
@ -982,49 +989,49 @@ class TestInsertInvalid(object):
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
params=ut.gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
params=ut.gen_invalid_strs()
)
def get_tag_name(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
params=ut.gen_invalid_strs()
)
def get_field_name(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
params=ut.gen_invalid_strs()
)
def get_field_type(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
params=ut.gen_invalid_strs()
)
def get_field_int_value(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_ints()
params=ut.gen_invalid_ints()
)
def get_entity_id(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_vectors()
params=ut.gen_invalid_vectors()
)
def get_field_vectors_value(self, request):
yield request.param
@ -1069,21 +1076,21 @@ class TestInsertInvalid(object):
@pytest.mark.tags(CaseLabel.L2)
def test_insert_with_invalid_field_name(self, connect, collection, get_field_name):
tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", get_field_name)
tmp_entity = ut.update_field_name(copy.deepcopy(default_entity), "int64", get_field_name)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@pytest.mark.tags(CaseLabel.L2)
def test_insert_with_invalid_field_type(self, connect, collection, get_field_type):
field_type = get_field_type
tmp_entity = update_field_type(copy.deepcopy(default_entity), 'float', field_type)
tmp_entity = ut.update_field_type(copy.deepcopy(default_entity), 'float', field_type)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@pytest.mark.tags(CaseLabel.L2)
def test_insert_with_invalid_field_value(self, connect, collection, get_field_int_value):
field_value = get_field_int_value
tmp_entity = update_field_type(copy.deepcopy(default_entity), 'int64', field_value)
tmp_entity = ut.update_field_type(copy.deepcopy(default_entity), 'int64', field_value)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -1108,49 +1115,49 @@ class TestInsertInvalidBinary(object):
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
params=ut.gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
params=ut.gen_invalid_strs()
)
def get_tag_name(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
params=ut.gen_invalid_strs()
)
def get_field_name(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
params=ut.gen_invalid_strs()
)
def get_field_type(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
params=ut.gen_invalid_strs()
)
def get_field_int_value(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_ints()
params=ut.gen_invalid_ints()
)
def get_entity_id(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_vectors()
params=ut.gen_invalid_vectors()
)
def get_field_vectors_value(self, request):
yield request.param
@ -1162,13 +1169,13 @@ class TestInsertInvalidBinary(object):
method: insert with invalid field name
expected: raise exception
"""
tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name)
tmp_entity = ut.update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
@pytest.mark.tags(CaseLabel.L2)
def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value):
tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value)
tmp_entity = ut.update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
@ -1205,7 +1212,7 @@ class TestInsertInvalidBinary(object):
expected: raise exception
"""
field_type = get_field_type
tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', field_type)
tmp_entity = ut.update_field_type(copy.deepcopy(default_binary_entity), 'int64', field_type)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)