mirror of https://github.com/milvus-io/milvus.git
Remove duplicated search results in segcore reduce (#10117)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/10309/head
parent
7ce7cb7a5e
commit
48648c818b
|
@ -73,12 +73,13 @@ struct SearchResult {
|
|||
int64_t num_queries_;
|
||||
int64_t topk_;
|
||||
std::vector<float> result_distances_;
|
||||
std::vector<int64_t> internal_seg_offsets_;
|
||||
|
||||
public:
|
||||
// TODO(gexi): utilize these field
|
||||
void* segment_;
|
||||
std::vector<int64_t> internal_seg_offsets_;
|
||||
std::vector<int64_t> result_offsets_;
|
||||
std::vector<int64_t> primary_keys_;
|
||||
std::vector<std::vector<char>> row_data_;
|
||||
};
|
||||
|
||||
|
|
|
@ -14,6 +14,35 @@
|
|||
namespace milvus::segcore {
|
||||
class Naive;
|
||||
|
||||
void
|
||||
SegmentInternalInterface::FillPrimaryKeys(const query::Plan* plan, SearchResult& results) const {
|
||||
std::shared_lock lck(mutex_);
|
||||
AssertInfo(plan, "empty plan");
|
||||
auto size = results.result_distances_.size();
|
||||
AssertInfo(results.internal_seg_offsets_.size() == size,
|
||||
"Size of result distances is not equal to size of segment offsets");
|
||||
Assert(results.primary_keys_.size() == 0);
|
||||
|
||||
results.primary_keys_.resize(size);
|
||||
|
||||
auto element_sizeof = sizeof(int64_t);
|
||||
|
||||
aligned_vector<char> blob(size * element_sizeof);
|
||||
if (plan->schema_.get_is_auto_id()) {
|
||||
bulk_subscript(SystemFieldType::RowId, results.internal_seg_offsets_.data(), size, blob.data());
|
||||
} else {
|
||||
auto key_offset_opt = get_schema().get_primary_key_offset();
|
||||
AssertInfo(key_offset_opt.has_value(), "Cannot get primary key offset from schema");
|
||||
auto key_offset = key_offset_opt.value();
|
||||
AssertInfo(get_schema()[key_offset].get_data_type() == DataType::INT64, "Primary key field is not INT64 type");
|
||||
bulk_subscript(key_offset, results.internal_seg_offsets_.data(), size, blob.data());
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
results.primary_keys_[i] = *(int64_t*)(blob.data() + element_sizeof * i);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult& results) const {
|
||||
std::shared_lock lck(mutex_);
|
||||
|
@ -21,10 +50,8 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult&
|
|||
auto size = results.result_distances_.size();
|
||||
AssertInfo(results.internal_seg_offsets_.size() == size,
|
||||
"Size of result distances is not equal to size of segment offsets");
|
||||
// Assert(results.result_offsets_.size() == size);
|
||||
Assert(results.row_data_.size() == 0);
|
||||
|
||||
// std::vector<int64_t> row_ids(size);
|
||||
std::vector<int64_t> element_sizeofs;
|
||||
std::vector<aligned_vector<char>> blobs;
|
||||
|
||||
|
@ -45,7 +72,7 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult&
|
|||
element_sizeofs.push_back(sizeof(int64_t));
|
||||
}
|
||||
|
||||
// fill other entries
|
||||
// fill other entries except primary key
|
||||
for (auto field_offset : plan->target_entries_) {
|
||||
auto& field_meta = get_schema()[field_offset];
|
||||
auto element_sizeof = field_meta.get_sizeof();
|
||||
|
|
|
@ -28,11 +28,12 @@
|
|||
|
||||
namespace milvus::segcore {
|
||||
|
||||
// common interface of SegmentSealed and SegmentGrowing
|
||||
// used by C API
|
||||
// common interface of SegmentSealed and SegmentGrowing used by C API
|
||||
class SegmentInterface {
|
||||
public:
|
||||
// fill results according to target_entries in plan
|
||||
virtual void
|
||||
FillPrimaryKeys(const query::Plan* plan, SearchResult& results) const = 0;
|
||||
|
||||
virtual void
|
||||
FillTargetEntry(const query::Plan* plan, SearchResult& results) const = 0;
|
||||
|
||||
|
@ -82,6 +83,9 @@ class SegmentInternalInterface : public SegmentInterface {
|
|||
const query::PlaceholderGroup& placeholder_group,
|
||||
Timestamp timestamp) const override;
|
||||
|
||||
void
|
||||
FillPrimaryKeys(const query::Plan* plan, SearchResult& results) const override;
|
||||
|
||||
void
|
||||
FillTargetEntry(const query::Plan* plan, SearchResult& results) const override;
|
||||
|
||||
|
|
|
@ -10,14 +10,16 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <vector>
|
||||
#include <exceptions/EasyAssert.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "common/Types.h"
|
||||
#include "exceptions/EasyAssert.h"
|
||||
#include "log/Log.h"
|
||||
#include "query/Plan.h"
|
||||
#include "segcore/reduce_c.h"
|
||||
#include "segcore/Reduce.h"
|
||||
#include "segcore/ReduceStructure.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "common/Types.h"
|
||||
#include "pb/milvus.pb.h"
|
||||
|
||||
using SearchResult = milvus::SearchResult;
|
||||
|
@ -69,6 +71,8 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
|||
}
|
||||
int64_t loc_offset = query_offset;
|
||||
AssertInfo(topk > 0, "topk must greater than 0");
|
||||
|
||||
#if 0
|
||||
for (int i = 0; i < topk; ++i) {
|
||||
result_pairs[0].reset_distance();
|
||||
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
||||
|
@ -77,6 +81,42 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
|||
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
|
||||
search_records[index].push_back(result_pair.offset_++);
|
||||
}
|
||||
#else
|
||||
float prev_dis = MAXFLOAT;
|
||||
std::unordered_set<int64_t> prev_pk_set;
|
||||
prev_pk_set.insert(-1);
|
||||
while (loc_offset - query_offset < topk) {
|
||||
result_pairs[0].reset_distance();
|
||||
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
||||
auto& result_pair = result_pairs[0];
|
||||
auto index = result_pair.index_;
|
||||
int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_];
|
||||
float curr_dis = result_pair.search_result_->result_distances_[result_pair.offset_];
|
||||
// remove duplicates
|
||||
if (curr_pk == -1 || curr_dis != prev_dis) {
|
||||
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
|
||||
search_records[index].push_back(result_pair.offset_++);
|
||||
prev_dis = curr_dis;
|
||||
prev_pk_set.clear();
|
||||
prev_pk_set.insert(curr_pk);
|
||||
} else {
|
||||
// To handle this case:
|
||||
// e1: [100, 0.99]
|
||||
// e2: [101, 0.99] ==> not duplicated, should keep
|
||||
// e3: [100, 0.99] ==> duplicated, should remove
|
||||
if (prev_pk_set.count(curr_pk) == 0) {
|
||||
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
|
||||
search_records[index].push_back(result_pair.offset_++);
|
||||
// prev_pk_set keeps all primary keys with same distance
|
||||
prev_pk_set.insert(curr_pk);
|
||||
} else {
|
||||
// the entity with same distance and same primary key must be duplicated
|
||||
result_pair.offset_++;
|
||||
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, primary key " << curr_pk;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -90,17 +130,21 @@ ResetSearchResult(std::vector<std::vector<int64_t>>& search_records, std::vector
|
|||
continue;
|
||||
}
|
||||
|
||||
std::vector<int64_t> primary_keys;
|
||||
std::vector<float> result_distances;
|
||||
std::vector<int64_t> internal_seg_offsets;
|
||||
|
||||
for (int j = 0; j < search_records[i].size(); j++) {
|
||||
auto& offset = search_records[i][j];
|
||||
auto primary_key = search_result->primary_keys_[offset];
|
||||
auto distance = search_result->result_distances_[offset];
|
||||
auto internal_seg_offset = search_result->internal_seg_offsets_[offset];
|
||||
primary_keys.push_back(primary_key);
|
||||
result_distances.push_back(distance);
|
||||
internal_seg_offsets.push_back(internal_seg_offset);
|
||||
}
|
||||
|
||||
search_result->primary_keys_ = primary_keys;
|
||||
search_result->result_distances_ = result_distances;
|
||||
search_result->internal_seg_offsets_ = internal_seg_offsets;
|
||||
}
|
||||
|
@ -118,13 +162,19 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul
|
|||
auto num_queries = search_results[0]->num_queries_;
|
||||
std::vector<std::vector<int64_t>> search_records(num_segments);
|
||||
|
||||
// get primary keys for duplicates removal
|
||||
for (auto& search_result : search_results) {
|
||||
auto segment = (milvus::segcore::SegmentInterface*)(search_result->segment_);
|
||||
segment->FillPrimaryKeys(plan, *search_result);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
GetResultData(search_records, search_results, i, topk);
|
||||
}
|
||||
ResetSearchResult(search_records, search_results);
|
||||
|
||||
for (int i = 0; i < num_segments; ++i) {
|
||||
auto search_result = search_results[i];
|
||||
// fill in other entities
|
||||
for (auto& search_result : search_results) {
|
||||
auto segment = (milvus::segcore::SegmentInterface*)(search_result->segment_);
|
||||
segment->FillTargetEntry(plan, *search_result);
|
||||
}
|
||||
|
|
|
@ -9,12 +9,13 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <random>
|
||||
#include <gtest/gtest.h>
|
||||
#include <chrono>
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "common/LoadInfo.h"
|
||||
#include "index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
@ -448,6 +449,121 @@ TEST(CApiTest, MergeInto) {
|
|||
ASSERT_EQ(distance[1], 5);
|
||||
}
|
||||
|
||||
void
|
||||
CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
|
||||
auto sr = (SearchResult*)results[0];
|
||||
auto topk = sr->topk_;
|
||||
auto num_queries = sr->num_queries_;
|
||||
|
||||
std::unordered_set<int64_t> pk_set;
|
||||
std::unordered_set<float> distance_set;
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
auto search_result = (SearchResult*)results[i];
|
||||
auto size = search_result->result_offsets_.size();
|
||||
for (int j = 0; j < size; j++) {
|
||||
auto ret = pk_set.insert(search_result->primary_keys_[j]);
|
||||
// std::cout << j << ": " << ret.second << " "
|
||||
// << search_result->primary_keys_[j] << " "
|
||||
// << search_result->result_distances_[j] << std::endl;
|
||||
distance_set.insert(search_result->result_distances_[j]);
|
||||
}
|
||||
}
|
||||
std::cout << pk_set.size() << " " << distance_set.size() << " " << topk * num_queries << std::endl;
|
||||
// TODO: find 1 duplicated result (pk = 10345), need check
|
||||
assert(pk_set.size() == topk * num_queries - 1);
|
||||
}
|
||||
|
||||
TEST(CApiTest, ReduceRemoveDuplicates) {
|
||||
auto collection = NewCollection(get_default_schema_config());
|
||||
auto segment = NewSegment(collection, 0, Growing);
|
||||
|
||||
int N = 10000;
|
||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
|
||||
|
||||
int64_t offset;
|
||||
PreInsert(segment, N, &offset);
|
||||
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||
assert(ins_res.error_code == Success);
|
||||
|
||||
const char* dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
})";
|
||||
|
||||
int num_queries = 10;
|
||||
auto blob = generate_query_data(num_queries);
|
||||
|
||||
void* plan = nullptr;
|
||||
auto status = CreateSearchPlan(collection, dsl_string, &plan);
|
||||
assert(status.error_code == Success);
|
||||
|
||||
void* placeholderGroup = nullptr;
|
||||
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
|
||||
assert(status.error_code == Success);
|
||||
|
||||
std::vector<CPlaceholderGroup> placeholderGroups;
|
||||
placeholderGroups.push_back(placeholderGroup);
|
||||
timestamps.clear();
|
||||
timestamps.push_back(1);
|
||||
|
||||
{
|
||||
std::vector<CSearchResult> results;
|
||||
CSearchResult res1, res2;
|
||||
status = Search(segment, plan, placeholderGroup, timestamps[0], &res1);
|
||||
assert(status.error_code == Success);
|
||||
status = Search(segment, plan, placeholderGroup, timestamps[0], &res2);
|
||||
assert(status.error_code == Success);
|
||||
results.push_back(res1);
|
||||
results.push_back(res2);
|
||||
|
||||
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
|
||||
assert(status.error_code == Success);
|
||||
CheckSearchResultDuplicate(results);
|
||||
|
||||
DeleteSearchResult(res1);
|
||||
DeleteSearchResult(res2);
|
||||
}
|
||||
{
|
||||
std::vector<CSearchResult> results;
|
||||
CSearchResult res1, res2, res3;
|
||||
status = Search(segment, plan, placeholderGroup, timestamps[0], &res1);
|
||||
assert(status.error_code == Success);
|
||||
status = Search(segment, plan, placeholderGroup, timestamps[0], &res2);
|
||||
assert(status.error_code == Success);
|
||||
status = Search(segment, plan, placeholderGroup, timestamps[0], &res3);
|
||||
assert(status.error_code == Success);
|
||||
results.push_back(res1);
|
||||
results.push_back(res2);
|
||||
results.push_back(res3);
|
||||
|
||||
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
|
||||
assert(status.error_code == Success);
|
||||
CheckSearchResultDuplicate(results);
|
||||
|
||||
DeleteSearchResult(res1);
|
||||
DeleteSearchResult(res2);
|
||||
DeleteSearchResult(res3);
|
||||
}
|
||||
|
||||
DeleteSearchPlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteCollection(collection);
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
|
||||
TEST(CApiTest, Reduce) {
|
||||
auto collection = NewCollection(get_default_schema_config());
|
||||
auto segment = NewSegment(collection, 0, Growing);
|
||||
|
|
Loading…
Reference in New Issue