test: relax the checks on range search (#36542) (#38234)

/kind improvement

pr:  #36542
pull/38253/head
zhuwenxing 2024-12-05 14:36:41 +08:00 committed by GitHub
parent 3d98e8e690
commit d4ef89f1c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 3 deletions

View File

@ -1580,10 +1580,11 @@ class TestSearchVector(TestBase):
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
training_data = [item[vector_field] for item in data]
distance_sorted = get_sorted_distance(training_data, [vector_to_search], metric_type)
r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.2*limit))] # recall is not 100% so add 20% to make sure the range is correct
r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.5*limit))] # recall is not 100% so add 50% to make sure the range is more than limit
if metric_type == "L2":
r1, r2 = r2, r1
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
logger.info(f"r1: {r1}, r2: {r2}")
payload = {
"collectionName": name,
"data": [vector_to_search],
@ -1601,7 +1602,14 @@ class TestSearchVector(TestBase):
assert rsp['code'] == 0
res = rsp['data']
logger.info(f"res: {len(res)}")
assert len(res) == limit
assert len(res) >= limit*0.8
# add buffer to the distance of comparison
if metric_type == "L2":
r1 = r1 + 10**-6
r2 = r2 - 10**-6
else:
r1 = r1 - 10**-6
r2 = r2 + 10**-6
for item in res:
distance = item.get("distance")
if metric_type == "L2":

View File

@ -262,6 +262,6 @@ def get_sorted_distance(train_emb, test_emb, metric_type):
"IP": ip_distance
}
distance = pairwise_distances(train_emb, Y=test_emb, metric=milvus_sklearn_metric_map[metric_type], n_jobs=-1)
distance = np.array(distance.T, order='C', dtype=np.float16)
distance = np.array(distance.T, order='C', dtype=np.float32)
distance_sorted = np.sort(distance, axis=1).tolist()
return distance_sorted