mirror of https://github.com/milvus-io/milvus.git
refine scheduler
Former-commit-id: 9b772adf62a9f7f2ae349f3a2420fcecb08af6cepull/191/head
parent
87d0ed293e
commit
75410aed03
|
@ -11,10 +11,11 @@ namespace vecwise {
|
|||
namespace server {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
BaseTask::BaseTask(const std::string& task_group)
|
||||
: task_group_(task_group),
|
||||
done_(false),
|
||||
error_code_(SERVER_SUCCESS) {
|
||||
BaseTask::BaseTask(const std::string& task_group, bool async)
|
||||
: task_group_(task_group),
|
||||
async_(async),
|
||||
done_(false),
|
||||
error_code_(SERVER_SUCCESS) {
|
||||
|
||||
}
|
||||
|
||||
|
@ -73,14 +74,6 @@ void VecServiceScheduler::Stop() {
|
|||
stopped_ = true;
|
||||
}
|
||||
|
||||
ServerError VecServiceScheduler::PushTask(const BaseTaskPtr& task_ptr) {
|
||||
if(task_ptr == nullptr) {
|
||||
return SERVER_NULL_POINTER;
|
||||
}
|
||||
|
||||
return PutTaskToQueue(task_ptr);
|
||||
}
|
||||
|
||||
ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
|
||||
if(task_ptr == nullptr) {
|
||||
return SERVER_NULL_POINTER;
|
||||
|
@ -91,7 +84,11 @@ ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
|
|||
return err;
|
||||
}
|
||||
|
||||
return task_ptr->WaitToFinish();
|
||||
if(task_ptr->IsAsync()) {
|
||||
return SERVER_SUCCESS;//async execution, caller need to call WaitToFinish at somewhere
|
||||
}
|
||||
|
||||
return task_ptr->WaitToFinish();//sync execution
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -17,7 +17,7 @@ namespace server {
|
|||
|
||||
class BaseTask {
|
||||
protected:
|
||||
BaseTask(const std::string& task_group);
|
||||
BaseTask(const std::string& task_group, bool async = false);
|
||||
virtual ~BaseTask();
|
||||
|
||||
public:
|
||||
|
@ -27,6 +27,9 @@ public:
|
|||
std::string TaskGroup() const { return task_group_; }
|
||||
|
||||
ServerError ErrorCode() const { return error_code_; }
|
||||
|
||||
bool IsAsync() const { return async_; }
|
||||
|
||||
protected:
|
||||
virtual ServerError OnExecute() = 0;
|
||||
|
||||
|
@ -35,6 +38,7 @@ protected:
|
|||
std::condition_variable finish_cond_;
|
||||
|
||||
std::string task_group_;
|
||||
bool async_;
|
||||
bool done_;
|
||||
ServerError error_code_;
|
||||
};
|
||||
|
@ -54,9 +58,6 @@ public:
|
|||
void Start();
|
||||
void Stop();
|
||||
|
||||
//async
|
||||
ServerError PushTask(const BaseTaskPtr& task_ptr);
|
||||
//sync
|
||||
ServerError ExecuteTask(const BaseTaskPtr& task_ptr);
|
||||
|
||||
protected:
|
||||
|
|
|
@ -16,8 +16,8 @@ namespace zilliz {
|
|||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
static const std::string NORMAL_TASK_GROUP = "normal";
|
||||
static const std::string SEARCH_TASK_GROUP = "search";
|
||||
static const std::string DQL_TASK_GROUP = "dql";
|
||||
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
|
||||
|
||||
namespace {
|
||||
class DBWrapper {
|
||||
|
@ -53,7 +53,7 @@ namespace {
|
|||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
AddGroupTask::AddGroupTask(int32_t dimension,
|
||||
const std::string& group_id)
|
||||
: BaseTask(NORMAL_TASK_GROUP),
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
dimension_(dimension),
|
||||
group_id_(group_id) {
|
||||
|
||||
|
@ -81,7 +81,7 @@ ServerError AddGroupTask::OnExecute() {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
GetGroupTask::GetGroupTask(const std::string& group_id, int32_t& dimension)
|
||||
: BaseTask(NORMAL_TASK_GROUP),
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
dimension_(dimension) {
|
||||
|
||||
|
@ -111,7 +111,7 @@ ServerError GetGroupTask::OnExecute() {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
DeleteGroupTask::DeleteGroupTask(const std::string& group_id)
|
||||
: BaseTask(NORMAL_TASK_GROUP),
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
group_id_(group_id) {
|
||||
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ ServerError DeleteGroupTask::OnExecute() {
|
|||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
AddSingleVectorTask::AddSingleVectorTask(const std::string& group_id,
|
||||
const VecTensor &tensor)
|
||||
: BaseTask(NORMAL_TASK_GROUP),
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
tensor_(tensor) {
|
||||
|
||||
|
@ -169,7 +169,7 @@ ServerError AddSingleVectorTask::OnExecute() {
|
|||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
|
||||
const VecTensorList &tensor_list)
|
||||
: BaseTask(NORMAL_TASK_GROUP),
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
tensor_list_(tensor_list) {
|
||||
|
||||
|
@ -233,7 +233,7 @@ SearchVectorTask::SearchVectorTask(const std::string& group_id,
|
|||
const VecTensorList& tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result)
|
||||
: BaseTask(SEARCH_TASK_GROUP),
|
||||
: BaseTask(DQL_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
top_k_(top_k),
|
||||
tensor_list_(tensor_list),
|
||||
|
|
Loading…
Reference in New Issue