Update tests to match removal of dsl search (#9622)

Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
pull/9624/head
yanliang567 2021-10-11 15:32:55 +08:00 committed by GitHub
parent 8c510a52b2
commit 29af477c5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 158 additions and 795 deletions

View File

@ -13,7 +13,7 @@ pytest-print==0.2.1
pytest-level==0.1.1
pytest-xdist==2.2.1
# pytest-parallel
pymilvus==2.0.0rc7.dev22
pymilvus==2.0.0rc7.dev25
pytest-rerunfailures==9.1.1
git+https://github.com/Projectplace/pytest-tags
ndg-httpsclient

View File

@ -17,13 +17,11 @@ default_nb = ut.default_nb
row_count = ut.row_count
default_tag = ut.default_tag
default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": 10, "query": ut.gen_vectors(1, ut.default_dim), "metric_type": "L2",
"params": {"nprobe": 10}}}}
]
}
}
"data": ut.gen_vectors(1, ut.default_dim),
"anns_field": ut.default_float_vec_field_name,
"param": {"metric_type": "L2", "params": {"nprobe": 10}},
"limit": 10,
}
class TestInsertBase:
@ -173,7 +171,7 @@ class TestInsertBase:
result = connect.insert(collection, default_entities)
connect.flush([collection])
connect.load_collection(collection)
res = connect.search(collection, default_single_query)
res = connect.search(collection, **default_single_query)
assert len(res[0]) == ut.default_top_k
@pytest.mark.tags(CaseLabel.L2)
@ -673,10 +671,10 @@ class TestInsertBinary:
"""
result = connect.insert(binary_collection, default_binary_entities)
connect.flush([binary_collection])
query, vecs = ut.gen_query_vectors(binary_field_name, default_binary_entities,
ut.default_top_k, 1, metric_type="JACCARD")
query, _ = ut.gen_search_vectors_params(binary_field_name, default_binary_entities,
ut.default_top_k, 1, metric_type="JACCARD")
connect.load_collection(binary_collection)
res = connect.search(binary_collection, query)
res = connect.search(binary_collection, **query)
logging.getLogger().debug(res)
assert len(res[0]) == ut.default_top_k
@ -920,7 +918,7 @@ class TestInsertMultiCollections:
collection_name = ut.gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
connect.load_collection(collection)
res = connect.search(collection, default_single_query)
res = connect.search(collection, **default_single_query)
assert len(res[0]) == 0
connect.insert(collection_name, default_entity)
connect.flush([collection_name])
@ -940,7 +938,7 @@ class TestInsertMultiCollections:
result = connect.insert(collection, default_entity)
connect.flush([collection])
connect.load_collection(collection_name)
res = connect.search(collection_name, default_single_query)
res = connect.search(collection_name, **default_single_query)
stats = connect.get_collection_stats(collection)
assert stats[row_count] == 1
@ -957,7 +955,7 @@ class TestInsertMultiCollections:
result = connect.insert(collection, default_entity)
connect.flush([collection])
connect.load_collection(collection_name)
res = connect.search(collection_name, default_single_query)
res = connect.search(collection_name, **default_single_query)
assert len(res[0]) == 0
@pytest.mark.timeout(ADD_TIMEOUT)
@ -1238,3 +1236,4 @@ class TestInsertInvalidBinary(object):
src_vector[1] = get_field_vectors_value
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entities)

View File

@ -29,14 +29,11 @@ uid_list = "list_collections"
uid_load = "load_collection"
field_name = default_float_vec_field_name
default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": default_top_k, "query": gen_vectors(1, default_dim), "metric_type": "L2",
"params": {"nprobe": 10}}}}
]
}
}
"data": gen_vectors(1, default_dim),
"anns_field": default_float_vec_field_name,
"param": {"metric_type": "L2", "params": {"nprobe": 10}},
"limit": default_top_k,
}
class TestCollectionParams(TestcaseBase):
""" Test case of collection interface """
@ -2841,8 +2838,8 @@ class TestLoadCollection:
connect.load_collection(collection)
connect.release_partitions(collection, [default_tag])
with pytest.raises(Exception) as e:
connect.search(collection, default_single_query, partition_names=[default_tag])
res = connect.search(collection, default_single_query, partition_names=[default_partition_name])
connect.search(collection, **default_single_query, partition_names=[default_tag])
res = connect.search(collection, **default_single_query, partition_names=[default_partition_name])
assert len(res[0]) == default_top_k
@pytest.mark.tags(CaseLabel.L2)
@ -2860,7 +2857,7 @@ class TestLoadCollection:
connect.flush([collection])
connect.load_collection(collection)
connect.release_partitions(collection, [default_partition_name, default_tag])
res = connect.search(collection, default_single_query)
res = connect.search(collection, **default_single_query)
assert len(res[0]) == 0
@pytest.mark.tags(CaseLabel.L0)
@ -2877,7 +2874,7 @@ class TestLoadCollection:
connect.load_partitions(collection, [default_tag])
connect.release_collection(collection)
with pytest.raises(Exception):
connect.search(collection, default_single_query)
connect.search(collection, **default_single_query)
# assert len(res[0]) == 0
@ -2895,11 +2892,11 @@ class TestReleaseAdvanced:
connect.insert(collection, cons.default_entities)
connect.flush([collection])
connect.load_collection(collection)
query, _ = gen_query_vectors(field_name, cons.default_entities, top_k, nq)
future = connect.search(collection, query, _async=True)
params, _ = gen_search_vectors_params(field_name, cons.default_entities, top_k, nq)
future = connect.search(collection, **params, _async=True)
connect.release_collection(collection)
with pytest.raises(Exception):
connect.search(collection, default_single_query)
connect.search(collection, **default_single_query)
@pytest.mark.tags(CaseLabel.L2)
def test_release_partition_during_searching(self, connect, collection):
@ -2911,14 +2908,14 @@ class TestReleaseAdvanced:
nq = 1000
top_k = 1
connect.create_partition(collection, default_tag)
query, _ = gen_query_vectors(field_name, cons.default_entities, top_k, nq)
query, _ = gen_search_vectors_params(field_name, cons.default_entities, top_k, nq)
connect.insert(collection, cons.default_entities, partition_name=default_tag)
connect.flush([collection])
connect.load_partitions(collection, [default_tag])
res = connect.search(collection, query, _async=True)
res = connect.search(collection, **query, _async=True)
connect.release_partitions(collection, [default_tag])
with pytest.raises(Exception) as e:
res = connect.search(collection, default_single_query)
res = connect.search(collection, **default_single_query)
@pytest.mark.tags(CaseLabel.L0)
def test_release_collection_during_searching_A(self, connect, collection):
@ -2930,14 +2927,14 @@ class TestReleaseAdvanced:
nq = 1000
top_k = 1
connect.create_partition(collection, default_tag)
query, _ = gen_query_vectors(field_name, cons.default_entities, top_k, nq)
query, _ = gen_search_vectors_params(field_name, cons.default_entities, top_k, nq)
connect.insert(collection, cons.default_entities, partition_name=default_tag)
connect.flush([collection])
connect.load_partitions(collection, [default_tag])
res = connect.search(collection, query, _async=True)
res = connect.search(collection, **query, _async=True)
connect.release_collection(collection)
with pytest.raises(Exception):
connect.search(collection, default_single_query)
connect.search(collection, **default_single_query)
def _test_release_collection_during_loading(self, connect, collection):
"""
@ -2955,7 +2952,7 @@ class TestReleaseAdvanced:
t.start()
connect.release_collection(collection)
with pytest.raises(Exception):
connect.search(collection, default_single_query)
connect.search(collection, **default_single_query)
def _test_release_partition_during_loading(self, connect, collection):
"""
@ -2973,7 +2970,7 @@ class TestReleaseAdvanced:
t = threading.Thread(target=load, args=())
t.start()
connect.release_partitions(collection, [default_tag])
res = connect.search(collection, default_single_query)
res = connect.search(collection,**default_single_query)
assert len(res[0]) == 0
def _test_release_collection_during_inserting(self, connect, collection):
@ -2993,7 +2990,7 @@ class TestReleaseAdvanced:
t.start()
connect.release_collection(collection)
with pytest.raises(Exception):
res = connect.search(collection, default_single_query)
res = connect.search(collection, **default_single_query)
# assert len(res[0]) == 0
def _test_release_collection_during_indexing(self, connect, collection):
@ -3264,3 +3261,4 @@ class TestLoadPartitionInvalid(object):
partition_name = get_partition_name
with pytest.raises(Exception) as e:
connect.load_partitions(collection, [partition_name])

View File

@ -5,13 +5,11 @@ from common.common_type import CaseLabel
DELETE_TIMEOUT = 60
default_single_query = {
"bool": {
"must": [
{"vector": {default_float_vec_field_name: {"topk": 10, "query": gen_vectors(1, default_dim),
"metric_type": "L2", "params": {"nprobe": 10}}}}
]
}
}
"data": gen_vectors(1, default_dim),
"anns_field": default_float_vec_field_name,
"param": {"metric_type": "L2", "params": {"nprobe": 10}},
"limit": 10,
}
class TestFlushBase:
@ -177,7 +175,7 @@ class TestFlushBase:
res = connect.get_collection_stats(collection)
assert res["row_count"] == len(result.primary_keys)
connect.load_collection(collection)
res = connect.search(collection, default_single_query)
res = connect.search(collection, **default_single_query)
logging.getLogger().debug(res)
assert len(res) == 1
assert len(res[0].ids) == 10
@ -241,14 +239,14 @@ class TestFlushBase:
connect.flush([collection])
# query_vecs = [vectors[0], vectors[1], vectors[-1]]
connect.load_collection(collection)
res = connect.search(collection, default_single_query)
res = connect.search(collection, **default_single_query)
assert len(res) == 1
assert len(res[0].ids) == 10
assert len(res[0].distances) == 10
logging.getLogger().debug(res)
# assert res
# TODO: unable to set config
# TODO: unable to set config
@pytest.mark.tags(CaseLabel.L2)
def _test_collection_count_during_flush(self, connect, collection, args):
"""
@ -291,10 +289,10 @@ class TestFlushBase:
connect.flush([collection])
ids.extend(tmp.primary_keys)
nq = 10000
query, query_vecs = gen_query_vectors(default_float_vec_field_name, default_entities, default_top_k, nq)
query, query_vecs = gen_search_vectors_params(default_float_vec_field_name, default_entities, default_top_k, nq)
time.sleep(0.1)
connect.load_collection(collection)
future = connect.search(collection, query, _async=True)
future = connect.search(collection, **query, _async=True)
res = future.result()
assert res
delete_ids = [ids[0], ids[-1]]
@ -411,3 +409,4 @@ class TestCollectionNameInvalid(object):
connect.flush()
except Exception as e:
assert e.args[0] == "Collection name list can not be None or empty"

