From 6c75301c704598d99729fea6b124dd4e98d613a2 Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Thu, 12 Aug 2021 18:00:11 +0800 Subject: [PATCH] optimize search reduce logic (#7066) Signed-off-by: yudong.cai --- internal/core/src/common/Types.h | 1 + internal/core/src/query/Plan.cpp | 2 +- .../core/src/segcore/SegmentInterface.cpp | 1 + internal/core/src/segcore/reduce_c.cpp | 184 ++++++------------ internal/core/src/segcore/reduce_c.h | 22 +-- internal/core/src/segcore/segment_c.cpp | 14 -- internal/core/src/segcore/segment_c.h | 3 - internal/core/unittest/test_c_api.cpp | 63 +++--- internal/querynode/query_collection.go | 45 ++--- internal/querynode/reduce.go | 66 +------ internal/querynode/reduce_test.go | 9 +- internal/querynode/segment.go | 20 -- internal/querynode/segment_test.go | 11 +- internal/querynode/streaming.go | 34 ++-- 14 files changed, 122 insertions(+), 353 deletions(-) diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 24c4b9e805..935a868728 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -75,6 +75,7 @@ struct SearchResult { public: // TODO(gexi): utilize these field + void* segment_; std::vector internal_seg_offsets_; std::vector result_offsets_; std::vector> row_data_; diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index ea6da58817..69e62b83d2 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -546,7 +546,7 @@ GetNumOfQueries(const PlaceholderGroup* group) { return group->at(0).num_of_queries_; } -[[maybe_unused]] std::unique_ptr +std::unique_ptr CreateRetrievePlan(const Schema& schema, proto::segcore::RetrieveRequest&& request) { auto plan = std::make_unique(); plan->ids_ = std::unique_ptr(request.release_ids()); diff --git a/internal/core/src/segcore/SegmentInterface.cpp b/internal/core/src/segcore/SegmentInterface.cpp index 2dbb654a70..e282b02ae4 100644 --- a/internal/core/src/segcore/SegmentInterface.cpp +++ b/internal/core/src/segcore/SegmentInterface.cpp @@ -79,6 +79,7 @@ SegmentInternalInterface::Search(const query::Plan* plan, check_search(plan); query::ExecPlanNodeVisitor visitor(*this, timestamp, placeholder_group); auto results = visitor.get_moved_result(*plan->plan_node_); + results.segment_ = (void*)this; return results; } diff --git a/internal/core/src/segcore/reduce_c.cpp b/internal/core/src/segcore/reduce_c.cpp index 46a710caba..877d90569c 100644 --- a/internal/core/src/segcore/reduce_c.cpp +++ b/internal/core/src/segcore/reduce_c.cpp @@ -11,10 +11,12 @@ #include #include -#include "segcore/reduce_c.h" +#include "query/Plan.h" +#include "segcore/reduce_c.h" #include "segcore/Reduce.h" #include "segcore/ReduceStructure.h" +#include "segcore/SegmentInterface.h" #include "common/Types.h" #include "pb/milvus.pb.h" @@ -26,7 +28,7 @@ MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, fl return status.code(); } -struct MarshaledHitsPeerGroup { +struct MarshaledHitsPerGroup { std::vector hits_; std::vector blob_length_; }; @@ -41,7 +43,7 @@ struct MarshaledHits { return marshaled_hits_.size(); } - std::vector marshaled_hits_; + std::vector marshaled_hits_; }; void @@ -53,16 +55,16 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) { void GetResultData(std::vector>& search_records, std::vector& search_results, - int64_t query_offset, - bool* is_selected, + int64_t query_idx, int64_t topk) { auto num_segments = search_results.size(); AssertInfo(num_segments > 0, "num segment must greater than 0"); std::vector result_pairs; + int64_t query_offset = query_idx * topk; for (int j = 0; j < num_segments; ++j) { - auto distance = search_results[j]->result_distances_[query_offset]; auto search_result = search_results[j]; AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); + auto distance = search_result->result_distances_[query_offset]; result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j)); } int64_t loc_offset = query_offset; @@ -72,24 +74,21 @@ GetResultData(std::vector>& search_records, std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>()); auto& result_pair = result_pairs[0]; auto index = result_pair.index_; - is_selected[index] = true; result_pair.search_result_->result_offsets_.push_back(loc_offset++); search_records[index].push_back(result_pair.offset_++); } } void -ResetSearchResult(std::vector>& search_records, - std::vector& search_results, - bool* is_selected) { +ResetSearchResult(std::vector>& search_records, std::vector& search_results) { auto num_segments = search_results.size(); AssertInfo(num_segments > 0, "num segment must greater than 0"); for (int i = 0; i < num_segments; i++) { - if (is_selected[i] == false) { - continue; - } auto search_result = search_results[i]; AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); + if (search_result->result_offsets_.size() == 0) { + continue; + } std::vector result_distances; std::vector internal_seg_offsets; @@ -108,8 +107,9 @@ ResetSearchResult(std::vector>& search_records, } CStatus -ReduceSearchResults(CSearchResult* c_search_results, int64_t num_segments, bool* is_selected) { +ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_results, int64_t num_segments) { try { + auto plan = (milvus::query::Plan*)c_plan; std::vector search_results; for (int i = 0; i < num_segments; ++i) { search_results.push_back((SearchResult*)c_search_results[i]); @@ -118,12 +118,17 @@ ReduceSearchResults(CSearchResult* c_search_results, int64_t num_segments, bool* auto num_queries = search_results[0]->num_queries_; std::vector> search_records(num_segments); - int64_t query_offset = 0; - for (int j = 0; j < num_queries; ++j) { - GetResultData(search_records, search_results, query_offset, is_selected, topk); - query_offset += topk; + for (int i = 0; i < num_queries; ++i) { + GetResultData(search_records, search_results, i, topk); } - ResetSearchResult(search_records, search_results, is_selected); + ResetSearchResult(search_records, search_results); + + for (int i = 0; i < num_segments; ++i) { + auto search_result = search_results[i]; + auto segment = (milvus::segcore::SegmentInterface*)(search_result->segment_); + segment->FillTargetEntry(plan, *search_result); + } + auto status = CStatus(); status.error_code = Success; status.error_msg = ""; @@ -137,43 +142,29 @@ ReduceSearchResults(CSearchResult* c_search_results, int64_t num_segments, bool* } CStatus -ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, - CPlaceholderGroup* c_placeholder_groups, - int64_t num_groups, - CSearchResult* c_search_results, - bool* is_selected, - int64_t num_segments, - CSearchPlan c_plan) { +ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, CSearchResult* c_search_results, int64_t num_segments) { try { - auto marshaledHits = std::make_unique(num_groups); - auto topk = GetTopK(c_plan); - std::vector num_queries_peer_group(num_groups); - int64_t total_num_queries = 0; - for (int i = 0; i < num_groups; i++) { - auto num_queries = GetNumOfQueries(c_placeholder_groups[i]); - num_queries_peer_group[i] = num_queries; - total_num_queries += num_queries; - } + auto marshaledHits = std::make_unique(1); + auto sr = (SearchResult*)c_search_results[0]; + auto topk = sr->topk_; + auto num_queries = sr->num_queries_; - std::vector result_distances(total_num_queries * topk); - std::vector result_ids(total_num_queries * topk); - std::vector> row_datas(total_num_queries * topk); - std::vector temp_ids; + std::vector result_distances(num_queries * topk); + std::vector> row_datas(num_queries * topk); std::vector counts(num_segments); for (int i = 0; i < num_segments; i++) { - if (is_selected[i] == false) { - continue; - } auto search_result = (SearchResult*)c_search_results[i]; AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); auto size = search_result->result_offsets_.size(); + if (size == 0) { + continue; + } #pragma omp parallel for for (int j = 0; j < size; j++) { auto loc = search_result->result_offsets_[j]; result_distances[loc] = search_result->result_distances_[j]; row_datas[loc] = search_result->row_data_[j]; - memcpy(&result_ids[loc], search_result->row_data_[j].data(), sizeof(int64_t)); } counts[i] = size; } @@ -182,100 +173,35 @@ ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, for (int i = 0; i < num_segments; i++) { total_count += counts[i]; } - AssertInfo(total_count == total_num_queries * topk, - "the reduces result's size less than total_num_queries*topk"); + AssertInfo(total_count == num_queries * topk, "the reduces result's size less than total_num_queries*topk"); - int64_t last_offset = 0; - for (int i = 0; i < num_groups; i++) { - MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i]; - hits_peer_group.hits_.resize(num_queries_peer_group[i]); - hits_peer_group.blob_length_.resize(num_queries_peer_group[i]); - std::vector hits(num_queries_peer_group[i]); + MarshaledHitsPerGroup& hits_per_group = (*marshaledHits).marshaled_hits_[0]; + hits_per_group.hits_.resize(num_queries); + hits_per_group.blob_length_.resize(num_queries); + std::vector hits(num_queries); #pragma omp parallel for - for (int m = 0; m < num_queries_peer_group[i]; m++) { - for (int n = 0; n < topk; n++) { - int64_t result_offset = last_offset + m * topk + n; - hits[m].add_ids(result_ids[result_offset]); - hits[m].add_scores(result_distances[result_offset]); - auto& row_data = row_datas[result_offset]; - hits[m].add_row_data(row_data.data(), row_data.size()); - } + for (int m = 0; m < num_queries; m++) { + for (int n = 0; n < topk; n++) { + int64_t result_offset = m * topk + n; + hits[m].add_scores(result_distances[result_offset]); + auto& row_data = row_datas[result_offset]; + hits[m].add_row_data(row_data.data(), row_data.size()); + hits[m].add_ids(*(int64_t*)row_data.data()); } - last_offset = last_offset + num_queries_peer_group[i] * topk; + } #pragma omp parallel for - for (int j = 0; j < num_queries_peer_group[i]; j++) { - auto blob = hits[j].SerializeAsString(); - hits_peer_group.hits_[j] = blob; - hits_peer_group.blob_length_[j] = blob.size(); - } + for (int j = 0; j < num_queries; j++) { + auto blob = hits[j].SerializeAsString(); + hits_per_group.hits_[j] = blob; + hits_per_group.blob_length_[j] = blob.size(); } auto status = CStatus(); status.error_code = Success; status.error_msg = ""; - auto marshled_res = (CMarshaledHits)marshaledHits.release(); - *c_marshaled_hits = marshled_res; - return status; - } catch (std::exception& e) { - auto status = CStatus(); - status.error_code = UnexpectedError; - status.error_msg = strdup(e.what()); - *c_marshaled_hits = nullptr; - return status; - } -} - -CStatus -ReorganizeSingleSearchResult(CMarshaledHits* c_marshaled_hits, - CPlaceholderGroup* c_placeholder_groups, - int64_t num_groups, - CSearchResult c_search_result, - CSearchPlan c_plan) { - try { - auto marshaledHits = std::make_unique(num_groups); - auto search_result = (SearchResult*)c_search_result; - auto topk = GetTopK(c_plan); - std::vector num_queries_peer_group; - int64_t total_num_queries = 0; - for (int i = 0; i < num_groups; i++) { - auto num_queries = GetNumOfQueries(c_placeholder_groups[i]); - num_queries_peer_group.push_back(num_queries); - } - - int64_t last_offset = 0; - for (int i = 0; i < num_groups; i++) { - MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i]; - hits_peer_group.hits_.resize(num_queries_peer_group[i]); - hits_peer_group.blob_length_.resize(num_queries_peer_group[i]); - std::vector hits(num_queries_peer_group[i]); -#pragma omp parallel for - for (int m = 0; m < num_queries_peer_group[i]; m++) { - for (int n = 0; n < topk; n++) { - int64_t result_offset = last_offset + m * topk + n; - hits[m].add_scores(search_result->result_distances_[result_offset]); - auto& row_data = search_result->row_data_[result_offset]; - hits[m].add_row_data(row_data.data(), row_data.size()); - int64_t result_id; - memcpy(&result_id, row_data.data(), sizeof(int64_t)); - hits[m].add_ids(result_id); - } - } - last_offset = last_offset + num_queries_peer_group[i] * topk; - -#pragma omp parallel for - for (int j = 0; j < num_queries_peer_group[i]; j++) { - auto blob = hits[j].SerializeAsString(); - hits_peer_group.hits_[j] = blob; - hits_peer_group.blob_length_[j] = blob.size(); - } - } - - auto status = CStatus(); - status.error_code = Success; - status.error_msg = ""; - auto marshled_res = (CMarshaledHits)marshaledHits.release(); - *c_marshaled_hits = marshled_res; + auto marshaled_res = (CMarshaledHits)marshaledHits.release(); + *c_marshaled_hits = marshaled_res; return status; } catch (std::exception& e) { auto status = CStatus(); @@ -318,14 +244,14 @@ GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits) { } int64_t -GetNumQueriesPeerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index) { +GetNumQueriesPerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index) { auto marshaled_hits = (MarshaledHits*)c_marshaled_hits; auto& hits = marshaled_hits->marshaled_hits_[group_index].hits_; return hits.size(); } void -GetHitSizePeerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query) { +GetHitSizePerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query) { auto marshaled_hits = (MarshaledHits*)c_marshaled_hits; auto& blob_lens = marshaled_hits->marshaled_hits_[group_index].blob_length_; for (int i = 0; i < blob_lens.size(); i++) { diff --git a/internal/core/src/segcore/reduce_c.h b/internal/core/src/segcore/reduce_c.h index 9f68af3368..3acd2a01cc 100644 --- a/internal/core/src/segcore/reduce_c.h +++ b/internal/core/src/segcore/reduce_c.h @@ -15,6 +15,7 @@ extern "C" { #include #include +#include "segcore/plan_c.h" #include "segcore/segment_c.h" #include "common/type_c.h" @@ -27,23 +28,10 @@ int MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids); CStatus -ReduceSearchResults(CSearchResult* search_results, int64_t num_segments, bool* is_selected); +ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* search_results, int64_t num_segments); CStatus -ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, - CPlaceholderGroup* c_placeholder_groups, - int64_t num_groups, - CSearchResult* c_search_results, - bool* is_selected, - int64_t num_segments, - CSearchPlan c_plan); - -CStatus -ReorganizeSingleSearchResult(CMarshaledHits* c_marshaled_hits, - CPlaceholderGroup* c_placeholder_groups, - int64_t num_groups, - CSearchResult c_search_result, - CSearchPlan c_plan); +ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, CSearchResult* c_search_results, int64_t num_segments); int64_t GetHitsBlobSize(CMarshaledHits c_marshaled_hits); @@ -52,10 +40,10 @@ void GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits); int64_t -GetNumQueriesPeerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index); +GetNumQueriesPerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index); void -GetHitSizePeerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query); +GetHitSizePerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query); #ifdef __cplusplus } diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index fe6f38078a..bf9ce12d69 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -88,20 +88,6 @@ Search(CSegmentInterface c_segment, } } -CStatus -FillTargetEntry(CSegmentInterface c_segment, CSearchPlan c_plan, CSearchResult c_result) { - auto segment = (milvus::segcore::SegmentInterface*)c_segment; - auto plan = (milvus::query::Plan*)c_plan; - auto result = (milvus::SearchResult*)c_result; - - try { - segment->FillTargetEntry(plan, *result); - return milvus::SuccessCStatus(); - } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); - } -} - int64_t GetMemoryUsageInBytes(CSegmentInterface c_segment) { auto segment = (milvus::segcore::SegmentInterface*)c_segment; diff --git a/internal/core/src/segcore/segment_c.h b/internal/core/src/segcore/segment_c.h index f24e1876b1..740db5b5c3 100644 --- a/internal/core/src/segcore/segment_c.h +++ b/internal/core/src/segcore/segment_c.h @@ -46,9 +46,6 @@ Search(CSegmentInterface c_segment, CProtoResult GetEntityByIds(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp); -CStatus -FillTargetEntry(CSegmentInterface c_segment, CSearchPlan c_plan, CSearchResult result); - int64_t GetMemoryUsageInBytes(CSegmentInterface c_segment); diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 76133d95ef..12d75c7992 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -33,7 +33,6 @@ namespace chrono = std::chrono; using namespace milvus; using namespace milvus::segcore; -// using namespace milvus::proto; using namespace milvus::knowhere; namespace { @@ -203,7 +202,7 @@ TEST(CApiTest, InsertTest) { int N = 10000; auto [raw_data, timestamps, uids] = generate_data(N); - auto line_sizeof = (sizeof(int) + sizeof(float) * 16); + auto line_sizeof = (sizeof(int) + sizeof(float) * DIM); int64_t offset; PreInsert(segment, N, &offset); @@ -237,7 +236,7 @@ TEST(CApiTest, SearchTest) { int N = 10000; auto [raw_data, timestamps, uids] = generate_data(N); - auto line_sizeof = (sizeof(int) + sizeof(float) * 16); + auto line_sizeof = (sizeof(int) + sizeof(float) * DIM); int64_t offset; PreInsert(segment, N, &offset); @@ -294,7 +293,7 @@ TEST(CApiTest, SearchTestWithExpr) { int N = 10000; auto [raw_data, timestamps, uids] = generate_data(N); - auto line_sizeof = (sizeof(int) + sizeof(float) * 16); + auto line_sizeof = (sizeof(int) + sizeof(float) * DIM); int64_t offset; PreInsert(segment, N, &offset); @@ -350,7 +349,7 @@ TEST(CApiTest, GetMemoryUsageInBytesTest) { int N = 10000; auto [raw_data, timestamps, uids] = generate_data(N); - auto line_sizeof = (sizeof(int) + sizeof(float) * 16); + auto line_sizeof = (sizeof(int) + sizeof(float) * DIM); int64_t offset; PreInsert(segment, N, &offset); @@ -392,7 +391,7 @@ TEST(CApiTest, GetRowCountTest) { int N = 10000; auto [raw_data, timestamps, uids] = generate_data(N); - auto line_sizeof = (sizeof(int) + sizeof(float) * 16); + auto line_sizeof = (sizeof(int) + sizeof(float) * DIM); int64_t offset; PreInsert(segment, N, &offset); @@ -454,7 +453,7 @@ TEST(CApiTest, Reduce) { int N = 10000; auto [raw_data, timestamps, uids] = generate_data(N); - auto line_sizeof = (sizeof(int) + sizeof(float) * 16); + auto line_sizeof = (sizeof(int) + sizeof(float) * DIM); int64_t offset; PreInsert(segment, N, &offset); @@ -503,14 +502,10 @@ TEST(CApiTest, Reduce) { results.push_back(res1); results.push_back(res2); - bool is_selected[2] = {false, false}; - status = ReduceSearchResults(results.data(), 2, is_selected); + status = ReduceSearchResultsAndFillData(plan, results.data(), results.size()); assert(status.error_code == Success); - FillTargetEntry(segment, plan, res1); - FillTargetEntry(segment, plan, res2); void* reorganize_search_result = nullptr; - status = ReorganizeSearchResults(&reorganize_search_result, placeholderGroups.data(), 1, results.data(), - is_selected, 2, plan); + status = ReorganizeSearchResults(&reorganize_search_result, results.data(), results.size()); assert(status.error_code == Success); auto hits_blob_size = GetHitsBlobSize(reorganize_search_result); assert(hits_blob_size > 0); @@ -518,12 +513,12 @@ TEST(CApiTest, Reduce) { hits_blob.resize(hits_blob_size); GetHitsBlob(reorganize_search_result, hits_blob.data()); assert(hits_blob.data() != nullptr); - auto num_queries_group = GetNumQueriesPeerGroup(reorganize_search_result, 0); - assert(num_queries_group == 10); - std::vector hit_size_peer_query; - hit_size_peer_query.resize(num_queries_group); - GetHitSizePeerQueries(reorganize_search_result, 0, hit_size_peer_query.data()); - assert(hit_size_peer_query[0] > 0); + auto num_queries_group = GetNumQueriesPerGroup(reorganize_search_result, 0); + assert(num_queries_group == num_queries); + std::vector hit_size_per_query; + hit_size_per_query.resize(num_queries_group); + GetHitSizePerQueries(reorganize_search_result, 0, hit_size_per_query.data()); + assert(hit_size_per_query[0] > 0); DeleteSearchPlan(plan); DeletePlaceholderGroup(placeholderGroup); @@ -540,7 +535,7 @@ TEST(CApiTest, ReduceSearchWithExpr) { int N = 10000; auto [raw_data, timestamps, uids] = generate_data(N); - auto line_sizeof = (sizeof(int) + sizeof(float) * 16); + auto line_sizeof = (sizeof(int) + sizeof(float) * DIM); int64_t offset; PreInsert(segment, N, &offset); @@ -584,14 +579,10 @@ TEST(CApiTest, ReduceSearchWithExpr) { results.push_back(res1); results.push_back(res2); - bool is_selected[2] = {false, false}; - status = ReduceSearchResults(results.data(), 2, is_selected); + status = ReduceSearchResultsAndFillData(plan, results.data(), results.size()); assert(status.error_code == Success); - FillTargetEntry(segment, plan, res1); - FillTargetEntry(segment, plan, res2); void* reorganize_search_result = nullptr; - status = ReorganizeSearchResults(&reorganize_search_result, placeholderGroups.data(), 1, results.data(), - is_selected, 2, plan); + status = ReorganizeSearchResults(&reorganize_search_result, results.data(), results.size()); assert(status.error_code == Success); auto hits_blob_size = GetHitsBlobSize(reorganize_search_result); assert(hits_blob_size > 0); @@ -599,12 +590,12 @@ TEST(CApiTest, ReduceSearchWithExpr) { hits_blob.resize(hits_blob_size); GetHitsBlob(reorganize_search_result, hits_blob.data()); assert(hits_blob.data() != nullptr); - auto num_queries_group = GetNumQueriesPeerGroup(reorganize_search_result, 0); - assert(num_queries_group == 10); - std::vector hit_size_peer_query; - hit_size_peer_query.resize(num_queries_group); - GetHitSizePeerQueries(reorganize_search_result, 0, hit_size_peer_query.data()); - assert(hit_size_peer_query[0] > 0); + auto num_queries_group = GetNumQueriesPerGroup(reorganize_search_result, 0); + assert(num_queries_group == num_queries); + std::vector hit_size_per_query; + hit_size_per_query.resize(num_queries_group); + GetHitSizePerQueries(reorganize_search_result, 0, hit_size_per_query.data()); + assert(hit_size_per_query[0] > 0); DeleteSearchPlan(plan); DeletePlaceholderGroup(placeholderGroup); @@ -1921,10 +1912,8 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) { std::vector results; results.push_back(c_search_result_on_bigIndex); - bool is_selected[1] = {false}; - status = ReduceSearchResults(results.data(), 1, is_selected); + status = ReduceSearchResultsAndFillData(plan, results.data(), results.size()); assert(status.error_code == Success); - FillTargetEntry(segment, plan, c_search_result_on_bigIndex); auto search_result_on_bigIndex = (*(SearchResult*)c_search_result_on_bigIndex); for (int i = 0; i < num_queries; ++i) { @@ -2073,10 +2062,8 @@ vector_anns: < std::vector results; results.push_back(c_search_result_on_bigIndex); - bool is_selected[1] = {false}; - status = ReduceSearchResults(results.data(), 1, is_selected); + status = ReduceSearchResultsAndFillData(plan, results.data(), results.size()); assert(status.error_code == Success); - FillTargetEntry(segment, plan, c_search_result_on_bigIndex); auto search_result_on_bigIndex = (*(SearchResult*)c_search_result_on_bigIndex); for (int i = 0; i < num_queries; ++i) { diff --git a/internal/querynode/query_collection.go b/internal/querynode/query_collection.go index 83ad1b3e51..027652375e 100644 --- a/internal/querynode/query_collection.go +++ b/internal/querynode/query_collection.go @@ -843,7 +843,6 @@ func (q *queryCollection) search(msg queryMsg) error { } searchResults := make([]*SearchResult, 0) - matchedSegments := make([]*Segment, 0) sealedSegmentSearched := make([]UniqueID, 0) // historical search @@ -853,7 +852,6 @@ func (q *queryCollection) search(msg queryMsg) error { return err1 } searchResults = append(searchResults, hisSearchResults...) - matchedSegments = append(matchedSegments, hisSegmentResults...) for _, seg := range hisSegmentResults { sealedSegmentSearched = append(sealedSegmentSearched, seg.segmentID) } @@ -863,14 +861,12 @@ func (q *queryCollection) search(msg queryMsg) error { var err2 error for _, channel := range collection.getVChannels() { var strSearchResults []*SearchResult - var strSegmentResults []*Segment - strSearchResults, strSegmentResults, err2 = q.streaming.search(searchRequests, collectionID, searchMsg.PartitionIDs, channel, plan, travelTimestamp) + strSearchResults, err2 = q.streaming.search(searchRequests, collectionID, searchMsg.PartitionIDs, channel, plan, travelTimestamp) if err2 != nil { log.Warn(err2.Error()) return err2 } searchResults = append(searchResults, strSearchResults...) - matchedSegments = append(matchedSegments, strSegmentResults...) } tr.Record("streaming search done") @@ -939,38 +935,19 @@ func (q *queryCollection) search(msg queryMsg) error { } } - inReduced := make([]bool, len(searchResults)) numSegment := int64(len(searchResults)) var marshaledHits *MarshaledHits = nil - if numSegment == 1 { - inReduced[0] = true - err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) - sp.LogFields(oplog.String("statistical time", "fillTargetEntry end")) - if err != nil { - return err - } - marshaledHits, err = reorganizeSingleSearchResult(plan, searchRequests, searchResults[0]) - sp.LogFields(oplog.String("statistical time", "reorganizeSingleSearchResult end")) - if err != nil { - return err - } - } else { - err = reduceSearchResults(searchResults, numSegment, inReduced) - sp.LogFields(oplog.String("statistical time", "reduceSearchResults end")) - if err != nil { - return err - } - err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) - sp.LogFields(oplog.String("statistical time", "fillTargetEntry end")) - if err != nil { - return err - } - marshaledHits, err = reorganizeSearchResults(plan, searchRequests, searchResults, numSegment, inReduced) - sp.LogFields(oplog.String("statistical time", "reorganizeSearchResults end")) - if err != nil { - return err - } + err = reduceSearchResultsAndFillData(plan, searchResults, numSegment) + sp.LogFields(oplog.String("statistical time", "reduceSearchResults end")) + if err != nil { + return err } + marshaledHits, err = reorganizeSearchResults(searchResults, numSegment) + sp.LogFields(oplog.String("statistical time", "reorganizeSearchResults end")) + if err != nil { + return err + } + hitsBlob, err := marshaledHits.getHitsBlob() sp.LogFields(oplog.String("statistical time", "getHitsBlob end")) if err != nil { diff --git a/internal/querynode/reduce.go b/internal/querynode/reduce.go index a8cdfcba08..5dbba1dc26 100644 --- a/internal/querynode/reduce.go +++ b/internal/querynode/reduce.go @@ -23,10 +23,7 @@ import "C" import ( "errors" "strconv" - "sync" "unsafe" - - "github.com/milvus-io/milvus/internal/log" ) type SearchResult struct { @@ -37,17 +34,15 @@ type MarshaledHits struct { cMarshaledHits C.CMarshaledHits } -func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inReduced []bool) error { +func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchResult, numSegments int64) error { cSearchResults := make([]C.CSearchResult, 0) for _, res := range searchResults { cSearchResults = append(cSearchResults, res.cSearchResult) } cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0]) cNumSegments := C.long(numSegments) - cInReduced := (*C.bool)(&inReduced[0]) - - status := C.ReduceSearchResults(cSearchResultPtr, cNumSegments, cInReduced) + status := C.ReduceSearchResultsAndFillData(plan.cSearchPlan, cSearchResultPtr, cNumSegments) errorCode := status.error_code if errorCode != 0 { @@ -58,33 +53,7 @@ func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inRed return nil } -func fillTargetEntry(plan *SearchPlan, searchResults []*SearchResult, matchedSegments []*Segment, inReduced []bool) error { - wg := &sync.WaitGroup{} - //fmt.Println(inReduced) - for i := range inReduced { - if inReduced[i] { - wg.Add(1) - go func(i int) { - err := matchedSegments[i].fillTargetEntry(plan, searchResults[i]) - if err != nil { - log.Warn(err.Error()) - } - wg.Done() - }(i) - } - } - wg.Wait() - return nil -} - -func reorganizeSearchResults(plan *SearchPlan, searchRequests []*searchRequest, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) { - cPlaceholderGroups := make([]C.CPlaceholderGroup, 0) - for _, pg := range searchRequests { - cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup) - } - var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0]) - var cNumGroup = (C.long)(len(searchRequests)) - +func reorganizeSearchResults(searchResults []*SearchResult, numSegments int64) (*MarshaledHits, error) { cSearchResults := make([]C.CSearchResult, 0) for _, res := range searchResults { cSearchResults = append(cSearchResults, res.cSearchResult) @@ -92,32 +61,9 @@ func reorganizeSearchResults(plan *SearchPlan, searchRequests []*searchRequest, cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0]) var cNumSegments = C.long(numSegments) - var cInReduced = (*C.bool)(&inReduced[0]) var cMarshaledHits C.CMarshaledHits - status := C.ReorganizeSearchResults(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResultPtr, cInReduced, cNumSegments, plan.cSearchPlan) - errorCode := status.error_code - - if errorCode != 0 { - errorMsg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return nil, errors.New("reorganizeSearchResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) - } - return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil -} - -func reorganizeSingleSearchResult(plan *SearchPlan, placeholderGroups []*searchRequest, searchResult *SearchResult) (*MarshaledHits, error) { - cPlaceholderGroups := make([]C.CPlaceholderGroup, 0) - for _, pg := range placeholderGroups { - cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup) - } - var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0]) - var cNumGroup = (C.long)(len(placeholderGroups)) - - cSearchResult := searchResult.cSearchResult - var cMarshaledHits C.CMarshaledHits - - status := C.ReorganizeSingleSearchResult(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResult, plan.cSearchPlan) + status := C.ReorganizeSearchResults(&cMarshaledHits, cSearchResultPtr, cNumSegments) errorCode := status.error_code if errorCode != 0 { @@ -143,10 +89,10 @@ func (mh *MarshaledHits) getHitsBlob() ([]byte, error) { func (mh *MarshaledHits) hitBlobSizeInGroup(groupOffset int64) ([]int64, error) { cGroupOffset := (C.long)(groupOffset) - numQueries := C.GetNumQueriesPeerGroup(mh.cMarshaledHits, cGroupOffset) + numQueries := C.GetNumQueriesPerGroup(mh.cMarshaledHits, cGroupOffset) result := make([]int64, int64(numQueries)) cResult := (*C.long)(&result[0]) - C.GetHitSizePeerQueries(mh.cMarshaledHits, cGroupOffset, cResult) + C.GetHitSizePerQueries(mh.cMarshaledHits, cGroupOffset, cResult) return result, nil } diff --git a/internal/querynode/reduce_test.go b/internal/querynode/reduce_test.go index dee528d7b1..1175e21a88 100644 --- a/internal/querynode/reduce_test.go +++ b/internal/querynode/reduce_test.go @@ -71,19 +71,14 @@ func TestReduce_AllFunc(t *testing.T) { placeholderGroups = append(placeholderGroups, holder) searchResults := make([]*SearchResult, 0) - matchedSegment := make([]*Segment, 0) searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{0}) assert.Nil(t, err) searchResults = append(searchResults, searchResult) - matchedSegment = append(matchedSegment, segment) - testReduce := make([]bool, len(searchResults)) - err = reduceSearchResults(searchResults, 1, testReduce) - assert.Nil(t, err) - err = fillTargetEntry(plan, searchResults, matchedSegment, testReduce) + err = reduceSearchResultsAndFillData(plan, searchResults, 1) assert.Nil(t, err) - marshaledHits, err := reorganizeSearchResults(plan, placeholderGroups, searchResults, 1, testReduce) + marshaledHits, err := reorganizeSearchResults(searchResults, 1) assert.NotNil(t, marshaledHits) assert.Nil(t, err) diff --git a/internal/querynode/segment.go b/internal/querynode/segment.go index b647c2710d..018dbe8c4b 100644 --- a/internal/querynode/segment.go +++ b/internal/querynode/segment.go @@ -316,26 +316,6 @@ func (s *Segment) getEntityByIds(plan *RetrievePlan) (*segcorepb.RetrieveResults return result, nil } -func (s *Segment) fillTargetEntry(plan *SearchPlan, result *SearchResult) error { - s.segPtrMu.RLock() - defer s.segPtrMu.RUnlock() - if s.segmentPtr == nil { - return errors.New("null seg core pointer") - } - - log.Debug("segment fill target entry, ", zap.Int64("segment ID = ", s.segmentID)) - var status = C.FillTargetEntry(s.segmentPtr, plan.cSearchPlan, result.cSearchResult) - errorCode := status.error_code - - if errorCode != 0 { - errorMsg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New("FillTargetEntry failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) - } - - return nil -} - //-------------------------------------------------------------------------------------- index info interface func (s *Segment) setIndexName(fieldID int64, name string) error { s.paramMutex.Lock() diff --git a/internal/querynode/segment_test.go b/internal/querynode/segment_test.go index 9f3957cfa8..7bf1d4c7b0 100644 --- a/internal/querynode/segment_test.go +++ b/internal/querynode/segment_test.go @@ -435,22 +435,15 @@ func TestSegment_segmentSearch(t *testing.T) { placeholderGroups = append(placeholderGroups, holder) searchResults := make([]*SearchResult, 0) - matchedSegments := make([]*Segment, 0) - searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{travelTimestamp}) assert.Nil(t, err) - searchResults = append(searchResults, searchResult) - matchedSegments = append(matchedSegments, segment) /////////////////////////////////// - inReduced := make([]bool, len(searchResults)) numSegment := int64(len(searchResults)) - err2 := reduceSearchResults(searchResults, numSegment, inReduced) - assert.NoError(t, err2) - err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) + err = reduceSearchResultsAndFillData(plan, searchResults, numSegment) assert.NoError(t, err) - marshaledHits, err := reorganizeSearchResults(plan, placeholderGroups, searchResults, numSegment, inReduced) + marshaledHits, err := reorganizeSearchResults(searchResults, numSegment) assert.NoError(t, err) hitsBlob, err := marshaledHits.getHitsBlob() assert.NoError(t, err) diff --git a/internal/querynode/streaming.go b/internal/querynode/streaming.go index c96955b9ce..e1c2dd0c94 100644 --- a/internal/querynode/streaming.go +++ b/internal/querynode/streaming.go @@ -61,15 +61,10 @@ func (s *streaming) close() { s.replica.freeAll() } -func (s *streaming) search(searchReqs []*searchRequest, - collID UniqueID, - partIDs []UniqueID, - vChannel Channel, - plan *SearchPlan, - searchTs Timestamp) ([]*SearchResult, []*Segment, error) { +func (s *streaming) search(searchReqs []*searchRequest, collID UniqueID, partIDs []UniqueID, vChannel Channel, + plan *SearchPlan, searchTs Timestamp) ([]*SearchResult, error) { searchResults := make([]*SearchResult, 0) - segmentResults := make([]*Segment, 0) // get streaming partition ids var searchPartIDs []UniqueID @@ -77,10 +72,10 @@ func (s *streaming) search(searchReqs []*searchRequest, strPartIDs, err := s.replica.getPartitionIDs(collID) if len(strPartIDs) == 0 { // no partitions in collection, do empty search - return nil, nil, nil + return nil, nil } if err != nil { - return searchResults, segmentResults, err + return searchResults, err } log.Debug("no partition specified, search all partitions", zap.Any("collectionID", collID), @@ -104,22 +99,20 @@ func (s *streaming) search(searchReqs []*searchRequest, col, err := s.replica.getCollectionByID(collID) if err != nil { - return nil, nil, err + return nil, err } // all partitions have been released if len(searchPartIDs) == 0 && col.getLoadType() == loadTypePartition { - return nil, nil, errors.New("partitions have been released , collectionID = " + - fmt.Sprintln(collID) + - "target partitionIDs = " + - fmt.Sprintln(partIDs)) + err = errors.New("partitions have been released , collectionID = " + fmt.Sprintln(collID) + "target partitionIDs = " + fmt.Sprintln(partIDs)) + return nil, err } if len(searchPartIDs) == 0 && col.getLoadType() == loadTypeCollection { if err = col.checkReleasedPartitions(partIDs); err != nil { - return nil, nil, err + return nil, err } - return nil, nil, nil + return nil, nil } log.Debug("doing search in streaming", @@ -144,13 +137,13 @@ func (s *streaming) search(searchReqs []*searchRequest, ) if err != nil { log.Warn(err.Error()) - return searchResults, segmentResults, err + return searchResults, err } for _, segID := range segIDs { seg, err := s.replica.getSegmentByID(segID) if err != nil { log.Warn(err.Error()) - return searchResults, segmentResults, err + return searchResults, err } // TSafe less than searchTs means this vChannel is not available @@ -175,12 +168,11 @@ func (s *streaming) search(searchReqs []*searchRequest, searchResult, err := seg.search(plan, searchReqs, []Timestamp{searchTs}) if err != nil { - return searchResults, segmentResults, err + return searchResults, err } searchResults = append(searchResults, searchResult) - segmentResults = append(segmentResults, seg) } } - return searchResults, segmentResults, nil + return searchResults, nil }