mirror of https://github.com/milvus-io/milvus.git
Merge part query cases of pymilvus and orm (#8253)
Signed-off-by: ThreadDao <yufen.zong@zilliz.com>pull/8255/head
parent
dc328679a1
commit
75ce32dcad
|
@ -58,130 +58,6 @@ def init_binary_data(connect, collection, nb=3000, insert=True, partition_names=
|
|||
return insert_raw_vectors, insert_entities, ids
|
||||
|
||||
|
||||
class TestQueryBase:
|
||||
"""
|
||||
test Query interface
|
||||
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=ut.gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=ut.gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, connect):
|
||||
return request.param
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_invalid(self, connect, collection):
|
||||
"""
|
||||
target: test query
|
||||
method: query with term expr
|
||||
expected: verify query result
|
||||
"""
|
||||
entities, ids = init_data(connect, collection)
|
||||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
term_expr = f'{default_int_field_name} in {entities[:default_pos]}'
|
||||
with pytest.raises(Exception):
|
||||
res = connect.query(collection, term_expr)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_valid(self, connect, collection):
|
||||
"""
|
||||
target: test query
|
||||
method: query with term expr
|
||||
expected: verify query result
|
||||
"""
|
||||
entities, ids = init_data(connect, collection)
|
||||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
|
||||
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
||||
assert len(res) == default_pos
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
||||
res = connect.query(collection, term_expr, output_fields=[ut.default_float_vec_field_name])
|
||||
assert len(res) == default_pos
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_collection_not_existed(self, connect):
|
||||
"""
|
||||
target: test query not existed collection
|
||||
method: query not existed collection
|
||||
expected: raise exception
|
||||
"""
|
||||
collection = "not_exist"
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, default_term_expr)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_invalid_collection_name(self, connect, get_collection_name):
|
||||
"""
|
||||
target: test query with invalid collection name
|
||||
method: query with invalid collection name
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection_name, default_term_expr)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("expr", [1, "1", "12-s", "中文", [], {}, ()])
|
||||
def test_query_expr_invalid_string(self, connect, collection, expr):
|
||||
"""
|
||||
target: test query with non-string expr
|
||||
method: query with non-string expr, eg 1, [] ..
|
||||
expected: raise exception
|
||||
"""
|
||||
# entities, ids = init_data(connect, collection)
|
||||
# assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, expr)
|
||||
|
||||
@pytest.mark.xfail(reason="#6072")
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_binary_expr_single_term_array(self, connect, binary_collection):
|
||||
"""
|
||||
target: test query with single array term expr
|
||||
method: query with single array value
|
||||
expected: query result is one entity
|
||||
"""
|
||||
_, binary_entities, ids = init_binary_data(connect, binary_collection)
|
||||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(binary_collection)
|
||||
term_expr = f'{default_int_field_name} in [0]'
|
||||
res = connect.query(binary_collection, term_expr, output_fields=["*", "%"])
|
||||
assert len(res) == 1
|
||||
assert res[0][default_int_field_name] == binary_entities[0]["values"][0]
|
||||
assert res[1][default_float_field_name] == binary_entities[1]["values"][0]
|
||||
assert res[2][ut.default_float_vec_field_name] == binary_entities[2]["values"][0]
|
||||
|
||||
@pytest.mark.parametrize("fields", ut.gen_invalid_strs())
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_invalid_output_fields(self, connect, collection, fields):
|
||||
"""
|
||||
target: test query with invalid output fields
|
||||
method: query with invalid field fields
|
||||
expected: raise exception
|
||||
"""
|
||||
init_data(connect, collection)
|
||||
connect.load_collection(collection)
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, default_term_expr, output_fields=[fields])
|
||||
|
||||
|
||||
class TestQueryPartition:
|
||||
"""
|
||||
test Query interface
|
||||
|
|
|
@ -11,6 +11,7 @@ from common import common_func as cf
|
|||
from common import common_type as ct
|
||||
from common.common_type import CaseLabel, CheckTasks
|
||||
from utils.util_log import test_log as log
|
||||
import utils.utils as ut
|
||||
|
||||
prefix = "query"
|
||||
exp_res = "exp_res"
|
||||
|
@ -18,6 +19,11 @@ default_term_expr = f'{ct.default_int64_field_name} in [0, 1]'
|
|||
default_index_params = {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}}
|
||||
binary_index_params = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"nlist": 64}}
|
||||
|
||||
default_entities = ut.gen_entities(ut.default_nb, is_normal=True)
|
||||
default_pos = 5
|
||||
default_int_field_name = "int64"
|
||||
default_float_field_name = "float"
|
||||
|
||||
|
||||
class TestQueryBase(TestcaseBase):
|
||||
"""
|
||||
|
@ -1056,3 +1062,130 @@ class TestQueryOperation(TestcaseBase):
|
|||
res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name])
|
||||
assert len(res) == 1
|
||||
assert res[0][ct.default_int64_field_name] == half
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following classes are copied from pymilvus test
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
|
||||
def init_data(connect, collection, nb=ut.default_nb, partition_names=None, auto_id=True):
|
||||
"""
|
||||
Generate entities and add it in collection
|
||||
"""
|
||||
if nb == 3000:
|
||||
insert_entities = default_entities
|
||||
else:
|
||||
insert_entities = ut.gen_entities(nb, is_normal=True)
|
||||
if partition_names is None:
|
||||
if auto_id:
|
||||
res = connect.insert(collection, insert_entities)
|
||||
else:
|
||||
res = connect.insert(collection, insert_entities, ids=[i for i in range(nb)])
|
||||
else:
|
||||
if auto_id:
|
||||
res = connect.insert(collection, insert_entities, partition_name=partition_names)
|
||||
else:
|
||||
res = connect.insert(collection, insert_entities, ids=[i for i in range(nb)],
|
||||
partition_name=partition_names)
|
||||
connect.flush([collection])
|
||||
ids = res.primary_keys
|
||||
return insert_entities, ids
|
||||
|
||||
|
||||
class TestQueryBase:
|
||||
"""
|
||||
test Query interface
|
||||
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=ut.gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_invalid(self, connect, collection):
|
||||
"""
|
||||
target: test query
|
||||
method: query with term expr
|
||||
expected: verify query result
|
||||
"""
|
||||
entities, ids = init_data(connect, collection)
|
||||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
term_expr = f'{default_int_field_name} in {entities[:default_pos]}'
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, term_expr)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_valid(self, connect, collection):
|
||||
"""
|
||||
target: test query
|
||||
method: query with term expr
|
||||
expected: verify query result
|
||||
"""
|
||||
entities, ids = init_data(connect, collection)
|
||||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
|
||||
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
||||
assert len(res) == default_pos
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
||||
res = connect.query(collection, term_expr, output_fields=[ut.default_float_vec_field_name])
|
||||
assert len(res) == default_pos
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_collection_not_existed(self, connect):
|
||||
"""
|
||||
target: test query not existed collection
|
||||
method: query not existed collection
|
||||
expected: raise exception
|
||||
"""
|
||||
collection = "not_exist"
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, default_term_expr)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_invalid_collection_name(self, connect, get_collection_name):
|
||||
"""
|
||||
target: test query with invalid collection name
|
||||
method: query with invalid collection name
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection_name, default_term_expr)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("expr", [1, "1", "12-s", "中文", [], {}, ()])
|
||||
def test_query_expr_invalid_string(self, connect, collection, expr):
|
||||
"""
|
||||
target: test query with non-string expr
|
||||
method: query with non-string expr, eg 1, [] ..
|
||||
expected: raise exception
|
||||
"""
|
||||
connect.load_collection(collection)
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, expr)
|
||||
|
||||
@pytest.mark.parametrize("fields", ut.gen_invalid_strs())
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_invalid_output_fields(self, connect, collection, fields):
|
||||
"""
|
||||
target: test query with invalid output fields
|
||||
method: query with invalid field fields
|
||||
expected: raise exception
|
||||
"""
|
||||
init_data(connect, collection)
|
||||
connect.load_collection(collection)
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, default_term_expr, output_fields=[fields])
|
||||
|
|
Loading…
Reference in New Issue