Add output fields value check (#24255)

Signed-off-by: nico <cheng.yuan@zilliz.com>
pull/25293/head
nico 2023-07-07 16:58:25 +08:00 committed by GitHub
parent 342cfcad46
commit 9b64f12a6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 87 additions and 75 deletions

View File

@ -296,6 +296,9 @@ class ResponseChecker:
if check_items.get("output_fields", None):
assert set(search_res[0][0].entity.fields) == set(check_items["output_fields"])
log.info('search_results_check: Output fields of query searched is correct')
if check_items.get("original_entities", None):
original_entities = check_items["original_entities"][0]
pc.output_field_value_check(search_res, original_entities)
if len(search_res) != check_items["nq"]:
log.error("search_results_check: Numbers of query searched (%d) "
"is not equal with expected (%d)"

View File

@ -223,4 +223,24 @@ def equal_entities_list(exp, actual, primary_field, with_vec=False):
exp.remove(a)
except Exception as ex:
log.error(ex)
return True if len(exp) == 0 else False
return True if len(exp) == 0 else False
def output_field_value_check(search_res, original):
"""
check if the value of output fields is correct
:param search_res: the search result of specific output fields
:param original: the data in the collection
:return: True or False
"""
limit = len(search_res[0])
for i in range(limit):
entity = eval(str(search_res[0][i]).split('entity: ', 1)[1])
_id = search_res[0][i].id
for field in entity.keys():
if isinstance(entity[field], list):
for order in range(0, len(entity[field]), 4):
assert abs(original[field][_id][order] - entity[field][order]) < ct.epsilon
else:
assert original[field][_id] == entity[field]
return True

View File

@ -1101,17 +1101,7 @@ def install_milvus_operator_specific_config(namespace, milvus_mode, release_name
def get_wildcard_output_field_names(collection_w, output_fields):
all_fields = collection_w.schema.fields
scalar_fields = []
vector_field = []
for field in all_fields:
if field.dtype == DataType.FLOAT_VECTOR:
vector_field.append(field.name)
else:
scalar_fields.append(field.name)
if "*" in output_fields:
output_fields.remove("*")
output_fields.extend(scalar_fields)
if "%" in output_fields:
output_fields.remove("%")
output_fields.extend(vector_field)
output_fields.extend(all_fields)
return output_fields

View File

@ -632,8 +632,7 @@ class TestQueryParams(TestcaseBase):
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("wildcard_output_fields", [["*"], ["*", default_float_field_name],
["*", default_int_field_name],
["%"], ["%", default_float_field_name], ["*", "%"]])
["*", default_int_field_name]])
def test_query_output_field_wildcard(self, wildcard_output_fields):
"""
target: test query with output fields using wildcard

View File

@ -1752,7 +1752,6 @@ class TestCollectionSearch(TestcaseBase):
"_async": _async})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip(reason="debug")
def test_search_max_dim(self, auto_id, _async):
"""
target: test search with max configuration
@ -1801,7 +1800,6 @@ class TestCollectionSearch(TestcaseBase):
"limit": nq,
"_async": _async})
@pytest.mark.xfail(reason="issue #19129")
@pytest.mark.tags(CaseLabel.L2)
def test_search_max_nq(self, auto_id, dim, _async):
"""
@ -1810,7 +1808,7 @@ class TestCollectionSearch(TestcaseBase):
expected: search successfully with max nq
"""
self._connect()
nq = 17000
nq = 16384
collection_w, _, _, insert_ids = self.init_collection_general(prefix, True,
auto_id=auto_id,
dim=dim)[0:4]
@ -3225,9 +3223,7 @@ class TestCollectionSearch(TestcaseBase):
expected: search success
"""
# 1. create a collection and insert data
collection_w = self.init_collection_general(prefix, is_index=False)[0]
data = cf.gen_default_dataframe_data()
collection_w.insert(data)
collection_w, _vectors = self.init_collection_general(prefix, True, is_index=False)[:2]
# 2. create index and load
default_index = {"index_type": index, "params": params, "metric_type": metrics}
@ -3236,25 +3232,14 @@ class TestCollectionSearch(TestcaseBase):
# 3. search with output field vector
search_params = cf.gen_search_param(index, metrics)[0]
res = collection_w.search(vectors[:1], default_search_field,
search_params, default_limit, default_search_exp,
output_fields=[field_name],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"limit": default_limit,
"output_fields": [field_name]})[0]
# 4. check the result vectors should be equal to the inserted
for _id in range(default_limit):
for i in range(default_dim):
vectorInsert = str(data[field_name][res[0][_id].id][i])[:7]
vectorRes = str(res[0][_id].entity.float_vector[i])[:7]
if vectorInsert != vectorRes:
getcontext().rounding = getattr(decimal, 'ROUND_HALF_UP')
vectorInsert = Decimal(data[field_name][res[0][_id].id][i]).quantize(Decimal('0.00000'))
log.info(data[field_name][res[0][_id].id][i])
log.info(res[0][_id].entity.float_vector[i])
assert float(str(vectorInsert)) == float(vectorRes)
collection_w.search(vectors[:1], default_search_field,
search_params, default_limit, default_search_exp,
output_fields=[field_name],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"limit": default_limit,
"original_entities": _vectors,
"output_fields": [field_name]})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip(reason="issue #23661")
@ -3304,27 +3289,18 @@ class TestCollectionSearch(TestcaseBase):
expected: search success
"""
# 1. create a collection and insert data
collection_w = self.init_collection_general(prefix, is_index=False, dim=dim)[0]
data = cf.gen_default_dataframe_data(dim=dim)
collection_w.insert(data)
collection_w, _vectors = self.init_collection_general(prefix, True, dim=dim)[:2]
# 2. create index and load
index_params = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "COSINE"}
collection_w.create_index("float_vector", index_params)
collection_w.load()
# 3. search with output field vector
# 2. search with output field vector
vectors = cf.gen_vectors(default_nq, dim=dim)
res = collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit, default_search_exp,
output_fields=[field_name],
check_task=CheckTasks.check_search_results,
check_items={"nq": default_nq,
"limit": default_limit})[0]
# 4. check the result vectors should be equal to the inserted
for i in range(default_limit):
assert len(res[0][i].entity.float_vector) == len(data[field_name][res[0][i].id])
collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit, default_search_exp,
output_fields=[field_name],
check_task=CheckTasks.check_search_results,
check_items={"nq": default_nq,
"limit": default_limit,
"original_entities": _vectors,
"output_fields": [field_name]})
@pytest.mark.tags(CaseLabel.L2)
def test_search_output_vector_field_and_scalar_field(self, enable_dynamic_field):
@ -3336,16 +3312,28 @@ class TestCollectionSearch(TestcaseBase):
expected: search success
"""
# 1. initialize a collection
collection_w = self.init_collection_general(prefix, True, enable_dynamic_field=enable_dynamic_field)[0]
collection_w, _vectors = self.init_collection_general(prefix, True,
enable_dynamic_field=enable_dynamic_field)[:2]
# 2. search with output field vector
output_fields = [default_float_field_name, default_string_field_name, default_search_field]
original_entities = []
if enable_dynamic_field:
entities = []
for vector in _vectors[0]:
entities.append({default_float_field_name: vector[default_float_field_name],
default_string_field_name: vector[default_string_field_name],
default_search_field: vector[default_search_field]})
original_entities.append(pd.DataFrame(entities))
else:
original_entities = _vectors
collection_w.search(vectors[:1], default_search_field,
default_search_params, default_limit, default_search_exp,
output_fields=output_fields,
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"limit": default_limit,
"original_entities": original_entities,
"output_fields": output_fields})
@pytest.mark.tags(CaseLabel.L2)
@ -3391,22 +3379,14 @@ class TestCollectionSearch(TestcaseBase):
collection_w.load()
# 3. search with output field vector
res = partition_w.search(vectors[:1], default_search_field,
default_search_params, default_limit, default_search_exp,
output_fields=[field_name],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"limit": default_limit})[0]
# 4. check the result vectors should be equal to the inserted
for _id in range(default_limit):
for i in range(default_dim):
vectorInsert = str(data[field_name][res[0][_id].id][i])[:7]
vectorRes = str(res[0][_id].entity.float_vector[i])[:7]
if vectorInsert != vectorRes:
getcontext().rounding = getattr(decimal, 'ROUND_HALF_UP')
vectorInsert = Decimal(data[field_name][res[0][_id].id][i]).quantize(Decimal('0.00000'))
assert str(vectorInsert) == vectorRes
partition_w.search(vectors[:1], default_search_field,
default_search_params, default_limit, default_search_exp,
output_fields=[field_name],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"limit": default_limit,
"original_entities": [data],
"output_fields": [field_name]})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("wildcard_output_fields", [["*"], ["*", default_float_field_name],
@ -3437,6 +3417,26 @@ class TestCollectionSearch(TestcaseBase):
"_async": _async,
"output_fields": output_fields})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("invalid_output_fields", [["%"], [""], ["-"]])
def test_search_with_invalid_output_fields(self, invalid_output_fields, auto_id):
"""
target: test search with output fields using wildcard
method: search with one output_field (wildcard)
expected: search success
"""
# 1. initialize with data
collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id)[0:4]
# 2. search
log.info("test_search_with_output_field_wildcard: Searching collection %s" % collection_w.name)
collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit,
default_search_exp,
output_fields=invalid_output_fields,
check_task=CheckTasks.err_res,
check_items={"err_code": 1,
"err_msg": "field %s is not exist" % invalid_output_fields[0]})
@pytest.mark.tags(CaseLabel.L2)
def test_search_multi_collections(self, nb, nq, dim, auto_id, _async):
"""