Update query test cases (#15550)

Signed-off-by: Binbin Lv <binbin.lv@zilliz.com>
pull/15537/head
binbin 2022-02-14 13:59:48 +08:00 committed by GitHub
parent f74a9c25ad
commit bbd888fb88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 34 additions and 144 deletions

View File

@ -30,6 +30,18 @@ class TestQueryParams(TestcaseBase):
test Query interface
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
"""
@pytest.mark.tags(CaseLabel.L2)
def test_query_invalid(self):
"""
target: test query with invalid term expression
method: query with invalid term expr
expected: raise exception
"""
collection_w, entities = self.init_collection_general(prefix, insert_data=True)[0:2]
term_expr = f'{default_int_field_name} in {entities[:default_pos]}'
error = {ct.err_code: 1, ct.err_msg: "unexpected token Identifier"}
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L0)
def test_query(self):
@ -46,6 +58,28 @@ class TestQueryParams(TestcaseBase):
res = vectors[0].iloc[0:pos, :1].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1)
def test_query_no_collection(self):
"""
target: test the scenario which query the non-exist collection
method: 1. create collection
2. drop collection
3. query the dropped collection
expected: raise exception and report the error
"""
# 1. initialize without data
collection_w = self.init_collection_general(prefix)[0]
# 2. Drop collection
log.info("test_query_no_collection: drop collection %s" % collection_w.name)
collection_w.drop()
# 3. Search without collection
log.info("test_query_no_collection: query without collection ")
collection_w.query(default_term_expr,
check_task=CheckTasks.err_res,
check_items={"err_code": 1,
"err_msg": "DescribeCollection failed: "
"can't find collection: %s" % collection_w.name})
@pytest.mark.tags(CaseLabel.L2)
def test_query_empty_collection(self):
"""
@ -1156,147 +1190,3 @@ class TestQueryOperation(TestcaseBase):
collection_w.query(f'{ct.default_int64_field_name} in [1]',
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
"""
******************************************************************
The following classes are copied from pymilvus test
******************************************************************
"""
def init_data(connect, collection, nb=ut.default_nb, partition_names=None, auto_id=True):
"""
Generate entities and add it in collection
"""
if nb == ct.default_nb:
insert_entities = default_entities
else:
insert_entities = ut.gen_entities(nb, is_normal=True)
if partition_names is None:
if auto_id:
res = connect.insert(collection, insert_entities)
else:
res = connect.insert(collection, insert_entities, ids=[i for i in range(nb)])
else:
if auto_id:
res = connect.insert(collection, insert_entities, partition_name=partition_names)
else:
res = connect.insert(collection, insert_entities, ids=[i for i in range(nb)],
partition_name=partition_names)
connect.flush([collection])
ids = res.primary_keys
return insert_entities, ids
class TestQueryBase:
"""
test Query interface
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
"""
@pytest.fixture(
scope="function",
params=ut.gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
@pytest.mark.tags(CaseLabel.L2)
def test_query_invalid(self, connect, collection):
"""
target: test query
method: query with term expr
expected: verify query result
"""
entities, ids = init_data(connect, collection)
assert len(ids) == ut.default_nb
connect.load_collection(collection)
term_expr = f'{default_int_field_name} in {entities[:default_pos]}'
with pytest.raises(Exception):
connect.query(collection, term_expr)
@pytest.mark.tags(CaseLabel.L0)
def test_query_valid(self, connect, collection):
"""
target: test query
method: query with term expr
expected: verify query result
"""
entities, ids = init_data(connect, collection)
assert len(ids) == ut.default_nb
connect.load_collection(collection)
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
res = connect.query(collection, term_expr, output_fields=["*", "%"])
assert len(res) == default_pos
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
res = connect.query(collection, term_expr, output_fields=[ut.default_float_vec_field_name])
assert len(res) == default_pos
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
@pytest.mark.tags(CaseLabel.L2)
def test_query_collection_not_existed(self, connect):
"""
target: test query not existed collection
method: query not existed collection
expected: raise exception
"""
collection = "not_exist"
with pytest.raises(Exception):
connect.query(collection, default_term_expr)
@pytest.mark.tags(CaseLabel.L2)
def test_query_invalid_collection_name(self, connect, get_collection_name):
"""
target: test query with invalid collection name
method: query with invalid collection name
expected: raise exception
"""
collection_name = get_collection_name
with pytest.raises(Exception):
connect.query(collection_name, default_term_expr)
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("expr", [1, "1", "12-s", "中文", [], {}, ()])
def test_query_expr_invalid_string(self, connect, collection, expr):
"""
target: test query with non-string expr
method: query with non-string expr, eg 1, [] ..
expected: raise exception
"""
connect.load_collection(collection)
with pytest.raises(Exception):
connect.query(collection, expr)
@pytest.mark.parametrize("fields", ut.gen_invalid_strs())
@pytest.mark.tags(CaseLabel.L2)
def test_query_invalid_output_fields(self, connect, collection, fields):
"""
target: test query with invalid output fields
method: query with invalid field fields
expected: raise exception
"""
init_data(connect, collection)
connect.load_collection(collection)
with pytest.raises(Exception):
connect.query(collection, default_term_expr, output_fields=[fields])
@pytest.mark.tags(CaseLabel.L0)
def test_query_partition(self, connect, collection):
"""
target: test query on partition
method: create a partition and query
expected: verify query result
"""
connect.create_partition(collection, ut.default_tag)
entities, ids = init_data(connect, collection, partition_names=ut.default_tag)
assert len(ids) == ut.default_nb
connect.load_partitions(collection, [ut.default_tag])
term_expr = f'{default_int_field_name} in {[i for i in range(default_pos)]}'
res = connect.query(collection, term_expr, partition_names=[ut.default_tag], output_fields=["*", "%"])
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])