Add simple test case to search_with_expression (#5094)

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
pull/5103/head
dragondriver 2021-04-30 15:56:01 +08:00 committed by GitHub
parent 396b3f33e9
commit 56dbbfe4cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 1 deletions

View File

@ -1789,6 +1789,51 @@ class TestSearchInvalid(object):
res = connect.search(collection, query)
class TestSearchWithExpression(object):
@pytest.fixture(
scope="function",
params=[1, 10, 20],
)
def limit(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_normal_expressions(),
)
def expression(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=[
{"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}},
]
)
def index_param(self, request):
return request.param
@pytest.fixture(
scope="function",
)
def search_params(self):
return {"metric_type": "L2", "params": {"nprobe": 10}}
@pytest.mark.tags(CaseLabel.tags_smoke)
def test_search_with_expression(self, connect, collection, index_param, search_params, limit, expression):
entities, ids = init_data(connect, collection)
assert len(ids) == default_nb
connect.create_index(collection, default_float_vec_field_name, index_param)
connect.load_collection(collection)
nq = 10
query_data = entities[2]["values"][:nq]
res = connect.search_with_expression(collection, query_data, default_float_vec_field_name, search_params,
limit, expression)
assert len(res) == nq
for topk_results in res:
assert len(topk_results) <= limit
def check_id_result(result, id):
limit_in = 5
ids = [entity.id for entity in result]

View File

@ -10,7 +10,7 @@ allure-pytest==2.7.0
pytest-print==0.2.1
pytest-level==0.1.1
pytest-xdist==2.2.1
pymilvus-distributed==0.0.62
pymilvus-distributed==0.0.63
pytest-rerunfailures==9.1.1
git+https://github.com/Projectplace/pytest-tags
ndg-httpsclient

View File

@ -867,6 +867,13 @@ def gen_binary_index():
return index_params
def gen_normal_expressions():
expressions = [
"int64 > 0",
]
return expressions
def get_search_param(index_type, metric_type="L2"):
search_params = {"metric_type": metric_type}
if index_type in ivf() or index_type in binary_support():