View File

@ -22,7 +22,7 @@ uid = "test_index"
BUILD_TIMEOUT = 300
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
query, query_vecs = gen_query_vectors(field_name, default_entities, default_top_k, 1)
# query = gen_search_vectors_params(field_name, default_entities, default_top_k, 1)
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
@ -262,6 +262,7 @@ class TestIndexOperation(TestcaseBase):
"""
pass
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.tags(CaseLabel.L1)
def test_index_drop_index(self):
"""
@ -501,9 +502,9 @@ class TestIndexBase:
nq = get_nq
index_type = get_simple_index["index_type"]
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, search_params=search_param)
params, _ = gen_search_vectors_params(field_name, default_entities, default_top_k, nq, search_params=search_param)
connect.load_collection(collection)
res = connect.search(collection, query)
res = connect.search(collection, **params)
assert len(res) == nq
@pytest.mark.timeout(BUILD_TIMEOUT)
@ -701,8 +702,9 @@ class TestIndexBase:
nq = get_nq
index_type = get_simple_index["index_type"]
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, metric_type=metric_type, search_params=search_param)
res = connect.search(collection, query)
params, _ = gen_search_vectors_params(field_name, default_entities, default_top_k, nq,
metric_type=metric_type, search_params=search_param)
res = connect.search(collection, **params)
assert len(res) == nq
@pytest.mark.timeout(BUILD_TIMEOUT)
@ -1044,10 +1046,11 @@ class TestIndexBinary:
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
connect.load_collection(binary_collection)
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, nq, metric_type="JACCARD")
search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD")
logging.getLogger().info(search_param)
res = connect.search(binary_collection, query, search_params=search_param)
params, _ = gen_search_vectors_params(binary_field_name, default_binary_entities, default_top_k, nq,
search_params=search_param, metric_type="JACCARD")
logging.getLogger().info(params)
res = connect.search(binary_collection, **params)
assert len(res) == nq
@pytest.mark.timeout(BUILD_TIMEOUT)
@ -1282,4 +1285,4 @@ class TestIndexAsync:
logging.getLogger().info("before result")
res = future.result()
# TODO:
logging.getLogger().info(res)
logging.getLogger().info(res)

