mirror of https://github.com/milvus-io/milvus.git
				
				
				
			test: add array inverted index function test (#35874)
/kind improvement --------- Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>pull/35909/head
							parent
							
								
									b2eb9fe2a7
								
							
						
					
					
						commit
						57422cb2ed
					
				| 
						 | 
				
			
			@ -63,6 +63,148 @@ class ParamInfo:
 | 
			
		|||
param_info = ParamInfo()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_array_dataset(size, array_length, hit_probabilities, target_values):
 | 
			
		||||
    dataset = []
 | 
			
		||||
    target_array_length = target_values.get('array_length_field', None)
 | 
			
		||||
    target_array_access = target_values.get('array_access', None)
 | 
			
		||||
    all_target_values = set(
 | 
			
		||||
        val for sublist in target_values.values() for val in (sublist if isinstance(sublist, list) else [sublist]))
 | 
			
		||||
    for i in range(size):
 | 
			
		||||
        entry = {"id": i}
 | 
			
		||||
 | 
			
		||||
        # Generate random arrays for each condition
 | 
			
		||||
        for condition in hit_probabilities.keys():
 | 
			
		||||
            available_values = [val for val in range(1, 100) if val not in all_target_values]
 | 
			
		||||
            array = random.sample(available_values, array_length)
 | 
			
		||||
 | 
			
		||||
            # Ensure the array meets the condition based on its probability
 | 
			
		||||
            if random.random() < hit_probabilities[condition]:
 | 
			
		||||
                if condition == 'contains':
 | 
			
		||||
                    if target_values[condition] not in array:
 | 
			
		||||
                        array[random.randint(0, array_length - 1)] = target_values[condition]
 | 
			
		||||
                elif condition == 'contains_any':
 | 
			
		||||
                    if not any(val in array for val in target_values[condition]):
 | 
			
		||||
                        array[random.randint(0, array_length - 1)] = random.choice(target_values[condition])
 | 
			
		||||
                elif condition == 'contains_all':
 | 
			
		||||
                    indices = random.sample(range(array_length), len(target_values[condition]))
 | 
			
		||||
                    for idx, val in zip(indices, target_values[condition]):
 | 
			
		||||
                        array[idx] = val
 | 
			
		||||
                elif condition == 'equals':
 | 
			
		||||
                    array = target_values[condition][:]
 | 
			
		||||
                elif condition == 'array_length_field':
 | 
			
		||||
                    array = [random.randint(0, 10) for _ in range(target_array_length)]
 | 
			
		||||
                elif condition == 'array_access':
 | 
			
		||||
                    array = [random.randint(0, 10) for _ in range(random.randint(10, 20))]
 | 
			
		||||
                    array[target_array_access[0]] = target_array_access[1]
 | 
			
		||||
                else:
 | 
			
		||||
                    raise ValueError(f"Unknown condition: {condition}")
 | 
			
		||||
 | 
			
		||||
            entry[condition] = array
 | 
			
		||||
 | 
			
		||||
        dataset.append(entry)
 | 
			
		||||
 | 
			
		||||
    return dataset
 | 
			
		||||
 | 
			
		||||
def prepare_array_test_data(data_size, hit_rate=0.005, dim=128):
 | 
			
		||||
    size = data_size  # Number of arrays in the dataset
 | 
			
		||||
    array_length = 10  # Length of each array
 | 
			
		||||
 | 
			
		||||
    # Probabilities that an array hits the target condition
 | 
			
		||||
    hit_probabilities = {
 | 
			
		||||
        'contains': hit_rate,
 | 
			
		||||
        'contains_any': hit_rate,
 | 
			
		||||
        'contains_all': hit_rate,
 | 
			
		||||
        'equals': hit_rate,
 | 
			
		||||
        'array_length_field': hit_rate,
 | 
			
		||||
        'array_access': hit_rate
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # Target values for each condition
 | 
			
		||||
    target_values = {
 | 
			
		||||
        'contains': 42,
 | 
			
		||||
        'contains_any': [21, 37, 42],
 | 
			
		||||
        'contains_all': [15, 30],
 | 
			
		||||
        'equals': [1,2,3,4,5],
 | 
			
		||||
        'array_length_field': 5, # array length == 5
 | 
			
		||||
        'array_access': [0, 5] # index=0, and value == 5
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # Generate dataset
 | 
			
		||||
    dataset = generate_array_dataset(size, array_length, hit_probabilities, target_values)
 | 
			
		||||
    data = {
 | 
			
		||||
        "id": pd.Series([x["id"] for x in dataset]),
 | 
			
		||||
        "contains": pd.Series([x["contains"] for x in dataset]),
 | 
			
		||||
        "contains_any": pd.Series([x["contains_any"] for x in dataset]),
 | 
			
		||||
        "contains_all": pd.Series([x["contains_all"] for x in dataset]),
 | 
			
		||||
        "equals": pd.Series([x["equals"] for x in dataset]),
 | 
			
		||||
        "array_length_field": pd.Series([x["array_length_field"] for x in dataset]),
 | 
			
		||||
        "array_access": pd.Series([x["array_access"] for x in dataset]),
 | 
			
		||||
        "emb": pd.Series([np.array([random.random() for j in range(dim)], dtype=np.dtype("float32")) for _ in
 | 
			
		||||
                          range(size)])
 | 
			
		||||
    }
 | 
			
		||||
    # Define testing conditions
 | 
			
		||||
    contains_value = target_values['contains']
 | 
			
		||||
    contains_any_values = target_values['contains_any']
 | 
			
		||||
    contains_all_values = target_values['contains_all']
 | 
			
		||||
    equals_array = target_values['equals']
 | 
			
		||||
 | 
			
		||||
    # Perform tests
 | 
			
		||||
    contains_result = [d for d in dataset if contains_value in d["contains"]]
 | 
			
		||||
    contains_any_result = [d for d in dataset if any(val in d["contains_any"] for val in contains_any_values)]
 | 
			
		||||
    contains_all_result = [d for d in dataset if all(val in d["contains_all"] for val in contains_all_values)]
 | 
			
		||||
    equals_result = [d for d in dataset if d["equals"] == equals_array]
 | 
			
		||||
    array_length_result = [d for d in dataset if len(d["array_length_field"]) == target_values['array_length_field']]
 | 
			
		||||
    array_access_result = [d for d in dataset if d["array_access"][0] == target_values['array_access'][1]]
 | 
			
		||||
    # Calculate and log.info proportions
 | 
			
		||||
    contains_ratio = len(contains_result) / size
 | 
			
		||||
    contains_any_ratio = len(contains_any_result) / size
 | 
			
		||||
    contains_all_ratio = len(contains_all_result) / size
 | 
			
		||||
    equals_ratio = len(equals_result) / size
 | 
			
		||||
    array_length_ratio = len(array_length_result) / size
 | 
			
		||||
    array_access_ratio = len(array_access_result) / size
 | 
			
		||||
 | 
			
		||||
    log.info(f"\nProportion of arrays that contain the value: {contains_ratio}")
 | 
			
		||||
    log.info(f"Proportion of arrays that contain any of the values: {contains_any_ratio}")
 | 
			
		||||
    log.info(f"Proportion of arrays that contain all of the values: {contains_all_ratio}")
 | 
			
		||||
    log.info(f"Proportion of arrays that equal the target array: {equals_ratio}")
 | 
			
		||||
    log.info(f"Proportion of arrays that have the target array length: {array_length_ratio}")
 | 
			
		||||
    log.info(f"Proportion of arrays that have the target array access: {array_access_ratio}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    train_df = pd.DataFrame(data)
 | 
			
		||||
 | 
			
		||||
    target_id = {
 | 
			
		||||
        "contains": [r["id"] for r in contains_result],
 | 
			
		||||
        "contains_any": [r["id"] for r in contains_any_result],
 | 
			
		||||
        "contains_all": [r["id"] for r in contains_all_result],
 | 
			
		||||
        "equals": [r["id"] for r in equals_result],
 | 
			
		||||
        "array_length": [r["id"] for r in array_length_result],
 | 
			
		||||
        "array_access": [r["id"] for r in array_access_result]
 | 
			
		||||
    }
 | 
			
		||||
    target_id_list = [target_id[key] for key in ["contains", "contains_any", "contains_all", "equals", "array_length", "array_access"]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    filters = [
 | 
			
		||||
        "array_contains(contains, 42)",
 | 
			
		||||
        "array_contains_any(contains_any, [21, 37, 42])",
 | 
			
		||||
        "array_contains_all(contains_all, [15, 30])",
 | 
			
		||||
        "equals == [1,2,3,4,5]",
 | 
			
		||||
        "array_length(array_length_field) == 5",
 | 
			
		||||
        "array_access[0] == 5"
 | 
			
		||||
 | 
			
		||||
    ]
 | 
			
		||||
    query_expr = []
 | 
			
		||||
    for i in range(len(filters)):
 | 
			
		||||
        item = {
 | 
			
		||||
            "expr": filters[i],
 | 
			
		||||
            "ground_truth": target_id_list[i],
 | 
			
		||||
        }
 | 
			
		||||
        query_expr.append(item)
 | 
			
		||||
    return train_df, query_expr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gen_unique_str(str_value=None):
 | 
			
		||||
    prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
 | 
			
		||||
    return "test_" + prefix if str_value is None else str_value + "_" + prefix
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,6 +7,10 @@ from common.code_mapping import CollectionErrorMessage as clem
 | 
			
		|||
from common.code_mapping import ConnectionErrorMessage as cem
 | 
			
		||||
from base.client_base import TestcaseBase
 | 
			
		||||
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_EVENTUALLY
 | 
			
		||||
from pymilvus import (
 | 
			
		||||
    FieldSchema, CollectionSchema, DataType,
 | 
			
		||||
    Collection
 | 
			
		||||
)
 | 
			
		||||
import threading
 | 
			
		||||
from pymilvus import DefaultConfig
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
| 
						 | 
				
			
			@ -1520,7 +1524,7 @@ class TestQueryParams(TestcaseBase):
 | 
			
		|||
    def test_query_output_fields_simple_wildcard(self):
 | 
			
		||||
        """
 | 
			
		||||
        target: test query output_fields with simple wildcard (* and %)
 | 
			
		||||
        method: specify output_fields as "*" 
 | 
			
		||||
        method: specify output_fields as "*"
 | 
			
		||||
        expected: output all scale field; output all fields
 | 
			
		||||
        """
 | 
			
		||||
        # init collection with fields: int64, float, float_vec, float_vector1
 | 
			
		||||
| 
						 | 
				
			
			@ -2566,7 +2570,7 @@ class TestQueryOperation(TestcaseBase):
 | 
			
		|||
        """
 | 
			
		||||
        target: test the scenario which query with many logical expressions
 | 
			
		||||
        method: 1. create collection
 | 
			
		||||
                3. query the expr that like: int64 == 0 || int64 == 1 ........ 
 | 
			
		||||
                3. query the expr that like: int64 == 0 || int64 == 1 ........
 | 
			
		||||
        expected: run successfully
 | 
			
		||||
        """
 | 
			
		||||
        c_name = cf.gen_unique_str(prefix)
 | 
			
		||||
| 
						 | 
				
			
			@ -2577,14 +2581,14 @@ class TestQueryOperation(TestcaseBase):
 | 
			
		|||
        collection_w.load()
 | 
			
		||||
        multi_exprs = " || ".join(f'{default_int_field_name} == {i}' for i in range(60))
 | 
			
		||||
        _, check_res = collection_w.query(multi_exprs, output_fields=[f'{default_int_field_name}'])
 | 
			
		||||
        assert(check_res == True) 
 | 
			
		||||
        assert(check_res == True)
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.tags(CaseLabel.L0)
 | 
			
		||||
    def test_search_multi_logical_exprs(self):
 | 
			
		||||
        """
 | 
			
		||||
        target: test the scenario which search with many logical expressions
 | 
			
		||||
        method: 1. create collection
 | 
			
		||||
                3. search with the expr that like: int64 == 0 || int64 == 1 ........ 
 | 
			
		||||
                3. search with the expr that like: int64 == 0 || int64 == 1 ........
 | 
			
		||||
        expected: run successfully
 | 
			
		||||
        """
 | 
			
		||||
        c_name = cf.gen_unique_str(prefix)
 | 
			
		||||
| 
						 | 
				
			
			@ -2593,15 +2597,15 @@ class TestQueryOperation(TestcaseBase):
 | 
			
		|||
        collection_w.insert(df)
 | 
			
		||||
        collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
 | 
			
		||||
        collection_w.load()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
        multi_exprs = " || ".join(f'{default_int_field_name} == {i}' for i in range(60))
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
        collection_w.load()
 | 
			
		||||
        vectors_s = [[random.random() for _ in range(ct.default_dim)] for _ in range(ct.default_nq)]
 | 
			
		||||
        limit = 1000
 | 
			
		||||
        _, check_res = collection_w.search(vectors_s[:ct.default_nq], ct.default_float_vec_field_name,
 | 
			
		||||
                            ct.default_search_params, limit, multi_exprs)
 | 
			
		||||
        assert(check_res == True) 
 | 
			
		||||
        assert(check_res == True)
 | 
			
		||||
 | 
			
		||||
class TestQueryString(TestcaseBase):
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			@ -2947,8 +2951,8 @@ class TestQueryString(TestcaseBase):
 | 
			
		|||
    @pytest.mark.tags(CaseLabel.L2)
 | 
			
		||||
    def test_query_with_create_diskann_index(self):
 | 
			
		||||
        """
 | 
			
		||||
        target: test query after create diskann index 
 | 
			
		||||
        method: create a collection and build diskann index 
 | 
			
		||||
        target: test query after create diskann index
 | 
			
		||||
        method: create a collection and build diskann index
 | 
			
		||||
        expected: verify query result
 | 
			
		||||
        """
 | 
			
		||||
        collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_index=False)[0:2]
 | 
			
		||||
| 
						 | 
				
			
			@ -2968,8 +2972,8 @@ class TestQueryString(TestcaseBase):
 | 
			
		|||
    @pytest.mark.tags(CaseLabel.L2)
 | 
			
		||||
    def test_query_with_create_diskann_with_string_pk(self):
 | 
			
		||||
        """
 | 
			
		||||
        target: test query after create diskann index 
 | 
			
		||||
        method: create a collection with string pk and build diskann index 
 | 
			
		||||
        target: test query after create diskann index
 | 
			
		||||
        method: create a collection with string pk and build diskann index
 | 
			
		||||
        expected: verify query result
 | 
			
		||||
        """
 | 
			
		||||
        collection_w, vectors = self.init_collection_general(prefix, insert_data=True,
 | 
			
		||||
| 
						 | 
				
			
			@ -2986,7 +2990,7 @@ class TestQueryString(TestcaseBase):
 | 
			
		|||
    @pytest.mark.tags(CaseLabel.L1)
 | 
			
		||||
    def test_query_with_scalar_field(self):
 | 
			
		||||
        """
 | 
			
		||||
        target: test query with Scalar field 
 | 
			
		||||
        target: test query with Scalar field
 | 
			
		||||
        method: create collection , string field is primary
 | 
			
		||||
                collection load and insert empty data with string field
 | 
			
		||||
                collection query uses string expr in string field
 | 
			
		||||
| 
						 | 
				
			
			@ -3015,6 +3019,48 @@ class TestQueryString(TestcaseBase):
 | 
			
		|||
        res, _ = collection_w.query(expr, output_fields=output_fields)
 | 
			
		||||
 | 
			
		||||
        assert len(res) == 4
 | 
			
		||||
class TestQueryArray(TestcaseBase):
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.tags(CaseLabel.L1)
 | 
			
		||||
    @pytest.mark.parametrize("array_element_data_type", [DataType.INT64])
 | 
			
		||||
    def test_query_array_with_inverted_index(self, array_element_data_type):
 | 
			
		||||
        # create collection
 | 
			
		||||
        additional_params = {"max_length": 1000} if array_element_data_type == DataType.VARCHAR else {}
 | 
			
		||||
        fields = [
 | 
			
		||||
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
 | 
			
		||||
            FieldSchema(name="contains", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="contains_any", dtype=DataType.ARRAY, element_type=array_element_data_type,
 | 
			
		||||
                        max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="contains_all", dtype=DataType.ARRAY, element_type=array_element_data_type,
 | 
			
		||||
                        max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="equals", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="array_length_field", dtype=DataType.ARRAY, element_type=array_element_data_type,
 | 
			
		||||
                        max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="array_access", dtype=DataType.ARRAY, element_type=array_element_data_type,
 | 
			
		||||
                        max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=128)
 | 
			
		||||
        ]
 | 
			
		||||
        schema = CollectionSchema(fields=fields, description="test collection", enable_dynamic_field=True)
 | 
			
		||||
        collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema)
 | 
			
		||||
        # insert data
 | 
			
		||||
        train_data, query_expr = cf.prepare_array_test_data(3000, hit_rate=0.05)
 | 
			
		||||
        collection_w.insert(train_data)
 | 
			
		||||
        index_params = {"metric_type": "L2", "index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}}
 | 
			
		||||
        collection_w.create_index("emb", index_params=index_params)
 | 
			
		||||
        for f in ["contains", "contains_any", "contains_all", "equals", "array_length_field", "array_access"]:
 | 
			
		||||
            collection_w.create_index(f, {"index_type": "INVERTED"})
 | 
			
		||||
        collection_w.load()
 | 
			
		||||
 | 
			
		||||
        for item in query_expr:
 | 
			
		||||
            expr = item["expr"]
 | 
			
		||||
            ground_truth = item["ground_truth"]
 | 
			
		||||
            res, _ = collection_w.query(
 | 
			
		||||
                expr=expr,
 | 
			
		||||
                output_fields=["*"],
 | 
			
		||||
            )
 | 
			
		||||
            assert len(res) == len(ground_truth)
 | 
			
		||||
            for i in range(len(res)):
 | 
			
		||||
                assert res[i]["id"] == ground_truth[i]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestQueryCount(TestcaseBase):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,10 @@
 | 
			
		|||
import numpy as np
 | 
			
		||||
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY
 | 
			
		||||
from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker
 | 
			
		||||
from pymilvus import (
 | 
			
		||||
    FieldSchema, CollectionSchema, DataType,
 | 
			
		||||
    Collection
 | 
			
		||||
)
 | 
			
		||||
from common.constants import *
 | 
			
		||||
from utils.util_pymilvus import *
 | 
			
		||||
from common.common_type import CaseLabel, CheckTasks
 | 
			
		||||
| 
						 | 
				
			
			@ -5237,7 +5241,7 @@ class TestSearchBase(TestcaseBase):
 | 
			
		|||
 | 
			
		||||
class TestSearchDSL(TestcaseBase):
 | 
			
		||||
    @pytest.mark.tags(CaseLabel.L0)
 | 
			
		||||
    def test_query_vector_only(self):
 | 
			
		||||
    def test_search_vector_only(self):
 | 
			
		||||
        """
 | 
			
		||||
        target: test search normal scenario
 | 
			
		||||
        method: search vector only
 | 
			
		||||
| 
						 | 
				
			
			@ -5254,6 +5258,54 @@ class TestSearchDSL(TestcaseBase):
 | 
			
		|||
                            check_items={"nq": nq,
 | 
			
		||||
                                         "ids": insert_ids,
 | 
			
		||||
                                         "limit": ct.default_top_k})
 | 
			
		||||
class TestSearchArray(TestcaseBase):
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.tags(CaseLabel.L1)
 | 
			
		||||
    @pytest.mark.parametrize("array_element_data_type", [DataType.INT64])
 | 
			
		||||
    def test_search_array_with_inverted_index(self, array_element_data_type):
 | 
			
		||||
        # create collection
 | 
			
		||||
        additional_params = {"max_length": 1000} if array_element_data_type == DataType.VARCHAR else {}
 | 
			
		||||
        fields = [
 | 
			
		||||
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
 | 
			
		||||
            FieldSchema(name="contains", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="contains_any", dtype=DataType.ARRAY, element_type=array_element_data_type,
 | 
			
		||||
                        max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="contains_all", dtype=DataType.ARRAY, element_type=array_element_data_type,
 | 
			
		||||
                        max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="equals", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="array_length_field", dtype=DataType.ARRAY, element_type=array_element_data_type,
 | 
			
		||||
                        max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="array_access", dtype=DataType.ARRAY, element_type=array_element_data_type,
 | 
			
		||||
                        max_capacity=2000, **additional_params),
 | 
			
		||||
            FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=128)
 | 
			
		||||
        ]
 | 
			
		||||
        schema = CollectionSchema(fields=fields, description="test collection", enable_dynamic_field=True)
 | 
			
		||||
        collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema)
 | 
			
		||||
        # insert data
 | 
			
		||||
        train_data, query_expr = cf.prepare_array_test_data(3000, hit_rate=0.05)
 | 
			
		||||
        collection_w.insert(train_data)
 | 
			
		||||
        index_params = {"metric_type": "L2", "index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}}
 | 
			
		||||
        collection_w.create_index("emb", index_params=index_params)
 | 
			
		||||
        for f in ["contains", "contains_any", "contains_all", "equals", "array_length_field", "array_access"]:
 | 
			
		||||
            collection_w.create_index(f, {"index_type": "INVERTED"})
 | 
			
		||||
        collection_w.load()
 | 
			
		||||
 | 
			
		||||
        for item in query_expr:
 | 
			
		||||
            expr = item["expr"]
 | 
			
		||||
            ground_truth_candidate = item["ground_truth"]
 | 
			
		||||
            res, _ = collection_w.search(
 | 
			
		||||
                data = [np.array([random.random() for j in range(128)], dtype=np.dtype("float32"))],
 | 
			
		||||
                anns_field="emb",
 | 
			
		||||
                param={"metric_type": "L2", "params": {"M": 32, "efConstruction": 360}},
 | 
			
		||||
                limit=10,
 | 
			
		||||
                expr=expr,
 | 
			
		||||
                output_fields=["*"],
 | 
			
		||||
            )
 | 
			
		||||
            assert len(res) == 1
 | 
			
		||||
            for i in range(len(res)):
 | 
			
		||||
                assert len(res[i]) == 10
 | 
			
		||||
                for hit in res[i]:
 | 
			
		||||
                    assert hit.id in ground_truth_candidate
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestSearchString(TestcaseBase):
 | 
			
		||||
| 
						 | 
				
			
			@ -12869,4 +12921,4 @@ class TestSparseSearch(TestcaseBase):
 | 
			
		|||
        collection_w.search_iterator(data[-1][-1:], ct.default_sparse_vec_field_name,
 | 
			
		||||
                                     ct.default_sparse_search_params, batch_size,
 | 
			
		||||
                                     check_task=CheckTasks.check_search_iterator,
 | 
			
		||||
                                     check_items={"batch_size": batch_size})
 | 
			
		||||
                                     check_items={"batch_size": batch_size})
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue