Fix search binary pagination failure (#22477)

See also: #22168

Signed-off-by: yangxuan <xuan.yang@zilliz.com>
pull/22628/head
XuanYang-cn 2023-03-08 11:03:51 +08:00 committed by GitHub
parent a413d8a803
commit 955bc06165
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 5 deletions

View File

@ -52,7 +52,7 @@ struct SearchResult {
std::vector<float> distances_;
std::vector<int64_t> seg_offsets_;
// fist fill data during fillPrimaryKey, and then update data after reducing search results
// first fill data during fillPrimaryKey, and then update data after reducing search results
std::vector<PkType> primary_keys_;
DataType pk_type_;

View File

@ -43,6 +43,9 @@ struct SearchResultPair {
bool
operator>(const SearchResultPair& other) const {
if (fabs(distance_ - other.distance_) < 0.000001f) {
return primary_key_ < other.primary_key_;
}
return distance_ > other.distance_;
}

View File

@ -23,3 +23,13 @@ TEST(SearchResultPair, Greater) {
ASSERT_EQ(pair1 > pair2, true);
ASSERT_EQ(pair2.primary_key_, INVALID_PK);
}
TEST(SearchResultPair, SameDistance) {
auto pair1 = SearchResultPair(0, 1.0, nullptr, 0, 0, 1);
auto pair2 = SearchResultPair(1, 1.0, nullptr, 1, 0, 1);
ASSERT_EQ(pair1 > pair2, true);
pair1.advance();
ASSERT_EQ(pair2 > pair1, true);
ASSERT_EQ(pair1.primary_key_, INVALID_PK);
}

View File

@ -3971,10 +3971,19 @@ class TestsearchPagination(TestcaseBase):
res = collection_w.search(binary_vectors[:default_nq], "binary_vector", search_binary_param,
default_limit + offset)[0]
assert res[0].distances == sorted(res[0].distances)
assert search_res[0].distances == sorted(search_res[0].distances)
assert search_res[0].distances == res[0].distances[offset:]
assert set(search_res[0].ids) == set(res[0].ids[offset:])
assert len(search_res[0].ids) == len(res[0].ids[offset:])
assert sorted(search_res[0].distances, key=numpy.float32) == sorted(res[0].distances[offset:], key=numpy.float32)
unique_a, unique_b = set(search_res[0].ids), set(res[0].ids[offset:])
diff = unique_a ^ unique_b
assert len(diff) <= 2
if len(diff) == 2:
i = search_res[0].ids.index((unique_a-unique_b).pop())
i2 = res[0].ids[offset:].index((unique_b-unique_a).pop())
assert search_res[0].distances[i] == res[0].distances[offset:][i2]
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("limit", [100, 3000, 10000])