From 765907ab775589effe2d0d465a4e61b4ce79b2ca Mon Sep 17 00:00:00 2001 From: Cai Yudong <yudong.cai@zilliz.com> Date: Tue, 6 Sep 2022 10:55:12 +0800 Subject: [PATCH] Optimize segcore Reduce (#18902) Signed-off-by: yudong.cai <yudong.cai@zilliz.com> Signed-off-by: yudong.cai <yudong.cai@zilliz.com> --- internal/core/src/segcore/Reduce.cpp | 335 ++++++++++++-------------- internal/core/src/segcore/Reduce.h | 25 +- internal/core/unittest/test_c_api.cpp | 21 +- 3 files changed, 173 insertions(+), 208 deletions(-) diff --git a/internal/core/src/segcore/Reduce.cpp b/internal/core/src/segcore/Reduce.cpp index d0a6850361..e030632de8 100644 --- a/internal/core/src/segcore/Reduce.cpp +++ b/internal/core/src/segcore/Reduce.cpp @@ -28,127 +28,40 @@ ReduceHelper::Initialize() { AssertInfo(slice_nqs_.size() > 0, "empty slice_nqs"); AssertInfo(slice_nqs_.size() == slice_topKs_.size(), "unaligned slice_nqs and slice_topKs"); - unify_topK_ = search_results_[0]->unity_topK_; total_nq_ = search_results_[0]->total_nq_; num_segments_ = search_results_.size(); num_slices_ = slice_nqs_.size(); // prefix sum, get slices offsets AssertInfo(num_slices_ > 0, "empty slice_nqs is not allowed"); - auto slice_offsets_size = num_slices_ + 1; - nq_slice_offsets_ = std::vector<int32_t>(slice_offsets_size); - - for (int i = 1; i < slice_offsets_size; i++) { - nq_slice_offsets_[i] = nq_slice_offsets_[i - 1] + slice_nqs_[i - 1]; - for (auto j = nq_slice_offsets_[i - 1]; j < nq_slice_offsets_[i]; j++) { - } - } - AssertInfo(nq_slice_offsets_[num_slices_] == total_nq_, - "illegal req sizes" - ", nq_slice_offsets[last] = " + - std::to_string(nq_slice_offsets_[num_slices_]) + ", total_nq = " + std::to_string(total_nq_)); + slice_nqs_prefix_sum_.resize(num_slices_ + 1); + std::partial_sum(slice_nqs_.begin(), slice_nqs_.end(), slice_nqs_prefix_sum_.begin() + 1); + AssertInfo(slice_nqs_prefix_sum_[num_slices_] == total_nq_, "illegal req sizes, slice_nqs_prefix_sum_[last] = " + + std::to_string(slice_nqs_prefix_sum_[num_slices_]) + + ", total_nq = " + std::to_string(total_nq_)); // init final_search_records and final_read_topKs - final_search_records_ = std::vector<std::vector<int64_t>>(num_segments_); - final_real_topKs_ = std::vector<std::vector<int64_t>>(num_segments_); - for (auto& topKs : final_real_topKs_) { - // `topKs` records real topK of each query - topKs.resize(total_nq_); + final_search_records_.resize(num_segments_); + for (auto& search_record : final_search_records_) { + search_record.resize(total_nq_); } } void ReduceHelper::Reduce() { - std::vector<SearchResult*> valid_search_results; - // get primary keys for duplicates removal - for (auto search_result : search_results_) { - FilterInvalidSearchResult(search_result); - if (search_result->get_total_result_count() > 0) { - auto segment = static_cast<SegmentInterface*>(search_result->segment_); - segment->FillPrimaryKeys(plan_, *search_result); - valid_search_results.emplace_back(search_result); - } - } - search_results_ = valid_search_results; - num_segments_ = search_results_.size(); - if (valid_search_results.size() == 0) { - // TODO: return empty search result? - return; - } - - for (int i = 0; i < num_slices_; i++) { - // ReduceResultData for each slice - ReduceResultData(i); - } - // after reduce, remove invalid primary_keys, distances and ids by `final_search_records` - for (int i = 0; i < num_segments_; i++) { - auto search_result = search_results_[i]; - if (search_result->result_offsets_.size() != 0) { - std::vector<milvus::PkType> primary_keys; - std::vector<float> distances; - std::vector<int64_t> seg_offsets; - for (int j = 0; j < final_search_records_[i].size(); j++) { - auto& offset = final_search_records_[i][j]; - primary_keys.push_back(search_result->primary_keys_[offset]); - distances.push_back(search_result->distances_[offset]); - seg_offsets.push_back(search_result->seg_offsets_[offset]); - } - - search_result->primary_keys_ = std::move(primary_keys); - search_result->distances_ = std::move(distances); - search_result->seg_offsets_ = std::move(seg_offsets); - } - search_result->topk_per_nq_prefix_sum_.resize(final_real_topKs_[i].size() + 1); - std::partial_sum(final_real_topKs_[i].begin(), final_real_topKs_[i].end(), - search_result->topk_per_nq_prefix_sum_.begin() + 1); - } - - // fill target entry - for (auto& search_result : search_results_) { - auto segment = static_cast<milvus::segcore::SegmentInterface*>(search_result->segment_); - segment->FillTargetEntry(plan_, *search_result); - } + FillPrimaryKey(); + ReduceResultData(); + RefreshSearchResult(); + FillEntryData(); } void ReduceHelper::Marshal() { - // example: - // ---------------------------------- - // nq0 nq1 nq2 - // sr0 topk00 topk01 topk02 - // sr1 topk10 topk11 topk12 - // ---------------------------------- - // then: - // result_slice_offsets[] = { - // 0, - // == sr0->topk_per_nq_prefix_sum_[0] + sr1->topk_per_nq_prefix_sum_[0] - // ((topk00) + (topk10)), - // == sr0->topk_per_nq_prefix_sum_[1] + sr1->topk_per_nq_prefix_sum_[1] - // ((topk00 + topk01) + (topk10 + topk11)), - // == sr0->topk_per_nq_prefix_sum_[2] + sr1->topk_per_nq_prefix_sum_[2] - // ((topk00 + topk01 + topk02) + (topk10 + topk11 + topk12)), - // == sr0->topk_per_nq_prefix_sum_[3] + sr1->topk_per_nq_prefix_sum_[3] - // } - auto result_slice_offsets = std::vector<int64_t>(nq_slice_offsets_.size(), 0); - for (auto search_result : search_results_) { - AssertInfo(search_result->topk_per_nq_prefix_sum_.size() == search_result->total_nq_ + 1, - "incorrect topk_per_nq_prefix_sum_ size in search result"); - for (int i = 1; i < nq_slice_offsets_.size(); i++) { - result_slice_offsets[i] += search_result->topk_per_nq_prefix_sum_[nq_slice_offsets_[i]]; - } - } - AssertInfo(result_slice_offsets[num_slices_] <= total_nq_ * unify_topK_, - "illegal result_slice_offsets when Marshal, result_slice_offsets[last] = " + - std::to_string(result_slice_offsets[num_slices_]) + ", total_nq = " + std::to_string(total_nq_) + - ", unify_topK = " + std::to_string(unify_topK_)); - // get search result data blobs of slices search_result_data_blobs_ = std::make_unique<milvus::segcore::SearchResultDataBlobs>(); search_result_data_blobs_->blobs.resize(num_slices_); - //#pragma omp parallel for for (int i = 0; i < num_slices_; i++) { - auto result_count = result_slice_offsets[i + 1] - result_slice_offsets[i]; - auto proto = GetSearchResultDataSlice(i, result_count); + auto proto = GetSearchResultDataSlice(i); search_result_data_blobs_->blobs[i] = proto; } } @@ -178,102 +91,152 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) { } } - search_result->distances_ = std::move(distances); - search_result->seg_offsets_ = std::move(seg_offsets); + search_result->distances_.swap(distances); + search_result->seg_offsets_.swap(seg_offsets); search_result->topk_per_nq_prefix_sum_.resize(nq + 1); std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1); } void -ReduceHelper::ReduceResultData(int slice_index) { +ReduceHelper::FillPrimaryKey() { + std::vector<SearchResult*> valid_search_results; + // get primary keys for duplicates removal + for (auto search_result : search_results_) { + FilterInvalidSearchResult(search_result); + if (search_result->get_total_result_count() > 0) { + auto segment = static_cast<SegmentInterface*>(search_result->segment_); + segment->FillPrimaryKeys(plan_, *search_result); + valid_search_results.emplace_back(search_result); + } + } + search_results_.swap(valid_search_results); + num_segments_ = search_results_.size(); +} + +void +ReduceHelper::RefreshSearchResult() { + for (int i = 0; i < num_segments_; i++) { + std::vector<int64_t> real_topks(total_nq_, 0); + auto search_result = search_results_[i]; + if (search_result->result_offsets_.size() != 0) { + std::vector<milvus::PkType> primary_keys; + std::vector<float> distances; + std::vector<int64_t> seg_offsets; + for (int j = 0; j < total_nq_; j++) { + for (auto offset : final_search_records_[i][j]) { + primary_keys.push_back(search_result->primary_keys_[offset]); + distances.push_back(search_result->distances_[offset]); + seg_offsets.push_back(search_result->seg_offsets_[offset]); + real_topks[j]++; + } + } + search_result->primary_keys_ = std::move(primary_keys); + search_result->distances_ = std::move(distances); + search_result->seg_offsets_ = std::move(seg_offsets); + } + std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1); + } +} + +void +ReduceHelper::FillEntryData() { + for (auto search_result : search_results_) { + auto segment = static_cast<milvus::segcore::SegmentInterface*>(search_result->segment_); + segment->FillTargetEntry(plan_, *search_result); + } +} + +int64_t +ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offset) { + std::vector<SearchResultPair> result_pairs; + for (int i = 0; i < num_segments_; i++) { + auto search_result = search_results_[i]; + auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi]; + auto offset_end = search_result->topk_per_nq_prefix_sum_[qi + 1]; + if (offset_beg == offset_end) { + continue; + } + auto primary_key = search_result->primary_keys_[offset_beg]; + auto distance = search_result->distances_[offset_beg]; + result_pairs.emplace_back(primary_key, distance, search_result, i, offset_beg, offset_end); + } + + // nq has no results for all segments + if (result_pairs.size() == 0) { + return 0; + } + + int64_t dup_cnt = 0; + std::unordered_set<milvus::PkType> pk_set; + int64_t prev_offset = offset; + while (offset - prev_offset < topk) { + std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>()); + auto& pilot = result_pairs[0]; + auto index = pilot.segment_index_; + auto pk = pilot.primary_key_; + // no valid search result for this nq, break to next + if (pk == INVALID_PK) { + break; + } + // remove duplicates + if (pk_set.count(pk) == 0) { + pilot.search_result_->result_offsets_.push_back(offset++); + final_search_records_[index][qi].push_back(pilot.offset_); + pk_set.insert(pk); + } else { + // skip entity with same primary key + dup_cnt++; + } + pilot.reset(); + } + return dup_cnt; +} + +void +ReduceHelper::ReduceResultData() { for (int i = 0; i < num_segments_; i++) { auto search_result = search_results_[i]; auto result_count = search_result->get_total_result_count(); AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); - AssertInfo(search_result->primary_keys_.size() == result_count, "incorrect search result primary key size"); AssertInfo(search_result->distances_.size() == result_count, "incorrect search result distance size"); + AssertInfo(search_result->seg_offsets_.size() == result_count, "incorrect search result seg offset size"); + AssertInfo(search_result->primary_keys_.size() == result_count, "incorrect search result primary key size"); } - auto nq_offset_begin = nq_slice_offsets_[slice_index]; - auto nq_offset_end = nq_slice_offsets_[slice_index + 1]; - AssertInfo(nq_offset_begin < nq_offset_end, - "illegal nq offsets when ReduceResultData, nq_offset_begin = " + std::to_string(nq_offset_begin) + - ", nq_offset_end = " + std::to_string(nq_offset_end)); - - // `search_records` records the search result offsets - std::vector<std::vector<int64_t>> search_records(num_segments_); int64_t skip_dup_cnt = 0; + for (int64_t slice_index = 0; slice_index < num_slices_; slice_index++) { + auto nq_begin = slice_nqs_prefix_sum_[slice_index]; + auto nq_end = slice_nqs_prefix_sum_[slice_index + 1]; - // reduce search results - int64_t result_offset = 0; - for (int64_t qi = nq_offset_begin; qi < nq_offset_end; qi++) { - std::vector<SearchResultPair> result_pairs; - for (int i = 0; i < num_segments_; i++) { - auto search_result = search_results_[i]; - if (search_result->topk_per_nq_prefix_sum_[qi + 1] - search_result->topk_per_nq_prefix_sum_[qi] == 0) { - continue; - } - auto base_offset = search_result->topk_per_nq_prefix_sum_[qi]; - auto primary_key = search_result->primary_keys_[base_offset]; - auto distance = search_result->distances_[base_offset]; - result_pairs.emplace_back(primary_key, distance, search_result, i, base_offset, - search_result->topk_per_nq_prefix_sum_[qi + 1]); - } - - // nq has no results for all segments - if (result_pairs.size() == 0) { - continue; - } - std::unordered_set<milvus::PkType> pk_set; - int64_t last_nq_result_offset = result_offset; - while (result_offset - last_nq_result_offset < slice_topKs_[slice_index]) { - std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>()); - auto& pilot = result_pairs[0]; - auto index = pilot.segment_index_; - auto curr_pk = pilot.primary_key_; - // no valid search result for this nq, break to next - if (curr_pk == INVALID_PK) { - break; - } - // remove duplicates - if (pk_set.count(curr_pk) == 0) { - pilot.search_result_->result_offsets_.push_back(result_offset++); - search_records[index].push_back(pilot.offset_); - pk_set.insert(curr_pk); - final_real_topKs_[index][qi]++; - } else { - // skip entity with same primary key - skip_dup_cnt++; - } - pilot.reset(); + // reduce search results + int64_t result_offset = 0; + for (int64_t qi = nq_begin; qi < nq_end; qi++) { + skip_dup_cnt += ReduceSearchResultForOneNQ(qi, slice_topKs_[slice_index], result_offset); } } - if (skip_dup_cnt > 0) { LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt; } - - // append search_records to final_search_records - for (int i = 0; i < num_segments_; i++) { - for (int j = 0; j < search_records[i].size(); j++) { - final_search_records_[i].emplace_back(search_records[i][j]); - } - } } std::vector<char> -ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) { - auto nq_offset_begin = nq_slice_offsets_[slice_index_]; - auto nq_offset_end = nq_slice_offsets_[slice_index_ + 1]; - AssertInfo(nq_offset_begin <= nq_offset_end, - "illegal offsets when GetSearchResultDataSlice, nq_offset_begin = " + std::to_string(nq_offset_begin) + - ", nq_offset_end = " + std::to_string(nq_offset_end)); +ReduceHelper::GetSearchResultDataSlice(int slice_index) { + auto nq_begin = slice_nqs_prefix_sum_[slice_index]; + auto nq_end = slice_nqs_prefix_sum_[slice_index + 1]; + + int64_t result_count = 0; + for (auto search_result : search_results_) { + AssertInfo(search_result->topk_per_nq_prefix_sum_.size() == search_result->total_nq_ + 1, + "incorrect topk_per_nq_prefix_sum_ size in search result"); + result_count += + search_result->topk_per_nq_prefix_sum_[nq_end] - search_result->topk_per_nq_prefix_sum_[nq_begin]; + } auto search_result_data = std::make_unique<milvus::proto::schema::SearchResultData>(); // set unify_topK and total_nq - search_result_data->set_top_k(slice_topKs_[slice_index_]); - search_result_data->set_num_queries(nq_offset_end - nq_offset_begin); - search_result_data->mutable_topks()->Resize(nq_offset_end - nq_offset_begin, 0); + search_result_data->set_top_k(slice_topKs_[slice_index]); + search_result_data->set_num_queries(nq_end - nq_begin); + search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0); // `result_pairs` contains the SearchResult and result_offset info, used for filling output fields std::vector<std::pair<SearchResult*, int64_t>> result_pairs(result_count); @@ -306,19 +269,20 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) { search_result_data->mutable_scores()->Resize(result_count, 0); // fill pks and distances - for (auto nq_offset = nq_offset_begin; nq_offset < nq_offset_end; nq_offset++) { - int64_t topK_count = 0; - for (int i = 0; i < search_results_.size(); i++) { - auto search_result = search_results_[i]; + for (auto qi = nq_begin; qi < nq_end; qi++) { + int64_t topk_count = 0; + for (auto search_result : search_results_) { AssertInfo(search_result != nullptr, "null search result when reorganize"); if (search_result->result_offsets_.size() == 0) { continue; } - auto result_start = search_result->topk_per_nq_prefix_sum_[nq_offset]; - auto result_end = search_result->topk_per_nq_prefix_sum_[nq_offset + 1]; - for (auto offset = result_start; offset < result_end; offset++) { - auto loc = search_result->result_offsets_[offset]; + auto topk_start = search_result->topk_per_nq_prefix_sum_[qi]; + auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1]; + topk_count += topk_end - topk_start; + + for (auto ki = topk_start; ki < topk_end; ki++) { + auto loc = search_result->result_offsets_[ki]; AssertInfo(loc < result_count && loc >= 0, "invalid loc when GetSearchResultDataSlice, loc = " + std::to_string(loc) + ", result_count = " + std::to_string(result_count)); @@ -326,12 +290,12 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) { switch (pk_type) { case milvus::DataType::INT64: { search_result_data->mutable_ids()->mutable_int_id()->mutable_data()->Set( - loc, std::visit(Int64PKVisitor{}, search_result->primary_keys_[offset])); + loc, std::visit(Int64PKVisitor{}, search_result->primary_keys_[ki])); break; } case milvus::DataType::VARCHAR: { *search_result_data->mutable_ids()->mutable_str_id()->mutable_data()->Mutable(loc) = - std::visit(StrPKVisitor{}, search_result->primary_keys_[offset]); + std::visit(StrPKVisitor{}, search_result->primary_keys_[ki]); break; } default: { @@ -340,17 +304,14 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) { } // set result distances - search_result_data->mutable_scores()->Set(loc, search_result->distances_[offset]); + search_result_data->mutable_scores()->Set(loc, search_result->distances_[ki]); // set result offset to fill output fields data - result_pairs[loc] = std::make_pair(search_result, offset); + result_pairs[loc] = std::make_pair(search_result, ki); } - - topK_count += search_result->topk_per_nq_prefix_sum_[nq_offset + 1] - - search_result->topk_per_nq_prefix_sum_[nq_offset]; } // update result topKs - search_result_data->mutable_topks()->Set(nq_offset - nq_offset_begin, topK_count); + search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count); } AssertInfo(search_result_data->scores_size() == result_count, diff --git a/internal/core/src/segcore/Reduce.h b/internal/core/src/segcore/Reduce.h index fee85e25e4..9147731de7 100644 --- a/internal/core/src/segcore/Reduce.h +++ b/internal/core/src/segcore/Reduce.h @@ -61,15 +61,26 @@ class ReduceHelper { FilterInvalidSearchResult(SearchResult* search_result); void - ReduceResultData(int slice_index); + FillPrimaryKey(); + + void + RefreshSearchResult(); + + void + FillEntryData(); + + int64_t + ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& result_offset); + + void + ReduceResultData(); std::vector<char> - GetSearchResultDataSlice(int slice_index_, int64_t result_count); + GetSearchResultDataSlice(int slice_index_); private: std::vector<int64_t> slice_topKs_; std::vector<int64_t> slice_nqs_; - int64_t unify_topK_; int64_t total_nq_; int64_t num_segments_; int64_t num_slices_; @@ -77,10 +88,10 @@ class ReduceHelper { milvus::query::Plan* plan_; std::vector<SearchResult*>& search_results_; - // - std::vector<int32_t> nq_slice_offsets_; - std::vector<std::vector<int64_t>> final_search_records_; - std::vector<std::vector<int64_t>> final_real_topKs_; + std::vector<int64_t> slice_nqs_prefix_sum_; + + // dim0: num_segments_; dim1: total_nq_; dim2: offset + std::vector<std::vector<std::vector<int64_t>>> final_search_records_; // output std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_; diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 363423bc3f..55ba7fba86 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -1347,25 +1347,16 @@ testReduceSearchWithExpr(int N, int topK, int num_queries) { auto suc = search_result_data.ParseFromArray(search_result_data_blobs->blobs[i].data(), search_result_data_blobs->blobs[i].size()); assert(suc); - - assert(suc); assert(search_result_data.num_queries() == slice_nqs[i]); assert(search_result_data.top_k() == slice_topKs[i]); - assert(search_result_data.scores().size() == slice_topKs[i] * slice_nqs[i]); - assert(search_result_data.ids().int_id().data_size() == slice_topKs[i] * slice_nqs[i]); + assert(search_result_data.scores().size() == search_result_data.topks().at(0) * slice_nqs[i]); + assert(search_result_data.ids().int_id().data_size() == search_result_data.topks().at(0) * slice_nqs[i]); - // check topKs + // check real topks assert(search_result_data.topks().size() == slice_nqs[i]); - for (int j = 0; j < search_result_data.topks().size(); j++) { - assert(search_result_data.topks().at(j) == slice_topKs[i]); + for (auto real_topk : search_result_data.topks()) { + assert(real_topk <= slice_topKs[i]); } - - // assert(search_result_data.scores().size() == slice_topKs[i] * slice_nqs[i]); - // assert(search_result_data.ids().int_id().data_size() == slice_topKs[i] * slice_nqs[i]); - // assert(search_result_data.top_k() == topK); - // assert(search_result_data.num_queries() == req_sizes[i]); - // assert(search_result_data.scores().size() == topK * req_sizes[i]); - // assert(search_result_data.ids().int_id().data_size() == topK * req_sizes[i]); } DeleteSearchResultDataBlobs(cSearchResultData); @@ -1378,6 +1369,8 @@ testReduceSearchWithExpr(int N, int topK, int num_queries) { } TEST(CApiTest, ReduceSearchWithExpr) { + testReduceSearchWithExpr(2, 1, 1); + testReduceSearchWithExpr(2, 10, 10); testReduceSearchWithExpr(100, 1, 1); testReduceSearchWithExpr(100, 10, 10); testReduceSearchWithExpr(10000, 1, 1);