mirror of https://github.com/milvus-io/milvus.git
669 lines
29 KiB
Python
669 lines
29 KiB
Python
import pdb
|
|
import logging
|
|
|
|
import pytest
|
|
from pymilvus import DataType
|
|
|
|
import utils as ut
|
|
|
|
default_entities = ut.gen_entities(ut.default_nb, is_normal=True)
|
|
raw_vectors, default_binary_entities = ut.gen_binary_entities(ut.default_nb)
|
|
default_int_field_name = "int64"
|
|
default_float_field_name = "float"
|
|
default_pos = 5
|
|
default_term_expr = f'{default_int_field_name} in {[i for i in range(default_pos)]}'
|
|
|
|
|
|
def init_data(connect, collection, nb=ut.default_nb, partition_names=None, auto_id=True):
|
|
"""
|
|
Generate entities and add it in collection
|
|
"""
|
|
if nb == 3000:
|
|
insert_entities = default_entities
|
|
else:
|
|
insert_entities = ut.gen_entities(nb, is_normal=True)
|
|
if partition_names is None:
|
|
if auto_id:
|
|
ids = connect.insert(collection, insert_entities)
|
|
else:
|
|
ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)])
|
|
else:
|
|
if auto_id:
|
|
ids = connect.insert(collection, insert_entities, partition_name=partition_names)
|
|
else:
|
|
ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)],
|
|
partition_name=partition_names)
|
|
connect.flush([collection])
|
|
return insert_entities, ids
|
|
|
|
|
|
def init_binary_data(connect, collection, nb=3000, insert=True, partition_names=None):
|
|
"""
|
|
Generate entities and add it in collection
|
|
"""
|
|
ids = []
|
|
# global binary_entities
|
|
global raw_vectors
|
|
if nb == 3000:
|
|
insert_entities = default_binary_entities
|
|
insert_raw_vectors = raw_vectors
|
|
else:
|
|
insert_raw_vectors, insert_entities = ut.gen_binary_entities(nb)
|
|
if insert is True:
|
|
if partition_names is None:
|
|
ids = connect.insert(collection, insert_entities)
|
|
else:
|
|
ids = connect.insert(collection, insert_entities, partition_name=partition_names)
|
|
connect.flush([collection])
|
|
return insert_raw_vectors, insert_entities, ids
|
|
|
|
|
|
class TestQueryBase:
|
|
"""
|
|
test Query interface
|
|
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
|
|
"""
|
|
|
|
@pytest.fixture(
|
|
scope="function",
|
|
params=ut.gen_invalid_strs()
|
|
)
|
|
def get_collection_name(self, request):
|
|
yield request.param
|
|
|
|
@pytest.fixture(
|
|
scope="function",
|
|
params=ut.gen_simple_index()
|
|
)
|
|
def get_simple_index(self, request, connect):
|
|
return request.param
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_invalid(self, connect, collection):
|
|
"""
|
|
target: test query
|
|
method: query with term expr
|
|
expected: verify query result
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
term_expr = f'{default_int_field_name} in {entities[:default_pos]}'
|
|
with pytest.raises(Exception):
|
|
res = connect.query(collection, term_expr)
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_valid(self, connect, collection):
|
|
"""
|
|
target: test query
|
|
method: query with term expr
|
|
expected: verify query result
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
|
|
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
|
assert len(res) == default_pos
|
|
for _id, index in enumerate(ids[:default_pos]):
|
|
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
|
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
|
res = connect.query(collection, term_expr, output_fields=[ut.default_float_vec_field_name])
|
|
assert len(res) == default_pos
|
|
for _id, index in enumerate(ids[:default_pos]):
|
|
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
|
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_collection_not_existed(self, connect):
|
|
"""
|
|
target: test query not existed collection
|
|
method: query not existed collection
|
|
expected: raise exception
|
|
"""
|
|
collection = "not_exist"
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, default_term_expr)
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_without_connect(self, dis_connect, collection):
|
|
"""
|
|
target: test query without connection
|
|
method: close connect and query
|
|
expected: raise exception
|
|
"""
|
|
with pytest.raises(Exception):
|
|
dis_connect.query(collection, default_term_expr)
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_invalid_collection_name(self, connect, get_collection_name):
|
|
"""
|
|
target: test query with invalid collection name
|
|
method: query with invalid collection name
|
|
expected: raise exception
|
|
"""
|
|
collection_name = get_collection_name
|
|
with pytest.raises(Exception):
|
|
connect.query(collection_name, default_term_expr)
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_after_index(self, connect, collection, get_simple_index):
|
|
"""
|
|
target: test query after creating index
|
|
method: query after index
|
|
expected: query result is correct
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.create_index(collection, ut.default_float_vec_field_name, get_simple_index)
|
|
connect.load_collection(collection)
|
|
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
|
|
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
|
logging.getLogger().info(res)
|
|
assert len(res) == default_pos
|
|
for _id, index in enumerate(ids[:default_pos]):
|
|
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
|
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
|
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[-1]["values"][index])
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.L2)
|
|
def test_query_after_search(self, connect, collection):
|
|
"""
|
|
target: test query after search
|
|
method: query after search
|
|
expected: query result is correct
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
top_k = 10
|
|
nq = 2
|
|
query, _ = ut.gen_query_vectors(ut.default_float_vec_field_name, entities, top_k=top_k, nq=nq)
|
|
connect.load_collection(collection)
|
|
search_res = connect.search(collection, query)
|
|
assert len(search_res) == nq
|
|
assert len(search_res[0]) == top_k
|
|
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
|
|
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
|
logging.getLogger().info(res)
|
|
assert len(res) == default_pos
|
|
for _id, index in enumerate(ids[:default_pos]):
|
|
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
|
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
|
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_empty_collection(self, connect, collection):
|
|
"""
|
|
target: test query empty collection
|
|
method: query on a empty collection
|
|
expected: todo
|
|
"""
|
|
connect.load_collection(collection)
|
|
res = connect.query(collection, default_term_expr)
|
|
logging.getLogger().info(res)
|
|
assert len(res) == 0
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_without_loading(self, connect, collection):
|
|
"""
|
|
target: test query without loading
|
|
method: no loading before query
|
|
expected: raise exception
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, default_term_expr)
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_collection_not_primary_key(self, connect, collection):
|
|
"""
|
|
target: test query on collection that not on the primary field
|
|
method: 1.create collection with auto_id=True 2.query on the other field
|
|
expected: exception raised
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
term_expr = f'{default_float_field_name} in {entities[:default_pos]}'
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, term_expr)
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_expr_none(self, connect, collection):
|
|
"""
|
|
target: test query with none expr
|
|
method: query with expr None
|
|
expected: raise exception
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, None)
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
@pytest.mark.parametrize("expr", [1, "1", "12-s", "中文", [], {}, ()])
|
|
def test_query_expr_invalid_string(self, connect, collection, expr):
|
|
"""
|
|
target: test query with non-string expr
|
|
method: query with non-string expr, eg 1, [] ..
|
|
expected: raise exception
|
|
"""
|
|
# entities, ids = init_data(connect, collection)
|
|
# assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, expr)
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_expr_not_existed_field(self, connect, collection):
|
|
"""
|
|
target: test query with not existed field
|
|
method: query by term expr with fake field
|
|
expected: raise exception
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
term_expr = 'field in [1, 2]'
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, term_expr)
|
|
|
|
@pytest.mark.parametrize("expr", [f'{default_int_field_name} inn [1, 2]',
|
|
f'{default_int_field_name} not in [1, 2]',
|
|
f'{default_int_field_name} in not [1, 2]'])
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_expr_wrong_term_keyword(self, connect, collection, expr):
|
|
"""
|
|
target: test query with wrong term expr keyword
|
|
method: query with wrong keyword term expr
|
|
expected: raise exception
|
|
"""
|
|
connect.load_collection(collection)
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, expr)
|
|
|
|
@pytest.mark.parametrize("expr", [f'{default_int_field_name} in 1',
|
|
f'{default_int_field_name} in "in"',
|
|
f'{default_int_field_name} in (mn)'])
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_expr_non_array_term(self, connect, collection, expr):
|
|
"""
|
|
target: test query with non-array term expr
|
|
method: query with non-array term expr
|
|
expected: raise exception
|
|
"""
|
|
connect.load_collection(collection)
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, expr)
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_expr_empty_term_array(self, connect, collection):
|
|
"""
|
|
target: test query with empty array term expr
|
|
method: query with empty term expr
|
|
expected: todo
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
term_expr = f'{default_int_field_name} in []'
|
|
res = connect.query(collection, term_expr)
|
|
assert len(res) == 0
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_expr_single_term_array(self, connect, collection):
|
|
"""
|
|
target: test query with single array term expr
|
|
method: query with single array value
|
|
expected: query result is one entity
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
term_expr = f'{default_int_field_name} in [0]'
|
|
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
|
assert len(res) == 1
|
|
assert res[0][default_int_field_name] == entities[0]["values"][0]
|
|
assert res[0][default_float_field_name] == entities[1]["values"][0]
|
|
ut.assert_equal_vector(res[0][ut.default_float_vec_field_name], entities[2]["values"][0])
|
|
|
|
@pytest.mark.xfail(reason="#6072")
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_binary_expr_single_term_array(self, connect, binary_collection):
|
|
"""
|
|
target: test query with single array term expr
|
|
method: query with single array value
|
|
expected: query result is one entity
|
|
"""
|
|
_, binary_entities, ids = init_binary_data(connect, binary_collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(binary_collection)
|
|
term_expr = f'{default_int_field_name} in [0]'
|
|
res = connect.query(binary_collection, term_expr, output_fields=["*", "%"])
|
|
assert len(res) == 1
|
|
assert res[0][default_int_field_name] == binary_entities[0]["values"][0]
|
|
assert res[1][default_float_field_name] == binary_entities[1]["values"][0]
|
|
assert res[2][ut.default_float_vec_field_name] == binary_entities[2]["values"][0]
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_expr_all_term_array(self, connect, collection):
|
|
"""
|
|
target: test query with all array term expr
|
|
method: query with all array value
|
|
expected: verify query result
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
term_expr = f'{default_int_field_name} in {ids}'
|
|
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
|
assert len(res) == ut.default_nb
|
|
for _id, index in enumerate(ids):
|
|
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
|
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
|
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_expr_repeated_term_array(self, connect, collection):
|
|
"""
|
|
target: test query with repeated term array on primary field with unique value
|
|
method: query with repeated array value
|
|
expected: verify query result
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
int_values = [0, 0]
|
|
term_expr = f'{default_int_field_name} in {int_values}'
|
|
res = connect.query(collection, term_expr)
|
|
assert len(res) == 2
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_expr_inconstant_term_array(self, connect, collection):
|
|
"""
|
|
target: test query with term expr that field and array are inconsistent
|
|
method: query with int field and float values
|
|
expected: raise exception
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
expr = f'{default_int_field_name} in [1.0, 2.0]'
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, expr)
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_expr_mix_term_array(self, connect, collection):
|
|
"""
|
|
target: test query with mix type value expr
|
|
method: query with term expr that has int and float type value
|
|
expected: todo
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
expr = f'{default_int_field_name} in [1, 2.]'
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, expr)
|
|
|
|
@pytest.mark.parametrize("constant", [[1], (), {}])
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_expr_non_constant_array_term(self, connect, collection, constant):
|
|
"""
|
|
target: test query with non-constant array term expr
|
|
method: query with non-constant array expr
|
|
expected: raise exception
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
expr = f'{default_int_field_name} in [{constant}]'
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, expr)
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_output_field_empty(self, connect, collection):
|
|
"""
|
|
target: test query with none output field
|
|
method: query with output field=None
|
|
expected: return all fields
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
res = connect.query(collection, default_term_expr, output_fields=[])
|
|
assert default_int_field_name in res[0].keys()
|
|
assert default_float_field_name not in res[0].keys()
|
|
assert ut.default_float_vec_field_name not in res[0].keys()
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_output_one_field(self, connect, collection):
|
|
"""
|
|
target: test query with output one field
|
|
method: query with output one field
|
|
expected: return one field
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
res = connect.query(collection, default_term_expr, output_fields=[default_int_field_name])
|
|
assert default_int_field_name in res[0].keys()
|
|
assert len(res[0].keys()) == 1
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_output_all_fields(self, connect, collection):
|
|
"""
|
|
target: test query with none output field
|
|
method: query with output field=None
|
|
expected: return all fields
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
# fields = [default_int_field_name, default_float_field_name, ut.default_float_vec_field_name]
|
|
fields = [default_int_field_name, default_float_field_name]
|
|
res = connect.query(collection, default_term_expr, output_fields=fields)
|
|
for field in fields:
|
|
assert field in res[0].keys()
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_output_not_existed_field(self, connect, collection):
|
|
"""
|
|
target: test query output not existed field
|
|
method: query with not existed output field
|
|
expected: raise exception
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
connect.load_collection(collection)
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, default_term_expr, output_fields=["int"])
|
|
|
|
# @pytest.mark.xfail(reason="#6074")
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_output_part_not_existed_field(self, connect, collection):
|
|
"""
|
|
target: test query output part not existed field
|
|
method: query with part not existed field
|
|
expected: raise exception
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
connect.load_collection(collection)
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, default_term_expr, output_fields=[default_int_field_name, "int"])
|
|
|
|
@pytest.mark.parametrize("fields", ut.gen_invalid_strs())
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_invalid_output_fields(self, connect, collection, fields):
|
|
"""
|
|
target: test query with invalid output fields
|
|
method: query with invalid field fields
|
|
expected: raise exception
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
connect.load_collection(collection)
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, default_term_expr, output_fields=[fields])
|
|
|
|
|
|
class TestQueryPartition:
|
|
"""
|
|
test Query interface
|
|
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
|
|
"""
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_partition(self, connect, collection):
|
|
"""
|
|
target: test query on partition
|
|
method: create a partition and query
|
|
expected: verify query result
|
|
"""
|
|
connect.create_partition(collection, ut.default_tag)
|
|
entities, ids = init_data(connect, collection, partition_names=ut.default_tag)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_partitions(collection, [ut.default_tag])
|
|
res = connect.query(collection, default_term_expr, partition_names=[ut.default_tag], output_fields=["*", "%"])
|
|
for _id, index in enumerate(ids[:default_pos]):
|
|
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
|
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
|
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_partition_without_loading(self, connect, collection):
|
|
"""
|
|
target: test query on partition without loading
|
|
method: query on partition and no loading
|
|
expected: raise exception
|
|
"""
|
|
connect.create_partition(collection, ut.default_tag)
|
|
entities, ids = init_data(connect, collection, partition_names=ut.default_tag)
|
|
assert len(ids) == ut.default_nb
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_default_partition(self, connect, collection):
|
|
"""
|
|
target: test query on default partition
|
|
method: query on default partition
|
|
expected: verify query result
|
|
"""
|
|
entities, ids = init_data(connect, collection)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_collection(collection)
|
|
res = connect.query(collection, default_term_expr, partition_names=[ut.default_partition_name], output_fields=["*", "%"])
|
|
for _id, index in enumerate(ids[:default_pos]):
|
|
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
|
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
|
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.L2)
|
|
def test_query_empty_partition(self, connect, collection):
|
|
"""
|
|
target: test query on empty partition
|
|
method: query on a empty collection
|
|
expected: empty query result
|
|
"""
|
|
connect.create_partition(collection, ut.default_tag)
|
|
connect.load_partitions(collection, [ut.default_partition_name])
|
|
res = connect.query(collection, default_term_expr, partition_names=[ut.default_partition_name])
|
|
assert len(res) == 0
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_not_existed_partition(self, connect, collection):
|
|
"""
|
|
target: test query on a not existed partition
|
|
method: query on not existed partition
|
|
expected: raise exception
|
|
"""
|
|
connect.load_partitions(collection, [ut.default_partition_name])
|
|
tag = ut.gen_unique_str()
|
|
with pytest.raises(Exception):
|
|
connect.query(collection, default_term_expr, partition_names=[tag])
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_partition_repeatedly(self, connect, collection):
|
|
"""
|
|
target: test query repeatedly on partition
|
|
method: query on partition twice
|
|
expected: verify query result
|
|
"""
|
|
connect.create_partition(collection, ut.default_tag)
|
|
entities, ids = init_data(connect, collection, partition_names=ut.default_tag)
|
|
assert len(ids) == ut.default_nb
|
|
connect.load_partitions(collection, [ut.default_tag])
|
|
res_one = connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
|
|
res_two = connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
|
|
assert res_one == res_two
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_another_partition(self, connect, collection):
|
|
"""
|
|
target: test query another partition
|
|
method: 1. insert entities into two partitions
|
|
2. query on one partition and query result empty
|
|
expected: query result is empty
|
|
"""
|
|
insert_entities_into_two_partitions_in_half(connect, collection)
|
|
half = ut.default_nb // 2
|
|
term_expr = f'{default_int_field_name} in [{half}]'
|
|
res = connect.query(collection, term_expr, partition_names=[ut.default_tag])
|
|
assert len(res) == 0
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_multi_partitions_multi_results(self, connect, collection):
|
|
"""
|
|
target: test query on multi partitions and get multi results
|
|
method: 1.insert entities into two partitions
|
|
2.query on two partitions and query multi result
|
|
expected: query results from two partitions
|
|
"""
|
|
entities, entities_2 = insert_entities_into_two_partitions_in_half(connect, collection)
|
|
half = ut.default_nb // 2
|
|
term_expr = f'{default_int_field_name} in [{half - 1}]'
|
|
res = connect.query(collection, term_expr, partition_names=[ut.default_tag, ut.default_partition_name])
|
|
assert len(res) == 1
|
|
assert res[0][default_int_field_name] == entities[0]["values"][-1]
|
|
term_expr = f'{default_int_field_name} in [{half}]'
|
|
res = connect.query(collection, term_expr, partition_names=[ut.default_tag, ut.default_partition_name])
|
|
assert len(res) == 1
|
|
assert res[0][default_int_field_name] == entities_2[0]["values"][0]
|
|
|
|
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
|
def test_query_multi_partitions_single_result(self, connect, collection):
|
|
"""
|
|
target: test query on multi partitions and get single result
|
|
method: 1.insert into two partitions
|
|
2.query on two partitions and query single result
|
|
expected: query from two partitions and get single result
|
|
"""
|
|
entities, entities_2 = insert_entities_into_two_partitions_in_half(connect, collection)
|
|
half = ut.default_nb // 2
|
|
term_expr = f'{default_int_field_name} in [{half}]'
|
|
res = connect.query(collection, term_expr, partition_names=[ut.default_tag, ut.default_partition_name])
|
|
assert len(res) == 1
|
|
assert res[0][default_int_field_name] == entities_2[0]["values"][0]
|
|
|
|
|
|
def insert_entities_into_two_partitions_in_half(connect, collection):
|
|
"""
|
|
insert default entities into two partitions(default_tag and _default) in half(int64 and float fields values)
|
|
:param connect: milvus connect
|
|
:param collection: milvus created collection
|
|
:return: entities of default_tag and entities_2 of _default
|
|
"""
|
|
connect.create_partition(collection, ut.default_tag)
|
|
half = ut.default_nb // 2
|
|
entities, _ = init_data(connect, collection, nb=half, partition_names=ut.default_tag)
|
|
vectors = ut.gen_vectors(half, ut.default_dim)
|
|
entities_2 = [
|
|
{"name": "int64", "type": DataType.INT64, "values": [i for i in range(half, ut.default_nb)]},
|
|
{"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(half, ut.default_nb)]},
|
|
{"name": ut.default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}
|
|
]
|
|
connect.insert(collection, entities_2)
|
|
connect.flush([collection])
|
|
connect.load_collection(collection)
|
|
return entities, entities_2
|