mirror of https://github.com/milvus-io/milvus.git
Add test case and function check about search/query output (#24117)
Signed-off-by: nico <cheng.yuan@zilliz.com>pull/24152/head
parent
6965495b9d
commit
a037f36891
|
@ -239,9 +239,7 @@ class ResponseChecker:
|
||||||
search_res.done()
|
search_res.done()
|
||||||
search_res = search_res.result()
|
search_res = search_res.result()
|
||||||
if check_items.get("output_fields", None):
|
if check_items.get("output_fields", None):
|
||||||
for field in check_items['output_fields']:
|
assert set(search_res[0][0].entity.fields) == set(check_items["output_fields"])
|
||||||
assert field in search_res[0][0].entity._raw_data
|
|
||||||
assert len(check_items['output_fields']) == len(search_res[0][0].entity._raw_data)
|
|
||||||
log.info('search_results_check: Output fields of query searched is correct')
|
log.info('search_results_check: Output fields of query searched is correct')
|
||||||
if len(search_res) != check_items["nq"]:
|
if len(search_res) != check_items["nq"]:
|
||||||
log.error("search_results_check: Numbers of query searched (%d) "
|
log.error("search_results_check: Numbers of query searched (%d) "
|
||||||
|
|
|
@ -920,3 +920,21 @@ def install_milvus_operator_specific_config(namespace, milvus_mode, release_name
|
||||||
raise MilvusException(message=f'Milvus healthy timeout 1800s')
|
raise MilvusException(message=f'Milvus healthy timeout 1800s')
|
||||||
|
|
||||||
return host
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
def get_wildcard_output_field_names(collection_w, output_fields):
|
||||||
|
all_fields = collection_w.schema.fields
|
||||||
|
scalar_fields = []
|
||||||
|
vector_field = []
|
||||||
|
for field in all_fields:
|
||||||
|
if field.dtype == DataType.FLOAT_VECTOR:
|
||||||
|
vector_field.append(field.name)
|
||||||
|
else:
|
||||||
|
scalar_fields.append(field.name)
|
||||||
|
if "*" in output_fields:
|
||||||
|
output_fields.remove("*")
|
||||||
|
output_fields.extend(scalar_fields)
|
||||||
|
if "%" in output_fields:
|
||||||
|
output_fields.remove("%")
|
||||||
|
output_fields.extend(vector_field)
|
||||||
|
return output_fields
|
||||||
|
|
|
@ -598,6 +598,28 @@ class TestQueryParams(TestcaseBase):
|
||||||
check_task=CheckTasks.check_query_results,
|
check_task=CheckTasks.check_query_results,
|
||||||
check_items={exp_res: res, "with_vec": True})
|
check_items={exp_res: res, "with_vec": True})
|
||||||
|
|
||||||
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
|
@pytest.mark.parametrize("wildcard_output_fields", [["*"], ["*", default_float_field_name],
|
||||||
|
["*", default_int_field_name],
|
||||||
|
["%"], ["%", default_float_field_name], ["*", "%"]])
|
||||||
|
def test_query_output_field_wildcard(self, wildcard_output_fields):
|
||||||
|
"""
|
||||||
|
target: test query with output fields using wildcard
|
||||||
|
method: query with one output_field (wildcard)
|
||||||
|
expected: query success
|
||||||
|
"""
|
||||||
|
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
|
||||||
|
df = cf.gen_default_dataframe_data()
|
||||||
|
collection_w.insert(df)
|
||||||
|
assert collection_w.num_entities == ct.default_nb
|
||||||
|
output_fields = cf.get_wildcard_output_field_names(collection_w, wildcard_output_fields)
|
||||||
|
output_fields.append(default_int_field_name)
|
||||||
|
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
|
||||||
|
collection_w.load()
|
||||||
|
with_vec = True if ct.default_float_vec_field_name in output_fields else False
|
||||||
|
actual_res = collection_w.query(default_term_expr, output_fields=wildcard_output_fields)[0]
|
||||||
|
assert set(actual_res[0].keys()) == set(output_fields)
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
@pytest.mark.skip(reason="https://github.com/milvus-io/milvus/issues/12680")
|
@pytest.mark.skip(reason="https://github.com/milvus-io/milvus/issues/12680")
|
||||||
@pytest.mark.parametrize("vec_fields", [[cf.gen_float_vec_field(name="float_vector1")]])
|
@pytest.mark.parametrize("vec_fields", [[cf.gen_float_vec_field(name="float_vector1")]])
|
||||||
|
|
|
@ -2823,7 +2823,7 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
# 2. search
|
# 2. search
|
||||||
log.info("test_search_with_output_fields_empty: Searching collection %s" % collection_w.name)
|
log.info("test_search_with_output_fields_empty: Searching collection %s" % collection_w.name)
|
||||||
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
|
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
|
||||||
res = collection_w.search(vectors[:nq], default_search_field,
|
collection_w.search(vectors[:nq], default_search_field,
|
||||||
default_search_params, default_limit,
|
default_search_params, default_limit,
|
||||||
default_search_exp, _async=_async,
|
default_search_exp, _async=_async,
|
||||||
output_fields=[],
|
output_fields=[],
|
||||||
|
@ -2831,12 +2831,8 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
check_items={"nq": nq,
|
check_items={"nq": nq,
|
||||||
"ids": insert_ids,
|
"ids": insert_ids,
|
||||||
"limit": default_limit,
|
"limit": default_limit,
|
||||||
"_async": _async})[0]
|
"_async": _async,
|
||||||
if _async:
|
"output_fields": []})
|
||||||
res.done()
|
|
||||||
res = res.result()
|
|
||||||
assert len(res[0][0].entity._row_data) == 0
|
|
||||||
assert res[0][0].entity.fields == []
|
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
def test_search_with_output_field(self, auto_id, _async):
|
def test_search_with_output_field(self, auto_id, _async):
|
||||||
|
@ -2859,12 +2855,8 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
check_items={"nq": default_nq,
|
check_items={"nq": default_nq,
|
||||||
"ids": insert_ids,
|
"ids": insert_ids,
|
||||||
"limit": default_limit,
|
"limit": default_limit,
|
||||||
"_async": _async})[0]
|
"_async": _async,
|
||||||
if _async:
|
"output_fields": [default_int64_field_name]})[0]
|
||||||
res.done()
|
|
||||||
res = res.result()
|
|
||||||
assert len(res[0][0].entity._row_data) != 0
|
|
||||||
assert default_int64_field_name in res[0][0].entity._row_data
|
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
def test_search_with_output_vector_field(self, auto_id, _async):
|
def test_search_with_output_vector_field(self, auto_id, _async):
|
||||||
|
@ -2878,7 +2870,7 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
auto_id=auto_id)[0:4]
|
auto_id=auto_id)[0:4]
|
||||||
# 2. search
|
# 2. search
|
||||||
log.info("test_search_with_output_field: Searching collection %s" % collection_w.name)
|
log.info("test_search_with_output_field: Searching collection %s" % collection_w.name)
|
||||||
res = collection_w.search(vectors[:default_nq], default_search_field,
|
collection_w.search(vectors[:default_nq], default_search_field,
|
||||||
default_search_params, default_limit,
|
default_search_params, default_limit,
|
||||||
default_search_exp, _async=_async,
|
default_search_exp, _async=_async,
|
||||||
output_fields=[field_name],
|
output_fields=[field_name],
|
||||||
|
@ -2886,12 +2878,8 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
check_items={"nq": default_nq,
|
check_items={"nq": default_nq,
|
||||||
"ids": insert_ids,
|
"ids": insert_ids,
|
||||||
"limit": default_limit,
|
"limit": default_limit,
|
||||||
"_async": _async})[0]
|
"_async": _async,
|
||||||
if _async:
|
"output_fields": [field_name]})
|
||||||
res.done()
|
|
||||||
res = res.result()
|
|
||||||
assert len(res[0][0].entity._row_data) != 0
|
|
||||||
assert field_name in res[0][0].entity._row_data
|
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
def test_search_with_output_fields(self, nb, nq, dim, auto_id, _async):
|
def test_search_with_output_fields(self, nb, nq, dim, auto_id, _async):
|
||||||
|
@ -2908,21 +2896,17 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
# 2. search
|
# 2. search
|
||||||
log.info("test_search_with_output_fields: Searching collection %s" % collection_w.name)
|
log.info("test_search_with_output_fields: Searching collection %s" % collection_w.name)
|
||||||
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
|
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
|
||||||
res = collection_w.search(vectors[:nq], default_search_field,
|
output_fields = [default_int64_field_name, default_float_field_name]
|
||||||
|
collection_w.search(vectors[:nq], default_search_field,
|
||||||
default_search_params, default_limit,
|
default_search_params, default_limit,
|
||||||
default_search_exp, _async=_async,
|
default_search_exp, _async=_async,
|
||||||
output_fields=[default_int64_field_name,
|
output_fields=output_fields,
|
||||||
default_float_field_name],
|
|
||||||
check_task=CheckTasks.check_search_results,
|
check_task=CheckTasks.check_search_results,
|
||||||
check_items={"nq": nq,
|
check_items={"nq": nq,
|
||||||
"ids": insert_ids,
|
"ids": insert_ids,
|
||||||
"limit": default_limit,
|
"limit": default_limit,
|
||||||
"_async": _async})[0]
|
"_async": _async,
|
||||||
if _async:
|
"output_fields": output_fields})
|
||||||
res.done()
|
|
||||||
res = res.result()
|
|
||||||
assert len(res[0][0].entity._row_data) != 0
|
|
||||||
assert (default_int64_field_name and default_float_field_name) in res[0][0].entity._row_data
|
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
@pytest.mark.parametrize("index, params",
|
@pytest.mark.parametrize("index, params",
|
||||||
|
@ -2955,7 +2939,8 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
output_fields=[field_name],
|
output_fields=[field_name],
|
||||||
check_task=CheckTasks.check_search_results,
|
check_task=CheckTasks.check_search_results,
|
||||||
check_items={"nq": 1,
|
check_items={"nq": 1,
|
||||||
"limit": default_limit})[0]
|
"limit": default_limit,
|
||||||
|
"output_fields": [field_name]})[0]
|
||||||
|
|
||||||
# 4. check the result vectors should be equal to the inserted
|
# 4. check the result vectors should be equal to the inserted
|
||||||
for _id in range(default_limit):
|
for _id in range(default_limit):
|
||||||
|
@ -2965,6 +2950,8 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
if vectorInsert != vectorRes:
|
if vectorInsert != vectorRes:
|
||||||
getcontext().rounding = getattr(decimal, 'ROUND_HALF_UP')
|
getcontext().rounding = getattr(decimal, 'ROUND_HALF_UP')
|
||||||
vectorInsert = Decimal(data[field_name][res[0][_id].id][i]).quantize(Decimal('0.00000'))
|
vectorInsert = Decimal(data[field_name][res[0][_id].id][i]).quantize(Decimal('0.00000'))
|
||||||
|
log.info(data[field_name][res[0][_id].id][i])
|
||||||
|
log.info(res[0][_id].entity.float_vector[i])
|
||||||
assert float(str(vectorInsert)) == float(vectorRes)
|
assert float(str(vectorInsert)) == float(vectorRes)
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
|
@ -3050,19 +3037,14 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
collection_w = self.init_collection_general(prefix, True)[0]
|
collection_w = self.init_collection_general(prefix, True)[0]
|
||||||
|
|
||||||
# 2. search with output field vector
|
# 2. search with output field vector
|
||||||
res = collection_w.search(vectors[:1], default_search_field,
|
output_fields = [default_float_field_name, default_string_field_name, default_search_field]
|
||||||
|
collection_w.search(vectors[:1], default_search_field,
|
||||||
default_search_params, default_limit, default_search_exp,
|
default_search_params, default_limit, default_search_exp,
|
||||||
output_fields=[default_float_field_name,
|
output_fields=output_fields,
|
||||||
default_string_field_name,
|
|
||||||
default_search_field],
|
|
||||||
check_task=CheckTasks.check_search_results,
|
check_task=CheckTasks.check_search_results,
|
||||||
check_items={"nq": 1,
|
check_items={"nq": 1,
|
||||||
"limit": default_limit})[0]
|
"limit": default_limit,
|
||||||
|
"output_fields": output_fields})
|
||||||
# 3. check the result
|
|
||||||
assert default_float_field_name, default_string_field_name in res[0][0].entity._row_data
|
|
||||||
assert default_search_field in res[0][0].entity._row_data
|
|
||||||
assert len(res[0][0].entity._row_data) == 3
|
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
def test_search_output_vector_field_and_pk_field(self):
|
def test_search_output_vector_field_and_pk_field(self):
|
||||||
|
@ -3078,17 +3060,13 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
|
|
||||||
# 2. search with output field vector
|
# 2. search with output field vector
|
||||||
output_fields = [default_int64_field_name, default_string_field_name, default_search_field]
|
output_fields = [default_int64_field_name, default_string_field_name, default_search_field]
|
||||||
res = collection_w.search(vectors[:1], default_search_field,
|
collection_w.search(vectors[:1], default_search_field,
|
||||||
default_search_params, default_limit, default_search_exp,
|
default_search_params, default_limit, default_search_exp,
|
||||||
output_fields=output_fields,
|
output_fields=output_fields,
|
||||||
check_task=CheckTasks.check_search_results,
|
check_task=CheckTasks.check_search_results,
|
||||||
check_items={"nq": 1,
|
check_items={"nq": 1,
|
||||||
"limit": default_limit})[0]
|
"limit": default_limit,
|
||||||
|
"output_fields": output_fields})
|
||||||
# 3. check the result variables
|
|
||||||
assert default_int64_field_name, default_string_field_name in res[0][0].entity._row_data
|
|
||||||
assert default_search_field in res[0][0].entity._row_data
|
|
||||||
assert len(res[0][0].entity._row_data) == 3
|
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
def test_search_output_field_vector_with_partition(self):
|
def test_search_output_field_vector_with_partition(self):
|
||||||
|
@ -3129,8 +3107,9 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
assert str(vectorInsert) == vectorRes
|
assert str(vectorInsert) == vectorRes
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
@pytest.mark.parametrize("output_fields", [["*"], ["*", default_float_field_name]])
|
@pytest.mark.parametrize("wildcard_output_fields", [["*"], ["*", default_float_field_name], ["*", default_search_field],
|
||||||
def test_search_with_output_field_wildcard(self, output_fields, auto_id, _async):
|
["%"], ["%", default_float_field_name], ["*", "%"]])
|
||||||
|
def test_search_with_output_field_wildcard(self, wildcard_output_fields, auto_id, _async):
|
||||||
"""
|
"""
|
||||||
target: test search with output fields using wildcard
|
target: test search with output fields using wildcard
|
||||||
method: search with one output_field (wildcard)
|
method: search with one output_field (wildcard)
|
||||||
|
@ -3141,21 +3120,17 @@ class TestCollectionSearch(TestcaseBase):
|
||||||
auto_id=auto_id)[0:4]
|
auto_id=auto_id)[0:4]
|
||||||
# 2. search
|
# 2. search
|
||||||
log.info("test_search_with_output_field_wildcard: Searching collection %s" % collection_w.name)
|
log.info("test_search_with_output_field_wildcard: Searching collection %s" % collection_w.name)
|
||||||
|
output_fields = cf.get_wildcard_output_field_names(collection_w, wildcard_output_fields)
|
||||||
res = collection_w.search(vectors[:default_nq], default_search_field,
|
collection_w.search(vectors[:default_nq], default_search_field,
|
||||||
default_search_params, default_limit,
|
default_search_params, default_limit,
|
||||||
default_search_exp, _async=_async,
|
default_search_exp, _async=_async,
|
||||||
output_fields=output_fields,
|
output_fields=wildcard_output_fields,
|
||||||
check_task=CheckTasks.check_search_results,
|
check_task=CheckTasks.check_search_results,
|
||||||
check_items={"nq": default_nq,
|
check_items={"nq": default_nq,
|
||||||
"ids": insert_ids,
|
"ids": insert_ids,
|
||||||
"limit": default_limit,
|
"limit": default_limit,
|
||||||
"_async": _async})[0]
|
"_async": _async,
|
||||||
if _async:
|
"output_fields": output_fields})
|
||||||
res.done()
|
|
||||||
res = res.result()
|
|
||||||
assert len(res[0][0].entity._row_data) != 0
|
|
||||||
assert (default_int64_field_name and default_float_field_name) in res[0][0].entity._row_data
|
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
def test_search_multi_collections(self, nb, nq, dim, auto_id, _async):
|
def test_search_multi_collections(self, nb, nq, dim, auto_id, _async):
|
||||||
|
|
Loading…
Reference in New Issue