Update test cases for new default metric type cosine (#25064)

Signed-off-by: nico <cheng.yuan@zilliz.com>
pull/25118/head
nico 2023-06-25 14:48:44 +08:00 committed by GitHub
parent ccf3f0066f
commit 9e787416be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 15 deletions

View File

@ -4569,11 +4569,12 @@ class TestSearchString(TestcaseBase):
# 1. initialize with data
collection_w, _, _, insert_ids = \
self.init_collection_general(prefix, True, dim=dim, primary_field=ct.default_string_field_name,
enable_dynamic_field=enable_dynamic_field)[0:4]
enable_dynamic_field=enable_dynamic_field, is_index=False)[0:4]
collection_w.create_index(field_name, {"metric_type": "L2"})
collection_w.load()
# 2. search
log.info("test_search_string_field_is_primary_true: searching collection %s" % collection_w.name)
range_search_params = {"metric_type": "L2", "params": {"nprobe": 10, "radius": 1000,
"range_filter": 0}}
range_search_params = {"metric_type": "L2", "params": {"radius": 1000, "range_filter": 0}}
vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
output_fields = [default_string_field_name, default_float_field_name]
collection_w.search(vectors[:default_nq], default_search_field,
@ -5182,8 +5183,7 @@ class TestSearchPagination(TestcaseBase):
assert set(search_res[0].ids) == set(res[0].ids[offset:])
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("expression", [i for i in range(25)])
def test_search_pagination_with_index_partition(self, offset, auto_id, _async, expression):
def test_search_pagination_with_index_partition(self, offset, auto_id, _async):
"""
target: test search pagination with index and partition
method: create connection, collection, insert data, create index and search
@ -5202,9 +5202,9 @@ class TestSearchPagination(TestcaseBase):
# 3. search through partitions
par = collection_w.partitions
limit = 100
search_param = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": offset}
search_params = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": offset}
search_res = collection_w.search(vectors[:default_nq], default_search_field,
search_param, limit, default_search_exp,
search_params, limit, default_search_exp,
[par[0].name, par[1].name], _async=_async,
check_task=CheckTasks.check_search_results,
check_items={"nq": default_nq,
@ -5212,7 +5212,8 @@ class TestSearchPagination(TestcaseBase):
"limit": limit,
"_async": _async})[0]
# 3. search through partitions with offset+limit
res = collection_w.search(vectors[:default_nq], default_search_field, default_search_params,
search_params = {"metric_type": "L2"}
res = collection_w.search(vectors[:default_nq], default_search_field, search_params,
limit + offset, default_search_exp,
[par[0].name, par[1].name], _async=_async)[0]
if _async:
@ -5279,17 +5280,17 @@ class TestSearchPagination(TestcaseBase):
collection_w.insert(data)
collection_w.load()
# 3. search
search_param = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": offset}
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
search_params = {"offset": offset}
search_res = collection_w.search(vectors[:default_nq], default_search_field,
search_param, default_limit,
search_params, default_limit,
default_search_exp, _async=_async,
check_task=CheckTasks.check_search_results,
check_items={"nq": default_nq,
"limit": default_limit,
"_async": _async})[0]
# 4. search through partitions with offset+limit
res = collection_w.search(vectors[:default_nq], default_search_field, default_search_params,
search_params = {}
res = collection_w.search(vectors[:default_nq], default_search_field, search_params,
default_limit + offset, default_search_exp, _async=_async)[0]
if _async:
search_res.done()
@ -5955,12 +5956,14 @@ class TestCollectionRangeSearch(TestcaseBase):
expected: search successfully with filtered limit(topK)
"""
# 1. initialize with data
collection_w, _vectors, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, nb=10)[0:5]
collection_w, _vectors, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, nb=10, is_index=False)[0:5]
collection_w.create_index(field_name, {"metric_type": "L2"})
collection_w.load()
# 2. get vectors that inserted into collection
vectors = np.array(_vectors[0]).tolist()
vectors = [vectors[i][-1] for i in range(default_nq)]
# 3. range search with L2
range_search_params = {"metric_type": "COSINE", "params": {"nprobe": 10, "radius": 0}}
range_search_params = {"metric_type": "L2", "params": {"nprobe": 10, "radius": 0}}
collection_w.search(vectors[:default_nq], default_search_field,
range_search_params, default_limit,
default_search_exp,
@ -5975,7 +5978,7 @@ class TestCollectionRangeSearch(TestcaseBase):
default_search_exp,
check_task=CheckTasks.err_res,
check_items={ct.err_code: 1,
ct.err_msg: "metric type not match: expected=COSINE, actual=IP"})
ct.err_msg: "metric type not match: expected=L2, actual=IP"})
@pytest.mark.tags(CaseLabel.L2)
def test_range_search_radius_range_filter_not_in_params(self):