MS-606 speed up result reduce

Former-commit-id: 775557c15789b01d63d149ae495653af3cdaca38
pull/191/head
yudong.cai 2019-10-08 14:34:17 +08:00
parent 37b519ff1e
commit e9e263fdb7
5 changed files with 216 additions and 394 deletions

View File

@ -14,15 +14,16 @@ Please mark all change in change log and use the ticket from JIRA.
## Improvement
- MS-552 - Add and change the easylogging library
- MS-553 - Refine cache code
- MS-557 - Merge Log.h
- MS-555 - Remove old scheduler
- MS-556 - Add Job Definition in Scheduler
- MS-557 - Merge Log.h
- MS-558 - Refine status code
- MS-562 - Add JobMgr and TaskCreator in Scheduler
- MS-566 - Refactor cmake
- MS-555 - Remove old scheduler
- MS-574 - Milvus configuration refactor
- MS-578 - Make sure milvus5.0 don't crack 0.3.1 data
- MS-585 - Update namespace in scheduler
- MS-606 - Speed up result reduce
- MS-608 - Update TODO names
- MS-609 - Update task construct function

View File

@ -37,8 +37,9 @@ namespace scheduler {
using engine::meta::TableFileSchemaPtr;
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
using Id2DistanceMap = std::vector<std::pair<int64_t, double>>;
using ResultSet = std::vector<Id2DistanceMap>;
using IdDistPair = std::pair<int64_t, double>;
using Id2DistVec = std::vector<IdDistPair>;
using ResultSet = std::vector<Id2DistVec>;
class SearchJob : public Job {
public:

View File

@ -78,18 +78,19 @@ std::mutex XSearchTask::merge_mutex_;
void
CollectFileMetrics(int file_type, size_t file_size) {
server::MetricsBase& inst = server::Metrics::GetInstance();
switch (file_type) {
case TableFileSchema::RAW:
case TableFileSchema::TO_INDEX: {
server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size);
server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size);
server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size);
inst.RawFileSizeHistogramObserve(file_size);
inst.RawFileSizeTotalIncrement(file_size);
inst.RawFileSizeGaugeSet(file_size);
break;
}
default: {
server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size);
server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size);
server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size);
inst.IndexFileSizeHistogramObserve(file_size);
inst.IndexFileSizeTotalIncrement(file_size);
inst.IndexFileSizeGaugeSet(file_size);
break;
}
}
@ -206,16 +207,9 @@ XSearchTask::Execute() {
double span = rc.RecordSection(hdr + ", do search");
// search_job->AccumSearchCost(span);
// step 3: cluster result
scheduler::ResultSet result_set;
// step 3: pick up topk result
auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
XSearchTask::ClusterResult(output_ids, output_distance, nq, spec_k, result_set);
span = rc.RecordSection(hdr + ", cluster result");
// search_job->AccumReduceCost(span);
// step 4: pick up topk result
XSearchTask::TopkResult(result_set, topk, metric_l2, search_job->GetResult());
XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult());
span = rc.RecordSection(hdr + ", reduce topk");
// search_job->AccumReduceCost(span);
@ -235,142 +229,75 @@ XSearchTask::Execute() {
}
Status
XSearchTask::ClusterResult(const std::vector<int64_t>& output_ids, const std::vector<float>& output_distance,
uint64_t nq, uint64_t topk, scheduler::ResultSet& result_set) {
if (output_ids.size() < nq * topk || output_distance.size() < nq * topk) {
std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) + " distance array size: " +
std::to_string(output_distance.size());
ENGINE_LOG_ERROR << msg;
return Status(DB_ERROR, msg);
}
XSearchTask::TopkResult(const std::vector<long> &input_ids,
const std::vector<float> &input_distance,
uint64_t input_k,
uint64_t nq,
uint64_t topk,
bool ascending,
scheduler::ResultSet &result) {
scheduler::ResultSet result_buf;
result_set.clear();
result_set.resize(nq);
std::function<void(size_t, size_t)> reduce_worker = [&](size_t from_index, size_t to_index) {
for (auto i = from_index; i < to_index; i++) {
scheduler::Id2DistanceMap id_distance;
id_distance.reserve(topk);
for (auto k = 0; k < topk; k++) {
uint64_t index = i * topk + k;
if (output_ids[index] < 0) {
continue;
if (result.empty()) {
result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0)));
for (auto i = 0; i < nq; ++i) {
auto& result_buf_i = result_buf[i];
uint64_t input_k_multi_i = input_k * i;
for (auto k = 0; k < input_k; ++k) {
uint64_t idx = input_k_multi_i + k;
auto& result_buf_item = result_buf_i[k];
result_buf_item.first = input_ids[idx];
result_buf_item.second = input_distance[idx];
}
id_distance.push_back(std::make_pair(output_ids[index], output_distance[index]));
}
result_set[i] = id_distance;
}
};
// if (NeedParallelReduce(nq, topk)) {
// ParallelReduce(reduce_worker, nq);
// } else {
reduce_worker(0, nq);
// }
return Status::OK();
}
Status
XSearchTask::MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target,
uint64_t topk, bool ascending) {
// Note: the score_src and score_target are already arranged by score in ascending order
if (distance_src.empty()) {
ENGINE_LOG_WARNING << "Empty distance source array";
return Status::OK();
}
std::unique_lock<std::mutex> lock(merge_mutex_);
if (distance_target.empty()) {
distance_target.swap(distance_src);
return Status::OK();
}
size_t src_count = distance_src.size();
size_t target_count = distance_target.size();
scheduler::Id2DistanceMap distance_merged;
distance_merged.reserve(topk);
size_t src_index = 0, target_index = 0;
while (true) {
// all score_src items are merged, if score_merged.size() still less than topk
// move items from score_target to score_merged until score_merged.size() equal topk
if (src_index >= src_count) {
for (size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) {
distance_merged.push_back(distance_target[i]);
}
break;
}
// all score_target items are merged, if score_merged.size() still less than topk
// move items from score_src to score_merged until score_merged.size() equal topk
if (target_index >= target_count) {
for (size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) {
distance_merged.push_back(distance_src[i]);
}
break;
}
// compare score,
// if ascending = true, put smallest score to score_merged one by one
// else, put largest score to score_merged one by one
auto& src_pair = distance_src[src_index];
auto& target_pair = distance_target[target_index];
if (ascending) {
if (src_pair.second > target_pair.second) {
distance_merged.push_back(target_pair);
target_index++;
} else {
distance_merged.push_back(src_pair);
src_index++;
}
} else {
if (src_pair.second < target_pair.second) {
distance_merged.push_back(target_pair);
target_index++;
size_t tar_size = result[0].size();
uint64_t output_k = std::min(topk, input_k + tar_size);
result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0)));
for (auto i = 0; i < nq; ++i) {
size_t buf_k = 0, src_k = 0, tar_k = 0;
uint64_t src_idx;
auto& result_i = result[i];
auto& result_buf_i = result_buf[i];
uint64_t input_k_multi_i = input_k * i;
while (buf_k < output_k && src_k < input_k && tar_k < tar_size) {
src_idx = input_k_multi_i + src_k;
auto& result_buf_item = result_buf_i[buf_k];
auto& result_item = result_i[tar_k];
if ((ascending && input_distance[src_idx] < result_item.second) ||
(!ascending && input_distance[src_idx] > result_item.second)) {
result_buf_item.first = input_ids[src_idx];
result_buf_item.second = input_distance[src_idx];
src_k++;
} else {
distance_merged.push_back(src_pair);
src_index++;
result_buf_item = result_item;
tar_k++;
}
buf_k++;
}
if (buf_k < topk) {
if (src_k < input_k) {
while (buf_k < output_k && src_k < input_k) {
src_idx = input_k_multi_i + src_k;
auto& result_buf_item = result_buf_i[buf_k];
result_buf_item.first = input_ids[src_idx];
result_buf_item.second = input_distance[src_idx];
src_k++;
buf_k++;
}
} else {
while (buf_k < output_k && tar_k < tar_size) {
result_buf_i[buf_k] = result_i[tar_k];
tar_k++;
buf_k++;
}
}
}
}
}
// score_merged.size() already equal topk
if (distance_merged.size() >= topk) {
break;
}
}
distance_target.swap(distance_merged);
return Status::OK();
}
Status
XSearchTask::TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending,
scheduler::ResultSet& result_target) {
if (result_target.empty()) {
result_target.swap(result_src);
return Status::OK();
}
if (result_src.size() != result_target.size()) {
std::string msg = "Invalid result set size";
ENGINE_LOG_ERROR << msg;
return Status(DB_ERROR, msg);
}
std::function<void(size_t, size_t)> ReduceWorker = [&](size_t from_index, size_t to_index) {
for (size_t i = from_index; i < to_index; i++) {
scheduler::Id2DistanceMap& score_src = result_src[i];
scheduler::Id2DistanceMap& score_target = result_target[i];
XSearchTask::MergeResult(score_src, score_target, topk, ascending);
}
};
// if (NeedParallelReduce(result_src.size(), topk)) {
// ParallelReduce(ReduceWorker, result_src.size());
// } else {
ReduceWorker(0, result_src.size());
// }
result.swap(result_buf);
return Status::OK();
}

