Add test case and function check about search/query output (#24117)

Signed-off-by: nico <cheng.yuan@zilliz.com>
pull/24152/head
nico 2023-05-16 21:43:23 +08:00 committed by GitHub
parent 6965495b9d
commit a037f36891
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 115 additions and 102 deletions

View File

@ -239,9 +239,7 @@ class ResponseChecker:
search_res.done() search_res.done()
search_res = search_res.result() search_res = search_res.result()
if check_items.get("output_fields", None): if check_items.get("output_fields", None):
for field in check_items['output_fields']: assert set(search_res[0][0].entity.fields) == set(check_items["output_fields"])
assert field in search_res[0][0].entity._raw_data
assert len(check_items['output_fields']) == len(search_res[0][0].entity._raw_data)
log.info('search_results_check: Output fields of query searched is correct') log.info('search_results_check: Output fields of query searched is correct')
if len(search_res) != check_items["nq"]: if len(search_res) != check_items["nq"]:
log.error("search_results_check: Numbers of query searched (%d) " log.error("search_results_check: Numbers of query searched (%d) "

View File

@ -920,3 +920,21 @@ def install_milvus_operator_specific_config(namespace, milvus_mode, release_name
raise MilvusException(message=f'Milvus healthy timeout 1800s') raise MilvusException(message=f'Milvus healthy timeout 1800s')
return host return host
def get_wildcard_output_field_names(collection_w, output_fields):
all_fields = collection_w.schema.fields
scalar_fields = []
vector_field = []
for field in all_fields:
if field.dtype == DataType.FLOAT_VECTOR:
vector_field.append(field.name)
else:
scalar_fields.append(field.name)
if "*" in output_fields:
output_fields.remove("*")
output_fields.extend(scalar_fields)
if "%" in output_fields:
output_fields.remove("%")
output_fields.extend(vector_field)
return output_fields

View File

@ -598,6 +598,28 @@ class TestQueryParams(TestcaseBase):
check_task=CheckTasks.check_query_results, check_task=CheckTasks.check_query_results,
check_items={exp_res: res, "with_vec": True}) check_items={exp_res: res, "with_vec": True})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("wildcard_output_fields", [["*"], ["*", default_float_field_name],
["*", default_int_field_name],
["%"], ["%", default_float_field_name], ["*", "%"]])
def test_query_output_field_wildcard(self, wildcard_output_fields):
"""
target: test query with output fields using wildcard
method: query with one output_field (wildcard)
expected: query success
"""
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
df = cf.gen_default_dataframe_data()
collection_w.insert(df)
assert collection_w.num_entities == ct.default_nb
output_fields = cf.get_wildcard_output_field_names(collection_w, wildcard_output_fields)
output_fields.append(default_int_field_name)
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
collection_w.load()
with_vec = True if ct.default_float_vec_field_name in output_fields else False
actual_res = collection_w.query(default_term_expr, output_fields=wildcard_output_fields)[0]
assert set(actual_res[0].keys()) == set(output_fields)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.skip(reason="https://github.com/milvus-io/milvus/issues/12680") @pytest.mark.skip(reason="https://github.com/milvus-io/milvus/issues/12680")
@pytest.mark.parametrize("vec_fields", [[cf.gen_float_vec_field(name="float_vector1")]]) @pytest.mark.parametrize("vec_fields", [[cf.gen_float_vec_field(name="float_vector1")]])

View File

@ -2823,7 +2823,7 @@ class TestCollectionSearch(TestcaseBase):
# 2. search # 2. search
log.info("test_search_with_output_fields_empty: Searching collection %s" % collection_w.name) log.info("test_search_with_output_fields_empty: Searching collection %s" % collection_w.name)
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
res = collection_w.search(vectors[:nq], default_search_field, collection_w.search(vectors[:nq], default_search_field,
default_search_params, default_limit, default_search_params, default_limit,
default_search_exp, _async=_async, default_search_exp, _async=_async,
output_fields=[], output_fields=[],
@ -2831,12 +2831,8 @@ class TestCollectionSearch(TestcaseBase):
check_items={"nq": nq, check_items={"nq": nq,
"ids": insert_ids, "ids": insert_ids,
"limit": default_limit, "limit": default_limit,
"_async": _async})[0] "_async": _async,
if _async: "output_fields": []})
res.done()
res = res.result()
assert len(res[0][0].entity._row_data) == 0
assert res[0][0].entity.fields == []
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_search_with_output_field(self, auto_id, _async): def test_search_with_output_field(self, auto_id, _async):
@ -2859,12 +2855,8 @@ class TestCollectionSearch(TestcaseBase):
check_items={"nq": default_nq, check_items={"nq": default_nq,
"ids": insert_ids, "ids": insert_ids,
"limit": default_limit, "limit": default_limit,
"_async": _async})[0] "_async": _async,
if _async: "output_fields": [default_int64_field_name]})[0]
res.done()
res = res.result()
assert len(res[0][0].entity._row_data) != 0
assert default_int64_field_name in res[0][0].entity._row_data
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_search_with_output_vector_field(self, auto_id, _async): def test_search_with_output_vector_field(self, auto_id, _async):
@ -2878,7 +2870,7 @@ class TestCollectionSearch(TestcaseBase):
auto_id=auto_id)[0:4] auto_id=auto_id)[0:4]
# 2. search # 2. search
log.info("test_search_with_output_field: Searching collection %s" % collection_w.name) log.info("test_search_with_output_field: Searching collection %s" % collection_w.name)
res = collection_w.search(vectors[:default_nq], default_search_field, collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit, default_search_params, default_limit,
default_search_exp, _async=_async, default_search_exp, _async=_async,
output_fields=[field_name], output_fields=[field_name],
@ -2886,12 +2878,8 @@ class TestCollectionSearch(TestcaseBase):
check_items={"nq": default_nq, check_items={"nq": default_nq,
"ids": insert_ids, "ids": insert_ids,
"limit": default_limit, "limit": default_limit,
"_async": _async})[0] "_async": _async,
if _async: "output_fields": [field_name]})
res.done()
res = res.result()
assert len(res[0][0].entity._row_data) != 0
assert field_name in res[0][0].entity._row_data
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_search_with_output_fields(self, nb, nq, dim, auto_id, _async): def test_search_with_output_fields(self, nb, nq, dim, auto_id, _async):
@ -2908,21 +2896,17 @@ class TestCollectionSearch(TestcaseBase):
# 2. search # 2. search
log.info("test_search_with_output_fields: Searching collection %s" % collection_w.name) log.info("test_search_with_output_fields: Searching collection %s" % collection_w.name)
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
res = collection_w.search(vectors[:nq], default_search_field, output_fields = [default_int64_field_name, default_float_field_name]
collection_w.search(vectors[:nq], default_search_field,
default_search_params, default_limit, default_search_params, default_limit,
default_search_exp, _async=_async, default_search_exp, _async=_async,
output_fields=[default_int64_field_name, output_fields=output_fields,
default_float_field_name],
check_task=CheckTasks.check_search_results, check_task=CheckTasks.check_search_results,
check_items={"nq": nq, check_items={"nq": nq,
"ids": insert_ids, "ids": insert_ids,
"limit": default_limit, "limit": default_limit,
"_async": _async})[0] "_async": _async,
if _async: "output_fields": output_fields})
res.done()
res = res.result()
assert len(res[0][0].entity._row_data) != 0
assert (default_int64_field_name and default_float_field_name) in res[0][0].entity._row_data
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("index, params", @pytest.mark.parametrize("index, params",
@ -2955,7 +2939,8 @@ class TestCollectionSearch(TestcaseBase):
output_fields=[field_name], output_fields=[field_name],
check_task=CheckTasks.check_search_results, check_task=CheckTasks.check_search_results,
check_items={"nq": 1, check_items={"nq": 1,
"limit": default_limit})[0] "limit": default_limit,
"output_fields": [field_name]})[0]
# 4. check the result vectors should be equal to the inserted # 4. check the result vectors should be equal to the inserted
for _id in range(default_limit): for _id in range(default_limit):
@ -2965,6 +2950,8 @@ class TestCollectionSearch(TestcaseBase):
if vectorInsert != vectorRes: if vectorInsert != vectorRes:
getcontext().rounding = getattr(decimal, 'ROUND_HALF_UP') getcontext().rounding = getattr(decimal, 'ROUND_HALF_UP')
vectorInsert = Decimal(data[field_name][res[0][_id].id][i]).quantize(Decimal('0.00000')) vectorInsert = Decimal(data[field_name][res[0][_id].id][i]).quantize(Decimal('0.00000'))
log.info(data[field_name][res[0][_id].id][i])
log.info(res[0][_id].entity.float_vector[i])
assert float(str(vectorInsert)) == float(vectorRes) assert float(str(vectorInsert)) == float(vectorRes)
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
@ -3050,19 +3037,14 @@ class TestCollectionSearch(TestcaseBase):
collection_w = self.init_collection_general(prefix, True)[0] collection_w = self.init_collection_general(prefix, True)[0]
# 2. search with output field vector # 2. search with output field vector
res = collection_w.search(vectors[:1], default_search_field, output_fields = [default_float_field_name, default_string_field_name, default_search_field]
collection_w.search(vectors[:1], default_search_field,
default_search_params, default_limit, default_search_exp, default_search_params, default_limit, default_search_exp,
output_fields=[default_float_field_name, output_fields=output_fields,
default_string_field_name,
default_search_field],
check_task=CheckTasks.check_search_results, check_task=CheckTasks.check_search_results,
check_items={"nq": 1, check_items={"nq": 1,
"limit": default_limit})[0] "limit": default_limit,
"output_fields": output_fields})
# 3. check the result
assert default_float_field_name, default_string_field_name in res[0][0].entity._row_data
assert default_search_field in res[0][0].entity._row_data
assert len(res[0][0].entity._row_data) == 3
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_search_output_vector_field_and_pk_field(self): def test_search_output_vector_field_and_pk_field(self):
@ -3078,17 +3060,13 @@ class TestCollectionSearch(TestcaseBase):
# 2. search with output field vector # 2. search with output field vector
output_fields = [default_int64_field_name, default_string_field_name, default_search_field] output_fields = [default_int64_field_name, default_string_field_name, default_search_field]
res = collection_w.search(vectors[:1], default_search_field, collection_w.search(vectors[:1], default_search_field,
default_search_params, default_limit, default_search_exp, default_search_params, default_limit, default_search_exp,
output_fields=output_fields, output_fields=output_fields,
check_task=CheckTasks.check_search_results, check_task=CheckTasks.check_search_results,
check_items={"nq": 1, check_items={"nq": 1,
"limit": default_limit})[0] "limit": default_limit,
"output_fields": output_fields})
# 3. check the result variables
assert default_int64_field_name, default_string_field_name in res[0][0].entity._row_data
assert default_search_field in res[0][0].entity._row_data
assert len(res[0][0].entity._row_data) == 3
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_search_output_field_vector_with_partition(self): def test_search_output_field_vector_with_partition(self):
@ -3129,8 +3107,9 @@ class TestCollectionSearch(TestcaseBase):
assert str(vectorInsert) == vectorRes assert str(vectorInsert) == vectorRes
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("output_fields", [["*"], ["*", default_float_field_name]]) @pytest.mark.parametrize("wildcard_output_fields", [["*"], ["*", default_float_field_name], ["*", default_search_field],
def test_search_with_output_field_wildcard(self, output_fields, auto_id, _async): ["%"], ["%", default_float_field_name], ["*", "%"]])
def test_search_with_output_field_wildcard(self, wildcard_output_fields, auto_id, _async):
""" """
target: test search with output fields using wildcard target: test search with output fields using wildcard
method: search with one output_field (wildcard) method: search with one output_field (wildcard)
@ -3141,21 +3120,17 @@ class TestCollectionSearch(TestcaseBase):
auto_id=auto_id)[0:4] auto_id=auto_id)[0:4]
# 2. search # 2. search
log.info("test_search_with_output_field_wildcard: Searching collection %s" % collection_w.name) log.info("test_search_with_output_field_wildcard: Searching collection %s" % collection_w.name)
output_fields = cf.get_wildcard_output_field_names(collection_w, wildcard_output_fields)
res = collection_w.search(vectors[:default_nq], default_search_field, collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit, default_search_params, default_limit,
default_search_exp, _async=_async, default_search_exp, _async=_async,
output_fields=output_fields, output_fields=wildcard_output_fields,
check_task=CheckTasks.check_search_results, check_task=CheckTasks.check_search_results,
check_items={"nq": default_nq, check_items={"nq": default_nq,
"ids": insert_ids, "ids": insert_ids,
"limit": default_limit, "limit": default_limit,
"_async": _async})[0] "_async": _async,
if _async: "output_fields": output_fields})
res.done()
res = res.result()
assert len(res[0][0].entity._row_data) != 0
assert (default_int64_field_name and default_float_field_name) in res[0][0].entity._row_data
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_search_multi_collections(self, nb, nq, dim, auto_id, _async): def test_search_multi_collections(self, nb, nq, dim, auto_id, _async):