Add test case and function check about search/query output (#24117)

Signed-off-by: nico <cheng.yuan@zilliz.com>
pull/24152/head
nico 2023-05-16 21:43:23 +08:00 committed by GitHub
parent 6965495b9d
commit a037f36891
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 115 additions and 102 deletions

View File

@ -239,9 +239,7 @@ class ResponseChecker:
search_res.done()
search_res = search_res.result()
if check_items.get("output_fields", None):
for field in check_items['output_fields']:
assert field in search_res[0][0].entity._raw_data
assert len(check_items['output_fields']) == len(search_res[0][0].entity._raw_data)
assert set(search_res[0][0].entity.fields) == set(check_items["output_fields"])
log.info('search_results_check: Output fields of query searched is correct')
if len(search_res) != check_items["nq"]:
log.error("search_results_check: Numbers of query searched (%d) "

View File

@ -920,3 +920,21 @@ def install_milvus_operator_specific_config(namespace, milvus_mode, release_name
raise MilvusException(message=f'Milvus healthy timeout 1800s')
return host
def get_wildcard_output_field_names(collection_w, output_fields):
all_fields = collection_w.schema.fields
scalar_fields = []
vector_field = []
for field in all_fields:
if field.dtype == DataType.FLOAT_VECTOR:
vector_field.append(field.name)
else:
scalar_fields.append(field.name)
if "*" in output_fields:
output_fields.remove("*")
output_fields.extend(scalar_fields)
if "%" in output_fields:
output_fields.remove("%")
output_fields.extend(vector_field)
return output_fields

View File

@ -598,6 +598,28 @@ class TestQueryParams(TestcaseBase):
check_task=CheckTasks.check_query_results,
check_items={exp_res: res, "with_vec": True})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("wildcard_output_fields", [["*"], ["*", default_float_field_name],
["*", default_int_field_name],
["%"], ["%", default_float_field_name], ["*", "%"]])
def test_query_output_field_wildcard(self, wildcard_output_fields):
"""
target: test query with output fields using wildcard
method: query with one output_field (wildcard)
expected: query success
"""
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
df = cf.gen_default_dataframe_data()
collection_w.insert(df)
assert collection_w.num_entities == ct.default_nb
output_fields = cf.get_wildcard_output_field_names(collection_w, wildcard_output_fields)
output_fields.append(default_int_field_name)
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
collection_w.load()
with_vec = True if ct.default_float_vec_field_name in output_fields else False
actual_res = collection_w.query(default_term_expr, output_fields=wildcard_output_fields)[0]
assert set(actual_res[0].keys()) == set(output_fields)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.skip(reason="https://github.com/milvus-io/milvus/issues/12680")
@pytest.mark.parametrize("vec_fields", [[cf.gen_float_vec_field(name="float_vector1")]])

View File

@ -2823,7 +2823,7 @@ class TestCollectionSearch(TestcaseBase):
# 2. search
log.info("test_search_with_output_fields_empty: Searching collection %s" % collection_w.name)
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
res = collection_w.search(vectors[:nq], default_search_field,
collection_w.search(vectors[:nq], default_search_field,
default_search_params, default_limit,
default_search_exp, _async=_async,
output_fields=[],
@ -2831,12 +2831,8 @@ class TestCollectionSearch(TestcaseBase):
check_items={"nq": nq,
"ids": insert_ids,
"limit": default_limit,
"_async": _async})[0]
if _async:
res.done()
res = res.result()
assert len(res[0][0].entity._row_data) == 0
assert res[0][0].entity.fields == []
"_async": _async,
"output_fields": []})
@pytest.mark.tags(CaseLabel.L1)
def test_search_with_output_field(self, auto_id, _async):
@ -2859,12 +2855,8 @@ class TestCollectionSearch(TestcaseBase):
check_items={"nq": default_nq,
"ids": insert_ids,
"limit": default_limit,
"_async": _async})[0]
if _async:
res.done()
res = res.result()
assert len(res[0][0].entity._row_data) != 0
assert default_int64_field_name in res[0][0].entity._row_data
"_async": _async,
"output_fields": [default_int64_field_name]})[0]
@pytest.mark.tags(CaseLabel.L1)
def test_search_with_output_vector_field(self, auto_id, _async):
@ -2878,7 +2870,7 @@ class TestCollectionSearch(TestcaseBase):
auto_id=auto_id)[0:4]
# 2. search
log.info("test_search_with_output_field: Searching collection %s" % collection_w.name)
res = collection_w.search(vectors[:default_nq], default_search_field,
collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit,
default_search_exp, _async=_async,
output_fields=[field_name],
@ -2886,12 +2878,8 @@ class TestCollectionSearch(TestcaseBase):
check_items={"nq": default_nq,
"ids": insert_ids,
"limit": default_limit,
"_async": _async})[0]
if _async:
res.done()
res = res.result()
assert len(res[0][0].entity._row_data) != 0
assert field_name in res[0][0].entity._row_data
"_async": _async,
"output_fields": [field_name]})
@pytest.mark.tags(CaseLabel.L2)
def test_search_with_output_fields(self, nb, nq, dim, auto_id, _async):
@ -2908,21 +2896,17 @@ class TestCollectionSearch(TestcaseBase):
# 2. search
log.info("test_search_with_output_fields: Searching collection %s" % collection_w.name)
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
res = collection_w.search(vectors[:nq], default_search_field,
output_fields = [default_int64_field_name, default_float_field_name]
collection_w.search(vectors[:nq], default_search_field,
default_search_params, default_limit,
default_search_exp, _async=_async,
output_fields=[default_int64_field_name,
default_float_field_name],
output_fields=output_fields,
check_task=CheckTasks.check_search_results,
check_items={"nq": nq,
"ids": insert_ids,
"limit": default_limit,
"_async": _async})[0]
if _async:
res.done()
res = res.result()
assert len(res[0][0].entity._row_data) != 0
assert (default_int64_field_name and default_float_field_name) in res[0][0].entity._row_data
"_async": _async,
"output_fields": output_fields})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("index, params",
@ -2955,7 +2939,8 @@ class TestCollectionSearch(TestcaseBase):
output_fields=[field_name],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"limit": default_limit})[0]
"limit": default_limit,
"output_fields": [field_name]})[0]
# 4. check the result vectors should be equal to the inserted
for _id in range(default_limit):
@ -2965,6 +2950,8 @@ class TestCollectionSearch(TestcaseBase):
if vectorInsert != vectorRes:
getcontext().rounding = getattr(decimal, 'ROUND_HALF_UP')
vectorInsert = Decimal(data[field_name][res[0][_id].id][i]).quantize(Decimal('0.00000'))
log.info(data[field_name][res[0][_id].id][i])
log.info(res[0][_id].entity.float_vector[i])
assert float(str(vectorInsert)) == float(vectorRes)
@pytest.mark.tags(CaseLabel.L2)
@ -3050,19 +3037,14 @@ class TestCollectionSearch(TestcaseBase):
collection_w = self.init_collection_general(prefix, True)[0]
# 2. search with output field vector
res = collection_w.search(vectors[:1], default_search_field,
output_fields = [default_float_field_name, default_string_field_name, default_search_field]
collection_w.search(vectors[:1], default_search_field,
default_search_params, default_limit, default_search_exp,
output_fields=[default_float_field_name,
default_string_field_name,
default_search_field],
output_fields=output_fields,
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"limit": default_limit})[0]
# 3. check the result
assert default_float_field_name, default_string_field_name in res[0][0].entity._row_data
assert default_search_field in res[0][0].entity._row_data
assert len(res[0][0].entity._row_data) == 3
"limit": default_limit,
"output_fields": output_fields})
@pytest.mark.tags(CaseLabel.L2)
def test_search_output_vector_field_and_pk_field(self):
@ -3078,17 +3060,13 @@ class TestCollectionSearch(TestcaseBase):
# 2. search with output field vector
output_fields = [default_int64_field_name, default_string_field_name, default_search_field]
res = collection_w.search(vectors[:1], default_search_field,
collection_w.search(vectors[:1], default_search_field,
default_search_params, default_limit, default_search_exp,
output_fields=output_fields,
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"limit": default_limit})[0]
# 3. check the result variables
assert default_int64_field_name, default_string_field_name in res[0][0].entity._row_data
assert default_search_field in res[0][0].entity._row_data
assert len(res[0][0].entity._row_data) == 3
"limit": default_limit,
"output_fields": output_fields})
@pytest.mark.tags(CaseLabel.L2)
def test_search_output_field_vector_with_partition(self):
@ -3129,8 +3107,9 @@ class TestCollectionSearch(TestcaseBase):
assert str(vectorInsert) == vectorRes
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("output_fields", [["*"], ["*", default_float_field_name]])
def test_search_with_output_field_wildcard(self, output_fields, auto_id, _async):
@pytest.mark.parametrize("wildcard_output_fields", [["*"], ["*", default_float_field_name], ["*", default_search_field],
["%"], ["%", default_float_field_name], ["*", "%"]])
def test_search_with_output_field_wildcard(self, wildcard_output_fields, auto_id, _async):
"""
target: test search with output fields using wildcard
method: search with one output_field (wildcard)
@ -3141,21 +3120,17 @@ class TestCollectionSearch(TestcaseBase):
auto_id=auto_id)[0:4]
# 2. search
log.info("test_search_with_output_field_wildcard: Searching collection %s" % collection_w.name)
res = collection_w.search(vectors[:default_nq], default_search_field,
output_fields = cf.get_wildcard_output_field_names(collection_w, wildcard_output_fields)
collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit,
default_search_exp, _async=_async,
output_fields=output_fields,
output_fields=wildcard_output_fields,
check_task=CheckTasks.check_search_results,
check_items={"nq": default_nq,
"ids": insert_ids,
"limit": default_limit,
"_async": _async})[0]
if _async:
res.done()
res = res.result()
assert len(res[0][0].entity._row_data) != 0
assert (default_int64_field_name and default_float_field_name) in res[0][0].entity._row_data
"_async": _async,
"output_fields": output_fields})
@pytest.mark.tags(CaseLabel.L2)
def test_search_multi_collections(self, nb, nq, dim, auto_id, _async):