mirror of https://github.com/milvus-io/milvus.git
Merge branch 'branch-0.4.0' into 'branch-0.4.0'
MS-539 Remove old task code See merge request megasearch/milvus!540 Former-commit-id: 7021b1a8082102e8a8f2cf969aa4c837da4bafa0pull/191/head
commit
6e28241f2f
|
@ -114,6 +114,7 @@ Please mark all change in change log and use the ticket from JIRA.
|
||||||
- MS-531 - Disable next version code
|
- MS-531 - Disable next version code
|
||||||
- MS-533 - Update resource_test to cover dump function
|
- MS-533 - Update resource_test to cover dump function
|
||||||
- MS-523 - Config file validation
|
- MS-523 - Config file validation
|
||||||
|
- MS-539 - Remove old task code
|
||||||
|
|
||||||
## New Feature
|
## New Feature
|
||||||
- MS-343 - Implement ResourceMgr
|
- MS-343 - Implement ResourceMgr
|
||||||
|
|
|
@ -89,7 +89,6 @@ TaskScheduler::TaskDispatchWorker() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef NEW_SCHEDULER
|
|
||||||
// TODO: Put task into Disk-TaskTable
|
// TODO: Put task into Disk-TaskTable
|
||||||
auto task = TaskConvert(task_ptr);
|
auto task = TaskConvert(task_ptr);
|
||||||
auto disk_list = ResMgrInst::GetInstance()->GetDiskResources();
|
auto disk_list = ResMgrInst::GetInstance()->GetDiskResources();
|
||||||
|
@ -98,16 +97,7 @@ TaskScheduler::TaskDispatchWorker() {
|
||||||
disk->task_table().Put(task);
|
disk->task_table().Put(task);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
//execute task
|
|
||||||
ScheduleTaskPtr next_task = task_ptr->Execute();
|
|
||||||
if(next_task != nullptr) {
|
|
||||||
task_queue_.Put(next_task);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool
|
bool
|
||||||
|
@ -126,8 +116,6 @@ TaskScheduler::TaskWorker() {
|
||||||
task_queue_.Put(next_task);
|
task_queue_.Put(next_task);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,11 +17,6 @@ DeleteTask::DeleteTask(const DeleteContextPtr& context)
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<IScheduleTask> DeleteTask::Execute() {
|
std::shared_ptr<IScheduleTask> DeleteTask::Execute() {
|
||||||
|
|
||||||
if(context_ != nullptr && context_->meta() != nullptr) {
|
|
||||||
context_->meta()->DeleteTableFiles(context_->table_id());
|
|
||||||
}
|
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,82 +15,13 @@ namespace zilliz {
|
||||||
namespace milvus {
|
namespace milvus {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
|
|
||||||
namespace {
|
|
||||||
void CollectFileMetrics(int file_type, size_t file_size) {
|
|
||||||
switch(file_type) {
|
|
||||||
case meta::TableFileSchema::RAW:
|
|
||||||
case meta::TableFileSchema::TO_INDEX: {
|
|
||||||
server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size);
|
|
||||||
server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size);
|
|
||||||
server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default: {
|
|
||||||
server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size);
|
|
||||||
server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size);
|
|
||||||
server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
IndexLoadTask::IndexLoadTask()
|
IndexLoadTask::IndexLoadTask()
|
||||||
: IScheduleTask(ScheduleTaskType::kIndexLoad) {
|
: IScheduleTask(ScheduleTaskType::kIndexLoad) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<IScheduleTask> IndexLoadTask::Execute() {
|
std::shared_ptr<IScheduleTask> IndexLoadTask::Execute() {
|
||||||
server::TimeRecorder rc("");
|
return nullptr;
|
||||||
//step 1: load index
|
|
||||||
ExecutionEnginePtr index_ptr = EngineFactory::Build(file_->dimension_,
|
|
||||||
file_->location_,
|
|
||||||
(EngineType)file_->engine_type_,
|
|
||||||
(MetricType)file_->metric_type_,
|
|
||||||
file_->nlist_);
|
|
||||||
|
|
||||||
try {
|
|
||||||
auto stat = index_ptr->Load();
|
|
||||||
if(!stat.ok()) {
|
|
||||||
//typical error: file not available
|
|
||||||
ENGINE_LOG_ERROR << "Failed to load index file: file not available";
|
|
||||||
|
|
||||||
for(auto& context : search_contexts_) {
|
|
||||||
context->IndexSearchDone(file_->id_);//mark as done avoid dead lock, even failed
|
|
||||||
}
|
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
} catch (std::exception& ex) {
|
|
||||||
//typical error: out of disk space or permition denied
|
|
||||||
std::string msg = "Failed to load index file: " + std::string(ex.what());
|
|
||||||
ENGINE_LOG_ERROR << msg;
|
|
||||||
|
|
||||||
for(auto& context : search_contexts_) {
|
|
||||||
context->IndexSearchDone(file_->id_);//mark as done avoid dead lock, even failed
|
|
||||||
}
|
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t file_size = index_ptr->PhysicalSize();
|
|
||||||
|
|
||||||
std::string info = "Load file id:" + std::to_string(file_->id_) + " file type:" + std::to_string(file_->file_type_)
|
|
||||||
+ " size:" + std::to_string(file_size) + " bytes from location: " + file_->location_ + " totally cost";
|
|
||||||
double span = rc.ElapseFromBegin(info);
|
|
||||||
for(auto& context : search_contexts_) {
|
|
||||||
context->AccumLoadCost(span);
|
|
||||||
}
|
|
||||||
|
|
||||||
CollectFileMetrics(file_->file_type_, file_size);
|
|
||||||
|
|
||||||
//step 2: return search task for later execution
|
|
||||||
SearchTaskPtr task_ptr = std::make_shared<SearchTask>();
|
|
||||||
task_ptr->index_id_ = file_->id_;
|
|
||||||
task_ptr->file_type_ = file_->file_type_;
|
|
||||||
task_ptr->index_engine_ = index_ptr;
|
|
||||||
task_ptr->search_contexts_.swap(search_contexts_);
|
|
||||||
return std::static_pointer_cast<IScheduleTask>(task_ptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,259 +14,12 @@ namespace zilliz {
|
||||||
namespace milvus {
|
namespace milvus {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 1000000;
|
|
||||||
static constexpr size_t PARALLEL_REDUCE_BATCH = 1000;
|
|
||||||
|
|
||||||
bool NeedParallelReduce(uint64_t nq, uint64_t topk) {
|
|
||||||
server::ServerConfig &config = server::ServerConfig::GetInstance();
|
|
||||||
server::ConfigNode& db_config = config.GetConfig(server::CONFIG_DB);
|
|
||||||
bool need_parallel = db_config.GetBoolValue(server::CONFIG_DB_PARALLEL_REDUCE, false);
|
|
||||||
if(!need_parallel) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return nq*topk >= PARALLEL_REDUCE_THRESHOLD;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ParallelReduce(std::function<void(size_t, size_t)>& reduce_function, size_t max_index) {
|
|
||||||
size_t reduce_batch = PARALLEL_REDUCE_BATCH;
|
|
||||||
|
|
||||||
auto thread_count = std::thread::hardware_concurrency() - 1; //not all core do this work
|
|
||||||
if(thread_count > 0) {
|
|
||||||
reduce_batch = max_index/thread_count + 1;
|
|
||||||
}
|
|
||||||
ENGINE_LOG_DEBUG << "use " << thread_count <<
|
|
||||||
" thread parallelly do reduce, each thread process " << reduce_batch << " vectors";
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<std::thread> > thread_array;
|
|
||||||
size_t from_index = 0;
|
|
||||||
while(from_index < max_index) {
|
|
||||||
size_t to_index = from_index + reduce_batch;
|
|
||||||
if(to_index > max_index) {
|
|
||||||
to_index = max_index;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto reduce_thread = std::make_shared<std::thread>(reduce_function, from_index, to_index);
|
|
||||||
thread_array.push_back(reduce_thread);
|
|
||||||
|
|
||||||
from_index = to_index;
|
|
||||||
}
|
|
||||||
|
|
||||||
for(auto& thread_ptr : thread_array) {
|
|
||||||
thread_ptr->join();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
SearchTask::SearchTask()
|
SearchTask::SearchTask()
|
||||||
: IScheduleTask(ScheduleTaskType::kSearch) {
|
: IScheduleTask(ScheduleTaskType::kSearch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<IScheduleTask> SearchTask::Execute() {
|
std::shared_ptr<IScheduleTask> SearchTask::Execute() {
|
||||||
if(index_engine_ == nullptr) {
|
return nullptr;
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
ENGINE_LOG_DEBUG << "Searching in file id:" << index_id_<< " with "
|
|
||||||
<< search_contexts_.size() << " tasks";
|
|
||||||
|
|
||||||
server::TimeRecorder rc("DoSearch file id:" + std::to_string(index_id_));
|
|
||||||
|
|
||||||
server::CollectSearchTaskMetrics metrics(file_type_);
|
|
||||||
|
|
||||||
bool metric_l2 = (index_engine_->IndexMetricType() == MetricType::L2);
|
|
||||||
|
|
||||||
std::vector<long> output_ids;
|
|
||||||
std::vector<float> output_distence;
|
|
||||||
for(auto& context : search_contexts_) {
|
|
||||||
//step 1: allocate memory
|
|
||||||
auto inner_k = context->topk();
|
|
||||||
auto nprobe = context->nprobe();
|
|
||||||
output_ids.resize(inner_k*context->nq());
|
|
||||||
output_distence.resize(inner_k*context->nq());
|
|
||||||
|
|
||||||
try {
|
|
||||||
//step 2: search
|
|
||||||
index_engine_->Search(context->nq(), context->vectors(), inner_k, nprobe, output_distence.data(),
|
|
||||||
output_ids.data());
|
|
||||||
|
|
||||||
double span = rc.RecordSection("do search for context:" + context->Identity());
|
|
||||||
context->AccumSearchCost(span);
|
|
||||||
|
|
||||||
//step 3: cluster result
|
|
||||||
SearchContext::ResultSet result_set;
|
|
||||||
auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
|
|
||||||
SearchTask::ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set);
|
|
||||||
|
|
||||||
span = rc.RecordSection("cluster result for context:" + context->Identity());
|
|
||||||
context->AccumReduceCost(span);
|
|
||||||
|
|
||||||
//step 4: pick up topk result
|
|
||||||
SearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult());
|
|
||||||
|
|
||||||
span = rc.RecordSection("reduce topk for context:" + context->Identity());
|
|
||||||
context->AccumReduceCost(span);
|
|
||||||
|
|
||||||
} catch (std::exception& ex) {
|
|
||||||
ENGINE_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
|
|
||||||
context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
//step 5: notify to send result to client
|
|
||||||
context->IndexSearchDone(index_id_);
|
|
||||||
}
|
|
||||||
|
|
||||||
rc.ElapseFromBegin("totally cost");
|
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
|
|
||||||
const std::vector<float> &output_distence,
|
|
||||||
uint64_t nq,
|
|
||||||
uint64_t topk,
|
|
||||||
SearchContext::ResultSet &result_set) {
|
|
||||||
if(output_ids.size() < nq*topk || output_distence.size() < nq*topk) {
|
|
||||||
std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) +
|
|
||||||
" distance array size: " + std::to_string(output_distence.size());
|
|
||||||
ENGINE_LOG_ERROR << msg;
|
|
||||||
return Status(DB_ERROR, msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
result_set.clear();
|
|
||||||
result_set.resize(nq);
|
|
||||||
|
|
||||||
std::function<void(size_t, size_t)> reduce_worker = [&](size_t from_index, size_t to_index) {
|
|
||||||
for (auto i = from_index; i < to_index; i++) {
|
|
||||||
SearchContext::Id2DistanceMap id_distance;
|
|
||||||
id_distance.reserve(topk);
|
|
||||||
for (auto k = 0; k < topk; k++) {
|
|
||||||
uint64_t index = i * topk + k;
|
|
||||||
if(output_ids[index] < 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
id_distance.push_back(std::make_pair(output_ids[index], output_distence[index]));
|
|
||||||
}
|
|
||||||
result_set[i] = id_distance;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if(NeedParallelReduce(nq, topk)) {
|
|
||||||
ParallelReduce(reduce_worker, nq);
|
|
||||||
} else {
|
|
||||||
reduce_worker(0, nq);
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
|
|
||||||
SearchContext::Id2DistanceMap &distance_target,
|
|
||||||
uint64_t topk,
|
|
||||||
bool ascending) {
|
|
||||||
//Note: the score_src and score_target are already arranged by score in ascending order
|
|
||||||
if(distance_src.empty()) {
|
|
||||||
ENGINE_LOG_WARNING << "Empty distance source array";
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if(distance_target.empty()) {
|
|
||||||
distance_target.swap(distance_src);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t src_count = distance_src.size();
|
|
||||||
size_t target_count = distance_target.size();
|
|
||||||
SearchContext::Id2DistanceMap distance_merged;
|
|
||||||
distance_merged.reserve(topk);
|
|
||||||
size_t src_index = 0, target_index = 0;
|
|
||||||
while(true) {
|
|
||||||
//all score_src items are merged, if score_merged.size() still less than topk
|
|
||||||
//move items from score_target to score_merged until score_merged.size() equal topk
|
|
||||||
if(src_index >= src_count) {
|
|
||||||
for(size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) {
|
|
||||||
distance_merged.push_back(distance_target[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
//all score_target items are merged, if score_merged.size() still less than topk
|
|
||||||
//move items from score_src to score_merged until score_merged.size() equal topk
|
|
||||||
if(target_index >= target_count) {
|
|
||||||
for(size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) {
|
|
||||||
distance_merged.push_back(distance_src[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
//compare score,
|
|
||||||
// if ascending = true, put smallest score to score_merged one by one
|
|
||||||
// else, put largest score to score_merged one by one
|
|
||||||
auto& src_pair = distance_src[src_index];
|
|
||||||
auto& target_pair = distance_target[target_index];
|
|
||||||
if(ascending){
|
|
||||||
if(src_pair.second > target_pair.second) {
|
|
||||||
distance_merged.push_back(target_pair);
|
|
||||||
target_index++;
|
|
||||||
} else {
|
|
||||||
distance_merged.push_back(src_pair);
|
|
||||||
src_index++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if(src_pair.second < target_pair.second) {
|
|
||||||
distance_merged.push_back(target_pair);
|
|
||||||
target_index++;
|
|
||||||
} else {
|
|
||||||
distance_merged.push_back(src_pair);
|
|
||||||
src_index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//score_merged.size() already equal topk
|
|
||||||
if(distance_merged.size() >= topk) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
distance_target.swap(distance_merged);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
|
|
||||||
uint64_t topk,
|
|
||||||
bool ascending,
|
|
||||||
SearchContext::ResultSet &result_target) {
|
|
||||||
if (result_target.empty()) {
|
|
||||||
result_target.swap(result_src);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result_src.size() != result_target.size()) {
|
|
||||||
std::string msg = "Invalid result set size";
|
|
||||||
ENGINE_LOG_ERROR << msg;
|
|
||||||
return Status(DB_ERROR, msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::function<void(size_t, size_t)> ReduceWorker = [&](size_t from_index, size_t to_index) {
|
|
||||||
for (size_t i = from_index; i < to_index; i++) {
|
|
||||||
SearchContext::Id2DistanceMap &score_src = result_src[i];
|
|
||||||
SearchContext::Id2DistanceMap &score_target = result_target[i];
|
|
||||||
SearchTask::MergeResult(score_src, score_target, topk, ascending);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if(NeedParallelReduce(result_src.size(), topk)) {
|
|
||||||
ParallelReduce(ReduceWorker, result_src.size());
|
|
||||||
} else {
|
|
||||||
ReduceWorker(0, result_src.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,22 +19,6 @@ public:
|
||||||
|
|
||||||
virtual std::shared_ptr<IScheduleTask> Execute() override;
|
virtual std::shared_ptr<IScheduleTask> Execute() override;
|
||||||
|
|
||||||
static Status ClusterResult(const std::vector<long> &output_ids,
|
|
||||||
const std::vector<float> &output_distence,
|
|
||||||
uint64_t nq,
|
|
||||||
uint64_t topk,
|
|
||||||
SearchContext::ResultSet &result_set);
|
|
||||||
|
|
||||||
static Status MergeResult(SearchContext::Id2DistanceMap &distance_src,
|
|
||||||
SearchContext::Id2DistanceMap &distance_target,
|
|
||||||
uint64_t topk,
|
|
||||||
bool ascending);
|
|
||||||
|
|
||||||
static Status TopkResult(SearchContext::ResultSet &result_src,
|
|
||||||
uint64_t topk,
|
|
||||||
bool ascending,
|
|
||||||
SearchContext::ResultSet &result_target);
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
size_t index_id_ = 0;
|
size_t index_id_ = 0;
|
||||||
int file_type_ = 0; //for metrics
|
int file_type_ = 0; //for metrics
|
||||||
|
|
|
@ -10,6 +10,8 @@
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <src/scheduler/task/SearchTask.h>
|
||||||
|
|
||||||
|
|
||||||
using namespace zilliz::milvus;
|
using namespace zilliz::milvus;
|
||||||
|
|
||||||
|
@ -114,23 +116,23 @@ TEST(DBSearchTest, TOPK_TEST) {
|
||||||
std::vector<long> target_ids;
|
std::vector<long> target_ids;
|
||||||
std::vector<float> target_distence;
|
std::vector<float> target_distence;
|
||||||
engine::SearchContext::ResultSet src_result;
|
engine::SearchContext::ResultSet src_result;
|
||||||
auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
|
auto status = engine::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
ASSERT_TRUE(src_result.empty());
|
ASSERT_TRUE(src_result.empty());
|
||||||
|
|
||||||
BuildResult(NQ, TOP_K, ascending, target_ids, target_distence);
|
BuildResult(NQ, TOP_K, ascending, target_ids, target_distence);
|
||||||
status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
|
status = engine::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
ASSERT_EQ(src_result.size(), NQ);
|
ASSERT_EQ(src_result.size(), NQ);
|
||||||
|
|
||||||
engine::SearchContext::ResultSet target_result;
|
engine::SearchContext::ResultSet target_result;
|
||||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, ascending, target_result);
|
status = engine::XSearchTask::TopkResult(target_result, TOP_K, ascending, target_result);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
|
|
||||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, ascending, src_result);
|
status = engine::XSearchTask::TopkResult(target_result, TOP_K, ascending, src_result);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
|
|
||||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
status = engine::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
ASSERT_TRUE(src_result.empty());
|
ASSERT_TRUE(src_result.empty());
|
||||||
ASSERT_EQ(target_result.size(), NQ);
|
ASSERT_EQ(target_result.size(), NQ);
|
||||||
|
@ -140,10 +142,10 @@ TEST(DBSearchTest, TOPK_TEST) {
|
||||||
uint64_t wrong_topk = TOP_K - 10;
|
uint64_t wrong_topk = TOP_K - 10;
|
||||||
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
|
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
|
||||||
|
|
||||||
status = engine::SearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
|
status = engine::XSearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
|
|
||||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
status = engine::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
for(uint64_t i = 0; i < NQ; i++) {
|
for(uint64_t i = 0; i < NQ; i++) {
|
||||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||||
|
@ -152,7 +154,7 @@ TEST(DBSearchTest, TOPK_TEST) {
|
||||||
wrong_topk = TOP_K + 10;
|
wrong_topk = TOP_K + 10;
|
||||||
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
|
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
|
||||||
|
|
||||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
status = engine::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
for(uint64_t i = 0; i < NQ; i++) {
|
for(uint64_t i = 0; i < NQ; i++) {
|
||||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||||
|
@ -170,15 +172,15 @@ TEST(DBSearchTest, MERGE_TEST) {
|
||||||
uint64_t src_count = 5, target_count = 8;
|
uint64_t src_count = 5, target_count = 8;
|
||||||
BuildResult(1, src_count, ascending, src_ids, src_distence);
|
BuildResult(1, src_count, ascending, src_ids, src_distence);
|
||||||
BuildResult(1, target_count, ascending, target_ids, target_distence);
|
BuildResult(1, target_count, ascending, target_ids, target_distence);
|
||||||
auto status = engine::SearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
|
auto status = engine::XSearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
status = engine::SearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
|
status = engine::XSearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
|
|
||||||
{
|
{
|
||||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||||
engine::SearchContext::Id2DistanceMap target = target_result[0];
|
engine::SearchContext::Id2DistanceMap target = target_result[0];
|
||||||
status = engine::SearchTask::MergeResult(src, target, 10, ascending);
|
status = engine::XSearchTask::MergeResult(src, target, 10, ascending);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
ASSERT_EQ(target.size(), 10);
|
ASSERT_EQ(target.size(), 10);
|
||||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||||
|
@ -187,7 +189,7 @@ TEST(DBSearchTest, MERGE_TEST) {
|
||||||
{
|
{
|
||||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||||
engine::SearchContext::Id2DistanceMap target;
|
engine::SearchContext::Id2DistanceMap target;
|
||||||
status = engine::SearchTask::MergeResult(src, target, 10, ascending);
|
status = engine::XSearchTask::MergeResult(src, target, 10, ascending);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
ASSERT_EQ(target.size(), src_count);
|
ASSERT_EQ(target.size(), src_count);
|
||||||
ASSERT_TRUE(src.empty());
|
ASSERT_TRUE(src.empty());
|
||||||
|
@ -197,7 +199,7 @@ TEST(DBSearchTest, MERGE_TEST) {
|
||||||
{
|
{
|
||||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||||
engine::SearchContext::Id2DistanceMap target = target_result[0];
|
engine::SearchContext::Id2DistanceMap target = target_result[0];
|
||||||
status = engine::SearchTask::MergeResult(src, target, 30, ascending);
|
status = engine::XSearchTask::MergeResult(src, target, 30, ascending);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
ASSERT_EQ(target.size(), src_count + target_count);
|
ASSERT_EQ(target.size(), src_count + target_count);
|
||||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||||
|
@ -206,7 +208,7 @@ TEST(DBSearchTest, MERGE_TEST) {
|
||||||
{
|
{
|
||||||
engine::SearchContext::Id2DistanceMap target = src_result[0];
|
engine::SearchContext::Id2DistanceMap target = src_result[0];
|
||||||
engine::SearchContext::Id2DistanceMap src = target_result[0];
|
engine::SearchContext::Id2DistanceMap src = target_result[0];
|
||||||
status = engine::SearchTask::MergeResult(src, target, 30, ascending);
|
status = engine::XSearchTask::MergeResult(src, target, 30, ascending);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
ASSERT_EQ(target.size(), src_count + target_count);
|
ASSERT_EQ(target.size(), src_count + target_count);
|
||||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||||
|
@ -229,7 +231,7 @@ TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) {
|
||||||
BuildResult(nq, topk, ascending, target_ids, target_distence);
|
BuildResult(nq, topk, ascending, target_ids, target_distence);
|
||||||
rc.RecordSection("build id/dietance map");
|
rc.RecordSection("build id/dietance map");
|
||||||
|
|
||||||
auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
|
auto status = engine::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
ASSERT_EQ(src_result.size(), nq);
|
ASSERT_EQ(src_result.size(), nq);
|
||||||
|
|
||||||
|
@ -269,14 +271,14 @@ TEST(DBSearchTest, PARALLEL_TOPK_TEST) {
|
||||||
server::TimeRecorder rc("DoCluster");
|
server::TimeRecorder rc("DoCluster");
|
||||||
|
|
||||||
BuildResult(nq, topk, ascending, target_ids, target_distence);
|
BuildResult(nq, topk, ascending, target_ids, target_distence);
|
||||||
auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
|
auto status = engine::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
|
||||||
rc.RecordSection("cluster result");
|
rc.RecordSection("cluster result");
|
||||||
|
|
||||||
BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence);
|
BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence);
|
||||||
status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, insufficient_topk, insufficient_result);
|
status = engine::XSearchTask::ClusterResult(target_ids, target_distence, nq, insufficient_topk, insufficient_result);
|
||||||
rc.RecordSection("cluster result");
|
rc.RecordSection("cluster result");
|
||||||
|
|
||||||
engine::SearchTask::TopkResult(insufficient_result, topk, ascending, src_result);
|
engine::XSearchTask::TopkResult(insufficient_result, topk, ascending, src_result);
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
rc.RecordSection("topk");
|
rc.RecordSection("topk");
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue