Support output vectors in query (#6636)

* support output vectors in query

Signed-off-by: zhenwu <zhenxiang.li@zilliz.com>

* add query case

Signed-off-by: del-zhenwu <zhenxiang.li@zilliz.com>

* add query case

Signed-off-by: del-zhenwu <zhenxiang.li@zilliz.com>
pull/6910/head
del-zhenwu 2021-07-30 16:27:22 +08:00 committed by GitHub
parent ac50c5dd89
commit ec22185ff4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 62 additions and 32 deletions

View File

@ -78,6 +78,7 @@ class TestQueryBase:
def get_simple_index(self, request, connect):
return request.param
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_invalid(self, connect, collection):
"""
target: test query
@ -91,6 +92,7 @@ class TestQueryBase:
with pytest.raises(Exception):
res = connect.query(collection, term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_valid(self, connect, collection):
"""
target: test query
@ -101,35 +103,39 @@ class TestQueryBase:
assert len(ids) == ut.default_nb
connect.load_collection(collection)
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
res = connect.query(collection, term_expr)
res = connect.query(collection, term_expr, output_fields=["*", "%"])
assert len(res) == default_pos
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# not support
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
res = connect.query(collection, term_expr, output_fields=[ut.default_float_vec_field_name])
assert len(res) == default_pos
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_collection_not_existed(self, connect):
"""
target: test query not existed collection
method: query not existed collection
expected: raise exception
"""
ex_msg = 'find collection'
collection = "not_exist"
with pytest.raises(Exception, match=ex_msg):
with pytest.raises(Exception):
connect.query(collection, default_term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_without_connect(self, dis_connect, collection):
"""
target: test query without connection
method: close connect and query
expected: raise exception
"""
ex_msg = 'NoneType'
with pytest.raises(Exception, match=ex_msg):
with pytest.raises(Exception):
dis_connect.query(collection, default_term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_invalid_collection_name(self, connect, get_collection_name):
"""
target: test query with invalid collection name
@ -140,6 +146,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection_name, default_term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_after_index(self, connect, collection, get_simple_index):
"""
target: test query after creating index
@ -151,13 +158,13 @@ class TestQueryBase:
connect.create_index(collection, ut.default_float_vec_field_name, get_simple_index)
connect.load_collection(collection)
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
res = connect.query(collection, term_expr)
res = connect.query(collection, term_expr, output_fields=["*", "%"])
logging.getLogger().info(res)
assert len(res) == default_pos
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# # ut.assert_equal_vector(res[i][ut.default_float_vec_field_name], entities[-1]["values"][i])
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[-1]["values"][index])
def test_query_after_search(self, connect, collection):
"""
@ -175,15 +182,15 @@ class TestQueryBase:
assert len(search_res) == nq
assert len(search_res[0]) == top_k
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
res = connect.query(collection, term_expr)
res = connect.query(collection, term_expr, output_fields=["*", "%"])
logging.getLogger().info(res)
assert len(res) == default_pos
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# not support
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_empty_collection(self, connect, collection):
"""
target: test query empty collection
@ -195,6 +202,7 @@ class TestQueryBase:
logging.getLogger().info(res)
assert len(res) == 0
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_without_loading(self, connect, collection):
"""
target: test query without loading
@ -206,6 +214,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, default_term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_collection_not_primary_key(self, connect, collection):
"""
target: test query on collection that not on the primary field
@ -219,6 +228,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_none(self, connect, collection):
"""
target: test query with none expr
@ -231,6 +241,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, None)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
@pytest.mark.parametrize("expr", [1, "1", "12-s", "中文", [], {}, ()])
def test_query_expr_invalid_string(self, connect, collection, expr):
"""
@ -244,6 +255,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_not_existed_field(self, connect, collection):
"""
target: test query with not existed field
@ -260,6 +272,7 @@ class TestQueryBase:
@pytest.mark.parametrize("expr", [f'{default_int_field_name} inn [1, 2]',
f'{default_int_field_name} not in [1, 2]',
f'{default_int_field_name} in not [1, 2]'])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_wrong_term_keyword(self, connect, collection, expr):
"""
target: test query with wrong term expr keyword
@ -273,6 +286,7 @@ class TestQueryBase:
@pytest.mark.parametrize("expr", [f'{default_int_field_name} in 1',
f'{default_int_field_name} in "in"',
f'{default_int_field_name} in (mn)'])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_non_array_term(self, connect, collection, expr):
"""
target: test query with non-array term expr
@ -283,6 +297,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_empty_term_array(self, connect, collection):
"""
target: test query with empty array term expr
@ -296,6 +311,7 @@ class TestQueryBase:
res = connect.query(collection, term_expr)
assert len(res) == 0
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_single_term_array(self, connect, collection):
"""
target: test query with single array term expr
@ -306,14 +322,14 @@ class TestQueryBase:
assert len(ids) == ut.default_nb
connect.load_collection(collection)
term_expr = f'{default_int_field_name} in [0]'
res = connect.query(collection, term_expr)
res = connect.query(collection, term_expr, output_fields=["*", "%"])
assert len(res) == 1
assert res[0][default_int_field_name] == entities[0]["values"][0]
assert res[0][default_float_field_name] == entities[1]["values"][0]
# not support
# ut.assert_equal_vector(res[0][ut.default_float_vec_field_name], entities[2]["values"][0])
ut.assert_equal_vector(res[0][ut.default_float_vec_field_name], entities[2]["values"][0])
@pytest.mark.xfail(reason="#6072")
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_binary_expr_single_term_array(self, connect, binary_collection):
"""
target: test query with single array term expr
@ -324,14 +340,14 @@ class TestQueryBase:
assert len(ids) == ut.default_nb
connect.load_collection(binary_collection)
term_expr = f'{default_int_field_name} in [0]'
res = connect.query(binary_collection, term_expr)
res = connect.query(binary_collection, term_expr, output_fields=["*", "%"])
assert len(res) == 1
assert res[0][default_int_field_name] == binary_entities[0]["values"][0]
assert res[1][default_float_field_name] == binary_entities[1]["values"][0]
# not support
# assert res[2][ut.default_float_vec_field_name] == binary_entities[2]["values"][0]
assert res[2][ut.default_float_vec_field_name] == binary_entities[2]["values"][0]
@pytest.mark.level(2)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_all_term_array(self, connect, collection):
"""
target: test query with all array term expr
@ -342,14 +358,14 @@ class TestQueryBase:
assert len(ids) == ut.default_nb
connect.load_collection(collection)
term_expr = f'{default_int_field_name} in {ids}'
res = connect.query(collection, term_expr)
res = connect.query(collection, term_expr, output_fields=["*", "%"])
assert len(res) == ut.default_nb
for _id, index in enumerate(ids):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# not support
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_repeated_term_array(self, connect, collection):
"""
target: test query with repeated term array on primary field with unique value
@ -364,6 +380,7 @@ class TestQueryBase:
res = connect.query(collection, term_expr)
assert len(res) == 2
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_inconstant_term_array(self, connect, collection):
"""
target: test query with term expr that field and array are inconsistent
@ -377,6 +394,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_mix_term_array(self, connect, collection):
"""
target: test query with mix type value expr
@ -391,6 +409,7 @@ class TestQueryBase:
connect.query(collection, expr)
@pytest.mark.parametrize("constant", [[1], (), {}])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_non_constant_array_term(self, connect, collection, constant):
"""
target: test query with non-constant array term expr
@ -404,6 +423,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_output_field_empty(self, connect, collection):
"""
target: test query with none output field
@ -414,11 +434,11 @@ class TestQueryBase:
assert len(ids) == ut.default_nb
connect.load_collection(collection)
res = connect.query(collection, default_term_expr, output_fields=[])
# not support float_vector
fields = [default_int_field_name, default_float_field_name]
for field in fields:
assert field in res[0].keys()
assert default_int_field_name in res[0].keys()
assert default_float_field_name not in res[0].keys()
assert ut.default_float_vec_field_name not in res[0].keys()
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_output_one_field(self, connect, collection):
"""
target: test query with output one field
@ -432,6 +452,7 @@ class TestQueryBase:
assert default_int_field_name in res[0].keys()
assert len(res[0].keys()) == 1
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_output_all_fields(self, connect, collection):
"""
target: test query with none output field
@ -447,6 +468,7 @@ class TestQueryBase:
for field in fields:
assert field in res[0].keys()
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_output_not_existed_field(self, connect, collection):
"""
target: test query output not existed field
@ -459,6 +481,7 @@ class TestQueryBase:
connect.query(collection, default_term_expr, output_fields=["int"])
# @pytest.mark.xfail(reason="#6074")
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_output_part_not_existed_field(self, connect, collection):
"""
target: test query output part not existed field
@ -471,6 +494,7 @@ class TestQueryBase:
connect.query(collection, default_term_expr, output_fields=[default_int_field_name, "int"])
@pytest.mark.parametrize("fields", ut.gen_invalid_strs())
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_invalid_output_fields(self, connect, collection, fields):
"""
target: test query with invalid output fields
@ -488,6 +512,7 @@ class TestQueryPartition:
test Query interface
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
"""
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_partition(self, connect, collection):
"""
target: test query on partition
@ -498,13 +523,13 @@ class TestQueryPartition:
entities, ids = init_data(connect, collection, partition_names=ut.default_tag)
assert len(ids) == ut.default_nb
connect.load_partitions(collection, [ut.default_tag])
res = connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
res = connect.query(collection, default_term_expr, partition_names=[ut.default_tag], output_fields=["*", "%"])
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# not support
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_partition_without_loading(self, connect, collection):
"""
target: test query on partition without loading
@ -517,6 +542,7 @@ class TestQueryPartition:
with pytest.raises(Exception):
connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_default_partition(self, connect, collection):
"""
target: test query on default partition
@ -526,12 +552,11 @@ class TestQueryPartition:
entities, ids = init_data(connect, collection)
assert len(ids) == ut.default_nb
connect.load_collection(collection)
res = connect.query(collection, default_term_expr, partition_names=[ut.default_partition_name])
res = connect.query(collection, default_term_expr, partition_names=[ut.default_partition_name], output_fields=["*", "%"])
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# not support
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
@pytest.mark.xfail(reason="#6075")
def test_query_empty_partition(self, connect, collection):
@ -545,6 +570,7 @@ class TestQueryPartition:
res = connect.query(collection, default_term_expr, partition_names=[ut.default_partition_name])
assert len(res) == 0
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_not_existed_partition(self, connect, collection):
"""
target: test query on a not existed partition
@ -556,6 +582,7 @@ class TestQueryPartition:
with pytest.raises(Exception):
connect.query(collection, default_term_expr, partition_names=[tag])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_partition_repeatedly(self, connect, collection):
"""
target: test query repeatedly on partition
@ -570,6 +597,7 @@ class TestQueryPartition:
res_two = connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
assert res_one == res_two
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_another_partition(self, connect, collection):
"""
target: test query another partition
@ -583,6 +611,7 @@ class TestQueryPartition:
res = connect.query(collection, term_expr, partition_names=[ut.default_tag])
assert len(res) == 0
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_multi_partitions_multi_results(self, connect, collection):
"""
target: test query on multi partitions and get multi results
@ -601,6 +630,7 @@ class TestQueryPartition:
assert len(res) == 1
assert res[0][default_int_field_name] == entities_2[0]["values"][0]
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_multi_partitions_single_result(self, connect, collection):
"""
target: test query on multi partitions and get single result