mirror of https://github.com/milvus-io/milvus.git
Update tests to match removal of dsl search (#9622)
Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>pull/9624/head
parent
8c510a52b2
commit
29af477c5e
|
@ -13,7 +13,7 @@ pytest-print==0.2.1
|
|||
pytest-level==0.1.1
|
||||
pytest-xdist==2.2.1
|
||||
# pytest-parallel
|
||||
pymilvus==2.0.0rc7.dev22
|
||||
pymilvus==2.0.0rc7.dev25
|
||||
pytest-rerunfailures==9.1.1
|
||||
git+https://github.com/Projectplace/pytest-tags
|
||||
ndg-httpsclient
|
||||
|
|
|
@ -17,13 +17,11 @@ default_nb = ut.default_nb
|
|||
row_count = ut.row_count
|
||||
default_tag = ut.default_tag
|
||||
default_single_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"vector": {field_name: {"topk": 10, "query": ut.gen_vectors(1, ut.default_dim), "metric_type": "L2",
|
||||
"params": {"nprobe": 10}}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
"data": ut.gen_vectors(1, ut.default_dim),
|
||||
"anns_field": ut.default_float_vec_field_name,
|
||||
"param": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
|
||||
class TestInsertBase:
|
||||
|
@ -173,7 +171,7 @@ class TestInsertBase:
|
|||
result = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
connect.load_collection(collection)
|
||||
res = connect.search(collection, default_single_query)
|
||||
res = connect.search(collection, **default_single_query)
|
||||
assert len(res[0]) == ut.default_top_k
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
|
@ -673,10 +671,10 @@ class TestInsertBinary:
|
|||
"""
|
||||
result = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
query, vecs = ut.gen_query_vectors(binary_field_name, default_binary_entities,
|
||||
ut.default_top_k, 1, metric_type="JACCARD")
|
||||
query, _ = ut.gen_search_vectors_params(binary_field_name, default_binary_entities,
|
||||
ut.default_top_k, 1, metric_type="JACCARD")
|
||||
connect.load_collection(binary_collection)
|
||||
res = connect.search(binary_collection, query)
|
||||
res = connect.search(binary_collection, **query)
|
||||
logging.getLogger().debug(res)
|
||||
assert len(res[0]) == ut.default_top_k
|
||||
|
||||
|
@ -920,7 +918,7 @@ class TestInsertMultiCollections:
|
|||
collection_name = ut.gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
connect.load_collection(collection)
|
||||
res = connect.search(collection, default_single_query)
|
||||
res = connect.search(collection, **default_single_query)
|
||||
assert len(res[0]) == 0
|
||||
connect.insert(collection_name, default_entity)
|
||||
connect.flush([collection_name])
|
||||
|
@ -940,7 +938,7 @@ class TestInsertMultiCollections:
|
|||
result = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
connect.load_collection(collection_name)
|
||||
res = connect.search(collection_name, default_single_query)
|
||||
res = connect.search(collection_name, **default_single_query)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats[row_count] == 1
|
||||
|
||||
|
@ -957,7 +955,7 @@ class TestInsertMultiCollections:
|
|||
result = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
connect.load_collection(collection_name)
|
||||
res = connect.search(collection_name, default_single_query)
|
||||
res = connect.search(collection_name, **default_single_query)
|
||||
assert len(res[0]) == 0
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
|
@ -1238,3 +1236,4 @@ class TestInsertInvalidBinary(object):
|
|||
src_vector[1] = get_field_vectors_value
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(binary_collection, tmp_entities)
|
||||
|
||||
|
|
|
@ -29,14 +29,11 @@ uid_list = "list_collections"
|
|||
uid_load = "load_collection"
|
||||
field_name = default_float_vec_field_name
|
||||
default_single_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"vector": {field_name: {"topk": default_top_k, "query": gen_vectors(1, default_dim), "metric_type": "L2",
|
||||
"params": {"nprobe": 10}}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
"data": gen_vectors(1, default_dim),
|
||||
"anns_field": default_float_vec_field_name,
|
||||
"param": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"limit": default_top_k,
|
||||
}
|
||||
|
||||
class TestCollectionParams(TestcaseBase):
|
||||
""" Test case of collection interface """
|
||||
|
@ -2841,8 +2838,8 @@ class TestLoadCollection:
|
|||
connect.load_collection(collection)
|
||||
connect.release_partitions(collection, [default_tag])
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.search(collection, default_single_query, partition_names=[default_tag])
|
||||
res = connect.search(collection, default_single_query, partition_names=[default_partition_name])
|
||||
connect.search(collection, **default_single_query, partition_names=[default_tag])
|
||||
res = connect.search(collection, **default_single_query, partition_names=[default_partition_name])
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
|
@ -2860,7 +2857,7 @@ class TestLoadCollection:
|
|||
connect.flush([collection])
|
||||
connect.load_collection(collection)
|
||||
connect.release_partitions(collection, [default_partition_name, default_tag])
|
||||
res = connect.search(collection, default_single_query)
|
||||
res = connect.search(collection, **default_single_query)
|
||||
assert len(res[0]) == 0
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
|
@ -2877,7 +2874,7 @@ class TestLoadCollection:
|
|||
connect.load_partitions(collection, [default_tag])
|
||||
connect.release_collection(collection)
|
||||
with pytest.raises(Exception):
|
||||
connect.search(collection, default_single_query)
|
||||
connect.search(collection, **default_single_query)
|
||||
# assert len(res[0]) == 0
|
||||
|
||||
|
||||
|
@ -2895,11 +2892,11 @@ class TestReleaseAdvanced:
|
|||
connect.insert(collection, cons.default_entities)
|
||||
connect.flush([collection])
|
||||
connect.load_collection(collection)
|
||||
query, _ = gen_query_vectors(field_name, cons.default_entities, top_k, nq)
|
||||
future = connect.search(collection, query, _async=True)
|
||||
params, _ = gen_search_vectors_params(field_name, cons.default_entities, top_k, nq)
|
||||
future = connect.search(collection, **params, _async=True)
|
||||
connect.release_collection(collection)
|
||||
with pytest.raises(Exception):
|
||||
connect.search(collection, default_single_query)
|
||||
connect.search(collection, **default_single_query)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_release_partition_during_searching(self, connect, collection):
|
||||
|
@ -2911,14 +2908,14 @@ class TestReleaseAdvanced:
|
|||
nq = 1000
|
||||
top_k = 1
|
||||
connect.create_partition(collection, default_tag)
|
||||
query, _ = gen_query_vectors(field_name, cons.default_entities, top_k, nq)
|
||||
query, _ = gen_search_vectors_params(field_name, cons.default_entities, top_k, nq)
|
||||
connect.insert(collection, cons.default_entities, partition_name=default_tag)
|
||||
connect.flush([collection])
|
||||
connect.load_partitions(collection, [default_tag])
|
||||
res = connect.search(collection, query, _async=True)
|
||||
res = connect.search(collection, **query, _async=True)
|
||||
connect.release_partitions(collection, [default_tag])
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, default_single_query)
|
||||
res = connect.search(collection, **default_single_query)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_release_collection_during_searching_A(self, connect, collection):
|
||||
|
@ -2930,14 +2927,14 @@ class TestReleaseAdvanced:
|
|||
nq = 1000
|
||||
top_k = 1
|
||||
connect.create_partition(collection, default_tag)
|
||||
query, _ = gen_query_vectors(field_name, cons.default_entities, top_k, nq)
|
||||
query, _ = gen_search_vectors_params(field_name, cons.default_entities, top_k, nq)
|
||||
connect.insert(collection, cons.default_entities, partition_name=default_tag)
|
||||
connect.flush([collection])
|
||||
connect.load_partitions(collection, [default_tag])
|
||||
res = connect.search(collection, query, _async=True)
|
||||
res = connect.search(collection, **query, _async=True)
|
||||
connect.release_collection(collection)
|
||||
with pytest.raises(Exception):
|
||||
connect.search(collection, default_single_query)
|
||||
connect.search(collection, **default_single_query)
|
||||
|
||||
def _test_release_collection_during_loading(self, connect, collection):
|
||||
"""
|
||||
|
@ -2955,7 +2952,7 @@ class TestReleaseAdvanced:
|
|||
t.start()
|
||||
connect.release_collection(collection)
|
||||
with pytest.raises(Exception):
|
||||
connect.search(collection, default_single_query)
|
||||
connect.search(collection, **default_single_query)
|
||||
|
||||
def _test_release_partition_during_loading(self, connect, collection):
|
||||
"""
|
||||
|
@ -2973,7 +2970,7 @@ class TestReleaseAdvanced:
|
|||
t = threading.Thread(target=load, args=())
|
||||
t.start()
|
||||
connect.release_partitions(collection, [default_tag])
|
||||
res = connect.search(collection, default_single_query)
|
||||
res = connect.search(collection,**default_single_query)
|
||||
assert len(res[0]) == 0
|
||||
|
||||
def _test_release_collection_during_inserting(self, connect, collection):
|
||||
|
@ -2993,7 +2990,7 @@ class TestReleaseAdvanced:
|
|||
t.start()
|
||||
connect.release_collection(collection)
|
||||
with pytest.raises(Exception):
|
||||
res = connect.search(collection, default_single_query)
|
||||
res = connect.search(collection, **default_single_query)
|
||||
# assert len(res[0]) == 0
|
||||
|
||||
def _test_release_collection_during_indexing(self, connect, collection):
|
||||
|
@ -3264,3 +3261,4 @@ class TestLoadPartitionInvalid(object):
|
|||
partition_name = get_partition_name
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.load_partitions(collection, [partition_name])
|
||||
|
||||
|
|
|
@ -5,13 +5,11 @@ from common.common_type import CaseLabel
|
|||
|
||||
DELETE_TIMEOUT = 60
|
||||
default_single_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"vector": {default_float_vec_field_name: {"topk": 10, "query": gen_vectors(1, default_dim),
|
||||
"metric_type": "L2", "params": {"nprobe": 10}}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
"data": gen_vectors(1, default_dim),
|
||||
"anns_field": default_float_vec_field_name,
|
||||
"param": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
|
||||
class TestFlushBase:
|
||||
|
@ -177,7 +175,7 @@ class TestFlushBase:
|
|||
res = connect.get_collection_stats(collection)
|
||||
assert res["row_count"] == len(result.primary_keys)
|
||||
connect.load_collection(collection)
|
||||
res = connect.search(collection, default_single_query)
|
||||
res = connect.search(collection, **default_single_query)
|
||||
logging.getLogger().debug(res)
|
||||
assert len(res) == 1
|
||||
assert len(res[0].ids) == 10
|
||||
|
@ -241,14 +239,14 @@ class TestFlushBase:
|
|||
connect.flush([collection])
|
||||
# query_vecs = [vectors[0], vectors[1], vectors[-1]]
|
||||
connect.load_collection(collection)
|
||||
res = connect.search(collection, default_single_query)
|
||||
res = connect.search(collection, **default_single_query)
|
||||
assert len(res) == 1
|
||||
assert len(res[0].ids) == 10
|
||||
assert len(res[0].distances) == 10
|
||||
logging.getLogger().debug(res)
|
||||
# assert res
|
||||
|
||||
# TODO: unable to set config
|
||||
# TODO: unable to set config
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def _test_collection_count_during_flush(self, connect, collection, args):
|
||||
"""
|
||||
|
@ -291,10 +289,10 @@ class TestFlushBase:
|
|||
connect.flush([collection])
|
||||
ids.extend(tmp.primary_keys)
|
||||
nq = 10000
|
||||
query, query_vecs = gen_query_vectors(default_float_vec_field_name, default_entities, default_top_k, nq)
|
||||
query, query_vecs = gen_search_vectors_params(default_float_vec_field_name, default_entities, default_top_k, nq)
|
||||
time.sleep(0.1)
|
||||
connect.load_collection(collection)
|
||||
future = connect.search(collection, query, _async=True)
|
||||
future = connect.search(collection, **query, _async=True)
|
||||
res = future.result()
|
||||
assert res
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
|
@ -411,3 +409,4 @@ class TestCollectionNameInvalid(object):
|
|||
connect.flush()
|
||||
except Exception as e:
|
||||
assert e.args[0] == "Collection name list can not be None or empty"
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ uid = "test_index"
|
|||
BUILD_TIMEOUT = 300
|
||||
field_name = default_float_vec_field_name
|
||||
binary_field_name = default_binary_vec_field_name
|
||||
query, query_vecs = gen_query_vectors(field_name, default_entities, default_top_k, 1)
|
||||
# query = gen_search_vectors_params(field_name, default_entities, default_top_k, 1)
|
||||
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
|
||||
|
||||
|
||||
|
@ -262,6 +262,7 @@ class TestIndexOperation(TestcaseBase):
|
|||
"""
|
||||
pass
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_index_drop_index(self):
|
||||
"""
|
||||
|
@ -501,9 +502,9 @@ class TestIndexBase:
|
|||
nq = get_nq
|
||||
index_type = get_simple_index["index_type"]
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, search_params=search_param)
|
||||
params, _ = gen_search_vectors_params(field_name, default_entities, default_top_k, nq, search_params=search_param)
|
||||
connect.load_collection(collection)
|
||||
res = connect.search(collection, query)
|
||||
res = connect.search(collection, **params)
|
||||
assert len(res) == nq
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
|
@ -701,8 +702,9 @@ class TestIndexBase:
|
|||
nq = get_nq
|
||||
index_type = get_simple_index["index_type"]
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, metric_type=metric_type, search_params=search_param)
|
||||
res = connect.search(collection, query)
|
||||
params, _ = gen_search_vectors_params(field_name, default_entities, default_top_k, nq,
|
||||
metric_type=metric_type, search_params=search_param)
|
||||
res = connect.search(collection, **params)
|
||||
assert len(res) == nq
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
|
@ -1044,10 +1046,11 @@ class TestIndexBinary:
|
|||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
connect.load_collection(binary_collection)
|
||||
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, nq, metric_type="JACCARD")
|
||||
search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD")
|
||||
logging.getLogger().info(search_param)
|
||||
res = connect.search(binary_collection, query, search_params=search_param)
|
||||
params, _ = gen_search_vectors_params(binary_field_name, default_binary_entities, default_top_k, nq,
|
||||
search_params=search_param, metric_type="JACCARD")
|
||||
logging.getLogger().info(params)
|
||||
res = connect.search(binary_collection, **params)
|
||||
assert len(res) == nq
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
|
@ -1282,4 +1285,4 @@ class TestIndexAsync:
|
|||
logging.getLogger().info("before result")
|
||||
res = future.result()
|
||||
# TODO:
|
||||
logging.getLogger().info(res)
|
||||
logging.getLogger().info(res)
|
||||
|
|
|
@ -36,9 +36,9 @@ class TestMixBase:
|
|||
index = connect.describe_index(collection, "")
|
||||
create_target_index(default_index, default_float_vec_field_name)
|
||||
assert index == default_index
|
||||
query, vecs = gen_query_vectors(default_float_vec_field_name, entities, default_top_k, nq)
|
||||
query, vecs = gen_search_vectors_params(default_float_vec_field_name, entities, default_top_k, nq)
|
||||
connect.load_collection(collection)
|
||||
res = connect.search(collection, query)
|
||||
res = connect.search(collection, **query)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
assert res[0]._distances[0] <= epsilon
|
||||
|
@ -198,3 +198,4 @@ def check_id_result(result, id):
|
|||
return id in ids[:limit_in]
|
||||
else:
|
||||
return id in ids
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -381,8 +381,8 @@ def assert_equal_entity(a, b):
|
|||
pass
|
||||
|
||||
|
||||
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
|
||||
metric_type="L2", replace_vecs=None):
|
||||
def gen_search_vectors_params(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
|
||||
metric_type="L2", replace_vecs=None):
|
||||
if rand_vector is True:
|
||||
dimension = len(entities[-1]["values"][0])
|
||||
query_vectors = gen_vectors(nq, dimension)
|
||||
|
@ -390,14 +390,15 @@ def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe":
|
|||
query_vectors = entities[-1]["values"][:nq]
|
||||
if replace_vecs:
|
||||
query_vectors = replace_vecs
|
||||
must_param = {"vector": {field_name: {"topk": top_k, "query": query_vectors, "params": search_params}}}
|
||||
must_param["vector"][field_name]["metric_type"] = metric_type
|
||||
query = {
|
||||
"bool": {
|
||||
"must": [must_param]
|
||||
}
|
||||
|
||||
search_params["metric_type"] = metric_type
|
||||
_params = {
|
||||
"data": query_vectors,
|
||||
"anns_field": field_name,
|
||||
"param": search_params,
|
||||
"limit": top_k,
|
||||
}
|
||||
return query, query_vectors
|
||||
return _params, query_vectors
|
||||
|
||||
|
||||
def update_query_expr(src_query, keep_old=True, expr=None):
|
||||
|
@ -884,13 +885,17 @@ def gen_normal_expressions():
|
|||
def get_search_param(index_type, metric_type="L2"):
|
||||
search_params = {"metric_type": metric_type}
|
||||
if index_type in ivf() or index_type in binary_support():
|
||||
search_params.update({"nprobe": 64})
|
||||
nprobe64 = {"nprobe": 64}
|
||||
search_params.update({"params": nprobe64})
|
||||
elif index_type in ["HNSW", "RHNSW_FLAT", "RHNSW_SQ", "RHNSW_PQ"]:
|
||||
search_params.update({"ef": 64})
|
||||
ef64 = {"ef": 64}
|
||||
search_params.update({"params": ef64})
|
||||
elif index_type == "NSG":
|
||||
search_params.update({"search_length": 100})
|
||||
length100 = {"search_length": 100}
|
||||
search_params.update({"params": length100})
|
||||
elif index_type == "ANNOY":
|
||||
search_params.update({"search_k": 1000})
|
||||
search_k = {"search_k": 1000}
|
||||
search_params.update({"params": search_k})
|
||||
else:
|
||||
logging.getLogger().error("Invalid index_type.")
|
||||
raise Exception("Invalid index_type.")
|
||||
|
@ -1019,3 +1024,4 @@ class MyThread(threading.Thread):
|
|||
super(MyThread, self).join()
|
||||
if self.exc:
|
||||
raise self.exc
|
||||
|
||||
|
|
Loading…
Reference in New Issue