From d995b3f0fa9e7d6194571e75291cf9df665ca437 Mon Sep 17 00:00:00 2001 From: binbin <83755740+binbinlv@users.noreply.github.com> Date: Wed, 27 Mar 2024 12:05:09 +0800 Subject: [PATCH] test: modify hybrid search cases (#31624) issue: #31339 Signed-off-by: binbin lv --- tests/python_client/testcases/test_search.py | 26 ++++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index cd51e6655b..ff8eec07b6 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -10710,21 +10710,31 @@ class TestCollectionHybridSearchValid(TestcaseBase): collection_w.load() # 3. prepare search params req_list = [] + id_list = [] for i in range(len(vector_name_list)): + vectors = [[random.random() for _ in range(multiple_dim_array[i])] for _ in range(1)] + search_params = {"metric_type": metric_type, "offset": 0} search_param = { - "data": [[random.random() for _ in range(multiple_dim_array[i])] for _ in range(1)], + "data": vectors, "anns_field": vector_name_list[i], - "param": {"metric_type": metric_type, "offset": 0}, + "param": search_params, "limit": default_limit, - "expr": "int64 > 0"} + "expr": default_search_exp} req = AnnSearchRequest(**search_param) req_list.append(req) + search_res = collection_w.search(vectors[:1], vector_name_list[i], + search_params, default_limit, + default_search_exp, + check_task=CheckTasks.check_search_results, + check_items={"nq": 1, + "ids": insert_ids, + "limit": default_limit})[0] + id_list.extend(search_res[0].ids) # 4. hybrid search - collection_w.hybrid_search(req_list, WeightedRanker(0.1, 0.9), default_limit*len(req_list)+1, - check_task=CheckTasks.check_search_results, - check_items={"nq": 1, - "ids": insert_ids, - "limit": default_limit*len(req_list)}) + hybrid_search = \ + collection_w.hybrid_search(req_list, WeightedRanker(0.1, 0.9), default_limit * len(req_list) + 1)[0] + assert len(hybrid_search) == 1 + assert len(hybrid_search[0].ids) == len(list(set(id_list))) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("primary_field", [ct.default_int64_field_name, ct.default_string_field_name])