mirror of https://github.com/milvus-io/milvus.git
enhance: Add more tests for groupby (#30346)
Related issue: #30033 skip the tests before bug fixes --------- Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>pull/30413/head
parent
878c4c9463
commit
54150253e7
|
@ -540,7 +540,7 @@ def gen_default_rows_data_all_data_type(nb=ct.default_nb, dim=ct.default_dim, st
|
|||
return array
|
||||
|
||||
|
||||
def gen_default_binary_dataframe_data(nb=ct.default_nb, dim=ct.default_dim, start=0):
|
||||
def gen_default_binary_dataframe_data(nb=ct.default_nb, dim=ct.default_dim, start=0, auto_id=False):
|
||||
int_values = pd.Series(data=[i for i in range(start, start + nb)])
|
||||
float_values = pd.Series(data=[np.float32(i) for i in range(start, start + nb)], dtype="float32")
|
||||
string_values = pd.Series(data=[str(i) for i in range(start, start + nb)], dtype="string")
|
||||
|
@ -551,6 +551,12 @@ def gen_default_binary_dataframe_data(nb=ct.default_nb, dim=ct.default_dim, star
|
|||
ct.default_string_field_name: string_values,
|
||||
ct.default_binary_vec_field_name: binary_vec_values
|
||||
})
|
||||
if auto_id is True:
|
||||
df = pd.DataFrame({
|
||||
ct.default_float_field_name: float_values,
|
||||
ct.default_string_field_name: string_values,
|
||||
ct.default_binary_vec_field_name: binary_vec_values
|
||||
})
|
||||
return df, binary_raw_values
|
||||
|
||||
|
||||
|
|
|
@ -9577,10 +9577,10 @@ class TestSearchIterator(TestcaseBase):
|
|||
class TestSearchGroupBy(TestcaseBase):
|
||||
""" Test case of search group by """
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("metric", ct.float_metrics)
|
||||
@pytest.mark.xfail(reason="issue #29883")
|
||||
def test_search_group_by(self, metric):
|
||||
@pytest.mark.skip(reason="issue #29883")
|
||||
def test_search_group_by_default(self, metric):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with data
|
||||
|
@ -9647,3 +9647,390 @@ class TestSearchGroupBy(TestcaseBase):
|
|||
# verify no dup values of the group_by_field in results
|
||||
assert len(grpby_values) == len(set(grpby_values))
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("metric", ["JACCARD", "HAMMING"])
|
||||
@pytest.mark.skip(reason="issue #29883")
|
||||
def test_search_binary_vec_group_by(self, metric):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with binary vectors
|
||||
2. create index with different metric types
|
||||
3. search with group by
|
||||
verify no duplicate values for group_by_field
|
||||
4. search with filtering every value of group_by_field
|
||||
verify: verify that every record in groupby results is the top1 for that value of the group_by_field
|
||||
"""
|
||||
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
|
||||
is_binary=True)[0]
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
collection_w.create_index(ct.default_binary_vec_field_name, index_params=_index)
|
||||
# insert with the same values for scalar fields
|
||||
for _ in range(30):
|
||||
data = cf.gen_default_binary_dataframe_data(nb=100, auto_id=True)[0]
|
||||
collection_w.insert(data)
|
||||
|
||||
collection_w.flush()
|
||||
collection_w.create_index(ct.default_binary_vec_field_name, index_params=_index)
|
||||
time.sleep(30)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": metric, "params": {"ef": 128}}
|
||||
nq = 2
|
||||
limit = 10
|
||||
search_vectors = cf.gen_binary_vectors(nq, dim=ct.default_dim)[1]
|
||||
|
||||
# verify the results are same if gourp by pk
|
||||
res1 = collection_w.search(data=search_vectors, anns_field=ct.default_binary_vec_field_name,
|
||||
param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG,
|
||||
group_by_field=ct.default_int64_field_name)[0]
|
||||
res2 = collection_w.search(data=search_vectors, anns_field=ct.default_binary_vec_field_name,
|
||||
param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG)[0]
|
||||
# for i in range(nq):
|
||||
# assert res1[i].ids == res2[i].ids
|
||||
|
||||
# verify that every record in groupby results is the top1 for that value of the group_by_field
|
||||
supported_grpby_fields = [ct.default_string_field_name]
|
||||
for grpby_field in supported_grpby_fields:
|
||||
res1 = collection_w.search(data=search_vectors, anns_field=ct.default_binary_vec_field_name,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=grpby_field,
|
||||
output_fields=[grpby_field])[0]
|
||||
for i in range(nq):
|
||||
grpby_values = []
|
||||
results_num = 2 if grpby_field == ct.default_bool_field_name else limit
|
||||
for l in range(results_num):
|
||||
top1 = res1[i][l]
|
||||
top1_grpby_pk = top1.id
|
||||
top1_grpby_value = top1.fields.get(grpby_field)
|
||||
expr = f"{grpby_field}=={top1_grpby_value}"
|
||||
if grpby_field == ct.default_string_field_name:
|
||||
expr = f"{grpby_field}=='{top1_grpby_value}'"
|
||||
grpby_values.append(top1_grpby_value)
|
||||
res_tmp = collection_w.search(data=[search_vectors[i]], anns_field=ct.default_binary_vec_field_name,
|
||||
param=search_params, limit=1,
|
||||
expr=expr,
|
||||
output_fields=[grpby_field])[0]
|
||||
top1_expr_pk = res_tmp[0][0].id
|
||||
assert top1_grpby_pk == top1_expr_pk
|
||||
# verify no dup values of the group_by_field in results
|
||||
assert len(grpby_values) == len(set(grpby_values))
|
||||
|
||||
@pytest.mark.skip(reason="issue #29883")
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("grpby_field", [ct.default_string_field_name, ct.default_int8_field_name])
|
||||
def test_search_group_by_with_field_indexed(self, grpby_field):
|
||||
"""
|
||||
target: test search group by with the field indexed
|
||||
method: 1. create a collection with data
|
||||
2. create index for the vector field and the groupby field
|
||||
3. search with group by
|
||||
4. search with filtering every value of group_by_field
|
||||
verify: verify that every record in groupby results is the top1 for that value of the group_by_field
|
||||
"""
|
||||
metric = "COSINE"
|
||||
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
|
||||
is_all_data_type=True, with_json=False)[0]
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
# insert with the same values(by insert rounds) for scalar fields
|
||||
for _ in range(100):
|
||||
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
|
||||
collection_w.insert(data)
|
||||
|
||||
collection_w.flush()
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
collection_w.create_index(grpby_field)
|
||||
time.sleep(30)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": metric, "params": {"ef": 128}}
|
||||
nq = 2
|
||||
limit = 20
|
||||
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
|
||||
|
||||
# verify that every record in groupby results is the top1 for that value of the group_by_field
|
||||
res1 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=grpby_field,
|
||||
output_fields=[grpby_field])[0]
|
||||
for i in range(nq):
|
||||
grpby_values = []
|
||||
results_num = 2 if grpby_field == ct.default_bool_field_name else limit
|
||||
for l in range(results_num):
|
||||
top1 = res1[i][l]
|
||||
top1_grpby_pk = top1.id
|
||||
top1_grpby_value = top1.fields.get(grpby_field)
|
||||
expr = f"{grpby_field}=={top1_grpby_value}"
|
||||
if grpby_field == ct.default_string_field_name:
|
||||
expr = f"{grpby_field}=='{top1_grpby_value}'"
|
||||
grpby_values.append(top1_grpby_value)
|
||||
res_tmp = collection_w.search(data=[search_vectors[i]], anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=1,
|
||||
expr=expr,
|
||||
output_fields=[grpby_field])[0]
|
||||
top1_expr_pk = res_tmp[0][0].id
|
||||
log.info(f"nq={i}, limit={l}")
|
||||
assert top1_grpby_pk == top1_expr_pk
|
||||
# verify no dup values of the group_by_field in results
|
||||
assert len(grpby_values) == len(set(grpby_values))
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.skip(reason="issue #29967")
|
||||
@pytest.mark.parametrize("grpby_unsupported_field", [ct.default_float_field_name, ct.default_json_field_name,
|
||||
ct.default_double_field_name, ct.default_float_vec_field_name])
|
||||
def test_search_group_by_unsupported_filed(self, grpby_unsupported_field):
|
||||
"""
|
||||
target: test search group by with the unsupported field
|
||||
method: 1. create a collection with data
|
||||
2. create index
|
||||
3. search with group by the unsupported fields
|
||||
verify: the error code and msg
|
||||
"""
|
||||
metric = "IP"
|
||||
collection_w = self.init_collection_general(prefix, insert_data=True, is_index=False,
|
||||
is_all_data_type=True, with_json=True,)[0]
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": metric, "params": {"ef": 128}}
|
||||
nq = 1
|
||||
limit = 1
|
||||
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
|
||||
|
||||
# search with groupby
|
||||
err_code = 999
|
||||
err_msg = "unsupported"
|
||||
collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=grpby_unsupported_field,
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": err_code, "err_msg": err_msg})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("index, params",
|
||||
zip(ct.all_index_types[:7],
|
||||
ct.default_index_params[:7]))
|
||||
def test_search_group_by_unsupported_index(self, index, params):
|
||||
"""
|
||||
target: test search group by with the unsupported vector index
|
||||
method: 1. create a collection with data
|
||||
2. create a groupby unsupported index
|
||||
3. search with group by
|
||||
verify: the error code and msg
|
||||
"""
|
||||
if index == "HNSW":
|
||||
pass # HNSW is supported
|
||||
else:
|
||||
metric = "L2"
|
||||
collection_w = self.init_collection_general(prefix, insert_data=True, is_index=False,
|
||||
is_all_data_type=True, with_json=False)[0]
|
||||
index_params = {"index_type": index, "params": params, "metric_type": metric}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"params": {}}
|
||||
nq = 1
|
||||
limit = 1
|
||||
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
|
||||
|
||||
# search with groupby
|
||||
err_code = 999
|
||||
err_msg = "Unexpected index"
|
||||
if index in ["IVF_FLAT", "IVF_SQ8", "IVF_PQ", "SCANN"]:
|
||||
err_msg = "not supported for current index type"
|
||||
collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=ct.default_int8_field_name,
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": err_code, "err_msg": err_msg})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("grpby_nonexist_field", ["nonexit_field", 100])
|
||||
def test_search_group_by_nonexit_filed(self, grpby_nonexist_field):
|
||||
"""
|
||||
target: test search group by with the nonexisting field
|
||||
method: 1. create a collection with data
|
||||
2. create index
|
||||
3. search with group by the unsupported fields
|
||||
verify: the error code and msg
|
||||
"""
|
||||
metric = "IP"
|
||||
collection_w = self.init_collection_general(prefix, insert_data=True, is_index=False,
|
||||
is_all_data_type=True, with_json=True, )[0]
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": metric, "params": {"ef": 128}}
|
||||
nq = 1
|
||||
limit = 1
|
||||
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
|
||||
|
||||
# search with groupby
|
||||
err_code = 1700
|
||||
err_msg = f"groupBy field not found in schema: field not found[field={grpby_nonexist_field}]"
|
||||
collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=grpby_nonexist_field,
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": err_code, "err_msg": err_msg})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.skip(reason="issue #30033")
|
||||
def test_search_pagination_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with data
|
||||
2. create index HNSW
|
||||
3. search with groupby and pagination
|
||||
4. search with groupby and limits=pages*page_rounds
|
||||
verify: search with groupby and pagination returns correct results
|
||||
"""
|
||||
# 1. create a collection
|
||||
metric = "COSINE"
|
||||
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
|
||||
is_all_data_type=True, with_json=False)[0]
|
||||
# insert with the same values for scalar fields
|
||||
for _ in range(50):
|
||||
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
|
||||
collection_w.insert(data)
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
collection_w.load()
|
||||
# 2. search pagination with offset
|
||||
limit = 10
|
||||
page_rounds = 3
|
||||
search_param = {"metric_type": metric}
|
||||
grpby_field = ct.default_string_field_name
|
||||
search_vectors = cf.gen_vectors(1, dim=ct.default_dim)
|
||||
all_pages_ids = []
|
||||
for r in range(page_rounds):
|
||||
page_res = collection_w.search(search_vectors, anns_field=default_search_field,
|
||||
param=search_param, limit=limit, offset=limit * r,
|
||||
expr=default_search_exp, group_by_field=grpby_field,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1, "limit": limit},
|
||||
)[0]
|
||||
all_pages_ids += page_res[0].ids
|
||||
|
||||
total_res = collection_w.search(search_vectors, anns_field=default_search_field,
|
||||
param=search_param, limit=limit * page_rounds,
|
||||
expr=default_search_exp, group_by_field=grpby_field,
|
||||
output_fields=[grpby_field],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1, "limit": limit * page_rounds}
|
||||
)[0]
|
||||
assert total_res[0].ids == all_pages_ids
|
||||
grpby_field_values = []
|
||||
for i in range(limit * page_rounds):
|
||||
grpby_field_values.append(total_res[0][i].fields.get(grpby_field))
|
||||
assert len(grpby_field_values) == len(set(grpby_field_values))
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.skip(reason="issue #30033")
|
||||
def test_search_iterator_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with data
|
||||
2. create index HNSW
|
||||
3. search iterator with group by
|
||||
4. search with filtering every value of group_by_field
|
||||
verify: verify successfully and iterators are correct
|
||||
"""
|
||||
metric = "COSINE"
|
||||
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
|
||||
is_all_data_type=True, with_json=False)[0]
|
||||
# insert with the same values for scalar fields
|
||||
value_num = 50
|
||||
for _ in range(value_num):
|
||||
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
|
||||
collection_w.insert(data)
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
collection_w.load()
|
||||
|
||||
grpby_field = ct.default_int64_field_name
|
||||
search_vectors = cf.gen_vectors(1, dim=ct.default_dim)
|
||||
search_params = {"metric_type": metric}
|
||||
batch_size = 10
|
||||
|
||||
# res = collection_w.search(search_vectors,ct.default_float_vec_field_name,
|
||||
# search_params, group_by_field=grpby_field, limit=10)[0]
|
||||
|
||||
ite_res = collection_w.search_iterator(search_vectors, ct.default_float_vec_field_name,
|
||||
search_params, batch_size, group_by_field=grpby_field
|
||||
)[0]
|
||||
iterators = 0
|
||||
while True:
|
||||
res = ite_res.next() # turn to the next page
|
||||
if len(res) == 0:
|
||||
ite_res.close() # close the iterator
|
||||
break
|
||||
iterators += 1
|
||||
assert iterators == value_num/batch_size
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_range_search_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with data
|
||||
2. create index hnsw
|
||||
3. range search with group by
|
||||
verify: the error code and msg
|
||||
"""
|
||||
metric = "COSINE"
|
||||
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
|
||||
is_all_data_type=True, with_json=False)[0]
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
# insert with the same values for scalar fields
|
||||
for _ in range(30):
|
||||
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
|
||||
collection_w.insert(data)
|
||||
|
||||
collection_w.flush()
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
time.sleep(10)
|
||||
collection_w.load()
|
||||
|
||||
nq = 1
|
||||
limit = 10
|
||||
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
|
||||
grpby_field = ct.default_int32_field_name
|
||||
range_search_params = {"metric_type": "COSINE", "params": {"radius": 0.1,
|
||||
"range_filter": 0.5}}
|
||||
res = collection_w.search(search_vectors, ct.default_float_vec_field_name,
|
||||
range_search_params, limit,
|
||||
default_search_exp, group_by_field=grpby_field,
|
||||
output_fields=[grpby_field],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": nq, "limit": limit})[0]
|
||||
grpby_field_values = []
|
||||
for i in range(limit):
|
||||
grpby_field_values.append(res[0][i].fields.get(grpby_field))
|
||||
assert len(grpby_field_values) == len(set(grpby_field_values))
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip(reason="not completed")
|
||||
def test_hybrid_search_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with multiple vector fields
|
||||
2. create index hnsw and hnsw
|
||||
3. hybrid_search with group by
|
||||
verify: the error code and msg
|
||||
"""
|
||||
pass
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.skip(reason="not completed")
|
||||
def test_multi_vectors_search_one_vector_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with multiple vector fields
|
||||
2. create index hnsw and ivfflat
|
||||
3. search on the vector with hnsw index with group by
|
||||
verify: search successfully
|
||||
"""
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue