Add search tests with tag 0331

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/4973/head^2
bigsheeper 2021-03-03 11:32:33 +08:00 committed by yefu.chen
parent f064e77f21
commit f3bdaa2ec6
1 changed files with 81 additions and 0 deletions

View File

@ -145,6 +145,7 @@ class TestSearchBase:
def get_nq(self, request):
yield request.param
@pytest.mark.tag("0331")
def test_search_flat(self, connect, collection, get_top_k, get_nq):
'''
target: test basic search function, all the search params is correct, change top-k value
@ -165,6 +166,7 @@ class TestSearchBase:
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
def test_search_flat_top_k(self, connect, collection, get_nq):
'''
target: test basic search function, all the search params is correct, change top-k value
@ -185,6 +187,7 @@ class TestSearchBase:
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# @pytest.mark.skip("r0.3-test")
def test_search_field(self, connect, collection, get_top_k, get_nq):
'''
target: test basic search function, all the search params is correct, change top-k value
@ -255,6 +258,7 @@ class TestSearchBase:
assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64")
# TODO:
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
@ -283,6 +287,7 @@ class TestSearchBase:
assert res[0]._distances[0] < epsilon
assert check_id_result(res[0], ids[0])
# @pytest.mark.skip("r0.3-test")
def test_search_after_index_different_metric_type(self, connect, collection, get_simple_index):
'''
target: test search with different metric_type
@ -302,6 +307,7 @@ class TestSearchBase:
assert len(res[0]) == default_top_k
assert res[0]._distances[0] > res[0]._distances[default_top_k - 1]
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_index_empty_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
@ -335,6 +341,7 @@ class TestSearchBase:
res = connect.search(collection, query, partition_tags=[default_tag])
assert len(res[0]) == 0
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
@ -365,6 +372,7 @@ class TestSearchBase:
assert res[0]._distances[0] < epsilon
assert check_id_result(res[0], ids[0])
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_index_partition_not_existed(self, connect, collection, get_top_k, get_nq):
'''
@ -386,6 +394,7 @@ class TestSearchBase:
assert len(res) == nq
assert len(res[0]) == 0
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
'''
@ -421,6 +430,7 @@ class TestSearchBase:
assert res[1]._distances[0] > epsilon
connect.release_collection(collection)
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
'''
@ -459,6 +469,7 @@ class TestSearchBase:
#
# test for ip metric
#
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
@ -476,6 +487,7 @@ class TestSearchBase:
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
assert check_id_result(res[0], ids[0])
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
@ -501,6 +513,7 @@ class TestSearchBase:
assert check_id_result(res[0], ids[0])
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_search_ip_index_empty_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
@ -534,6 +547,7 @@ class TestSearchBase:
res = connect.search(collection, query, partition_tags=[default_tag])
assert len(res[0]) == 0
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
'''
@ -567,6 +581,7 @@ class TestSearchBase:
# TODO:
# assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_without_connect(self, dis_connect, collection):
'''
@ -577,6 +592,7 @@ class TestSearchBase:
with pytest.raises(Exception) as e:
res = dis_connect.search(collection, default_query)
@pytest.mark.tag("0331")
def test_search_collection_not_existed(self, connect):
'''
target: search collection not existed
@ -587,6 +603,7 @@ class TestSearchBase:
with pytest.raises(Exception) as e:
res = connect.search(collection_name, default_query)
@pytest.mark.tag("0331")
def test_search_distance_l2(self, connect, collection):
'''
target: search collection, and check the result: distance
@ -607,6 +624,7 @@ class TestSearchBase:
assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
# TODO
@pytest.mark.tag("0331")
def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index):
'''
target: search collection, and check the result: distance
@ -637,6 +655,7 @@ class TestSearchBase:
# TODO:
# assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_distance_ip(self, connect, collection):
'''
@ -659,6 +678,7 @@ class TestSearchBase:
res = connect.search(collection, query)
assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon
@pytest.mark.tag("0331")
def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index):
'''
target: search collection, and check the result: distance
@ -692,6 +712,7 @@ class TestSearchBase:
# TODO:
# assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon
@pytest.mark.tag("0331")
def test_search_distance_jaccard_flat_index(self, connect, binary_collection):
'''
target: search binary_collection, and check the result: distance
@ -708,6 +729,7 @@ class TestSearchBase:
res = connect.search(binary_collection, query)
assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_binary_flat_with_L2(self, connect, binary_collection):
'''
@ -722,6 +744,7 @@ class TestSearchBase:
with pytest.raises(Exception) as e:
connect.search(binary_collection, query)
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_distance_hamming_flat_index(self, connect, binary_collection):
'''
@ -740,6 +763,7 @@ class TestSearchBase:
assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
# TODO
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_distance_substructure_flat_index(self, connect, binary_collection):
'''
@ -758,6 +782,7 @@ class TestSearchBase:
res = connect.search(binary_collection, query)
assert len(res[0]) == 0
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
'''
@ -777,6 +802,7 @@ class TestSearchBase:
assert res[1][0].distance <= epsilon
assert res[1][0].id == ids[1]
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_distance_superstructure_flat_index(self, connect, binary_collection):
'''
@ -795,6 +821,7 @@ class TestSearchBase:
res = connect.search(binary_collection, query)
assert len(res[0]) == 0
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection):
'''
@ -816,6 +843,7 @@ class TestSearchBase:
assert res[1][0].id in ids
assert res[1][0].distance <= epsilon
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_distance_tanimoto_flat_index(self, connect, binary_collection):
'''
@ -833,6 +861,7 @@ class TestSearchBase:
res = connect.search(binary_collection, query)
assert abs(res[0][0].distance - min(distance_0, distance_1)) <= epsilon
@pytest.mark.tag("0331")
@pytest.mark.level(2)
@pytest.mark.timeout(30)
def test_search_concurrent_multithreads(self, connect, args):
@ -868,6 +897,7 @@ class TestSearchBase:
for t in threads:
t.join()
@pytest.mark.tag("0331")
@pytest.mark.level(2)
@pytest.mark.timeout(30)
def test_search_concurrent_multithreads_single_connection(self, connect, args):
@ -902,6 +932,7 @@ class TestSearchBase:
for t in threads:
t.join()
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_multi_collections(self, connect, args):
'''
@ -926,6 +957,7 @@ class TestSearchBase:
assert res[i]._distances[0] < epsilon
assert res[i]._distances[1] > epsilon
# @pytest.mark.skip("r0.3-test")
def test_query_entities_with_field_less_than_top_k(self, connect, id_collection):
"""
target: test search with field, and let return entities less than topk
@ -955,6 +987,7 @@ class TestSearchDSL(object):
******************************************************************
"""
@pytest.mark.tag("0331")
def test_query_no_must(self, connect, collection):
'''
method: build query without must expr
@ -965,6 +998,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
def test_query_no_vector_term_only(self, connect, collection):
'''
method: build query without vector only term
@ -978,6 +1012,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
def test_query_no_vector_range_only(self, connect, collection):
'''
method: build query without vector only range
@ -991,6 +1026,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
def test_query_vector_only(self, connect, collection):
entities, ids = init_data(connect, collection)
connect.load_collection(collection)
@ -998,6 +1034,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
@pytest.mark.tag("0331")
def test_query_wrong_format(self, connect, collection):
'''
method: build query without must expr, with wrong expr name
@ -1011,6 +1048,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
def test_query_empty(self, connect, collection):
'''
method: search with empty query
@ -1026,6 +1064,7 @@ class TestSearchDSL(object):
******************************************************************
"""
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_term_value_not_in(self, connect, collection):
'''
@ -1043,6 +1082,7 @@ class TestSearchDSL(object):
# TODO:
# TODO:
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_term_value_all_in(self, connect, collection):
'''
@ -1059,6 +1099,7 @@ class TestSearchDSL(object):
# TODO:
# TODO:
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_term_values_not_in(self, connect, collection):
'''
@ -1075,6 +1116,7 @@ class TestSearchDSL(object):
assert len(res[0]) == 0
# TODO:
# @pytest.mark.skip("r0.3-test")
def test_query_term_values_all_in(self, connect, collection):
'''
method: build query with vector and term expr, with all term can be filtered
@ -1094,6 +1136,7 @@ class TestSearchDSL(object):
assert result.id in ids[:limit]
# TODO:
# @pytest.mark.skip("r0.3-test")
def test_query_term_values_parts_in(self, connect, collection):
'''
method: build query with vector and term expr, with parts of term can be filtered
@ -1111,6 +1154,7 @@ class TestSearchDSL(object):
# TODO:
# TODO:
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_term_values_repeat(self, connect, collection):
'''
@ -1128,6 +1172,7 @@ class TestSearchDSL(object):
assert len(res[0]) == 1
# TODO:
@pytest.mark.tag("0331")
def test_query_term_value_empty(self, connect, collection):
'''
method: build query with term value empty
@ -1140,6 +1185,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
@pytest.mark.tag("0331")
def test_query_complex_dsl(self, connect, collection):
'''
method: query with complicated dsl
@ -1163,6 +1209,7 @@ class TestSearchDSL(object):
"""
# TODO
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_query_term_key_error(self, connect, collection):
'''
@ -1182,6 +1229,7 @@ class TestSearchDSL(object):
def get_invalid_term(self, request):
return request.param
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
'''
@ -1196,6 +1244,7 @@ class TestSearchDSL(object):
res = connect.search(collection, query)
# TODO
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_term_field_named_term(self, connect, collection):
'''
@ -1223,6 +1272,7 @@ class TestSearchDSL(object):
assert len(res[0]) == default_top_k
connect.drop_collection(collection_term)
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_query_term_one_field_not_existed(self, connect, collection):
'''
@ -1244,6 +1294,7 @@ class TestSearchDSL(object):
"""
# TODO
@pytest.mark.tag("0331")
def test_query_range_key_error(self, connect, collection):
'''
method: build query with range key error
@ -1263,6 +1314,7 @@ class TestSearchDSL(object):
return request.param
# TODO
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_query_range_wrong_format(self, connect, collection, get_invalid_range):
'''
@ -1276,6 +1328,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_query_range_string_ranges(self, connect, collection):
'''
@ -1290,6 +1343,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_range_invalid_ranges(self, connect, collection):
'''
@ -1312,6 +1366,7 @@ class TestSearchDSL(object):
def get_valid_ranges(self, request):
return request.param
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges):
'''
@ -1328,6 +1383,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
@pytest.mark.tag("0331")
def test_query_range_one_field_not_existed(self, connect, collection):
'''
method: build query with two fields ranges, one of fields not existed
@ -1348,6 +1404,7 @@ class TestSearchDSL(object):
"""
# TODO
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_multi_term_has_common(self, connect, collection):
'''
@ -1365,6 +1422,7 @@ class TestSearchDSL(object):
assert len(res[0]) == default_top_k
# TODO
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_multi_term_no_common(self, connect, collection):
'''
@ -1382,6 +1440,7 @@ class TestSearchDSL(object):
assert len(res[0]) == 0
# TODO
# @pytest.mark.skip("r0.3-test")
def test_query_multi_term_different_fields(self, connect, collection):
'''
method: build query with multi range with same field, and ranges no common
@ -1399,6 +1458,7 @@ class TestSearchDSL(object):
assert len(res[0]) == 0
# TODO
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_query_single_term_multi_fields(self, connect, collection):
'''
@ -1415,6 +1475,7 @@ class TestSearchDSL(object):
res = connect.search(collection, query)
# TODO
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_multi_range_has_common(self, connect, collection):
'''
@ -1432,6 +1493,7 @@ class TestSearchDSL(object):
assert len(res[0]) == default_top_k
# TODO
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_multi_range_no_common(self, connect, collection):
'''
@ -1449,6 +1511,7 @@ class TestSearchDSL(object):
assert len(res[0]) == 0
# TODO
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_multi_range_different_fields(self, connect, collection):
'''
@ -1466,6 +1529,7 @@ class TestSearchDSL(object):
assert len(res[0]) == 0
# TODO
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_query_single_range_multi_fields(self, connect, collection):
'''
@ -1488,6 +1552,7 @@ class TestSearchDSL(object):
"""
# TODO
# @pytest.mark.skip("r0.3-test")
@pytest.mark.level(2)
def test_query_single_term_range_has_common(self, connect, collection):
'''
@ -1505,6 +1570,7 @@ class TestSearchDSL(object):
assert len(res[0]) == default_top_k
# TODO
# @pytest.mark.skip("r0.3-test")
def test_query_single_term_range_no_common(self, connect, collection):
'''
method: build query with single term single range
@ -1527,6 +1593,7 @@ class TestSearchDSL(object):
"""
# TODO
@pytest.mark.tag("0331")
def test_query_multi_vectors_same_field(self, connect, collection):
'''
method: build query with two vectors same field
@ -1550,6 +1617,7 @@ class TestSearchDSLBools(object):
******************************************************************
"""
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_query_no_bool(self, connect, collection):
'''
@ -1562,6 +1630,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
def test_query_should_only_term(self, connect, collection):
'''
method: build query without must, with should.term instead
@ -1572,6 +1641,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
def test_query_should_only_vector(self, connect, collection):
'''
method: build query without must, with should.vector instead
@ -1582,6 +1652,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
def test_query_must_not_only_term(self, connect, collection):
'''
method: build query without must, with must_not.term instead
@ -1592,6 +1663,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
def test_query_must_not_vector(self, connect, collection):
'''
method: build query without must, with must_not.vector instead
@ -1602,6 +1674,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
def test_query_must_should(self, connect, collection):
'''
method: build query must, and with should.term
@ -1657,12 +1730,14 @@ class TestSearchInvalid(object):
# pytest.skip("sq8h not support in CPU mode")
return request.param
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_with_invalid_collection(self, connect, get_collection_name):
collection_name = get_collection_name
with pytest.raises(Exception) as e:
res = connect.search(collection_name, default_query)
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_with_invalid_partition(self, connect, collection, get_invalid_partition):
# tag = " "
@ -1670,12 +1745,14 @@ class TestSearchInvalid(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, default_query, partition_tags=tag)
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field):
fields = [get_invalid_field]
with pytest.raises(Exception) as e:
res = connect.search(collection, default_query, fields=fields)
@pytest.mark.tag("0331")
@pytest.mark.level(1)
def test_search_with_not_existed_field(self, connect, collection):
fields = [gen_unique_str("field_name")]
@ -1693,6 +1770,7 @@ class TestSearchInvalid(object):
def get_top_k(self, request):
yield request.param
@pytest.mark.tag("0331")
@pytest.mark.level(1)
def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
'''
@ -1716,6 +1794,7 @@ class TestSearchInvalid(object):
def get_search_params(self, request):
yield request.param
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
'''
@ -1736,6 +1815,7 @@ class TestSearchInvalid(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_with_invalid_params_binary(self, connect, binary_collection):
'''
@ -1754,6 +1834,7 @@ class TestSearchInvalid(object):
with pytest.raises(Exception) as e:
res = connect.search(binary_collection, query)
@pytest.mark.tag("0331")
@pytest.mark.level(2)
def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
'''