refine scheduler

Former-commit-id: 9b772adf62a9f7f2ae349f3a2420fcecb08af6ce
pull/191/head
groot 2019-04-25 16:41:01 +08:00
parent 87d0ed293e
commit 75410aed03
3 changed files with 23 additions and 25 deletions

View File

@ -11,10 +11,11 @@ namespace vecwise {
namespace server {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
BaseTask::BaseTask(const std::string& task_group)
: task_group_(task_group),
done_(false),
error_code_(SERVER_SUCCESS) {
BaseTask::BaseTask(const std::string& task_group, bool async)
: task_group_(task_group),
async_(async),
done_(false),
error_code_(SERVER_SUCCESS) {
}
@ -73,14 +74,6 @@ void VecServiceScheduler::Stop() {
stopped_ = true;
}
ServerError VecServiceScheduler::PushTask(const BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return SERVER_NULL_POINTER;
}
return PutTaskToQueue(task_ptr);
}
ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return SERVER_NULL_POINTER;
@ -91,7 +84,11 @@ ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
return err;
}
return task_ptr->WaitToFinish();
if(task_ptr->IsAsync()) {
return SERVER_SUCCESS;//async execution, caller need to call WaitToFinish at somewhere
}
return task_ptr->WaitToFinish();//sync execution
}
namespace {

View File

@ -17,7 +17,7 @@ namespace server {
class BaseTask {
protected:
BaseTask(const std::string& task_group);
BaseTask(const std::string& task_group, bool async = false);
virtual ~BaseTask();
public:
@ -27,6 +27,9 @@ public:
std::string TaskGroup() const { return task_group_; }
ServerError ErrorCode() const { return error_code_; }
bool IsAsync() const { return async_; }
protected:
virtual ServerError OnExecute() = 0;
@ -35,6 +38,7 @@ protected:
std::condition_variable finish_cond_;
std::string task_group_;
bool async_;
bool done_;
ServerError error_code_;
};
@ -54,9 +58,6 @@ public:
void Start();
void Stop();
//async
ServerError PushTask(const BaseTaskPtr& task_ptr);
//sync
ServerError ExecuteTask(const BaseTaskPtr& task_ptr);
protected:

View File

@ -16,8 +16,8 @@ namespace zilliz {
namespace vecwise {
namespace server {
static const std::string NORMAL_TASK_GROUP = "normal";
static const std::string SEARCH_TASK_GROUP = "search";
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
namespace {
class DBWrapper {
@ -53,7 +53,7 @@ namespace {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddGroupTask::AddGroupTask(int32_t dimension,
const std::string& group_id)
: BaseTask(NORMAL_TASK_GROUP),
: BaseTask(DDL_DML_TASK_GROUP),
dimension_(dimension),
group_id_(group_id) {
@ -81,7 +81,7 @@ ServerError AddGroupTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
GetGroupTask::GetGroupTask(const std::string& group_id, int32_t& dimension)
: BaseTask(NORMAL_TASK_GROUP),
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id),
dimension_(dimension) {
@ -111,7 +111,7 @@ ServerError GetGroupTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteGroupTask::DeleteGroupTask(const std::string& group_id)
: BaseTask(NORMAL_TASK_GROUP),
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id) {
}
@ -132,7 +132,7 @@ ServerError DeleteGroupTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddSingleVectorTask::AddSingleVectorTask(const std::string& group_id,
const VecTensor &tensor)
: BaseTask(NORMAL_TASK_GROUP),
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id),
tensor_(tensor) {
@ -169,7 +169,7 @@ ServerError AddSingleVectorTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
const VecTensorList &tensor_list)
: BaseTask(NORMAL_TASK_GROUP),
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id),
tensor_list_(tensor_list) {
@ -233,7 +233,7 @@ SearchVectorTask::SearchVectorTask(const std::string& group_id,
const VecTensorList& tensor_list,
const VecTimeRangeList& time_range_list,
VecSearchResultList& result)
: BaseTask(SEARCH_TASK_GROUP),
: BaseTask(DQL_TASK_GROUP),
group_id_(group_id),
top_k_(top_k),
tensor_list_(tensor_list),