mirror of https://github.com/milvus-io/milvus.git
fix: Add hybridsearch result cases (#30834)
#30694 Signed-off-by: luzhang <luzhang@zilliz.com> Co-authored-by: luzhang <luzhang@zilliz.com>pull/30561/head
parent
8addd75481
commit
21b41e96fc
|
@ -11709,3 +11709,72 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
for i in range(len(score_answer[:limit])):
|
||||
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_hybrid_search_result_L2_order(self):
|
||||
"""
|
||||
target: test hybrid search result having correct order for L2 distance
|
||||
method: create connection, collection, insert and search
|
||||
expected: hybrid search successfully and result order is correct
|
||||
"""
|
||||
# 1. initialize collection with data
|
||||
collection_w, _, _, insert_ids, time_stamp = \
|
||||
self.init_collection_general(prefix, True, is_index=False, multiple_dim_array=[default_dim, default_dim])[0:5]
|
||||
|
||||
# 2. create index
|
||||
vector_name_list = cf.extract_vector_field_name_list(collection_w)
|
||||
vector_name_list.append(ct.default_float_vec_field_name)
|
||||
for i in range(len(vector_name_list)) :
|
||||
default_index = { "index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 128},}
|
||||
collection_w.create_index(vector_name_list[i], default_index)
|
||||
collection_w.load()
|
||||
|
||||
# 3. prepare search params
|
||||
req_list = []
|
||||
weights = [0.2, 0.3, 0.5]
|
||||
for i in range(len(vector_name_list)):
|
||||
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
|
||||
search_param = {
|
||||
"data": vectors,
|
||||
"anns_field": vector_name_list[i],
|
||||
"param": {"metric_type": "L2", "offset": 0},
|
||||
"limit": default_limit,
|
||||
"expr": "int64 > 0"}
|
||||
req = AnnSearchRequest(**search_param)
|
||||
req_list.append(req)
|
||||
# 4. hybrid search
|
||||
res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), 10)
|
||||
is_sorted_decrease = lambda lst: all(lst[i]['distance'] >= lst[i+1]['distance'] for i in range(len(lst)-1))
|
||||
assert is_sorted_decrease(res[0])
|
||||
print(res)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_hybrid_search_result_order(self):
|
||||
"""
|
||||
target: test hybrid search result having correct order for cosine distance
|
||||
method: create connection, collection, insert and search
|
||||
expected: hybrid search successfully and result order is correct
|
||||
"""
|
||||
# 1. initialize collection with data
|
||||
collection_w, _, _, insert_ids, time_stamp = \
|
||||
self.init_collection_general(prefix, True, multiple_dim_array=[default_dim, default_dim])[0:5]
|
||||
# 2. extract vector field name
|
||||
vector_name_list = cf.extract_vector_field_name_list(collection_w)
|
||||
vector_name_list.append(ct.default_float_vec_field_name)
|
||||
# 3. prepare search params
|
||||
req_list = []
|
||||
weights = [0.2, 0.3, 0.5]
|
||||
for i in range(len(vector_name_list)):
|
||||
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
|
||||
search_param = {
|
||||
"data": vectors,
|
||||
"anns_field": vector_name_list[i],
|
||||
"param": {"metric_type": "COSINE", "offset": 0},
|
||||
"limit": default_limit,
|
||||
"expr": "int64 > 0"}
|
||||
req = AnnSearchRequest(**search_param)
|
||||
req_list.append(req)
|
||||
# 4. hybrid search
|
||||
res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), 10)
|
||||
is_sorted_ascend = lambda lst: all(lst[i]['distance'] <= lst[i+1]['distance'] for i in range(len(lst)-1))
|
||||
assert is_sorted_ascend(res[0])
|
||||
print(res)
|
Loading…
Reference in New Issue