mirror of https://github.com/milvus-io/milvus.git
parent
f74a9c25ad
commit
bbd888fb88
|
@ -30,6 +30,18 @@ class TestQueryParams(TestcaseBase):
|
|||
test Query interface
|
||||
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
|
||||
"""
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_query_invalid(self):
|
||||
"""
|
||||
target: test query with invalid term expression
|
||||
method: query with invalid term expr
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_w, entities = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
term_expr = f'{default_int_field_name} in {entities[:default_pos]}'
|
||||
error = {ct.err_code: 1, ct.err_msg: "unexpected token Identifier"}
|
||||
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query(self):
|
||||
|
@ -46,6 +58,28 @@ class TestQueryParams(TestcaseBase):
|
|||
res = vectors[0].iloc[0:pos, :1].to_dict('records')
|
||||
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_no_collection(self):
|
||||
"""
|
||||
target: test the scenario which query the non-exist collection
|
||||
method: 1. create collection
|
||||
2. drop collection
|
||||
3. query the dropped collection
|
||||
expected: raise exception and report the error
|
||||
"""
|
||||
# 1. initialize without data
|
||||
collection_w = self.init_collection_general(prefix)[0]
|
||||
# 2. Drop collection
|
||||
log.info("test_query_no_collection: drop collection %s" % collection_w.name)
|
||||
collection_w.drop()
|
||||
# 3. Search without collection
|
||||
log.info("test_query_no_collection: query without collection ")
|
||||
collection_w.query(default_term_expr,
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": 1,
|
||||
"err_msg": "DescribeCollection failed: "
|
||||
"can't find collection: %s" % collection_w.name})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_query_empty_collection(self):
|
||||
"""
|
||||
|
@ -1156,147 +1190,3 @@ class TestQueryOperation(TestcaseBase):
|
|||
collection_w.query(f'{ct.default_int64_field_name} in [1]',
|
||||
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following classes are copied from pymilvus test
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
|
||||
def init_data(connect, collection, nb=ut.default_nb, partition_names=None, auto_id=True):
|
||||
"""
|
||||
Generate entities and add it in collection
|
||||
"""
|
||||
if nb == ct.default_nb:
|
||||
insert_entities = default_entities
|
||||
else:
|
||||
insert_entities = ut.gen_entities(nb, is_normal=True)
|
||||
if partition_names is None:
|
||||
if auto_id:
|
||||
res = connect.insert(collection, insert_entities)
|
||||
else:
|
||||
res = connect.insert(collection, insert_entities, ids=[i for i in range(nb)])
|
||||
else:
|
||||
if auto_id:
|
||||
res = connect.insert(collection, insert_entities, partition_name=partition_names)
|
||||
else:
|
||||
res = connect.insert(collection, insert_entities, ids=[i for i in range(nb)],
|
||||
partition_name=partition_names)
|
||||
connect.flush([collection])
|
||||
ids = res.primary_keys
|
||||
return insert_entities, ids
|
||||
|
||||
|
||||
class TestQueryBase:
|
||||
"""
|
||||
test Query interface
|
||||
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=ut.gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_query_invalid(self, connect, collection):
|
||||
"""
|
||||
target: test query
|
||||
method: query with term expr
|
||||
expected: verify query result
|
||||
"""
|
||||
entities, ids = init_data(connect, collection)
|
||||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
term_expr = f'{default_int_field_name} in {entities[:default_pos]}'
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, term_expr)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_valid(self, connect, collection):
|
||||
"""
|
||||
target: test query
|
||||
method: query with term expr
|
||||
expected: verify query result
|
||||
"""
|
||||
entities, ids = init_data(connect, collection)
|
||||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
|
||||
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
||||
assert len(res) == default_pos
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
||||
res = connect.query(collection, term_expr, output_fields=[ut.default_float_vec_field_name])
|
||||
assert len(res) == default_pos
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_query_collection_not_existed(self, connect):
|
||||
"""
|
||||
target: test query not existed collection
|
||||
method: query not existed collection
|
||||
expected: raise exception
|
||||
"""
|
||||
collection = "not_exist"
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, default_term_expr)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_query_invalid_collection_name(self, connect, get_collection_name):
|
||||
"""
|
||||
target: test query with invalid collection name
|
||||
method: query with invalid collection name
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection_name, default_term_expr)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("expr", [1, "1", "12-s", "中文", [], {}, ()])
|
||||
def test_query_expr_invalid_string(self, connect, collection, expr):
|
||||
"""
|
||||
target: test query with non-string expr
|
||||
method: query with non-string expr, eg 1, [] ..
|
||||
expected: raise exception
|
||||
"""
|
||||
connect.load_collection(collection)
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, expr)
|
||||
|
||||
@pytest.mark.parametrize("fields", ut.gen_invalid_strs())
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_query_invalid_output_fields(self, connect, collection, fields):
|
||||
"""
|
||||
target: test query with invalid output fields
|
||||
method: query with invalid field fields
|
||||
expected: raise exception
|
||||
"""
|
||||
init_data(connect, collection)
|
||||
connect.load_collection(collection)
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, default_term_expr, output_fields=[fields])
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_partition(self, connect, collection):
|
||||
"""
|
||||
target: test query on partition
|
||||
method: create a partition and query
|
||||
expected: verify query result
|
||||
"""
|
||||
connect.create_partition(collection, ut.default_tag)
|
||||
entities, ids = init_data(connect, collection, partition_names=ut.default_tag)
|
||||
assert len(ids) == ut.default_nb
|
||||
connect.load_partitions(collection, [ut.default_tag])
|
||||
term_expr = f'{default_int_field_name} in {[i for i in range(default_pos)]}'
|
||||
res = connect.query(collection, term_expr, partition_names=[ut.default_tag], output_fields=["*", "%"])
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
||||
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
|
|
Loading…
Reference in New Issue