optimize search reduce logic (#7066)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/7075/head
Cai Yudong 2021-08-12 18:00:11 +08:00 committed by GitHub
parent b3f10ae5cc
commit 6c75301c70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 122 additions and 353 deletions

View File

@ -75,6 +75,7 @@ struct SearchResult {
public:
// TODO(gexi): utilize these field
void* segment_;
std::vector<int64_t> internal_seg_offsets_;
std::vector<int64_t> result_offsets_;
std::vector<std::vector<char>> row_data_;

View File

@ -546,7 +546,7 @@ GetNumOfQueries(const PlaceholderGroup* group) {
return group->at(0).num_of_queries_;
}
[[maybe_unused]] std::unique_ptr<RetrievePlan>
std::unique_ptr<RetrievePlan>
CreateRetrievePlan(const Schema& schema, proto::segcore::RetrieveRequest&& request) {
auto plan = std::make_unique<RetrievePlan>();
plan->ids_ = std::unique_ptr<proto::schema::IDs>(request.release_ids());

View File

@ -79,6 +79,7 @@ SegmentInternalInterface::Search(const query::Plan* plan,
check_search(plan);
query::ExecPlanNodeVisitor visitor(*this, timestamp, placeholder_group);
auto results = visitor.get_moved_result(*plan->plan_node_);
results.segment_ = (void*)this;
return results;
}

View File

@ -11,10 +11,12 @@
#include <vector>
#include <exceptions/EasyAssert.h>
#include "segcore/reduce_c.h"
#include "query/Plan.h"
#include "segcore/reduce_c.h"
#include "segcore/Reduce.h"
#include "segcore/ReduceStructure.h"
#include "segcore/SegmentInterface.h"
#include "common/Types.h"
#include "pb/milvus.pb.h"
@ -26,7 +28,7 @@ MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, fl
return status.code();
}
struct MarshaledHitsPeerGroup {
struct MarshaledHitsPerGroup {
std::vector<std::string> hits_;
std::vector<int64_t> blob_length_;
};
@ -41,7 +43,7 @@ struct MarshaledHits {
return marshaled_hits_.size();
}
std::vector<MarshaledHitsPeerGroup> marshaled_hits_;
std::vector<MarshaledHitsPerGroup> marshaled_hits_;
};
void
@ -53,16 +55,16 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
void
GetResultData(std::vector<std::vector<int64_t>>& search_records,
std::vector<SearchResult*>& search_results,
int64_t query_offset,
bool* is_selected,
int64_t query_idx,
int64_t topk) {
auto num_segments = search_results.size();
AssertInfo(num_segments > 0, "num segment must greater than 0");
std::vector<SearchResultPair> result_pairs;
int64_t query_offset = query_idx * topk;
for (int j = 0; j < num_segments; ++j) {
auto distance = search_results[j]->result_distances_[query_offset];
auto search_result = search_results[j];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
auto distance = search_result->result_distances_[query_offset];
result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j));
}
int64_t loc_offset = query_offset;
@ -72,24 +74,21 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
auto& result_pair = result_pairs[0];
auto index = result_pair.index_;
is_selected[index] = true;
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
search_records[index].push_back(result_pair.offset_++);
}
}
void
ResetSearchResult(std::vector<std::vector<int64_t>>& search_records,
std::vector<SearchResult*>& search_results,
bool* is_selected) {
ResetSearchResult(std::vector<std::vector<int64_t>>& search_records, std::vector<SearchResult*>& search_results) {
auto num_segments = search_results.size();
AssertInfo(num_segments > 0, "num segment must greater than 0");
for (int i = 0; i < num_segments; i++) {
if (is_selected[i] == false) {
continue;
}
auto search_result = search_results[i];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
if (search_result->result_offsets_.size() == 0) {
continue;
}
std::vector<float> result_distances;
std::vector<int64_t> internal_seg_offsets;
@ -108,8 +107,9 @@ ResetSearchResult(std::vector<std::vector<int64_t>>& search_records,
}
CStatus
ReduceSearchResults(CSearchResult* c_search_results, int64_t num_segments, bool* is_selected) {
ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_results, int64_t num_segments) {
try {
auto plan = (milvus::query::Plan*)c_plan;
std::vector<SearchResult*> search_results;
for (int i = 0; i < num_segments; ++i) {
search_results.push_back((SearchResult*)c_search_results[i]);
@ -118,12 +118,17 @@ ReduceSearchResults(CSearchResult* c_search_results, int64_t num_segments, bool*
auto num_queries = search_results[0]->num_queries_;
std::vector<std::vector<int64_t>> search_records(num_segments);
int64_t query_offset = 0;
for (int j = 0; j < num_queries; ++j) {
GetResultData(search_records, search_results, query_offset, is_selected, topk);
query_offset += topk;
for (int i = 0; i < num_queries; ++i) {
GetResultData(search_records, search_results, i, topk);
}
ResetSearchResult(search_records, search_results, is_selected);
ResetSearchResult(search_records, search_results);
for (int i = 0; i < num_segments; ++i) {
auto search_result = search_results[i];
auto segment = (milvus::segcore::SegmentInterface*)(search_result->segment_);
segment->FillTargetEntry(plan, *search_result);
}
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
@ -137,43 +142,29 @@ ReduceSearchResults(CSearchResult* c_search_results, int64_t num_segments, bool*
}
CStatus
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups,
CSearchResult* c_search_results,
bool* is_selected,
int64_t num_segments,
CSearchPlan c_plan) {
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, CSearchResult* c_search_results, int64_t num_segments) {
try {
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
auto topk = GetTopK(c_plan);
std::vector<int64_t> num_queries_peer_group(num_groups);
int64_t total_num_queries = 0;
for (int i = 0; i < num_groups; i++) {
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
num_queries_peer_group[i] = num_queries;
total_num_queries += num_queries;
}
auto marshaledHits = std::make_unique<MarshaledHits>(1);
auto sr = (SearchResult*)c_search_results[0];
auto topk = sr->topk_;
auto num_queries = sr->num_queries_;
std::vector<float> result_distances(total_num_queries * topk);
std::vector<int64_t> result_ids(total_num_queries * topk);
std::vector<std::vector<char>> row_datas(total_num_queries * topk);
std::vector<char> temp_ids;
std::vector<float> result_distances(num_queries * topk);
std::vector<std::vector<char>> row_datas(num_queries * topk);
std::vector<int64_t> counts(num_segments);
for (int i = 0; i < num_segments; i++) {
if (is_selected[i] == false) {
continue;
}
auto search_result = (SearchResult*)c_search_results[i];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
auto size = search_result->result_offsets_.size();
if (size == 0) {
continue;
}
#pragma omp parallel for
for (int j = 0; j < size; j++) {
auto loc = search_result->result_offsets_[j];
result_distances[loc] = search_result->result_distances_[j];
row_datas[loc] = search_result->row_data_[j];
memcpy(&result_ids[loc], search_result->row_data_[j].data(), sizeof(int64_t));
}
counts[i] = size;
}
@ -182,100 +173,35 @@ ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits,
for (int i = 0; i < num_segments; i++) {
total_count += counts[i];
}
AssertInfo(total_count == total_num_queries * topk,
"the reduces result's size less than total_num_queries*topk");
AssertInfo(total_count == num_queries * topk, "the reduces result's size less than total_num_queries*topk");
int64_t last_offset = 0;
for (int i = 0; i < num_groups; i++) {
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
hits_peer_group.hits_.resize(num_queries_peer_group[i]);
hits_peer_group.blob_length_.resize(num_queries_peer_group[i]);
std::vector<milvus::proto::milvus::Hits> hits(num_queries_peer_group[i]);
MarshaledHitsPerGroup& hits_per_group = (*marshaledHits).marshaled_hits_[0];
hits_per_group.hits_.resize(num_queries);
hits_per_group.blob_length_.resize(num_queries);
std::vector<milvus::proto::milvus::Hits> hits(num_queries);
#pragma omp parallel for
for (int m = 0; m < num_queries_peer_group[i]; m++) {
for (int n = 0; n < topk; n++) {
int64_t result_offset = last_offset + m * topk + n;
hits[m].add_ids(result_ids[result_offset]);
hits[m].add_scores(result_distances[result_offset]);
auto& row_data = row_datas[result_offset];
hits[m].add_row_data(row_data.data(), row_data.size());
}
for (int m = 0; m < num_queries; m++) {
for (int n = 0; n < topk; n++) {
int64_t result_offset = m * topk + n;
hits[m].add_scores(result_distances[result_offset]);
auto& row_data = row_datas[result_offset];
hits[m].add_row_data(row_data.data(), row_data.size());
hits[m].add_ids(*(int64_t*)row_data.data());
}
last_offset = last_offset + num_queries_peer_group[i] * topk;
}
#pragma omp parallel for
for (int j = 0; j < num_queries_peer_group[i]; j++) {
auto blob = hits[j].SerializeAsString();
hits_peer_group.hits_[j] = blob;
hits_peer_group.blob_length_[j] = blob.size();
}
for (int j = 0; j < num_queries; j++) {
auto blob = hits[j].SerializeAsString();
hits_per_group.hits_[j] = blob;
hits_per_group.blob_length_[j] = blob.size();
}
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
auto marshled_res = (CMarshaledHits)marshaledHits.release();
*c_marshaled_hits = marshled_res;
return status;
} catch (std::exception& e) {
auto status = CStatus();
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
*c_marshaled_hits = nullptr;
return status;
}
}
CStatus
ReorganizeSingleSearchResult(CMarshaledHits* c_marshaled_hits,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups,
CSearchResult c_search_result,
CSearchPlan c_plan) {
try {
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
auto search_result = (SearchResult*)c_search_result;
auto topk = GetTopK(c_plan);
std::vector<int64_t> num_queries_peer_group;
int64_t total_num_queries = 0;
for (int i = 0; i < num_groups; i++) {
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
num_queries_peer_group.push_back(num_queries);
}
int64_t last_offset = 0;
for (int i = 0; i < num_groups; i++) {
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
hits_peer_group.hits_.resize(num_queries_peer_group[i]);
hits_peer_group.blob_length_.resize(num_queries_peer_group[i]);
std::vector<milvus::proto::milvus::Hits> hits(num_queries_peer_group[i]);
#pragma omp parallel for
for (int m = 0; m < num_queries_peer_group[i]; m++) {
for (int n = 0; n < topk; n++) {
int64_t result_offset = last_offset + m * topk + n;
hits[m].add_scores(search_result->result_distances_[result_offset]);
auto& row_data = search_result->row_data_[result_offset];
hits[m].add_row_data(row_data.data(), row_data.size());
int64_t result_id;
memcpy(&result_id, row_data.data(), sizeof(int64_t));
hits[m].add_ids(result_id);
}
}
last_offset = last_offset + num_queries_peer_group[i] * topk;
#pragma omp parallel for
for (int j = 0; j < num_queries_peer_group[i]; j++) {
auto blob = hits[j].SerializeAsString();
hits_peer_group.hits_[j] = blob;
hits_peer_group.blob_length_[j] = blob.size();
}
}
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
auto marshled_res = (CMarshaledHits)marshaledHits.release();
*c_marshaled_hits = marshled_res;
auto marshaled_res = (CMarshaledHits)marshaledHits.release();
*c_marshaled_hits = marshaled_res;
return status;
} catch (std::exception& e) {
auto status = CStatus();
@ -318,14 +244,14 @@ GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits) {
}
int64_t
GetNumQueriesPeerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index) {
GetNumQueriesPerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index) {
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
auto& hits = marshaled_hits->marshaled_hits_[group_index].hits_;
return hits.size();
}
void
GetHitSizePeerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query) {
GetHitSizePerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query) {
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
auto& blob_lens = marshaled_hits->marshaled_hits_[group_index].blob_length_;
for (int i = 0; i < blob_lens.size(); i++) {

View File

@ -15,6 +15,7 @@ extern "C" {
#include <stdbool.h>
#include <stdint.h>
#include "segcore/plan_c.h"
#include "segcore/segment_c.h"
#include "common/type_c.h"
@ -27,23 +28,10 @@ int
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids);
CStatus
ReduceSearchResults(CSearchResult* search_results, int64_t num_segments, bool* is_selected);
ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* search_results, int64_t num_segments);
CStatus
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups,
CSearchResult* c_search_results,
bool* is_selected,
int64_t num_segments,
CSearchPlan c_plan);
CStatus
ReorganizeSingleSearchResult(CMarshaledHits* c_marshaled_hits,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups,
CSearchResult c_search_result,
CSearchPlan c_plan);
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, CSearchResult* c_search_results, int64_t num_segments);
int64_t
GetHitsBlobSize(CMarshaledHits c_marshaled_hits);
@ -52,10 +40,10 @@ void
GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits);
int64_t
GetNumQueriesPeerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index);
GetNumQueriesPerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index);
void
GetHitSizePeerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query);
GetHitSizePerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query);
#ifdef __cplusplus
}

View File

@ -88,20 +88,6 @@ Search(CSegmentInterface c_segment,
}
}
CStatus
FillTargetEntry(CSegmentInterface c_segment, CSearchPlan c_plan, CSearchResult c_result) {
auto segment = (milvus::segcore::SegmentInterface*)c_segment;
auto plan = (milvus::query::Plan*)c_plan;
auto result = (milvus::SearchResult*)c_result;
try {
segment->FillTargetEntry(plan, *result);
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
}
}
int64_t
GetMemoryUsageInBytes(CSegmentInterface c_segment) {
auto segment = (milvus::segcore::SegmentInterface*)c_segment;

View File

@ -46,9 +46,6 @@ Search(CSegmentInterface c_segment,
CProtoResult
GetEntityByIds(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp);
CStatus
FillTargetEntry(CSegmentInterface c_segment, CSearchPlan c_plan, CSearchResult result);
int64_t
GetMemoryUsageInBytes(CSegmentInterface c_segment);

View File

@ -33,7 +33,6 @@ namespace chrono = std::chrono;
using namespace milvus;
using namespace milvus::segcore;
// using namespace milvus::proto;
using namespace milvus::knowhere;
namespace {
@ -203,7 +202,7 @@ TEST(CApiTest, InsertTest) {
int N = 10000;
auto [raw_data, timestamps, uids] = generate_data(N);
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
int64_t offset;
PreInsert(segment, N, &offset);
@ -237,7 +236,7 @@ TEST(CApiTest, SearchTest) {
int N = 10000;
auto [raw_data, timestamps, uids] = generate_data(N);
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
int64_t offset;
PreInsert(segment, N, &offset);
@ -294,7 +293,7 @@ TEST(CApiTest, SearchTestWithExpr) {
int N = 10000;
auto [raw_data, timestamps, uids] = generate_data(N);
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
int64_t offset;
PreInsert(segment, N, &offset);
@ -350,7 +349,7 @@ TEST(CApiTest, GetMemoryUsageInBytesTest) {
int N = 10000;
auto [raw_data, timestamps, uids] = generate_data(N);
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
int64_t offset;
PreInsert(segment, N, &offset);
@ -392,7 +391,7 @@ TEST(CApiTest, GetRowCountTest) {
int N = 10000;
auto [raw_data, timestamps, uids] = generate_data(N);
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
int64_t offset;
PreInsert(segment, N, &offset);
@ -454,7 +453,7 @@ TEST(CApiTest, Reduce) {
int N = 10000;
auto [raw_data, timestamps, uids] = generate_data(N);
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
int64_t offset;
PreInsert(segment, N, &offset);
@ -503,14 +502,10 @@ TEST(CApiTest, Reduce) {
results.push_back(res1);
results.push_back(res2);
bool is_selected[2] = {false, false};
status = ReduceSearchResults(results.data(), 2, is_selected);
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
assert(status.error_code == Success);
FillTargetEntry(segment, plan, res1);
FillTargetEntry(segment, plan, res2);
void* reorganize_search_result = nullptr;
status = ReorganizeSearchResults(&reorganize_search_result, placeholderGroups.data(), 1, results.data(),
is_selected, 2, plan);
status = ReorganizeSearchResults(&reorganize_search_result, results.data(), results.size());
assert(status.error_code == Success);
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
assert(hits_blob_size > 0);
@ -518,12 +513,12 @@ TEST(CApiTest, Reduce) {
hits_blob.resize(hits_blob_size);
GetHitsBlob(reorganize_search_result, hits_blob.data());
assert(hits_blob.data() != nullptr);
auto num_queries_group = GetNumQueriesPeerGroup(reorganize_search_result, 0);
assert(num_queries_group == 10);
std::vector<int64_t> hit_size_peer_query;
hit_size_peer_query.resize(num_queries_group);
GetHitSizePeerQueries(reorganize_search_result, 0, hit_size_peer_query.data());
assert(hit_size_peer_query[0] > 0);
auto num_queries_group = GetNumQueriesPerGroup(reorganize_search_result, 0);
assert(num_queries_group == num_queries);
std::vector<int64_t> hit_size_per_query;
hit_size_per_query.resize(num_queries_group);
GetHitSizePerQueries(reorganize_search_result, 0, hit_size_per_query.data());
assert(hit_size_per_query[0] > 0);
DeleteSearchPlan(plan);
DeletePlaceholderGroup(placeholderGroup);
@ -540,7 +535,7 @@ TEST(CApiTest, ReduceSearchWithExpr) {
int N = 10000;
auto [raw_data, timestamps, uids] = generate_data(N);
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
int64_t offset;
PreInsert(segment, N, &offset);
@ -584,14 +579,10 @@ TEST(CApiTest, ReduceSearchWithExpr) {
results.push_back(res1);
results.push_back(res2);
bool is_selected[2] = {false, false};
status = ReduceSearchResults(results.data(), 2, is_selected);
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
assert(status.error_code == Success);
FillTargetEntry(segment, plan, res1);
FillTargetEntry(segment, plan, res2);
void* reorganize_search_result = nullptr;
status = ReorganizeSearchResults(&reorganize_search_result, placeholderGroups.data(), 1, results.data(),
is_selected, 2, plan);
status = ReorganizeSearchResults(&reorganize_search_result, results.data(), results.size());
assert(status.error_code == Success);
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
assert(hits_blob_size > 0);
@ -599,12 +590,12 @@ TEST(CApiTest, ReduceSearchWithExpr) {
hits_blob.resize(hits_blob_size);
GetHitsBlob(reorganize_search_result, hits_blob.data());
assert(hits_blob.data() != nullptr);
auto num_queries_group = GetNumQueriesPeerGroup(reorganize_search_result, 0);
assert(num_queries_group == 10);
std::vector<int64_t> hit_size_peer_query;
hit_size_peer_query.resize(num_queries_group);
GetHitSizePeerQueries(reorganize_search_result, 0, hit_size_peer_query.data());
assert(hit_size_peer_query[0] > 0);
auto num_queries_group = GetNumQueriesPerGroup(reorganize_search_result, 0);
assert(num_queries_group == num_queries);
std::vector<int64_t> hit_size_per_query;
hit_size_per_query.resize(num_queries_group);
GetHitSizePerQueries(reorganize_search_result, 0, hit_size_per_query.data());
assert(hit_size_per_query[0] > 0);
DeleteSearchPlan(plan);
DeletePlaceholderGroup(placeholderGroup);
@ -1921,10 +1912,8 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) {
std::vector<CSearchResult> results;
results.push_back(c_search_result_on_bigIndex);
bool is_selected[1] = {false};
status = ReduceSearchResults(results.data(), 1, is_selected);
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
assert(status.error_code == Success);
FillTargetEntry(segment, plan, c_search_result_on_bigIndex);
auto search_result_on_bigIndex = (*(SearchResult*)c_search_result_on_bigIndex);
for (int i = 0; i < num_queries; ++i) {
@ -2073,10 +2062,8 @@ vector_anns: <
std::vector<CSearchResult> results;
results.push_back(c_search_result_on_bigIndex);
bool is_selected[1] = {false};
status = ReduceSearchResults(results.data(), 1, is_selected);
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
assert(status.error_code == Success);
FillTargetEntry(segment, plan, c_search_result_on_bigIndex);
auto search_result_on_bigIndex = (*(SearchResult*)c_search_result_on_bigIndex);
for (int i = 0; i < num_queries; ++i) {

View File

@ -843,7 +843,6 @@ func (q *queryCollection) search(msg queryMsg) error {
}
searchResults := make([]*SearchResult, 0)
matchedSegments := make([]*Segment, 0)
sealedSegmentSearched := make([]UniqueID, 0)
// historical search
@ -853,7 +852,6 @@ func (q *queryCollection) search(msg queryMsg) error {
return err1
}
searchResults = append(searchResults, hisSearchResults...)
matchedSegments = append(matchedSegments, hisSegmentResults...)
for _, seg := range hisSegmentResults {
sealedSegmentSearched = append(sealedSegmentSearched, seg.segmentID)
}
@ -863,14 +861,12 @@ func (q *queryCollection) search(msg queryMsg) error {
var err2 error
for _, channel := range collection.getVChannels() {
var strSearchResults []*SearchResult
var strSegmentResults []*Segment
strSearchResults, strSegmentResults, err2 = q.streaming.search(searchRequests, collectionID, searchMsg.PartitionIDs, channel, plan, travelTimestamp)
strSearchResults, err2 = q.streaming.search(searchRequests, collectionID, searchMsg.PartitionIDs, channel, plan, travelTimestamp)
if err2 != nil {
log.Warn(err2.Error())
return err2
}
searchResults = append(searchResults, strSearchResults...)
matchedSegments = append(matchedSegments, strSegmentResults...)
}
tr.Record("streaming search done")
@ -939,38 +935,19 @@ func (q *queryCollection) search(msg queryMsg) error {
}
}
inReduced := make([]bool, len(searchResults))
numSegment := int64(len(searchResults))
var marshaledHits *MarshaledHits = nil
if numSegment == 1 {
inReduced[0] = true
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
sp.LogFields(oplog.String("statistical time", "fillTargetEntry end"))
if err != nil {
return err
}
marshaledHits, err = reorganizeSingleSearchResult(plan, searchRequests, searchResults[0])
sp.LogFields(oplog.String("statistical time", "reorganizeSingleSearchResult end"))
if err != nil {
return err
}
} else {
err = reduceSearchResults(searchResults, numSegment, inReduced)
sp.LogFields(oplog.String("statistical time", "reduceSearchResults end"))
if err != nil {
return err
}
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
sp.LogFields(oplog.String("statistical time", "fillTargetEntry end"))
if err != nil {
return err
}
marshaledHits, err = reorganizeSearchResults(plan, searchRequests, searchResults, numSegment, inReduced)
sp.LogFields(oplog.String("statistical time", "reorganizeSearchResults end"))
if err != nil {
return err
}
err = reduceSearchResultsAndFillData(plan, searchResults, numSegment)
sp.LogFields(oplog.String("statistical time", "reduceSearchResults end"))
if err != nil {
return err
}
marshaledHits, err = reorganizeSearchResults(searchResults, numSegment)
sp.LogFields(oplog.String("statistical time", "reorganizeSearchResults end"))
if err != nil {
return err
}
hitsBlob, err := marshaledHits.getHitsBlob()
sp.LogFields(oplog.String("statistical time", "getHitsBlob end"))
if err != nil {

View File

@ -23,10 +23,7 @@ import "C"
import (
"errors"
"strconv"
"sync"
"unsafe"
"github.com/milvus-io/milvus/internal/log"
)
type SearchResult struct {
@ -37,17 +34,15 @@ type MarshaledHits struct {
cMarshaledHits C.CMarshaledHits
}
func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inReduced []bool) error {
func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchResult, numSegments int64) error {
cSearchResults := make([]C.CSearchResult, 0)
for _, res := range searchResults {
cSearchResults = append(cSearchResults, res.cSearchResult)
}
cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0])
cNumSegments := C.long(numSegments)
cInReduced := (*C.bool)(&inReduced[0])
status := C.ReduceSearchResults(cSearchResultPtr, cNumSegments, cInReduced)
status := C.ReduceSearchResultsAndFillData(plan.cSearchPlan, cSearchResultPtr, cNumSegments)
errorCode := status.error_code
if errorCode != 0 {
@ -58,33 +53,7 @@ func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inRed
return nil
}
func fillTargetEntry(plan *SearchPlan, searchResults []*SearchResult, matchedSegments []*Segment, inReduced []bool) error {
wg := &sync.WaitGroup{}
//fmt.Println(inReduced)
for i := range inReduced {
if inReduced[i] {
wg.Add(1)
go func(i int) {
err := matchedSegments[i].fillTargetEntry(plan, searchResults[i])
if err != nil {
log.Warn(err.Error())
}
wg.Done()
}(i)
}
}
wg.Wait()
return nil
}
func reorganizeSearchResults(plan *SearchPlan, searchRequests []*searchRequest, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) {
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
for _, pg := range searchRequests {
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
}
var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
var cNumGroup = (C.long)(len(searchRequests))
func reorganizeSearchResults(searchResults []*SearchResult, numSegments int64) (*MarshaledHits, error) {
cSearchResults := make([]C.CSearchResult, 0)
for _, res := range searchResults {
cSearchResults = append(cSearchResults, res.cSearchResult)
@ -92,32 +61,9 @@ func reorganizeSearchResults(plan *SearchPlan, searchRequests []*searchRequest,
cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0])
var cNumSegments = C.long(numSegments)
var cInReduced = (*C.bool)(&inReduced[0])
var cMarshaledHits C.CMarshaledHits
status := C.ReorganizeSearchResults(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResultPtr, cInReduced, cNumSegments, plan.cSearchPlan)
errorCode := status.error_code
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return nil, errors.New("reorganizeSearchResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil
}
func reorganizeSingleSearchResult(plan *SearchPlan, placeholderGroups []*searchRequest, searchResult *SearchResult) (*MarshaledHits, error) {
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
for _, pg := range placeholderGroups {
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
}
var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
var cNumGroup = (C.long)(len(placeholderGroups))
cSearchResult := searchResult.cSearchResult
var cMarshaledHits C.CMarshaledHits
status := C.ReorganizeSingleSearchResult(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResult, plan.cSearchPlan)
status := C.ReorganizeSearchResults(&cMarshaledHits, cSearchResultPtr, cNumSegments)
errorCode := status.error_code
if errorCode != 0 {
@ -143,10 +89,10 @@ func (mh *MarshaledHits) getHitsBlob() ([]byte, error) {
func (mh *MarshaledHits) hitBlobSizeInGroup(groupOffset int64) ([]int64, error) {
cGroupOffset := (C.long)(groupOffset)
numQueries := C.GetNumQueriesPeerGroup(mh.cMarshaledHits, cGroupOffset)
numQueries := C.GetNumQueriesPerGroup(mh.cMarshaledHits, cGroupOffset)
result := make([]int64, int64(numQueries))
cResult := (*C.long)(&result[0])
C.GetHitSizePeerQueries(mh.cMarshaledHits, cGroupOffset, cResult)
C.GetHitSizePerQueries(mh.cMarshaledHits, cGroupOffset, cResult)
return result, nil
}

View File

@ -71,19 +71,14 @@ func TestReduce_AllFunc(t *testing.T) {
placeholderGroups = append(placeholderGroups, holder)
searchResults := make([]*SearchResult, 0)
matchedSegment := make([]*Segment, 0)
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{0})
assert.Nil(t, err)
searchResults = append(searchResults, searchResult)
matchedSegment = append(matchedSegment, segment)
testReduce := make([]bool, len(searchResults))
err = reduceSearchResults(searchResults, 1, testReduce)
assert.Nil(t, err)
err = fillTargetEntry(plan, searchResults, matchedSegment, testReduce)
err = reduceSearchResultsAndFillData(plan, searchResults, 1)
assert.Nil(t, err)
marshaledHits, err := reorganizeSearchResults(plan, placeholderGroups, searchResults, 1, testReduce)
marshaledHits, err := reorganizeSearchResults(searchResults, 1)
assert.NotNil(t, marshaledHits)
assert.Nil(t, err)

View File

@ -316,26 +316,6 @@ func (s *Segment) getEntityByIds(plan *RetrievePlan) (*segcorepb.RetrieveResults
return result, nil
}
func (s *Segment) fillTargetEntry(plan *SearchPlan, result *SearchResult) error {
s.segPtrMu.RLock()
defer s.segPtrMu.RUnlock()
if s.segmentPtr == nil {
return errors.New("null seg core pointer")
}
log.Debug("segment fill target entry, ", zap.Int64("segment ID = ", s.segmentID))
var status = C.FillTargetEntry(s.segmentPtr, plan.cSearchPlan, result.cSearchResult)
errorCode := status.error_code
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New("FillTargetEntry failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
return nil
}
//-------------------------------------------------------------------------------------- index info interface
func (s *Segment) setIndexName(fieldID int64, name string) error {
s.paramMutex.Lock()

View File

@ -435,22 +435,15 @@ func TestSegment_segmentSearch(t *testing.T) {
placeholderGroups = append(placeholderGroups, holder)
searchResults := make([]*SearchResult, 0)
matchedSegments := make([]*Segment, 0)
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{travelTimestamp})
assert.Nil(t, err)
searchResults = append(searchResults, searchResult)
matchedSegments = append(matchedSegments, segment)
///////////////////////////////////
inReduced := make([]bool, len(searchResults))
numSegment := int64(len(searchResults))
err2 := reduceSearchResults(searchResults, numSegment, inReduced)
assert.NoError(t, err2)
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
err = reduceSearchResultsAndFillData(plan, searchResults, numSegment)
assert.NoError(t, err)
marshaledHits, err := reorganizeSearchResults(plan, placeholderGroups, searchResults, numSegment, inReduced)
marshaledHits, err := reorganizeSearchResults(searchResults, numSegment)
assert.NoError(t, err)
hitsBlob, err := marshaledHits.getHitsBlob()
assert.NoError(t, err)

View File

@ -61,15 +61,10 @@ func (s *streaming) close() {
s.replica.freeAll()
}
func (s *streaming) search(searchReqs []*searchRequest,
collID UniqueID,
partIDs []UniqueID,
vChannel Channel,
plan *SearchPlan,
searchTs Timestamp) ([]*SearchResult, []*Segment, error) {
func (s *streaming) search(searchReqs []*searchRequest, collID UniqueID, partIDs []UniqueID, vChannel Channel,
plan *SearchPlan, searchTs Timestamp) ([]*SearchResult, error) {
searchResults := make([]*SearchResult, 0)
segmentResults := make([]*Segment, 0)
// get streaming partition ids
var searchPartIDs []UniqueID
@ -77,10 +72,10 @@ func (s *streaming) search(searchReqs []*searchRequest,
strPartIDs, err := s.replica.getPartitionIDs(collID)
if len(strPartIDs) == 0 {
// no partitions in collection, do empty search
return nil, nil, nil
return nil, nil
}
if err != nil {
return searchResults, segmentResults, err
return searchResults, err
}
log.Debug("no partition specified, search all partitions",
zap.Any("collectionID", collID),
@ -104,22 +99,20 @@ func (s *streaming) search(searchReqs []*searchRequest,
col, err := s.replica.getCollectionByID(collID)
if err != nil {
return nil, nil, err
return nil, err
}
// all partitions have been released
if len(searchPartIDs) == 0 && col.getLoadType() == loadTypePartition {
return nil, nil, errors.New("partitions have been released , collectionID = " +
fmt.Sprintln(collID) +
"target partitionIDs = " +
fmt.Sprintln(partIDs))
err = errors.New("partitions have been released , collectionID = " + fmt.Sprintln(collID) + "target partitionIDs = " + fmt.Sprintln(partIDs))
return nil, err
}
if len(searchPartIDs) == 0 && col.getLoadType() == loadTypeCollection {
if err = col.checkReleasedPartitions(partIDs); err != nil {
return nil, nil, err
return nil, err
}
return nil, nil, nil
return nil, nil
}
log.Debug("doing search in streaming",
@ -144,13 +137,13 @@ func (s *streaming) search(searchReqs []*searchRequest,
)
if err != nil {
log.Warn(err.Error())
return searchResults, segmentResults, err
return searchResults, err
}
for _, segID := range segIDs {
seg, err := s.replica.getSegmentByID(segID)
if err != nil {
log.Warn(err.Error())
return searchResults, segmentResults, err
return searchResults, err
}
// TSafe less than searchTs means this vChannel is not available
@ -175,12 +168,11 @@ func (s *streaming) search(searchReqs []*searchRequest,
searchResult, err := seg.search(plan, searchReqs, []Timestamp{searchTs})
if err != nil {
return searchResults, segmentResults, err
return searchResults, err
}
searchResults = append(searchResults, searchResult)
segmentResults = append(segmentResults, seg)
}
}
return searchResults, segmentResults, nil
return searchResults, nil
}