Debug insert and collection stats

Signed-off-by: ThreadDao <yufen.zong@zilliz.com>
pull/4973/head^2
ThreadDao 2021-03-02 19:21:38 +08:00 committed by yefu.chen
parent 9c5c8f35e8
commit f064e77f21
6 changed files with 243 additions and 169 deletions

View File

@ -147,7 +147,6 @@ class TestSearchBase:
yield request.param
# PASS
@pytest.mark.skip("r0.3-test")
def test_search_flat(self, connect, collection, get_top_k, get_nq):
'''
target: test basic search function, all the search params is corrent, change top-k value
@ -258,7 +257,6 @@ class TestSearchBase:
assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64")
# Pass
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
@ -306,8 +304,8 @@ class TestSearchBase:
assert len(res[0]) == default_top_k
# pass
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
@pytest.mark.skip("r0.3-test")
def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
target: test basic search function, all the search params is corrent, test all index params, and build
@ -338,7 +336,6 @@ class TestSearchBase:
assert len(res) == nq
# PASS
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
@ -424,7 +421,6 @@ class TestSearchBase:
assert res[1]._distances[0] > epsilon
# Pass
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
'''
@ -514,7 +510,6 @@ class TestSearchBase:
assert check_id_result(res[0], ids[0])
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
@ -548,7 +543,6 @@ class TestSearchBase:
assert len(res) == nq
# PASS
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
'''
@ -609,7 +603,6 @@ class TestSearchBase:
res = connect.search(collection_name, default_query)
# PASS
@pytest.mark.skip("r0.3-test")
def test_search_distance_l2(self, connect, collection):
'''
target: search collection, and check the result: distance
@ -629,7 +622,6 @@ class TestSearchBase:
assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
# Pass
@pytest.mark.skip("r0.3-test")
def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index):
'''
target: search collection, and check the result: distance
@ -684,7 +676,6 @@ class TestSearchBase:
assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon
# Pass
@pytest.mark.skip("r0.3-test")
def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index):
'''
target: search collection, and check the result: distance
@ -771,7 +762,6 @@ class TestSearchBase:
assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
# PASS
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_distance_substructure_flat_index(self, connect, binary_collection):
'''
@ -790,7 +780,6 @@ class TestSearchBase:
assert len(res[0]) == 0
# PASS
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
'''
@ -810,7 +799,6 @@ class TestSearchBase:
assert res[1][0].id == ids[1]
# PASS
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_distance_superstructure_flat_index(self, connect, binary_collection):
'''
@ -829,7 +817,6 @@ class TestSearchBase:
assert len(res[0]) == 0
# PASS
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection):
'''
@ -851,7 +838,6 @@ class TestSearchBase:
assert res[1][0].distance <= epsilon
# PASS
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_distance_tanimoto_flat_index(self, connect, binary_collection):
'''
@ -985,7 +971,6 @@ class TestSearchBase:
assert getattr(r.entity, "int64") == getattr(r.entity, "id")
@pytest.mark.skip("r0.3-test")
class TestSearchDSL(object):
"""
******************************************************************
@ -1579,7 +1564,6 @@ class TestSearchDSL(object):
res = connect.search(collection, query)
@pytest.mark.skip("r0.3-test")
class TestSearchDSLBools(object):
"""
******************************************************************
@ -1766,7 +1750,6 @@ class TestSearchInvalid(object):
yield request.param
# Pass
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
'''
@ -1788,7 +1771,6 @@ class TestSearchInvalid(object):
res = connect.search(collection, query)
# pass
@pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_with_invalid_params_binary(self, connect, binary_collection):
'''

View File

@ -10,18 +10,19 @@ from constants import *
uid = "get_collection_stats"
class TestGetCollectionStats:
"""
******************************************************************
The following cases are used to test `collection_stats` function
******************************************************************
"""
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_collection_name(self, request):
def get_invalid_collection_name(self, request):
yield request.param
@pytest.fixture(
@ -46,6 +47,17 @@ class TestGetCollectionStats:
else:
pytest.skip("Skip index Temporary")
@pytest.fixture(
scope="function",
params=[
1,
1000,
2001
],
)
def insert_count(self, request):
yield request.param
def test_get_collection_stats_name_not_existed(self, connect, collection):
'''
target: get collection stats where collection name does not exist
@ -53,22 +65,19 @@ class TestGetCollectionStats:
expected: status not ok
'''
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
connect.get_collection_stats(collection_name)
connect.drop_collection(collection_name)
with pytest.raises(Exception) as e:
connect.get_collection_stats(collection_name)
@pytest.mark.level(2)
def test_get_collection_stats_name_invalid(self, connect, get_collection_name):
def test_get_collection_stats_name_invalid(self, connect, get_invalid_collection_name):
'''
target: get collection stats where collection name is invalid
method: call collection_stats with invalid collection_name
expected: status not ok
'''
collection_name = get_collection_name
collection_name = get_invalid_collection_name
with pytest.raises(Exception) as e:
stats = connect.get_collection_stats(collection_name)
connect.get_collection_stats(collection_name)
def test_get_collection_stats_empty(self, connect, collection):
'''
@ -77,10 +86,17 @@ class TestGetCollectionStats:
expected: segment = []
'''
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == 0
# assert len(stats["partitions"]) == 1
# assert stats["partitions"][0]["tag"] == default_partition_name
# assert stats["partitions"][0]["row_count"] == 0
connect.flush([collection])
assert stats[row_count] == 0
def test_get_collection_stats_without_connection(self, collection, dis_connect):
'''
target: test count_entities, without connection
method: calling count_entities with correct params, with a disconnected instance
expected: count_entities raise exception
'''
with pytest.raises(Exception) as e:
dis_connect.get_collection_stats(collection)
def test_get_collection_stats_batch(self, connect, collection):
'''
@ -89,12 +105,10 @@ class TestGetCollectionStats:
expected: count as expected
'''
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == default_nb
# assert len(stats["partitions"]) == 1
# assert stats["partitions"][0]["tag"] == default_partition_name
# assert stats["partitions"][0]["row_count"] == default_nb
assert int(stats[row_count]) == default_nb
def test_get_collection_stats_single(self, connect, collection):
'''
@ -104,13 +118,10 @@ class TestGetCollectionStats:
'''
nb = 10
for i in range(nb):
ids = connect.insert(collection, default_entity)
connect.insert(collection, default_entity)
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == nb
# assert len(stats["partitions"]) == 1
# assert stats["partitions"][0]["tag"] == default_partition_name
# assert stats["partitions"][0]["row_count"] == nb
assert stats[row_count] == nb
@pytest.mark.skip("delete_by_id not support yet")
def test_get_collection_stats_after_delete(self, connect, collection):
@ -184,12 +195,10 @@ class TestGetCollectionStats:
'''
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
assert len(ids) == default_nb
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == default_nb
# assert len(stats["partitions"]) == 2
# assert stats["partitions"][1]["tag"] == default_tag
# assert stats["partitions"][1]["row_count"] == default_nb
assert stats[row_count] == default_nb
def test_get_collection_stats_partitions(self, connect, collection):
'''
@ -200,26 +209,88 @@ class TestGetCollectionStats:
new_tag = "new_tag"
connect.create_partition(collection, default_tag)
connect.create_partition(collection, new_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.insert(collection, default_entities, partition_tag=default_tag)
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == default_nb
# for partition in stats["partitions"]:
# if partition["tag"] == default_tag:
# assert partition["row_count"] == default_nb
# else:
# assert partition["row_count"] == 0
ids = connect.insert(collection, default_entities, partition_tag=new_tag)
assert stats[row_count] == default_nb
connect.insert(collection, default_entities, partition_tag=new_tag)
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == default_nb * 2
# for partition in stats["partitions"]:
# if partition["tag"] in [default_tag, new_tag]:
# assert partition["row_count"] == default_nb
ids = connect.insert(collection, default_entities)
assert stats[row_count] == default_nb * 2
connect.insert(collection, default_entities)
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == default_nb * 3
assert stats[row_count] == default_nb * 3
# @pytest.mark.tags("0331")
def test_get_collection_stats_partitions_A(self, connect, collection, insert_count):
'''
target: test collection rows_count is correct or not
method: create collection, create partitions and add entities in it,
assert the value returned by count_entities method is equal to length of entities
expected: the count is equal to the length of entities
'''
new_tag = "new_tag"
entities = gen_entities(insert_count)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, new_tag)
connect.insert(collection, entities)
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats[row_count] == insert_count
# @pytest.mark.tags("0331")
def test_get_collection_stats_partitions_B(self, connect, collection, insert_count):
'''
target: test collection rows_count is correct or not
method: create collection, create partitions and add entities in one of the partitions,
assert the value returned by count_entities method is equal to length of entities
expected: the count is equal to the length of entities
'''
new_tag = "new_tag"
entities = gen_entities(insert_count)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, new_tag)
connect.insert(collection, entities, partition_tag=default_tag)
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats[row_count] == insert_count
# @pytest.mark.tags("0331")
def test_get_collection_stats_partitions_C(self, connect, collection, insert_count):
'''
target: test collection rows_count is correct or not
method: create collection, create partitions and add entities in one of the partitions,
assert the value returned by count_entities method is equal to length of entities
expected: the count is equal to the length of vectors
'''
new_tag = "new_tag"
entities = gen_entities(insert_count)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, new_tag)
connect.insert(collection, entities)
connect.insert(collection, entities, partition_tag=default_tag)
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats[row_count] == insert_count*2
# @pytest.mark.tags("0331")
def test_get_collection_stats_partitions_D(self, connect, collection, insert_count):
'''
target: test collection rows_count is correct or not
method: create collection, create partitions and add entities in one of the partitions,
assert the value returned by count_entities method is equal to length of entities
expected: the collection count is equal to the length of entities
'''
new_tag = "new_tag"
entities = gen_entities(insert_count)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, new_tag)
connect.insert(collection, entities, partition_tag=default_tag)
connect.insert(collection, entities, partition_tag=new_tag)
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats[row_count] == insert_count*2
# TODO: assert metric type in stats response
def test_get_collection_stats_after_index_created(self, connect, collection, get_simple_index):
@ -228,17 +299,11 @@ class TestGetCollectionStats:
method: create collection, add vectors, create index and call collection_stats
expected: status ok, index created and shown in segments
'''
ids = connect.insert(collection, default_entities)
connect.insert(collection, default_entities)
connect.flush([collection])
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
stats = connect.get_collection_stats(collection)
logging.getLogger().info(stats)
assert stats["row_count"] == default_nb
# for file in stats["partitions"][0]["segments"][0]["files"]:
# if file["name"] == default_float_vec_field_name and "index_type" in file:
# assert file["data_size"] > 0
# assert file["index_type"] == get_simple_index["index_type"]
# break
assert stats[row_count] == default_nb
# TODO: assert metric type in stats response
def test_get_collection_stats_after_index_created_ip(self, connect, collection, get_simple_index):
@ -249,16 +314,12 @@ class TestGetCollectionStats:
'''
get_simple_index["metric_type"] = "IP"
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
connect.flush([collection])
get_simple_index.update({"metric_type": "IP"})
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == default_nb
# for file in stats["partitions"][0]["segments"][0]["files"]:
# if file["name"] == default_float_vec_field_name and "index_type" in file:
# assert file["data_size"] > 0
# assert file["index_type"] == get_simple_index["index_type"]
# break
assert stats[row_count] == default_nb
# TODO: assert metric type in stats response
def test_get_collection_stats_after_index_created_jac(self, connect, binary_collection, get_jaccard_index):
@ -269,14 +330,9 @@ class TestGetCollectionStats:
'''
ids = connect.insert(binary_collection, default_binary_entities)
connect.flush([binary_collection])
connect.create_index(binary_collection, "binary_vector", get_jaccard_index)
connect.create_index(binary_collection, default_binary_vec_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
assert stats["row_count"] == default_nb
# for file in stats["partitions"][0]["segments"][0]["files"]:
# if file["name"] == default_float_vec_field_name and "index_type" in file:
# assert file["data_size"] > 0
# assert file["index_type"] == get_simple_index["index_type"]
# break
assert stats[row_count] == default_nb
def test_get_collection_stats_after_create_different_index(self, connect, collection):
'''
@ -288,14 +344,9 @@ class TestGetCollectionStats:
connect.flush([collection])
for index_type in ["IVF_FLAT", "IVF_SQ8"]:
connect.create_index(collection, default_float_vec_field_name,
{"index_type": index_type, "params":{"nlist": 1024}, "metric_type": "L2"})
{"index_type": index_type, "params": {"nlist": 1024}, "metric_type": "L2"})
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == default_nb
# for file in stats["partitions"][0]["segments"][0]["files"]:
# if file["name"] == default_float_vec_field_name and "index_type" in file:
# assert file["data_size"] > 0
# assert file["index_type"] == index_type
# break
assert stats[row_count] == default_nb
def test_collection_count_multi_collections(self, connect):
'''
@ -310,12 +361,11 @@ class TestGetCollectionStats:
collection_name = gen_unique_str(uid)
collection_list.append(collection_name)
connect.create_collection(collection_name, default_fields)
res = connect.insert(collection_name, default_entities)
ids = connect.insert(collection_name, default_entities)
connect.flush(collection_list)
for i in range(collection_num):
stats = connect.get_collection_stats(collection_list[i])
# assert stats["partitions"][0]["row_count"] == default_nb
assert stats["row_count"] == default_nb
assert stats[row_count] == default_nb
connect.drop_collection(collection_list[i])
@pytest.mark.level(2)
@ -334,23 +384,19 @@ class TestGetCollectionStats:
connect.create_collection(collection_name, default_fields)
res = connect.insert(collection_name, default_entities)
connect.flush(collection_list)
index_1 = {"index_type": "IVF_SQ8", "params": {"nlist": 1024}, "metric_type": "L2"}
index_2 = {"index_type": "IVF_FLAT", "params": {"nlist": 1024}, "metric_type": "L2"}
if i % 2:
connect.create_index(collection_name, default_float_vec_field_name,
{"index_type": "IVF_SQ8", "params":{"nlist": 1024}, "metric_type": "L2"})
connect.create_index(collection_name, default_float_vec_field_name, index_1)
else:
connect.create_index(collection_name, default_float_vec_field_name,
{"index_type": "IVF_FLAT","params":{"nlist": 1024}, "metric_type": "L2"})
connect.create_index(collection_name, default_float_vec_field_name, index_2)
for i in range(collection_num):
stats = connect.get_collection_stats(collection_list[i])
assert stats["row_count"] == default_nb
# if i % 2:
# for file in stats["partitions"][0]["segments"][0]["files"]:
# if file["name"] == default_float_vec_field_name and "index_type" in file:
# assert file["index_type"] == "IVF_SQ8"
# break
# else:
# for file in stats["partitions"][0]["segments"][0]["files"]:
# if file["name"] == default_float_vec_field_name and "index_type" in file:
# assert file["index_type"] == "IVF_FLAT"
# break
assert stats[row_count] == default_nb
index = connect.describe_index(collection_list[i], default_float_vec_field_name)
if i % 2:
assert index == index_1
else:
assert index == index_2
# break
connect.drop_collection(collection_list[i])

View File

@ -6,19 +6,20 @@ import time
import threading
from multiprocessing import Process
import sklearn.preprocessing
import pytest
from utils import *
from constants import *
uid = "create_collection"
class TestCreateCollection:
"""
******************************************************************
The following cases are used to test `create_collection` function
******************************************************************
"""
@pytest.fixture(
scope="function",
params=gen_single_filter_fields()
@ -52,8 +53,8 @@ class TestCreateCollection:
vector_field = get_vector_field
collection_name = gen_unique_str(uid)
fields = {
"fields": [filter_field, vector_field],
# "segment_row_limit": default_segment_row_limit
"fields": [filter_field, vector_field],
# "segment_row_limit": default_segment_row_limit
}
logging.getLogger().info(fields)
connect.create_collection(collection_name, fields)
@ -93,7 +94,7 @@ class TestCreateCollection:
expected: error raised
'''
connect.insert(collection, default_entity)
connect.flush([collection])
# connect.flush([collection])
with pytest.raises(Exception) as e:
connect.create_collection(collection, default_fields)
@ -140,7 +141,7 @@ class TestCreateCollection:
method: create collection using multithread,
expected: collections are created
'''
threads_num = 8
threads_num = 8
threads = []
collection_names = []
@ -148,6 +149,7 @@ class TestCreateCollection:
collection_name = gen_unique_str(uid)
collection_names.append(collection_name)
connect.create_collection(collection_name, default_fields)
for i in range(threads_num):
t = TestThread(target=create, args=())
threads.append(t)
@ -155,7 +157,7 @@ class TestCreateCollection:
time.sleep(0.2)
for t in threads:
t.join()
for item in collection_names:
assert item in connect.list_collections()
connect.drop_collection(item)
@ -165,6 +167,7 @@ class TestCreateCollectionInvalid(object):
"""
Test creating collections with invalid params
"""
@pytest.fixture(
scope="function",
params=gen_invalid_metric_types()
@ -217,7 +220,7 @@ class TestCreateCollectionInvalid(object):
fields = copy.deepcopy(default_fields)
fields["fields"][-1]["params"]["dim"] = dimension
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
connect.create_collection(collection_name, fields)
@pytest.mark.level(2)
@pytest.mark.tags("0331")

View File

@ -5,7 +5,7 @@ import copy
import threading
from multiprocessing import Pool, Process
import pytest
from milvus import DataType
from milvus import DataType, ParamError, BaseException
from utils import *
from constants import *
@ -36,8 +36,9 @@ class TestInsertBase:
)
def get_simple_index(self, request, connect):
# if str(connect._cmd("mode")) == "CPU":
# if request.param["index_type"] in index_cpu_not_support():
# pytest.skip("CPU not support index_type: ivf_sq8h")
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("CPU not support index_type: ivf_sq8h")
logging.getLogger().info(request.param)
return request.param
@pytest.fixture(
@ -54,6 +55,7 @@ class TestInsertBase:
def get_vector_field(self, request):
yield request.param
@pytest.mark.tags("0331")
def test_insert_with_empty_entity(self, connect, collection):
'''
target: test insert with empty entity list
@ -62,7 +64,7 @@ class TestInsertBase:
'''
entities = []
with pytest.raises(ParamError) as e:
status, ids = connect.insert(collection, entities)
connect.insert(collection, entities)
def test_insert_with_None(self, connect, collection):
'''
@ -71,10 +73,11 @@ class TestInsertBase:
expected: raises a ParamError
'''
entity = None
with pytest.raises(Exception) as e:
status, ids = connect.insert(collection, entity)
with pytest.raises(ParamError) as e:
connect.insert(collection, entity)
@pytest.mark.timeout(ADD_TIMEOUT)
@pytest.mark.tags("0331")
def test_insert_collection_not_existed(self, connect):
'''
target: test insert, with collection not existed
@ -82,10 +85,11 @@ class TestInsertBase:
expected: raise a BaseException
'''
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
with pytest.raises(BaseException) as e:
connect.insert(collection_name, default_entities)
@pytest.mark.level(2)
@pytest.mark.tags("0331")
def test_insert_without_connect(self, dis_connect, collection):
'''
target: test insert entities without connection
@ -93,28 +97,30 @@ class TestInsertBase:
expected: raise exception
'''
with pytest.raises(Exception) as e:
ids = dis_connect.insert(collection, default_entities)
dis_connect.insert(collection, default_entities)
@pytest.mark.timeout(ADD_TIMEOUT)
@pytest.mark.tags("0331")
def test_insert_drop_collection(self, connect, collection):
'''
target: test delete collection after insert entities
method: insert entities and drop collection
expected: has_collection false
'''
ids = connect.insert(collection, default_entity_row)
ids = connect.insert(collection, default_entity)
assert len(ids) == 1
connect.drop_collection(collection)
assert connect.has_collection(collection) == False
@pytest.mark.timeout(ADD_TIMEOUT)
@pytest.mark.tags("0331")
def test_insert_flush_drop_collection(self, connect, collection):
'''
target: test drop collection after insert entities for a while
method: insert entities, sleep, and delete collection
expected: has_collection false
'''
ids = connect.insert(collection, default_entity_row)
ids = connect.insert(collection, default_entity)
assert len(ids) == 1
connect.flush([collection])
connect.drop_collection(collection)
@ -133,10 +139,6 @@ class TestInsertBase:
connect.create_index(collection, field_name, get_simple_index)
index = connect.describe_index(collection, field_name)
assert index == get_simple_index
# fields = info["fields"]
# for field in fields:
# if field["name"] == field_name:
# assert field["indexes"][0] == get_simple_index
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_after_create_index(self, connect, collection, get_simple_index):
@ -200,9 +202,10 @@ class TestInsertBase:
assert len(res_ids) == nb
assert res_ids == ids
stats = connect.get_collection_stats(id_collection)
assert stats["row_count"] == nb
assert stats[row_count] == nb
@pytest.mark.timeout(ADD_TIMEOUT)
# @pytest.mark.tags("0331")
def test_insert_the_same_ids(self, connect, id_collection, insert_count):
'''
target: test insert vectors in collection, use customize the same ids
@ -216,7 +219,7 @@ class TestInsertBase:
assert len(res_ids) == nb
assert res_ids == ids
stats = connect.get_collection_stats(id_collection)
assert stats["row_count"] == nb
assert stats[row_count] == nb
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_ids_fields(self, connect, get_filter_field, get_vector_field):
@ -231,16 +234,17 @@ class TestInsertBase:
collection_name = gen_unique_str("test_collection")
fields = {
"fields": [filter_field, vector_field],
"auto_id": True
"auto_id": False
}
connect.create_collection(collection_name, fields)
ids = [i for i in range(nb)]
entities = gen_entities_by_fields(fields["fields"], nb, dim)
entities = gen_entities_by_fields(fields["fields"], nb, default_dim)
logging.getLogger().info(entities)
res_ids = connect.insert(collection_name, entities, ids)
assert res_ids == ids
connect.flush([collection_name])
stats = connect.get_collection_stats(id_collection)
assert stats["row_count"] == nb
stats = connect.get_collection_stats(collection_name)
assert stats[row_count] == nb
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_ids_not_match(self, connect, id_collection, insert_count):
@ -250,7 +254,7 @@ class TestInsertBase:
expected: exception raised
'''
nb = insert_count
with pytest.raises(Exception) as e:
with pytest.raises(BaseException) as e:
connect.insert(id_collection, gen_entities(nb))
# TODO
@ -262,9 +266,9 @@ class TestInsertBase:
expected: BaseException raised
'''
ids = [i for i in range(default_nb)]
res_ids = connect.insert(id_collection, default_entities, ids)
with pytest.raises(Exception) as e:
res_ids_new = connect.insert(id_collection, default_entities)
connect.insert(collection, default_entities, ids)
with pytest.raises(BaseException) as e:
connect.insert(collection, default_entities)
# TODO: assert exception && enable
@pytest.mark.level(2)
@ -275,8 +279,8 @@ class TestInsertBase:
method: test insert vectors twice, use not ids first, and then use customize ids
expected: error raised
'''
with pytest.raises(Exception) as e:
res_ids = connect.insert(id_collection, default_entities)
with pytest.raises(BaseException) as e:
connect.insert(id_collection, default_entities)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_ids_length_not_match_batch(self, connect, id_collection):
@ -287,8 +291,8 @@ class TestInsertBase:
'''
ids = [i for i in range(1, default_nb)]
logging.getLogger().info(len(ids))
with pytest.raises(Exception) as e:
res_ids = connect.insert(id_collection, default_entities, ids)
with pytest.raises(BaseException) as e:
connect.insert(id_collection, default_entities, ids)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_ids_length_not_match_single(self, connect, id_collection):
@ -299,8 +303,8 @@ class TestInsertBase:
'''
ids = [i for i in range(1, default_nb)]
logging.getLogger().info(len(ids))
with pytest.raises(Exception) as e:
res_ids = connect.insert(id_collection, default_entity, ids)
with pytest.raises(BaseException) as e:
connect.insert(id_collection, default_entity, ids)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_partition(self, connect, collection):
@ -313,9 +317,9 @@ class TestInsertBase:
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
assert len(ids) == default_nb
assert connect.has_partition(collection, default_tag)
connect.flush([collection_name])
stats = connect.get_collection_stats(id_collection)
assert stats["row_count"] == default_nb
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats[row_count] == default_nb
# TODO
@pytest.mark.timeout(ADD_TIMEOUT)
@ -331,17 +335,18 @@ class TestInsertBase:
assert res_ids == ids
@pytest.mark.timeout(ADD_TIMEOUT)
@pytest.mark.tags("0331")
def test_insert_default_partition(self, connect, collection):
'''
target: test insert entities into default partition
method: create partition and insert info collection without tag params
expected: the collection row count equals to nb
'''
default_tag = "_default"
with pytest.raises(Exception) as e:
connect.create_partition(collection, default_tag)
with pytest.raises(BaseException) as e:
connect.create_partition(collection, default_partition_name)
@pytest.mark.timeout(ADD_TIMEOUT)
@pytest.mark.tags("0331")
def test_insert_partition_not_existed(self, connect, collection):
'''
target: test insert entities in collection created before
@ -350,7 +355,7 @@ class TestInsertBase:
'''
tag = gen_unique_str()
with pytest.raises(Exception) as e:
ids = connect.insert(collection, default_entities, partition_tag=tag)
connect.insert(collection, default_entities, partition_tag=tag)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_partition_repeatedly(self, connect, collection):
@ -364,7 +369,7 @@ class TestInsertBase:
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.flush([collection])
res = connect.get_collection_stats(collection)
assert res["row_count"] == 2 * default_nb
assert res[row_count] == 2 * default_nb
def test_insert_dim_not_matched(self, connect, collection):
'''
@ -375,9 +380,11 @@ class TestInsertBase:
vectors = gen_vectors(default_nb, int(default_dim) // 2)
insert_entities = copy.deepcopy(default_entities)
insert_entities[-1][default_float_vec_field_name] = vectors
# logging.getLogger().info(len(insert_entities[-1][default_float_vec_field_name][0]))
with pytest.raises(Exception) as e:
ids = connect.insert(collection, insert_entities)
connect.insert(collection, insert_entities)
@pytest.mark.tags("0331")
def test_insert_with_field_name_not_match(self, connect, collection):
'''
target: test insert entities, with the entity field name updated
@ -400,6 +407,7 @@ class TestInsertBase:
connect.insert(collection, tmp_entity)
@pytest.mark.level(2)
@pytest.mark.tags("0331")
def test_insert_with_field_value_not_match(self, connect, collection):
'''
target: test insert entities, with the entity field value updated
@ -410,6 +418,7 @@ class TestInsertBase:
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@pytest.mark.tags("0331")
def test_insert_with_field_more(self, connect, collection):
'''
target: test insert entities, with more fields than collection schema
@ -420,6 +429,7 @@ class TestInsertBase:
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@pytest.mark.tags("0331")
def test_insert_with_field_vector_more(self, connect, collection):
'''
target: test insert entities, with more fields than collection schema
@ -430,6 +440,7 @@ class TestInsertBase:
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@pytest.mark.tags("0331")
def test_insert_with_field_less(self, connect, collection):
'''
target: test insert entities, with less fields than collection schema
@ -440,6 +451,7 @@ class TestInsertBase:
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@pytest.mark.tags("0331")
def test_insert_with_field_vector_less(self, connect, collection):
'''
target: test insert entities, with less fields than collection schema
@ -450,6 +462,7 @@ class TestInsertBase:
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@pytest.mark.tags("0331")
def test_insert_with_no_field_vector_value(self, connect, collection):
'''
target: test insert entities, with no vector field value
@ -461,6 +474,7 @@ class TestInsertBase:
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@pytest.mark.tags("0331")
def test_insert_with_no_field_vector_type(self, connect, collection):
'''
target: test insert entities, with no vector field type
@ -472,6 +486,7 @@ class TestInsertBase:
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@pytest.mark.tags("0331")
def test_insert_with_no_field_vector_name(self, connect, collection):
'''
target: test insert entities, with no vector field name
@ -537,6 +552,7 @@ class TestInsertBinary:
request.param["metric_type"] = "JACCARD"
return request.param
# @pytest.mark.tags("0331")
def test_insert_binary_entities(self, connect, binary_collection):
'''
target: test insert entities in binary collection
@ -545,10 +561,11 @@ class TestInsertBinary:
'''
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
connect.flush()
connect.flush([binary_collection])
stats = connect.get_collection_stats(binary_collection)
assert stats["row_count"] == default_nb
assert stats[row_count] == default_nb
# @pytest.mark.tags("0331")
def test_insert_binary_partition(self, connect, binary_collection):
'''
target: test insert entities and create partition tag
@ -559,8 +576,9 @@ class TestInsertBinary:
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
assert len(ids) == default_nb
assert connect.has_partition(binary_collection, default_tag)
connect.flush([binary_collection])
stats = connect.get_collection_stats(binary_collection)
assert stats["row_count"] == default_nb
assert stats[row_count] == default_nb
def test_insert_binary_multi_times(self, connect, binary_collection):
'''
@ -573,7 +591,7 @@ class TestInsertBinary:
assert len(ids) == 1
connect.flush([binary_collection])
stats = connect.get_collection_stats(binary_collection)
assert stats["row_count"] == default_nb
assert stats[row_count] == default_nb
def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index):
'''
@ -610,7 +628,8 @@ class TestInsertBinary:
'''
ids = connect.insert(binary_collection, default_binary_entities)
connect.flush([binary_collection])
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1, metric_type="JACCARD")
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1,
metric_type="JACCARD")
connect.load_collection(binary_collection)
res = connect.search(binary_collection, query)
logging.getLogger().debug(res)
@ -638,9 +657,10 @@ class TestInsertAsync:
assert not result
def check_result(self, result):
logging.getLogger().info("In callback check status")
logging.getLogger().info("In callback check results")
assert result
@pytest.mark.tags("0331")
def test_insert_async(self, connect, collection, insert_count):
'''
target: test insert vectors with different length of vectors
@ -654,6 +674,7 @@ class TestInsertAsync:
assert len(ids) == nb
@pytest.mark.level(2)
@pytest.mark.tags("0331")
def test_insert_async_false(self, connect, collection, insert_count):
'''
target: test insert vectors with different length of vectors
@ -666,6 +687,7 @@ class TestInsertAsync:
connect.flush([collection])
assert len(ids) == nb
# @pytest.mark.tags("0331")
def test_insert_async_callback(self, connect, collection, insert_count):
'''
target: test insert vectors with different length of vectors
@ -675,6 +697,8 @@ class TestInsertAsync:
nb = insert_count
future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_status)
future.done()
ids = future.result()
assert len(ids) == nb
@pytest.mark.level(2)
def test_insert_async_long(self, connect, collection):
@ -685,14 +709,15 @@ class TestInsertAsync:
'''
nb = 50000
future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_result)
result = future.result()
assert len(result) == nb
ids = future.result()
assert len(ids) == nb
connect.flush([collection])
stats = connect.get_collection_stats(collection)
logging.getLogger().info(stats)
assert stats["row_count"] == nb
assert stats[row_count] == nb
@pytest.mark.level(2)
# @pytest.mark.tags("0331")
def test_insert_async_callback_timeout(self, connect, collection):
'''
target: test insert vectors with different length of vectors
@ -704,7 +729,7 @@ class TestInsertAsync:
with pytest.raises(Exception) as e:
result = future.result()
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == 0
assert stats[row_count] == 0
def test_insert_async_invalid_params(self, connect):
'''
@ -714,8 +739,9 @@ class TestInsertAsync:
'''
collection_new = gen_unique_str()
future = connect.insert(collection_new, default_entities, _async=True)
future.done()
with pytest.raises(Exception) as e:
result = future.result()
ids = future.result()
def test_insert_async_invalid_params_raise_exception(self, connect, collection):
'''
@ -747,6 +773,7 @@ class TestInsertMultiCollections:
# pytest.skip("sq8h not support in CPU mode")
return request.param
# @pytest.mark.tags("0331")
def test_insert_entity_multi_collections(self, connect):
'''
target: test insert entities
@ -763,9 +790,10 @@ class TestInsertMultiCollections:
connect.flush([collection_name])
assert len(ids) == default_nb
stats = connect.get_collection_stats(collection_name)
assert stats["row_count"] == default_nb
assert stats[row_count] == default_nb
@pytest.mark.timeout(ADD_TIMEOUT)
@pytest.mark.tags("0331")
def test_drop_collection_insert_entity_another(self, connect, collection):
'''
target: test insert vector to collection_1 after collection_2 deleted
@ -780,6 +808,7 @@ class TestInsertMultiCollections:
assert len(ids) == 1
@pytest.mark.timeout(ADD_TIMEOUT)
@pytest.mark.tags("0331")
def test_create_index_insert_entity_another(self, connect, collection, get_simple_index):
'''
target: test insert vector to collection_2 after build index for collection_1
@ -807,7 +836,7 @@ class TestInsertMultiCollections:
index = connect.describe_index(collection_name, field_name)
assert index == get_simple_index
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == 1
assert stats[row_count] == 1
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_entity_sleep_create_index_another(self, connect, collection, get_simple_index):
@ -822,10 +851,10 @@ class TestInsertMultiCollections:
connect.flush([collection])
connect.create_index(collection_name, field_name, get_simple_index)
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == 1
assert stats[row_count] == 1
@pytest.mark.timeout(ADD_TIMEOUT)
def test_search_entity_insert_vector_another(self, connect, collection):
def test_search_entity_insert_entity_another(self, connect, collection):
'''
target: test insert entity to collection_1 after search collection_2
method: search collection and insert entity
@ -838,7 +867,7 @@ class TestInsertMultiCollections:
ids = connect.insert(collection_name, default_entity)
connect.flush()
stats = connect.get_collection_stats(collection_name)
assert stats["row_count"] == 1
assert stats[row_count] == 1
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_entity_search_entity_another(self, connect, collection):
@ -876,9 +905,11 @@ class TestInsertMultiCollections:
connect.insert(collection, default_entities)
connect.flush([collection])
connect.load_collection(collection)
def release():
connect.release_collection(collection)
t = threading.Thread(target=release, args=())
t = threading.Thread(target=release, args=(collection,))
t.start()
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
@ -938,6 +969,7 @@ class TestInsertInvalid(object):
def get_field_vectors_value(self, request):
yield request.param
@pytest.mark.tags("0331")
def test_insert_ids_invalid(self, connect, id_collection, get_entity_id):
'''
target: test insert, with using customize ids, which are not int64
@ -949,11 +981,13 @@ class TestInsertInvalid(object):
with pytest.raises(Exception):
connect.insert(id_collection, default_entities, ids)
@pytest.mark.tags("0331")
def test_insert_with_invalid_collection_name(self, connect, get_collection_name):
collection_name = get_collection_name
with pytest.raises(Exception):
connect.insert(collection_name, default_entity)
@pytest.mark.tags("0331")
def test_insert_with_invalid_partition_name(self, connect, collection, get_tag_name):
tag_name = get_tag_name
connect.create_partition(collection, default_tag)
@ -963,11 +997,13 @@ class TestInsertInvalid(object):
else:
connect.insert(collection, default_entity, partition_tag=tag_name)
@pytest.mark.tags("0331")
def test_insert_with_invalid_field_name(self, connect, collection, get_field_name):
tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", get_field_name)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
# @pytest.mark.tags("0331")
def test_insert_with_invalid_field_type(self, connect, collection, get_field_type):
field_type = get_field_type
tmp_entity = update_field_type(copy.deepcopy(default_entity), 'float', field_type)
@ -980,6 +1016,7 @@ class TestInsertInvalid(object):
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@pytest.mark.tags("0331")
def test_insert_with_invalid_field_entity_value(self, connect, collection, get_field_vectors_value):
tmp_entity = copy.deepcopy(default_entity)
src_vector = tmp_entity[-1]["values"]
@ -1043,12 +1080,14 @@ class TestInsertInvalidBinary(object):
yield request.param
@pytest.mark.level(2)
@pytest.mark.tags("0331")
def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name):
tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
@pytest.mark.level(2)
# @pytest.mark.tags("0331")
def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value):
tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value)
with pytest.raises(Exception):
@ -1063,6 +1102,7 @@ class TestInsertInvalidBinary(object):
connect.insert(binary_collection, tmp_entity)
@pytest.mark.level(2)
@pytest.mark.tags("0331")
def test_insert_ids_invalid(self, connect, binary_id_collection, get_entity_id):
'''
target: test insert, with using customize ids, which are not int64

