mirror of https://github.com/milvus-io/milvus.git
Optimize segcore Reduce (#18902)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com> Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/19040/head
parent
495b214dd0
commit
765907ab77
|
@ -28,127 +28,40 @@ ReduceHelper::Initialize() {
|
|||
AssertInfo(slice_nqs_.size() > 0, "empty slice_nqs");
|
||||
AssertInfo(slice_nqs_.size() == slice_topKs_.size(), "unaligned slice_nqs and slice_topKs");
|
||||
|
||||
unify_topK_ = search_results_[0]->unity_topK_;
|
||||
total_nq_ = search_results_[0]->total_nq_;
|
||||
num_segments_ = search_results_.size();
|
||||
num_slices_ = slice_nqs_.size();
|
||||
|
||||
// prefix sum, get slices offsets
|
||||
AssertInfo(num_slices_ > 0, "empty slice_nqs is not allowed");
|
||||
auto slice_offsets_size = num_slices_ + 1;
|
||||
nq_slice_offsets_ = std::vector<int32_t>(slice_offsets_size);
|
||||
|
||||
for (int i = 1; i < slice_offsets_size; i++) {
|
||||
nq_slice_offsets_[i] = nq_slice_offsets_[i - 1] + slice_nqs_[i - 1];
|
||||
for (auto j = nq_slice_offsets_[i - 1]; j < nq_slice_offsets_[i]; j++) {
|
||||
}
|
||||
}
|
||||
AssertInfo(nq_slice_offsets_[num_slices_] == total_nq_,
|
||||
"illegal req sizes"
|
||||
", nq_slice_offsets[last] = " +
|
||||
std::to_string(nq_slice_offsets_[num_slices_]) + ", total_nq = " + std::to_string(total_nq_));
|
||||
slice_nqs_prefix_sum_.resize(num_slices_ + 1);
|
||||
std::partial_sum(slice_nqs_.begin(), slice_nqs_.end(), slice_nqs_prefix_sum_.begin() + 1);
|
||||
AssertInfo(slice_nqs_prefix_sum_[num_slices_] == total_nq_, "illegal req sizes, slice_nqs_prefix_sum_[last] = " +
|
||||
std::to_string(slice_nqs_prefix_sum_[num_slices_]) +
|
||||
", total_nq = " + std::to_string(total_nq_));
|
||||
|
||||
// init final_search_records and final_read_topKs
|
||||
final_search_records_ = std::vector<std::vector<int64_t>>(num_segments_);
|
||||
final_real_topKs_ = std::vector<std::vector<int64_t>>(num_segments_);
|
||||
for (auto& topKs : final_real_topKs_) {
|
||||
// `topKs` records real topK of each query
|
||||
topKs.resize(total_nq_);
|
||||
final_search_records_.resize(num_segments_);
|
||||
for (auto& search_record : final_search_records_) {
|
||||
search_record.resize(total_nq_);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ReduceHelper::Reduce() {
|
||||
std::vector<SearchResult*> valid_search_results;
|
||||
// get primary keys for duplicates removal
|
||||
for (auto search_result : search_results_) {
|
||||
FilterInvalidSearchResult(search_result);
|
||||
if (search_result->get_total_result_count() > 0) {
|
||||
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
|
||||
segment->FillPrimaryKeys(plan_, *search_result);
|
||||
valid_search_results.emplace_back(search_result);
|
||||
}
|
||||
}
|
||||
search_results_ = valid_search_results;
|
||||
num_segments_ = search_results_.size();
|
||||
if (valid_search_results.size() == 0) {
|
||||
// TODO: return empty search result?
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_slices_; i++) {
|
||||
// ReduceResultData for each slice
|
||||
ReduceResultData(i);
|
||||
}
|
||||
// after reduce, remove invalid primary_keys, distances and ids by `final_search_records`
|
||||
for (int i = 0; i < num_segments_; i++) {
|
||||
auto search_result = search_results_[i];
|
||||
if (search_result->result_offsets_.size() != 0) {
|
||||
std::vector<milvus::PkType> primary_keys;
|
||||
std::vector<float> distances;
|
||||
std::vector<int64_t> seg_offsets;
|
||||
for (int j = 0; j < final_search_records_[i].size(); j++) {
|
||||
auto& offset = final_search_records_[i][j];
|
||||
primary_keys.push_back(search_result->primary_keys_[offset]);
|
||||
distances.push_back(search_result->distances_[offset]);
|
||||
seg_offsets.push_back(search_result->seg_offsets_[offset]);
|
||||
}
|
||||
|
||||
search_result->primary_keys_ = std::move(primary_keys);
|
||||
search_result->distances_ = std::move(distances);
|
||||
search_result->seg_offsets_ = std::move(seg_offsets);
|
||||
}
|
||||
search_result->topk_per_nq_prefix_sum_.resize(final_real_topKs_[i].size() + 1);
|
||||
std::partial_sum(final_real_topKs_[i].begin(), final_real_topKs_[i].end(),
|
||||
search_result->topk_per_nq_prefix_sum_.begin() + 1);
|
||||
}
|
||||
|
||||
// fill target entry
|
||||
for (auto& search_result : search_results_) {
|
||||
auto segment = static_cast<milvus::segcore::SegmentInterface*>(search_result->segment_);
|
||||
segment->FillTargetEntry(plan_, *search_result);
|
||||
}
|
||||
FillPrimaryKey();
|
||||
ReduceResultData();
|
||||
RefreshSearchResult();
|
||||
FillEntryData();
|
||||
}
|
||||
|
||||
void
|
||||
ReduceHelper::Marshal() {
|
||||
// example:
|
||||
// ----------------------------------
|
||||
// nq0 nq1 nq2
|
||||
// sr0 topk00 topk01 topk02
|
||||
// sr1 topk10 topk11 topk12
|
||||
// ----------------------------------
|
||||
// then:
|
||||
// result_slice_offsets[] = {
|
||||
// 0,
|
||||
// == sr0->topk_per_nq_prefix_sum_[0] + sr1->topk_per_nq_prefix_sum_[0]
|
||||
// ((topk00) + (topk10)),
|
||||
// == sr0->topk_per_nq_prefix_sum_[1] + sr1->topk_per_nq_prefix_sum_[1]
|
||||
// ((topk00 + topk01) + (topk10 + topk11)),
|
||||
// == sr0->topk_per_nq_prefix_sum_[2] + sr1->topk_per_nq_prefix_sum_[2]
|
||||
// ((topk00 + topk01 + topk02) + (topk10 + topk11 + topk12)),
|
||||
// == sr0->topk_per_nq_prefix_sum_[3] + sr1->topk_per_nq_prefix_sum_[3]
|
||||
// }
|
||||
auto result_slice_offsets = std::vector<int64_t>(nq_slice_offsets_.size(), 0);
|
||||
for (auto search_result : search_results_) {
|
||||
AssertInfo(search_result->topk_per_nq_prefix_sum_.size() == search_result->total_nq_ + 1,
|
||||
"incorrect topk_per_nq_prefix_sum_ size in search result");
|
||||
for (int i = 1; i < nq_slice_offsets_.size(); i++) {
|
||||
result_slice_offsets[i] += search_result->topk_per_nq_prefix_sum_[nq_slice_offsets_[i]];
|
||||
}
|
||||
}
|
||||
AssertInfo(result_slice_offsets[num_slices_] <= total_nq_ * unify_topK_,
|
||||
"illegal result_slice_offsets when Marshal, result_slice_offsets[last] = " +
|
||||
std::to_string(result_slice_offsets[num_slices_]) + ", total_nq = " + std::to_string(total_nq_) +
|
||||
", unify_topK = " + std::to_string(unify_topK_));
|
||||
|
||||
// get search result data blobs of slices
|
||||
search_result_data_blobs_ = std::make_unique<milvus::segcore::SearchResultDataBlobs>();
|
||||
search_result_data_blobs_->blobs.resize(num_slices_);
|
||||
//#pragma omp parallel for
|
||||
for (int i = 0; i < num_slices_; i++) {
|
||||
auto result_count = result_slice_offsets[i + 1] - result_slice_offsets[i];
|
||||
auto proto = GetSearchResultDataSlice(i, result_count);
|
||||
auto proto = GetSearchResultDataSlice(i);
|
||||
search_result_data_blobs_->blobs[i] = proto;
|
||||
}
|
||||
}
|
||||
|
@ -178,102 +91,152 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
|
|||
}
|
||||
}
|
||||
|
||||
search_result->distances_ = std::move(distances);
|
||||
search_result->seg_offsets_ = std::move(seg_offsets);
|
||||
search_result->distances_.swap(distances);
|
||||
search_result->seg_offsets_.swap(seg_offsets);
|
||||
search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
|
||||
std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
|
||||
}
|
||||
|
||||
void
|
||||
ReduceHelper::ReduceResultData(int slice_index) {
|
||||
ReduceHelper::FillPrimaryKey() {
|
||||
std::vector<SearchResult*> valid_search_results;
|
||||
// get primary keys for duplicates removal
|
||||
for (auto search_result : search_results_) {
|
||||
FilterInvalidSearchResult(search_result);
|
||||
if (search_result->get_total_result_count() > 0) {
|
||||
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
|
||||
segment->FillPrimaryKeys(plan_, *search_result);
|
||||
valid_search_results.emplace_back(search_result);
|
||||
}
|
||||
}
|
||||
search_results_.swap(valid_search_results);
|
||||
num_segments_ = search_results_.size();
|
||||
}
|
||||
|
||||
void
|
||||
ReduceHelper::RefreshSearchResult() {
|
||||
for (int i = 0; i < num_segments_; i++) {
|
||||
std::vector<int64_t> real_topks(total_nq_, 0);
|
||||
auto search_result = search_results_[i];
|
||||
if (search_result->result_offsets_.size() != 0) {
|
||||
std::vector<milvus::PkType> primary_keys;
|
||||
std::vector<float> distances;
|
||||
std::vector<int64_t> seg_offsets;
|
||||
for (int j = 0; j < total_nq_; j++) {
|
||||
for (auto offset : final_search_records_[i][j]) {
|
||||
primary_keys.push_back(search_result->primary_keys_[offset]);
|
||||
distances.push_back(search_result->distances_[offset]);
|
||||
seg_offsets.push_back(search_result->seg_offsets_[offset]);
|
||||
real_topks[j]++;
|
||||
}
|
||||
}
|
||||
search_result->primary_keys_ = std::move(primary_keys);
|
||||
search_result->distances_ = std::move(distances);
|
||||
search_result->seg_offsets_ = std::move(seg_offsets);
|
||||
}
|
||||
std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ReduceHelper::FillEntryData() {
|
||||
for (auto search_result : search_results_) {
|
||||
auto segment = static_cast<milvus::segcore::SegmentInterface*>(search_result->segment_);
|
||||
segment->FillTargetEntry(plan_, *search_result);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t
|
||||
ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offset) {
|
||||
std::vector<SearchResultPair> result_pairs;
|
||||
for (int i = 0; i < num_segments_; i++) {
|
||||
auto search_result = search_results_[i];
|
||||
auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi];
|
||||
auto offset_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
|
||||
if (offset_beg == offset_end) {
|
||||
continue;
|
||||
}
|
||||
auto primary_key = search_result->primary_keys_[offset_beg];
|
||||
auto distance = search_result->distances_[offset_beg];
|
||||
result_pairs.emplace_back(primary_key, distance, search_result, i, offset_beg, offset_end);
|
||||
}
|
||||
|
||||
// nq has no results for all segments
|
||||
if (result_pairs.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t dup_cnt = 0;
|
||||
std::unordered_set<milvus::PkType> pk_set;
|
||||
int64_t prev_offset = offset;
|
||||
while (offset - prev_offset < topk) {
|
||||
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
||||
auto& pilot = result_pairs[0];
|
||||
auto index = pilot.segment_index_;
|
||||
auto pk = pilot.primary_key_;
|
||||
// no valid search result for this nq, break to next
|
||||
if (pk == INVALID_PK) {
|
||||
break;
|
||||
}
|
||||
// remove duplicates
|
||||
if (pk_set.count(pk) == 0) {
|
||||
pilot.search_result_->result_offsets_.push_back(offset++);
|
||||
final_search_records_[index][qi].push_back(pilot.offset_);
|
||||
pk_set.insert(pk);
|
||||
} else {
|
||||
// skip entity with same primary key
|
||||
dup_cnt++;
|
||||
}
|
||||
pilot.reset();
|
||||
}
|
||||
return dup_cnt;
|
||||
}
|
||||
|
||||
void
|
||||
ReduceHelper::ReduceResultData() {
|
||||
for (int i = 0; i < num_segments_; i++) {
|
||||
auto search_result = search_results_[i];
|
||||
auto result_count = search_result->get_total_result_count();
|
||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||
AssertInfo(search_result->primary_keys_.size() == result_count, "incorrect search result primary key size");
|
||||
AssertInfo(search_result->distances_.size() == result_count, "incorrect search result distance size");
|
||||
AssertInfo(search_result->seg_offsets_.size() == result_count, "incorrect search result seg offset size");
|
||||
AssertInfo(search_result->primary_keys_.size() == result_count, "incorrect search result primary key size");
|
||||
}
|
||||
|
||||
auto nq_offset_begin = nq_slice_offsets_[slice_index];
|
||||
auto nq_offset_end = nq_slice_offsets_[slice_index + 1];
|
||||
AssertInfo(nq_offset_begin < nq_offset_end,
|
||||
"illegal nq offsets when ReduceResultData, nq_offset_begin = " + std::to_string(nq_offset_begin) +
|
||||
", nq_offset_end = " + std::to_string(nq_offset_end));
|
||||
|
||||
// `search_records` records the search result offsets
|
||||
std::vector<std::vector<int64_t>> search_records(num_segments_);
|
||||
int64_t skip_dup_cnt = 0;
|
||||
for (int64_t slice_index = 0; slice_index < num_slices_; slice_index++) {
|
||||
auto nq_begin = slice_nqs_prefix_sum_[slice_index];
|
||||
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
|
||||
|
||||
// reduce search results
|
||||
int64_t result_offset = 0;
|
||||
for (int64_t qi = nq_offset_begin; qi < nq_offset_end; qi++) {
|
||||
std::vector<SearchResultPair> result_pairs;
|
||||
for (int i = 0; i < num_segments_; i++) {
|
||||
auto search_result = search_results_[i];
|
||||
if (search_result->topk_per_nq_prefix_sum_[qi + 1] - search_result->topk_per_nq_prefix_sum_[qi] == 0) {
|
||||
continue;
|
||||
}
|
||||
auto base_offset = search_result->topk_per_nq_prefix_sum_[qi];
|
||||
auto primary_key = search_result->primary_keys_[base_offset];
|
||||
auto distance = search_result->distances_[base_offset];
|
||||
result_pairs.emplace_back(primary_key, distance, search_result, i, base_offset,
|
||||
search_result->topk_per_nq_prefix_sum_[qi + 1]);
|
||||
}
|
||||
|
||||
// nq has no results for all segments
|
||||
if (result_pairs.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
std::unordered_set<milvus::PkType> pk_set;
|
||||
int64_t last_nq_result_offset = result_offset;
|
||||
while (result_offset - last_nq_result_offset < slice_topKs_[slice_index]) {
|
||||
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
||||
auto& pilot = result_pairs[0];
|
||||
auto index = pilot.segment_index_;
|
||||
auto curr_pk = pilot.primary_key_;
|
||||
// no valid search result for this nq, break to next
|
||||
if (curr_pk == INVALID_PK) {
|
||||
break;
|
||||
}
|
||||
// remove duplicates
|
||||
if (pk_set.count(curr_pk) == 0) {
|
||||
pilot.search_result_->result_offsets_.push_back(result_offset++);
|
||||
search_records[index].push_back(pilot.offset_);
|
||||
pk_set.insert(curr_pk);
|
||||
final_real_topKs_[index][qi]++;
|
||||
} else {
|
||||
// skip entity with same primary key
|
||||
skip_dup_cnt++;
|
||||
}
|
||||
pilot.reset();
|
||||
// reduce search results
|
||||
int64_t result_offset = 0;
|
||||
for (int64_t qi = nq_begin; qi < nq_end; qi++) {
|
||||
skip_dup_cnt += ReduceSearchResultForOneNQ(qi, slice_topKs_[slice_index], result_offset);
|
||||
}
|
||||
}
|
||||
|
||||
if (skip_dup_cnt > 0) {
|
||||
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt;
|
||||
}
|
||||
|
||||
// append search_records to final_search_records
|
||||
for (int i = 0; i < num_segments_; i++) {
|
||||
for (int j = 0; j < search_records[i].size(); j++) {
|
||||
final_search_records_[i].emplace_back(search_records[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<char>
|
||||
ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
|
||||
auto nq_offset_begin = nq_slice_offsets_[slice_index_];
|
||||
auto nq_offset_end = nq_slice_offsets_[slice_index_ + 1];
|
||||
AssertInfo(nq_offset_begin <= nq_offset_end,
|
||||
"illegal offsets when GetSearchResultDataSlice, nq_offset_begin = " + std::to_string(nq_offset_begin) +
|
||||
", nq_offset_end = " + std::to_string(nq_offset_end));
|
||||
ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
||||
auto nq_begin = slice_nqs_prefix_sum_[slice_index];
|
||||
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
|
||||
|
||||
int64_t result_count = 0;
|
||||
for (auto search_result : search_results_) {
|
||||
AssertInfo(search_result->topk_per_nq_prefix_sum_.size() == search_result->total_nq_ + 1,
|
||||
"incorrect topk_per_nq_prefix_sum_ size in search result");
|
||||
result_count +=
|
||||
search_result->topk_per_nq_prefix_sum_[nq_end] - search_result->topk_per_nq_prefix_sum_[nq_begin];
|
||||
}
|
||||
|
||||
auto search_result_data = std::make_unique<milvus::proto::schema::SearchResultData>();
|
||||
// set unify_topK and total_nq
|
||||
search_result_data->set_top_k(slice_topKs_[slice_index_]);
|
||||
search_result_data->set_num_queries(nq_offset_end - nq_offset_begin);
|
||||
search_result_data->mutable_topks()->Resize(nq_offset_end - nq_offset_begin, 0);
|
||||
search_result_data->set_top_k(slice_topKs_[slice_index]);
|
||||
search_result_data->set_num_queries(nq_end - nq_begin);
|
||||
search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0);
|
||||
|
||||
// `result_pairs` contains the SearchResult and result_offset info, used for filling output fields
|
||||
std::vector<std::pair<SearchResult*, int64_t>> result_pairs(result_count);
|
||||
|
@ -306,19 +269,20 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
|
|||
search_result_data->mutable_scores()->Resize(result_count, 0);
|
||||
|
||||
// fill pks and distances
|
||||
for (auto nq_offset = nq_offset_begin; nq_offset < nq_offset_end; nq_offset++) {
|
||||
int64_t topK_count = 0;
|
||||
for (int i = 0; i < search_results_.size(); i++) {
|
||||
auto search_result = search_results_[i];
|
||||
for (auto qi = nq_begin; qi < nq_end; qi++) {
|
||||
int64_t topk_count = 0;
|
||||
for (auto search_result : search_results_) {
|
||||
AssertInfo(search_result != nullptr, "null search result when reorganize");
|
||||
if (search_result->result_offsets_.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto result_start = search_result->topk_per_nq_prefix_sum_[nq_offset];
|
||||
auto result_end = search_result->topk_per_nq_prefix_sum_[nq_offset + 1];
|
||||
for (auto offset = result_start; offset < result_end; offset++) {
|
||||
auto loc = search_result->result_offsets_[offset];
|
||||
auto topk_start = search_result->topk_per_nq_prefix_sum_[qi];
|
||||
auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
|
||||
topk_count += topk_end - topk_start;
|
||||
|
||||
for (auto ki = topk_start; ki < topk_end; ki++) {
|
||||
auto loc = search_result->result_offsets_[ki];
|
||||
AssertInfo(loc < result_count && loc >= 0,
|
||||
"invalid loc when GetSearchResultDataSlice, loc = " + std::to_string(loc) +
|
||||
", result_count = " + std::to_string(result_count));
|
||||
|
@ -326,12 +290,12 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
|
|||
switch (pk_type) {
|
||||
case milvus::DataType::INT64: {
|
||||
search_result_data->mutable_ids()->mutable_int_id()->mutable_data()->Set(
|
||||
loc, std::visit(Int64PKVisitor{}, search_result->primary_keys_[offset]));
|
||||
loc, std::visit(Int64PKVisitor{}, search_result->primary_keys_[ki]));
|
||||
break;
|
||||
}
|
||||
case milvus::DataType::VARCHAR: {
|
||||
*search_result_data->mutable_ids()->mutable_str_id()->mutable_data()->Mutable(loc) =
|
||||
std::visit(StrPKVisitor{}, search_result->primary_keys_[offset]);
|
||||
std::visit(StrPKVisitor{}, search_result->primary_keys_[ki]);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -340,17 +304,14 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
|
|||
}
|
||||
|
||||
// set result distances
|
||||
search_result_data->mutable_scores()->Set(loc, search_result->distances_[offset]);
|
||||
search_result_data->mutable_scores()->Set(loc, search_result->distances_[ki]);
|
||||
// set result offset to fill output fields data
|
||||
result_pairs[loc] = std::make_pair(search_result, offset);
|
||||
result_pairs[loc] = std::make_pair(search_result, ki);
|
||||
}
|
||||
|
||||
topK_count += search_result->topk_per_nq_prefix_sum_[nq_offset + 1] -
|
||||
search_result->topk_per_nq_prefix_sum_[nq_offset];
|
||||
}
|
||||
|
||||
// update result topKs
|
||||
search_result_data->mutable_topks()->Set(nq_offset - nq_offset_begin, topK_count);
|
||||
search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count);
|
||||
}
|
||||
|
||||
AssertInfo(search_result_data->scores_size() == result_count,
|
||||
|
|
|
@ -61,15 +61,26 @@ class ReduceHelper {
|
|||
FilterInvalidSearchResult(SearchResult* search_result);
|
||||
|
||||
void
|
||||
ReduceResultData(int slice_index);
|
||||
FillPrimaryKey();
|
||||
|
||||
void
|
||||
RefreshSearchResult();
|
||||
|
||||
void
|
||||
FillEntryData();
|
||||
|
||||
int64_t
|
||||
ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& result_offset);
|
||||
|
||||
void
|
||||
ReduceResultData();
|
||||
|
||||
std::vector<char>
|
||||
GetSearchResultDataSlice(int slice_index_, int64_t result_count);
|
||||
GetSearchResultDataSlice(int slice_index_);
|
||||
|
||||
private:
|
||||
std::vector<int64_t> slice_topKs_;
|
||||
std::vector<int64_t> slice_nqs_;
|
||||
int64_t unify_topK_;
|
||||
int64_t total_nq_;
|
||||
int64_t num_segments_;
|
||||
int64_t num_slices_;
|
||||
|
@ -77,10 +88,10 @@ class ReduceHelper {
|
|||
milvus::query::Plan* plan_;
|
||||
std::vector<SearchResult*>& search_results_;
|
||||
|
||||
//
|
||||
std::vector<int32_t> nq_slice_offsets_;
|
||||
std::vector<std::vector<int64_t>> final_search_records_;
|
||||
std::vector<std::vector<int64_t>> final_real_topKs_;
|
||||
std::vector<int64_t> slice_nqs_prefix_sum_;
|
||||
|
||||
// dim0: num_segments_; dim1: total_nq_; dim2: offset
|
||||
std::vector<std::vector<std::vector<int64_t>>> final_search_records_;
|
||||
|
||||
// output
|
||||
std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_;
|
||||
|
|
|
@ -1347,25 +1347,16 @@ testReduceSearchWithExpr(int N, int topK, int num_queries) {
|
|||
auto suc = search_result_data.ParseFromArray(search_result_data_blobs->blobs[i].data(),
|
||||
search_result_data_blobs->blobs[i].size());
|
||||
assert(suc);
|
||||
|
||||
assert(suc);
|
||||
assert(search_result_data.num_queries() == slice_nqs[i]);
|
||||
assert(search_result_data.top_k() == slice_topKs[i]);
|
||||
assert(search_result_data.scores().size() == slice_topKs[i] * slice_nqs[i]);
|
||||
assert(search_result_data.ids().int_id().data_size() == slice_topKs[i] * slice_nqs[i]);
|
||||
assert(search_result_data.scores().size() == search_result_data.topks().at(0) * slice_nqs[i]);
|
||||
assert(search_result_data.ids().int_id().data_size() == search_result_data.topks().at(0) * slice_nqs[i]);
|
||||
|
||||
// check topKs
|
||||
// check real topks
|
||||
assert(search_result_data.topks().size() == slice_nqs[i]);
|
||||
for (int j = 0; j < search_result_data.topks().size(); j++) {
|
||||
assert(search_result_data.topks().at(j) == slice_topKs[i]);
|
||||
for (auto real_topk : search_result_data.topks()) {
|
||||
assert(real_topk <= slice_topKs[i]);
|
||||
}
|
||||
|
||||
// assert(search_result_data.scores().size() == slice_topKs[i] * slice_nqs[i]);
|
||||
// assert(search_result_data.ids().int_id().data_size() == slice_topKs[i] * slice_nqs[i]);
|
||||
// assert(search_result_data.top_k() == topK);
|
||||
// assert(search_result_data.num_queries() == req_sizes[i]);
|
||||
// assert(search_result_data.scores().size() == topK * req_sizes[i]);
|
||||
// assert(search_result_data.ids().int_id().data_size() == topK * req_sizes[i]);
|
||||
}
|
||||
|
||||
DeleteSearchResultDataBlobs(cSearchResultData);
|
||||
|
@ -1378,6 +1369,8 @@ testReduceSearchWithExpr(int N, int topK, int num_queries) {
|
|||
}
|
||||
|
||||
TEST(CApiTest, ReduceSearchWithExpr) {
|
||||
testReduceSearchWithExpr(2, 1, 1);
|
||||
testReduceSearchWithExpr(2, 10, 10);
|
||||
testReduceSearchWithExpr(100, 1, 1);
|
||||
testReduceSearchWithExpr(100, 10, 10);
|
||||
testReduceSearchWithExpr(10000, 1, 1);
|
||||
|
|
Loading…
Reference in New Issue