mirror of https://github.com/milvus-io/milvus.git
Support output vectors in query (#6636)
* support output vectors in query Signed-off-by: zhenwu <zhenxiang.li@zilliz.com> * add query case Signed-off-by: del-zhenwu <zhenxiang.li@zilliz.com> * add query case Signed-off-by: del-zhenwu <zhenxiang.li@zilliz.com>pull/6910/head
parent
ac50c5dd89
commit
ec22185ff4
|
@ -78,6 +78,7 @@ class TestQueryBase:
|
|||
def get_simple_index(self, request, connect):
|
||||
return request.param
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_invalid(self, connect, collection):
|
||||
"""
|
||||
target: test query
|
||||
|
@ -91,6 +92,7 @@ class TestQueryBase:
|
|||
with pytest.raises(Exception):
|
||||
res = connect.query(collection, term_expr)
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_valid(self, connect, collection):
|
||||
"""
|
||||
target: test query
|
||||
|
@ -101,35 +103,39 @@ class TestQueryBase:
|
|||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
|
||||
res = connect.query(collection, term_expr)
|
||||
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
||||
assert len(res) == default_pos
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
||||
# not support
|
||||
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
res = connect.query(collection, term_expr, output_fields=[ut.default_float_vec_field_name])
|
||||
assert len(res) == default_pos
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_collection_not_existed(self, connect):
|
||||
"""
|
||||
target: test query not existed collection
|
||||
method: query not existed collection
|
||||
expected: raise exception
|
||||
"""
|
||||
ex_msg = 'find collection'
|
||||
collection = "not_exist"
|
||||
with pytest.raises(Exception, match=ex_msg):
|
||||
with pytest.raises(Exception):
|
||||
connect.query(collection, default_term_expr)
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_without_connect(self, dis_connect, collection):
|
||||
"""
|
||||
target: test query without connection
|
||||
method: close connect and query
|
||||
expected: raise exception
|
||||
"""
|
||||
ex_msg = 'NoneType'
|
||||
with pytest.raises(Exception, match=ex_msg):
|
||||
with pytest.raises(Exception):
|
||||
dis_connect.query(collection, default_term_expr)
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_invalid_collection_name(self, connect, get_collection_name):
|
||||
"""
|
||||
target: test query with invalid collection name
|
||||
|
@ -140,6 +146,7 @@ class TestQueryBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.query(collection_name, default_term_expr)
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_after_index(self, connect, collection, get_simple_index):
|
||||
"""
|
||||
target: test query after creating index
|
||||
|
@ -151,13 +158,13 @@ class TestQueryBase:
|
|||
connect.create_index(collection, ut.default_float_vec_field_name, get_simple_index)
|
||||
connect.load_collection(collection)
|
||||
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
|
||||
res = connect.query(collection, term_expr)
|
||||
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
||||
logging.getLogger().info(res)
|
||||
assert len(res) == default_pos
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
||||
# # ut.assert_equal_vector(res[i][ut.default_float_vec_field_name], entities[-1]["values"][i])
|
||||
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[-1]["values"][index])
|
||||
|
||||
def test_query_after_search(self, connect, collection):
|
||||
"""
|
||||
|
@ -175,15 +182,15 @@ class TestQueryBase:
|
|||
assert len(search_res) == nq
|
||||
assert len(search_res[0]) == top_k
|
||||
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
|
||||
res = connect.query(collection, term_expr)
|
||||
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
||||
logging.getLogger().info(res)
|
||||
assert len(res) == default_pos
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
||||
# not support
|
||||
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_empty_collection(self, connect, collection):
|
||||
"""
|
||||
target: test query empty collection
|
||||
|
@ -195,6 +202,7 @@ class TestQueryBase:
|
|||
logging.getLogger().info(res)
|
||||
assert len(res) == 0
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_without_loading(self, connect, collection):
|
||||
"""
|
||||
target: test query without loading
|
||||
|
@ -206,6 +214,7 @@ class TestQueryBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.query(collection, default_term_expr)
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_collection_not_primary_key(self, connect, collection):
|
||||
"""
|
||||
target: test query on collection that not on the primary field
|
||||
|
@ -219,6 +228,7 @@ class TestQueryBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.query(collection, term_expr)
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_expr_none(self, connect, collection):
|
||||
"""
|
||||
target: test query with none expr
|
||||
|
@ -231,6 +241,7 @@ class TestQueryBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.query(collection, None)
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
@pytest.mark.parametrize("expr", [1, "1", "12-s", "中文", [], {}, ()])
|
||||
def test_query_expr_invalid_string(self, connect, collection, expr):
|
||||
"""
|
||||
|
@ -244,6 +255,7 @@ class TestQueryBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.query(collection, expr)
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_expr_not_existed_field(self, connect, collection):
|
||||
"""
|
||||
target: test query with not existed field
|
||||
|
@ -260,6 +272,7 @@ class TestQueryBase:
|
|||
@pytest.mark.parametrize("expr", [f'{default_int_field_name} inn [1, 2]',
|
||||
f'{default_int_field_name} not in [1, 2]',
|
||||
f'{default_int_field_name} in not [1, 2]'])
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_expr_wrong_term_keyword(self, connect, collection, expr):
|
||||
"""
|
||||
target: test query with wrong term expr keyword
|
||||
|
@ -273,6 +286,7 @@ class TestQueryBase:
|
|||
@pytest.mark.parametrize("expr", [f'{default_int_field_name} in 1',
|
||||
f'{default_int_field_name} in "in"',
|
||||
f'{default_int_field_name} in (mn)'])
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_expr_non_array_term(self, connect, collection, expr):
|
||||
"""
|
||||
target: test query with non-array term expr
|
||||
|
@ -283,6 +297,7 @@ class TestQueryBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.query(collection, expr)
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_expr_empty_term_array(self, connect, collection):
|
||||
"""
|
||||
target: test query with empty array term expr
|
||||
|
@ -296,6 +311,7 @@ class TestQueryBase:
|
|||
res = connect.query(collection, term_expr)
|
||||
assert len(res) == 0
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_expr_single_term_array(self, connect, collection):
|
||||
"""
|
||||
target: test query with single array term expr
|
||||
|
@ -306,14 +322,14 @@ class TestQueryBase:
|
|||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
term_expr = f'{default_int_field_name} in [0]'
|
||||
res = connect.query(collection, term_expr)
|
||||
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
||||
assert len(res) == 1
|
||||
assert res[0][default_int_field_name] == entities[0]["values"][0]
|
||||
assert res[0][default_float_field_name] == entities[1]["values"][0]
|
||||
# not support
|
||||
# ut.assert_equal_vector(res[0][ut.default_float_vec_field_name], entities[2]["values"][0])
|
||||
ut.assert_equal_vector(res[0][ut.default_float_vec_field_name], entities[2]["values"][0])
|
||||
|
||||
@pytest.mark.xfail(reason="#6072")
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_binary_expr_single_term_array(self, connect, binary_collection):
|
||||
"""
|
||||
target: test query with single array term expr
|
||||
|
@ -324,14 +340,14 @@ class TestQueryBase:
|
|||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(binary_collection)
|
||||
term_expr = f'{default_int_field_name} in [0]'
|
||||
res = connect.query(binary_collection, term_expr)
|
||||
res = connect.query(binary_collection, term_expr, output_fields=["*", "%"])
|
||||
assert len(res) == 1
|
||||
assert res[0][default_int_field_name] == binary_entities[0]["values"][0]
|
||||
assert res[1][default_float_field_name] == binary_entities[1]["values"][0]
|
||||
# not support
|
||||
# assert res[2][ut.default_float_vec_field_name] == binary_entities[2]["values"][0]
|
||||
assert res[2][ut.default_float_vec_field_name] == binary_entities[2]["values"][0]
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_expr_all_term_array(self, connect, collection):
|
||||
"""
|
||||
target: test query with all array term expr
|
||||
|
@ -342,14 +358,14 @@ class TestQueryBase:
|
|||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
term_expr = f'{default_int_field_name} in {ids}'
|
||||
res = connect.query(collection, term_expr)
|
||||
res = connect.query(collection, term_expr, output_fields=["*", "%"])
|
||||
assert len(res) == ut.default_nb
|
||||
for _id, index in enumerate(ids):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
||||
# not support
|
||||
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_expr_repeated_term_array(self, connect, collection):
|
||||
"""
|
||||
target: test query with repeated term array on primary field with unique value
|
||||
|
@ -364,6 +380,7 @@ class TestQueryBase:
|
|||
res = connect.query(collection, term_expr)
|
||||
assert len(res) == 2
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_expr_inconstant_term_array(self, connect, collection):
|
||||
"""
|
||||
target: test query with term expr that field and array are inconsistent
|
||||
|
@ -377,6 +394,7 @@ class TestQueryBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.query(collection, expr)
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_expr_mix_term_array(self, connect, collection):
|
||||
"""
|
||||
target: test query with mix type value expr
|
||||
|
@ -391,6 +409,7 @@ class TestQueryBase:
|
|||
connect.query(collection, expr)
|
||||
|
||||
@pytest.mark.parametrize("constant", [[1], (), {}])
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_expr_non_constant_array_term(self, connect, collection, constant):
|
||||
"""
|
||||
target: test query with non-constant array term expr
|
||||
|
@ -404,6 +423,7 @@ class TestQueryBase:
|
|||
with pytest.raises(Exception):
|
||||
connect.query(collection, expr)
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_output_field_empty(self, connect, collection):
|
||||
"""
|
||||
target: test query with none output field
|
||||
|
@ -414,11 +434,11 @@ class TestQueryBase:
|
|||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
res = connect.query(collection, default_term_expr, output_fields=[])
|
||||
# not support float_vector
|
||||
fields = [default_int_field_name, default_float_field_name]
|
||||
for field in fields:
|
||||
assert field in res[0].keys()
|
||||
assert default_int_field_name in res[0].keys()
|
||||
assert default_float_field_name not in res[0].keys()
|
||||
assert ut.default_float_vec_field_name not in res[0].keys()
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_output_one_field(self, connect, collection):
|
||||
"""
|
||||
target: test query with output one field
|
||||
|
@ -432,6 +452,7 @@ class TestQueryBase:
|
|||
assert default_int_field_name in res[0].keys()
|
||||
assert len(res[0].keys()) == 1
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_output_all_fields(self, connect, collection):
|
||||
"""
|
||||
target: test query with none output field
|
||||
|
@ -447,6 +468,7 @@ class TestQueryBase:
|
|||
for field in fields:
|
||||
assert field in res[0].keys()
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_output_not_existed_field(self, connect, collection):
|
||||
"""
|
||||
target: test query output not existed field
|
||||
|
@ -459,6 +481,7 @@ class TestQueryBase:
|
|||
connect.query(collection, default_term_expr, output_fields=["int"])
|
||||
|
||||
# @pytest.mark.xfail(reason="#6074")
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_output_part_not_existed_field(self, connect, collection):
|
||||
"""
|
||||
target: test query output part not existed field
|
||||
|
@ -471,6 +494,7 @@ class TestQueryBase:
|
|||
connect.query(collection, default_term_expr, output_fields=[default_int_field_name, "int"])
|
||||
|
||||
@pytest.mark.parametrize("fields", ut.gen_invalid_strs())
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_invalid_output_fields(self, connect, collection, fields):
|
||||
"""
|
||||
target: test query with invalid output fields
|
||||
|
@ -488,6 +512,7 @@ class TestQueryPartition:
|
|||
test Query interface
|
||||
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
|
||||
"""
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_partition(self, connect, collection):
|
||||
"""
|
||||
target: test query on partition
|
||||
|
@ -498,13 +523,13 @@ class TestQueryPartition:
|
|||
entities, ids = init_data(connect, collection, partition_names=ut.default_tag)
|
||||
assert len(ids) == ut.default_nb
|
||||
connect.load_partitions(collection, [ut.default_tag])
|
||||
res = connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
|
||||
res = connect.query(collection, default_term_expr, partition_names=[ut.default_tag], output_fields=["*", "%"])
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
||||
# not support
|
||||
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
|
||||
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_partition_without_loading(self, connect, collection):
|
||||
"""
|
||||
target: test query on partition without loading
|
||||
|
@ -517,6 +542,7 @@ class TestQueryPartition:
|
|||
with pytest.raises(Exception):
|
||||
connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_default_partition(self, connect, collection):
|
||||
"""
|
||||
target: test query on default partition
|
||||
|
@ -526,12 +552,11 @@ class TestQueryPartition:
|
|||
entities, ids = init_data(connect, collection)
|
||||
assert len(ids) == ut.default_nb
|
||||
connect.load_collection(collection)
|
||||
res = connect.query(collection, default_term_expr, partition_names=[ut.default_partition_name])
|
||||
res = connect.query(collection, default_term_expr, partition_names=[ut.default_partition_name], output_fields=["*", "%"])
|
||||
for _id, index in enumerate(ids[:default_pos]):
|
||||
if res[index][default_int_field_name] == entities[0]["values"][index]:
|
||||
assert res[index][default_float_field_name] == entities[1]["values"][index]
|
||||
# not support
|
||||
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
|
||||
|
||||
@pytest.mark.xfail(reason="#6075")
|
||||
def test_query_empty_partition(self, connect, collection):
|
||||
|
@ -545,6 +570,7 @@ class TestQueryPartition:
|
|||
res = connect.query(collection, default_term_expr, partition_names=[ut.default_partition_name])
|
||||
assert len(res) == 0
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_not_existed_partition(self, connect, collection):
|
||||
"""
|
||||
target: test query on a not existed partition
|
||||
|
@ -556,6 +582,7 @@ class TestQueryPartition:
|
|||
with pytest.raises(Exception):
|
||||
connect.query(collection, default_term_expr, partition_names=[tag])
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_partition_repeatedly(self, connect, collection):
|
||||
"""
|
||||
target: test query repeatedly on partition
|
||||
|
@ -570,6 +597,7 @@ class TestQueryPartition:
|
|||
res_two = connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
|
||||
assert res_one == res_two
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_another_partition(self, connect, collection):
|
||||
"""
|
||||
target: test query another partition
|
||||
|
@ -583,6 +611,7 @@ class TestQueryPartition:
|
|||
res = connect.query(collection, term_expr, partition_names=[ut.default_tag])
|
||||
assert len(res) == 0
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_multi_partitions_multi_results(self, connect, collection):
|
||||
"""
|
||||
target: test query on multi partitions and get multi results
|
||||
|
@ -601,6 +630,7 @@ class TestQueryPartition:
|
|||
assert len(res) == 1
|
||||
assert res[0][default_int_field_name] == entities_2[0]["values"][0]
|
||||
|
||||
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
|
||||
def test_query_multi_partitions_single_result(self, connect, collection):
|
||||
"""
|
||||
target: test query on multi partitions and get single result
|
||||
|
|
Loading…
Reference in New Issue