View File

@ -38,7 +38,7 @@ class TestConnect:
expected: raise an error after disconnected
'''
with pytest.raises(Exception) as e:
connect.close()
dis_connect.close()
@pytest.mark.tags("0331")
def test_connect_correct_ip_port(self, args):

View File

@ -31,6 +31,7 @@ default_float_vec_field_name = "float_vector"
default_binary_vec_field_name = "binary_vector"
default_partition_name = "_default"
default_tag = "1970_01_01"
row_count = "row_count"
# TODO:
# TODO: disable RHNSW_SQ/PQ in 0.11.0
@ -43,6 +44,7 @@ all_index_types = [
"HNSW",
# "NSG",
"ANNOY",
"RHNSW_FLAT",
"RHNSW_PQ",
"RHNSW_SQ",
"BIN_FLAT",
@ -54,10 +56,11 @@ default_index_params = [
{"nlist": 128},
{"nlist": 128},
# {"nlist": 128},
{"nlist": 128, "m": 16},
{"nlist": 128, "m": 16, "nbits": 8},
{"M": 48, "efConstruction": 500},
# {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
{"n_trees": 50},
{"M": 48, "efConstruction": 500},
{"M": 48, "efConstruction": 500, "PQM": 64},
{"M": 48, "efConstruction": 500},
{"nlist": 128},