test: supplementing case for text match (#36693)

/kind improvement

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
pull/37180/head
zhuwenxing 2024-10-28 10:31:40 +08:00 committed by GitHub
parent 7774b7275e
commit c8dd665bf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 354 additions and 56 deletions

View File

@ -117,7 +117,7 @@ class ResponseChecker:
return True
def assert_exception(self, res, actual=True, error_dict=None):
assert actual is False
assert actual is False, f"Response of API {self.func_name} expect get error, but success"
assert len(error_dict) > 0
if isinstance(res, Error):
error_code = error_dict[ct.err_code]

View File

@ -118,8 +118,6 @@ def get_bm25_ground_truth(corpus, queries, top_k=100, language="en"):
return results, scores
def custom_tokenizer(language="en"):
def remove_punctuation(text):
text = text.strip()
@ -153,11 +151,12 @@ def custom_tokenizer(language="en"):
def analyze_documents(texts, language="en"):
tokenizer = custom_tokenizer(language)
# Start timing
t0 = time.time()
new_texts = []
for text in texts:
if isinstance(text, str):
new_texts.append(text)
# Tokenize the corpus
tokenized = tokenizer.tokenize(texts, return_as="tuple")
tokenized = tokenizer.tokenize(new_texts, return_as="tuple")
# log.info(f"Tokenized: {tokenized}")
# Create a frequency counter
freq = Counter()
@ -170,13 +169,11 @@ def analyze_documents(texts, language="en"):
# Convert token ids back to words
word_freq = Counter({id_to_word[token_id]: count for token_id, count in freq.items()})
# End timing
tt = time.time() - t0
log.debug(f"Analyze document cost time: {tt}")
log.debug(f"word freq {word_freq.most_common(10)}")
return word_freq
def check_token_overlap(text_a, text_b, language="en"):
word_freq_a = analyze_documents([text_a], language)
word_freq_b = analyze_documents([text_b], language)
@ -3054,7 +3051,7 @@ def gen_sparse_vectors(nb, dim=1000, sparse_format="dok"):
rng = np.random.default_rng()
vectors = [{
d: rng.random() for d in random.sample(range(dim), random.randint(20, 30))
d: rng.random() for d in list(set(random.sample(range(dim), random.randint(20, 30)) + [0, 1]))
} for _ in range(nb)]
if sparse_format == "coo":
vectors = [

View File

@ -5206,6 +5206,120 @@ class TestQueryTextMatch(TestcaseBase):
log.info(f"res len {len(res)}")
assert len(res) == wf_map[field].most_common()[-1][1]
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("combine_op", ["and", "or"])
def test_query_text_match_with_non_varchar_fields_expr(self, combine_op):
"""
target: test text match with non-varchar fields expr
method: 1. enable text match for varchar field and add some non varchar fields
2. insert data, create index and load
3. query with text match expr and non-varchar fields expr
4. verify the result
expected: query result is correct
"""
# 1. initialize with data
fake_en = Faker("en_US")
analyzer_params = {
"tokenizer": "default",
}
dim = 128
default_fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
FieldSchema(
name="age",
dtype=DataType.INT64,
),
FieldSchema(
name="word",
dtype=DataType.VARCHAR,
max_length=65535,
enable_tokenizer=True,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(
name="sentence",
dtype=DataType.VARCHAR,
max_length=65535,
enable_tokenizer=True,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(
name="paragraph",
dtype=DataType.VARCHAR,
max_length=65535,
enable_tokenizer=True,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(
name="text",
dtype=DataType.VARCHAR,
max_length=65535,
enable_tokenizer=True,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
]
default_schema = CollectionSchema(
fields=default_fields, description="test collection"
)
collection_w = self.init_collection_wrap(
name=cf.gen_unique_str(prefix), schema=default_schema
)
data = []
data_size = 10000
for i in range(data_size):
d = {
"id": i,
"age": random.randint(1, 100),
"word": fake_en.word().lower(),
"sentence": fake_en.sentence().lower(),
"paragraph": fake_en.paragraph().lower(),
"text": fake_en.text().lower(),
"emb": cf.gen_vectors(1, dim)[0],
}
data.append(d)
batch_size = 5000
for i in range(0, data_size, batch_size):
collection_w.insert(
data[i : i + batch_size]
if i + batch_size < data_size
else data[i:data_size]
)
collection_w.create_index(
"emb",
{"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}},
)
collection_w.create_index("word", {"index_type": "INVERTED"})
collection_w.load()
df = pd.DataFrame(data)
log.info(f"dataframe\n{df}")
text_fields = ["word", "sentence", "paragraph", "text"]
wf_map = {}
for field in text_fields:
wf_map[field] = cf.analyze_documents(df[field].tolist(), language="en")
# query single field for one word
for field in text_fields:
token = wf_map[field].most_common()[0][0]
tm_expr = f"TextMatch({field}, '{token}')"
int_expr = "age > 10"
combined_expr = f"{tm_expr} {combine_op} {int_expr}"
log.info(f"expr: {combined_expr}")
res, _ = collection_w.query(expr=combined_expr, output_fields=["id", field, "age"])
log.info(f"res len {len(res)}")
for r in res:
if combine_op == "and":
assert token in r[field] and r["age"] > 10
if combine_op == "or":
assert token in r[field] or r["age"] > 10
@pytest.mark.tags(CaseLabel.L1)
def test_query_text_match_with_some_empty_string(self):
"""
@ -5347,6 +5461,124 @@ class TestQueryTextMatch(TestcaseBase):
for r in res:
assert any([token in r[field] for token in multi_words])
@pytest.mark.tags(CaseLabel.L1)
def test_query_text_match_with_nullable(self):
"""
target: test text match with nullable
method: 1. enable text match and nullable, and insert data with varchar with some None value
2. get the most common words and query with text match
3. verify the result
expected: text match successfully and result is correct
"""
# 1. initialize with data
analyzer_params = {
"tokenizer": "default",
}
# 1. initialize with data
dim = 128
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
FieldSchema(
name="word",
dtype=DataType.VARCHAR,
max_length=65535,
enable_tokenizer=True,
enable_match=True,
analyzer_params=analyzer_params,
nullable=True,
),
FieldSchema(
name="sentence",
dtype=DataType.VARCHAR,
max_length=65535,
enable_tokenizer=True,
enable_match=True,
analyzer_params=analyzer_params,
nullable=True,
),
FieldSchema(
name="paragraph",
dtype=DataType.VARCHAR,
max_length=65535,
enable_tokenizer=True,
enable_match=True,
analyzer_params=analyzer_params,
nullable=True,
),
FieldSchema(
name="text",
dtype=DataType.VARCHAR,
max_length=65535,
enable_tokenizer=True,
enable_match=True,
analyzer_params=analyzer_params,
nullable=True,
),
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
]
schema = CollectionSchema(fields=fields, description="test collection")
data_size = 5000
collection_w = self.init_collection_wrap(
name=cf.gen_unique_str(prefix), schema=schema
)
fake = fake_en
language = "en"
data_null = [
{
"id": i,
"word": None if random.random() < 0.9 else fake.word().lower(),
"sentence": None if random.random() < 0.9 else fake.sentence().lower(),
"paragraph": None if random.random() < 0.9 else fake.paragraph().lower(),
"text": None if random.random() < 0.9 else fake.paragraph().lower(),
"emb": [random.random() for _ in range(dim)],
}
for i in range(0, data_size)
]
data = data_null
df = pd.DataFrame(data)
log.info(f"dataframe\n{df}")
batch_size = 5000
for i in range(0, len(df), batch_size):
collection_w.insert(
data[i:i + batch_size]
if i + batch_size < len(df)
else data[i:len(df)]
)
collection_w.flush()
collection_w.create_index(
"emb",
{"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}},
)
collection_w.load()
text_fields = ["word", "sentence", "paragraph", "text"]
wf_map = {}
for field in text_fields:
wf_map[field] = cf.analyze_documents(df[field].tolist(), language=language)
# query single field for one word
for field in text_fields:
token = wf_map[field].most_common()[-1][0]
expr = f"TextMatch({field}, '{token}')"
log.info(f"expr: {expr}")
res, _ = collection_w.query(expr=expr, output_fields=text_fields)
log.info(f"res len {len(res)}, \n{res}")
assert len(res) > 0
for r in res:
assert token in r[field]
# query single field for multi-word
for field in text_fields:
# match top 3 most common words
multi_words = []
for word, count in wf_map[field].most_common(3):
multi_words.append(word)
string_of_multi_words = " ".join(multi_words)
expr = f"TextMatch({field}, '{string_of_multi_words}')"
log.info(f"expr {expr}")
res, _ = collection_w.query(expr=expr, output_fields=text_fields)
log.info(f"res len {len(res)}, {res}")
assert len(res) > 0
for r in res:
assert any([token in r[field] for token in multi_words])
class TestQueryTextMatchNegative(TestcaseBase):
@pytest.mark.tags(CaseLabel.L0)
@ -5471,3 +5703,56 @@ class TestQueryFunction(TestcaseBase):
collection_w.query(
call_expr, check_task=CheckTasks.err_res, check_items=error
)
@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.xfail(reason="issue 36685")
def test_query_text_match_with_unsupported_fields(self):
"""
target: test enable text match with unsupported field
method: 1. enable text match in unsupported field
2. create collection
expected: create collection failed and return error
"""
analyzer_params = {
"tokenizer": "default",
}
dim = 128
default_fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
FieldSchema(
name="title",
dtype=DataType.VARCHAR,
max_length=65535,
enable_tokenizer=True,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(
name="overview",
dtype=DataType.VARCHAR,
max_length=65535,
enable_tokenizer=True,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(
name="age",
dtype=DataType.INT64,
enable_tokenizer=True,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
]
default_schema = CollectionSchema(
fields=default_fields, description="test collection"
)
error = {ct.err_code: 2000, ct.err_msg: "field type is not supported"}
self.init_collection_wrap(
name=cf.gen_unique_str(prefix),
schema=default_schema,
check_task=CheckTasks.err_res,
check_items=error,
)

View File

@ -13335,13 +13335,15 @@ class TestSearchWithTextMatchFilter(TestcaseBase):
enable_match=True,
tokenizer_params=tokenizer_params,
),
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
FieldSchema(name="float32_emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
FieldSchema(name="sparse_emb", dtype=DataType.SPARSE_FLOAT_VECTOR),
]
schema = CollectionSchema(fields=fields, description="test collection")
data_size = 5000
collection_w = self.init_collection_wrap(
name=cf.gen_unique_str(prefix), schema=schema
)
log.info(f"collection {collection_w.describe()}")
fake = fake_en
if tokenizer == "jieba":
language = "zh"
@ -13356,7 +13358,8 @@ class TestSearchWithTextMatchFilter(TestcaseBase):
"sentence": fake.sentence().lower(),
"paragraph": fake.paragraph().lower(),
"text": fake.text().lower(),
"emb": [random.random() for _ in range(dim)],
"float32_emb": [random.random() for _ in range(dim)],
"sparse_emb": cf.gen_sparse_vectors(1, dim=10000)[0],
}
for i in range(data_size)
]
@ -13371,9 +13374,13 @@ class TestSearchWithTextMatchFilter(TestcaseBase):
)
collection_w.flush()
collection_w.create_index(
"emb",
"float32_emb",
{"index_type": "HNSW", "metric_type": "L2", "params": {"M": 16, "efConstruction": 500}},
)
collection_w.create_index(
"sparse_emb",
{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"},
)
if enable_inverted_index:
collection_w.create_index("word", {"index_type": "INVERTED"})
collection_w.load()
@ -13382,47 +13389,56 @@ class TestSearchWithTextMatchFilter(TestcaseBase):
wf_map = {}
for field in text_fields:
wf_map[field] = cf.analyze_documents(df[field].tolist(), language=language)
# query single field for one token
# search with filter single field for one token
df_split = cf.split_dataframes(df, text_fields, language=language)
log.info(f"df_split\n{df_split}")
for field in text_fields:
token = wf_map[field].most_common()[0][0]
expr = f"TextMatch({field}, '{token}')"
manual_result = df_split[
df_split.apply(lambda row: token in row[field], axis=1)
]
log.info(f"expr: {expr}, manual_check_result\n: {manual_result}")
res_list, _ = collection_w.search(
data=[[random.random() for _ in range(dim)]],
anns_field="emb",
param={},
limit=100,
expr=expr, output_fields=["id", field])
for res in res_list:
assert len(res) > 0
log.info(f"res len {len(res)} res {res}")
for r in res:
r = r.to_dict()
assert token in r["entity"][field]
for ann_field in ["float32_emb", "sparse_emb"]:
log.info(f"ann_field {ann_field}")
if ann_field == "float32_emb":
search_data = [[random.random() for _ in range(dim)]]
elif ann_field == "sparse_emb":
search_data = cf.gen_sparse_vectors(1,dim=10000)
else:
search_data = [[random.random() for _ in range(dim)]]
for field in text_fields:
token = wf_map[field].most_common()[0][0]
expr = f"TextMatch({field}, '{token}')"
manual_result = df_split[
df_split.apply(lambda row: token in row[field], axis=1)
]
log.info(f"expr: {expr}, manual_check_result: {len(manual_result)}")
res_list, _ = collection_w.search(
data=search_data,
anns_field=ann_field,
param={},
limit=100,
expr=expr, output_fields=["id", field])
for res in res_list:
log.info(f"res len {len(res)} res {res}")
assert len(res) > 0
for r in res:
r = r.to_dict()
assert token in r["entity"][field]
# query single field for multi-word
for field in text_fields:
# match top 10 most common words
top_10_tokens = []
for word, count in wf_map[field].most_common(10):
top_10_tokens.append(word)
string_of_top_10_words = " ".join(top_10_tokens)
expr = f"TextMatch({field}, '{string_of_top_10_words}')"
log.info(f"expr {expr}")
res_list, _ = collection_w.search(
data=[[random.random() for _ in range(dim)]],
anns_field="emb",
param={},
limit=100,
expr=expr, output_fields=["id", field])
for res in res_list:
log.info(f"res len {len(res)} res {res}")
for r in res:
r = r.to_dict()
assert any([token in r["entity"][field] for token in top_10_tokens])
# search with filter single field for multi-token
for field in text_fields:
# match top 10 most common words
top_10_tokens = []
for word, count in wf_map[field].most_common(10):
top_10_tokens.append(word)
string_of_top_10_words = " ".join(top_10_tokens)
expr = f"TextMatch({field}, '{string_of_top_10_words}')"
log.info(f"expr {expr}")
res_list, _ = collection_w.search(
data=search_data,
anns_field=ann_field,
param={},
limit=100,
expr=expr, output_fields=["id", field])
for res in res_list:
log.info(f"res len {len(res)} res {res}")
assert len(res) > 0
for r in res:
r = r.to_dict()
assert any([token in r["entity"][field] for token in top_10_tokens])

View File

@ -450,7 +450,7 @@ class TestCreateCollection(TestBase):
indexes = rsp['data']['indexes']
assert len(indexes) == len(payload['indexParams'])
# assert load success
assert rsp['data']['load'] == "LoadStateLoaded"
assert rsp['data']['load'] in ["LoadStateLoaded", "LoadStateLoading"]
@pytest.mark.parametrize("auto_id", [True])
@pytest.mark.parametrize("enable_dynamic_field", [True])