mirror of https://github.com/milvus-io/milvus.git
enhance: Add a group by case (#29939)
Related issue: #29883 xfail for now. Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>pull/29957/head
parent
8c89ad694e
commit
c1b0562d21
|
@ -481,7 +481,8 @@ def gen_dataframe_multi_string_fields(string_fields, nb=ct.default_nb):
|
|||
return df
|
||||
|
||||
|
||||
def gen_dataframe_all_data_type(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True, random_primary_key=False):
|
||||
def gen_dataframe_all_data_type(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True,
|
||||
auto_id=False, random_primary_key=False):
|
||||
if not random_primary_key:
|
||||
int64_values = pd.Series(data=[i for i in range(start, start + nb)])
|
||||
else:
|
||||
|
@ -511,6 +512,8 @@ def gen_dataframe_all_data_type(nb=ct.default_nb, dim=ct.default_dim, start=0, w
|
|||
})
|
||||
if with_json is False:
|
||||
df.drop(ct.default_json_field_name, axis=1, inplace=True)
|
||||
if auto_id:
|
||||
df.drop(ct.default_int64_field_name, axis=1, inplace=True)
|
||||
|
||||
return df
|
||||
|
||||
|
|
|
@ -4438,7 +4438,7 @@ class TestCollectionSearch(TestcaseBase):
|
|||
res1 = collection_w.search(vector[:default_nq], default_search_field, search_params, limit)[0]
|
||||
res2 = collection_w.search(vector[:default_nq], default_search_field, search_params, limit)[0]
|
||||
for i in range(default_nq):
|
||||
res1[i].ids == res2[i].ids
|
||||
assert res1[i].ids == res2[i].ids
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("index, params", zip(ct.all_index_types[1:4], ct.default_index_params[1:4]))
|
||||
|
@ -4464,7 +4464,7 @@ class TestCollectionSearch(TestcaseBase):
|
|||
res1 = collection_w.search(vector, default_search_field, search_params, limit)[0]
|
||||
res2 = collection_w.search(vector, default_search_field, search_params, limit * 2)[0]
|
||||
for i in range(default_nq):
|
||||
res1[i].ids == res2[i].ids[limit:]
|
||||
assert res1[i].ids == res2[i].ids[limit:]
|
||||
|
||||
|
||||
class TestSearchBase(TestcaseBase):
|
||||
|
@ -9572,3 +9572,78 @@ class TestSearchIterator(TestcaseBase):
|
|||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": 1,
|
||||
"err_msg": "Not support multiple vector iterator at present"})
|
||||
|
||||
|
||||
class TestSearchGroupBy(TestcaseBase):
|
||||
""" Test case of search group by """
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("metric", ct.float_metrics)
|
||||
@pytest.mark.xfail(reason="issue #29883")
|
||||
def test_search_group_by(self, metric):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with data
|
||||
2. create index with different metric types
|
||||
3. search with group by
|
||||
verify no duplicate values for group_by_field
|
||||
4. search with filtering every value of group_by_field
|
||||
verify: verify that every record in groupby results is the top1 for that value of the group_by_field
|
||||
"""
|
||||
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
|
||||
is_all_data_type=True, with_json=False)[0]
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
# insert with the same values for scalar fields
|
||||
for _ in range(30):
|
||||
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
|
||||
collection_w.insert(data)
|
||||
|
||||
collection_w.flush()
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
time.sleep(30)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": metric, "params": {"ef": 128}}
|
||||
nq = 2
|
||||
limit = 10
|
||||
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
|
||||
|
||||
# verify the results are same if gourp by pk
|
||||
res1 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG,
|
||||
group_by_field=ct.default_int64_field_name)[0]
|
||||
res2 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG)[0]
|
||||
for i in range(nq):
|
||||
assert res1[i].ids == res2[i].ids
|
||||
|
||||
# verify that every record in groupby results is the top1 for that value of the group_by_field
|
||||
supported_grpby_fields = [ct.default_int8_field_name, ct.default_int16_field_name,
|
||||
ct.default_int32_field_name, ct.default_bool_field_name,
|
||||
ct.default_string_field_name]
|
||||
for grpby_field in supported_grpby_fields:
|
||||
res1 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=grpby_field,
|
||||
output_fields=[grpby_field])[0]
|
||||
for i in range(nq):
|
||||
grpby_values = []
|
||||
results_num = 2 if grpby_field == ct.default_bool_field_name else limit
|
||||
for l in range(results_num):
|
||||
top1 = res1[i][l]
|
||||
top1_grpby_pk = top1.id
|
||||
top1_grpby_value = top1.fields.get(grpby_field)
|
||||
expr = f"{grpby_field}=={top1_grpby_value}"
|
||||
if grpby_field == ct.default_string_field_name:
|
||||
expr = f"{grpby_field}=='{top1_grpby_value}'"
|
||||
grpby_values.append(top1_grpby_value)
|
||||
res_tmp = collection_w.search(data=[search_vectors[i]], anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=1,
|
||||
expr=expr,
|
||||
output_fields=[grpby_field])[0]
|
||||
top1_expr_pk = res_tmp[0][0].id
|
||||
assert top1_grpby_pk == top1_expr_pk
|
||||
# verify no dup values of the group_by_field in results
|
||||
assert len(grpby_values) == len(set(grpby_values))
|
||||
|
||||
|
|
Loading…
Reference in New Issue