enhance: Add a group by case (#29939)

Related issue: #29883 
xfail for now.

Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
pull/29957/head
yanliang567 2024-01-13 01:06:51 +08:00 committed by GitHub
parent 8c89ad694e
commit c1b0562d21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 3 deletions

View File

@ -481,7 +481,8 @@ def gen_dataframe_multi_string_fields(string_fields, nb=ct.default_nb):
return df
def gen_dataframe_all_data_type(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True, random_primary_key=False):
def gen_dataframe_all_data_type(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True,
auto_id=False, random_primary_key=False):
if not random_primary_key:
int64_values = pd.Series(data=[i for i in range(start, start + nb)])
else:
@ -511,6 +512,8 @@ def gen_dataframe_all_data_type(nb=ct.default_nb, dim=ct.default_dim, start=0, w
})
if with_json is False:
df.drop(ct.default_json_field_name, axis=1, inplace=True)
if auto_id:
df.drop(ct.default_int64_field_name, axis=1, inplace=True)
return df

View File

@ -4438,7 +4438,7 @@ class TestCollectionSearch(TestcaseBase):
res1 = collection_w.search(vector[:default_nq], default_search_field, search_params, limit)[0]
res2 = collection_w.search(vector[:default_nq], default_search_field, search_params, limit)[0]
for i in range(default_nq):
res1[i].ids == res2[i].ids
assert res1[i].ids == res2[i].ids
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("index, params", zip(ct.all_index_types[1:4], ct.default_index_params[1:4]))
@ -4464,7 +4464,7 @@ class TestCollectionSearch(TestcaseBase):
res1 = collection_w.search(vector, default_search_field, search_params, limit)[0]
res2 = collection_w.search(vector, default_search_field, search_params, limit * 2)[0]
for i in range(default_nq):
res1[i].ids == res2[i].ids[limit:]
assert res1[i].ids == res2[i].ids[limit:]
class TestSearchBase(TestcaseBase):
@ -9572,3 +9572,78 @@ class TestSearchIterator(TestcaseBase):
check_task=CheckTasks.err_res,
check_items={"err_code": 1,
"err_msg": "Not support multiple vector iterator at present"})
class TestSearchGroupBy(TestcaseBase):
""" Test case of search group by """
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("metric", ct.float_metrics)
@pytest.mark.xfail(reason="issue #29883")
def test_search_group_by(self, metric):
"""
target: test search group by
method: 1. create a collection with data
2. create index with different metric types
3. search with group by
verify no duplicate values for group_by_field
4. search with filtering every value of group_by_field
verify: verify that every record in groupby results is the top1 for that value of the group_by_field
"""
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
is_all_data_type=True, with_json=False)[0]
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
# insert with the same values for scalar fields
for _ in range(30):
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
collection_w.insert(data)
collection_w.flush()
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
time.sleep(30)
collection_w.load()
search_params = {"metric_type": metric, "params": {"ef": 128}}
nq = 2
limit = 10
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
# verify the results are same if gourp by pk
res1 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG,
group_by_field=ct.default_int64_field_name)[0]
res2 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG)[0]
for i in range(nq):
assert res1[i].ids == res2[i].ids
# verify that every record in groupby results is the top1 for that value of the group_by_field
supported_grpby_fields = [ct.default_int8_field_name, ct.default_int16_field_name,
ct.default_int32_field_name, ct.default_bool_field_name,
ct.default_string_field_name]
for grpby_field in supported_grpby_fields:
res1 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
param=search_params, limit=limit,
group_by_field=grpby_field,
output_fields=[grpby_field])[0]
for i in range(nq):
grpby_values = []
results_num = 2 if grpby_field == ct.default_bool_field_name else limit
for l in range(results_num):
top1 = res1[i][l]
top1_grpby_pk = top1.id
top1_grpby_value = top1.fields.get(grpby_field)
expr = f"{grpby_field}=={top1_grpby_value}"
if grpby_field == ct.default_string_field_name:
expr = f"{grpby_field}=='{top1_grpby_value}'"
grpby_values.append(top1_grpby_value)
res_tmp = collection_w.search(data=[search_vectors[i]], anns_field=ct.default_float_vec_field_name,
param=search_params, limit=1,
expr=expr,
output_fields=[grpby_field])[0]
top1_expr_pk = res_tmp[0][0].id
assert top1_grpby_pk == top1_expr_pk
# verify no dup values of the group_by_field in results
assert len(grpby_values) == len(set(grpby_values))