diff --git a/cpp/CHANGELOG.md b/cpp/CHANGELOG.md index bc3d41556e..7f860baf7c 100644 --- a/cpp/CHANGELOG.md +++ b/cpp/CHANGELOG.md @@ -93,7 +93,11 @@ Please mark all change in change log and use the ticket from JIRA. - MS-487 - Define metric type in CreateTable - MS-488 - Improve code format in scheduler - MS-495 - cmake: integrated knowhere +- MS-496 - Change the top_k limitation from 1024 to 2048 +- MS-502 - Update tasktable_test in scheduler +- MS-504 - Update node_test in scheduler - MS-505 - Install core unit test and add to coverage +- MS-508 - Update normal_test in scheduler ## New Feature - MS-343 - Implement ResourceMgr diff --git a/cpp/src/scheduler/Scheduler.cpp b/cpp/src/scheduler/Scheduler.cpp index dcd17e31cf..c77b37648d 100644 --- a/cpp/src/scheduler/Scheduler.cpp +++ b/cpp/src/scheduler/Scheduler.cpp @@ -108,6 +108,7 @@ void Scheduler::OnFinishTask(const EventPtr &event) { } +// TODO: refactor the function void Scheduler::OnLoadCompleted(const EventPtr &event) { auto load_completed_event = std::static_pointer_cast(event); @@ -120,18 +121,23 @@ Scheduler::OnLoadCompleted(const EventPtr &event) { if (not resource->HasExecutor() && load_completed_event->task_table_item_->Move()) { auto task = load_completed_event->task_table_item_->task; auto search_task = std::static_pointer_cast(task); - auto location = search_task->index_engine_->GetLocation(); bool moved = false; - for (auto i = 0; i < res_mgr_.lock()->GetNumGpuResource(); ++i) { - auto index = zilliz::milvus::cache::GpuCacheMgr::GetInstance(i)->GetIndex(location); - if (index != nullptr) { - moved = true; - auto dest_resource = res_mgr_.lock()->GetResource(ResourceType::GPU, i); - Action::PushTaskToResource(load_completed_event->task_table_item_->task, dest_resource); - break; + // to support test task, REFACTOR + if (auto index_engine = search_task->index_engine_) { + auto location = index_engine->GetLocation(); + + for (auto i = 0; i < res_mgr_.lock()->GetNumGpuResource(); ++i) { + auto index = zilliz::milvus::cache::GpuCacheMgr::GetInstance(i)->GetIndex(location); + if (index != nullptr) { + moved = true; + auto dest_resource = res_mgr_.lock()->GetResource(ResourceType::GPU, i); + Action::PushTaskToResource(load_completed_event->task_table_item_->task, dest_resource); + break; + } } } + if (not moved) { Action::PushTaskToNeighbourRandomly(task, resource); } @@ -147,7 +153,7 @@ Scheduler::OnLoadCompleted(const EventPtr &event) { // step 1: calculate shortest path per resource, from disk to compute resource auto compute_resources = res_mgr_.lock()->GetComputeResource(); std::vector> paths; - std::vector transport_costs; + std::vector transport_costs; for (auto &res : compute_resources) { std::vector path; uint64_t transport_cost = ShortestPath(self, res, res_mgr_.lock(), path); @@ -176,7 +182,7 @@ Scheduler::OnLoadCompleted(const EventPtr &event) { task->path() = task_path; } - if(self->name() == task->path().Last()) { + if (self->name() == task->path().Last()) { self->WakeupLoader(); } else { auto next_res_name = task->path().Next(); diff --git a/cpp/src/scheduler/Scheduler.h b/cpp/src/scheduler/Scheduler.h index e2d51ee31d..c2a36069b9 100644 --- a/cpp/src/scheduler/Scheduler.h +++ b/cpp/src/scheduler/Scheduler.h @@ -21,6 +21,7 @@ namespace milvus { namespace engine { +// TODO: refactor, not friendly to unittest, logical in framework code class Scheduler { public: explicit diff --git a/cpp/src/scheduler/TaskTable.cpp b/cpp/src/scheduler/TaskTable.cpp index 086bf06835..91d0bd7052 100644 --- a/cpp/src/scheduler/TaskTable.cpp +++ b/cpp/src/scheduler/TaskTable.cpp @@ -136,7 +136,7 @@ std::vector TaskTable::PickToLoad(uint64_t limit) { std::vector indexes; bool cross = false; - for (uint64_t i = last_finish_, count = 0; i < table_.size() && count < limit; ++i) { + for (uint64_t i = last_finish_ + 1, count = 0; i < table_.size() && count < limit; ++i) { if (not cross && table_[i]->IsFinish()) { last_finish_ = i; } else if (table_[i]->state == TaskTableItemState::START) { @@ -152,7 +152,7 @@ std::vector TaskTable::PickToExecute(uint64_t limit) { std::vector indexes; bool cross = false; - for (uint64_t i = last_finish_, count = 0; i < table_.size() && count < limit; ++i) { + for (uint64_t i = last_finish_ + 1, count = 0; i < table_.size() && count < limit; ++i) { if (not cross && table_[i]->IsFinish()) { last_finish_ = i; } else if (table_[i]->state == TaskTableItemState::LOADED) { @@ -200,15 +200,15 @@ TaskTable::Get(uint64_t index) { return table_[index]; } -void -TaskTable::Clear() { -// find first task is NOT (done or moved), erase from begin to it; -// auto iterator = table_.begin(); -// while (iterator->state == TaskTableItemState::EXECUTED or -// iterator->state == TaskTableItemState::MOVED) -// iterator++; -// table_.erase(table_.begin(), iterator); -} +//void +//TaskTable::Clear() { +//// find first task is NOT (done or moved), erase from begin to it; +//// auto iterator = table_.begin(); +//// while (iterator->state == TaskTableItemState::EXECUTED or +//// iterator->state == TaskTableItemState::MOVED) +//// iterator++; +//// table_.erase(table_.begin(), iterator); +//} std::string diff --git a/cpp/src/scheduler/TaskTable.h b/cpp/src/scheduler/TaskTable.h index f5c151f4ca..7b064f20d4 100644 --- a/cpp/src/scheduler/TaskTable.h +++ b/cpp/src/scheduler/TaskTable.h @@ -40,10 +40,10 @@ struct TaskTimestamp { }; struct TaskTableItem { - TaskTableItem() : id(0), state(TaskTableItemState::INVALID), mutex() {} + TaskTableItem() : id(0), task(nullptr), state(TaskTableItemState::INVALID), mutex() {} - TaskTableItem(const TaskTableItem &src) - : id(src.id), state(src.state), mutex() {} + TaskTableItem(const TaskTableItem &src) = delete; + TaskTableItem(TaskTableItem &&) = delete; uint64_t id; // auto increment from 0; TaskPtr task; // the task; @@ -114,8 +114,8 @@ public: * Remove sequence task which is DONE or MOVED from front; * Called by ? */ - void - Clear(); +// void +// Clear(); /* * Return true if task table empty, otherwise false; @@ -229,7 +229,9 @@ private: std::function subscriber_ = nullptr; // cache last finish avoid Pick task from begin always - uint64_t last_finish_ = 0; + // pick from (last_finish_ + 1) + // init with -1, pick from (last_finish_ + 1) = 0 + uint64_t last_finish_ = -1; }; diff --git a/cpp/src/scheduler/resource/Node.cpp b/cpp/src/scheduler/resource/Node.cpp index 8e3db29ea2..0b322abb57 100644 --- a/cpp/src/scheduler/resource/Node.cpp +++ b/cpp/src/scheduler/resource/Node.cpp @@ -17,27 +17,6 @@ Node::Node() { id_ = counter++; } -void Node::DelNeighbour(const NeighbourNodePtr &neighbour_ptr) { - std::lock_guard lk(mutex_); - if (auto s = neighbour_ptr.lock()) { - auto search = neighbours_.find(s->id_); - if (search != neighbours_.end()) { - neighbours_.erase(search); - } - } -} - -bool Node::IsNeighbour(const NeighbourNodePtr &neighbour_ptr) { - std::lock_guard lk(mutex_); - if (auto s = neighbour_ptr.lock()) { - auto search = neighbours_.find(s->id_); - if (search != neighbours_.end()) { - return true; - } - } - return false; -} - std::vector Node::GetNeighbours() { std::lock_guard lk(mutex_); std::vector ret; @@ -48,8 +27,13 @@ std::vector Node::GetNeighbours() { } std::string Node::Dump() { - // TODO(linxj): what's that? - return std::__cxx11::string(); + std::stringstream ss; + ss << "::neighbours:" << std::endl; + for (auto &neighbour : neighbours_) { + ss << "\t" << std::endl; + } + return ss.str(); } void Node::AddNeighbour(const NeighbourNodePtr &neighbour_node, Connection &connection) { diff --git a/cpp/src/scheduler/resource/Node.h b/cpp/src/scheduler/resource/Node.h index a57987ca9c..568aaf93e7 100644 --- a/cpp/src/scheduler/resource/Node.h +++ b/cpp/src/scheduler/resource/Node.h @@ -37,12 +37,6 @@ public: void AddNeighbour(const NeighbourNodePtr &neighbour_node, Connection &connection); - void - DelNeighbour(const NeighbourNodePtr &neighbour_ptr); - - bool - IsNeighbour(const NeighbourNodePtr& neighbour_ptr); - std::vector GetNeighbours(); diff --git a/cpp/src/scheduler/task/SearchTask.cpp b/cpp/src/scheduler/task/SearchTask.cpp index af13724c6e..ec603ac328 100644 --- a/cpp/src/scheduler/task/SearchTask.cpp +++ b/cpp/src/scheduler/task/SearchTask.cpp @@ -83,11 +83,14 @@ CollectFileMetrics(int file_type, size_t file_size) { XSearchTask::XSearchTask(TableFileSchemaPtr file) : Task(TaskType::SearchTask), file_(file) { - index_engine_ = EngineFactory::Build(file_->dimension_, - file_->location_, - (EngineType) file_->engine_type_, - (MetricType) file_->metric_type_, - file_->nlist_); + if (file_) { + index_engine_ = EngineFactory::Build(file_->dimension_, + file_->location_, + (EngineType) file_->engine_type_, + (MetricType) file_->metric_type_, + file_->nlist_); + } + } void @@ -103,6 +106,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) { index_engine_->CopyToCpu(); } else { // TODO: exception + std::string msg = "Wrong load type"; + ENGINE_LOG_ERROR << msg; } } catch (std::exception &ex) { //typical error: out of disk space or permition denied @@ -147,17 +152,17 @@ XSearchTask::Execute() { server::CollectDurationMetrics metrics(index_type_); std::vector output_ids; - std::vector output_distence; + std::vector output_distance; for (auto &context : search_contexts_) { //step 1: allocate memory auto inner_k = context->topk(); auto nprobe = context->nprobe(); output_ids.resize(inner_k * context->nq()); - output_distence.resize(inner_k * context->nq()); + output_distance.resize(inner_k * context->nq()); try { //step 2: search - index_engine_->Search(context->nq(), context->vectors(), inner_k, nprobe, output_distence.data(), + index_engine_->Search(context->nq(), context->vectors(), inner_k, nprobe, output_distance.data(), output_ids.data()); double span = rc.RecordSection("do search for context:" + context->Identity()); @@ -167,12 +172,12 @@ XSearchTask::Execute() { //step 3: cluster result SearchContext::ResultSet result_set; auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk(); - XSearchTask::ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set); + XSearchTask::ClusterResult(output_ids, output_distance, context->nq(), spec_k, result_set); span = rc.RecordSection("cluster result for context:" + context->Identity()); context->AccumReduceCost(span); - //step 4: pick up topk result + // step 4: pick up topk result XSearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult()); span = rc.RecordSection("reduce topk for context:" + context->Identity()); @@ -194,13 +199,13 @@ XSearchTask::Execute() { } Status XSearchTask::ClusterResult(const std::vector &output_ids, - const std::vector &output_distence, + const std::vector &output_distance, uint64_t nq, uint64_t topk, SearchContext::ResultSet &result_set) { - if (output_ids.size() < nq * topk || output_distence.size() < nq * topk) { + if (output_ids.size() < nq * topk || output_distance.size() < nq * topk) { std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) + - " distance array size: " + std::to_string(output_distence.size()); + " distance array size: " + std::to_string(output_distance.size()); ENGINE_LOG_ERROR << msg; return Status(DB_ERROR, msg); } @@ -217,7 +222,7 @@ Status XSearchTask::ClusterResult(const std::vector &output_ids, if (output_ids[index] < 0) { continue; } - id_distance.push_back(std::make_pair(output_ids[index], output_distence[index])); + id_distance.push_back(std::make_pair(output_ids[index], output_distance[index])); } result_set[i] = id_distance; } diff --git a/cpp/src/scheduler/task/SearchTask.h b/cpp/src/scheduler/task/SearchTask.h index b45eea48dc..2370e26361 100644 --- a/cpp/src/scheduler/task/SearchTask.h +++ b/cpp/src/scheduler/task/SearchTask.h @@ -12,6 +12,7 @@ namespace zilliz { namespace milvus { namespace engine { +// TODO: rewrite class XSearchTask : public Task { public: explicit diff --git a/cpp/src/scheduler/task/Task.h b/cpp/src/scheduler/task/Task.h index 01b2c8eb1b..53893aacaa 100644 --- a/cpp/src/scheduler/task/Task.h +++ b/cpp/src/scheduler/task/Task.h @@ -34,6 +34,7 @@ class Task; using TaskPtr = std::shared_ptr; +// TODO: re-design class Task { public: explicit diff --git a/cpp/src/scheduler/task/TestTask.cpp b/cpp/src/scheduler/task/TestTask.cpp index 15f60baa95..1078da583f 100644 --- a/cpp/src/scheduler/task/TestTask.cpp +++ b/cpp/src/scheduler/task/TestTask.cpp @@ -13,7 +13,7 @@ namespace milvus { namespace engine { -TestTask::TestTask(TableFileSchemaPtr& file) : XSearchTask(file) {} +TestTask::TestTask(TableFileSchemaPtr &file) : XSearchTask(file) {} void TestTask::Load(LoadType type, uint8_t device_id) { @@ -22,9 +22,12 @@ TestTask::Load(LoadType type, uint8_t device_id) { void TestTask::Execute() { - std::lock_guard lock(mutex_); - exec_count_++; - done_ = true; + { + std::lock_guard lock(mutex_); + exec_count_++; + done_ = true; + } + cv_.notify_one(); } void diff --git a/cpp/src/server/grpc_impl/GrpcRequestTask.cpp b/cpp/src/server/grpc_impl/GrpcRequestTask.cpp index a6308de3c8..3cdef8a5b9 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestTask.cpp +++ b/cpp/src/server/grpc_impl/GrpcRequestTask.cpp @@ -646,7 +646,7 @@ SearchTask::OnExecute() { search_param_->query_record_array(i).vector_data().data(), table_info.dimension_ * sizeof(float)); } - rc.ElapseFromBegin("prepare vector data"); + rc.RecordSection("prepare vector data"); //step 6: search vectors engine::QueryResults results; @@ -669,7 +669,7 @@ SearchTask::OnExecute() { ProfilerStop(); #endif - rc.ElapseFromBegin("search vectors from engine"); + rc.RecordSection("search vectors from engine"); if (!stat.ok()) { return SetError(DB_META_TRANSACTION_FAILED, stat.ToString()); } @@ -684,8 +684,6 @@ SearchTask::OnExecute() { return SetError(SERVER_ILLEGAL_SEARCH_RESULT, msg); } - rc.ElapseFromBegin("do search"); - //step 7: construct result array for (auto &result : results) { ::milvus::grpc::TopKQueryResult *topk_query_result = topk_result_list->add_topk_query_result(); @@ -697,7 +695,7 @@ SearchTask::OnExecute() { } //step 8: print time cost percent - double span_result = rc.RecordSection("construct result"); + rc.RecordSection("construct result and send"); rc.ElapseFromBegin("totally cost"); @@ -969,4 +967,4 @@ DropIndexTask::OnExecute() { } } } -} \ No newline at end of file +} diff --git a/cpp/src/utils/ValidationUtil.cpp b/cpp/src/utils/ValidationUtil.cpp index f12d68b364..bce8143c1a 100644 --- a/cpp/src/utils/ValidationUtil.cpp +++ b/cpp/src/utils/ValidationUtil.cpp @@ -94,7 +94,7 @@ ValidationUtil::ValidateTableIndexMetricType(int32_t metric_type) { ErrorCode ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema) { - if (top_k <= 0 || top_k > 1024) { + if (top_k <= 0 || top_k > 2048) { return SERVER_INVALID_TOPK; } diff --git a/cpp/src/wrapper/knowhere/vec_impl.cpp b/cpp/src/wrapper/knowhere/vec_impl.cpp index 685cc8f84a..c012ea60fb 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.cpp +++ b/cpp/src/wrapper/knowhere/vec_impl.cpp @@ -103,6 +103,7 @@ ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, lon // TODO(linxj): avoid copy here. memcpy(ids, p_ids, sizeof(int64_t) * nq * k); memcpy(dist, p_dist, sizeof(float) * nq * k); + } catch (KnowhereException &e) { WRAPPER_LOG_ERROR << e.what(); return KNOWHERE_UNEXPECTED_ERROR; diff --git a/cpp/src/wrapper/knowhere/vec_index.cpp b/cpp/src/wrapper/knowhere/vec_index.cpp index e14496884d..1ee8697d64 100644 --- a/cpp/src/wrapper/knowhere/vec_index.cpp +++ b/cpp/src/wrapper/knowhere/vec_index.cpp @@ -14,6 +14,8 @@ #include "vec_impl.h" #include "wrapper_log.h" +#include + namespace zilliz { namespace milvus { @@ -246,11 +248,13 @@ void ParameterValidation(const IndexType &type, Config &cfg) { case IndexType::FAISS_IVFSQ8_GPU: case IndexType::FAISS_IVFFLAT_GPU: case IndexType::FAISS_IVFPQ_GPU: { + //search on GPU if (cfg.get_with_default("nprobe", 0) != 0) { auto nprobe = cfg["nprobe"].as(); if (nprobe > GPU_MAX_NRPOBE) { - WRAPPER_LOG_WARNING << "When search with GPU, nprobe shoud be no more than " << GPU_MAX_NRPOBE << ", but you passed " << nprobe - << ". Search with " << GPU_MAX_NRPOBE << " instead"; + WRAPPER_LOG_WARNING << "When search with GPU, nprobe shoud be no more than " << GPU_MAX_NRPOBE + << ", but you passed " << nprobe + << ". Search with " << GPU_MAX_NRPOBE << " instead"; cfg.insert_or_assign("nprobe", GPU_MAX_NRPOBE); } } diff --git a/cpp/src/wrapper/knowhere/vec_index.h b/cpp/src/wrapper/knowhere/vec_index.h index 137429d1ad..a5cdcabf04 100644 --- a/cpp/src/wrapper/knowhere/vec_index.h +++ b/cpp/src/wrapper/knowhere/vec_index.h @@ -14,8 +14,6 @@ #include "knowhere/common/config.h" #include "knowhere/common/binary_set.h" -#include "cuda.h" - namespace zilliz { namespace milvus { @@ -62,7 +60,7 @@ class VecIndex { long *ids, const Config &cfg = Config()) = 0; - virtual VecIndexPtr CopyToGpu(const int64_t& device_id, + virtual VecIndexPtr CopyToGpu(const int64_t &device_id, const Config &cfg = Config()) = 0; virtual VecIndexPtr CopyToCpu(const Config &cfg = Config()) = 0; @@ -86,16 +84,16 @@ extern ErrorCode write_index(VecIndexPtr index, const std::string &location); extern VecIndexPtr read_index(const std::string &location); -extern VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config& cfg = Config()); +extern VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg = Config()); extern VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary); -extern void AutoGenParams(const IndexType& type, const long& size, Config& cfg); +extern void AutoGenParams(const IndexType &type, const long &size, Config &cfg); -extern void ParameterValidation(const IndexType& type, Config& cfg); +extern void ParameterValidation(const IndexType &type, Config &cfg); -extern IndexType ConvertToCpuIndexType(const IndexType& type); -extern IndexType ConvertToGpuIndexType(const IndexType& type); +extern IndexType ConvertToCpuIndexType(const IndexType &type); +extern IndexType ConvertToGpuIndexType(const IndexType &type); } } diff --git a/cpp/unittest/scheduler/node_test.cpp b/cpp/unittest/scheduler/node_test.cpp index f0621043db..642bdce773 100644 --- a/cpp/unittest/scheduler/node_test.cpp +++ b/cpp/unittest/scheduler/node_test.cpp @@ -11,58 +11,71 @@ protected: node1_ = std::make_shared(); node2_ = std::make_shared(); node3_ = std::make_shared(); - node4_ = std::make_shared(); + isolated_node1_ = std::make_shared(); + isolated_node2_ = std::make_shared(); auto pcie = Connection("PCIe", 11.0); node1_->AddNeighbour(node2_, pcie); + node1_->AddNeighbour(node3_, pcie); node2_->AddNeighbour(node1_, pcie); } NodePtr node1_; NodePtr node2_; NodePtr node3_; - NodePtr node4_; + NodePtr isolated_node1_; + NodePtr isolated_node2_; }; TEST_F(NodeTest, add_neighbour) { - ASSERT_EQ(node3_->GetNeighbours().size(), 0); - ASSERT_EQ(node4_->GetNeighbours().size(), 0); + ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 0); + ASSERT_EQ(isolated_node2_->GetNeighbours().size(), 0); auto pcie = Connection("PCIe", 11.0); - node3_->AddNeighbour(node4_, pcie); - node4_->AddNeighbour(node3_, pcie); - ASSERT_EQ(node3_->GetNeighbours().size(), 1); - ASSERT_EQ(node4_->GetNeighbours().size(), 1); + isolated_node1_->AddNeighbour(isolated_node2_, pcie); + ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 1); + ASSERT_EQ(isolated_node2_->GetNeighbours().size(), 0); } -TEST_F(NodeTest, del_neighbour) { - ASSERT_EQ(node1_->GetNeighbours().size(), 1); - ASSERT_EQ(node2_->GetNeighbours().size(), 1); - ASSERT_EQ(node3_->GetNeighbours().size(), 0); - node1_->DelNeighbour(node2_); - node2_->DelNeighbour(node2_); - node3_->DelNeighbour(node2_); - ASSERT_EQ(node1_->GetNeighbours().size(), 0); - ASSERT_EQ(node2_->GetNeighbours().size(), 1); - ASSERT_EQ(node3_->GetNeighbours().size(), 0); -} - -TEST_F(NodeTest, is_neighbour) { - ASSERT_TRUE(node1_->IsNeighbour(node2_)); - ASSERT_TRUE(node2_->IsNeighbour(node1_)); - - ASSERT_FALSE(node1_->IsNeighbour(node3_)); - ASSERT_FALSE(node2_->IsNeighbour(node3_)); - ASSERT_FALSE(node3_->IsNeighbour(node1_)); - ASSERT_FALSE(node3_->IsNeighbour(node2_)); +TEST_F(NodeTest, repeat_add_neighbour) { + ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 0); + ASSERT_EQ(isolated_node2_->GetNeighbours().size(), 0); + auto pcie = Connection("PCIe", 11.0); + isolated_node1_->AddNeighbour(isolated_node2_, pcie); + isolated_node1_->AddNeighbour(isolated_node2_, pcie); + ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 1); + ASSERT_EQ(isolated_node2_->GetNeighbours().size(), 0); } TEST_F(NodeTest, get_neighbours) { - auto node1_neighbours = node1_->GetNeighbours(); - ASSERT_EQ(node1_neighbours.size(), 1); - ASSERT_EQ(node1_neighbours[0].neighbour_node.lock(), node2_); + { + bool n2 = false, n3 = false; + auto node1_neighbours = node1_->GetNeighbours(); + ASSERT_EQ(node1_neighbours.size(), 2); + for (auto &n : node1_neighbours) { + if (n.neighbour_node.lock() == node2_) n2 = true; + if (n.neighbour_node.lock() == node3_) n3 = true; + } + ASSERT_TRUE(n2); + ASSERT_TRUE(n3); + } - auto node2_neighbours = node2_->GetNeighbours(); - ASSERT_EQ(node2_neighbours.size(), 1); - ASSERT_EQ(node2_neighbours[0].neighbour_node.lock(), node1_); + { + auto node2_neighbours = node2_->GetNeighbours(); + ASSERT_EQ(node2_neighbours.size(), 1); + ASSERT_EQ(node2_neighbours[0].neighbour_node.lock(), node1_); + } + + { + auto node3_neighbours = node3_->GetNeighbours(); + ASSERT_EQ(node3_neighbours.size(), 0); + } +} + +TEST_F(NodeTest, dump) { + std::cout << node1_->Dump(); + ASSERT_FALSE(node1_->Dump().empty()); + + std::cout << node2_->Dump(); + ASSERT_FALSE(node2_->Dump().empty()); } diff --git a/cpp/unittest/scheduler/normal_test.cpp b/cpp/unittest/scheduler/normal_test.cpp index 576ed3ee2a..c679a356bd 100644 --- a/cpp/unittest/scheduler/normal_test.cpp +++ b/cpp/unittest/scheduler/normal_test.cpp @@ -2,6 +2,7 @@ #include "scheduler/ResourceMgr.h" #include "scheduler/Scheduler.h" #include "scheduler/task/TestTask.h" +#include "scheduler/tasklabel/DefaultLabel.h" #include "scheduler/SchedInst.h" #include "utils/Log.h" #include @@ -9,48 +10,44 @@ using namespace zilliz::milvus::engine; -TEST(normal_test, test1) { + +TEST(normal_test, inst_test) { // ResourceMgr only compose resources, provide unified event -// auto res_mgr = std::make_shared(); auto res_mgr = ResMgrInst::GetInstance(); - auto disk = res_mgr->Add(ResourceFactory::Create("disk", "ssd", true, false)); - auto cpu = res_mgr->Add(ResourceFactory::Create("cpu", "CPU", 0)); - auto gpu1 = res_mgr->Add(ResourceFactory::Create("gpu", "gpu0", false, false)); - auto gpu2 = res_mgr->Add(ResourceFactory::Create("gpu", "gpu2", false, false)); + + res_mgr->Add(ResourceFactory::Create("disk", "DISK", 0, true, false)); + res_mgr->Add(ResourceFactory::Create("cpu", "CPU", 0, true, true)); auto IO = Connection("IO", 500.0); - auto PCIE = Connection("IO", 11000.0); - res_mgr->Connect(disk, cpu, IO); - res_mgr->Connect(cpu, gpu1, PCIE); - res_mgr->Connect(cpu, gpu2, PCIE); + res_mgr->Connect("disk", "cpu", IO); + + auto scheduler = SchedInst::GetInstance(); res_mgr->Start(); - -// auto scheduler = new Scheduler(res_mgr); - auto scheduler = SchedInst::GetInstance(); scheduler->Start(); const uint64_t NUM_TASK = 1000; std::vector> tasks; TableFileSchemaPtr dummy = nullptr; - for (uint64_t i = 0; i < NUM_TASK; ++i) { - if (auto observe = disk.lock()) { + auto disks = res_mgr->GetDiskResources(); + ASSERT_FALSE(disks.empty()); + if (auto observe = disks[0].lock()) { + for (uint64_t i = 0; i < NUM_TASK; ++i) { auto task = std::make_shared(dummy); + task->label() = std::make_shared(); tasks.push_back(task); observe->task_table().Put(task); } } - sleep(1); + for (auto &task : tasks) { + task->Wait(); + ASSERT_EQ(task->load_count_, 1); + ASSERT_EQ(task->exec_count_, 1); + } scheduler->Stop(); res_mgr->Stop(); - auto pcpu = cpu.lock(); - for (uint64_t i = 0; i < NUM_TASK; ++i) { - auto task = std::static_pointer_cast(pcpu->task_table()[i]->task); - ASSERT_EQ(task->load_count_, 1); - ASSERT_EQ(task->exec_count_, 1); - } } diff --git a/cpp/unittest/scheduler/tasktable_test.cpp b/cpp/unittest/scheduler/tasktable_test.cpp index 5a1094b0ad..8710aff5b5 100644 --- a/cpp/unittest/scheduler/tasktable_test.cpp +++ b/cpp/unittest/scheduler/tasktable_test.cpp @@ -5,30 +5,37 @@ using namespace zilliz::milvus::engine; + +/************ TaskTableBaseTest ************/ + class TaskTableItemTest : public ::testing::Test { protected: void SetUp() override { - item1_.id = 0; - item1_.state = TaskTableItemState::MOVED; - item1_.priority = 10; + std::vector states{ + TaskTableItemState::INVALID, + TaskTableItemState::START, + TaskTableItemState::LOADING, + TaskTableItemState::LOADED, + TaskTableItemState::EXECUTING, + TaskTableItemState::EXECUTED, + TaskTableItemState::MOVING, + TaskTableItemState::MOVED}; + for (auto &state : states) { + auto item = std::make_shared(); + item->state = state; + items_.emplace_back(item); + } } TaskTableItem default_; - TaskTableItem item1_; + std::vector items_; }; TEST_F(TaskTableItemTest, construct) { ASSERT_EQ(default_.id, 0); + ASSERT_EQ(default_.task, nullptr); ASSERT_EQ(default_.state, TaskTableItemState::INVALID); - ASSERT_EQ(default_.priority, 0); -} - -TEST_F(TaskTableItemTest, copy) { - TaskTableItem another(item1_); - ASSERT_EQ(another.id, item1_.id); - ASSERT_EQ(another.state, item1_.state); - ASSERT_EQ(another.priority, item1_.priority); } TEST_F(TaskTableItemTest, destruct) { @@ -36,6 +43,107 @@ TEST_F(TaskTableItemTest, destruct) { delete p_item; } +TEST_F(TaskTableItemTest, is_finish) { + for (auto &item : items_) { + if (item->state == TaskTableItemState::EXECUTED + || item->state == TaskTableItemState::MOVED) { + ASSERT_TRUE(item->IsFinish()); + } else { + ASSERT_FALSE(item->IsFinish()); + } + } +} + +TEST_F(TaskTableItemTest, dump) { + for (auto &item : items_) { + ASSERT_FALSE(item->Dump().empty()); + } +} + +TEST_F(TaskTableItemTest, load) { + for (auto &item : items_) { + auto before_state = item->state; + auto ret = item->Load(); + if (before_state == TaskTableItemState::START) { + ASSERT_TRUE(ret); + ASSERT_EQ(item->state, TaskTableItemState::LOADING); + } else { + ASSERT_FALSE(ret); + ASSERT_EQ(item->state, before_state); + } + } +} + +TEST_F(TaskTableItemTest, loaded) { + for (auto &item : items_) { + auto before_state = item->state; + auto ret = item->Loaded(); + if (before_state == TaskTableItemState::LOADING) { + ASSERT_TRUE(ret); + ASSERT_EQ(item->state, TaskTableItemState::LOADED); + } else { + ASSERT_FALSE(ret); + ASSERT_EQ(item->state, before_state); + } + } +} + +TEST_F(TaskTableItemTest, execute) { + for (auto &item : items_) { + auto before_state = item->state; + auto ret = item->Execute(); + if (before_state == TaskTableItemState::LOADED) { + ASSERT_TRUE(ret); + ASSERT_EQ(item->state, TaskTableItemState::EXECUTING); + } else { + ASSERT_FALSE(ret); + ASSERT_EQ(item->state, before_state); + } + } +} + + +TEST_F(TaskTableItemTest, executed) { + for (auto &item : items_) { + auto before_state = item->state; + auto ret = item->Executed(); + if (before_state == TaskTableItemState::EXECUTING) { + ASSERT_TRUE(ret); + ASSERT_EQ(item->state, TaskTableItemState::EXECUTED); + } else { + ASSERT_FALSE(ret); + ASSERT_EQ(item->state, before_state); + } + } +} + +TEST_F(TaskTableItemTest, move) { + for (auto &item : items_) { + auto before_state = item->state; + auto ret = item->Move(); + if (before_state == TaskTableItemState::LOADED) { + ASSERT_TRUE(ret); + ASSERT_EQ(item->state, TaskTableItemState::MOVING); + } else { + ASSERT_FALSE(ret); + ASSERT_EQ(item->state, before_state); + } + } +} + +TEST_F(TaskTableItemTest, moved) { + for (auto &item : items_) { + auto before_state = item->state; + auto ret = item->Moved(); + if (before_state == TaskTableItemState::MOVING) { + ASSERT_TRUE(ret); + ASSERT_EQ(item->state, TaskTableItemState::MOVED); + } else { + ASSERT_FALSE(ret); + ASSERT_EQ(item->state, before_state); + } + } +} /************ TaskTableBaseTest ************/ @@ -55,6 +163,16 @@ protected: TaskTable empty_table_; }; +TEST_F(TaskTableBaseTest, subscriber) { + bool flag = false; + auto callback = [&]() { + flag = true; + }; + empty_table_.RegisterSubscriber(callback); + empty_table_.Put(task1_); + ASSERT_TRUE(flag); +} + TEST_F(TaskTableBaseTest, put_task) { empty_table_.Put(task1_); @@ -78,6 +196,125 @@ TEST_F(TaskTableBaseTest, put_empty_batch) { empty_table_.Put(tasks); } +TEST_F(TaskTableBaseTest, empty) { + ASSERT_TRUE(empty_table_.Empty()); + empty_table_.Put(task1_); + ASSERT_FALSE(empty_table_.Empty()); +} + +TEST_F(TaskTableBaseTest, size) { + ASSERT_EQ(empty_table_.Size(), 0); + empty_table_.Put(task1_); + ASSERT_EQ(empty_table_.Size(), 1); +} + +TEST_F(TaskTableBaseTest, operator_) { + empty_table_.Put(task1_); + ASSERT_EQ(empty_table_.Get(0), empty_table_[0]); +} + +TEST_F(TaskTableBaseTest, pick_to_load) { + const size_t NUM_TASKS = 10; + for (size_t i = 0; i < NUM_TASKS; ++i) { + empty_table_.Put(task1_); + } + empty_table_[0]->state = TaskTableItemState::MOVED; + empty_table_[1]->state = TaskTableItemState::EXECUTED; + + auto indexes = empty_table_.PickToLoad(1); + ASSERT_EQ(indexes.size(), 1); + ASSERT_EQ(indexes[0], 2); +} + +TEST_F(TaskTableBaseTest, pick_to_load_limit) { + const size_t NUM_TASKS = 10; + for (size_t i = 0; i < NUM_TASKS; ++i) { + empty_table_.Put(task1_); + } + empty_table_[0]->state = TaskTableItemState::MOVED; + empty_table_[1]->state = TaskTableItemState::EXECUTED; + + auto indexes = empty_table_.PickToLoad(3); + ASSERT_EQ(indexes.size(), 3); + ASSERT_EQ(indexes[0], 2); + ASSERT_EQ(indexes[1], 3); + ASSERT_EQ(indexes[2], 4); +} + +TEST_F(TaskTableBaseTest, pick_to_load_cache) { + const size_t NUM_TASKS = 10; + for (size_t i = 0; i < NUM_TASKS; ++i) { + empty_table_.Put(task1_); + } + empty_table_[0]->state = TaskTableItemState::MOVED; + empty_table_[1]->state = TaskTableItemState::EXECUTED; + + // first pick, non-cache + auto indexes = empty_table_.PickToLoad(1); + ASSERT_EQ(indexes.size(), 1); + ASSERT_EQ(indexes[0], 2); + + // second pick, iterate from 2 + // invalid state change + empty_table_[1]->state = TaskTableItemState::START; + indexes = empty_table_.PickToLoad(1); + ASSERT_EQ(indexes.size(), 1); + ASSERT_EQ(indexes[0], 2); +} + +TEST_F(TaskTableBaseTest, pick_to_execute) { + const size_t NUM_TASKS = 10; + for (size_t i = 0; i < NUM_TASKS; ++i) { + empty_table_.Put(task1_); + } + empty_table_[0]->state = TaskTableItemState::MOVED; + empty_table_[1]->state = TaskTableItemState::EXECUTED; + empty_table_[2]->state = TaskTableItemState::LOADED; + + auto indexes = empty_table_.PickToExecute(1); + ASSERT_EQ(indexes.size(), 1); + ASSERT_EQ(indexes[0], 2); +} + +TEST_F(TaskTableBaseTest, pick_to_execute_limit) { + const size_t NUM_TASKS = 10; + for (size_t i = 0; i < NUM_TASKS; ++i) { + empty_table_.Put(task1_); + } + empty_table_[0]->state = TaskTableItemState::MOVED; + empty_table_[1]->state = TaskTableItemState::EXECUTED; + empty_table_[2]->state = TaskTableItemState::LOADED; + empty_table_[3]->state = TaskTableItemState::LOADED; + + auto indexes = empty_table_.PickToExecute(3); + ASSERT_EQ(indexes.size(), 2); + ASSERT_EQ(indexes[0], 2); + ASSERT_EQ(indexes[1], 3); +} + +TEST_F(TaskTableBaseTest, pick_to_execute_cache) { + const size_t NUM_TASKS = 10; + for (size_t i = 0; i < NUM_TASKS; ++i) { + empty_table_.Put(task1_); + } + empty_table_[0]->state = TaskTableItemState::MOVED; + empty_table_[1]->state = TaskTableItemState::EXECUTED; + empty_table_[2]->state = TaskTableItemState::LOADED; + + // first pick, non-cache + auto indexes = empty_table_.PickToExecute(1); + ASSERT_EQ(indexes.size(), 1); + ASSERT_EQ(indexes[0], 2); + + // second pick, iterate from 2 + // invalid state change + empty_table_[1]->state = TaskTableItemState::START; + indexes = empty_table_.PickToExecute(1); + ASSERT_EQ(indexes.size(), 1); + ASSERT_EQ(indexes[0], 2); +} + + /************ TaskTableAdvanceTest ************/ class TaskTableAdvanceTest : public ::testing::Test { @@ -104,25 +341,116 @@ protected: }; TEST_F(TaskTableAdvanceTest, load) { - table1_.Load(1); - table1_.Loaded(2); + std::vector before_state; + for (auto &task : table1_) { + before_state.push_back(task->state); + } - ASSERT_EQ(table1_.Get(1)->state, TaskTableItemState::LOADING); - ASSERT_EQ(table1_.Get(2)->state, TaskTableItemState::LOADED); + for (size_t i = 0; i < table1_.Size(); ++i) { + table1_.Load(i); + } + + for (size_t i = 0; i < table1_.Size(); ++i) { + if (before_state[i] == TaskTableItemState::START) { + ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::LOADING); + } else { + ASSERT_EQ(table1_.Get(i)->state, before_state[i]); + } + } +} + +TEST_F(TaskTableAdvanceTest, loaded) { + std::vector before_state; + for (auto &task : table1_) { + before_state.push_back(task->state); + } + + for (size_t i = 0; i < table1_.Size(); ++i) { + table1_.Loaded(i); + } + + for (size_t i = 0; i < table1_.Size(); ++i) { + if (before_state[i] == TaskTableItemState::LOADING) { + ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::LOADED); + } else { + ASSERT_EQ(table1_.Get(i)->state, before_state[i]); + } + } } TEST_F(TaskTableAdvanceTest, execute) { - table1_.Execute(3); - table1_.Executed(4); + std::vector before_state; + for (auto &task : table1_) { + before_state.push_back(task->state); + } - ASSERT_EQ(table1_.Get(3)->state, TaskTableItemState::EXECUTING); - ASSERT_EQ(table1_.Get(4)->state, TaskTableItemState::EXECUTED); + for (size_t i = 0; i < table1_.Size(); ++i) { + table1_.Execute(i); + } + + for (size_t i = 0; i < table1_.Size(); ++i) { + if (before_state[i] == TaskTableItemState::LOADED) { + ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::EXECUTING); + } else { + ASSERT_EQ(table1_.Get(i)->state, before_state[i]); + } + } +} + +TEST_F(TaskTableAdvanceTest, executed) { + std::vector before_state; + for (auto &task : table1_) { + before_state.push_back(task->state); + } + + for (size_t i = 0; i < table1_.Size(); ++i) { + table1_.Executed(i); + } + + for (size_t i = 0; i < table1_.Size(); ++i) { + if (before_state[i] == TaskTableItemState::EXECUTING) { + ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::EXECUTED); + } else { + ASSERT_EQ(table1_.Get(i)->state, before_state[i]); + } + } } TEST_F(TaskTableAdvanceTest, move) { - table1_.Move(3); - table1_.Moved(6); + std::vector before_state; + for (auto &task : table1_) { + before_state.push_back(task->state); + } - ASSERT_EQ(table1_.Get(3)->state, TaskTableItemState::MOVING); - ASSERT_EQ(table1_.Get(6)->state, TaskTableItemState::MOVED); + for (size_t i = 0; i < table1_.Size(); ++i) { + table1_.Move(i); + } + + for (size_t i = 0; i < table1_.Size(); ++i) { + if (before_state[i] == TaskTableItemState::LOADED) { + ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::MOVING); + } else { + ASSERT_EQ(table1_.Get(i)->state, before_state[i]); + } + } } + +TEST_F(TaskTableAdvanceTest, moved) { + std::vector before_state; + for (auto &task : table1_) { + before_state.push_back(task->state); + } + + for (size_t i = 0; i < table1_.Size(); ++i) { + table1_.Moved(i); + } + + for (size_t i = 0; i < table1_.Size(); ++i) { + if (before_state[i] == TaskTableItemState::MOVING) { + ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::MOVED); + } else { + ASSERT_EQ(table1_.Get(i)->state, before_state[i]); + } + } +} +