mirror of https://github.com/milvus-io/milvus.git
Use primary key only to check search result duplicate (#10949)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/10977/head
parent
0dc5a86606
commit
da0cb4a702
|
@ -64,6 +64,7 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
|||
auto num_segments = search_results.size();
|
||||
AssertInfo(num_segments > 0, "num segment must greater than 0");
|
||||
|
||||
std::unordered_set<int64_t> pk_set;
|
||||
int64_t skip_dup_cnt = 0;
|
||||
for (int64_t qi = 0; qi < nq; qi++) {
|
||||
std::vector<SearchResultPair> result_pairs;
|
||||
|
@ -86,38 +87,25 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
|||
search_records[index].push_back(result_pair.offset_++);
|
||||
}
|
||||
#else
|
||||
float prev_dis = MAXFLOAT;
|
||||
std::unordered_set<int64_t> prev_pk_set;
|
||||
pk_set.clear();
|
||||
while (curr_offset - base_offset < topk) {
|
||||
result_pairs[0].reset_distance();
|
||||
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
||||
auto& result_pair = result_pairs[0];
|
||||
auto index = result_pair.index_;
|
||||
int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_];
|
||||
float curr_dis = result_pair.search_result_->result_distances_[result_pair.offset_];
|
||||
// remove duplicates
|
||||
if (curr_pk == INVALID_ID || std::abs(curr_dis - prev_dis) > 0.00001) {
|
||||
if (curr_pk == INVALID_ID || pk_set.count(curr_pk) == 0) {
|
||||
result_pair.search_result_->result_offsets_.push_back(curr_offset++);
|
||||
search_records[index].push_back(result_pair.offset_);
|
||||
prev_dis = curr_dis;
|
||||
prev_pk_set.clear();
|
||||
prev_pk_set.insert(curr_pk);
|
||||
} else {
|
||||
// To handle this case:
|
||||
// e1: [100, 0.99]
|
||||
// e2: [101, 0.99] ==> not duplicated, should keep
|
||||
// e3: [100, 0.99] ==> duplicated, should remove
|
||||
if (prev_pk_set.count(curr_pk) == 0) {
|
||||
result_pair.search_result_->result_offsets_.push_back(curr_offset++);
|
||||
search_records[index].push_back(result_pair.offset_);
|
||||
// prev_pk_set keeps all primary keys with same distance
|
||||
prev_pk_set.insert(curr_pk);
|
||||
} else {
|
||||
// the entity with same distance and same primary key must be duplicated
|
||||
skip_dup_cnt++;
|
||||
search_records[index].push_back(result_pair.offset_++);
|
||||
if (curr_pk != INVALID_ID) {
|
||||
pk_set.insert(curr_pk);
|
||||
}
|
||||
} else {
|
||||
// skip entity with same primary key
|
||||
result_pair.offset_++;
|
||||
skip_dup_cnt++;
|
||||
}
|
||||
result_pair.offset_++;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -1812,8 +1812,7 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in
|
|||
for i := int64(0); i < nq; i++ {
|
||||
offsets := make([]int64, len(searchResultData))
|
||||
|
||||
var prevIDSet = make(map[int64]struct{})
|
||||
var prevScore float32 = math.MaxFloat32
|
||||
var idSet = make(map[int64]struct{})
|
||||
var j int64
|
||||
for j = 0; j < topk; {
|
||||
sel := selectSearchResultData(searchResultData, offsets, topk, i)
|
||||
|
@ -1830,28 +1829,15 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in
|
|||
}
|
||||
|
||||
// remove duplicates
|
||||
if math.Abs(float64(score)-float64(prevScore)) > 0.00001 {
|
||||
if _, ok := idSet[id]; !ok {
|
||||
typeutil.AppendFieldData(ret.Results.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
prevScore = score
|
||||
prevIDSet = map[int64]struct{}{id: {}}
|
||||
idSet[id] = struct{}{}
|
||||
j++
|
||||
} else {
|
||||
// To handle this case:
|
||||
// e1: [100, 0.99]
|
||||
// e2: [101, 0.99] ==> not duplicated, should keep
|
||||
// e3: [100, 0.99] ==> duplicated, should remove
|
||||
if _, ok := prevIDSet[id]; !ok {
|
||||
typeutil.AppendFieldData(ret.Results.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
prevIDSet[id] = struct{}{}
|
||||
j++
|
||||
} else {
|
||||
// entity with same id and same score must be duplicated
|
||||
skipDupCnt++
|
||||
}
|
||||
// skip entity with same id
|
||||
skipDupCnt++
|
||||
}
|
||||
offsets[sel]++
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue