Pass PlaceholderGroup pointer to prevent memory copy in SegCore (#17389)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/17399/head
bigsheeper 2022-06-06 21:34:05 +08:00 committed by GitHub
parent 98e95275fe
commit f38637c227
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 34 additions and 33 deletions

View File

@ -33,12 +33,13 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
public:
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment,
Timestamp timestamp,
const PlaceholderGroup& placeholder_group)
const PlaceholderGroup* placeholder_group)
: segment_(segment), timestamp_(timestamp), placeholder_group_(placeholder_group) {
}
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment, Timestamp timestamp)
: segment_(segment), timestamp_(timestamp) {
placeholder_group_ = nullptr;
}
SearchResult
@ -72,7 +73,7 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
private:
const segcore::SegmentInterface& segment_;
Timestamp timestamp_;
PlaceholderGroup placeholder_group_;
const PlaceholderGroup* placeholder_group_;
SearchResultOpt search_result_opt_;
RetrieveResultOpt retrieve_result_opt_;

View File

@ -74,7 +74,7 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
auto segment = dynamic_cast<const segcore::SegmentInternalInterface*>(&segment_);
AssertInfo(segment, "support SegmentSmallIndex Only");
SearchResult search_result;
auto& ph = placeholder_group_.at(0);
auto& ph = placeholder_group_->at(0);
auto src_data = ph.get_blob<EmbeddedType<VectorType>>();
auto num_queries = ph.num_of_queries_;

View File

