Test search with round decimal (#11211)

Signed-off-by: ThreadDao <yufen.zong@zilliz.com>
pull/11291/head
ThreadDao 2021-11-05 10:07:04 +08:00 committed by GitHub
parent f24261450f
commit 93d869a39e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 6 deletions

View File

@ -103,13 +103,13 @@ class ApiCollectionWrapper:
return res, check_result
def search(self, data, anns_field, param, limit, expr=None,
partition_names=None, output_fields=None, timeout=None,
partition_names=None, output_fields=None, timeout=None, round_decimal=-1,
check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
func_name = sys._getframe().f_code.co_name
res, check = api_request([self.collection.search, data, anns_field, param, limit,
expr, partition_names, output_fields, timeout], **kwargs)
expr, partition_names, output_fields, timeout, round_decimal], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
data=data, anns_field=anns_field, param=param, limit=limit,
expr=expr, partition_names=partition_names,

View File

@ -37,6 +37,8 @@ entity = gen_entities(1, is_normal=True)
entities = gen_entities(default_nb, is_normal=True)
raw_vectors, binary_entities = gen_binary_entities(default_nb)
default_query, _ = gen_search_vectors_params(field_name, entities, default_top_k, nq)
# default_binary_query, _ = gen_search_vectors_params(binary_field_name, binary_entities, default_top_k, nq)
@ -717,6 +719,25 @@ class TestCollectionSearchInvalid(TestcaseBase):
check_items={"err_code": 1,
"err_msg": "`travel_timestamp` value %s is illegal" % invalid_travel_time})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("round_decimal", [7, -2, 999, 1.0, None, [1], "string", {}])
def test_search_invalid_round_decimal(self, round_decimal):
"""
target: test search with invalid round decimal
method: search with invalid round decimal
expected: raise exception and report the error
"""
# 1. initialize with data
collection_w = self.init_collection_general(prefix, True, nb=10)[0]
# 2. search
log.info("test_search_output_field_vector: Searching collection %s" % collection_w.name)
collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit,
default_search_exp, round_decimal=round_decimal,
check_task=CheckTasks.err_res,
check_items={"err_code": 1,
"err_msg": f"`round_decimal` value {round_decimal} is illegal"})
class TestCollectionSearch(TestcaseBase):
""" Test case of search interface """
@ -765,7 +786,7 @@ class TestCollectionSearch(TestcaseBase):
collection_w.search(vectors[:nq], default_search_field,
default_search_params, default_limit,
default_search_exp,
travel_timestamp=time_stamp-1,
travel_timestamp=time_stamp - 1,
check_task=CheckTasks.check_search_results,
check_items={"nq": nq,
"ids": [],
@ -1173,7 +1194,7 @@ class TestCollectionSearch(TestcaseBase):
collection_w.search(vectors[:nq], default_search_field,
default_search_params, limit,
default_search_exp, _async=_async,
travel_timestamp=time_stamp+1,
travel_timestamp=time_stamp + 1,
check_task=CheckTasks.check_search_results,
check_items={"nq": nq,
"ids": insert_ids,
@ -2076,6 +2097,37 @@ class TestCollectionSearch(TestcaseBase):
for t in threads:
t.join()
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("round_decimal", [0, 1, 2, 3, 4, 5, 6])
def test_search_round_decimal(self, round_decimal):
"""
target: test search with invalid round decimal
method: search with invalid round decimal
expected: raise exception and report the error
"""
import math
tmp_nb = 500
tmp_nq = 1
tmp_limit = 5
# 1. initialize with data
collection_w = self.init_collection_general(prefix, True, nb=tmp_nb)[0]
# 2. search
log.info("test_search_round_decimal: Searching collection %s" % collection_w.name)
res, _ = collection_w.search(vectors[:tmp_nq], default_search_field,
default_search_params, tmp_limit)
res_round, _ = collection_w.search(vectors[:tmp_nq], default_search_field,
default_search_params, tmp_limit, round_decimal=round_decimal)
abs_tol = pow(10, 1 - round_decimal)
# log.debug(f'abs_tol: {abs_tol}')
for i in range(tmp_limit):
dis_expect = round(res[0][i].distance, round_decimal)
dis_actual = res_round[0][i].distance
# log.debug(f'actual: {dis_actual}, expect: {dis_expect}')
# abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
assert math.isclose(dis_actual, dis_expect, rel_tol=0, abs_tol=abs_tol)
"""
******************************************************************
@ -2216,7 +2268,7 @@ class TestSearchBase:
method: search with the given vectors, check the result
expected: the length of the result is top_k
"""
top_k = 16385 # max top k is 16384
top_k = 16385 # max top k is 16384
nq = get_nq
entities, ids = init_data(connect, collection)
query, _ = gen_search_vectors_params(field_name, entities, top_k, nq)
@ -2465,7 +2517,8 @@ class TestSearchBase:
get_simple_index["metric_type"] = metric_type
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, _ = gen_search_vectors_params(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
query, _ = gen_search_vectors_params(field_name, entities, top_k, nq, metric_type="IP",
search_params=search_param)
connect.load_collection(collection)
res = connect.search(collection, **query)
assert check_id_result(res[0], ids[0])