Update Search return type (#12578)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/12595/head
Cai Yudong 2021-12-02 11:45:32 +08:00 committed by GitHub
parent a37a1062e1
commit cbb01051f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 50 deletions

View File

@ -94,15 +94,16 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult&
}
}
SearchResult
std::unique_ptr<SearchResult>
SegmentInternalInterface::Search(const query::Plan* plan,
const query::PlaceholderGroup& placeholder_group,
Timestamp timestamp) const {
std::shared_lock lck(mutex_);
check_search(plan);
query::ExecPlanNodeVisitor visitor(*this, timestamp, placeholder_group);
auto results = visitor.get_moved_result(*plan->plan_node_);
results.segment_ = (void*)this;
auto results = std::make_unique<SearchResult>();
*results = visitor.get_moved_result(*plan->plan_node_);
results->segment_ = (void*)this;
return results;
}

View File

@ -41,7 +41,7 @@ class SegmentInterface {
virtual void
FillTargetEntry(const query::Plan* plan, SearchResult& results) const = 0;
virtual SearchResult
virtual std::unique_ptr<SearchResult>
Search(const query::Plan* Plan, const query::PlaceholderGroup& placeholder_group, Timestamp timestamp) const = 0;
virtual std::unique_ptr<proto::segcore::RetrieveResults>
@ -84,7 +84,7 @@ class SegmentInternalInterface : public SegmentInterface {
return *ptr;
}
SearchResult
std::unique_ptr<SearchResult>
Search(const query::Plan* Plan,
const query::PlaceholderGroup& placeholder_group,
Timestamp timestamp) const override;

View File

@ -62,12 +62,11 @@ Search(CSegmentInterface c_segment,
CPlaceholderGroup c_placeholder_group,
uint64_t timestamp,
CSearchResult* result) {
auto search_result = std::make_unique<milvus::SearchResult>();
try {
auto segment = (milvus::segcore::SegmentInterface*)c_segment;
auto plan = (milvus::query::Plan*)c_plan;
auto phg_ptr = reinterpret_cast<const milvus::query::PlaceholderGroup*>(c_placeholder_group);
*search_result = segment->Search(plan, *phg_ptr, timestamp);
auto search_result = segment->Search(plan, *phg_ptr, timestamp);
if (!milvus::segcore::PositivelyRelated(plan->plan_node_->search_info_.metric_type_)) {
for (auto& dis : search_result->distances_) {
dis *= -1;

View File

@ -185,7 +185,7 @@ TEST(Query, ExecWithPredicateLoader) {
auto sr = segment->Search(plan.get(), *ph_group, time);
int topk = 5;
Json json = SearchResultToJson(sr);
Json json = SearchResultToJson(*sr);
auto ref = json::parse(R"(
[
[
@ -248,7 +248,7 @@ TEST(Query, ExecWithPredicateSmallN) {
auto sr = segment->Search(plan.get(), *ph_group, time);
int topk = 5;
Json json = SearchResultToJson(sr);
Json json = SearchResultToJson(*sr);
std::cout << json.dump(2);
}
@ -300,7 +300,7 @@ TEST(Query, ExecWithPredicate) {
auto sr = segment->Search(plan.get(), *ph_group, time);
int topk = 5;
Json json = SearchResultToJson(sr);
Json json = SearchResultToJson(*sr);
auto ref = json::parse(R"(
[
[
@ -357,15 +357,14 @@ TEST(Query, ExecTerm) {
auto num_queries = 3;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
SearchResult sr;
Timestamp time = 1000000;
sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), *ph_group, time);
std::vector<std::vector<std::string>> results;
int topk = 5;
auto json = SearchResultToJson(sr);
ASSERT_EQ(sr.num_queries_, num_queries);
ASSERT_EQ(sr.topk_, topk);
auto json = SearchResultToJson(*sr);
ASSERT_EQ(sr->num_queries_, num_queries);
ASSERT_EQ(sr->topk_, topk);
// for(auto x: )
}
@ -403,13 +402,13 @@ TEST(Query, ExecEmpty) {
Timestamp time = 1000000;
auto sr = segment->Search(plan.get(), *ph_group, time);
std::cout << SearchResultToJson(sr);
std::cout << SearchResultToJson(*sr);
for (auto i : sr.ids_) {
for (auto i : sr->ids_) {
ASSERT_EQ(i, -1);
}
for (auto v : sr.distances_) {
for (auto v : sr->distances_) {
ASSERT_EQ(v, std::numeric_limits<float>::max());
}
}
@ -449,13 +448,12 @@ TEST(Query, ExecWithoutPredicateFlat) {
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
SearchResult sr;
Timestamp time = 1000000;
sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), *ph_group, time);
std::vector<std::vector<std::string>> results;
int topk = 5;
auto json = SearchResultToJson(sr);
auto json = SearchResultToJson(*sr);
std::cout << json.dump(2);
}
@ -494,13 +492,12 @@ TEST(Query, ExecWithoutPredicate) {
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
SearchResult sr;
Timestamp time = 1000000;
sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), *ph_group, time);
std::vector<std::vector<std::string>> results;
int topk = 5;
auto json = SearchResultToJson(sr);
auto json = SearchResultToJson(*sr);
auto ref = json::parse(R"(
[
[
@ -551,9 +548,8 @@ TEST(Indexing, InnerProduct) {
auto ph_group_raw = CreatePlaceholderGroupFromBlob(num_queries, 16, col.data());
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp ts = N * 2;
SearchResult sr;
sr = segment->Search(plan.get(), *ph_group, ts);
std::cout << SearchResultToJson(sr).dump(2);
auto sr = segment->Search(plan.get(), *ph_group, ts);
std::cout << SearchResultToJson(*sr).dump(2);
}
TEST(Query, FillSegment) {
@ -661,12 +657,12 @@ TEST(Query, FillSegment) {
plan->target_entries_.clear();
plan->target_entries_.push_back(schema->get_offset(FieldName("fakevec")));
plan->target_entries_.push_back(schema->get_offset(FieldName("the_value")));
SearchResult result = segment->Search(plan.get(), *ph, ts);
auto result = segment->Search(plan.get(), *ph, ts);
// std::cout << SearchResultToJson(result).dump(2);
result.result_offsets_.resize(topk * num_queries);
segment->FillTargetEntry(plan.get(), result);
result->result_offsets_.resize(topk * num_queries);
segment->FillTargetEntry(plan.get(), *result);
auto ans = result.row_data_;
auto ans = result->row_data_;
ASSERT_EQ(ans.size(), topk * num_queries);
int64_t std_index = 0;
@ -675,7 +671,7 @@ TEST(Query, FillSegment) {
int64_t val;
memcpy(&val, vec.data(), sizeof(int64_t));
auto internal_offset = result.ids_[std_index];
auto internal_offset = result->ids_[std_index];
auto std_val = std_vec[internal_offset];
auto std_i32 = std_i32_vec[internal_offset];
std::vector<float> std_vfloat(dim);
@ -739,13 +735,12 @@ TEST(Query, ExecWithPredicateBinary) {
auto num_queries = 5;
auto ph_group_raw = CreateBinaryPlaceholderGroupFromBlob(num_queries, 512, vec_ptr.data() + 1024 * 512 / 8);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
SearchResult sr;
Timestamp time = 1000000;
sr = segment->Search(plan.get(), *ph_group, time);
auto sr = segment->Search(plan.get(), *ph_group, time);
int topk = 5;
Json json = SearchResultToJson(sr);
Json json = SearchResultToJson(*sr);
std::cout << json.dump(2);
// ASSERT_EQ(json.dump(2), ref.dump(2));
}

View File

@ -69,12 +69,11 @@ TEST(Sealed, without_predicate) {
auto ph_group_raw = CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
SearchResult sr;
Timestamp time = 1000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
sr = segment->Search(plan.get(), *ph_group, time);
auto pre_result = SearchResultToJson(sr);
auto sr = segment->Search(plan.get(), *ph_group, time);
auto pre_result = SearchResultToJson(*sr);
auto indexing = std::make_shared<knowhere::IVF>();
auto conf = knowhere::Config{{knowhere::meta::DIM, dim},
@ -100,9 +99,9 @@ TEST(Sealed, without_predicate) {
std::vector<int64_t> vec_ids(ids, ids + topK * num_queries);
std::vector<float> vec_dis(dis, dis + topK * num_queries);
sr.ids_ = vec_ids;
sr.distances_ = vec_dis;
auto ref_result = SearchResultToJson(sr);
sr->ids_ = vec_ids;
sr->distances_ = vec_dis;
auto ref_result = SearchResultToJson(*sr);
LoadIndexInfo load_info;
load_info.field_id = fake_id.get();
@ -112,7 +111,7 @@ TEST(Sealed, without_predicate) {
auto sealed_segment = SealedCreator(schema, dataset, load_info);
sr = sealed_segment->Search(plan.get(), *ph_group, time);
auto post_result = SearchResultToJson(sr);
auto post_result = SearchResultToJson(*sr);
std::cout << "ref_result" << std::endl;
std::cout << ref_result.dump(1) << std::endl;
std::cout << "post_result" << std::endl;
@ -171,12 +170,10 @@ TEST(Sealed, with_predicate) {
auto ph_group_raw = CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
SearchResult sr;
Timestamp time = 10000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
sr = segment->Search(plan.get(), *ph_group, time);
auto pre_sr = sr;
auto sr = segment->Search(plan.get(), *ph_group, time);
auto indexing = std::make_shared<knowhere::IVF>();
auto conf = knowhere::Config{{knowhere::meta::DIM, dim},
@ -205,11 +202,10 @@ TEST(Sealed, with_predicate) {
auto sealed_segment = SealedCreator(schema, dataset, load_info);
sr = sealed_segment->Search(plan.get(), *ph_group, time);
auto post_sr = sr;
for (int i = 0; i < num_queries; ++i) {
auto offset = i * topK;
ASSERT_EQ(post_sr.ids_[offset], 42000 + i);
ASSERT_EQ(post_sr.distances_[offset], 0.0);
ASSERT_EQ(sr->ids_[offset], 42000 + i);
ASSERT_EQ(sr->distances_[offset], 0.0);
}
}
@ -291,14 +287,14 @@ TEST(Sealed, LoadFieldData) {
}
auto sr = segment->Search(plan.get(), *ph_group, time);
auto json = SearchResultToJson(sr);
auto json = SearchResultToJson(*sr);
std::cout << json.dump(1);
segment->DropIndex(fakevec_id);
ASSERT_ANY_THROW(segment->Search(plan.get(), *ph_group, time));
segment->LoadIndex(vec_info);
auto sr2 = segment->Search(plan.get(), *ph_group, time);
auto json2 = SearchResultToJson(sr);
auto json2 = SearchResultToJson(*sr);
ASSERT_EQ(json.dump(-2), json2.dump(-2));
segment->DropFieldData(double_id);
ASSERT_ANY_THROW(segment->Search(plan.get(), *ph_group, time));