mirror of https://github.com/milvus-io/milvus.git
add test case: flat search with metric_type (#3558)
Signed-off-by: zw <zw@milvus.io> Co-authored-by: zw <zw@milvus.io>pull/3609/head
parent
273863f54d
commit
a9751fec4f
|
@ -155,6 +155,32 @@ class TestSearchBase:
|
|||
else:
|
||||
assert not status.OK()
|
||||
|
||||
def test_search_top_k_flat_index_metric_type(self, connect, collection):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, change top-k value
|
||||
method: search with the given vectors, check the result
|
||||
expected: search status ok, and the length of the result is top_k
|
||||
'''
|
||||
vectors, ids = self.init_data(connect, collection)
|
||||
query_vec = [vectors[0]]
|
||||
status, result = connect.search(collection, top_k, query_vec, params={"metric_type": MetricType.IP.value})
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
assert result[0][0].distance >= 1 - epsilon
|
||||
assert check_result(result[0], ids[0])
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_top_k_flat_index_metric_type_invalid(self, connect, collection):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, change top-k value
|
||||
method: search with the given vectors, check the result
|
||||
expected: search status ok, and the length of the result is top_k
|
||||
'''
|
||||
vectors, ids = self.init_data(connect, collection)
|
||||
query_vec = [vectors[0]]
|
||||
status, result = connect.search(collection, top_k, query_vec, params={"metric_type": MetricType.JACCARD.value})
|
||||
assert not status.OK()
|
||||
|
||||
def test_search_l2_index_params(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
|
@ -616,6 +642,33 @@ class TestSearchBase:
|
|||
logging.getLogger().info(result)
|
||||
assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon
|
||||
|
||||
def test_search_distance_jaccard_flat_index_metric_type(self, connect, jac_collection):
|
||||
'''
|
||||
target: search ip_collection, and check the result: distance
|
||||
method: compare the return distance value with value computed with HAMMING
|
||||
expected: the return distance equals to the computed value
|
||||
'''
|
||||
# from scipy.spatial import distance
|
||||
nprobe = 512
|
||||
int_vectors, vectors, ids = self.init_binary_data(connect, jac_collection, nb=2)
|
||||
index_type = IndexType.FLAT
|
||||
index_param = {
|
||||
"nlist": 16384
|
||||
}
|
||||
connect.create_index(jac_collection, index_type, index_param)
|
||||
logging.getLogger().info(connect.get_collection_info(jac_collection))
|
||||
logging.getLogger().info(connect.get_index_info(jac_collection))
|
||||
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, jac_collection, nb=1, insert=False)
|
||||
distance_0 = hamming(query_int_vectors[0], int_vectors[0])
|
||||
distance_1 = hamming(query_int_vectors[0], int_vectors[1])
|
||||
search_param = get_search_param(index_type)
|
||||
search_param.update({"metric_type": MetricType.HAMMING.value})
|
||||
status, result = connect.search(jac_collection, top_k, query_vecs, params=search_param)
|
||||
assert status.OK()
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info(result)
|
||||
assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
|
||||
|
||||
def test_search_distance_hamming_flat_index(self, connect, ham_collection):
|
||||
'''
|
||||
target: search ip_collection, and check the result: distance
|
||||
|
|
Loading…
Reference in New Issue