mirror of https://github.com/milvus-io/milvus.git
MS-606 speed up result reduce
Former-commit-id: 775557c15789b01d63d149ae495653af3cdaca38pull/191/head
parent
37b519ff1e
commit
e9e263fdb7
|
@ -14,15 +14,16 @@ Please mark all change in change log and use the ticket from JIRA.
|
|||
## Improvement
|
||||
- MS-552 - Add and change the easylogging library
|
||||
- MS-553 - Refine cache code
|
||||
- MS-557 - Merge Log.h
|
||||
- MS-555 - Remove old scheduler
|
||||
- MS-556 - Add Job Definition in Scheduler
|
||||
- MS-557 - Merge Log.h
|
||||
- MS-558 - Refine status code
|
||||
- MS-562 - Add JobMgr and TaskCreator in Scheduler
|
||||
- MS-566 - Refactor cmake
|
||||
- MS-555 - Remove old scheduler
|
||||
- MS-574 - Milvus configuration refactor
|
||||
- MS-578 - Make sure milvus5.0 don't crack 0.3.1 data
|
||||
- MS-585 - Update namespace in scheduler
|
||||
- MS-606 - Speed up result reduce
|
||||
- MS-608 - Update TODO names
|
||||
- MS-609 - Update task construct function
|
||||
|
||||
|
|
|
@ -37,8 +37,9 @@ namespace scheduler {
|
|||
using engine::meta::TableFileSchemaPtr;
|
||||
|
||||
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
|
||||
using Id2DistanceMap = std::vector<std::pair<int64_t, double>>;
|
||||
using ResultSet = std::vector<Id2DistanceMap>;
|
||||
using IdDistPair = std::pair<int64_t, double>;
|
||||
using Id2DistVec = std::vector<IdDistPair>;
|
||||
using ResultSet = std::vector<Id2DistVec>;
|
||||
|
||||
class SearchJob : public Job {
|
||||
public:
|
||||
|
|
|
@ -78,18 +78,19 @@ std::mutex XSearchTask::merge_mutex_;
|
|||
|
||||
void
|
||||
CollectFileMetrics(int file_type, size_t file_size) {
|
||||
server::MetricsBase& inst = server::Metrics::GetInstance();
|
||||
switch (file_type) {
|
||||
case TableFileSchema::RAW:
|
||||
case TableFileSchema::TO_INDEX: {
|
||||
server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size);
|
||||
server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size);
|
||||
server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size);
|
||||
inst.RawFileSizeHistogramObserve(file_size);
|
||||
inst.RawFileSizeTotalIncrement(file_size);
|
||||
inst.RawFileSizeGaugeSet(file_size);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size);
|
||||
server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size);
|
||||
server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size);
|
||||
inst.IndexFileSizeHistogramObserve(file_size);
|
||||
inst.IndexFileSizeTotalIncrement(file_size);
|
||||
inst.IndexFileSizeGaugeSet(file_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -206,16 +207,9 @@ XSearchTask::Execute() {
|
|||
double span = rc.RecordSection(hdr + ", do search");
|
||||
// search_job->AccumSearchCost(span);
|
||||
|
||||
// step 3: cluster result
|
||||
scheduler::ResultSet result_set;
|
||||
// step 3: pick up topk result
|
||||
auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
|
||||
XSearchTask::ClusterResult(output_ids, output_distance, nq, spec_k, result_set);
|
||||
|
||||
span = rc.RecordSection(hdr + ", cluster result");
|
||||
// search_job->AccumReduceCost(span);
|
||||
|
||||
// step 4: pick up topk result
|
||||
XSearchTask::TopkResult(result_set, topk, metric_l2, search_job->GetResult());
|
||||
XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult());
|
||||
|
||||
span = rc.RecordSection(hdr + ", reduce topk");
|
||||
// search_job->AccumReduceCost(span);
|
||||
|
@ -235,142 +229,75 @@ XSearchTask::Execute() {
|
|||
}
|
||||
|
||||
Status
|
||||
XSearchTask::ClusterResult(const std::vector<int64_t>& output_ids, const std::vector<float>& output_distance,
|
||||
uint64_t nq, uint64_t topk, scheduler::ResultSet& result_set) {
|
||||
if (output_ids.size() < nq * topk || output_distance.size() < nq * topk) {
|
||||
std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) + " distance array size: " +
|
||||
std::to_string(output_distance.size());
|
||||
ENGINE_LOG_ERROR << msg;
|
||||
return Status(DB_ERROR, msg);
|
||||
}
|
||||
XSearchTask::TopkResult(const std::vector<long> &input_ids,
|
||||
const std::vector<float> &input_distance,
|
||||
uint64_t input_k,
|
||||
uint64_t nq,
|
||||
uint64_t topk,
|
||||
bool ascending,
|
||||
scheduler::ResultSet &result) {
|
||||
scheduler::ResultSet result_buf;
|
||||
|
||||
result_set.clear();
|
||||
result_set.resize(nq);
|
||||
|
||||
std::function<void(size_t, size_t)> reduce_worker = [&](size_t from_index, size_t to_index) {
|
||||
for (auto i = from_index; i < to_index; i++) {
|
||||
scheduler::Id2DistanceMap id_distance;
|
||||
id_distance.reserve(topk);
|
||||
for (auto k = 0; k < topk; k++) {
|
||||
uint64_t index = i * topk + k;
|
||||
if (output_ids[index] < 0) {
|
||||
continue;
|
||||
if (result.empty()) {
|
||||
result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0)));
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
auto& result_buf_i = result_buf[i];
|
||||
uint64_t input_k_multi_i = input_k * i;
|
||||
for (auto k = 0; k < input_k; ++k) {
|
||||
uint64_t idx = input_k_multi_i + k;
|
||||
auto& result_buf_item = result_buf_i[k];
|
||||
result_buf_item.first = input_ids[idx];
|
||||
result_buf_item.second = input_distance[idx];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
size_t tar_size = result[0].size();
|
||||
uint64_t output_k = std::min(topk, input_k + tar_size);
|
||||
result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0)));
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
size_t buf_k = 0, src_k = 0, tar_k = 0;
|
||||
uint64_t src_idx;
|
||||
auto& result_i = result[i];
|
||||
auto& result_buf_i = result_buf[i];
|
||||
uint64_t input_k_multi_i = input_k * i;
|
||||
while (buf_k < output_k && src_k < input_k && tar_k < tar_size) {
|
||||
src_idx = input_k_multi_i + src_k;
|
||||
auto& result_buf_item = result_buf_i[buf_k];
|
||||
auto& result_item = result_i[tar_k];
|
||||
if ((ascending && input_distance[src_idx] < result_item.second) ||
|
||||
(!ascending && input_distance[src_idx] > result_item.second)) {
|
||||
result_buf_item.first = input_ids[src_idx];
|
||||
result_buf_item.second = input_distance[src_idx];
|
||||
src_k++;
|
||||
} else {
|
||||
result_buf_item = result_item;
|
||||
tar_k++;
|
||||
}
|
||||
id_distance.push_back(std::make_pair(output_ids[index], output_distance[index]));
|
||||
buf_k++;
|
||||
}
|
||||
result_set[i] = id_distance;
|
||||
}
|
||||
};
|
||||
|
||||
// if (NeedParallelReduce(nq, topk)) {
|
||||
// ParallelReduce(reduce_worker, nq);
|
||||
// } else {
|
||||
reduce_worker(0, nq);
|
||||
// }
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
XSearchTask::MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target,
|
||||
uint64_t topk, bool ascending) {
|
||||
// Note: the score_src and score_target are already arranged by score in ascending order
|
||||
if (distance_src.empty()) {
|
||||
ENGINE_LOG_WARNING << "Empty distance source array";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(merge_mutex_);
|
||||
if (distance_target.empty()) {
|
||||
distance_target.swap(distance_src);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t src_count = distance_src.size();
|
||||
size_t target_count = distance_target.size();
|
||||
scheduler::Id2DistanceMap distance_merged;
|
||||
distance_merged.reserve(topk);
|
||||
size_t src_index = 0, target_index = 0;
|
||||
while (true) {
|
||||
// all score_src items are merged, if score_merged.size() still less than topk
|
||||
// move items from score_target to score_merged until score_merged.size() equal topk
|
||||
if (src_index >= src_count) {
|
||||
for (size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) {
|
||||
distance_merged.push_back(distance_target[i]);
|
||||
if (buf_k < topk) {
|
||||
if (src_k < input_k) {
|
||||
while (buf_k < output_k && src_k < input_k) {
|
||||
src_idx = input_k_multi_i + src_k;
|
||||
auto& result_buf_item = result_buf_i[buf_k];
|
||||
result_buf_item.first = input_ids[src_idx];
|
||||
result_buf_item.second = input_distance[src_idx];
|
||||
src_k++;
|
||||
buf_k++;
|
||||
}
|
||||
} else {
|
||||
while (buf_k < output_k && tar_k < tar_size) {
|
||||
result_buf_i[buf_k] = result_i[tar_k];
|
||||
tar_k++;
|
||||
buf_k++;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// all score_target items are merged, if score_merged.size() still less than topk
|
||||
// move items from score_src to score_merged until score_merged.size() equal topk
|
||||
if (target_index >= target_count) {
|
||||
for (size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) {
|
||||
distance_merged.push_back(distance_src[i]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// compare score,
|
||||
// if ascending = true, put smallest score to score_merged one by one
|
||||
// else, put largest score to score_merged one by one
|
||||
auto& src_pair = distance_src[src_index];
|
||||
auto& target_pair = distance_target[target_index];
|
||||
if (ascending) {
|
||||
if (src_pair.second > target_pair.second) {
|
||||
distance_merged.push_back(target_pair);
|
||||
target_index++;
|
||||
} else {
|
||||
distance_merged.push_back(src_pair);
|
||||
src_index++;
|
||||
}
|
||||
} else {
|
||||
if (src_pair.second < target_pair.second) {
|
||||
distance_merged.push_back(target_pair);
|
||||
target_index++;
|
||||
} else {
|
||||
distance_merged.push_back(src_pair);
|
||||
src_index++;
|
||||
}
|
||||
}
|
||||
|
||||
// score_merged.size() already equal topk
|
||||
if (distance_merged.size() >= topk) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
distance_target.swap(distance_merged);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
XSearchTask::TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending,
|
||||
scheduler::ResultSet& result_target) {
|
||||
if (result_target.empty()) {
|
||||
result_target.swap(result_src);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (result_src.size() != result_target.size()) {
|
||||
std::string msg = "Invalid result set size";
|
||||
ENGINE_LOG_ERROR << msg;
|
||||
return Status(DB_ERROR, msg);
|
||||
}
|
||||
|
||||
std::function<void(size_t, size_t)> ReduceWorker = [&](size_t from_index, size_t to_index) {
|
||||
for (size_t i = from_index; i < to_index; i++) {
|
||||
scheduler::Id2DistanceMap& score_src = result_src[i];
|
||||
scheduler::Id2DistanceMap& score_target = result_target[i];
|
||||
XSearchTask::MergeResult(score_src, score_target, topk, ascending);
|
||||
}
|
||||
};
|
||||
|
||||
// if (NeedParallelReduce(result_src.size(), topk)) {
|
||||
// ParallelReduce(ReduceWorker, result_src.size());
|
||||
// } else {
|
||||
ReduceWorker(0, result_src.size());
|
||||
// }
|
||||
result.swap(result_buf);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -39,15 +39,13 @@ class XSearchTask : public Task {
|
|||
|
||||
public:
|
||||
static Status
|
||||
ClusterResult(const std::vector<int64_t>& output_ids, const std::vector<float>& output_distence, uint64_t nq,
|
||||
uint64_t topk, scheduler::ResultSet& result_set);
|
||||
|
||||
static Status
|
||||
MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target, uint64_t topk,
|
||||
bool ascending);
|
||||
|
||||
static Status
|
||||
TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending, scheduler::ResultSet& result_target);
|
||||
TopkResult(const std::vector<long> &input_ids,
|
||||
const std::vector<float> &input_distance,
|
||||
uint64_t input_k,
|
||||
uint64_t nq,
|
||||
uint64_t topk,
|
||||
bool ascending,
|
||||
scheduler::ResultSet &result);
|
||||
|
||||
public:
|
||||
TableFileSchemaPtr file_;
|
||||
|
|
|
@ -22,13 +22,10 @@
|
|||
#include "scheduler/task/SearchTask.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
|
||||
using namespace milvus::scheduler;
|
||||
|
||||
namespace {
|
||||
|
||||
namespace ms = milvus;
|
||||
|
||||
static constexpr uint64_t NQ = 15;
|
||||
static constexpr uint64_t TOP_K = 64;
|
||||
|
||||
void
|
||||
BuildResult(uint64_t nq,
|
||||
uint64_t topk,
|
||||
|
@ -48,76 +45,36 @@ BuildResult(uint64_t nq,
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
CheckResult(const ms::scheduler::Id2DistanceMap &src_1,
|
||||
const ms::scheduler::Id2DistanceMap &src_2,
|
||||
const ms::scheduler::Id2DistanceMap &target,
|
||||
bool ascending) {
|
||||
for (uint64_t i = 0; i < target.size() - 1; i++) {
|
||||
void CheckTopkResult(const std::vector<long> &input_ids_1,
|
||||
const std::vector<float> &input_distance_1,
|
||||
const std::vector<long> &input_ids_2,
|
||||
const std::vector<float> &input_distance_2,
|
||||
uint64_t nq,
|
||||
uint64_t topk,
|
||||
bool ascending,
|
||||
const ResultSet& result) {
|
||||
ASSERT_EQ(result.size(), nq);
|
||||
ASSERT_EQ(input_ids_1.size(), input_distance_1.size());
|
||||
ASSERT_EQ(input_ids_2.size(), input_distance_2.size());
|
||||
|
||||
uint64_t input_k1 = input_ids_1.size() / nq;
|
||||
uint64_t input_k2 = input_ids_2.size() / nq;
|
||||
|
||||
for (int64_t i = 0; i < nq; i++) {
|
||||
std::vector<float> src_vec(input_distance_1.begin()+i*input_k1, input_distance_1.begin()+(i+1)*input_k1);
|
||||
src_vec.insert(src_vec.end(), input_distance_2.begin()+i*input_k2, input_distance_2.begin()+(i+1)*input_k2);
|
||||
if (ascending) {
|
||||
ASSERT_LE(target[i].second, target[i + 1].second);
|
||||
std::sort(src_vec.begin(), src_vec.end());
|
||||
} else {
|
||||
ASSERT_GE(target[i].second, target[i + 1].second);
|
||||
}
|
||||
}
|
||||
|
||||
using ID2DistMap = std::map<int64_t, float>;
|
||||
ID2DistMap src_map_1, src_map_2;
|
||||
for (const auto &pair : src_1) {
|
||||
src_map_1.insert(pair);
|
||||
}
|
||||
for (const auto &pair : src_2) {
|
||||
src_map_2.insert(pair);
|
||||
}
|
||||
|
||||
for (const auto &pair : target) {
|
||||
ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end());
|
||||
|
||||
float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first];
|
||||
ASSERT_LT(fabs(pair.second - dist), std::numeric_limits<float>::epsilon());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
CheckCluster(const std::vector<int64_t> &target_ids,
|
||||
const std::vector<float> &target_distence,
|
||||
const ms::scheduler::ResultSet &src_result,
|
||||
int64_t nq,
|
||||
int64_t topk) {
|
||||
ASSERT_EQ(src_result.size(), nq);
|
||||
for (int64_t i = 0; i < nq; i++) {
|
||||
auto &res = src_result[i];
|
||||
ASSERT_EQ(res.size(), topk);
|
||||
|
||||
if (res.empty()) {
|
||||
continue;
|
||||
std::sort(src_vec.begin(), src_vec.end(), std::greater<float>());
|
||||
}
|
||||
|
||||
ASSERT_EQ(res[0].first, target_ids[i * topk]);
|
||||
ASSERT_EQ(res[topk - 1].first, target_ids[i * topk + topk - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
CheckTopkResult(const ms::scheduler::ResultSet &src_result,
|
||||
bool ascending,
|
||||
int64_t nq,
|
||||
int64_t topk) {
|
||||
ASSERT_EQ(src_result.size(), nq);
|
||||
for (int64_t i = 0; i < nq; i++) {
|
||||
auto &res = src_result[i];
|
||||
ASSERT_EQ(res.size(), topk);
|
||||
|
||||
if (res.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int64_t k = 0; k < topk - 1; k++) {
|
||||
if (ascending) {
|
||||
ASSERT_LE(res[k].second, res[k + 1].second);
|
||||
} else {
|
||||
ASSERT_GE(res[k].second, res[k + 1].second);
|
||||
uint64_t n = std::min(topk, input_k1+input_k2);
|
||||
for (uint64_t j = 0; j < n; j++) {
|
||||
if (src_vec[j] != result[i][j].second) {
|
||||
std::cout << src_vec[j] << " " << result[i][j].second << std::endl;
|
||||
}
|
||||
ASSERT_TRUE(src_vec[j] == result[i][j].second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -125,179 +82,117 @@ CheckTopkResult(const ms::scheduler::ResultSet &src_result,
|
|||
} // namespace
|
||||
|
||||
TEST(DBSearchTest, TOPK_TEST) {
|
||||
uint64_t NQ = 15;
|
||||
uint64_t TOP_K = 64;
|
||||
bool ascending;
|
||||
std::vector<long> ids1, ids2;
|
||||
std::vector<float> dist1, dist2;
|
||||
ResultSet result;
|
||||
milvus::Status status;
|
||||
|
||||
/* test1, id1/dist1 valid, id2/dist2 empty */
|
||||
ascending = true;
|
||||
BuildResult(NQ, TOP_K, ascending, ids1, dist1);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test2, id1/dist1 valid, id2/dist2 valid */
|
||||
BuildResult(NQ, TOP_K, ascending, ids2, dist2);
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test3, id1/dist1 small topk */
|
||||
ids1.clear();
|
||||
dist1.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K/2, ascending, ids1, dist1);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test4, id1/dist1 small topk, id2/dist2 small topk */
|
||||
ids2.clear();
|
||||
dist2.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K/3, ascending, ids2, dist2);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////
|
||||
ascending = false;
|
||||
ids1.clear();
|
||||
dist1.clear();
|
||||
ids2.clear();
|
||||
dist2.clear();
|
||||
result.clear();
|
||||
|
||||
/* test1, id1/dist1 valid, id2/dist2 empty */
|
||||
BuildResult(NQ, TOP_K, ascending, ids1, dist1);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test2, id1/dist1 valid, id2/dist2 valid */
|
||||
BuildResult(NQ, TOP_K, ascending, ids2, dist2);
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test3, id1/dist1 small topk */
|
||||
ids1.clear();
|
||||
dist1.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K/2, ascending, ids1, dist1);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test4, id1/dist1 small topk, id2/dist2 small topk */
|
||||
ids2.clear();
|
||||
dist2.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K/3, ascending, ids2, dist2);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, REDUCE_PERF_TEST) {
|
||||
int32_t nq = 100;
|
||||
int32_t top_k = 1000;
|
||||
int32_t index_file_num = 478; /* sift1B dataset, index files num */
|
||||
bool ascending = true;
|
||||
std::vector<int64_t> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
ms::scheduler::ResultSet src_result;
|
||||
auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
ASSERT_TRUE(src_result.empty());
|
||||
std::vector<long> input_ids;
|
||||
std::vector<float> input_distance;
|
||||
ResultSet final_result;
|
||||
milvus::Status status;
|
||||
|
||||
BuildResult(NQ, TOP_K, ascending, target_ids, target_distence);
|
||||
status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(src_result.size(), NQ);
|
||||
double span, reduce_cost = 0.0;
|
||||
milvus::TimeRecorder rc("");
|
||||
|
||||
ms::scheduler::ResultSet target_result;
|
||||
status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for (int32_t i = 0; i < index_file_num; i++) {
|
||||
BuildResult(nq, top_k, ascending, input_ids, input_distance);
|
||||
|
||||
status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, src_result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
rc.RecordSection("do search for context: " + std::to_string(i));
|
||||
|
||||
status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_TRUE(src_result.empty());
|
||||
ASSERT_EQ(target_result.size(), NQ);
|
||||
// pick up topk result
|
||||
status = XSearchTask::TopkResult(input_ids, input_distance, top_k, nq, top_k, ascending, final_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(final_result.size(), nq);
|
||||
|
||||
std::vector<int64_t> src_ids;
|
||||
std::vector<float> src_distence;
|
||||
uint64_t wrong_topk = TOP_K - 10;
|
||||
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
|
||||
|
||||
status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for (uint64_t i = 0; i < NQ; i++) {
|
||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||
}
|
||||
|
||||
wrong_topk = TOP_K + 10;
|
||||
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
|
||||
|
||||
status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for (uint64_t i = 0; i < NQ; i++) {
|
||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||
span = rc.RecordSection("reduce topk for context: " + std::to_string(i));
|
||||
reduce_cost += span;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, MERGE_TEST) {
|
||||
bool ascending = true;
|
||||
std::vector<int64_t> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
std::vector<int64_t> src_ids;
|
||||
std::vector<float> src_distence;
|
||||
ms::scheduler::ResultSet src_result, target_result;
|
||||
|
||||
uint64_t src_count = 5, target_count = 8;
|
||||
BuildResult(1, src_count, ascending, src_ids, src_distence);
|
||||
BuildResult(1, target_count, ascending, target_ids, target_distence);
|
||||
auto status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
{
|
||||
ms::scheduler::Id2DistanceMap src = src_result[0];
|
||||
ms::scheduler::Id2DistanceMap target = target_result[0];
|
||||
status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), 10);
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
|
||||
{
|
||||
ms::scheduler::Id2DistanceMap src = src_result[0];
|
||||
ms::scheduler::Id2DistanceMap target;
|
||||
status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count);
|
||||
ASSERT_TRUE(src.empty());
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
|
||||
{
|
||||
ms::scheduler::Id2DistanceMap src = src_result[0];
|
||||
ms::scheduler::Id2DistanceMap target = target_result[0];
|
||||
status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count + target_count);
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
|
||||
{
|
||||
ms::scheduler::Id2DistanceMap target = src_result[0];
|
||||
ms::scheduler::Id2DistanceMap src = target_result[0];
|
||||
status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count + target_count);
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) {
|
||||
bool ascending = true;
|
||||
std::vector<int64_t> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
ms::scheduler::ResultSet src_result;
|
||||
|
||||
auto DoCluster = [&](int64_t nq, int64_t topk) {
|
||||
ms::TimeRecorder rc("DoCluster");
|
||||
src_result.clear();
|
||||
BuildResult(nq, topk, ascending, target_ids, target_distence);
|
||||
rc.RecordSection("build id/dietance map");
|
||||
|
||||
auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(src_result.size(), nq);
|
||||
|
||||
rc.RecordSection("cluster result");
|
||||
|
||||
CheckCluster(target_ids, target_distence, src_result, nq, topk);
|
||||
rc.RecordSection("check result");
|
||||
};
|
||||
|
||||
DoCluster(10000, 1000);
|
||||
DoCluster(333, 999);
|
||||
DoCluster(1, 1000);
|
||||
DoCluster(1, 1);
|
||||
DoCluster(7, 0);
|
||||
DoCluster(9999, 1);
|
||||
DoCluster(10001, 1);
|
||||
DoCluster(58273, 1234);
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, PARALLEL_TOPK_TEST) {
|
||||
std::vector<int64_t> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
ms::scheduler::ResultSet src_result;
|
||||
|
||||
std::vector<int64_t> insufficient_ids;
|
||||
std::vector<float> insufficient_distence;
|
||||
ms::scheduler::ResultSet insufficient_result;
|
||||
|
||||
auto DoTopk = [&](int64_t nq, int64_t topk, int64_t insufficient_topk, bool ascending) {
|
||||
src_result.clear();
|
||||
insufficient_result.clear();
|
||||
|
||||
ms::TimeRecorder rc("DoCluster");
|
||||
|
||||
BuildResult(nq, topk, ascending, target_ids, target_distence);
|
||||
auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
|
||||
rc.RecordSection("cluster result");
|
||||
|
||||
BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence);
|
||||
status = ms::scheduler::XSearchTask::ClusterResult(target_ids,
|
||||
target_distence,
|
||||
nq,
|
||||
insufficient_topk,
|
||||
insufficient_result);
|
||||
rc.RecordSection("cluster result");
|
||||
|
||||
ms::scheduler::XSearchTask::TopkResult(insufficient_result, topk, ascending, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
rc.RecordSection("topk");
|
||||
|
||||
CheckTopkResult(src_result, ascending, nq, topk);
|
||||
rc.RecordSection("check result");
|
||||
};
|
||||
|
||||
DoTopk(5, 10, 4, false);
|
||||
DoTopk(20005, 998, 123, true);
|
||||
// DoTopk(9987, 12, 10, false);
|
||||
// DoTopk(77777, 1000, 1, false);
|
||||
// DoTopk(5432, 8899, 8899, true);
|
||||
std::cout << "total reduce time: " << reduce_cost/1000 << " ms" << std::endl;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue