mirror of https://github.com/milvus-io/milvus.git
Merge branch 'scheduler' into 'jinhai'
Scheduler See merge request jinhai/vecwise_engine!16 Former-commit-id: 1adc4d5f745ac461680ce6f7d7bb14d8044ef87apull/191/head
commit
87d0ed293e
|
@ -40,8 +40,11 @@ rm -rf ./cmake_build
|
|||
mkdir cmake_build
|
||||
cd cmake_build
|
||||
|
||||
CUDA_COMPILER=/usr/local/cuda/bin/nvcc
|
||||
|
||||
CMAKE_CMD="cmake -DBUILD_UNIT_TEST=${BUILD_UNITTEST} \
|
||||
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
|
||||
-DCMAKE_CUDA_COMPILER=${CUDA_COMPILER} \
|
||||
$@ ../"
|
||||
echo ${CMAKE_CMD}
|
||||
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "VecServiceHandler.h"
|
||||
#include "VecServiceTask.h"
|
||||
#include "ServerConfig.h"
|
||||
#include "VecIdMapper.h"
|
||||
#include "utils/Log.h"
|
||||
#include "utils/CommonUtil.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
|
||||
#include "db/DB.h"
|
||||
#include "db/Env.h"
|
||||
|
@ -34,19 +36,11 @@ VecServiceHandler::add_group(const VecGroup &group) {
|
|||
SERVER_LOG_TRACE << "group.id = " << group.id << ", group.dimension = " << group.dimension
|
||||
<< ", group.index_type = " << group.index_type;
|
||||
|
||||
try {
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.dimension = (size_t)group.dimension;
|
||||
group_info.group_id = group.id;
|
||||
engine::Status stat = db_->add_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
}
|
||||
BaseTaskPtr task_ptr = AddGroupTask::Create(group.dimension, group.id);
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
|
||||
SERVER_LOG_INFO << "add_group() finished";
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
SERVER_LOG_INFO << "add_group() finished";
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -54,21 +48,12 @@ VecServiceHandler::get_group(VecGroup &_return, const std::string &group_id) {
|
|||
SERVER_LOG_INFO << "get_group() called";
|
||||
SERVER_LOG_TRACE << "group_id = " << group_id;
|
||||
|
||||
try {
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.group_id = group_id;
|
||||
engine::Status stat = db_->get_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
} else {
|
||||
_return.id = group_info.group_id;
|
||||
_return.dimension = (int32_t)group_info.dimension;
|
||||
}
|
||||
_return.id = group_id;
|
||||
BaseTaskPtr task_ptr = GetGroupTask::Create(group_id, _return.dimension);
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
|
||||
SERVER_LOG_INFO << "get_group() finished";
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
SERVER_LOG_INFO << "get_group() finished";
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -76,12 +61,11 @@ VecServiceHandler::del_group(const std::string &group_id) {
|
|||
SERVER_LOG_INFO << "del_group() called";
|
||||
SERVER_LOG_TRACE << "group_id = " << group_id;
|
||||
|
||||
try {
|
||||
BaseTaskPtr task_ptr = DeleteGroupTask::Create(group_id);
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
|
||||
SERVER_LOG_INFO << "del_group() not implemented";
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
SERVER_LOG_INFO << "del_group() not implemented";
|
||||
}
|
||||
|
||||
|
||||
|
@ -90,25 +74,11 @@ VecServiceHandler::add_vector(const std::string &group_id, const VecTensor &tens
|
|||
SERVER_LOG_INFO << "add_vector() called";
|
||||
SERVER_LOG_TRACE << "group_id = " << group_id << ", vector size = " << tensor.tensor.size();
|
||||
|
||||
try {
|
||||
engine::IDNumbers vector_ids;
|
||||
std::vector<float> vec_f(tensor.tensor.begin(), tensor.tensor.end());
|
||||
engine::Status stat = db_->add_vectors(group_id, 1, vec_f.data(), vector_ids);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
} else {
|
||||
if(vector_ids.size() != 1) {
|
||||
SERVER_LOG_ERROR << "Vector ID not returned";
|
||||
} else {
|
||||
std::string nid = group_id + "_" + std::to_string(vector_ids[0]);
|
||||
IVecIdMapper::GetInstance()->Put(nid, tensor.uid);
|
||||
}
|
||||
}
|
||||
BaseTaskPtr task_ptr = AddSingleVectorTask::Create(group_id, tensor);
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
|
||||
SERVER_LOG_INFO << "add_vector() finished";
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
SERVER_LOG_INFO << "add_vector() finished";
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -117,33 +87,13 @@ VecServiceHandler::add_vector_batch(const std::string &group_id,
|
|||
SERVER_LOG_INFO << "add_vector_batch() called";
|
||||
SERVER_LOG_TRACE << "group_id = " << group_id << ", vector list size = "
|
||||
<< tensor_list.tensor_list.size();
|
||||
TimeRecorder rc("Add VECTOR BATCH");
|
||||
BaseTaskPtr task_ptr = AddBatchVectorTask::Create(group_id, tensor_list);
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
rc.Elapse("DONE!");
|
||||
|
||||
try {
|
||||
std::vector<float> vec_f;
|
||||
for(const VecTensor& tensor : tensor_list.tensor_list) {
|
||||
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
|
||||
}
|
||||
|
||||
engine::IDNumbers vector_ids;
|
||||
engine::Status stat = db_->add_vectors(group_id, tensor_list.tensor_list.size(), vec_f.data(), vector_ids);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
} else {
|
||||
if(vector_ids.size() != tensor_list.tensor_list.size()) {
|
||||
SERVER_LOG_ERROR << "Vector ID not returned";
|
||||
} else {
|
||||
std::string nid_prefix = group_id + "_";
|
||||
for(size_t i = 0; i < vector_ids.size(); i++) {
|
||||
std::string nid = nid_prefix + std::to_string(vector_ids[i]);
|
||||
IVecIdMapper::GetInstance()->Put(nid, tensor_list.tensor_list[i].uid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SERVER_LOG_INFO << "add_vector_batch() finished";
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
SERVER_LOG_INFO << "add_vector_batch() finished";
|
||||
}
|
||||
|
||||
|
||||
|
@ -158,29 +108,20 @@ VecServiceHandler::search_vector(VecSearchResult &_return,
|
|||
<< ", vector size = " << tensor.tensor.size()
|
||||
<< ", time range list size = " << time_range_list.range_list.size();
|
||||
|
||||
try {
|
||||
engine::QueryResults results;
|
||||
std::vector<float> vec_f(tensor.tensor.begin(), tensor.tensor.end());
|
||||
engine::Status stat = db_->search(group_id, (size_t)top_k, 1, vec_f.data(), results);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
} else {
|
||||
if(!results.empty()) {
|
||||
std::string nid_prefix = group_id + "_";
|
||||
for(auto id : results[0]) {
|
||||
std::string sid;
|
||||
std::string nid = nid_prefix + std::to_string(id);
|
||||
IVecIdMapper::GetInstance()->Get(nid, sid);
|
||||
_return.id_list.push_back(sid);
|
||||
_return.distance_list.push_back(0.0);//TODO: return distance
|
||||
}
|
||||
}
|
||||
}
|
||||
VecTensorList tensor_list;
|
||||
tensor_list.tensor_list.push_back(tensor);
|
||||
VecSearchResultList result;
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, result);
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
|
||||
SERVER_LOG_INFO << "search_vector() finished";
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
if(!result.result_list.empty()) {
|
||||
_return = result.result_list[0];
|
||||
} else {
|
||||
SERVER_LOG_ERROR << "No search result returned";
|
||||
}
|
||||
|
||||
SERVER_LOG_INFO << "search_vector() finished";
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -194,36 +135,11 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return,
|
|||
<< ", vector list size = " << tensor_list.tensor_list.size()
|
||||
<< ", time range list size = " << time_range_list.range_list.size();
|
||||
|
||||
try {
|
||||
std::vector<float> vec_f;
|
||||
for(const VecTensor& tensor : tensor_list.tensor_list) {
|
||||
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
|
||||
}
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, _return);
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
|
||||
engine::QueryResults results;
|
||||
engine::Status stat = db_->search(group_id, (size_t)top_k, tensor_list.tensor_list.size(), vec_f.data(), results);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
} else {
|
||||
for(engine::QueryResult& res : results){
|
||||
VecSearchResult v_res;
|
||||
std::string nid_prefix = group_id + "_";
|
||||
for(auto id : results[0]) {
|
||||
std::string sid;
|
||||
std::string nid = nid_prefix + std::to_string(id);
|
||||
IVecIdMapper::GetInstance()->Get(nid, sid);
|
||||
v_res.id_list.push_back(sid);
|
||||
v_res.distance_list.push_back(0.0);//TODO: return distance
|
||||
}
|
||||
|
||||
_return.result_list.push_back(v_res);
|
||||
}
|
||||
}
|
||||
|
||||
SERVER_LOG_INFO << "search_vector_batch() finished";
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
SERVER_LOG_INFO << "search_vector_batch() finished";
|
||||
}
|
||||
|
||||
VecServiceHandler::~VecServiceHandler() {
|
||||
|
|
|
@ -4,17 +4,138 @@
|
|||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "VecServiceScheduler.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
VecServiceScheduler::VecServiceScheduler() {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
BaseTask::BaseTask(const std::string& task_group)
|
||||
: task_group_(task_group),
|
||||
done_(false),
|
||||
error_code_(SERVER_SUCCESS) {
|
||||
|
||||
}
|
||||
|
||||
BaseTask::~BaseTask() {
|
||||
WaitToFinish();
|
||||
}
|
||||
|
||||
ServerError BaseTask::Execute() {
|
||||
error_code_ = OnExecute();
|
||||
done_ = true;
|
||||
finish_cond_.notify_all();
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
ServerError BaseTask::WaitToFinish() {
|
||||
std::unique_lock <std::mutex> lock(finish_mtx_);
|
||||
finish_cond_.wait(lock, [this] { return done_; });
|
||||
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
VecServiceScheduler::VecServiceScheduler()
|
||||
: stopped_(false) {
|
||||
Start();
|
||||
}
|
||||
|
||||
VecServiceScheduler::~VecServiceScheduler() {
|
||||
Stop();
|
||||
}
|
||||
|
||||
void VecServiceScheduler::Start() {
|
||||
if(!stopped_) {
|
||||
return;
|
||||
}
|
||||
|
||||
stopped_ = false;
|
||||
}
|
||||
|
||||
void VecServiceScheduler::Stop() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(queue_mtx_);
|
||||
for(auto iter : task_groups_) {
|
||||
if(iter.second != nullptr) {
|
||||
iter.second->Put(nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(auto iter : execute_threads_) {
|
||||
if(iter == nullptr)
|
||||
continue;
|
||||
|
||||
iter->join();
|
||||
}
|
||||
stopped_ = true;
|
||||
}
|
||||
|
||||
ServerError VecServiceScheduler::PushTask(const BaseTaskPtr& task_ptr) {
|
||||
if(task_ptr == nullptr) {
|
||||
return SERVER_NULL_POINTER;
|
||||
}
|
||||
|
||||
return PutTaskToQueue(task_ptr);
|
||||
}
|
||||
|
||||
ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
|
||||
if(task_ptr == nullptr) {
|
||||
return SERVER_NULL_POINTER;
|
||||
}
|
||||
|
||||
ServerError err = PutTaskToQueue(task_ptr);
|
||||
if(err != SERVER_SUCCESS) {
|
||||
return err;
|
||||
}
|
||||
|
||||
return task_ptr->WaitToFinish();
|
||||
}
|
||||
|
||||
namespace {
|
||||
void TakeTaskToExecute(TaskQueuePtr task_queue) {
|
||||
if(task_queue == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
while(true) {
|
||||
BaseTaskPtr task = task_queue->Take();
|
||||
if (task == nullptr) {
|
||||
break;//stop the thread
|
||||
}
|
||||
|
||||
try {
|
||||
ServerError err = task->Execute();
|
||||
if(err != SERVER_SUCCESS) {
|
||||
SERVER_LOG_ERROR << "Task failed with code: " << err;
|
||||
}
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << "Task failed to execute: " << ex.what();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ServerError VecServiceScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) {
|
||||
std::lock_guard<std::mutex> lock(queue_mtx_);
|
||||
|
||||
std::string group_name = task_ptr->TaskGroup();
|
||||
if(task_groups_.count(group_name) > 0) {
|
||||
task_groups_[group_name]->Put(task_ptr);
|
||||
} else {
|
||||
TaskQueuePtr queue = std::make_shared<TaskQueue>();
|
||||
queue->Put(task_ptr);
|
||||
task_groups_.insert(std::make_pair(group_name, queue));
|
||||
|
||||
//start a thread
|
||||
ThreadPtr thread = std::make_shared<std::thread>(&TakeTaskToExecute, queue);
|
||||
execute_threads_.push_back(thread);
|
||||
SERVER_LOG_INFO << "Create new thread for task group: " << group_name;
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -5,15 +5,74 @@
|
|||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "utils/BlockingQueue.h"
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
class BaseTask {
|
||||
protected:
|
||||
BaseTask(const std::string& task_group);
|
||||
virtual ~BaseTask();
|
||||
|
||||
public:
|
||||
ServerError Execute();
|
||||
ServerError WaitToFinish();
|
||||
|
||||
std::string TaskGroup() const { return task_group_; }
|
||||
|
||||
ServerError ErrorCode() const { return error_code_; }
|
||||
protected:
|
||||
virtual ServerError OnExecute() = 0;
|
||||
|
||||
protected:
|
||||
mutable std::mutex finish_mtx_;
|
||||
std::condition_variable finish_cond_;
|
||||
|
||||
std::string task_group_;
|
||||
bool done_;
|
||||
ServerError error_code_;
|
||||
};
|
||||
|
||||
using BaseTaskPtr = std::shared_ptr<BaseTask>;
|
||||
using TaskQueue = BlockingQueue<BaseTaskPtr>;
|
||||
using TaskQueuePtr = std::shared_ptr<TaskQueue>;
|
||||
using ThreadPtr = std::shared_ptr<std::thread>;
|
||||
|
||||
class VecServiceScheduler {
|
||||
public:
|
||||
static VecServiceScheduler& GetInstance() {
|
||||
static VecServiceScheduler scheduler;
|
||||
return scheduler;
|
||||
}
|
||||
|
||||
void Start();
|
||||
void Stop();
|
||||
|
||||
//async
|
||||
ServerError PushTask(const BaseTaskPtr& task_ptr);
|
||||
//sync
|
||||
ServerError ExecuteTask(const BaseTaskPtr& task_ptr);
|
||||
|
||||
protected:
|
||||
VecServiceScheduler();
|
||||
virtual ~VecServiceScheduler();
|
||||
|
||||
ServerError PutTaskToQueue(const BaseTaskPtr& task_ptr);
|
||||
|
||||
private:
|
||||
mutable std::mutex queue_mtx_;
|
||||
|
||||
std::map<std::string, TaskQueuePtr> task_groups_;
|
||||
|
||||
std::vector<ThreadPtr> execute_threads_;
|
||||
|
||||
bool stopped_;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,290 @@
|
|||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "VecServiceTask.h"
|
||||
#include "ServerConfig.h"
|
||||
#include "VecIdMapper.h"
|
||||
#include "utils/CommonUtil.h"
|
||||
#include "utils/Log.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
#include "db/DB.h"
|
||||
#include "db/Env.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
static const std::string NORMAL_TASK_GROUP = "normal";
|
||||
static const std::string SEARCH_TASK_GROUP = "search";
|
||||
|
||||
namespace {
|
||||
class DBWrapper {
|
||||
public:
|
||||
DBWrapper() {
|
||||
zilliz::vecwise::engine::Options opt;
|
||||
ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_SERVER);
|
||||
opt.meta.backend_uri = config.GetValue(CONFIG_SERVER_DB_URL);
|
||||
std::string db_path = config.GetValue(CONFIG_SERVER_DB_PATH);
|
||||
opt.meta.path = db_path + "/db";
|
||||
|
||||
CommonUtil::CreateDirectory(opt.meta.path);
|
||||
|
||||
zilliz::vecwise::engine::DB::Open(opt, &db_);
|
||||
if(db_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Failed to open db";
|
||||
throw ServerException(SERVER_NULL_POINTER, "Failed to open db");
|
||||
}
|
||||
}
|
||||
|
||||
zilliz::vecwise::engine::DB* DB() { return db_; }
|
||||
|
||||
private:
|
||||
zilliz::vecwise::engine::DB* db_ = nullptr;
|
||||
};
|
||||
|
||||
zilliz::vecwise::engine::DB* DB() {
|
||||
static DBWrapper db_wrapper;
|
||||
return db_wrapper.DB();
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
AddGroupTask::AddGroupTask(int32_t dimension,
|
||||
const std::string& group_id)
|
||||
: BaseTask(NORMAL_TASK_GROUP),
|
||||
dimension_(dimension),
|
||||
group_id_(group_id) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr AddGroupTask::Create(int32_t dimension,
|
||||
const std::string& group_id) {
|
||||
return std::shared_ptr<BaseTask>(new AddGroupTask(dimension,group_id));
|
||||
}
|
||||
|
||||
ServerError AddGroupTask::OnExecute() {
|
||||
try {
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.dimension = (size_t)dimension_;
|
||||
group_info.group_id = group_id_;
|
||||
engine::Status stat = DB()->add_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
GetGroupTask::GetGroupTask(const std::string& group_id, int32_t& dimension)
|
||||
: BaseTask(NORMAL_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
dimension_(dimension) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr GetGroupTask::Create(const std::string& group_id, int32_t& dimension) {
|
||||
return std::shared_ptr<BaseTask>(new GetGroupTask(group_id, dimension));
|
||||
}
|
||||
|
||||
ServerError GetGroupTask::OnExecute() {
|
||||
try {
|
||||
dimension_ = 0;
|
||||
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.group_id = group_id_;
|
||||
engine::Status stat = DB()->get_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
} else {
|
||||
dimension_ = (int32_t)group_info.dimension;
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
DeleteGroupTask::DeleteGroupTask(const std::string& group_id)
|
||||
: BaseTask(NORMAL_TASK_GROUP),
|
||||
group_id_(group_id) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr DeleteGroupTask::Create(const std::string& group_id) {
|
||||
return std::shared_ptr<BaseTask>(new DeleteGroupTask(group_id));
|
||||
}
|
||||
|
||||
ServerError DeleteGroupTask::OnExecute() {
|
||||
try {
|
||||
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
AddSingleVectorTask::AddSingleVectorTask(const std::string& group_id,
|
||||
const VecTensor &tensor)
|
||||
: BaseTask(NORMAL_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
tensor_(tensor) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr AddSingleVectorTask::Create(const std::string& group_id,
|
||||
const VecTensor &tensor) {
|
||||
return std::shared_ptr<BaseTask>(new AddSingleVectorTask(group_id, tensor));
|
||||
}
|
||||
|
||||
ServerError AddSingleVectorTask::OnExecute() {
|
||||
try {
|
||||
engine::IDNumbers vector_ids;
|
||||
std::vector<float> vec_f(tensor_.tensor.begin(), tensor_.tensor.end());
|
||||
engine::Status stat = DB()->add_vectors(group_id_, 1, vec_f.data(), vector_ids);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
} else {
|
||||
if(vector_ids.empty()) {
|
||||
SERVER_LOG_ERROR << "Vector ID not returned";
|
||||
} else {
|
||||
std::string nid = group_id_ + "_" + std::to_string(vector_ids[0]);
|
||||
IVecIdMapper::GetInstance()->Put(nid, tensor_.uid);
|
||||
SERVER_LOG_TRACE << "nid = " << vector_ids[0] << ", sid = " << tensor_.uid;
|
||||
}
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
|
||||
const VecTensorList &tensor_list)
|
||||
: BaseTask(NORMAL_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
tensor_list_(tensor_list) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr AddBatchVectorTask::Create(const std::string& group_id,
|
||||
const VecTensorList &tensor_list) {
|
||||
return std::shared_ptr<BaseTask>(new AddBatchVectorTask(group_id, tensor_list));
|
||||
}
|
||||
|
||||
ServerError AddBatchVectorTask::OnExecute() {
|
||||
try {
|
||||
TimeRecorder rc("Add vector batch");
|
||||
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.group_id = group_id_;
|
||||
engine::Status stat = DB()->get_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
|
||||
std::vector<float> vec_f;
|
||||
vec_f.reserve(tensor_list_.tensor_list.size()*group_info.dimension*4);
|
||||
for(const VecTensor& tensor : tensor_list_.tensor_list) {
|
||||
if(tensor.tensor.size() != group_info.dimension) {
|
||||
SERVER_LOG_ERROR << "Invalid vector data size: " << tensor.tensor.size()
|
||||
<< " vs. group dimension:" << group_info.dimension;
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
|
||||
}
|
||||
rc.Record("prepare vectors data");
|
||||
|
||||
engine::IDNumbers vector_ids;
|
||||
stat = DB()->add_vectors(group_id_, tensor_list_.tensor_list.size(), vec_f.data(), vector_ids);
|
||||
rc.Record("add vectors to engine");
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
} else {
|
||||
if(vector_ids.size() < tensor_list_.tensor_list.size()) {
|
||||
SERVER_LOG_ERROR << "Vector ID not returned";
|
||||
} else {
|
||||
std::string nid_prefix = group_id_ + "_";
|
||||
for(size_t i = 0; i < tensor_list_.tensor_list.size(); i++) {
|
||||
std::string nid = nid_prefix + std::to_string(vector_ids[i]);
|
||||
IVecIdMapper::GetInstance()->Put(nid, tensor_list_.tensor_list[i].uid);
|
||||
}
|
||||
rc.Record("build id mapping");
|
||||
}
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
SearchVectorTask::SearchVectorTask(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList& tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result)
|
||||
: BaseTask(SEARCH_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
top_k_(top_k),
|
||||
tensor_list_(tensor_list),
|
||||
time_range_list_(time_range_list),
|
||||
result_(result) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList& tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result) {
|
||||
return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, tensor_list, time_range_list, result));
|
||||
}
|
||||
|
||||
ServerError SearchVectorTask::OnExecute() {
|
||||
try {
|
||||
std::vector<float> vec_f;
|
||||
for(const VecTensor& tensor : tensor_list_.tensor_list) {
|
||||
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
|
||||
}
|
||||
|
||||
engine::QueryResults results;
|
||||
engine::Status stat = DB()->search(group_id_, (size_t)top_k_, tensor_list_.tensor_list.size(), vec_f.data(), results);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
} else {
|
||||
for(engine::QueryResult& res : results){
|
||||
VecSearchResult v_res;
|
||||
std::string nid_prefix = group_id_ + "_";
|
||||
for(auto id : results[0]) {
|
||||
std::string sid;
|
||||
std::string nid = nid_prefix + std::to_string(id);
|
||||
IVecIdMapper::GetInstance()->Get(nid, sid);
|
||||
v_res.id_list.push_back(sid);
|
||||
v_res.distance_list.push_back(0.0);//TODO: return distance
|
||||
|
||||
SERVER_LOG_TRACE << "nid = " << nid << ", string id = " << sid;
|
||||
|
||||
}
|
||||
|
||||
result_.result_list.push_back(v_res);
|
||||
}
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,134 @@
|
|||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "VecServiceScheduler.h"
|
||||
#include "utils/Error.h"
|
||||
#include "db/Types.h"
|
||||
|
||||
#include "thrift/gen-cpp/VectorService_types.h"
|
||||
|
||||
#include <condition_variable>
|
||||
#include <memory>
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class AddGroupTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(int32_t dimension,
|
||||
const std::string& group_id);
|
||||
|
||||
protected:
|
||||
AddGroupTask(int32_t dimension,
|
||||
const std::string& group_id);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
private:
|
||||
int32_t dimension_;
|
||||
std::string group_id_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class GetGroupTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& group_id, int32_t& dimension);
|
||||
|
||||
protected:
|
||||
GetGroupTask(const std::string& group_id, int32_t& dimension);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
|
||||
private:
|
||||
std::string group_id_;
|
||||
int32_t& dimension_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class DeleteGroupTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& group_id);
|
||||
|
||||
protected:
|
||||
DeleteGroupTask(const std::string& group_id);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
|
||||
private:
|
||||
std::string group_id_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class AddSingleVectorTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& group_id,
|
||||
const VecTensor &tensor);
|
||||
|
||||
protected:
|
||||
AddSingleVectorTask(const std::string& group_id,
|
||||
const VecTensor &tensor);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
|
||||
private:
|
||||
std::string group_id_;
|
||||
const VecTensor& tensor_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class AddBatchVectorTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& group_id,
|
||||
const VecTensorList &tensor_list);
|
||||
|
||||
protected:
|
||||
AddBatchVectorTask(const std::string& group_id,
|
||||
const VecTensorList &tensor_list);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
|
||||
private:
|
||||
std::string group_id_;
|
||||
const VecTensorList& tensor_list_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class SearchVectorTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList& tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result);
|
||||
|
||||
protected:
|
||||
SearchVectorTask(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList& tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
|
||||
private:
|
||||
std::string group_id_;
|
||||
int64_t top_k_;
|
||||
const VecTensorList& tensor_list_;
|
||||
const VecTimeRangeList& time_range_list_;
|
||||
VecSearchResultList& result_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@
|
|||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include <utils/TimeRecorder.h>
|
||||
#include "ClientApp.h"
|
||||
#include "ClientSession.h"
|
||||
#include "server/ServerConfig.h"
|
||||
|
@ -37,21 +38,44 @@ void ClientApp::Run(const std::string &config_file) {
|
|||
group.index_type = 0;
|
||||
session.interface()->add_group(group);
|
||||
|
||||
//add vectors
|
||||
for(int64_t k = 0; k < 10000; k++) {
|
||||
VecTensor tensor;
|
||||
for(int32_t i = 0; i < dim; i++) {
|
||||
tensor.tensor.push_back((double)(i + k));
|
||||
const int64_t count = 500;
|
||||
//add vectors one by one
|
||||
{
|
||||
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
VecTensor tensor;
|
||||
for (int32_t i = 0; i < dim; i++) {
|
||||
tensor.tensor.push_back((double) (i + k));
|
||||
}
|
||||
tensor.uid = "vec_" + std::to_string(k);
|
||||
|
||||
session.interface()->add_vector(group.id, tensor);
|
||||
|
||||
CLIENT_LOG_INFO << "add vector no." << k;
|
||||
}
|
||||
tensor.uid = "vec_" + std::to_string(k);
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
session.interface()->add_vector(group.id, tensor);
|
||||
|
||||
CLIENT_LOG_INFO << "add vector no." << k;
|
||||
//add vectors in one batch
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
|
||||
VecTensorList vec_list;
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
VecTensor tensor;
|
||||
for (int32_t i = 0; i < dim; i++) {
|
||||
tensor.tensor.push_back((double) (i + k));
|
||||
}
|
||||
tensor.uid = "vec_" + std::to_string(k);
|
||||
vec_list.tensor_list.push_back(tensor);
|
||||
}
|
||||
session.interface()->add_vector_batch(group.id, vec_list);
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//search vector
|
||||
{
|
||||
server::TimeRecorder rc("Search top_k");
|
||||
VecTensor tensor;
|
||||
for (int32_t i = 0; i < dim; i++) {
|
||||
tensor.tensor.push_back((double) (i + 100));
|
||||
|
@ -65,6 +89,7 @@ void ClientApp::Run(const std::string &config_file) {
|
|||
for(auto id : res.id_list) {
|
||||
std::cout << id << std::endl;
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
|
|
Loading…
Reference in New Issue