mirror of https://github.com/milvus-io/milvus.git
test: add restful cases for text match feature (#36405)
/kind improvement Since creating a collection with text match is not yet implemented on the RESTful interface, we will temporarily use pymilvus to create a collection for now. This PR includes a case to test using text match filters in search queries through the RESTful interface. Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>pull/36372/head
parent
bd7910632a
commit
31353ae406
|
@ -1,20 +1,24 @@
|
|||
import random
|
||||
from sklearn import preprocessing
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
|
||||
import utils.utils
|
||||
from utils import constant
|
||||
from utils.utils import gen_collection_name, get_sorted_distance
|
||||
from utils.util_log import test_log as logger
|
||||
import pytest
|
||||
from base.testbase import TestBase
|
||||
from utils.utils import (gen_unique_str, get_data_by_payload, get_common_fields_by_data, gen_vector)
|
||||
from utils.utils import (gen_unique_str, get_data_by_payload, get_common_fields_by_data, gen_vector, analyze_documents)
|
||||
from pymilvus import (
|
||||
FieldSchema, CollectionSchema, DataType,
|
||||
Collection, utility
|
||||
)
|
||||
from faker import Faker
|
||||
Faker.seed(19530)
|
||||
fake_en = Faker("en_US")
|
||||
fake_zh = Faker("zh_CN")
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
|
@ -1661,6 +1665,106 @@ class TestSearchVector(TestBase):
|
|||
assert len(res) == limit
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer", ["jieba", "default"])
|
||||
def test_search_vector_with_text_match_filter(self, tokenizer):
|
||||
"""
|
||||
Query a vector with a simple payload
|
||||
"""
|
||||
fake = fake_en
|
||||
language = "en"
|
||||
if tokenizer == "jieba":
|
||||
fake = fake_zh
|
||||
language = "zh"
|
||||
# create a collection
|
||||
dim = 128
|
||||
analyzer_params = {
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
name = gen_collection_name()
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(
|
||||
name="word",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
is_partition_key=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="sentence",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="paragraph",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="text",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
||||
]
|
||||
schema = CollectionSchema(fields=fields, description="test collection")
|
||||
collection = Collection(name=name, schema=schema
|
||||
)
|
||||
rsp = self.collection_client.collection_describe(name)
|
||||
logger.info(f"rsp: {rsp}")
|
||||
assert rsp['code'] == 0
|
||||
data_size = 3000
|
||||
batch_size = 1000
|
||||
# insert data
|
||||
data = [
|
||||
{
|
||||
"id": i,
|
||||
"word": fake.word().lower(),
|
||||
"sentence": fake.sentence().lower(),
|
||||
"paragraph": fake.sentence().lower(),
|
||||
"text": fake.text().lower(),
|
||||
"emb": [random.random() for _ in range(dim)]
|
||||
}
|
||||
for i in range(data_size)
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
text_fields = ["word", "sentence", "paragraph", "text"]
|
||||
wf_map = {}
|
||||
for field in text_fields:
|
||||
wf_map[field] = analyze_documents(df[field].tolist(), language=language)
|
||||
for i in range(0, data_size, batch_size):
|
||||
tmp = data[i:i + batch_size]
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"data": tmp,
|
||||
}
|
||||
rsp = self.vector_client.vector_insert(payload)
|
||||
assert rsp['code'] == 0
|
||||
assert rsp['data']['insertCount'] == len(tmp)
|
||||
collection.create_index(
|
||||
"emb",
|
||||
{"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}},
|
||||
)
|
||||
collection.load()
|
||||
time.sleep(5)
|
||||
vector_to_search = [[random.random() for _ in range(dim)]]
|
||||
for field in text_fields:
|
||||
token = wf_map[field].most_common()[0][0]
|
||||
expr = f"TextMatch({field}, '{token}')"
|
||||
logger.info(f"expr: {expr}")
|
||||
rsp = self.vector_client.vector_search({"collectionName": name, "data":vector_to_search, "filter": f"{expr}", "outputFields": ["*"]})
|
||||
assert rsp['code'] == 0, rsp
|
||||
for d in rsp['data']:
|
||||
assert token in d[field]
|
||||
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
class TestSearchVectorNegative(TestBase):
|
||||
|
@ -2371,6 +2475,106 @@ class TestQueryVector(TestBase):
|
|||
if "like" in filter_expr:
|
||||
assert name.startswith(prefix)
|
||||
|
||||
@pytest.mark.parametrize("tokenizer", ["jieba", "default"])
|
||||
def test_query_vector_with_text_match_filter(self, tokenizer):
|
||||
"""
|
||||
Query a vector with a simple payload
|
||||
"""
|
||||
fake = fake_en
|
||||
language = "en"
|
||||
if tokenizer == "jieba":
|
||||
fake = fake_zh
|
||||
language = "zh"
|
||||
# create a collection
|
||||
dim = 128
|
||||
analyzer_params = {
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
name = gen_collection_name()
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(
|
||||
name="word",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
is_partition_key=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="sentence",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="paragraph",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="text",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
||||
]
|
||||
schema = CollectionSchema(fields=fields, description="test collection")
|
||||
collection = Collection(name=name, schema=schema
|
||||
)
|
||||
rsp = self.collection_client.collection_describe(name)
|
||||
logger.info(f"rsp: {rsp}")
|
||||
assert rsp['code'] == 0
|
||||
data_size = 3000
|
||||
batch_size = 1000
|
||||
# insert data
|
||||
data = [
|
||||
{
|
||||
"id": i,
|
||||
"word": fake.word().lower(),
|
||||
"sentence": fake.sentence().lower(),
|
||||
"paragraph": fake.sentence().lower(),
|
||||
"text": fake.text().lower(),
|
||||
"emb": [random.random() for _ in range(dim)]
|
||||
}
|
||||
for i in range(data_size)
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
text_fields = ["word", "sentence", "paragraph", "text"]
|
||||
wf_map = {}
|
||||
for field in text_fields:
|
||||
wf_map[field] = analyze_documents(df[field].tolist(), language=language)
|
||||
for i in range(0, data_size, batch_size):
|
||||
tmp = data[i:i + batch_size]
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"data": tmp,
|
||||
}
|
||||
rsp = self.vector_client.vector_insert(payload)
|
||||
assert rsp['code'] == 0
|
||||
assert rsp['data']['insertCount'] == len(tmp)
|
||||
collection.create_index(
|
||||
"emb",
|
||||
{"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}},
|
||||
)
|
||||
collection.load()
|
||||
time.sleep(5)
|
||||
for field in text_fields:
|
||||
token = wf_map[field].most_common()[0][0]
|
||||
expr = f"TextMatch({field}, '{token}')"
|
||||
logger.info(f"expr: {expr}")
|
||||
rsp = self.vector_client.vector_query({"collectionName": name, "filter": f"{expr}", "outputFields": ["*"]})
|
||||
assert rsp['code'] == 0, rsp
|
||||
for d in rsp['data']:
|
||||
assert token in d[field]
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
class TestQueryVectorNegative(TestBase):
|
||||
|
|
|
@ -11,9 +11,50 @@ import requests
|
|||
from loguru import logger
|
||||
import datetime
|
||||
from sklearn.metrics import pairwise_distances
|
||||
from collections import Counter
|
||||
import bm25s
|
||||
import jieba
|
||||
fake = Faker()
|
||||
rng = np.random.default_rng()
|
||||
|
||||
|
||||
def analyze_documents(texts, language="en"):
|
||||
stopwords = "en"
|
||||
if language in ["en", "english"]:
|
||||
stopwords = "en"
|
||||
if language in ["zh", "cn", "chinese"]:
|
||||
stopword = " "
|
||||
new_texts = []
|
||||
for doc in texts:
|
||||
seg_list = jieba.cut(doc, cut_all=True)
|
||||
new_texts.append(" ".join(seg_list))
|
||||
texts = new_texts
|
||||
stopwords = [stopword]
|
||||
# Start timing
|
||||
t0 = time.time()
|
||||
|
||||
# Tokenize the corpus
|
||||
tokenized = bm25s.tokenize(texts, lower=True, stopwords=stopwords)
|
||||
# log.info(f"Tokenized: {tokenized}")
|
||||
# Create a frequency counter
|
||||
freq = Counter()
|
||||
|
||||
# Count the frequency of each token
|
||||
for doc_ids in tokenized.ids:
|
||||
freq.update(doc_ids)
|
||||
# Create a reverse vocabulary mapping
|
||||
id_to_word = {id: word for word, id in tokenized.vocab.items()}
|
||||
|
||||
# Convert token ids back to words
|
||||
word_freq = Counter({id_to_word[token_id]: count for token_id, count in freq.items()})
|
||||
|
||||
# End timing
|
||||
tt = time.time() - t0
|
||||
logger.info(f"Analyze document cost time: {tt}")
|
||||
|
||||
return word_freq
|
||||
|
||||
|
||||
def random_string(length=8):
|
||||
letters = string.ascii_letters
|
||||
return ''.join(random.choice(letters) for _ in range(length))
|
||||
|
|
|
@ -105,10 +105,10 @@ cd ${ROOT}/tests/restful_client_v2
|
|||
|
||||
if [[ -n "${TEST_TIMEOUT:-}" ]]; then
|
||||
|
||||
timeout "${TEST_TIMEOUT}" pytest testcases --endpoint http://${MILVUS_SERVICE_NAME}:${MILVUS_SERVICE_PORT} --minio_host ${MINIO_SERVICE_NAME} -v -x -m L0 -n 6 --timeout 180\
|
||||
timeout "${TEST_TIMEOUT}" pytest testcases --endpoint http://${MILVUS_SERVICE_NAME}:${MILVUS_SERVICE_PORT} --minio_host ${MINIO_SERVICE_NAME} -v -x -m L0 -n 6 --timeout 240\
|
||||
--html=${CI_LOG_PATH}/report_restful.html --self-contained-html
|
||||
else
|
||||
pytest testcases --endpoint http://${MILVUS_SERVICE_NAME}:${MILVUS_SERVICE_PORT} --minio_host ${MINIO_SERVICE_NAME} -v -x -m L0 -n 6 --timeout 180\
|
||||
pytest testcases --endpoint http://${MILVUS_SERVICE_NAME}:${MILVUS_SERVICE_PORT} --minio_host ${MINIO_SERVICE_NAME} -v -x -m L0 -n 6 --timeout 240\
|
||||
--html=${CI_LOG_PATH}/report_restful.html --self-contained-html
|
||||
fi
|
||||
|
||||
|
|
Loading…
Reference in New Issue