test: add full text search checker in test (#37122)

/kind improvement

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
pull/37061/head
zhuwenxing 2024-10-25 14:09:29 +08:00 committed by GitHub
parent 6ef014d931
commit ac2858d418
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 93 additions and 4 deletions

View File

@ -223,6 +223,7 @@ class Op(Enum):
release_collection = 'release_collection'
release_partition = 'release_partition'
search = 'search'
full_text_search = 'full_text_search'
hybrid_search = 'hybrid_search'
query = 'query'
text_match = 'text_match'
@ -363,6 +364,7 @@ class Checker:
self.scalar_field_names = cf.get_scalar_field_name_list(schema=schema)
self.float_vector_field_names = cf.get_float_vec_field_name_list(schema=schema)
self.binary_vector_field_names = cf.get_binary_vec_field_name_list(schema=schema)
self.bm25_sparse_field_names = cf.get_bm25_vec_field_name_list(schema=schema)
# get index of collection
indexes = [index.to_dict() for index in self.c_wrap.indexes]
indexed_fields = [index['field'] for index in indexes]
@ -393,6 +395,15 @@ class Checker:
timeout=timeout,
enable_traceback=enable_traceback,
check_task=CheckTasks.check_nothing)
for f in self.bm25_sparse_field_names:
if f in indexed_fields:
continue
self.c_wrap.create_index(f,
constants.DEFAULT_BM25_INDEX_PARAM,
timeout=timeout,
enable_traceback=enable_traceback,
check_task=CheckTasks.check_nothing)
self.replica_number = replica_number
self.c_wrap.load(replica_number=self.replica_number)
@ -658,6 +669,41 @@ class SearchChecker(Checker):
sleep(constants.WAIT_PER_OP / 10)
class FullTextSearchChecker(Checker):
"""check full text search operations in a dependent thread"""
def __init__(self, collection_name=None, shards_num=2, replica_number=1, schema=None, ):
if collection_name is None:
collection_name = cf.gen_unique_str("FullTextSearchChecker_")
super().__init__(collection_name=collection_name, shards_num=shards_num, schema=schema)
self.insert_data()
@trace()
def full_text_search(self):
bm25_anns_field = random.choice(self.bm25_sparse_field_names)
res, result = self.c_wrap.search(
data=cf.gen_vectors(5, self.dim, vector_data_type="TEXT_SPARSE_VECTOR"),
anns_field=bm25_anns_field,
param=constants.DEFAULT_BM25_SEARCH_PARAM,
limit=1,
partition_names=self.p_names,
timeout=search_timeout,
check_task=CheckTasks.check_nothing
)
return res, result
@exception_handler()
def run_task(self):
res, result = self.full_text_search()
return res, result
def keep_running(self):
while self._keep_running:
self.run_task()
sleep(constants.WAIT_PER_OP / 10)
class HybridSearchChecker(Checker):
"""check hybrid search operations in a dependent thread"""

View File

@ -25,4 +25,6 @@ DEFAULT_INDEX_PARAM = {"index_type": "HNSW", "metric_type": "L2", "params": {"M"
DEFAULT_SEARCH_PARAM = {"metric_type": "L2", "params": {"ef": 64}}
DEFAULT_BINARY_INDEX_PARAM = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"M": 48}}
DEFAULT_BINARY_SEARCH_PARAM = {"metric_type": "JACCARD", "params": {"nprobe": 10}}
DEFAULT_BM25_INDEX_PARAM = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "BM25", "params": {"bm25_k1": 1.5, "bm25_b": 0.75}}
DEFAULT_BM25_SEARCH_PARAM = {"metric_type": "BM25", "params": {}}
CHAOS_INFO_SAVE_PATH = "/tmp/ci_logs/chaos_info.json"

View File