@ -53,7 +53,7 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult&
std::unique_ptr<SearchResult>
SegmentInternalInterface::Search(const query::Plan* plan,
const query::PlaceholderGroup& placeholder_group,
const query::PlaceholderGroup* placeholder_group,
Timestamp timestamp) const {
std::shared_lock lck(mutex_);
check_search(plan);

View File

@ -47,7 +47,7 @@ class SegmentInterface {
FillTargetEntry(const query::Plan* plan, SearchResult& results) const = 0;
virtual std::unique_ptr<SearchResult>
Search(const query::Plan* Plan, const query::PlaceholderGroup& placeholder_group, Timestamp timestamp) const = 0;
Search(const query::Plan* Plan, const query::PlaceholderGroup* placeholder_group, Timestamp timestamp) const = 0;
virtual std::unique_ptr<proto::segcore::RetrieveResults>
Retrieve(const query::RetrievePlan* Plan, Timestamp timestamp) const = 0;
@ -95,7 +95,7 @@ class SegmentInternalInterface : public SegmentInterface {
std::unique_ptr<SearchResult>
Search(const query::Plan* Plan,
const query::PlaceholderGroup& placeholder_group,
const query::PlaceholderGroup* placeholder_group,
Timestamp timestamp) const override;
void

View File

@ -71,7 +71,7 @@ Search(CSegmentInterface c_segment,
auto segment = (milvus::segcore::SegmentInterface*)c_segment;
auto plan = (milvus::query::Plan*)c_plan;
auto phg_ptr = reinterpret_cast<const milvus::query::PlaceholderGroup*>(c_placeholder_group);
auto search_result = segment->Search(plan, *phg_ptr, timestamp);
auto search_result = segment->Search(plan, phg_ptr, timestamp);
if (!milvus::segcore::PositivelyRelated(plan->plan_node_->search_info_.metric_type_)) {
for (auto& dis : search_result->distances_) {
dis *= -1;

View File

@ -81,7 +81,7 @@ Search_SmallIndex(benchmark::State& state) {
Timestamp time = 10000000;
for (auto _ : state) {
auto qr = segment->Search(plan.get(), *ph_group, time);
auto qr = segment->Search(plan.get(), ph_group.get(), time);
}
}
@ -113,7 +113,7 @@ Search_Sealed(benchmark::State& state) {
}
Timestamp time = 10000000;
for (auto _ : state) {
auto qr = segment->Search(plan.get(), *ph_group, time);
auto qr = segment->Search(plan.get(), ph_group.get(), time);
}
}

View File

@ -180,7 +180,7 @@ TEST(Query, ExecWithPredicateLoader) {
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp time = 1000000;
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
int topk = 5;
Json json = SearchResultToJson(*sr);
@ -258,7 +258,7 @@ TEST(Query, ExecWithPredicateSmallN) {
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp time = 1000000;
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
int topk = 5;
Json json = SearchResultToJson(*sr);
@ -312,7 +312,7 @@ TEST(Query, ExecWithPredicate) {
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp time = 1000000;
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
int topk = 5;
Json json = SearchResultToJson(*sr);
@ -389,7 +389,7 @@ TEST(Query, ExecTerm) {
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp time = 1000000;
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
std::vector<std::vector<std::string>> results;
int topk = 5;
auto json = SearchResultToJson(*sr);
@ -431,7 +431,7 @@ TEST(Query, ExecEmpty) {
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp time = 1000000;
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
std::cout << SearchResultToJson(*sr);
for (auto i : sr->seg_offsets_) {
@ -482,7 +482,7 @@ TEST(Query, ExecWithoutPredicateFlat) {
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp time = 1000000;
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
std::vector<std::vector<std::string>> results;
int topk = 5;
auto json = SearchResultToJson(*sr);
@ -528,7 +528,7 @@ TEST(Query, ExecWithoutPredicate) {
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp time = 1000000;
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
std::vector<std::vector<std::string>> results;
int topk = 5;
auto json = SearchResultToJson(*sr);
@ -597,7 +597,7 @@ TEST(Indexing, InnerProduct) {
auto ph_group_raw = CreatePlaceholderGroupFromBlob(num_queries, 16, col.data());
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp ts = N * 2;
auto sr = segment->Search(plan.get(), *ph_group, ts);
auto sr = segment->Search(plan.get(), ph_group.get(), ts);
std::cout << SearchResultToJson(*sr).dump(2);
}
@ -706,7 +706,7 @@ TEST(Query, FillSegment) {
plan->target_entries_.clear();
plan->target_entries_.push_back(schema->get_field_id(FieldName("fakevec")));
plan->target_entries_.push_back(schema->get_field_id(FieldName("the_value")));
auto result = segment->Search(plan.get(), *ph, ts);
auto result = segment->Search(plan.get(), ph.get(), ts);
// std::cout << SearchResultToJson(result).dump(2);
result->result_offsets_.resize(topk * num_queries);
segment->FillTargetEntry(plan.get(), *result);
@ -799,7 +799,7 @@ TEST(Query, ExecWithPredicateBinary) {
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp time = 1000000;
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
int topk = 5;
Json json = SearchResultToJson(*sr);

View File

@ -74,7 +74,7 @@ TEST(Sealed, without_predicate) {
Timestamp time = 1000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
auto pre_result = SearchResultToJson(*sr);
auto indexing = std::make_shared<knowhere::IVF>();
@ -115,7 +115,7 @@ TEST(Sealed, without_predicate) {
sealed_segment->DropFieldData(fake_id);
sealed_segment->LoadIndex(load_info);
sr = sealed_segment->Search(plan.get(), *ph_group, time);
sr = sealed_segment->Search(plan.get(), ph_group.get(), time);
auto post_result = SearchResultToJson(*sr);
std::cout << "ref_result" << std::endl;
@ -180,7 +180,7 @@ TEST(Sealed, with_predicate) {
Timestamp time = 10000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
auto indexing = std::make_shared<knowhere::IVF>();
auto conf = knowhere::Config{{knowhere::meta::DIM, dim},
@ -211,7 +211,7 @@ TEST(Sealed, with_predicate) {
sealed_segment->DropFieldData(fake_id);
sealed_segment->LoadIndex(load_info);
sr = sealed_segment->Search(plan.get(), *ph_group, time);
sr = sealed_segment->Search(plan.get(), ph_group.get(), time);
for (int i = 0; i < num_queries; ++i) {
auto offset = i * topK;
@ -274,14 +274,14 @@ TEST(Sealed, LoadFieldData) {
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
ASSERT_ANY_THROW(segment->Search(plan.get(), *ph_group, time));
ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), time));
SealedLoadFieldData(dataset, *segment);
segment->DropFieldData(nothing_id);
segment->Search(plan.get(), *ph_group, time);
segment->Search(plan.get(), ph_group.get(), time);
segment->DropFieldData(fakevec_id);
ASSERT_ANY_THROW(segment->Search(plan.get(), *ph_group, time));
ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), time));
LoadIndexInfo vec_info;
vec_info.field_id = fakevec_id.get();
@ -304,18 +304,18 @@ TEST(Sealed, LoadFieldData) {
ASSERT_EQ(chunk_span3[i], ref3[i]);
}
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
auto json = SearchResultToJson(*sr);
std::cout << json.dump(1);
segment->DropIndex(fakevec_id);
ASSERT_ANY_THROW(segment->Search(plan.get(), *ph_group, time));
ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), time));
segment->LoadIndex(vec_info);
auto sr2 = segment->Search(plan.get(), *ph_group, time);
auto sr2 = segment->Search(plan.get(), ph_group.get(), time);
auto json2 = SearchResultToJson(*sr);
ASSERT_EQ(json.dump(-2), json2.dump(-2));
segment->DropFieldData(double_id);
ASSERT_ANY_THROW(segment->Search(plan.get(), *ph_group, time));
ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), time));
#ifdef __linux__
auto std_json = Json::parse(R"(
[
@ -441,7 +441,7 @@ TEST(Sealed, LoadScalarIndex) {
nothing_index.index = std::move(GenScalarIndexing<int32_t>(N, nothing_data.data()));
segment->LoadIndex(nothing_index);
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
auto json = SearchResultToJson(*sr);
std::cout << json.dump(1);
}
@ -497,7 +497,7 @@ TEST(Sealed, Delete) {
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
ASSERT_ANY_THROW(segment->Search(plan.get(), *ph_group, time));
ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), time));
SealedLoadFieldData(dataset, *segment);

View File

@ -540,7 +540,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
};
auto sub_result = FloatSearchBruteForce(search_dataset, vec_col.data(), N, nullptr);
auto sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
segment->FillPrimaryKeys(plan.get(), *sr);
segment->FillTargetEntry(plan.get(), *sr);
ASSERT_EQ(sr->pk_type_, DataType::VARCHAR);