View File

@ -36,9 +36,9 @@ class TestMixBase:
index = connect.describe_index(collection, "")
create_target_index(default_index, default_float_vec_field_name)
assert index == default_index
query, vecs = gen_query_vectors(default_float_vec_field_name, entities, default_top_k, nq)
query, vecs = gen_search_vectors_params(default_float_vec_field_name, entities, default_top_k, nq)
connect.load_collection(collection)
res = connect.search(collection, query)
res = connect.search(collection, **query)
assert len(res) == nq
assert len(res[0]) == default_top_k
assert res[0]._distances[0] <= epsilon
@ -198,3 +198,4 @@ def check_id_result(result, id):
return id in ids[:limit_in]
else:
return id in ids

File diff suppressed because it is too large Load Diff

View File

@ -381,8 +381,8 @@ def assert_equal_entity(a, b):
pass
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
metric_type="L2", replace_vecs=None):
def gen_search_vectors_params(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
metric_type="L2", replace_vecs=None):
if rand_vector is True:
dimension = len(entities[-1]["values"][0])
query_vectors = gen_vectors(nq, dimension)
@ -390,14 +390,15 @@ def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe":
query_vectors = entities[-1]["values"][:nq]
if replace_vecs:
query_vectors = replace_vecs
must_param = {"vector": {field_name: {"topk": top_k, "query": query_vectors, "params": search_params}}}
must_param["vector"][field_name]["metric_type"] = metric_type
query = {
"bool": {
"must": [must_param]
}
search_params["metric_type"] = metric_type
_params = {
"data": query_vectors,
"anns_field": field_name,
"param": search_params,
"limit": top_k,
}
return query, query_vectors
return _params, query_vectors
def update_query_expr(src_query, keep_old=True, expr=None):
@ -884,13 +885,17 @@ def gen_normal_expressions():
def get_search_param(index_type, metric_type="L2"):
search_params = {"metric_type": metric_type}
if index_type in ivf() or index_type in binary_support():
search_params.update({"nprobe": 64})
nprobe64 = {"nprobe": 64}
search_params.update({"params": nprobe64})
elif index_type in ["HNSW", "RHNSW_FLAT", "RHNSW_SQ", "RHNSW_PQ"]:
search_params.update({"ef": 64})
ef64 = {"ef": 64}
search_params.update({"params": ef64})
elif index_type == "NSG":
search_params.update({"search_length": 100})
length100 = {"search_length": 100}
search_params.update({"params": length100})
elif index_type == "ANNOY":
search_params.update({"search_k": 1000})
search_k = {"search_k": 1000}
search_params.update({"params": search_k})
else:
logging.getLogger().error("Invalid index_type.")
raise Exception("Invalid index_type.")
@ -1019,3 +1024,4 @@ class MyThread(threading.Thread):
super(MyThread, self).join()
if self.exc:
raise self.exc