mirror of https://github.com/milvus-io/milvus.git
test: add array inverted index function test (#35874)
/kind improvement --------- Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>pull/35909/head
parent
b2eb9fe2a7
commit
57422cb2ed
|
@ -63,6 +63,148 @@ class ParamInfo:
|
|||
param_info = ParamInfo()
|
||||
|
||||
|
||||
def generate_array_dataset(size, array_length, hit_probabilities, target_values):
|
||||
dataset = []
|
||||
target_array_length = target_values.get('array_length_field', None)
|
||||
target_array_access = target_values.get('array_access', None)
|
||||
all_target_values = set(
|
||||
val for sublist in target_values.values() for val in (sublist if isinstance(sublist, list) else [sublist]))
|
||||
for i in range(size):
|
||||
entry = {"id": i}
|
||||
|
||||
# Generate random arrays for each condition
|
||||
for condition in hit_probabilities.keys():
|
||||
available_values = [val for val in range(1, 100) if val not in all_target_values]
|
||||
array = random.sample(available_values, array_length)
|
||||
|
||||
# Ensure the array meets the condition based on its probability
|
||||
if random.random() < hit_probabilities[condition]:
|
||||
if condition == 'contains':
|
||||
if target_values[condition] not in array:
|
||||
array[random.randint(0, array_length - 1)] = target_values[condition]
|
||||
elif condition == 'contains_any':
|
||||
if not any(val in array for val in target_values[condition]):
|
||||
array[random.randint(0, array_length - 1)] = random.choice(target_values[condition])
|
||||
elif condition == 'contains_all':
|
||||
indices = random.sample(range(array_length), len(target_values[condition]))
|
||||
for idx, val in zip(indices, target_values[condition]):
|
||||
array[idx] = val
|
||||
elif condition == 'equals':
|
||||
array = target_values[condition][:]
|
||||
elif condition == 'array_length_field':
|
||||
array = [random.randint(0, 10) for _ in range(target_array_length)]
|
||||
elif condition == 'array_access':
|
||||
array = [random.randint(0, 10) for _ in range(random.randint(10, 20))]
|
||||
array[target_array_access[0]] = target_array_access[1]
|
||||
else:
|
||||
raise ValueError(f"Unknown condition: {condition}")
|
||||
|
||||
entry[condition] = array
|
||||
|
||||
dataset.append(entry)
|
||||
|
||||
return dataset
|
||||
|
||||
def prepare_array_test_data(data_size, hit_rate=0.005, dim=128):
|
||||
size = data_size # Number of arrays in the dataset
|
||||
array_length = 10 # Length of each array
|
||||
|
||||
# Probabilities that an array hits the target condition
|
||||
hit_probabilities = {
|
||||
'contains': hit_rate,
|
||||
'contains_any': hit_rate,
|
||||
'contains_all': hit_rate,
|
||||
'equals': hit_rate,
|
||||
'array_length_field': hit_rate,
|
||||
'array_access': hit_rate
|
||||
}
|
||||
|
||||
# Target values for each condition
|
||||
target_values = {
|
||||
'contains': 42,
|
||||
'contains_any': [21, 37, 42],
|
||||
'contains_all': [15, 30],
|
||||
'equals': [1,2,3,4,5],
|
||||
'array_length_field': 5, # array length == 5
|
||||
'array_access': [0, 5] # index=0, and value == 5
|
||||
}
|
||||
|
||||
# Generate dataset
|
||||
dataset = generate_array_dataset(size, array_length, hit_probabilities, target_values)
|
||||
data = {
|
||||
"id": pd.Series([x["id"] for x in dataset]),
|
||||
"contains": pd.Series([x["contains"] for x in dataset]),
|
||||
"contains_any": pd.Series([x["contains_any"] for x in dataset]),
|
||||
"contains_all": pd.Series([x["contains_all"] for x in dataset]),
|
||||
"equals": pd.Series([x["equals"] for x in dataset]),
|
||||
"array_length_field": pd.Series([x["array_length_field"] for x in dataset]),
|
||||
"array_access": pd.Series([x["array_access"] for x in dataset]),
|
||||
"emb": pd.Series([np.array([random.random() for j in range(dim)], dtype=np.dtype("float32")) for _ in
|
||||
range(size)])
|
||||
}
|
||||
# Define testing conditions
|
||||
contains_value = target_values['contains']
|
||||
contains_any_values = target_values['contains_any']
|
||||
contains_all_values = target_values['contains_all']
|
||||
equals_array = target_values['equals']
|
||||
|
||||
# Perform tests
|
||||
contains_result = [d for d in dataset if contains_value in d["contains"]]
|
||||
contains_any_result = [d for d in dataset if any(val in d["contains_any"] for val in contains_any_values)]
|
||||
contains_all_result = [d for d in dataset if all(val in d["contains_all"] for val in contains_all_values)]
|
||||
equals_result = [d for d in dataset if d["equals"] == equals_array]
|
||||
array_length_result = [d for d in dataset if len(d["array_length_field"]) == target_values['array_length_field']]
|
||||
array_access_result = [d for d in dataset if d["array_access"][0] == target_values['array_access'][1]]
|
||||
# Calculate and log.info proportions
|
||||
contains_ratio = len(contains_result) / size
|
||||
contains_any_ratio = len(contains_any_result) / size
|
||||
contains_all_ratio = len(contains_all_result) / size
|
||||
equals_ratio = len(equals_result) / size
|
||||
array_length_ratio = len(array_length_result) / size
|
||||
array_access_ratio = len(array_access_result) / size
|
||||
|
||||
log.info(f"\nProportion of arrays that contain the value: {contains_ratio}")
|
||||
log.info(f"Proportion of arrays that contain any of the values: {contains_any_ratio}")
|
||||
log.info(f"Proportion of arrays that contain all of the values: {contains_all_ratio}")
|
||||
log.info(f"Proportion of arrays that equal the target array: {equals_ratio}")
|
||||
log.info(f"Proportion of arrays that have the target array length: {array_length_ratio}")
|
||||
log.info(f"Proportion of arrays that have the target array access: {array_access_ratio}")
|
||||
|
||||
|
||||
|
||||
train_df = pd.DataFrame(data)
|
||||
|
||||
target_id = {
|
||||
"contains": [r["id"] for r in contains_result],
|
||||
"contains_any": [r["id"] for r in contains_any_result],
|
||||
"contains_all": [r["id"] for r in contains_all_result],
|
||||
"equals": [r["id"] for r in equals_result],
|
||||
"array_length": [r["id"] for r in array_length_result],
|
||||
"array_access": [r["id"] for r in array_access_result]
|
||||
}
|
||||
target_id_list = [target_id[key] for key in ["contains", "contains_any", "contains_all", "equals", "array_length", "array_access"]]
|
||||
|
||||
|
||||
filters = [
|
||||
"array_contains(contains, 42)",
|
||||
"array_contains_any(contains_any, [21, 37, 42])",
|
||||
"array_contains_all(contains_all, [15, 30])",
|
||||
"equals == [1,2,3,4,5]",
|
||||
"array_length(array_length_field) == 5",
|
||||
"array_access[0] == 5"
|
||||
|
||||
]
|
||||
query_expr = []
|
||||
for i in range(len(filters)):
|
||||
item = {
|
||||
"expr": filters[i],
|
||||
"ground_truth": target_id_list[i],
|
||||
}
|
||||
query_expr.append(item)
|
||||
return train_df, query_expr
|
||||
|
||||
|
||||
|
||||
def gen_unique_str(str_value=None):
|
||||
prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
|
||||
return "test_" + prefix if str_value is None else str_value + "_" + prefix
|
||||
|
|
|
@ -7,6 +7,10 @@ from common.code_mapping import CollectionErrorMessage as clem
|
|||
from common.code_mapping import ConnectionErrorMessage as cem
|
||||
from base.client_base import TestcaseBase
|
||||
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_EVENTUALLY
|
||||
from pymilvus import (
|
||||
FieldSchema, CollectionSchema, DataType,
|
||||
Collection
|
||||
)
|
||||
import threading
|
||||
from pymilvus import DefaultConfig
|
||||
from datetime import datetime
|
||||
|
@ -1520,7 +1524,7 @@ class TestQueryParams(TestcaseBase):
|
|||
def test_query_output_fields_simple_wildcard(self):
|
||||
"""
|
||||
target: test query output_fields with simple wildcard (* and %)
|
||||
method: specify output_fields as "*"
|
||||
method: specify output_fields as "*"
|
||||
expected: output all scale field; output all fields
|
||||
"""
|
||||
# init collection with fields: int64, float, float_vec, float_vector1
|
||||
|
@ -2566,7 +2570,7 @@ class TestQueryOperation(TestcaseBase):
|
|||
"""
|
||||
target: test the scenario which query with many logical expressions
|
||||
method: 1. create collection
|
||||
3. query the expr that like: int64 == 0 || int64 == 1 ........
|
||||
3. query the expr that like: int64 == 0 || int64 == 1 ........
|
||||
expected: run successfully
|
||||
"""
|
||||
c_name = cf.gen_unique_str(prefix)
|
||||
|
@ -2577,14 +2581,14 @@ class TestQueryOperation(TestcaseBase):
|
|||
collection_w.load()
|
||||
multi_exprs = " || ".join(f'{default_int_field_name} == {i}' for i in range(60))
|
||||
_, check_res = collection_w.query(multi_exprs, output_fields=[f'{default_int_field_name}'])
|
||||
assert(check_res == True)
|
||||
assert(check_res == True)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_search_multi_logical_exprs(self):
|
||||
"""
|
||||
target: test the scenario which search with many logical expressions
|
||||
method: 1. create collection
|
||||
3. search with the expr that like: int64 == 0 || int64 == 1 ........
|
||||
3. search with the expr that like: int64 == 0 || int64 == 1 ........
|
||||
expected: run successfully
|
||||
"""
|
||||
c_name = cf.gen_unique_str(prefix)
|
||||
|
@ -2593,15 +2597,15 @@ class TestQueryOperation(TestcaseBase):
|
|||
collection_w.insert(df)
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
|
||||
collection_w.load()
|
||||
|
||||
|
||||
multi_exprs = " || ".join(f'{default_int_field_name} == {i}' for i in range(60))
|
||||
|
||||
|
||||
collection_w.load()
|
||||
vectors_s = [[random.random() for _ in range(ct.default_dim)] for _ in range(ct.default_nq)]
|
||||
limit = 1000
|
||||
_, check_res = collection_w.search(vectors_s[:ct.default_nq], ct.default_float_vec_field_name,
|
||||
ct.default_search_params, limit, multi_exprs)
|
||||
assert(check_res == True)
|
||||
assert(check_res == True)
|
||||
|
||||
class TestQueryString(TestcaseBase):
|
||||
"""
|
||||
|
@ -2947,8 +2951,8 @@ class TestQueryString(TestcaseBase):
|
|||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_query_with_create_diskann_index(self):
|
||||
"""
|
||||
target: test query after create diskann index
|
||||
method: create a collection and build diskann index
|
||||
target: test query after create diskann index
|
||||
method: create a collection and build diskann index
|
||||
expected: verify query result
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_index=False)[0:2]
|
||||
|
@ -2968,8 +2972,8 @@ class TestQueryString(TestcaseBase):
|
|||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_query_with_create_diskann_with_string_pk(self):
|
||||
"""
|
||||
target: test query after create diskann index
|
||||
method: create a collection with string pk and build diskann index
|
||||
target: test query after create diskann index
|
||||
method: create a collection with string pk and build diskann index
|
||||
expected: verify query result
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True,
|
||||
|
@ -2986,7 +2990,7 @@ class TestQueryString(TestcaseBase):
|
|||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_with_scalar_field(self):
|
||||
"""
|
||||
target: test query with Scalar field
|
||||
target: test query with Scalar field
|
||||
method: create collection , string field is primary
|
||||
collection load and insert empty data with string field
|
||||
collection query uses string expr in string field
|
||||
|
@ -3015,6 +3019,48 @@ class TestQueryString(TestcaseBase):
|
|||
res, _ = collection_w.query(expr, output_fields=output_fields)
|
||||
|
||||
assert len(res) == 4
|
||||
class TestQueryArray(TestcaseBase):
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("array_element_data_type", [DataType.INT64])
|
||||
def test_query_array_with_inverted_index(self, array_element_data_type):
|
||||
# create collection
|
||||
additional_params = {"max_length": 1000} if array_element_data_type == DataType.VARCHAR else {}
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(name="contains", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="contains_any", dtype=DataType.ARRAY, element_type=array_element_data_type,
|
||||
max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="contains_all", dtype=DataType.ARRAY, element_type=array_element_data_type,
|
||||
max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="equals", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="array_length_field", dtype=DataType.ARRAY, element_type=array_element_data_type,
|
||||
max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="array_access", dtype=DataType.ARRAY, element_type=array_element_data_type,
|
||||
max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=128)
|
||||
]
|
||||
schema = CollectionSchema(fields=fields, description="test collection", enable_dynamic_field=True)
|
||||
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema)
|
||||
# insert data
|
||||
train_data, query_expr = cf.prepare_array_test_data(3000, hit_rate=0.05)
|
||||
collection_w.insert(train_data)
|
||||
index_params = {"metric_type": "L2", "index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}}
|
||||
collection_w.create_index("emb", index_params=index_params)
|
||||
for f in ["contains", "contains_any", "contains_all", "equals", "array_length_field", "array_access"]:
|
||||
collection_w.create_index(f, {"index_type": "INVERTED"})
|
||||
collection_w.load()
|
||||
|
||||
for item in query_expr:
|
||||
expr = item["expr"]
|
||||
ground_truth = item["ground_truth"]
|
||||
res, _ = collection_w.query(
|
||||
expr=expr,
|
||||
output_fields=["*"],
|
||||
)
|
||||
assert len(res) == len(ground_truth)
|
||||
for i in range(len(res)):
|
||||
assert res[i]["id"] == ground_truth[i]
|
||||
|
||||
|
||||
class TestQueryCount(TestcaseBase):
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
import numpy as np
|
||||
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY
|
||||
from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker
|
||||
from pymilvus import (
|
||||
FieldSchema, CollectionSchema, DataType,
|
||||
Collection
|
||||
)
|
||||
from common.constants import *
|
||||
from utils.util_pymilvus import *
|
||||
from common.common_type import CaseLabel, CheckTasks
|
||||
|
@ -5237,7 +5241,7 @@ class TestSearchBase(TestcaseBase):
|
|||
|
||||
class TestSearchDSL(TestcaseBase):
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_vector_only(self):
|
||||
def test_search_vector_only(self):
|
||||
"""
|
||||
target: test search normal scenario
|
||||
method: search vector only
|
||||
|
@ -5254,6 +5258,54 @@ class TestSearchDSL(TestcaseBase):
|
|||
check_items={"nq": nq,
|
||||
"ids": insert_ids,
|
||||
"limit": ct.default_top_k})
|
||||
class TestSearchArray(TestcaseBase):
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("array_element_data_type", [DataType.INT64])
|
||||
def test_search_array_with_inverted_index(self, array_element_data_type):
|
||||
# create collection
|
||||
additional_params = {"max_length": 1000} if array_element_data_type == DataType.VARCHAR else {}
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(name="contains", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="contains_any", dtype=DataType.ARRAY, element_type=array_element_data_type,
|
||||
max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="contains_all", dtype=DataType.ARRAY, element_type=array_element_data_type,
|
||||
max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="equals", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="array_length_field", dtype=DataType.ARRAY, element_type=array_element_data_type,
|
||||
max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="array_access", dtype=DataType.ARRAY, element_type=array_element_data_type,
|
||||
max_capacity=2000, **additional_params),
|
||||
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=128)
|
||||
]
|
||||
schema = CollectionSchema(fields=fields, description="test collection", enable_dynamic_field=True)
|
||||
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema)
|
||||
# insert data
|
||||
train_data, query_expr = cf.prepare_array_test_data(3000, hit_rate=0.05)
|
||||
collection_w.insert(train_data)
|
||||
index_params = {"metric_type": "L2", "index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}}
|
||||
collection_w.create_index("emb", index_params=index_params)
|
||||
for f in ["contains", "contains_any", "contains_all", "equals", "array_length_field", "array_access"]:
|
||||
collection_w.create_index(f, {"index_type": "INVERTED"})
|
||||
collection_w.load()
|
||||
|
||||
for item in query_expr:
|
||||
expr = item["expr"]
|
||||
ground_truth_candidate = item["ground_truth"]
|
||||
res, _ = collection_w.search(
|
||||
data = [np.array([random.random() for j in range(128)], dtype=np.dtype("float32"))],
|
||||
anns_field="emb",
|
||||
param={"metric_type": "L2", "params": {"M": 32, "efConstruction": 360}},
|
||||
limit=10,
|
||||
expr=expr,
|
||||
output_fields=["*"],
|
||||
)
|
||||
assert len(res) == 1
|
||||
for i in range(len(res)):
|
||||
assert len(res[i]) == 10
|
||||
for hit in res[i]:
|
||||
assert hit.id in ground_truth_candidate
|
||||
|
||||
|
||||
class TestSearchString(TestcaseBase):
|
||||
|
@ -12869,4 +12921,4 @@ class TestSparseSearch(TestcaseBase):
|
|||
collection_w.search_iterator(data[-1][-1:], ct.default_sparse_vec_field_name,
|
||||
ct.default_sparse_search_params, batch_size,
|
||||
check_task=CheckTasks.check_search_iterator,
|
||||
check_items={"batch_size": batch_size})
|
||||
check_items={"batch_size": batch_size})
|
||||
|
|
Loading…
Reference in New Issue