mirror of https://github.com/milvus-io/milvus.git
				
				
				
			test: Add tests for hybrid search group by (#36326)
related issue: #36295 --------- Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>pull/36373/head
							parent
							
								
									167e4fb10d
								
							
						
					
					
						commit
						e013ef1908
					
				| 
						 | 
					@ -806,10 +806,10 @@ def gen_default_dataframe_data(nb=ct.default_nb, dim=ct.default_dim, start=0, wi
 | 
				
			||||||
    return df
 | 
					    return df
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def gen_general_default_list_data(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True,
 | 
					def gen_default_list_data(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True,
 | 
				
			||||||
                                  random_primary_key=False, multiple_dim_array=[], multiple_vector_field_name=[],
 | 
					                          random_primary_key=False, multiple_dim_array=[], multiple_vector_field_name=[],
 | 
				
			||||||
                                  vector_data_type="FLOAT_VECTOR", auto_id=False,
 | 
					                          vector_data_type="FLOAT_VECTOR", auto_id=False,
 | 
				
			||||||
                                  primary_field=ct.default_int64_field_name, nullable_fields={}):
 | 
					                          primary_field=ct.default_int64_field_name, nullable_fields={}):
 | 
				
			||||||
    insert_list = []
 | 
					    insert_list = []
 | 
				
			||||||
    if not random_primary_key:
 | 
					    if not random_primary_key:
 | 
				
			||||||
        int_values = pd.Series(data=[i for i in range(start, start + nb)])
 | 
					        int_values = pd.Series(data=[i for i in range(start, start + nb)])
 | 
				
			||||||
| 
						 | 
					@ -1244,19 +1244,19 @@ def gen_default_binary_dataframe_data(nb=ct.default_nb, dim=ct.default_dim, star
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return df, binary_raw_values
 | 
					    return df, binary_raw_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
def gen_default_list_data(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True):
 | 
					# def gen_default_list_data(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True):
 | 
				
			||||||
    int_values = [i for i in range(start, start + nb)]
 | 
					#     int_values = [i for i in range(start, start + nb)]
 | 
				
			||||||
    float_values = [np.float32(i) for i in range(start, start + nb)]
 | 
					#     float_values = [np.float32(i) for i in range(start, start + nb)]
 | 
				
			||||||
    string_values = [str(i) for i in range(start, start + nb)]
 | 
					#     string_values = [str(i) for i in range(start, start + nb)]
 | 
				
			||||||
    json_values = [{"number": i, "string": str(i), "bool": bool(i), "list": [j for j in range(0, i)]}
 | 
					#     json_values = [{"number": i, "string": str(i), "bool": bool(i), "list": [j for j in range(0, i)]}
 | 
				
			||||||
                   for i in range(start, start + nb)]
 | 
					#                    for i in range(start, start + nb)]
 | 
				
			||||||
    float_vec_values = gen_vectors(nb, dim)
 | 
					#     float_vec_values = gen_vectors(nb, dim)
 | 
				
			||||||
    if with_json is False:
 | 
					#     if with_json is False:
 | 
				
			||||||
        data = [int_values, float_values, string_values, float_vec_values]
 | 
					#         data = [int_values, float_values, string_values, float_vec_values]
 | 
				
			||||||
    else:
 | 
					#     else:
 | 
				
			||||||
        data = [int_values, float_values, string_values, json_values, float_vec_values]
 | 
					#         data = [int_values, float_values, string_values, json_values, float_vec_values]
 | 
				
			||||||
    return data
 | 
					#     return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def gen_default_list_sparse_data(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=False):
 | 
					def gen_default_list_sparse_data(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=False):
 | 
				
			||||||
| 
						 | 
					@ -2347,13 +2347,13 @@ def insert_data(collection_w, nb=ct.default_nb, is_binary=False, is_all_data_typ
 | 
				
			||||||
                                                                  auto_id=auto_id, primary_field=primary_field,
 | 
					                                                                  auto_id=auto_id, primary_field=primary_field,
 | 
				
			||||||
                                                                  nullable_fields=nullable_fields)
 | 
					                                                                  nullable_fields=nullable_fields)
 | 
				
			||||||
                    elif vector_data_type in ct.append_vector_type:
 | 
					                    elif vector_data_type in ct.append_vector_type:
 | 
				
			||||||
                        default_data = gen_general_default_list_data(nb // num, dim=dim, start=start, with_json=with_json,
 | 
					                        default_data = gen_default_list_data(nb // num, dim=dim, start=start, with_json=with_json,
 | 
				
			||||||
                                                                     random_primary_key=random_primary_key,
 | 
					                                                             random_primary_key=random_primary_key,
 | 
				
			||||||
                                                                     multiple_dim_array=multiple_dim_array,
 | 
					                                                             multiple_dim_array=multiple_dim_array,
 | 
				
			||||||
                                                                     multiple_vector_field_name=vector_name_list,
 | 
					                                                             multiple_vector_field_name=vector_name_list,
 | 
				
			||||||
                                                                     vector_data_type=vector_data_type,
 | 
					                                                             vector_data_type=vector_data_type,
 | 
				
			||||||
                                                                     auto_id=auto_id, primary_field=primary_field,
 | 
					                                                             auto_id=auto_id, primary_field=primary_field,
 | 
				
			||||||
                                                                     nullable_fields=nullable_fields)
 | 
					                                                             nullable_fields=nullable_fields)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    default_data = gen_default_rows_data(nb // num, dim=dim, start=start, with_json=with_json,
 | 
					                    default_data = gen_default_rows_data(nb // num, dim=dim, start=start, with_json=with_json,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,7 +60,7 @@ class TestInsertParams(TestcaseBase):
 | 
				
			||||||
        data = cf.gen_default_list_data(ct.default_nb)
 | 
					        data = cf.gen_default_list_data(ct.default_nb)
 | 
				
			||||||
        mutation_res, _ = collection_w.insert(data=data)
 | 
					        mutation_res, _ = collection_w.insert(data=data)
 | 
				
			||||||
        assert mutation_res.insert_count == ct.default_nb
 | 
					        assert mutation_res.insert_count == ct.default_nb
 | 
				
			||||||
        assert mutation_res.primary_keys == data[0]
 | 
					        assert mutation_res.primary_keys == data[0].tolist()
 | 
				
			||||||
        assert collection_w.num_entities == ct.default_nb
 | 
					        assert collection_w.num_entities == ct.default_nb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @pytest.mark.tags(CaseLabel.L2)
 | 
					    @pytest.mark.tags(CaseLabel.L2)
 | 
				
			||||||
| 
						 | 
					@ -214,7 +214,7 @@ class TestInsertParams(TestcaseBase):
 | 
				
			||||||
        data = cf.gen_default_list_data(nb=1)
 | 
					        data = cf.gen_default_list_data(nb=1)
 | 
				
			||||||
        mutation_res, _ = collection_w.insert(data=data)
 | 
					        mutation_res, _ = collection_w.insert(data=data)
 | 
				
			||||||
        assert mutation_res.insert_count == 1
 | 
					        assert mutation_res.insert_count == 1
 | 
				
			||||||
        assert mutation_res.primary_keys == data[0]
 | 
					        assert mutation_res.primary_keys == data[0].tolist()
 | 
				
			||||||
        assert collection_w.num_entities == 1
 | 
					        assert collection_w.num_entities == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @pytest.mark.tags(CaseLabel.L2)
 | 
					    @pytest.mark.tags(CaseLabel.L2)
 | 
				
			||||||
| 
						 | 
					@ -1417,7 +1417,7 @@ class TestInsertString(TestcaseBase):
 | 
				
			||||||
        data = cf.gen_default_list_data(ct.default_nb)
 | 
					        data = cf.gen_default_list_data(ct.default_nb)
 | 
				
			||||||
        mutation_res, _ = collection_w.insert(data=data)
 | 
					        mutation_res, _ = collection_w.insert(data=data)
 | 
				
			||||||
        assert mutation_res.insert_count == ct.default_nb
 | 
					        assert mutation_res.insert_count == ct.default_nb
 | 
				
			||||||
        assert mutation_res.primary_keys == data[2]
 | 
					        assert mutation_res.primary_keys == data[2].tolist()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @pytest.mark.tags(CaseLabel.L0)
 | 
					    @pytest.mark.tags(CaseLabel.L0)
 | 
				
			||||||
    @pytest.mark.parametrize("string_fields", [[cf.gen_string_field(name="string_field1")],
 | 
					    @pytest.mark.parametrize("string_fields", [[cf.gen_string_field(name="string_field1")],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7052,7 +7052,7 @@ class TestCollectionRangeSearch(TestcaseBase):
 | 
				
			||||||
        nb = 1000
 | 
					        nb = 1000
 | 
				
			||||||
        rounds = 10
 | 
					        rounds = 10
 | 
				
			||||||
        for i in range(rounds):
 | 
					        for i in range(rounds):
 | 
				
			||||||
            data = cf.gen_general_default_list_data(nb=nb, auto_id=True, vector_data_type=vector_data_type,
 | 
					            data = cf.gen_default_list_data(nb=nb, auto_id=True, vector_data_type=vector_data_type,
 | 
				
			||||||
                                                    with_json=False, start=i*nb)
 | 
					                                                    with_json=False, start=i*nb)
 | 
				
			||||||
            collection_w.insert(data)
 | 
					            collection_w.insert(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7064,7 +7064,7 @@ class TestCollectionRangeSearch(TestcaseBase):
 | 
				
			||||||
        if with_growing is True:
 | 
					        if with_growing is True:
 | 
				
			||||||
            # add some growing segments
 | 
					            # add some growing segments
 | 
				
			||||||
            for j in range(rounds//2):
 | 
					            for j in range(rounds//2):
 | 
				
			||||||
                data = cf.gen_general_default_list_data(nb=nb, auto_id=True, vector_data_type=vector_data_type,
 | 
					                data = cf.gen_default_list_data(nb=nb, auto_id=True, vector_data_type=vector_data_type,
 | 
				
			||||||
                                                        with_json=False, start=(rounds+j)*nb)
 | 
					                                                        with_json=False, start=(rounds+j)*nb)
 | 
				
			||||||
                collection_w.insert(data)
 | 
					                collection_w.insert(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10240,7 +10240,7 @@ class TestSearchIterator(TestcaseBase):
 | 
				
			||||||
class TestSearchGroupBy(TestcaseBase):
 | 
					class TestSearchGroupBy(TestcaseBase):
 | 
				
			||||||
    """ Test case of search group by """
 | 
					    """ Test case of search group by """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @pytest.mark.tags(CaseLabel.L3)
 | 
					    @pytest.mark.tags(CaseLabel.L2)
 | 
				
			||||||
    @pytest.mark.parametrize("index_type, metric", zip(["FLAT", "IVF_FLAT", "HNSW"], ct.float_metrics))
 | 
					    @pytest.mark.parametrize("index_type, metric", zip(["FLAT", "IVF_FLAT", "HNSW"], ct.float_metrics))
 | 
				
			||||||
    @pytest.mark.parametrize("vector_data_type", ["FLOAT16_VECTOR", "FLOAT_VECTOR", "BFLOAT16_VECTOR"])
 | 
					    @pytest.mark.parametrize("vector_data_type", ["FLOAT16_VECTOR", "FLOAT_VECTOR", "BFLOAT16_VECTOR"])
 | 
				
			||||||
    def test_search_group_by_default(self, index_type, metric, vector_data_type):
 | 
					    def test_search_group_by_default(self, index_type, metric, vector_data_type):
 | 
				
			||||||
| 
						 | 
					@ -10273,19 +10273,19 @@ class TestSearchGroupBy(TestcaseBase):
 | 
				
			||||||
        nq = 2
 | 
					        nq = 2
 | 
				
			||||||
        limit = 15
 | 
					        limit = 15
 | 
				
			||||||
        search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
 | 
					        search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
 | 
				
			||||||
        # verify the results are same if gourp by pk
 | 
					        # # verify the results are same if gourp by pk
 | 
				
			||||||
        res1 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
 | 
					        # res1 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
 | 
				
			||||||
                                   param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG,
 | 
					        #                            param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG,
 | 
				
			||||||
                                   group_by_field=ct.default_int64_field_name)[0]
 | 
					        #                            group_by_field=ct.default_int64_field_name)[0]
 | 
				
			||||||
        res2 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
 | 
					        # res2 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
 | 
				
			||||||
                                   param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG)[0]
 | 
					        #                            param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG)[0]
 | 
				
			||||||
        hits_num = 0
 | 
					        # hits_num = 0
 | 
				
			||||||
        for i in range(nq):
 | 
					        # for i in range(nq):
 | 
				
			||||||
            # assert res1[i].ids == res2[i].ids
 | 
					        #     assert res1[i].ids == res2[i].ids
 | 
				
			||||||
            hits_num += len(set(res1[i].ids).intersection(set(res2[i].ids)))
 | 
					        #     hits_num += len(set(res1[i].ids).intersection(set(res2[i].ids)))
 | 
				
			||||||
        hit_rate = hits_num / (nq * limit)
 | 
					        # hit_rate = hits_num / (nq * limit)
 | 
				
			||||||
        log.info(f"groupy primary key hits_num: {hits_num}, nq: {nq}, limit: {limit}, hit_rate: {hit_rate}")
 | 
					        # log.info(f"groupy primary key hits_num: {hits_num}, nq: {nq}, limit: {limit}, hit_rate: {hit_rate}")
 | 
				
			||||||
        assert hit_rate >= 0.60
 | 
					        # assert hit_rate >= 0.60
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # verify that every record in groupby results is the top1 for that value of the group_by_field
 | 
					        # verify that every record in groupby results is the top1 for that value of the group_by_field
 | 
				
			||||||
        supported_grpby_fields = [ct.default_int8_field_name, ct.default_int16_field_name,
 | 
					        supported_grpby_fields = [ct.default_int8_field_name, ct.default_int16_field_name,
 | 
				
			||||||
| 
						 | 
					@ -10323,61 +10323,119 @@ class TestSearchGroupBy(TestcaseBase):
 | 
				
			||||||
                assert len(grpby_values) == len(set(grpby_values))
 | 
					                assert len(grpby_values) == len(set(grpby_values))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @pytest.mark.tags(CaseLabel.L0)
 | 
					    @pytest.mark.tags(CaseLabel.L0)
 | 
				
			||||||
    @pytest.mark.parametrize("index_type, metric", zip(["FLAT", "IVF_FLAT", "HNSW"], ct.float_metrics))
 | 
					    def test_search_group_size_default(self):
 | 
				
			||||||
    @pytest.mark.parametrize("vector_data_type", ["FLOAT_VECTOR", "FLOAT16_VECTOR",  "BFLOAT16_VECTOR"])
 | 
					 | 
				
			||||||
    @pytest.mark.parametrize("group_strict_size", [True, False])
 | 
					 | 
				
			||||||
    def test_search_group_size_default(self, index_type, metric, vector_data_type, group_strict_size):
 | 
					 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        target: test search group by
 | 
					        target: test search group by
 | 
				
			||||||
        method: 1. create a collection with data
 | 
					        method: 1. create a collection with 3 different float vectors
 | 
				
			||||||
                2. search with group by int32 with group size
 | 
					                2. build index with 3 different index types and metrics
 | 
				
			||||||
 | 
					                2. search on 3 different float vector fields with group by varchar field with group size
 | 
				
			||||||
                verify results entity = limit * group_size  and group size is full if group_strict_size is True
 | 
					                verify results entity = limit * group_size  and group size is full if group_strict_size is True
 | 
				
			||||||
                verfiy results group counts = limit if group_strict_size is False
 | 
					                verify results group counts = limit if group_strict_size is False
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
 | 
					        self._connect()
 | 
				
			||||||
                                                    vector_data_type=vector_data_type,
 | 
					        dense_types = ["FLOAT16_VECTOR", "FLOAT_VECTOR", "BFLOAT16_VECTOR"]
 | 
				
			||||||
                                                    is_all_data_type=True, with_json=False)[0]
 | 
					        dims = [16, 128, 64]
 | 
				
			||||||
        _index_params = {"index_type": index_type, "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
 | 
					        index_types = ["FLAT", "IVF_SQ8", "HNSW"]
 | 
				
			||||||
        if index_type in ["IVF_FLAT", "FLAT"]:
 | 
					        metrics = ct.float_metrics
 | 
				
			||||||
            _index_params = {"index_type": index_type, "metric_type": metric, "params": {"nlist": 128}}
 | 
					        fields = [cf.gen_int64_field(is_primary=True), cf.gen_string_field()]
 | 
				
			||||||
        collection_w.create_index(ct.default_float_vec_field_name, index_params=_index_params)
 | 
					        for i in range(len(dense_types)):
 | 
				
			||||||
 | 
					            fields.append(cf.gen_float_vec_field(name=dense_types[i],
 | 
				
			||||||
 | 
					                                                 vector_data_type=dense_types[i], dim=dims[i]))
 | 
				
			||||||
 | 
					        schema = cf.gen_collection_schema(fields, auto_id=True)
 | 
				
			||||||
 | 
					        collection_w = self.init_collection_wrap(name=prefix, schema=schema)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # insert with the same values for scalar fields
 | 
					        # insert with the same values for scalar fields
 | 
				
			||||||
        for _ in range(500):
 | 
					        nb = 100
 | 
				
			||||||
            data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
 | 
					        for _ in range(100):
 | 
				
			||||||
 | 
					            string_values = pd.Series(data=[str(i) for i in range(nb)], dtype="string")
 | 
				
			||||||
 | 
					            data = [string_values]
 | 
				
			||||||
 | 
					            for i in range(len(dense_types)):
 | 
				
			||||||
 | 
					                data.append(cf.gen_vectors(dim=dims[i], nb=nb, vector_data_type=dense_types[i]))
 | 
				
			||||||
            collection_w.insert(data)
 | 
					            collection_w.insert(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        collection_w.flush()
 | 
					        collection_w.flush()
 | 
				
			||||||
        collection_w.create_index(ct.default_float_vec_field_name, index_params=_index_params)
 | 
					        for i in range(len(dense_types)):
 | 
				
			||||||
 | 
					            _index_params = {"index_type": index_types[i], "metric_type": metrics[i],
 | 
				
			||||||
 | 
					                             "params": cf.get_index_params_params(index_types[i])}
 | 
				
			||||||
 | 
					            collection_w.create_index(dense_types[i], _index_params)
 | 
				
			||||||
        collection_w.load()
 | 
					        collection_w.load()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        search_params = {"metric_type": metric, "params": {"ef": 128}}
 | 
					 | 
				
			||||||
        nq = 2
 | 
					        nq = 2
 | 
				
			||||||
        limit = 100
 | 
					        limit = 50
 | 
				
			||||||
        group_size = 10
 | 
					        group_size = 5
 | 
				
			||||||
        search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
 | 
					        for j in range(len(dense_types)):
 | 
				
			||||||
        # verify
 | 
					            search_vectors = cf.gen_vectors(nq, dim=dims[j], vector_data_type=dense_types[j])
 | 
				
			||||||
        res1 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
 | 
					            search_params = {"params": cf.get_search_params_params(index_types[j])}
 | 
				
			||||||
                                   param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG,
 | 
					            # when group_strict_size=true, it shall return results with entities = limit * group_size
 | 
				
			||||||
                                   group_by_field=ct.default_int32_field_name, group_size=group_size,
 | 
					            res1 = collection_w.search(data=search_vectors, anns_field=dense_types[j],
 | 
				
			||||||
                                   group_strict_size=group_strict_size,
 | 
					                                       param=search_params, limit=limit, # consistency_level=CONSISTENCY_STRONG,
 | 
				
			||||||
                                   output_fields=[ct.default_int32_field_name]
 | 
					                                       group_by_field=ct.default_string_field_name,
 | 
				
			||||||
                                   )[0]
 | 
					                                       group_size=group_size, group_strict_size=True,
 | 
				
			||||||
        # print(res1)
 | 
					                                       output_fields=[ct.default_string_field_name])[0]
 | 
				
			||||||
        if group_strict_size is True:   # when true, it shall return results with entities = limit * group_size
 | 
					 | 
				
			||||||
            for i in range(nq):
 | 
					            for i in range(nq):
 | 
				
			||||||
                for l in range(limit):
 | 
					                for l in range(limit):
 | 
				
			||||||
                    group_values = []
 | 
					                    group_values = []
 | 
				
			||||||
                    for k in range(10):
 | 
					                    for k in range(10):
 | 
				
			||||||
                        group_values.append(res1[i][l].fields.get(ct.default_int32_field_name))
 | 
					                        group_values.append(res1[i][l].fields.get(ct.default_string_field_name))
 | 
				
			||||||
                    assert len(set(group_values)) == 1
 | 
					                    assert len(set(group_values)) == 1
 | 
				
			||||||
                assert len(res1[i]) == limit * group_size
 | 
					                assert len(res1[i]) == limit * group_size
 | 
				
			||||||
        else:   # when False, it shall return results with group counts = limit
 | 
					
 | 
				
			||||||
 | 
					            # when group_strict_size=false, it shall return results with group counts = limit
 | 
				
			||||||
 | 
					            res1 = collection_w.search(data=search_vectors, anns_field=dense_types[j],
 | 
				
			||||||
 | 
					                                       param=search_params, limit=limit, # consistency_level=CONSISTENCY_STRONG,
 | 
				
			||||||
 | 
					                                       group_by_field=ct.default_string_field_name,
 | 
				
			||||||
 | 
					                                       group_size=group_size, group_strict_size=False,
 | 
				
			||||||
 | 
					                                       output_fields=[ct.default_string_field_name])[0]
 | 
				
			||||||
            for i in range(nq):
 | 
					            for i in range(nq):
 | 
				
			||||||
                group_values = []
 | 
					                group_values = []
 | 
				
			||||||
                for l in range(len(res1[i])):
 | 
					                for l in range(len(res1[i])):
 | 
				
			||||||
                    group_values.append(res1[i][l].fields.get(ct.default_int32_field_name))
 | 
					                    group_values.append(res1[i][l].fields.get(ct.default_string_field_name))
 | 
				
			||||||
                assert len(set(group_values)) == limit
 | 
					                assert len(set(group_values)) == limit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # hybrid search group by
 | 
				
			||||||
 | 
					        req_list = []
 | 
				
			||||||
 | 
					        for j in range(len(dense_types)):
 | 
				
			||||||
 | 
					            search_params = {
 | 
				
			||||||
 | 
					                "data": cf.gen_vectors(nq, dim=dims[j], vector_data_type=dense_types[j]),
 | 
				
			||||||
 | 
					                "anns_field": dense_types[j],
 | 
				
			||||||
 | 
					                "param": {"params": cf.get_search_params_params(index_types[j])},
 | 
				
			||||||
 | 
					                "limit": limit,
 | 
				
			||||||
 | 
					                "expr": "int64 > 0"}
 | 
				
			||||||
 | 
					            req = AnnSearchRequest(**search_params)
 | 
				
			||||||
 | 
					            req_list.append(req)
 | 
				
			||||||
 | 
					        # 4. hybrid search group by
 | 
				
			||||||
 | 
					        import numpy as np
 | 
				
			||||||
 | 
					        rank_scorers = ["max", "avg", "sum"]
 | 
				
			||||||
 | 
					        for scorer in rank_scorers:
 | 
				
			||||||
 | 
					            res = collection_w.hybrid_search(req_list, WeightedRanker(0.3, 0.3, 0.3), limit=limit,
 | 
				
			||||||
 | 
					                                             group_by_field=ct.default_string_field_name,
 | 
				
			||||||
 | 
					                                             group_size=group_size, rank_group_scorer=scorer,
 | 
				
			||||||
 | 
					                                             output_fields=[ct.default_string_field_name])[0]
 | 
				
			||||||
 | 
					            for i in range(nq):
 | 
				
			||||||
 | 
					                group_values = []
 | 
				
			||||||
 | 
					                for l in range(len(res[i])):
 | 
				
			||||||
 | 
					                    group_values.append(res[i][l].fields.get(ct.default_string_field_name))
 | 
				
			||||||
 | 
					                assert len(set(group_values)) == limit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # group_distances = []
 | 
				
			||||||
 | 
					                tmp_distances = [100 for _ in range(group_size)]    # init with a large value
 | 
				
			||||||
 | 
					                group_distances = [res[i][0].distance]              # init with the first value
 | 
				
			||||||
 | 
					                for l in range(len(res[i])-1):
 | 
				
			||||||
 | 
					                    curr_group_value = res[i][l].fields.get(ct.default_string_field_name)
 | 
				
			||||||
 | 
					                    next_group_value = res[i][l+1].fields.get(ct.default_string_field_name)
 | 
				
			||||||
 | 
					                    if curr_group_value == next_group_value:
 | 
				
			||||||
 | 
					                        group_distances.append(res[i][l+1].distance)
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        if scorer == 'sum':
 | 
				
			||||||
 | 
					                            assert np.sum(group_distances) < np.sum(tmp_distances)
 | 
				
			||||||
 | 
					                        elif scorer == 'avg':
 | 
				
			||||||
 | 
					                            assert np.mean(group_distances) < np.mean(tmp_distances)
 | 
				
			||||||
 | 
					                        else:      # default max
 | 
				
			||||||
 | 
					                            assert np.max(group_distances) < np.max(tmp_distances)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        tmp_distances = group_distances
 | 
				
			||||||
 | 
					                        group_distances = [res[i][l+1].distance]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @pytest.mark.tags(CaseLabel.L2)
 | 
					    @pytest.mark.tags(CaseLabel.L2)
 | 
				
			||||||
    def test_search_max_group_size_and_max_limit(self):
 | 
					    def test_search_max_group_size_and_max_limit(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -10776,8 +10834,8 @@ class TestSearchGroupBy(TestcaseBase):
 | 
				
			||||||
                            check_task=CheckTasks.err_res,
 | 
					                            check_task=CheckTasks.err_res,
 | 
				
			||||||
                            check_items={"err_code": err_code, "err_msg": err_msg})
 | 
					                            check_items={"err_code": err_code, "err_msg": err_msg})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @pytest.mark.tags(CaseLabel.L2)
 | 
					    @pytest.mark.tags(CaseLabel.L1)
 | 
				
			||||||
    def test_hybrid_search_not_support_group_by(self):
 | 
					    def test_hybrid_search_support_group_by(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        target: verify that hybrid search does not support groupby
 | 
					        target: verify that hybrid search does not support groupby
 | 
				
			||||||
        method: 1. create a collection with multiple vector fields
 | 
					        method: 1. create a collection with multiple vector fields
 | 
				
			||||||
| 
						 | 
					@ -10786,55 +10844,61 @@ class TestSearchGroupBy(TestcaseBase):
 | 
				
			||||||
                verify: the error code and msg
 | 
					                verify: the error code and msg
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        # 1. initialize collection with data
 | 
					        # 1. initialize collection with data
 | 
				
			||||||
        dim = 33
 | 
					        dim = 128
 | 
				
			||||||
        index_type = "HNSW"
 | 
					        supported_index = ["HNSW", "FLAT", "IVF_FLAT", "IVF_SQ8"]
 | 
				
			||||||
        metric_type = "COSINE"
 | 
					        metric = ct.default_L0_metric
 | 
				
			||||||
        _index_params = {"index_type": index_type, "metric_type": metric_type, "params": {"M": 16, "efConstruction": 128}}
 | 
					 | 
				
			||||||
        collection_w, _, _, insert_ids, time_stamp = \
 | 
					        collection_w, _, _, insert_ids, time_stamp = \
 | 
				
			||||||
            self.init_collection_general(prefix, True, dim=dim,  is_index=False,
 | 
					            self.init_collection_general(prefix, True, dim=dim, is_index=False,
 | 
				
			||||||
                                         enable_dynamic_field=False, multiple_dim_array=[dim, dim])[0:5]
 | 
					                                         enable_dynamic_field=False,
 | 
				
			||||||
 | 
					                                         multiple_dim_array=[dim, dim, dim])[0:5]
 | 
				
			||||||
        # 2. extract vector field name
 | 
					        # 2. extract vector field name
 | 
				
			||||||
        vector_name_list = cf.extract_vector_field_name_list(collection_w)
 | 
					        vector_name_list = cf.extract_vector_field_name_list(collection_w)
 | 
				
			||||||
        vector_name_list.append(ct.default_float_vec_field_name)
 | 
					        vector_name_list.append(ct.default_float_vec_field_name)
 | 
				
			||||||
        for vector_name in vector_name_list:
 | 
					        for i in range(len(vector_name_list)):
 | 
				
			||||||
            collection_w.create_index(vector_name, _index_params)
 | 
					            index = supported_index[i]
 | 
				
			||||||
 | 
					            _index_params = {"index_type": index, "metric_type": metric,
 | 
				
			||||||
 | 
					                             "params": cf.get_index_params_params(index)}
 | 
				
			||||||
 | 
					            collection_w.create_index(vector_name_list[i], _index_params)
 | 
				
			||||||
        collection_w.load()
 | 
					        collection_w.load()
 | 
				
			||||||
        # 3. prepare search params
 | 
					        # 3. prepare search params
 | 
				
			||||||
        req_list = []
 | 
					        req_list = []
 | 
				
			||||||
        for vector_name in vector_name_list:
 | 
					        for vector_name in vector_name_list:
 | 
				
			||||||
            search_param = {
 | 
					            search_param = {
 | 
				
			||||||
                "data": [[random.random() for _ in range(dim)] for _ in range(1)],
 | 
					                "data": [[random.random() for _ in range(dim)] for _ in range(ct.default_nq)],
 | 
				
			||||||
                "anns_field": vector_name,
 | 
					                "anns_field": vector_name,
 | 
				
			||||||
                "param": {"metric_type": metric_type, "offset": 0},
 | 
					                "param": {"metric_type": metric, "offset": 0},
 | 
				
			||||||
                "limit": default_limit,
 | 
					                "limit": default_limit,
 | 
				
			||||||
                # "group_by_field": ct.default_int64_field_name,
 | 
					 | 
				
			||||||
                "expr": "int64 > 0"}
 | 
					                "expr": "int64 > 0"}
 | 
				
			||||||
            req = AnnSearchRequest(**search_param)
 | 
					            req = AnnSearchRequest(**search_param)
 | 
				
			||||||
            req_list.append(req)
 | 
					            req_list.append(req)
 | 
				
			||||||
        # 4. hybrid search
 | 
					        # 4. hybrid search group by
 | 
				
			||||||
        err_code = 9999
 | 
					        res = collection_w.hybrid_search(req_list, WeightedRanker(0.1, 0.9, 1, 0.2), default_limit,
 | 
				
			||||||
        err_msg = f"not support search_group_by operation in the hybrid search"
 | 
					                                         group_by_field=ct.default_string_field_name,
 | 
				
			||||||
        collection_w.hybrid_search(req_list, WeightedRanker(0.1, 0.9, 1), default_limit,
 | 
					                                         output_fields=[ct.default_string_field_name],
 | 
				
			||||||
                                   group_by_field=ct.default_int64_field_name,
 | 
					                                         check_task=CheckTasks.check_search_results,
 | 
				
			||||||
                                   check_task=CheckTasks.err_res,
 | 
					                                         check_items={"nq": ct.default_nq, "limit": default_limit})[0]
 | 
				
			||||||
                                   check_items={"err_code": err_code, "err_msg": err_msg})
 | 
					        print(res)
 | 
				
			||||||
 | 
					        for i in range(ct.default_nq):
 | 
				
			||||||
 | 
					            group_values = []
 | 
				
			||||||
 | 
					            for l in range(ct.default_limit):
 | 
				
			||||||
 | 
					                group_values.append(res[i][l].fields.get(ct.default_string_field_name))
 | 
				
			||||||
 | 
					            assert len(group_values) == len(set(group_values))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # 5. hybrid search with group by on one vector field
 | 
					        # 5. hybrid search with RRFRanker on one vector field with group by
 | 
				
			||||||
        req_list = []
 | 
					        req_list = []
 | 
				
			||||||
        for vector_name in vector_name_list[:1]:
 | 
					        for vector_name in vector_name_list[:1]:
 | 
				
			||||||
            search_param = {
 | 
					            search_param = {
 | 
				
			||||||
                "data": [[random.random() for _ in range(dim)] for _ in range(1)],
 | 
					                "data": [[random.random() for _ in range(dim)] for _ in range(1)],
 | 
				
			||||||
                "anns_field": vector_name,
 | 
					                "anns_field": vector_name,
 | 
				
			||||||
                "param": {"metric_type": metric_type, "offset": 0},
 | 
					                "param": {"metric_type": metric, "offset": 0},
 | 
				
			||||||
                "limit": default_limit,
 | 
					                "limit": default_limit,
 | 
				
			||||||
                # "group_by_field": ct.default_int64_field_name,
 | 
					 | 
				
			||||||
                "expr": "int64 > 0"}
 | 
					                "expr": "int64 > 0"}
 | 
				
			||||||
            req = AnnSearchRequest(**search_param)
 | 
					            req = AnnSearchRequest(**search_param)
 | 
				
			||||||
            req_list.append(req)
 | 
					            req_list.append(req)
 | 
				
			||||||
        collection_w.hybrid_search(req_list, RRFRanker(), default_limit,
 | 
					            collection_w.hybrid_search(req_list, RRFRanker(), default_limit,
 | 
				
			||||||
                                   group_by_field=ct.default_int64_field_name,
 | 
					                                       group_by_field=ct.default_string_field_name,
 | 
				
			||||||
                                   check_task=CheckTasks.err_res,
 | 
					                                       check_task=CheckTasks.check_search_results,
 | 
				
			||||||
                                   check_items={"err_code": err_code, "err_msg": err_msg})
 | 
					                                       check_items={"nq": 1, "limit": default_limit})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @pytest.mark.tags(CaseLabel.L1)
 | 
					    @pytest.mark.tags(CaseLabel.L1)
 | 
				
			||||||
    def test_multi_vectors_search_one_vector_group_by(self):
 | 
					    def test_multi_vectors_search_one_vector_group_by(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue