mirror of https://github.com/milvus-io/milvus.git
Debug insert and collection stats
Signed-off-by: ThreadDao <yufen.zong@zilliz.com>pull/4973/head^2
parent
9c5c8f35e8
commit
f064e77f21
|
@ -147,7 +147,6 @@ class TestSearchBase:
|
|||
yield request.param
|
||||
|
||||
# PASS
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
def test_search_flat(self, connect, collection, get_top_k, get_nq):
|
||||
'''
|
||||
target: test basic search function, all the search params is corrent, change top-k value
|
||||
|
@ -258,7 +257,6 @@ class TestSearchBase:
|
|||
assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64")
|
||||
|
||||
# Pass
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
'''
|
||||
|
@ -306,8 +304,8 @@ class TestSearchBase:
|
|||
assert len(res[0]) == default_top_k
|
||||
|
||||
# pass
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
'''
|
||||
target: test basic search function, all the search params is corrent, test all index params, and build
|
||||
|
@ -338,7 +336,6 @@ class TestSearchBase:
|
|||
assert len(res) == nq
|
||||
|
||||
# PASS
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
'''
|
||||
|
@ -424,7 +421,6 @@ class TestSearchBase:
|
|||
assert res[1]._distances[0] > epsilon
|
||||
|
||||
# Pass
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
|
||||
'''
|
||||
|
@ -514,7 +510,6 @@ class TestSearchBase:
|
|||
assert check_id_result(res[0], ids[0])
|
||||
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
|
||||
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
'''
|
||||
|
@ -548,7 +543,6 @@ class TestSearchBase:
|
|||
assert len(res) == nq
|
||||
|
||||
# PASS
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
|
||||
'''
|
||||
|
@ -609,7 +603,6 @@ class TestSearchBase:
|
|||
res = connect.search(collection_name, default_query)
|
||||
|
||||
# PASS
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
def test_search_distance_l2(self, connect, collection):
|
||||
'''
|
||||
target: search collection, and check the result: distance
|
||||
|
@ -629,7 +622,6 @@ class TestSearchBase:
|
|||
assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
|
||||
|
||||
# Pass
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index):
|
||||
'''
|
||||
target: search collection, and check the result: distance
|
||||
|
@ -684,7 +676,6 @@ class TestSearchBase:
|
|||
assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon
|
||||
|
||||
# Pass
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index):
|
||||
'''
|
||||
target: search collection, and check the result: distance
|
||||
|
@ -771,7 +762,6 @@ class TestSearchBase:
|
|||
assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
|
||||
|
||||
# PASS
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_substructure_flat_index(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -790,7 +780,6 @@ class TestSearchBase:
|
|||
assert len(res[0]) == 0
|
||||
|
||||
# PASS
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -810,7 +799,6 @@ class TestSearchBase:
|
|||
assert res[1][0].id == ids[1]
|
||||
|
||||
# PASS
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_superstructure_flat_index(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -829,7 +817,6 @@ class TestSearchBase:
|
|||
assert len(res[0]) == 0
|
||||
|
||||
# PASS
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -851,7 +838,6 @@ class TestSearchBase:
|
|||
assert res[1][0].distance <= epsilon
|
||||
|
||||
# PASS
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_tanimoto_flat_index(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -985,7 +971,6 @@ class TestSearchBase:
|
|||
assert getattr(r.entity, "int64") == getattr(r.entity, "id")
|
||||
|
||||
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
class TestSearchDSL(object):
|
||||
"""
|
||||
******************************************************************
|
||||
|
@ -1579,7 +1564,6 @@ class TestSearchDSL(object):
|
|||
res = connect.search(collection, query)
|
||||
|
||||
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
class TestSearchDSLBools(object):
|
||||
"""
|
||||
******************************************************************
|
||||
|
@ -1766,7 +1750,6 @@ class TestSearchInvalid(object):
|
|||
yield request.param
|
||||
|
||||
# Pass
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
|
||||
'''
|
||||
|
@ -1788,7 +1771,6 @@ class TestSearchInvalid(object):
|
|||
res = connect.search(collection, query)
|
||||
|
||||
# pass
|
||||
@pytest.mark.skip("r0.3-test")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_with_invalid_params_binary(self, connect, binary_collection):
|
||||
'''
|
||||
|
|
|
@ -10,18 +10,19 @@ from constants import *
|
|||
|
||||
uid = "get_collection_stats"
|
||||
|
||||
|
||||
class TestGetCollectionStats:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `collection_stats` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
def get_invalid_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
|
@ -46,6 +47,17 @@ class TestGetCollectionStats:
|
|||
else:
|
||||
pytest.skip("Skip index Temporary")
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
1000,
|
||||
2001
|
||||
],
|
||||
)
|
||||
def insert_count(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_get_collection_stats_name_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: get collection stats where collection name does not exist
|
||||
|
@ -53,22 +65,19 @@ class TestGetCollectionStats:
|
|||
expected: status not ok
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
connect.get_collection_stats(collection_name)
|
||||
connect.drop_collection(collection_name)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.get_collection_stats(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_stats_name_invalid(self, connect, get_collection_name):
|
||||
def test_get_collection_stats_name_invalid(self, connect, get_invalid_collection_name):
|
||||
'''
|
||||
target: get collection stats where collection name is invalid
|
||||
method: call collection_stats with invalid collection_name
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = get_collection_name
|
||||
collection_name = get_invalid_collection_name
|
||||
with pytest.raises(Exception) as e:
|
||||
stats = connect.get_collection_stats(collection_name)
|
||||
connect.get_collection_stats(collection_name)
|
||||
|
||||
def test_get_collection_stats_empty(self, connect, collection):
|
||||
'''
|
||||
|
@ -77,10 +86,17 @@ class TestGetCollectionStats:
|
|||
expected: segment = []
|
||||
'''
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == 0
|
||||
# assert len(stats["partitions"]) == 1
|
||||
# assert stats["partitions"][0]["tag"] == default_partition_name
|
||||
# assert stats["partitions"][0]["row_count"] == 0
|
||||
connect.flush([collection])
|
||||
assert stats[row_count] == 0
|
||||
|
||||
def test_get_collection_stats_without_connection(self, collection, dis_connect):
|
||||
'''
|
||||
target: test count_entities, without connection
|
||||
method: calling count_entities with correct params, with a disconnected instance
|
||||
expected: count_entities raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
dis_connect.get_collection_stats(collection)
|
||||
|
||||
def test_get_collection_stats_batch(self, connect, collection):
|
||||
'''
|
||||
|
@ -89,12 +105,10 @@ class TestGetCollectionStats:
|
|||
expected: count as expected
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
# assert len(stats["partitions"]) == 1
|
||||
# assert stats["partitions"][0]["tag"] == default_partition_name
|
||||
# assert stats["partitions"][0]["row_count"] == default_nb
|
||||
assert int(stats[row_count]) == default_nb
|
||||
|
||||
def test_get_collection_stats_single(self, connect, collection):
|
||||
'''
|
||||
|
@ -104,13 +118,10 @@ class TestGetCollectionStats:
|
|||
'''
|
||||
nb = 10
|
||||
for i in range(nb):
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == nb
|
||||
# assert len(stats["partitions"]) == 1
|
||||
# assert stats["partitions"][0]["tag"] == default_partition_name
|
||||
# assert stats["partitions"][0]["row_count"] == nb
|
||||
assert stats[row_count] == nb
|
||||
|
||||
@pytest.mark.skip("delete_by_id not support yet")
|
||||
def test_get_collection_stats_after_delete(self, connect, collection):
|
||||
|
@ -184,12 +195,10 @@ class TestGetCollectionStats:
|
|||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
# assert len(stats["partitions"]) == 2
|
||||
# assert stats["partitions"][1]["tag"] == default_tag
|
||||
# assert stats["partitions"][1]["row_count"] == default_nb
|
||||
assert stats[row_count] == default_nb
|
||||
|
||||
def test_get_collection_stats_partitions(self, connect, collection):
|
||||
'''
|
||||
|
@ -200,26 +209,88 @@ class TestGetCollectionStats:
|
|||
new_tag = "new_tag"
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, new_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
# for partition in stats["partitions"]:
|
||||
# if partition["tag"] == default_tag:
|
||||
# assert partition["row_count"] == default_nb
|
||||
# else:
|
||||
# assert partition["row_count"] == 0
|
||||
ids = connect.insert(collection, default_entities, partition_tag=new_tag)
|
||||
assert stats[row_count] == default_nb
|
||||
connect.insert(collection, default_entities, partition_tag=new_tag)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == default_nb * 2
|
||||
# for partition in stats["partitions"]:
|
||||
# if partition["tag"] in [default_tag, new_tag]:
|
||||
# assert partition["row_count"] == default_nb
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert stats[row_count] == default_nb * 2
|
||||
connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == default_nb * 3
|
||||
assert stats[row_count] == default_nb * 3
|
||||
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_get_collection_stats_partitions_A(self, connect, collection, insert_count):
|
||||
'''
|
||||
target: test collection rows_count is correct or not
|
||||
method: create collection, create partitions and add entities in it,
|
||||
assert the value returned by count_entities method is equal to length of entities
|
||||
expected: the count is equal to the length of entities
|
||||
'''
|
||||
new_tag = "new_tag"
|
||||
entities = gen_entities(insert_count)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, new_tag)
|
||||
connect.insert(collection, entities)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats[row_count] == insert_count
|
||||
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_get_collection_stats_partitions_B(self, connect, collection, insert_count):
|
||||
'''
|
||||
target: test collection rows_count is correct or not
|
||||
method: create collection, create partitions and add entities in one of the partitions,
|
||||
assert the value returned by count_entities method is equal to length of entities
|
||||
expected: the count is equal to the length of entities
|
||||
'''
|
||||
new_tag = "new_tag"
|
||||
entities = gen_entities(insert_count)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, new_tag)
|
||||
connect.insert(collection, entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats[row_count] == insert_count
|
||||
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_get_collection_stats_partitions_C(self, connect, collection, insert_count):
|
||||
'''
|
||||
target: test collection rows_count is correct or not
|
||||
method: create collection, create partitions and add entities in one of the partitions,
|
||||
assert the value returned by count_entities method is equal to length of entities
|
||||
expected: the count is equal to the length of vectors
|
||||
'''
|
||||
new_tag = "new_tag"
|
||||
entities = gen_entities(insert_count)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, new_tag)
|
||||
connect.insert(collection, entities)
|
||||
connect.insert(collection, entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats[row_count] == insert_count*2
|
||||
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_get_collection_stats_partitions_D(self, connect, collection, insert_count):
|
||||
'''
|
||||
target: test collection rows_count is correct or not
|
||||
method: create collection, create partitions and add entities in one of the partitions,
|
||||
assert the value returned by count_entities method is equal to length of entities
|
||||
expected: the collection count is equal to the length of entities
|
||||
'''
|
||||
new_tag = "new_tag"
|
||||
entities = gen_entities(insert_count)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, new_tag)
|
||||
connect.insert(collection, entities, partition_tag=default_tag)
|
||||
connect.insert(collection, entities, partition_tag=new_tag)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats[row_count] == insert_count*2
|
||||
|
||||
# TODO: assert metric type in stats response
|
||||
def test_get_collection_stats_after_index_created(self, connect, collection, get_simple_index):
|
||||
|
@ -228,17 +299,11 @@ class TestGetCollectionStats:
|
|||
method: create collection, add vectors, create index and call collection_stats
|
||||
expected: status ok, index created and shown in segments
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
logging.getLogger().info(stats)
|
||||
assert stats["row_count"] == default_nb
|
||||
# for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
# if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
# assert file["data_size"] > 0
|
||||
# assert file["index_type"] == get_simple_index["index_type"]
|
||||
# break
|
||||
assert stats[row_count] == default_nb
|
||||
|
||||
# TODO: assert metric type in stats response
|
||||
def test_get_collection_stats_after_index_created_ip(self, connect, collection, get_simple_index):
|
||||
|
@ -249,16 +314,12 @@ class TestGetCollectionStats:
|
|||
'''
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
get_simple_index.update({"metric_type": "IP"})
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
# for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
# if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
# assert file["data_size"] > 0
|
||||
# assert file["index_type"] == get_simple_index["index_type"]
|
||||
# break
|
||||
assert stats[row_count] == default_nb
|
||||
|
||||
# TODO: assert metric type in stats response
|
||||
def test_get_collection_stats_after_index_created_jac(self, connect, binary_collection, get_jaccard_index):
|
||||
|
@ -269,14 +330,9 @@ class TestGetCollectionStats:
|
|||
'''
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, "binary_vector", get_jaccard_index)
|
||||
connect.create_index(binary_collection, default_binary_vec_field_name, get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
# for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
# if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
# assert file["data_size"] > 0
|
||||
# assert file["index_type"] == get_simple_index["index_type"]
|
||||
# break
|
||||
assert stats[row_count] == default_nb
|
||||
|
||||
def test_get_collection_stats_after_create_different_index(self, connect, collection):
|
||||
'''
|
||||
|
@ -288,14 +344,9 @@ class TestGetCollectionStats:
|
|||
connect.flush([collection])
|
||||
for index_type in ["IVF_FLAT", "IVF_SQ8"]:
|
||||
connect.create_index(collection, default_float_vec_field_name,
|
||||
{"index_type": index_type, "params":{"nlist": 1024}, "metric_type": "L2"})
|
||||
{"index_type": index_type, "params": {"nlist": 1024}, "metric_type": "L2"})
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
# for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
# if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
# assert file["data_size"] > 0
|
||||
# assert file["index_type"] == index_type
|
||||
# break
|
||||
assert stats[row_count] == default_nb
|
||||
|
||||
def test_collection_count_multi_collections(self, connect):
|
||||
'''
|
||||
|
@ -310,12 +361,11 @@ class TestGetCollectionStats:
|
|||
collection_name = gen_unique_str(uid)
|
||||
collection_list.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
res = connect.insert(collection_name, default_entities)
|
||||
ids = connect.insert(collection_name, default_entities)
|
||||
connect.flush(collection_list)
|
||||
for i in range(collection_num):
|
||||
stats = connect.get_collection_stats(collection_list[i])
|
||||
# assert stats["partitions"][0]["row_count"] == default_nb
|
||||
assert stats["row_count"] == default_nb
|
||||
assert stats[row_count] == default_nb
|
||||
connect.drop_collection(collection_list[i])
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -334,23 +384,19 @@ class TestGetCollectionStats:
|
|||
connect.create_collection(collection_name, default_fields)
|
||||
res = connect.insert(collection_name, default_entities)
|
||||
connect.flush(collection_list)
|
||||
index_1 = {"index_type": "IVF_SQ8", "params": {"nlist": 1024}, "metric_type": "L2"}
|
||||
index_2 = {"index_type": "IVF_FLAT", "params": {"nlist": 1024}, "metric_type": "L2"}
|
||||
if i % 2:
|
||||
connect.create_index(collection_name, default_float_vec_field_name,
|
||||
{"index_type": "IVF_SQ8", "params":{"nlist": 1024}, "metric_type": "L2"})
|
||||
connect.create_index(collection_name, default_float_vec_field_name, index_1)
|
||||
else:
|
||||
connect.create_index(collection_name, default_float_vec_field_name,
|
||||
{"index_type": "IVF_FLAT","params":{"nlist": 1024}, "metric_type": "L2"})
|
||||
connect.create_index(collection_name, default_float_vec_field_name, index_2)
|
||||
for i in range(collection_num):
|
||||
stats = connect.get_collection_stats(collection_list[i])
|
||||
assert stats["row_count"] == default_nb
|
||||
# if i % 2:
|
||||
# for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
# if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
# assert file["index_type"] == "IVF_SQ8"
|
||||
# break
|
||||
# else:
|
||||
# for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
# if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
# assert file["index_type"] == "IVF_FLAT"
|
||||
# break
|
||||
assert stats[row_count] == default_nb
|
||||
index = connect.describe_index(collection_list[i], default_float_vec_field_name)
|
||||
if i % 2:
|
||||
assert index == index_1
|
||||
else:
|
||||
assert index == index_2
|
||||
# break
|
||||
connect.drop_collection(collection_list[i])
|
||||
|
|
|
@ -6,19 +6,20 @@ import time
|
|||
import threading
|
||||
from multiprocessing import Process
|
||||
import sklearn.preprocessing
|
||||
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
uid = "create_collection"
|
||||
|
||||
|
||||
class TestCreateCollection:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_collection` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
|
@ -52,8 +53,8 @@ class TestCreateCollection:
|
|||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
# "segment_row_limit": default_segment_row_limit
|
||||
"fields": [filter_field, vector_field],
|
||||
# "segment_row_limit": default_segment_row_limit
|
||||
}
|
||||
logging.getLogger().info(fields)
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
@ -93,7 +94,7 @@ class TestCreateCollection:
|
|||
expected: error raised
|
||||
'''
|
||||
connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
# connect.flush([collection])
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection, default_fields)
|
||||
|
||||
|
@ -140,7 +141,7 @@ class TestCreateCollection:
|
|||
method: create collection using multithread,
|
||||
expected: collections are created
|
||||
'''
|
||||
threads_num = 8
|
||||
threads_num = 8
|
||||
threads = []
|
||||
collection_names = []
|
||||
|
||||
|
@ -148,6 +149,7 @@ class TestCreateCollection:
|
|||
collection_name = gen_unique_str(uid)
|
||||
collection_names.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
for i in range(threads_num):
|
||||
t = TestThread(target=create, args=())
|
||||
threads.append(t)
|
||||
|
@ -155,7 +157,7 @@ class TestCreateCollection:
|
|||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
|
||||
for item in collection_names:
|
||||
assert item in connect.list_collections()
|
||||
connect.drop_collection(item)
|
||||
|
@ -165,6 +167,7 @@ class TestCreateCollectionInvalid(object):
|
|||
"""
|
||||
Test creating collections with invalid params
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_metric_types()
|
||||
|
@ -217,7 +220,7 @@ class TestCreateCollectionInvalid(object):
|
|||
fields = copy.deepcopy(default_fields)
|
||||
fields["fields"][-1]["params"]["dim"] = dimension
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.tags("0331")
|
||||
|
|
|
@ -5,7 +5,7 @@ import copy
|
|||
import threading
|
||||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from milvus import DataType
|
||||
from milvus import DataType, ParamError, BaseException
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
|
@ -36,8 +36,9 @@ class TestInsertBase:
|
|||
)
|
||||
def get_simple_index(self, request, connect):
|
||||
# if str(connect._cmd("mode")) == "CPU":
|
||||
# if request.param["index_type"] in index_cpu_not_support():
|
||||
# pytest.skip("CPU not support index_type: ivf_sq8h")
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("CPU not support index_type: ivf_sq8h")
|
||||
logging.getLogger().info(request.param)
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(
|
||||
|
@ -54,6 +55,7 @@ class TestInsertBase:
|
|||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_empty_entity(self, connect, collection):
|
||||
'''
|
||||
target: test insert with empty entity list
|
||||
|
@ -62,7 +64,7 @@ class TestInsertBase:
|
|||
'''
|
||||
entities = []
|
||||
with pytest.raises(ParamError) as e:
|
||||
status, ids = connect.insert(collection, entities)
|
||||
connect.insert(collection, entities)
|
||||
|
||||
def test_insert_with_None(self, connect, collection):
|
||||
'''
|
||||
|
@ -71,10 +73,11 @@ class TestInsertBase:
|
|||
expected: raises a ParamError
|
||||
'''
|
||||
entity = None
|
||||
with pytest.raises(Exception) as e:
|
||||
status, ids = connect.insert(collection, entity)
|
||||
with pytest.raises(ParamError) as e:
|
||||
connect.insert(collection, entity)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_collection_not_existed(self, connect):
|
||||
'''
|
||||
target: test insert, with collection not existed
|
||||
|
@ -82,10 +85,11 @@ class TestInsertBase:
|
|||
expected: raise a BaseException
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
with pytest.raises(BaseException) as e:
|
||||
connect.insert(collection_name, default_entities)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_without_connect(self, dis_connect, collection):
|
||||
'''
|
||||
target: test insert entities without connection
|
||||
|
@ -93,28 +97,30 @@ class TestInsertBase:
|
|||
expected: raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
ids = dis_connect.insert(collection, default_entities)
|
||||
dis_connect.insert(collection, default_entities)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_drop_collection(self, connect, collection):
|
||||
'''
|
||||
target: test delete collection after insert entities
|
||||
method: insert entities and drop collection
|
||||
expected: has_collection false
|
||||
'''
|
||||
ids = connect.insert(collection, default_entity_row)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
assert len(ids) == 1
|
||||
connect.drop_collection(collection)
|
||||
assert connect.has_collection(collection) == False
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_flush_drop_collection(self, connect, collection):
|
||||
'''
|
||||
target: test drop collection after insert entities for a while
|
||||
method: insert entities, sleep, and delete collection
|
||||
expected: has_collection false
|
||||
'''
|
||||
ids = connect.insert(collection, default_entity_row)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
assert len(ids) == 1
|
||||
connect.flush([collection])
|
||||
connect.drop_collection(collection)
|
||||
|
@ -133,10 +139,6 @@ class TestInsertBase:
|
|||
connect.create_index(collection, field_name, get_simple_index)
|
||||
index = connect.describe_index(collection, field_name)
|
||||
assert index == get_simple_index
|
||||
# fields = info["fields"]
|
||||
# for field in fields:
|
||||
# if field["name"] == field_name:
|
||||
# assert field["indexes"][0] == get_simple_index
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_after_create_index(self, connect, collection, get_simple_index):
|
||||
|
@ -200,9 +202,10 @@ class TestInsertBase:
|
|||
assert len(res_ids) == nb
|
||||
assert res_ids == ids
|
||||
stats = connect.get_collection_stats(id_collection)
|
||||
assert stats["row_count"] == nb
|
||||
assert stats[row_count] == nb
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_insert_the_same_ids(self, connect, id_collection, insert_count):
|
||||
'''
|
||||
target: test insert vectors in collection, use customize the same ids
|
||||
|
@ -216,7 +219,7 @@ class TestInsertBase:
|
|||
assert len(res_ids) == nb
|
||||
assert res_ids == ids
|
||||
stats = connect.get_collection_stats(id_collection)
|
||||
assert stats["row_count"] == nb
|
||||
assert stats[row_count] == nb
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_ids_fields(self, connect, get_filter_field, get_vector_field):
|
||||
|
@ -231,16 +234,17 @@ class TestInsertBase:
|
|||
collection_name = gen_unique_str("test_collection")
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"auto_id": True
|
||||
"auto_id": False
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
ids = [i for i in range(nb)]
|
||||
entities = gen_entities_by_fields(fields["fields"], nb, dim)
|
||||
entities = gen_entities_by_fields(fields["fields"], nb, default_dim)
|
||||
logging.getLogger().info(entities)
|
||||
res_ids = connect.insert(collection_name, entities, ids)
|
||||
assert res_ids == ids
|
||||
connect.flush([collection_name])
|
||||
stats = connect.get_collection_stats(id_collection)
|
||||
assert stats["row_count"] == nb
|
||||
stats = connect.get_collection_stats(collection_name)
|
||||
assert stats[row_count] == nb
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_ids_not_match(self, connect, id_collection, insert_count):
|
||||
|
@ -250,7 +254,7 @@ class TestInsertBase:
|
|||
expected: exception raised
|
||||
'''
|
||||
nb = insert_count
|
||||
with pytest.raises(Exception) as e:
|
||||
with pytest.raises(BaseException) as e:
|
||||
connect.insert(id_collection, gen_entities(nb))
|
||||
|
||||
# TODO
|
||||
|
@ -262,9 +266,9 @@ class TestInsertBase:
|
|||
expected: BaseException raised
|
||||
'''
|
||||
ids = [i for i in range(default_nb)]
|
||||
res_ids = connect.insert(id_collection, default_entities, ids)
|
||||
with pytest.raises(Exception) as e:
|
||||
res_ids_new = connect.insert(id_collection, default_entities)
|
||||
connect.insert(collection, default_entities, ids)
|
||||
with pytest.raises(BaseException) as e:
|
||||
connect.insert(collection, default_entities)
|
||||
|
||||
# TODO: assert exception && enable
|
||||
@pytest.mark.level(2)
|
||||
|
@ -275,8 +279,8 @@ class TestInsertBase:
|
|||
method: test insert vectors twice, use not ids first, and then use customize ids
|
||||
expected: error raised
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
res_ids = connect.insert(id_collection, default_entities)
|
||||
with pytest.raises(BaseException) as e:
|
||||
connect.insert(id_collection, default_entities)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_ids_length_not_match_batch(self, connect, id_collection):
|
||||
|
@ -287,8 +291,8 @@ class TestInsertBase:
|
|||
'''
|
||||
ids = [i for i in range(1, default_nb)]
|
||||
logging.getLogger().info(len(ids))
|
||||
with pytest.raises(Exception) as e:
|
||||
res_ids = connect.insert(id_collection, default_entities, ids)
|
||||
with pytest.raises(BaseException) as e:
|
||||
connect.insert(id_collection, default_entities, ids)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_ids_length_not_match_single(self, connect, id_collection):
|
||||
|
@ -299,8 +303,8 @@ class TestInsertBase:
|
|||
'''
|
||||
ids = [i for i in range(1, default_nb)]
|
||||
logging.getLogger().info(len(ids))
|
||||
with pytest.raises(Exception) as e:
|
||||
res_ids = connect.insert(id_collection, default_entity, ids)
|
||||
with pytest.raises(BaseException) as e:
|
||||
connect.insert(id_collection, default_entity, ids)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_partition(self, connect, collection):
|
||||
|
@ -313,9 +317,9 @@ class TestInsertBase:
|
|||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
assert len(ids) == default_nb
|
||||
assert connect.has_partition(collection, default_tag)
|
||||
connect.flush([collection_name])
|
||||
stats = connect.get_collection_stats(id_collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats[row_count] == default_nb
|
||||
|
||||
# TODO
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
|
@ -331,17 +335,18 @@ class TestInsertBase:
|
|||
assert res_ids == ids
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_default_partition(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities into default partition
|
||||
method: create partition and insert info collection without tag params
|
||||
expected: the collection row count equals to nb
|
||||
'''
|
||||
default_tag = "_default"
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(BaseException) as e:
|
||||
connect.create_partition(collection, default_partition_name)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_partition_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities in collection created before
|
||||
|
@ -350,7 +355,7 @@ class TestInsertBase:
|
|||
'''
|
||||
tag = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
ids = connect.insert(collection, default_entities, partition_tag=tag)
|
||||
connect.insert(collection, default_entities, partition_tag=tag)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_partition_repeatedly(self, connect, collection):
|
||||
|
@ -364,7 +369,7 @@ class TestInsertBase:
|
|||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
res = connect.get_collection_stats(collection)
|
||||
assert res["row_count"] == 2 * default_nb
|
||||
assert res[row_count] == 2 * default_nb
|
||||
|
||||
def test_insert_dim_not_matched(self, connect, collection):
|
||||
'''
|
||||
|
@ -375,9 +380,11 @@ class TestInsertBase:
|
|||
vectors = gen_vectors(default_nb, int(default_dim) // 2)
|
||||
insert_entities = copy.deepcopy(default_entities)
|
||||
insert_entities[-1][default_float_vec_field_name] = vectors
|
||||
# logging.getLogger().info(len(insert_entities[-1][default_float_vec_field_name][0]))
|
||||
with pytest.raises(Exception) as e:
|
||||
ids = connect.insert(collection, insert_entities)
|
||||
connect.insert(collection, insert_entities)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_field_name_not_match(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities, with the entity field name updated
|
||||
|
@ -400,6 +407,7 @@ class TestInsertBase:
|
|||
connect.insert(collection, tmp_entity)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_field_value_not_match(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities, with the entity field value updated
|
||||
|
@ -410,6 +418,7 @@ class TestInsertBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_field_more(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities, with more fields than collection schema
|
||||
|
@ -420,6 +429,7 @@ class TestInsertBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_field_vector_more(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities, with more fields than collection schema
|
||||
|
@ -430,6 +440,7 @@ class TestInsertBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_field_less(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities, with less fields than collection schema
|
||||
|
@ -440,6 +451,7 @@ class TestInsertBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_field_vector_less(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities, with less fields than collection schema
|
||||
|
@ -450,6 +462,7 @@ class TestInsertBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_no_field_vector_value(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities, with no vector field value
|
||||
|
@ -461,6 +474,7 @@ class TestInsertBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_no_field_vector_type(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities, with no vector field type
|
||||
|
@ -472,6 +486,7 @@ class TestInsertBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_no_field_vector_name(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities, with no vector field name
|
||||
|
@ -537,6 +552,7 @@ class TestInsertBinary:
|
|||
request.param["metric_type"] = "JACCARD"
|
||||
return request.param
|
||||
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_insert_binary_entities(self, connect, binary_collection):
|
||||
'''
|
||||
target: test insert entities in binary collection
|
||||
|
@ -545,10 +561,11 @@ class TestInsertBinary:
|
|||
'''
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush()
|
||||
connect.flush([binary_collection])
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
assert stats[row_count] == default_nb
|
||||
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_insert_binary_partition(self, connect, binary_collection):
|
||||
'''
|
||||
target: test insert entities and create partition tag
|
||||
|
@ -559,8 +576,9 @@ class TestInsertBinary:
|
|||
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
|
||||
assert len(ids) == default_nb
|
||||
assert connect.has_partition(binary_collection, default_tag)
|
||||
connect.flush([binary_collection])
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
assert stats[row_count] == default_nb
|
||||
|
||||
def test_insert_binary_multi_times(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -573,7 +591,7 @@ class TestInsertBinary:
|
|||
assert len(ids) == 1
|
||||
connect.flush([binary_collection])
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
assert stats["row_count"] == default_nb
|
||||
assert stats[row_count] == default_nb
|
||||
|
||||
def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index):
|
||||
'''
|
||||
|
@ -610,7 +628,8 @@ class TestInsertBinary:
|
|||
'''
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1, metric_type="JACCARD")
|
||||
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1,
|
||||
metric_type="JACCARD")
|
||||
connect.load_collection(binary_collection)
|
||||
res = connect.search(binary_collection, query)
|
||||
logging.getLogger().debug(res)
|
||||
|
@ -638,9 +657,10 @@ class TestInsertAsync:
|
|||
assert not result
|
||||
|
||||
def check_result(self, result):
|
||||
logging.getLogger().info("In callback check status")
|
||||
logging.getLogger().info("In callback check results")
|
||||
assert result
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_async(self, connect, collection, insert_count):
|
||||
'''
|
||||
target: test insert vectors with different length of vectors
|
||||
|
@ -654,6 +674,7 @@ class TestInsertAsync:
|
|||
assert len(ids) == nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_async_false(self, connect, collection, insert_count):
|
||||
'''
|
||||
target: test insert vectors with different length of vectors
|
||||
|
@ -666,6 +687,7 @@ class TestInsertAsync:
|
|||
connect.flush([collection])
|
||||
assert len(ids) == nb
|
||||
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_insert_async_callback(self, connect, collection, insert_count):
|
||||
'''
|
||||
target: test insert vectors with different length of vectors
|
||||
|
@ -675,6 +697,8 @@ class TestInsertAsync:
|
|||
nb = insert_count
|
||||
future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_status)
|
||||
future.done()
|
||||
ids = future.result()
|
||||
assert len(ids) == nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_async_long(self, connect, collection):
|
||||
|
@ -685,14 +709,15 @@ class TestInsertAsync:
|
|||
'''
|
||||
nb = 50000
|
||||
future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_result)
|
||||
result = future.result()
|
||||
assert len(result) == nb
|
||||
ids = future.result()
|
||||
assert len(ids) == nb
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
logging.getLogger().info(stats)
|
||||
assert stats["row_count"] == nb
|
||||
assert stats[row_count] == nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_insert_async_callback_timeout(self, connect, collection):
|
||||
'''
|
||||
target: test insert vectors with different length of vectors
|
||||
|
@ -704,7 +729,7 @@ class TestInsertAsync:
|
|||
with pytest.raises(Exception) as e:
|
||||
result = future.result()
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == 0
|
||||
assert stats[row_count] == 0
|
||||
|
||||
def test_insert_async_invalid_params(self, connect):
|
||||
'''
|
||||
|
@ -714,8 +739,9 @@ class TestInsertAsync:
|
|||
'''
|
||||
collection_new = gen_unique_str()
|
||||
future = connect.insert(collection_new, default_entities, _async=True)
|
||||
future.done()
|
||||
with pytest.raises(Exception) as e:
|
||||
result = future.result()
|
||||
ids = future.result()
|
||||
|
||||
def test_insert_async_invalid_params_raise_exception(self, connect, collection):
|
||||
'''
|
||||
|
@ -747,6 +773,7 @@ class TestInsertMultiCollections:
|
|||
# pytest.skip("sq8h not support in CPU mode")
|
||||
return request.param
|
||||
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_insert_entity_multi_collections(self, connect):
|
||||
'''
|
||||
target: test insert entities
|
||||
|
@ -763,9 +790,10 @@ class TestInsertMultiCollections:
|
|||
connect.flush([collection_name])
|
||||
assert len(ids) == default_nb
|
||||
stats = connect.get_collection_stats(collection_name)
|
||||
assert stats["row_count"] == default_nb
|
||||
assert stats[row_count] == default_nb
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_drop_collection_insert_entity_another(self, connect, collection):
|
||||
'''
|
||||
target: test insert vector to collection_1 after collection_2 deleted
|
||||
|
@ -780,6 +808,7 @@ class TestInsertMultiCollections:
|
|||
assert len(ids) == 1
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_create_index_insert_entity_another(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test insert vector to collection_2 after build index for collection_1
|
||||
|
@ -807,7 +836,7 @@ class TestInsertMultiCollections:
|
|||
index = connect.describe_index(collection_name, field_name)
|
||||
assert index == get_simple_index
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == 1
|
||||
assert stats[row_count] == 1
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_entity_sleep_create_index_another(self, connect, collection, get_simple_index):
|
||||
|
@ -822,10 +851,10 @@ class TestInsertMultiCollections:
|
|||
connect.flush([collection])
|
||||
connect.create_index(collection_name, field_name, get_simple_index)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == 1
|
||||
assert stats[row_count] == 1
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_search_entity_insert_vector_another(self, connect, collection):
|
||||
def test_search_entity_insert_entity_another(self, connect, collection):
|
||||
'''
|
||||
target: test insert entity to collection_1 after search collection_2
|
||||
method: search collection and insert entity
|
||||
|
@ -838,7 +867,7 @@ class TestInsertMultiCollections:
|
|||
ids = connect.insert(collection_name, default_entity)
|
||||
connect.flush()
|
||||
stats = connect.get_collection_stats(collection_name)
|
||||
assert stats["row_count"] == 1
|
||||
assert stats[row_count] == 1
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_entity_search_entity_another(self, connect, collection):
|
||||
|
@ -876,9 +905,11 @@ class TestInsertMultiCollections:
|
|||
connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
connect.load_collection(collection)
|
||||
|
||||
def release():
|
||||
connect.release_collection(collection)
|
||||
t = threading.Thread(target=release, args=())
|
||||
|
||||
t = threading.Thread(target=release, args=(collection,))
|
||||
t.start()
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
|
@ -938,6 +969,7 @@ class TestInsertInvalid(object):
|
|||
def get_field_vectors_value(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_ids_invalid(self, connect, id_collection, get_entity_id):
|
||||
'''
|
||||
target: test insert, with using customize ids, which are not int64
|
||||
|
@ -949,11 +981,13 @@ class TestInsertInvalid(object):
|
|||
with pytest.raises(Exception):
|
||||
connect.insert(id_collection, default_entities, ids)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_invalid_collection_name(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection_name, default_entity)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_invalid_partition_name(self, connect, collection, get_tag_name):
|
||||
tag_name = get_tag_name
|
||||
connect.create_partition(collection, default_tag)
|
||||
|
@ -963,11 +997,13 @@ class TestInsertInvalid(object):
|
|||
else:
|
||||
connect.insert(collection, default_entity, partition_tag=tag_name)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_invalid_field_name(self, connect, collection, get_field_name):
|
||||
tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", get_field_name)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_insert_with_invalid_field_type(self, connect, collection, get_field_type):
|
||||
field_type = get_field_type
|
||||
tmp_entity = update_field_type(copy.deepcopy(default_entity), 'float', field_type)
|
||||
|
@ -980,6 +1016,7 @@ class TestInsertInvalid(object):
|
|||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_invalid_field_entity_value(self, connect, collection, get_field_vectors_value):
|
||||
tmp_entity = copy.deepcopy(default_entity)
|
||||
src_vector = tmp_entity[-1]["values"]
|
||||
|
@ -1043,12 +1080,14 @@ class TestInsertInvalidBinary(object):
|
|||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name):
|
||||
tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(binary_collection, tmp_entity)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
# @pytest.mark.tags("0331")
|
||||
def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value):
|
||||
tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value)
|
||||
with pytest.raises(Exception):
|
||||
|
@ -1063,6 +1102,7 @@ class TestInsertInvalidBinary(object):
|
|||
connect.insert(binary_collection, tmp_entity)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_insert_ids_invalid(self, connect, binary_id_collection, get_entity_id):
|
||||
'''
|
||||
target: test insert, with using customize ids, which are not int64
|
||||
|
|
|
@ -38,7 +38,7 @@ class TestConnect:
|
|||
expected: raise an error after disconnected
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.close()
|
||||
dis_connect.close()
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_connect_correct_ip_port(self, args):
|
||||
|
|
|
@ -31,6 +31,7 @@ default_float_vec_field_name = "float_vector"
|
|||
default_binary_vec_field_name = "binary_vector"
|
||||
default_partition_name = "_default"
|
||||
default_tag = "1970_01_01"
|
||||
row_count = "row_count"
|
||||
|
||||
# TODO:
|
||||
# TODO: disable RHNSW_SQ/PQ in 0.11.0
|
||||
|
@ -43,6 +44,7 @@ all_index_types = [
|
|||
"HNSW",
|
||||
# "NSG",
|
||||
"ANNOY",
|
||||
"RHNSW_FLAT",
|
||||
"RHNSW_PQ",
|
||||
"RHNSW_SQ",
|
||||
"BIN_FLAT",
|
||||
|
@ -54,10 +56,11 @@ default_index_params = [
|
|||
{"nlist": 128},
|
||||
{"nlist": 128},
|
||||
# {"nlist": 128},
|
||||
{"nlist": 128, "m": 16},
|
||||
{"nlist": 128, "m": 16, "nbits": 8},
|
||||
{"M": 48, "efConstruction": 500},
|
||||
# {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
|
||||
{"n_trees": 50},
|
||||
{"M": 48, "efConstruction": 500},
|
||||
{"M": 48, "efConstruction": 500, "PQM": 64},
|
||||
{"M": 48, "efConstruction": 500},
|
||||
{"nlist": 128},
|
||||
|
|
Loading…
Reference in New Issue