View File

@ -39,15 +39,13 @@ class XSearchTask : public Task {
public:
static Status
ClusterResult(const std::vector<int64_t>& output_ids, const std::vector<float>& output_distence, uint64_t nq,
uint64_t topk, scheduler::ResultSet& result_set);
static Status
MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target, uint64_t topk,
bool ascending);
static Status
TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending, scheduler::ResultSet& result_target);
TopkResult(const std::vector<long> &input_ids,
const std::vector<float> &input_distance,
uint64_t input_k,
uint64_t nq,
uint64_t topk,
bool ascending,
scheduler::ResultSet &result);
public:
TableFileSchemaPtr file_;

View File

@ -22,13 +22,10 @@
#include "scheduler/task/SearchTask.h"
#include "utils/TimeRecorder.h"
using namespace milvus::scheduler;
namespace {
namespace ms = milvus;
static constexpr uint64_t NQ = 15;
static constexpr uint64_t TOP_K = 64;
void
BuildResult(uint64_t nq,
uint64_t topk,
@ -48,76 +45,36 @@ BuildResult(uint64_t nq,
}
}
void
CheckResult(const ms::scheduler::Id2DistanceMap &src_1,
const ms::scheduler::Id2DistanceMap &src_2,
const ms::scheduler::Id2DistanceMap &target,
bool ascending) {
for (uint64_t i = 0; i < target.size() - 1; i++) {
if (ascending) {
ASSERT_LE(target[i].second, target[i + 1].second);
} else {
ASSERT_GE(target[i].second, target[i + 1].second);
}
}
using ID2DistMap = std::map<int64_t, float>;
ID2DistMap src_map_1, src_map_2;
for (const auto &pair : src_1) {
src_map_1.insert(pair);
}
for (const auto &pair : src_2) {
src_map_2.insert(pair);
}
for (const auto &pair : target) {
ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end());
float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first];
ASSERT_LT(fabs(pair.second - dist), std::numeric_limits<float>::epsilon());
}
}
void
CheckCluster(const std::vector<int64_t> &target_ids,
const std::vector<float> &target_distence,
const ms::scheduler::ResultSet &src_result,
int64_t nq,
int64_t topk) {
ASSERT_EQ(src_result.size(), nq);
for (int64_t i = 0; i < nq; i++) {
auto &res = src_result[i];
ASSERT_EQ(res.size(), topk);
if (res.empty()) {
continue;
}
ASSERT_EQ(res[0].first, target_ids[i * topk]);
ASSERT_EQ(res[topk - 1].first, target_ids[i * topk + topk - 1]);
}
}
void
CheckTopkResult(const ms::scheduler::ResultSet &src_result,
void CheckTopkResult(const std::vector<long> &input_ids_1,
const std::vector<float> &input_distance_1,
const std::vector<long> &input_ids_2,
const std::vector<float> &input_distance_2,
uint64_t nq,
uint64_t topk,
bool ascending,
int64_t nq,
int64_t topk) {
ASSERT_EQ(src_result.size(), nq);
const ResultSet& result) {
ASSERT_EQ(result.size(), nq);
ASSERT_EQ(input_ids_1.size(), input_distance_1.size());
ASSERT_EQ(input_ids_2.size(), input_distance_2.size());
uint64_t input_k1 = input_ids_1.size() / nq;
uint64_t input_k2 = input_ids_2.size() / nq;
for (int64_t i = 0; i < nq; i++) {
auto &res = src_result[i];
ASSERT_EQ(res.size(), topk);
if (res.empty()) {
continue;
}
for (int64_t k = 0; k < topk - 1; k++) {
std::vector<float> src_vec(input_distance_1.begin()+i*input_k1, input_distance_1.begin()+(i+1)*input_k1);
src_vec.insert(src_vec.end(), input_distance_2.begin()+i*input_k2, input_distance_2.begin()+(i+1)*input_k2);
if (ascending) {
ASSERT_LE(res[k].second, res[k + 1].second);
std::sort(src_vec.begin(), src_vec.end());
} else {
ASSERT_GE(res[k].second, res[k + 1].second);
std::sort(src_vec.begin(), src_vec.end(), std::greater<float>());
}
uint64_t n = std::min(topk, input_k1+input_k2);
for (uint64_t j = 0; j < n; j++) {
if (src_vec[j] != result[i][j].second) {
std::cout << src_vec[j] << " " << result[i][j].second << std::endl;
}
ASSERT_TRUE(src_vec[j] == result[i][j].second);
}
}
}
@ -125,179 +82,117 @@ CheckTopkResult(const ms::scheduler::ResultSet &src_result,
} // namespace
TEST(DBSearchTest, TOPK_TEST) {
uint64_t NQ = 15;
uint64_t TOP_K = 64;
bool ascending;
std::vector<long> ids1, ids2;
std::vector<float> dist1, dist2;
ResultSet result;
milvus::Status status;
/* test1, id1/dist1 valid, id2/dist2 empty */
ascending = true;
BuildResult(NQ, TOP_K, ascending, ids1, dist1);
status = XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test2, id1/dist1 valid, id2/dist2 valid */
BuildResult(NQ, TOP_K, ascending, ids2, dist2);
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test3, id1/dist1 small topk */
ids1.clear();
dist1.clear();
result.clear();
BuildResult(NQ, TOP_K/2, ascending, ids1, dist1);
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test4, id1/dist1 small topk, id2/dist2 small topk */
ids2.clear();
dist2.clear();
result.clear();
BuildResult(NQ, TOP_K/3, ascending, ids2, dist2);
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
status = XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/////////////////////////////////////////////////////////////////////////////////////////
ascending = false;
ids1.clear();
dist1.clear();
ids2.clear();
dist2.clear();
result.clear();
/* test1, id1/dist1 valid, id2/dist2 empty */
BuildResult(NQ, TOP_K, ascending, ids1, dist1);
status = XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test2, id1/dist1 valid, id2/dist2 valid */
BuildResult(NQ, TOP_K, ascending, ids2, dist2);
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test3, id1/dist1 small topk */
ids1.clear();
dist1.clear();
result.clear();
BuildResult(NQ, TOP_K/2, ascending, ids1, dist1);
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test4, id1/dist1 small topk, id2/dist2 small topk */
ids2.clear();
dist2.clear();
result.clear();
BuildResult(NQ, TOP_K/3, ascending, ids2, dist2);
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
status = XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
}
TEST(DBSearchTest, REDUCE_PERF_TEST) {
int32_t nq = 100;
int32_t top_k = 1000;
int32_t index_file_num = 478; /* sift1B dataset, index files num */
bool ascending = true;
std::vector<int64_t> target_ids;
std::vector<float> target_distence;
ms::scheduler::ResultSet src_result;
auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
ASSERT_FALSE(status.ok());
ASSERT_TRUE(src_result.empty());
std::vector<long> input_ids;
std::vector<float> input_distance;
ResultSet final_result;
milvus::Status status;
BuildResult(NQ, TOP_K, ascending, target_ids, target_distence);
status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
double span, reduce_cost = 0.0;
milvus::TimeRecorder rc("");
for (int32_t i = 0; i < index_file_num; i++) {
BuildResult(nq, top_k, ascending, input_ids, input_distance);
rc.RecordSection("do search for context: " + std::to_string(i));
// pick up topk result
status = XSearchTask::TopkResult(input_ids, input_distance, top_k, nq, top_k, ascending, final_result);
ASSERT_TRUE(status.ok());
ASSERT_EQ(src_result.size(), NQ);
ASSERT_EQ(final_result.size(), nq);
ms::scheduler::ResultSet target_result;
status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok());
status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, src_result);
ASSERT_FALSE(status.ok());
status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok());
ASSERT_TRUE(src_result.empty());
ASSERT_EQ(target_result.size(), NQ);
std::vector<int64_t> src_ids;
std::vector<float> src_distence;
uint64_t wrong_topk = TOP_K - 10;
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
ASSERT_TRUE(status.ok());
status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok());
for (uint64_t i = 0; i < NQ; i++) {
ASSERT_EQ(target_result[i].size(), TOP_K);
}
wrong_topk = TOP_K + 10;
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok());
for (uint64_t i = 0; i < NQ; i++) {
ASSERT_EQ(target_result[i].size(), TOP_K);
span = rc.RecordSection("reduce topk for context: " + std::to_string(i));
reduce_cost += span;
}
}
TEST(DBSearchTest, MERGE_TEST) {
bool ascending = true;
std::vector<int64_t> target_ids;
std::vector<float> target_distence;
std::vector<int64_t> src_ids;
std::vector<float> src_distence;
ms::scheduler::ResultSet src_result, target_result;
uint64_t src_count = 5, target_count = 8;
BuildResult(1, src_count, ascending, src_ids, src_distence);
BuildResult(1, target_count, ascending, target_ids, target_distence);
auto status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
ASSERT_TRUE(status.ok());
status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
ASSERT_TRUE(status.ok());
{
ms::scheduler::Id2DistanceMap src = src_result[0];
ms::scheduler::Id2DistanceMap target = target_result[0];
status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), 10);
CheckResult(src_result[0], target_result[0], target, ascending);
}
{
ms::scheduler::Id2DistanceMap src = src_result[0];
ms::scheduler::Id2DistanceMap target;
status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count);
ASSERT_TRUE(src.empty());
CheckResult(src_result[0], target_result[0], target, ascending);
}
{
ms::scheduler::Id2DistanceMap src = src_result[0];
ms::scheduler::Id2DistanceMap target = target_result[0];
status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count + target_count);
CheckResult(src_result[0], target_result[0], target, ascending);
}
{
ms::scheduler::Id2DistanceMap target = src_result[0];
ms::scheduler::Id2DistanceMap src = target_result[0];
status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count + target_count);
CheckResult(src_result[0], target_result[0], target, ascending);
}
}
TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) {
bool ascending = true;
std::vector<int64_t> target_ids;
std::vector<float> target_distence;
ms::scheduler::ResultSet src_result;
auto DoCluster = [&](int64_t nq, int64_t topk) {
ms::TimeRecorder rc("DoCluster");
src_result.clear();
BuildResult(nq, topk, ascending, target_ids, target_distence);
rc.RecordSection("build id/dietance map");
auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
ASSERT_TRUE(status.ok());
ASSERT_EQ(src_result.size(), nq);
rc.RecordSection("cluster result");
CheckCluster(target_ids, target_distence, src_result, nq, topk);
rc.RecordSection("check result");
};
DoCluster(10000, 1000);
DoCluster(333, 999);
DoCluster(1, 1000);
DoCluster(1, 1);
DoCluster(7, 0);
DoCluster(9999, 1);
DoCluster(10001, 1);
DoCluster(58273, 1234);
}
TEST(DBSearchTest, PARALLEL_TOPK_TEST) {
std::vector<int64_t> target_ids;
std::vector<float> target_distence;
ms::scheduler::ResultSet src_result;
std::vector<int64_t> insufficient_ids;
std::vector<float> insufficient_distence;
ms::scheduler::ResultSet insufficient_result;
auto DoTopk = [&](int64_t nq, int64_t topk, int64_t insufficient_topk, bool ascending) {
src_result.clear();
insufficient_result.clear();
ms::TimeRecorder rc("DoCluster");
BuildResult(nq, topk, ascending, target_ids, target_distence);
auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
rc.RecordSection("cluster result");
BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence);
status = ms::scheduler::XSearchTask::ClusterResult(target_ids,
target_distence,
nq,
insufficient_topk,
insufficient_result);
rc.RecordSection("cluster result");
ms::scheduler::XSearchTask::TopkResult(insufficient_result, topk, ascending, src_result);
ASSERT_TRUE(status.ok());
rc.RecordSection("topk");
CheckTopkResult(src_result, ascending, nq, topk);
rc.RecordSection("check result");
};
DoTopk(5, 10, 4, false);
DoTopk(20005, 998, 123, true);
// DoTopk(9987, 12, 10, false);
// DoTopk(77777, 1000, 1, false);
// DoTopk(5432, 8899, 8899, true);
std::cout << "total reduce time: " << reduce_cost/1000 << " ms" << std::endl;
}