diff --git a/cpp/CHANGELOG.md b/cpp/CHANGELOG.md index a026353d0a..909e281719 100644 --- a/cpp/CHANGELOG.md +++ b/cpp/CHANGELOG.md @@ -14,15 +14,16 @@ Please mark all change in change log and use the ticket from JIRA. ## Improvement - MS-552 - Add and change the easylogging library - MS-553 - Refine cache code -- MS-557 - Merge Log.h +- MS-555 - Remove old scheduler - MS-556 - Add Job Definition in Scheduler +- MS-557 - Merge Log.h - MS-558 - Refine status code - MS-562 - Add JobMgr and TaskCreator in Scheduler - MS-566 - Refactor cmake -- MS-555 - Remove old scheduler - MS-574 - Milvus configuration refactor - MS-578 - Make sure milvus5.0 don't crack 0.3.1 data - MS-585 - Update namespace in scheduler +- MS-606 - Speed up result reduce - MS-608 - Update TODO names - MS-609 - Update task construct function diff --git a/cpp/src/scheduler/job/SearchJob.h b/cpp/src/scheduler/job/SearchJob.h index aed40cd942..fb2d87d876 100644 --- a/cpp/src/scheduler/job/SearchJob.h +++ b/cpp/src/scheduler/job/SearchJob.h @@ -37,8 +37,9 @@ namespace scheduler { using engine::meta::TableFileSchemaPtr; using Id2IndexMap = std::unordered_map; -using Id2DistanceMap = std::vector>; -using ResultSet = std::vector; +using IdDistPair = std::pair; +using Id2DistVec = std::vector; +using ResultSet = std::vector; class SearchJob : public Job { public: diff --git a/cpp/src/scheduler/task/SearchTask.cpp b/cpp/src/scheduler/task/SearchTask.cpp index 2beff8f4c3..20962d8a10 100644 --- a/cpp/src/scheduler/task/SearchTask.cpp +++ b/cpp/src/scheduler/task/SearchTask.cpp @@ -78,18 +78,19 @@ std::mutex XSearchTask::merge_mutex_; void CollectFileMetrics(int file_type, size_t file_size) { + server::MetricsBase& inst = server::Metrics::GetInstance(); switch (file_type) { case TableFileSchema::RAW: case TableFileSchema::TO_INDEX: { - server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size); - server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size); - server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size); + inst.RawFileSizeHistogramObserve(file_size); + inst.RawFileSizeTotalIncrement(file_size); + inst.RawFileSizeGaugeSet(file_size); break; } default: { - server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size); - server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size); - server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size); + inst.IndexFileSizeHistogramObserve(file_size); + inst.IndexFileSizeTotalIncrement(file_size); + inst.IndexFileSizeGaugeSet(file_size); break; } } @@ -206,16 +207,9 @@ XSearchTask::Execute() { double span = rc.RecordSection(hdr + ", do search"); // search_job->AccumSearchCost(span); - // step 3: cluster result - scheduler::ResultSet result_set; + // step 3: pick up topk result auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk; - XSearchTask::ClusterResult(output_ids, output_distance, nq, spec_k, result_set); - - span = rc.RecordSection(hdr + ", cluster result"); - // search_job->AccumReduceCost(span); - - // step 4: pick up topk result - XSearchTask::TopkResult(result_set, topk, metric_l2, search_job->GetResult()); + XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult()); span = rc.RecordSection(hdr + ", reduce topk"); // search_job->AccumReduceCost(span); @@ -235,142 +229,75 @@ XSearchTask::Execute() { } Status -XSearchTask::ClusterResult(const std::vector& output_ids, const std::vector& output_distance, - uint64_t nq, uint64_t topk, scheduler::ResultSet& result_set) { - if (output_ids.size() < nq * topk || output_distance.size() < nq * topk) { - std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) + " distance array size: " + - std::to_string(output_distance.size()); - ENGINE_LOG_ERROR << msg; - return Status(DB_ERROR, msg); - } +XSearchTask::TopkResult(const std::vector &input_ids, + const std::vector &input_distance, + uint64_t input_k, + uint64_t nq, + uint64_t topk, + bool ascending, + scheduler::ResultSet &result) { + scheduler::ResultSet result_buf; - result_set.clear(); - result_set.resize(nq); - - std::function reduce_worker = [&](size_t from_index, size_t to_index) { - for (auto i = from_index; i < to_index; i++) { - scheduler::Id2DistanceMap id_distance; - id_distance.reserve(topk); - for (auto k = 0; k < topk; k++) { - uint64_t index = i * topk + k; - if (output_ids[index] < 0) { - continue; + if (result.empty()) { + result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0))); + for (auto i = 0; i < nq; ++i) { + auto& result_buf_i = result_buf[i]; + uint64_t input_k_multi_i = input_k * i; + for (auto k = 0; k < input_k; ++k) { + uint64_t idx = input_k_multi_i + k; + auto& result_buf_item = result_buf_i[k]; + result_buf_item.first = input_ids[idx]; + result_buf_item.second = input_distance[idx]; + } + } + } else { + size_t tar_size = result[0].size(); + uint64_t output_k = std::min(topk, input_k + tar_size); + result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0))); + for (auto i = 0; i < nq; ++i) { + size_t buf_k = 0, src_k = 0, tar_k = 0; + uint64_t src_idx; + auto& result_i = result[i]; + auto& result_buf_i = result_buf[i]; + uint64_t input_k_multi_i = input_k * i; + while (buf_k < output_k && src_k < input_k && tar_k < tar_size) { + src_idx = input_k_multi_i + src_k; + auto& result_buf_item = result_buf_i[buf_k]; + auto& result_item = result_i[tar_k]; + if ((ascending && input_distance[src_idx] < result_item.second) || + (!ascending && input_distance[src_idx] > result_item.second)) { + result_buf_item.first = input_ids[src_idx]; + result_buf_item.second = input_distance[src_idx]; + src_k++; + } else { + result_buf_item = result_item; + tar_k++; } - id_distance.push_back(std::make_pair(output_ids[index], output_distance[index])); + buf_k++; } - result_set[i] = id_distance; - } - }; - // if (NeedParallelReduce(nq, topk)) { - // ParallelReduce(reduce_worker, nq); - // } else { - reduce_worker(0, nq); - // } - - return Status::OK(); -} - -Status -XSearchTask::MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target, - uint64_t topk, bool ascending) { - // Note: the score_src and score_target are already arranged by score in ascending order - if (distance_src.empty()) { - ENGINE_LOG_WARNING << "Empty distance source array"; - return Status::OK(); - } - - std::unique_lock lock(merge_mutex_); - if (distance_target.empty()) { - distance_target.swap(distance_src); - return Status::OK(); - } - - size_t src_count = distance_src.size(); - size_t target_count = distance_target.size(); - scheduler::Id2DistanceMap distance_merged; - distance_merged.reserve(topk); - size_t src_index = 0, target_index = 0; - while (true) { - // all score_src items are merged, if score_merged.size() still less than topk - // move items from score_target to score_merged until score_merged.size() equal topk - if (src_index >= src_count) { - for (size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) { - distance_merged.push_back(distance_target[i]); + if (buf_k < topk) { + if (src_k < input_k) { + while (buf_k < output_k && src_k < input_k) { + src_idx = input_k_multi_i + src_k; + auto& result_buf_item = result_buf_i[buf_k]; + result_buf_item.first = input_ids[src_idx]; + result_buf_item.second = input_distance[src_idx]; + src_k++; + buf_k++; + } + } else { + while (buf_k < output_k && tar_k < tar_size) { + result_buf_i[buf_k] = result_i[tar_k]; + tar_k++; + buf_k++; + } + } } - break; - } - - // all score_target items are merged, if score_merged.size() still less than topk - // move items from score_src to score_merged until score_merged.size() equal topk - if (target_index >= target_count) { - for (size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) { - distance_merged.push_back(distance_src[i]); - } - break; - } - - // compare score, - // if ascending = true, put smallest score to score_merged one by one - // else, put largest score to score_merged one by one - auto& src_pair = distance_src[src_index]; - auto& target_pair = distance_target[target_index]; - if (ascending) { - if (src_pair.second > target_pair.second) { - distance_merged.push_back(target_pair); - target_index++; - } else { - distance_merged.push_back(src_pair); - src_index++; - } - } else { - if (src_pair.second < target_pair.second) { - distance_merged.push_back(target_pair); - target_index++; - } else { - distance_merged.push_back(src_pair); - src_index++; - } - } - - // score_merged.size() already equal topk - if (distance_merged.size() >= topk) { - break; } } - distance_target.swap(distance_merged); - - return Status::OK(); -} - -Status -XSearchTask::TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending, - scheduler::ResultSet& result_target) { - if (result_target.empty()) { - result_target.swap(result_src); - return Status::OK(); - } - - if (result_src.size() != result_target.size()) { - std::string msg = "Invalid result set size"; - ENGINE_LOG_ERROR << msg; - return Status(DB_ERROR, msg); - } - - std::function ReduceWorker = [&](size_t from_index, size_t to_index) { - for (size_t i = from_index; i < to_index; i++) { - scheduler::Id2DistanceMap& score_src = result_src[i]; - scheduler::Id2DistanceMap& score_target = result_target[i]; - XSearchTask::MergeResult(score_src, score_target, topk, ascending); - } - }; - - // if (NeedParallelReduce(result_src.size(), topk)) { - // ParallelReduce(ReduceWorker, result_src.size()); - // } else { - ReduceWorker(0, result_src.size()); - // } + result.swap(result_buf); return Status::OK(); } diff --git a/cpp/src/scheduler/task/SearchTask.h b/cpp/src/scheduler/task/SearchTask.h index bd48d9244e..92d7235c6b 100644 --- a/cpp/src/scheduler/task/SearchTask.h +++ b/cpp/src/scheduler/task/SearchTask.h @@ -39,15 +39,13 @@ class XSearchTask : public Task { public: static Status - ClusterResult(const std::vector& output_ids, const std::vector& output_distence, uint64_t nq, - uint64_t topk, scheduler::ResultSet& result_set); - - static Status - MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target, uint64_t topk, - bool ascending); - - static Status - TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending, scheduler::ResultSet& result_target); + TopkResult(const std::vector &input_ids, + const std::vector &input_distance, + uint64_t input_k, + uint64_t nq, + uint64_t topk, + bool ascending, + scheduler::ResultSet &result); public: TableFileSchemaPtr file_; diff --git a/cpp/unittest/db/test_search.cpp b/cpp/unittest/db/test_search.cpp index 12fc8e277a..0b13af0c51 100644 --- a/cpp/unittest/db/test_search.cpp +++ b/cpp/unittest/db/test_search.cpp @@ -22,13 +22,10 @@ #include "scheduler/task/SearchTask.h" #include "utils/TimeRecorder.h" +using namespace milvus::scheduler; + namespace { -namespace ms = milvus; - -static constexpr uint64_t NQ = 15; -static constexpr uint64_t TOP_K = 64; - void BuildResult(uint64_t nq, uint64_t topk, @@ -48,76 +45,36 @@ BuildResult(uint64_t nq, } } -void -CheckResult(const ms::scheduler::Id2DistanceMap &src_1, - const ms::scheduler::Id2DistanceMap &src_2, - const ms::scheduler::Id2DistanceMap &target, - bool ascending) { - for (uint64_t i = 0; i < target.size() - 1; i++) { +void CheckTopkResult(const std::vector &input_ids_1, + const std::vector &input_distance_1, + const std::vector &input_ids_2, + const std::vector &input_distance_2, + uint64_t nq, + uint64_t topk, + bool ascending, + const ResultSet& result) { + ASSERT_EQ(result.size(), nq); + ASSERT_EQ(input_ids_1.size(), input_distance_1.size()); + ASSERT_EQ(input_ids_2.size(), input_distance_2.size()); + + uint64_t input_k1 = input_ids_1.size() / nq; + uint64_t input_k2 = input_ids_2.size() / nq; + + for (int64_t i = 0; i < nq; i++) { + std::vector src_vec(input_distance_1.begin()+i*input_k1, input_distance_1.begin()+(i+1)*input_k1); + src_vec.insert(src_vec.end(), input_distance_2.begin()+i*input_k2, input_distance_2.begin()+(i+1)*input_k2); if (ascending) { - ASSERT_LE(target[i].second, target[i + 1].second); + std::sort(src_vec.begin(), src_vec.end()); } else { - ASSERT_GE(target[i].second, target[i + 1].second); - } - } - - using ID2DistMap = std::map; - ID2DistMap src_map_1, src_map_2; - for (const auto &pair : src_1) { - src_map_1.insert(pair); - } - for (const auto &pair : src_2) { - src_map_2.insert(pair); - } - - for (const auto &pair : target) { - ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end()); - - float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first]; - ASSERT_LT(fabs(pair.second - dist), std::numeric_limits::epsilon()); - } -} - -void -CheckCluster(const std::vector &target_ids, - const std::vector &target_distence, - const ms::scheduler::ResultSet &src_result, - int64_t nq, - int64_t topk) { - ASSERT_EQ(src_result.size(), nq); - for (int64_t i = 0; i < nq; i++) { - auto &res = src_result[i]; - ASSERT_EQ(res.size(), topk); - - if (res.empty()) { - continue; + std::sort(src_vec.begin(), src_vec.end(), std::greater()); } - ASSERT_EQ(res[0].first, target_ids[i * topk]); - ASSERT_EQ(res[topk - 1].first, target_ids[i * topk + topk - 1]); - } -} - -void -CheckTopkResult(const ms::scheduler::ResultSet &src_result, - bool ascending, - int64_t nq, - int64_t topk) { - ASSERT_EQ(src_result.size(), nq); - for (int64_t i = 0; i < nq; i++) { - auto &res = src_result[i]; - ASSERT_EQ(res.size(), topk); - - if (res.empty()) { - continue; - } - - for (int64_t k = 0; k < topk - 1; k++) { - if (ascending) { - ASSERT_LE(res[k].second, res[k + 1].second); - } else { - ASSERT_GE(res[k].second, res[k + 1].second); + uint64_t n = std::min(topk, input_k1+input_k2); + for (uint64_t j = 0; j < n; j++) { + if (src_vec[j] != result[i][j].second) { + std::cout << src_vec[j] << " " << result[i][j].second << std::endl; } + ASSERT_TRUE(src_vec[j] == result[i][j].second); } } } @@ -125,179 +82,117 @@ CheckTopkResult(const ms::scheduler::ResultSet &src_result, } // namespace TEST(DBSearchTest, TOPK_TEST) { + uint64_t NQ = 15; + uint64_t TOP_K = 64; + bool ascending; + std::vector ids1, ids2; + std::vector dist1, dist2; + ResultSet result; + milvus::Status status; + + /* test1, id1/dist1 valid, id2/dist2 empty */ + ascending = true; + BuildResult(NQ, TOP_K, ascending, ids1, dist1); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test2, id1/dist1 valid, id2/dist2 valid */ + BuildResult(NQ, TOP_K, ascending, ids2, dist2); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test3, id1/dist1 small topk */ + ids1.clear(); + dist1.clear(); + result.clear(); + BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test4, id1/dist1 small topk, id2/dist2 small topk */ + ids2.clear(); + dist2.clear(); + result.clear(); + BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + +///////////////////////////////////////////////////////////////////////////////////////// + ascending = false; + ids1.clear(); + dist1.clear(); + ids2.clear(); + dist2.clear(); + result.clear(); + + /* test1, id1/dist1 valid, id2/dist2 empty */ + BuildResult(NQ, TOP_K, ascending, ids1, dist1); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test2, id1/dist1 valid, id2/dist2 valid */ + BuildResult(NQ, TOP_K, ascending, ids2, dist2); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test3, id1/dist1 small topk */ + ids1.clear(); + dist1.clear(); + result.clear(); + BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test4, id1/dist1 small topk, id2/dist2 small topk */ + ids2.clear(); + dist2.clear(); + result.clear(); + BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); +} + +TEST(DBSearchTest, REDUCE_PERF_TEST) { + int32_t nq = 100; + int32_t top_k = 1000; + int32_t index_file_num = 478; /* sift1B dataset, index files num */ bool ascending = true; - std::vector target_ids; - std::vector target_distence; - ms::scheduler::ResultSet src_result; - auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result); - ASSERT_FALSE(status.ok()); - ASSERT_TRUE(src_result.empty()); + std::vector input_ids; + std::vector input_distance; + ResultSet final_result; + milvus::Status status; - BuildResult(NQ, TOP_K, ascending, target_ids, target_distence); - status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(src_result.size(), NQ); + double span, reduce_cost = 0.0; + milvus::TimeRecorder rc(""); - ms::scheduler::ResultSet target_result; - status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); + for (int32_t i = 0; i < index_file_num; i++) { + BuildResult(nq, top_k, ascending, input_ids, input_distance); - status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, src_result); - ASSERT_FALSE(status.ok()); + rc.RecordSection("do search for context: " + std::to_string(i)); - status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); - ASSERT_TRUE(src_result.empty()); - ASSERT_EQ(target_result.size(), NQ); + // pick up topk result + status = XSearchTask::TopkResult(input_ids, input_distance, top_k, nq, top_k, ascending, final_result); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(final_result.size(), nq); - std::vector src_ids; - std::vector src_distence; - uint64_t wrong_topk = TOP_K - 10; - BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence); - - status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result); - ASSERT_TRUE(status.ok()); - - status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); - for (uint64_t i = 0; i < NQ; i++) { - ASSERT_EQ(target_result[i].size(), TOP_K); - } - - wrong_topk = TOP_K + 10; - BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence); - - status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); - for (uint64_t i = 0; i < NQ; i++) { - ASSERT_EQ(target_result[i].size(), TOP_K); + span = rc.RecordSection("reduce topk for context: " + std::to_string(i)); + reduce_cost += span; } -} - -TEST(DBSearchTest, MERGE_TEST) { - bool ascending = true; - std::vector target_ids; - std::vector target_distence; - std::vector src_ids; - std::vector src_distence; - ms::scheduler::ResultSet src_result, target_result; - - uint64_t src_count = 5, target_count = 8; - BuildResult(1, src_count, ascending, src_ids, src_distence); - BuildResult(1, target_count, ascending, target_ids, target_distence); - auto status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result); - ASSERT_TRUE(status.ok()); - status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result); - ASSERT_TRUE(status.ok()); - - { - ms::scheduler::Id2DistanceMap src = src_result[0]; - ms::scheduler::Id2DistanceMap target = target_result[0]; - status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), 10); - CheckResult(src_result[0], target_result[0], target, ascending); - } - - { - ms::scheduler::Id2DistanceMap src = src_result[0]; - ms::scheduler::Id2DistanceMap target; - status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), src_count); - ASSERT_TRUE(src.empty()); - CheckResult(src_result[0], target_result[0], target, ascending); - } - - { - ms::scheduler::Id2DistanceMap src = src_result[0]; - ms::scheduler::Id2DistanceMap target = target_result[0]; - status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), src_count + target_count); - CheckResult(src_result[0], target_result[0], target, ascending); - } - - { - ms::scheduler::Id2DistanceMap target = src_result[0]; - ms::scheduler::Id2DistanceMap src = target_result[0]; - status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), src_count + target_count); - CheckResult(src_result[0], target_result[0], target, ascending); - } -} - -TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) { - bool ascending = true; - std::vector target_ids; - std::vector target_distence; - ms::scheduler::ResultSet src_result; - - auto DoCluster = [&](int64_t nq, int64_t topk) { - ms::TimeRecorder rc("DoCluster"); - src_result.clear(); - BuildResult(nq, topk, ascending, target_ids, target_distence); - rc.RecordSection("build id/dietance map"); - - auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(src_result.size(), nq); - - rc.RecordSection("cluster result"); - - CheckCluster(target_ids, target_distence, src_result, nq, topk); - rc.RecordSection("check result"); - }; - - DoCluster(10000, 1000); - DoCluster(333, 999); - DoCluster(1, 1000); - DoCluster(1, 1); - DoCluster(7, 0); - DoCluster(9999, 1); - DoCluster(10001, 1); - DoCluster(58273, 1234); -} - -TEST(DBSearchTest, PARALLEL_TOPK_TEST) { - std::vector target_ids; - std::vector target_distence; - ms::scheduler::ResultSet src_result; - - std::vector insufficient_ids; - std::vector insufficient_distence; - ms::scheduler::ResultSet insufficient_result; - - auto DoTopk = [&](int64_t nq, int64_t topk, int64_t insufficient_topk, bool ascending) { - src_result.clear(); - insufficient_result.clear(); - - ms::TimeRecorder rc("DoCluster"); - - BuildResult(nq, topk, ascending, target_ids, target_distence); - auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result); - rc.RecordSection("cluster result"); - - BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence); - status = ms::scheduler::XSearchTask::ClusterResult(target_ids, - target_distence, - nq, - insufficient_topk, - insufficient_result); - rc.RecordSection("cluster result"); - - ms::scheduler::XSearchTask::TopkResult(insufficient_result, topk, ascending, src_result); - ASSERT_TRUE(status.ok()); - rc.RecordSection("topk"); - - CheckTopkResult(src_result, ascending, nq, topk); - rc.RecordSection("check result"); - }; - - DoTopk(5, 10, 4, false); - DoTopk(20005, 998, 123, true); -// DoTopk(9987, 12, 10, false); -// DoTopk(77777, 1000, 1, false); -// DoTopk(5432, 8899, 8899, true); + std::cout << "total reduce time: " << reduce_cost/1000 << " ms" << std::endl; }