@ -7,6 +7,7 @@ from chaos.checker import (InsertChecker,
UpsertChecker,
FlushChecker,
SearchChecker,
FullTextSearchChecker,
HybridSearchChecker,
QueryChecker,
TextMatchChecker,
@ -76,6 +77,7 @@ class TestOperations(TestBase):
Op.upsert: UpsertChecker(collection_name=c_name),
Op.flush: FlushChecker(collection_name=c_name),
Op.search: SearchChecker(collection_name=c_name),
Op.full_text_search: FullTextSearchChecker(collection_name=c_name),
Op.hybrid_search: HybridSearchChecker(collection_name=c_name),
Op.query: QueryChecker(collection_name=c_name),
Op.text_match: TextMatchChecker(collection_name=c_name),

View File

@ -9,6 +9,7 @@ from chaos.checker import (CollectionCreateChecker,
UpsertChecker,
FlushChecker,
SearchChecker,
FullTextSearchChecker,
HybridSearchChecker,
QueryChecker,
TextMatchChecker,
@ -75,6 +76,7 @@ class TestOperations(TestBase):
Op.flush: FlushChecker(collection_name=c_name),
Op.index: IndexCreateChecker(collection_name=c_name),
Op.search: SearchChecker(collection_name=c_name),
Op.full_text_search: FullTextSearchChecker(collection_name=c_name),
Op.hybrid_search: HybridSearchChecker(collection_name=c_name),
Op.query: QueryChecker(collection_name=c_name),
Op.text_match: TextMatchChecker(collection_name=c_name),

View File

@ -25,7 +25,7 @@ import bm25s
import jieba
import re
from pymilvus import CollectionSchema, DataType
from pymilvus import CollectionSchema, DataType, FunctionType, Function
from bm25s.tokenization import Tokenizer
@ -717,12 +717,21 @@ def gen_all_datatype_collection_schema(description=ct.default_desc, primary_fiel
gen_array_field(name="array_bool", element_type=DataType.BOOL),
gen_float_vec_field(dim=dim),
gen_float_vec_field(name="image_emb", dim=dim),
gen_float_vec_field(name="text_emb", dim=dim),
gen_float_vec_field(name="text_sparse_emb", vector_data_type="SPARSE_FLOAT_VECTOR"),
gen_float_vec_field(name="voice_emb", dim=dim),
]
schema, _ = ApiCollectionSchemaWrapper().init_collection_schema(fields=fields, description=description,
primary_field=primary_field, auto_id=auto_id,
enable_dynamic_field=enable_dynamic_field, **kwargs)
bm25_function = Function(
name=f"text",
function_type=FunctionType.BM25,
input_field_names=["text"],
output_field_names=["text_sparse_emb"],
params={},
)
schema.add_function(bm25_function)
return schema
@ -1018,7 +1027,11 @@ def gen_vectors(nb, dim, vector_data_type="FLOAT_VECTOR"):
vectors = gen_bf16_vectors(nb, dim)[1]
elif vector_data_type == "SPARSE_FLOAT_VECTOR":
vectors = gen_sparse_vectors(nb, dim)
elif vector_data_type == "TEXT_SPARSE_VECTOR":
vectors = gen_text_vectors(nb)
else:
log.error(f"Invalid vector data type: {vector_data_type}")
raise Exception(f"Invalid vector data type: {vector_data_type}")
if dim > 1:
if vector_data_type == "FLOAT_VECTOR":
vectors = preprocessing.normalize(vectors, axis=1, norm='l2')
@ -1026,6 +1039,15 @@ def gen_vectors(nb, dim, vector_data_type="FLOAT_VECTOR"):
return vectors
def gen_text_vectors(nb, language="en"):
fake = Faker("en_US")
if language == "zh":
fake = Faker("zh_CN")
vectors = [" milvus " + fake.text() for _ in range(nb)]
return vectors
def gen_string(nb):
string_values = [str(random.random()) for _ in range(nb)]
return string_values
@ -1766,6 +1788,18 @@ def get_binary_vec_field_name_list(schema=None):
return vec_fields
def get_bm25_vec_field_name_list(schema=None):
if not hasattr(schema, "functions"):
return []
functions = schema.functions
bm25_func = [func for func in functions if func.type == FunctionType.BM25]
bm25_outputs = []
for func in bm25_func:
bm25_outputs.extend(func.output_field_names)
bm25_outputs = list(set(bm25_outputs))
return bm25_outputs
def get_dim_by_schema(schema=None):
if schema is None:
schema = gen_default_collection_schema()
@ -3042,7 +3076,10 @@ def gen_vectors_based_on_vector_type(num, dim, vector_data_type=ct.float_type):
vectors = gen_bf16_vectors(num, dim)[1]
elif vector_data_type == ct.sparse_vector:
vectors = gen_sparse_vectors(num, dim)
elif vector_data_type == ct.text_sparse_vector:
vectors = gen_text_vectors(num)
else:
raise Exception("vector_data_type is invalid")
return vectors