mirror of https://github.com/milvus-io/milvus.git
test: supplementing case for text match (#36693)
/kind improvement Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>pull/37180/head
parent
7774b7275e
commit
c8dd665bf6
|
@ -117,7 +117,7 @@ class ResponseChecker:
|
|||
return True
|
||||
|
||||
def assert_exception(self, res, actual=True, error_dict=None):
|
||||
assert actual is False
|
||||
assert actual is False, f"Response of API {self.func_name} expect get error, but success"
|
||||
assert len(error_dict) > 0
|
||||
if isinstance(res, Error):
|
||||
error_code = error_dict[ct.err_code]
|
||||
|
|
|
@ -118,8 +118,6 @@ def get_bm25_ground_truth(corpus, queries, top_k=100, language="en"):
|
|||
return results, scores
|
||||
|
||||
|
||||
|
||||
|
||||
def custom_tokenizer(language="en"):
|
||||
def remove_punctuation(text):
|
||||
text = text.strip()
|
||||
|
@ -153,11 +151,12 @@ def custom_tokenizer(language="en"):
|
|||
def analyze_documents(texts, language="en"):
|
||||
|
||||
tokenizer = custom_tokenizer(language)
|
||||
# Start timing
|
||||
t0 = time.time()
|
||||
|
||||
new_texts = []
|
||||
for text in texts:
|
||||
if isinstance(text, str):
|
||||
new_texts.append(text)
|
||||
# Tokenize the corpus
|
||||
tokenized = tokenizer.tokenize(texts, return_as="tuple")
|
||||
tokenized = tokenizer.tokenize(new_texts, return_as="tuple")
|
||||
# log.info(f"Tokenized: {tokenized}")
|
||||
# Create a frequency counter
|
||||
freq = Counter()
|
||||
|
@ -170,13 +169,11 @@ def analyze_documents(texts, language="en"):
|
|||
|
||||
# Convert token ids back to words
|
||||
word_freq = Counter({id_to_word[token_id]: count for token_id, count in freq.items()})
|
||||
|
||||
# End timing
|
||||
tt = time.time() - t0
|
||||
log.debug(f"Analyze document cost time: {tt}")
|
||||
log.debug(f"word freq {word_freq.most_common(10)}")
|
||||
|
||||
return word_freq
|
||||
|
||||
|
||||
def check_token_overlap(text_a, text_b, language="en"):
|
||||
word_freq_a = analyze_documents([text_a], language)
|
||||
word_freq_b = analyze_documents([text_b], language)
|
||||
|
@ -3054,7 +3051,7 @@ def gen_sparse_vectors(nb, dim=1000, sparse_format="dok"):
|
|||
|
||||
rng = np.random.default_rng()
|
||||
vectors = [{
|
||||
d: rng.random() for d in random.sample(range(dim), random.randint(20, 30))
|
||||
d: rng.random() for d in list(set(random.sample(range(dim), random.randint(20, 30)) + [0, 1]))
|
||||
} for _ in range(nb)]
|
||||
if sparse_format == "coo":
|
||||
vectors = [
|
||||
|
|
|
@ -5206,6 +5206,120 @@ class TestQueryTextMatch(TestcaseBase):
|
|||
log.info(f"res len {len(res)}")
|
||||
assert len(res) == wf_map[field].most_common()[-1][1]
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("combine_op", ["and", "or"])
|
||||
def test_query_text_match_with_non_varchar_fields_expr(self, combine_op):
|
||||
"""
|
||||
target: test text match with non-varchar fields expr
|
||||
method: 1. enable text match for varchar field and add some non varchar fields
|
||||
2. insert data, create index and load
|
||||
3. query with text match expr and non-varchar fields expr
|
||||
4. verify the result
|
||||
expected: query result is correct
|
||||
"""
|
||||
# 1. initialize with data
|
||||
fake_en = Faker("en_US")
|
||||
analyzer_params = {
|
||||
"tokenizer": "default",
|
||||
}
|
||||
dim = 128
|
||||
default_fields = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(
|
||||
name="age",
|
||||
dtype=DataType.INT64,
|
||||
),
|
||||
FieldSchema(
|
||||
name="word",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_tokenizer=True,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="sentence",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_tokenizer=True,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="paragraph",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_tokenizer=True,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="text",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_tokenizer=True,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
||||
]
|
||||
default_schema = CollectionSchema(
|
||||
fields=default_fields, description="test collection"
|
||||
)
|
||||
|
||||
collection_w = self.init_collection_wrap(
|
||||
name=cf.gen_unique_str(prefix), schema=default_schema
|
||||
)
|
||||
data = []
|
||||
data_size = 10000
|
||||
for i in range(data_size):
|
||||
d = {
|
||||
"id": i,
|
||||
"age": random.randint(1, 100),
|
||||
"word": fake_en.word().lower(),
|
||||
"sentence": fake_en.sentence().lower(),
|
||||
"paragraph": fake_en.paragraph().lower(),
|
||||
"text": fake_en.text().lower(),
|
||||
"emb": cf.gen_vectors(1, dim)[0],
|
||||
}
|
||||
data.append(d)
|
||||
batch_size = 5000
|
||||
for i in range(0, data_size, batch_size):
|
||||
collection_w.insert(
|
||||
data[i : i + batch_size]
|
||||
if i + batch_size < data_size
|
||||
else data[i:data_size]
|
||||
)
|
||||
collection_w.create_index(
|
||||
"emb",
|
||||
{"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}},
|
||||
)
|
||||
collection_w.create_index("word", {"index_type": "INVERTED"})
|
||||
collection_w.load()
|
||||
df = pd.DataFrame(data)
|
||||
log.info(f"dataframe\n{df}")
|
||||
text_fields = ["word", "sentence", "paragraph", "text"]
|
||||
wf_map = {}
|
||||
for field in text_fields:
|
||||
wf_map[field] = cf.analyze_documents(df[field].tolist(), language="en")
|
||||
# query single field for one word
|
||||
for field in text_fields:
|
||||
token = wf_map[field].most_common()[0][0]
|
||||
tm_expr = f"TextMatch({field}, '{token}')"
|
||||
int_expr = "age > 10"
|
||||
combined_expr = f"{tm_expr} {combine_op} {int_expr}"
|
||||
log.info(f"expr: {combined_expr}")
|
||||
res, _ = collection_w.query(expr=combined_expr, output_fields=["id", field, "age"])
|
||||
log.info(f"res len {len(res)}")
|
||||
for r in res:
|
||||
if combine_op == "and":
|
||||
assert token in r[field] and r["age"] > 10
|
||||
if combine_op == "or":
|
||||
assert token in r[field] or r["age"] > 10
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_text_match_with_some_empty_string(self):
|
||||
"""
|
||||
|
@ -5347,6 +5461,124 @@ class TestQueryTextMatch(TestcaseBase):
|
|||
for r in res:
|
||||
assert any([token in r[field] for token in multi_words])
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_text_match_with_nullable(self):
|
||||
"""
|
||||
target: test text match with nullable
|
||||
method: 1. enable text match and nullable, and insert data with varchar with some None value
|
||||
2. get the most common words and query with text match
|
||||
3. verify the result
|
||||
expected: text match successfully and result is correct
|
||||
"""
|
||||
# 1. initialize with data
|
||||
analyzer_params = {
|
||||
"tokenizer": "default",
|
||||
}
|
||||
# 1. initialize with data
|
||||
dim = 128
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(
|
||||
name="word",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_tokenizer=True,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
nullable=True,
|
||||
),
|
||||
FieldSchema(
|
||||
name="sentence",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_tokenizer=True,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
nullable=True,
|
||||
),
|
||||
FieldSchema(
|
||||
name="paragraph",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_tokenizer=True,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
nullable=True,
|
||||
),
|
||||
FieldSchema(
|
||||
name="text",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_tokenizer=True,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
nullable=True,
|
||||
),
|
||||
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
||||
]
|
||||
schema = CollectionSchema(fields=fields, description="test collection")
|
||||
data_size = 5000
|
||||
collection_w = self.init_collection_wrap(
|
||||
name=cf.gen_unique_str(prefix), schema=schema
|
||||
)
|
||||
fake = fake_en
|
||||
language = "en"
|
||||
data_null = [
|
||||
{
|
||||
"id": i,
|
||||
"word": None if random.random() < 0.9 else fake.word().lower(),
|
||||
"sentence": None if random.random() < 0.9 else fake.sentence().lower(),
|
||||
"paragraph": None if random.random() < 0.9 else fake.paragraph().lower(),
|
||||
"text": None if random.random() < 0.9 else fake.paragraph().lower(),
|
||||
"emb": [random.random() for _ in range(dim)],
|
||||
}
|
||||
for i in range(0, data_size)
|
||||
]
|
||||
data = data_null
|
||||
df = pd.DataFrame(data)
|
||||
log.info(f"dataframe\n{df}")
|
||||
batch_size = 5000
|
||||
for i in range(0, len(df), batch_size):
|
||||
collection_w.insert(
|
||||
data[i:i + batch_size]
|
||||
if i + batch_size < len(df)
|
||||
else data[i:len(df)]
|
||||
)
|
||||
collection_w.flush()
|
||||
collection_w.create_index(
|
||||
"emb",
|
||||
{"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}},
|
||||
)
|
||||
collection_w.load()
|
||||
text_fields = ["word", "sentence", "paragraph", "text"]
|
||||
wf_map = {}
|
||||
for field in text_fields:
|
||||
wf_map[field] = cf.analyze_documents(df[field].tolist(), language=language)
|
||||
# query single field for one word
|
||||
for field in text_fields:
|
||||
token = wf_map[field].most_common()[-1][0]
|
||||
expr = f"TextMatch({field}, '{token}')"
|
||||
log.info(f"expr: {expr}")
|
||||
res, _ = collection_w.query(expr=expr, output_fields=text_fields)
|
||||
log.info(f"res len {len(res)}, \n{res}")
|
||||
assert len(res) > 0
|
||||
for r in res:
|
||||
assert token in r[field]
|
||||
# query single field for multi-word
|
||||
for field in text_fields:
|
||||
# match top 3 most common words
|
||||
multi_words = []
|
||||
for word, count in wf_map[field].most_common(3):
|
||||
multi_words.append(word)
|
||||
string_of_multi_words = " ".join(multi_words)
|
||||
expr = f"TextMatch({field}, '{string_of_multi_words}')"
|
||||
log.info(f"expr {expr}")
|
||||
res, _ = collection_w.query(expr=expr, output_fields=text_fields)
|
||||
log.info(f"res len {len(res)}, {res}")
|
||||
assert len(res) > 0
|
||||
for r in res:
|
||||
assert any([token in r[field] for token in multi_words])
|
||||
|
||||
|
||||
class TestQueryTextMatchNegative(TestcaseBase):
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
|
@ -5471,3 +5703,56 @@ class TestQueryFunction(TestcaseBase):
|
|||
collection_w.query(
|
||||
call_expr, check_task=CheckTasks.err_res, check_items=error
|
||||
)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.xfail(reason="issue 36685")
|
||||
def test_query_text_match_with_unsupported_fields(self):
|
||||
"""
|
||||
target: test enable text match with unsupported field
|
||||
method: 1. enable text match in unsupported field
|
||||
2. create collection
|
||||
expected: create collection failed and return error
|
||||
"""
|
||||
analyzer_params = {
|
||||
"tokenizer": "default",
|
||||
}
|
||||
dim = 128
|
||||
default_fields = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(
|
||||
name="title",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_tokenizer=True,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="overview",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_tokenizer=True,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="age",
|
||||
dtype=DataType.INT64,
|
||||
enable_tokenizer=True,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
||||
]
|
||||
default_schema = CollectionSchema(
|
||||
fields=default_fields, description="test collection"
|
||||
)
|
||||
error = {ct.err_code: 2000, ct.err_msg: "field type is not supported"}
|
||||
self.init_collection_wrap(
|
||||
name=cf.gen_unique_str(prefix),
|
||||
schema=default_schema,
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items=error,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -13335,13 +13335,15 @@ class TestSearchWithTextMatchFilter(TestcaseBase):
|
|||
enable_match=True,
|
||||
tokenizer_params=tokenizer_params,
|
||||
),
|
||||
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
||||
FieldSchema(name="float32_emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
||||
FieldSchema(name="sparse_emb", dtype=DataType.SPARSE_FLOAT_VECTOR),
|
||||
]
|
||||
schema = CollectionSchema(fields=fields, description="test collection")
|
||||
data_size = 5000
|
||||
collection_w = self.init_collection_wrap(
|
||||
name=cf.gen_unique_str(prefix), schema=schema
|
||||
)
|
||||
log.info(f"collection {collection_w.describe()}")
|
||||
fake = fake_en
|
||||
if tokenizer == "jieba":
|
||||
language = "zh"
|
||||
|
@ -13356,7 +13358,8 @@ class TestSearchWithTextMatchFilter(TestcaseBase):
|
|||
"sentence": fake.sentence().lower(),
|
||||
"paragraph": fake.paragraph().lower(),
|
||||
"text": fake.text().lower(),
|
||||
"emb": [random.random() for _ in range(dim)],
|
||||
"float32_emb": [random.random() for _ in range(dim)],
|
||||
"sparse_emb": cf.gen_sparse_vectors(1, dim=10000)[0],
|
||||
}
|
||||
for i in range(data_size)
|
||||
]
|
||||
|
@ -13371,9 +13374,13 @@ class TestSearchWithTextMatchFilter(TestcaseBase):
|
|||
)
|
||||
collection_w.flush()
|
||||
collection_w.create_index(
|
||||
"emb",
|
||||
"float32_emb",
|
||||
{"index_type": "HNSW", "metric_type": "L2", "params": {"M": 16, "efConstruction": 500}},
|
||||
)
|
||||
collection_w.create_index(
|
||||
"sparse_emb",
|
||||
{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"},
|
||||
)
|
||||
if enable_inverted_index:
|
||||
collection_w.create_index("word", {"index_type": "INVERTED"})
|
||||
collection_w.load()
|
||||
|
@ -13382,47 +13389,56 @@ class TestSearchWithTextMatchFilter(TestcaseBase):
|
|||
wf_map = {}
|
||||
for field in text_fields:
|
||||
wf_map[field] = cf.analyze_documents(df[field].tolist(), language=language)
|
||||
# query single field for one token
|
||||
# search with filter single field for one token
|
||||
df_split = cf.split_dataframes(df, text_fields, language=language)
|
||||
log.info(f"df_split\n{df_split}")
|
||||
for field in text_fields:
|
||||
token = wf_map[field].most_common()[0][0]
|
||||
expr = f"TextMatch({field}, '{token}')"
|
||||
manual_result = df_split[
|
||||
df_split.apply(lambda row: token in row[field], axis=1)
|
||||
]
|
||||
log.info(f"expr: {expr}, manual_check_result\n: {manual_result}")
|
||||
res_list, _ = collection_w.search(
|
||||
data=[[random.random() for _ in range(dim)]],
|
||||
anns_field="emb",
|
||||
param={},
|
||||
limit=100,
|
||||
expr=expr, output_fields=["id", field])
|
||||
for res in res_list:
|
||||
assert len(res) > 0
|
||||
log.info(f"res len {len(res)} res {res}")
|
||||
for r in res:
|
||||
r = r.to_dict()
|
||||
assert token in r["entity"][field]
|
||||
for ann_field in ["float32_emb", "sparse_emb"]:
|
||||
log.info(f"ann_field {ann_field}")
|
||||
if ann_field == "float32_emb":
|
||||
search_data = [[random.random() for _ in range(dim)]]
|
||||
elif ann_field == "sparse_emb":
|
||||
search_data = cf.gen_sparse_vectors(1,dim=10000)
|
||||
else:
|
||||
search_data = [[random.random() for _ in range(dim)]]
|
||||
for field in text_fields:
|
||||
token = wf_map[field].most_common()[0][0]
|
||||
expr = f"TextMatch({field}, '{token}')"
|
||||
manual_result = df_split[
|
||||
df_split.apply(lambda row: token in row[field], axis=1)
|
||||
]
|
||||
log.info(f"expr: {expr}, manual_check_result: {len(manual_result)}")
|
||||
res_list, _ = collection_w.search(
|
||||
data=search_data,
|
||||
anns_field=ann_field,
|
||||
param={},
|
||||
limit=100,
|
||||
expr=expr, output_fields=["id", field])
|
||||
for res in res_list:
|
||||
log.info(f"res len {len(res)} res {res}")
|
||||
assert len(res) > 0
|
||||
for r in res:
|
||||
r = r.to_dict()
|
||||
assert token in r["entity"][field]
|
||||
|
||||
# query single field for multi-word
|
||||
for field in text_fields:
|
||||
# match top 10 most common words
|
||||
top_10_tokens = []
|
||||
for word, count in wf_map[field].most_common(10):
|
||||
top_10_tokens.append(word)
|
||||
string_of_top_10_words = " ".join(top_10_tokens)
|
||||
expr = f"TextMatch({field}, '{string_of_top_10_words}')"
|
||||
log.info(f"expr {expr}")
|
||||
res_list, _ = collection_w.search(
|
||||
data=[[random.random() for _ in range(dim)]],
|
||||
anns_field="emb",
|
||||
param={},
|
||||
limit=100,
|
||||
expr=expr, output_fields=["id", field])
|
||||
for res in res_list:
|
||||
log.info(f"res len {len(res)} res {res}")
|
||||
for r in res:
|
||||
r = r.to_dict()
|
||||
assert any([token in r["entity"][field] for token in top_10_tokens])
|
||||
# search with filter single field for multi-token
|
||||
for field in text_fields:
|
||||
# match top 10 most common words
|
||||
top_10_tokens = []
|
||||
for word, count in wf_map[field].most_common(10):
|
||||
top_10_tokens.append(word)
|
||||
string_of_top_10_words = " ".join(top_10_tokens)
|
||||
expr = f"TextMatch({field}, '{string_of_top_10_words}')"
|
||||
log.info(f"expr {expr}")
|
||||
res_list, _ = collection_w.search(
|
||||
data=search_data,
|
||||
anns_field=ann_field,
|
||||
param={},
|
||||
limit=100,
|
||||
expr=expr, output_fields=["id", field])
|
||||
for res in res_list:
|
||||
log.info(f"res len {len(res)} res {res}")
|
||||
assert len(res) > 0
|
||||
for r in res:
|
||||
r = r.to_dict()
|
||||
assert any([token in r["entity"][field] for token in top_10_tokens])
|
||||
|
||||
|
|
|
@ -450,7 +450,7 @@ class TestCreateCollection(TestBase):
|
|||
indexes = rsp['data']['indexes']
|
||||
assert len(indexes) == len(payload['indexParams'])
|
||||
# assert load success
|
||||
assert rsp['data']['load'] == "LoadStateLoaded"
|
||||
assert rsp['data']['load'] in ["LoadStateLoaded", "LoadStateLoading"]
|
||||
|
||||
@pytest.mark.parametrize("auto_id", [True])
|
||||
@pytest.mark.parametrize("enable_dynamic_field", [True])
|
||||
|
|
Loading…
Reference in New Issue