add test case: flat search with metric_type (#3558)

Signed-off-by: zw <zw@milvus.io>

Co-authored-by: zw <zw@milvus.io>
pull/3609/head
del-zhenwu 2020-09-03 11:30:17 +08:00 committed by GitHub
parent 273863f54d
commit a9751fec4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 53 additions and 0 deletions

View File

@ -155,6 +155,32 @@ class TestSearchBase:
else:
assert not status.OK()
def test_search_top_k_flat_index_metric_type(self, connect, collection):
'''
target: test basic search fuction, all the search params is corrent, change top-k value
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
'''
vectors, ids = self.init_data(connect, collection)
query_vec = [vectors[0]]
status, result = connect.search(collection, top_k, query_vec, params={"metric_type": MetricType.IP.value})
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert result[0][0].distance >= 1 - epsilon
assert check_result(result[0], ids[0])
@pytest.mark.level(2)
def test_search_top_k_flat_index_metric_type_invalid(self, connect, collection):
'''
target: test basic search fuction, all the search params is corrent, change top-k value
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
'''
vectors, ids = self.init_data(connect, collection)
query_vec = [vectors[0]]
status, result = connect.search(collection, top_k, query_vec, params={"metric_type": MetricType.JACCARD.value})
assert not status.OK()
def test_search_l2_index_params(self, connect, collection, get_simple_index):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
@ -616,6 +642,33 @@ class TestSearchBase:
logging.getLogger().info(result)
assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon
def test_search_distance_jaccard_flat_index_metric_type(self, connect, jac_collection):
'''
target: search ip_collection, and check the result: distance
method: compare the return distance value with value computed with HAMMING
expected: the return distance equals to the computed value
'''
# from scipy.spatial import distance
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, jac_collection, nb=2)
index_type = IndexType.FLAT
index_param = {
"nlist": 16384
}
connect.create_index(jac_collection, index_type, index_param)
logging.getLogger().info(connect.get_collection_info(jac_collection))
logging.getLogger().info(connect.get_index_info(jac_collection))
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, jac_collection, nb=1, insert=False)
distance_0 = hamming(query_int_vectors[0], int_vectors[0])
distance_1 = hamming(query_int_vectors[0], int_vectors[1])
search_param = get_search_param(index_type)
search_param.update({"metric_type": MetricType.HAMMING.value})
status, result = connect.search(jac_collection, top_k, query_vecs, params=search_param)
assert status.OK()
logging.getLogger().info(status)
logging.getLogger().info(result)
assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
def test_search_distance_hamming_flat_index(self, connect, ham_collection):
'''
target: search ip_collection, and check the result: distance