refine code

Former-commit-id: 964034f46d6ca2b81817a2ab089f16c8664bfd02
pull/191/head
starlord 2019-08-26 12:31:13 +08:00
parent 1b795cd463
commit cd36c006f2
3 changed files with 152 additions and 127 deletions

View File

@ -18,7 +18,7 @@ GrpcRequestHandler::CreateTable(::grpc::ServerContext *context,
const ::milvus::grpc::TableSchema *request, const ::milvus::grpc::TableSchema *request,
::milvus::grpc::Status *response) { ::milvus::grpc::Status *response) {
BaseTaskPtr task_ptr = CreateTableTask::Create(*request); BaseTaskPtr task_ptr = CreateTableTask::Create(request);
GrpcRequestScheduler::ExecTask(task_ptr, response); GrpcRequestScheduler::ExecTask(task_ptr, response);
return ::grpc::Status::OK; return ::grpc::Status::OK;
} }
@ -52,7 +52,7 @@ GrpcRequestHandler::CreateIndex(::grpc::ServerContext *context,
const ::milvus::grpc::IndexParam *request, const ::milvus::grpc::IndexParam *request,
::milvus::grpc::Status *response) { ::milvus::grpc::Status *response) {
BaseTaskPtr task_ptr = CreateIndexTask::Create(*request); BaseTaskPtr task_ptr = CreateIndexTask::Create(request);
GrpcRequestScheduler::ExecTask(task_ptr, response); GrpcRequestScheduler::ExecTask(task_ptr, response);
return ::grpc::Status::OK; return ::grpc::Status::OK;
} }
@ -62,7 +62,7 @@ GrpcRequestHandler::Insert(::grpc::ServerContext *context,
const ::milvus::grpc::InsertParam *request, const ::milvus::grpc::InsertParam *request,
::milvus::grpc::VectorIds *response) { ::milvus::grpc::VectorIds *response) {
BaseTaskPtr task_ptr = InsertTask::Create(*request, *response); BaseTaskPtr task_ptr = InsertTask::Create(request, response);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
response->mutable_status()->set_reason(grpc_status.reason()); response->mutable_status()->set_reason(grpc_status.reason());
@ -76,7 +76,7 @@ GrpcRequestHandler::Search(::grpc::ServerContext *context,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) { ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) {
std::vector<std::string> file_id_array; std::vector<std::string> file_id_array;
BaseTaskPtr task_ptr = SearchTask::Create(*request, file_id_array, *writer); BaseTaskPtr task_ptr = SearchTask::Create(request, file_id_array, writer);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
if (grpc_status.error_code() != SERVER_SUCCESS) { if (grpc_status.error_code() != SERVER_SUCCESS) {
@ -93,7 +93,11 @@ GrpcRequestHandler::SearchInFiles(::grpc::ServerContext *context,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) { ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) {
std::vector<std::string> file_id_array; std::vector<std::string> file_id_array;
BaseTaskPtr task_ptr = SearchTask::Create(request->search_param(), file_id_array, *writer); for(int i = 0; i < request->file_id_array_size(); i++) {
file_id_array.push_back(request->file_id_array(i));
}
::milvus::grpc::SearchInFilesParam *request_mutable = const_cast<::milvus::grpc::SearchInFilesParam *>(request);
BaseTaskPtr task_ptr = SearchTask::Create(request_mutable->mutable_search_param(), file_id_array, writer);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
if (grpc_status.error_code() != SERVER_SUCCESS) { if (grpc_status.error_code() != SERVER_SUCCESS) {
@ -109,7 +113,7 @@ GrpcRequestHandler::DescribeTable(::grpc::ServerContext *context,
const ::milvus::grpc::TableName *request, const ::milvus::grpc::TableName *request,
::milvus::grpc::TableSchema *response) { ::milvus::grpc::TableSchema *response) {
BaseTaskPtr task_ptr = DescribeTableTask::Create(request->table_name(), *response); BaseTaskPtr task_ptr = DescribeTableTask::Create(request->table_name(), response);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
response->mutable_table_name()->mutable_status()->set_error_code(grpc_status.error_code()); response->mutable_table_name()->mutable_status()->set_error_code(grpc_status.error_code());
@ -137,7 +141,7 @@ GrpcRequestHandler::ShowTables(::grpc::ServerContext *context,
const ::milvus::grpc::Command *request, const ::milvus::grpc::Command *request,
::grpc::ServerWriter<::milvus::grpc::TableName> *writer) { ::grpc::ServerWriter<::milvus::grpc::TableName> *writer) {
BaseTaskPtr task_ptr = ShowTablesTask::Create(*writer); BaseTaskPtr task_ptr = ShowTablesTask::Create(writer);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
if (grpc_status.error_code() != SERVER_SUCCESS) { if (grpc_status.error_code() != SERVER_SUCCESS) {
@ -167,7 +171,7 @@ GrpcRequestHandler::Cmd(::grpc::ServerContext *context,
GrpcRequestHandler::DeleteByRange(::grpc::ServerContext *context, GrpcRequestHandler::DeleteByRange(::grpc::ServerContext *context,
const ::milvus::grpc::DeleteByRangeParam *request, const ::milvus::grpc::DeleteByRangeParam *request,
::milvus::grpc::Status *response) { ::milvus::grpc::Status *response) {
BaseTaskPtr task_ptr = DeleteByRangeTask::Create(*request); BaseTaskPtr task_ptr = DeleteByRangeTask::Create(request);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
response->set_error_code(grpc_status.error_code()); response->set_error_code(grpc_status.error_code());
@ -191,7 +195,7 @@ GrpcRequestHandler::PreloadTable(::grpc::ServerContext *context,
GrpcRequestHandler::DescribeIndex(::grpc::ServerContext *context, GrpcRequestHandler::DescribeIndex(::grpc::ServerContext *context,
const ::milvus::grpc::TableName *request, const ::milvus::grpc::TableName *request,
::milvus::grpc::IndexParam *response) { ::milvus::grpc::IndexParam *response) {
BaseTaskPtr task_ptr = DescribeIndexTask::Create(request->table_name(), *response); BaseTaskPtr task_ptr = DescribeIndexTask::Create(request->table_name(), response);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
response->mutable_table_name()->mutable_status()->set_reason(grpc_status.reason()); response->mutable_table_name()->mutable_status()->set_reason(grpc_status.reason());

View File

@ -107,14 +107,18 @@ namespace {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
CreateTableTask::CreateTableTask(const ::milvus::grpc::TableSchema &schema) CreateTableTask::CreateTableTask(const ::milvus::grpc::TableSchema *schema)
: GrpcBaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
schema_(schema) { schema_(schema) {
} }
BaseTaskPtr BaseTaskPtr
CreateTableTask::Create(const ::milvus::grpc::TableSchema &schema) { CreateTableTask::Create(const ::milvus::grpc::TableSchema *schema) {
if(schema == nullptr) {
SERVER_LOG_ERROR << "grpc input is null!";
return nullptr;
}
return std::shared_ptr<GrpcBaseTask>(new CreateTableTask(schema)); return std::shared_ptr<GrpcBaseTask>(new CreateTableTask(schema));
} }
@ -124,26 +128,26 @@ CreateTableTask::OnExecute() {
try { try {
//step 1: check arguments //step 1: check arguments
ServerError res = ValidationUtil::ValidateTableName(schema_.table_name().table_name()); ServerError res = ValidationUtil::ValidateTableName(schema_->table_name().table_name());
if (res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + schema_.table_name().table_name()); return SetError(res, "Invalid table name: " + schema_->table_name().table_name());
} }
res = ValidationUtil::ValidateTableDimension(schema_.dimension()); res = ValidationUtil::ValidateTableDimension(schema_->dimension());
if (res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table dimension: " + std::to_string(schema_.dimension())); return SetError(res, "Invalid table dimension: " + std::to_string(schema_->dimension()));
} }
res = ValidationUtil::ValidateTableIndexFileSize(schema_.index_file_size()); res = ValidationUtil::ValidateTableIndexFileSize(schema_->index_file_size());
if(res != SERVER_SUCCESS) { if(res != SERVER_SUCCESS) {
return SetError(res, "Invalid index file size: " + std::to_string(schema_.index_file_size())); return SetError(res, "Invalid index file size: " + std::to_string(schema_->index_file_size()));
} }
//step 2: construct table schema //step 2: construct table schema
engine::meta::TableSchema table_info; engine::meta::TableSchema table_info;
table_info.table_id_ = schema_.table_name().table_name(); table_info.table_id_ = schema_->table_name().table_name();
table_info.dimension_ = (uint16_t) schema_.dimension(); table_info.dimension_ = (uint16_t) schema_->dimension();
table_info.index_file_size_ = schema_.index_file_size(); table_info.index_file_size_ = schema_->index_file_size();
//step 3: create table //step 3: create table
engine::Status stat = DBWrapper::DB()->CreateTable(table_info); engine::Status stat = DBWrapper::DB()->CreateTable(table_info);
@ -162,14 +166,14 @@ CreateTableTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DescribeTableTask::DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema &schema) DescribeTableTask::DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema *schema)
: GrpcBaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name), table_name_(table_name),
schema_(schema) { schema_(schema) {
} }
BaseTaskPtr BaseTaskPtr
DescribeTableTask::Create(const std::string &table_name, ::milvus::grpc::TableSchema &schema) { DescribeTableTask::Create(const std::string &table_name, ::milvus::grpc::TableSchema *schema) {
return std::shared_ptr<GrpcBaseTask>(new DescribeTableTask(table_name, schema)); return std::shared_ptr<GrpcBaseTask>(new DescribeTableTask(table_name, schema));
} }
@ -192,8 +196,8 @@ DescribeTableTask::OnExecute() {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
} }
schema_.mutable_table_name()->set_table_name(table_info.table_id_); schema_->mutable_table_name()->set_table_name(table_info.table_id_);
schema_.set_dimension(table_info.dimension_); schema_->set_dimension(table_info.dimension_);
} catch (std::exception &ex) { } catch (std::exception &ex) {
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
@ -205,13 +209,17 @@ DescribeTableTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
CreateIndexTask::CreateIndexTask(const ::milvus::grpc::IndexParam &index_param) CreateIndexTask::CreateIndexTask(const ::milvus::grpc::IndexParam *index_param)
: GrpcBaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
index_param_(index_param) { index_param_(index_param) {
} }
BaseTaskPtr BaseTaskPtr
CreateIndexTask::Create(const ::milvus::grpc::IndexParam &index_param) { CreateIndexTask::Create(const ::milvus::grpc::IndexParam *index_param) {
if(index_param == nullptr) {
SERVER_LOG_ERROR << "grpc input is null!";
return nullptr;
}
return std::shared_ptr<GrpcBaseTask>(new CreateIndexTask(index_param)); return std::shared_ptr<GrpcBaseTask>(new CreateIndexTask(index_param));
} }
@ -221,7 +229,7 @@ CreateIndexTask::OnExecute() {
TimeRecorder rc("CreateIndexTask"); TimeRecorder rc("CreateIndexTask");
//step 1: check arguments //step 1: check arguments
std::string table_name_ = index_param_.table_name().table_name(); std::string table_name_ = index_param_->table_name().table_name();
ServerError res = ValidationUtil::ValidateTableName(table_name_); ServerError res = ValidationUtil::ValidateTableName(table_name_);
if (res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + table_name_); return SetError(res, "Invalid table name: " + table_name_);
@ -237,26 +245,27 @@ CreateIndexTask::OnExecute() {
return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists");
} }
res = ValidationUtil::ValidateTableIndexType(index_param_.mutable_index()->index_type()); auto &grpc_index = index_param_->index();
res = ValidationUtil::ValidateTableIndexType(grpc_index.index_type());
if(res != SERVER_SUCCESS) { if(res != SERVER_SUCCESS) {
return SetError(res, "Invalid index type: " + std::to_string(index_param_.mutable_index()->index_type())); return SetError(res, "Invalid index type: " + std::to_string(grpc_index.index_type()));
} }
res = ValidationUtil::ValidateTableIndexNlist(index_param_.mutable_index()->nlist()); res = ValidationUtil::ValidateTableIndexNlist(grpc_index.nlist());
if(res != SERVER_SUCCESS) { if(res != SERVER_SUCCESS) {
return SetError(res, "Invalid index nlist: " + std::to_string(index_param_.mutable_index()->nlist())); return SetError(res, "Invalid index nlist: " + std::to_string(grpc_index.nlist()));
} }
res = ValidationUtil::ValidateTableIndexMetricType(index_param_.mutable_index()->metric_type()); res = ValidationUtil::ValidateTableIndexMetricType(grpc_index.metric_type());
if(res != SERVER_SUCCESS) { if(res != SERVER_SUCCESS) {
return SetError(res, "Invalid index metric type: " + std::to_string(index_param_.mutable_index()->metric_type())); return SetError(res, "Invalid index metric type: " + std::to_string(grpc_index.metric_type()));
} }
//step 2: check table existence //step 2: check table existence
engine::TableIndex index; engine::TableIndex index;
index.engine_type_ = index_param_.mutable_index()->index_type(); index.engine_type_ = grpc_index.index_type();
index.nlist_ = index_param_.mutable_index()->nlist(); index.nlist_ = grpc_index.nlist();
index.metric_type_ = index_param_.mutable_index()->metric_type(); index.metric_type_ = grpc_index.metric_type();
stat = DBWrapper::DB()->CreateIndex(table_name_, index); stat = DBWrapper::DB()->CreateIndex(table_name_, index);
if (!stat.ok()) { if (!stat.ok()) {
return SetError(SERVER_BUILD_INDEX_ERROR, "Engine failed: " + stat.ToString()); return SetError(SERVER_BUILD_INDEX_ERROR, "Engine failed: " + stat.ToString());
@ -361,14 +370,14 @@ DropTableTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
ShowTablesTask::ShowTablesTask(::grpc::ServerWriter<::milvus::grpc::TableName> &writer) ShowTablesTask::ShowTablesTask(::grpc::ServerWriter<::milvus::grpc::TableName> *writer)
: GrpcBaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
writer_(writer) { writer_(writer) {
} }
BaseTaskPtr BaseTaskPtr
ShowTablesTask::Create(::grpc::ServerWriter<::milvus::grpc::TableName> &writer) { ShowTablesTask::Create(::grpc::ServerWriter<::milvus::grpc::TableName> *writer) {
return std::shared_ptr<GrpcBaseTask>(new ShowTablesTask(writer)); return std::shared_ptr<GrpcBaseTask>(new ShowTablesTask(writer));
} }
@ -383,7 +392,7 @@ ShowTablesTask::OnExecute() {
for (auto &schema : schema_array) { for (auto &schema : schema_array) {
::milvus::grpc::TableName tableName; ::milvus::grpc::TableName tableName;
tableName.set_table_name(schema.table_id_); tableName.set_table_name(schema.table_id_);
if (!writer_.Write(tableName)) { if (!writer_->Write(tableName)) {
return SetError(SERVER_WRITE_ERROR, "Write table name failed!"); return SetError(SERVER_WRITE_ERROR, "Write table name failed!");
} }
} }
@ -391,17 +400,21 @@ ShowTablesTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
InsertTask::InsertTask(const ::milvus::grpc::InsertParam &insert_param, InsertTask::InsertTask(const ::milvus::grpc::InsertParam *insert_param,
::milvus::grpc::VectorIds &record_ids) ::milvus::grpc::VectorIds *record_ids)
: GrpcBaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
insert_param_(insert_param), insert_param_(insert_param),
record_ids_(record_ids) { record_ids_(record_ids) {
record_ids_.Clear(); record_ids_->Clear();
} }
BaseTaskPtr BaseTaskPtr
InsertTask::Create(const ::milvus::grpc::InsertParam &insert_param, InsertTask::Create(const ::milvus::grpc::InsertParam *insert_param,
::milvus::grpc::VectorIds &record_ids) { ::milvus::grpc::VectorIds *record_ids) {
if(insert_param == nullptr) {
SERVER_LOG_ERROR << "grpc input is null!";
return nullptr;
}
return std::shared_ptr<GrpcBaseTask>(new InsertTask(insert_param, record_ids)); return std::shared_ptr<GrpcBaseTask>(new InsertTask(insert_param, record_ids));
} }
@ -411,16 +424,16 @@ InsertTask::OnExecute() {
TimeRecorder rc("InsertVectorTask"); TimeRecorder rc("InsertVectorTask");
//step 1: check arguments //step 1: check arguments
ServerError res = ValidationUtil::ValidateTableName(insert_param_.table_name()); ServerError res = ValidationUtil::ValidateTableName(insert_param_->table_name());
if (res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + insert_param_.table_name()); return SetError(res, "Invalid table name: " + insert_param_->table_name());
} }
if (insert_param_.row_record_array().empty()) { if (insert_param_->row_record_array().empty()) {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty"); return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty");
} }
if (!record_ids_.vector_id_array().empty()) { if (!record_ids_->vector_id_array().empty()) {
if (record_ids_.vector_id_array().size() != insert_param_.row_record_array_size()) { if (record_ids_->vector_id_array().size() != insert_param_->row_record_array_size()) {
return SetError(SERVER_ILLEGAL_VECTOR_ID, return SetError(SERVER_ILLEGAL_VECTOR_ID,
"Size of vector ids is not equal to row record array size"); "Size of vector ids is not equal to row record array size");
} }
@ -428,12 +441,12 @@ InsertTask::OnExecute() {
//step 2: check table existence //step 2: check table existence
engine::meta::TableSchema table_info; engine::meta::TableSchema table_info;
table_info.table_id_ = insert_param_.table_name(); table_info.table_id_ = insert_param_->table_name();
engine::Status stat = DBWrapper::DB()->DescribeTable(table_info); engine::Status stat = DBWrapper::DB()->DescribeTable(table_info);
if (!stat.ok()) { if (!stat.ok()) {
if (stat.IsNotFound()) { if (stat.IsNotFound()) {
return SetError(SERVER_TABLE_NOT_EXIST, return SetError(SERVER_TABLE_NOT_EXIST,
"Table " + insert_param_.table_name() + " not exists"); "Table " + insert_param_->table_name() + " not exists");
} else { } else {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
} }
@ -443,7 +456,7 @@ InsertTask::OnExecute() {
uint64_t row_count = 0; uint64_t row_count = 0;
DBWrapper::DB()->GetTableRowCount(table_info.table_id_, row_count); DBWrapper::DB()->GetTableRowCount(table_info.table_id_, row_count);
bool empty_table = (row_count == 0); bool empty_table = (row_count == 0);
bool user_provide_ids = !insert_param_.row_id_array().empty(); bool user_provide_ids = !insert_param_->row_id_array().empty();
if(!empty_table) { if(!empty_table) {
//user already provided id before, all insert action require user id //user already provided id before, all insert action require user id
if(engine::utils::UserDefinedId(table_info.flag_) && !user_provide_ids) { if(engine::utils::UserDefinedId(table_info.flag_) && !user_provide_ids) {
@ -465,14 +478,14 @@ InsertTask::OnExecute() {
#endif #endif
//step 3: prepare float data //step 3: prepare float data
std::vector<float> vec_f(insert_param_.row_record_array_size() * table_info.dimension_, 0); std::vector<float> vec_f(insert_param_->row_record_array_size() * table_info.dimension_, 0);
// TODO: change to one dimension array in protobuf or use multiple-thread to copy the data // TODO: change to one dimension array in protobuf or use multiple-thread to copy the data
for (size_t i = 0; i < insert_param_.row_record_array_size(); i++) { for (size_t i = 0; i < insert_param_->row_record_array_size(); i++) {
if (insert_param_.row_record_array(i).vector_data().empty()) { if (insert_param_->row_record_array(i).vector_data().empty()) {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record float array is empty"); return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record float array is empty");
} }
uint64_t vec_dim = insert_param_.row_record_array(i).vector_data().size(); uint64_t vec_dim = insert_param_->row_record_array(i).vector_data().size();
if (vec_dim != table_info.dimension_) { if (vec_dim != table_info.dimension_) {
ServerError error_code = SERVER_INVALID_VECTOR_DIMENSION; ServerError error_code = SERVER_INVALID_VECTOR_DIMENSION;
std::string error_msg = "Invalid rowrecord dimension: " + std::to_string(vec_dim) std::string error_msg = "Invalid rowrecord dimension: " + std::to_string(vec_dim)
@ -481,31 +494,31 @@ InsertTask::OnExecute() {
return SetError(error_code, error_msg); return SetError(error_code, error_msg);
} }
memcpy(&vec_f[i * table_info.dimension_], memcpy(&vec_f[i * table_info.dimension_],
insert_param_.row_record_array(i).vector_data().data(), insert_param_->row_record_array(i).vector_data().data(),
table_info.dimension_ * sizeof(float)); table_info.dimension_ * sizeof(float));
} }
rc.ElapseFromBegin("prepare vectors data"); rc.ElapseFromBegin("prepare vectors data");
//step 4: insert vectors //step 4: insert vectors
auto vec_count = (uint64_t) insert_param_.row_record_array_size(); auto vec_count = (uint64_t) insert_param_->row_record_array_size();
std::vector<int64_t> vec_ids(insert_param_.row_id_array_size(), 0); std::vector<int64_t> vec_ids(insert_param_->row_id_array_size(), 0);
if(!insert_param_.row_id_array().empty()) { if(!insert_param_->row_id_array().empty()) {
const int64_t* src_data = insert_param_.row_id_array().data(); const int64_t* src_data = insert_param_->row_id_array().data();
int64_t* target_data = vec_ids.data(); int64_t* target_data = vec_ids.data();
memcpy(target_data, src_data, (size_t)(sizeof(int64_t)*insert_param_.row_id_array_size())); memcpy(target_data, src_data, (size_t)(sizeof(int64_t)*insert_param_->row_id_array_size()));
} }
stat = DBWrapper::DB()->InsertVectors(insert_param_.table_name(), vec_count, vec_f.data(), vec_ids); stat = DBWrapper::DB()->InsertVectors(insert_param_->table_name(), vec_count, vec_f.data(), vec_ids);
rc.ElapseFromBegin("add vectors to engine"); rc.ElapseFromBegin("add vectors to engine");
if (!stat.ok()) { if (!stat.ok()) {
return SetError(SERVER_CACHE_ERROR, "Cache error: " + stat.ToString()); return SetError(SERVER_CACHE_ERROR, "Cache error: " + stat.ToString());
} }
for (int64_t id : vec_ids) { for (int64_t id : vec_ids) {
record_ids_.add_vector_id_array(id); record_ids_->add_vector_id_array(id);
} }
auto ids_size = record_ids_.vector_id_array_size(); auto ids_size = record_ids_->vector_id_array_size();
if (ids_size != vec_count) { if (ids_size != vec_count) {
std::string msg = "Add " + std::to_string(vec_count) + " vectors but only return " std::string msg = "Add " + std::to_string(vec_count) + " vectors but only return "
+ std::to_string(ids_size) + " id"; + std::to_string(ids_size) + " id";
@ -514,7 +527,7 @@ InsertTask::OnExecute() {
//step 5: update table flag //step 5: update table flag
if(empty_table && user_provide_ids) { if(empty_table && user_provide_ids) {
stat = DBWrapper::DB()->UpdateTableFlag(insert_param_.table_name(), stat = DBWrapper::DB()->UpdateTableFlag(insert_param_->table_name(),
table_info.flag_ | engine::meta::FLAG_MASK_USERID); table_info.flag_ | engine::meta::FLAG_MASK_USERID);
} }
@ -533,9 +546,9 @@ InsertTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchTask::SearchTask(const ::milvus::grpc::SearchParam &search_vector_infos, SearchTask::SearchTask(const ::milvus::grpc::SearchParam *search_vector_infos,
const std::vector<std::string> &file_id_array, const std::vector<std::string> &file_id_array,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer) ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer)
: GrpcBaseTask(DQL_TASK_GROUP), : GrpcBaseTask(DQL_TASK_GROUP),
search_param_(search_vector_infos), search_param_(search_vector_infos),
file_id_array_(file_id_array), file_id_array_(file_id_array),
@ -544,9 +557,13 @@ SearchTask::SearchTask(const ::milvus::grpc::SearchParam &search_vector_infos,
} }
BaseTaskPtr BaseTaskPtr
SearchTask::Create(const ::milvus::grpc::SearchParam &search_vector_infos, SearchTask::Create(const ::milvus::grpc::SearchParam *search_vector_infos,
const std::vector<std::string> &file_id_array, const std::vector<std::string> &file_id_array,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer) { ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) {
if(search_vector_infos == nullptr) {
SERVER_LOG_ERROR << "grpc input is null!";
return nullptr;
}
return std::shared_ptr<GrpcBaseTask>(new SearchTask(search_vector_infos, file_id_array, return std::shared_ptr<GrpcBaseTask>(new SearchTask(search_vector_infos, file_id_array,
writer)); writer));
} }
@ -557,24 +574,24 @@ SearchTask::OnExecute() {
TimeRecorder rc("SearchTask"); TimeRecorder rc("SearchTask");
//step 1: check arguments //step 1: check arguments
std::string table_name_ = search_param_.table_name(); std::string table_name_ = search_param_->table_name();
ServerError res = ValidationUtil::ValidateTableName(table_name_); ServerError res = ValidationUtil::ValidateTableName(table_name_);
if (res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + table_name_); return SetError(res, "Invalid table name: " + table_name_);
} }
int64_t top_k_ = search_param_.topk(); int64_t top_k_ = search_param_->topk();
if (top_k_ <= 0 || top_k_ > 1024) { if (top_k_ <= 0 || top_k_ > 1024) {
return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string(top_k_)); return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string(top_k_));
} }
int64_t nprobe = search_param_.nprobe(); int64_t nprobe = search_param_->nprobe();
if (nprobe <= 0) { if (nprobe <= 0) {
return SetError(SERVER_INVALID_NPROBE, "Invalid nprobe: " + std::to_string(nprobe)); return SetError(SERVER_INVALID_NPROBE, "Invalid nprobe: " + std::to_string(nprobe));
} }
if (search_param_.query_record_array().empty()) { if (search_param_->query_record_array().empty()) {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty"); return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty");
} }
@ -596,8 +613,8 @@ SearchTask::OnExecute() {
std::string error_msg; std::string error_msg;
std::vector<::milvus::grpc::Range> range_array; std::vector<::milvus::grpc::Range> range_array;
for (size_t i = 0; i < search_param_.query_range_array_size(); i++) { for (size_t i = 0; i < search_param_->query_range_array_size(); i++) {
range_array.emplace_back(search_param_.query_range_array(i)); range_array.emplace_back(search_param_->query_range_array(i));
} }
ConvertTimeRangeToDBDates(range_array, dates, error_code, error_msg); ConvertTimeRangeToDBDates(range_array, dates, error_code, error_msg);
if (error_code != SERVER_SUCCESS) { if (error_code != SERVER_SUCCESS) {
@ -614,13 +631,13 @@ SearchTask::OnExecute() {
#endif #endif
//step 3: prepare float data //step 3: prepare float data
auto record_array_size = search_param_.query_record_array_size(); auto record_array_size = search_param_->query_record_array_size();
std::vector<float> vec_f(record_array_size * table_info.dimension_, 0); std::vector<float> vec_f(record_array_size * table_info.dimension_, 0);
for (size_t i = 0; i < record_array_size; i++) { for (size_t i = 0; i < record_array_size; i++) {
if (search_param_.query_record_array(i).vector_data().empty()) { if (search_param_->query_record_array(i).vector_data().empty()) {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Query record float array is empty"); return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Query record float array is empty");
} }
uint64_t query_vec_dim = search_param_.query_record_array(i).vector_data().size(); uint64_t query_vec_dim = search_param_->query_record_array(i).vector_data().size();
if (query_vec_dim != table_info.dimension_) { if (query_vec_dim != table_info.dimension_) {
ServerError error_code = SERVER_INVALID_VECTOR_DIMENSION; ServerError error_code = SERVER_INVALID_VECTOR_DIMENSION;
std::string error_msg = "Invalid rowrecord dimension: " + std::to_string(query_vec_dim) std::string error_msg = "Invalid rowrecord dimension: " + std::to_string(query_vec_dim)
@ -629,14 +646,14 @@ SearchTask::OnExecute() {
} }
memcpy(&vec_f[i * table_info.dimension_], memcpy(&vec_f[i * table_info.dimension_],
search_param_.query_record_array(i).vector_data().data(), search_param_->query_record_array(i).vector_data().data(),
table_info.dimension_ * sizeof(float)); table_info.dimension_ * sizeof(float));
} }
rc.ElapseFromBegin("prepare vector data"); rc.ElapseFromBegin("prepare vector data");
//step 4: search vectors //step 4: search vectors
engine::QueryResults results; engine::QueryResults results;
auto record_count = (uint64_t) search_param_.query_record_array().size(); auto record_count = (uint64_t) search_param_->query_record_array().size();
if (file_id_array_.empty()) { if (file_id_array_.empty()) {
stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, nprobe, vec_f.data(), stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, nprobe, vec_f.data(),
@ -666,14 +683,14 @@ SearchTask::OnExecute() {
//step 5: construct result array //step 5: construct result array
for (uint64_t i = 0; i < record_count; i++) { for (uint64_t i = 0; i < record_count; i++) {
auto &result = results[i]; auto &result = results[i];
const auto &record = search_param_.query_record_array(i); const auto &record = search_param_->query_record_array(i);
::milvus::grpc::TopKQueryResult grpc_topk_result; ::milvus::grpc::TopKQueryResult grpc_topk_result;
for (auto &pair : result) { for (auto &pair : result) {
::milvus::grpc::QueryResult *grpc_result = grpc_topk_result.add_query_result_arrays(); ::milvus::grpc::QueryResult *grpc_result = grpc_topk_result.add_query_result_arrays();
grpc_result->set_id(pair.first); grpc_result->set_id(pair.first);
grpc_result->set_distance(pair.second); grpc_result->set_distance(pair.second);
} }
if (!writer_.Write(grpc_topk_result)) { if (!writer_->Write(grpc_topk_result)) {
return SetError(SERVER_WRITE_ERROR, "Write topk result failed!"); return SetError(SERVER_WRITE_ERROR, "Write topk result failed!");
} }
} }
@ -765,13 +782,17 @@ CmdTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteByRangeTask::DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam &delete_by_range_param) DeleteByRangeTask::DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param)
: GrpcBaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
delete_by_range_param_(delete_by_range_param){ delete_by_range_param_(delete_by_range_param){
} }
BaseTaskPtr BaseTaskPtr
DeleteByRangeTask::Create(const ::milvus::grpc::DeleteByRangeParam &delete_by_range_param) { DeleteByRangeTask::Create(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param) {
if(delete_by_range_param == nullptr) {
SERVER_LOG_ERROR << "grpc input is null!";
return nullptr;
}
return std::shared_ptr<GrpcBaseTask>(new DeleteByRangeTask(delete_by_range_param)); return std::shared_ptr<GrpcBaseTask>(new DeleteByRangeTask(delete_by_range_param));
} }
@ -781,7 +802,7 @@ DeleteByRangeTask::OnExecute() {
TimeRecorder rc("DeleteByRangeTask"); TimeRecorder rc("DeleteByRangeTask");
//step 1: check arguments //step 1: check arguments
std::string table_name = delete_by_range_param_.table_name(); std::string table_name = delete_by_range_param_->table_name();
ServerError res = ValidationUtil::ValidateTableName(table_name); ServerError res = ValidationUtil::ValidateTableName(table_name);
if (res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + table_name); return SetError(res, "Invalid table name: " + table_name);
@ -807,7 +828,7 @@ DeleteByRangeTask::OnExecute() {
std::string error_msg; std::string error_msg;
std::vector<::milvus::grpc::Range> range_array; std::vector<::milvus::grpc::Range> range_array;
range_array.emplace_back(delete_by_range_param_.range()); range_array.emplace_back(delete_by_range_param_->range());
ConvertTimeRangeToDBDates(range_array, dates, error_code, error_msg); ConvertTimeRangeToDBDates(range_array, dates, error_code, error_msg);
if (error_code != SERVER_SUCCESS) { if (error_code != SERVER_SUCCESS) {
return SetError(error_code, error_msg); return SetError(error_code, error_msg);
@ -870,7 +891,7 @@ PreloadTableTask::OnExecute() {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DescribeIndexTask::DescribeIndexTask(const std::string &table_name, DescribeIndexTask::DescribeIndexTask(const std::string &table_name,
::milvus::grpc::IndexParam &index_param) ::milvus::grpc::IndexParam *index_param)
: GrpcBaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name), table_name_(table_name),
index_param_(index_param) { index_param_(index_param) {
@ -879,7 +900,7 @@ DescribeIndexTask::DescribeIndexTask(const std::string &table_name,
BaseTaskPtr BaseTaskPtr
DescribeIndexTask::Create(const std::string &table_name, DescribeIndexTask::Create(const std::string &table_name,
::milvus::grpc::IndexParam &index_param){ ::milvus::grpc::IndexParam *index_param){
return std::shared_ptr<GrpcBaseTask>(new DescribeIndexTask(table_name, index_param)); return std::shared_ptr<GrpcBaseTask>(new DescribeIndexTask(table_name, index_param));
} }
@ -901,10 +922,10 @@ DescribeIndexTask::OnExecute() {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
} }
index_param_.mutable_table_name()->set_table_name(table_name_); index_param_->mutable_table_name()->set_table_name(table_name_);
index_param_.mutable_index()->set_index_type(index.engine_type_); index_param_->mutable_index()->set_index_type(index.engine_type_);
index_param_.mutable_index()->set_nlist(index.nlist_); index_param_->mutable_index()->set_nlist(index.nlist_);
index_param_.mutable_index()->set_metric_type(index.metric_type_); index_param_->mutable_index()->set_metric_type(index.metric_type_);
rc.ElapseFromBegin("totally cost"); rc.ElapseFromBegin("totally cost");
} catch (std::exception &ex) { } catch (std::exception &ex) {

View File

@ -23,17 +23,17 @@ namespace grpc {
class CreateTableTask : public GrpcBaseTask { class CreateTableTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const ::milvus::grpc::TableSchema &schema); Create(const ::milvus::grpc::TableSchema *schema);
protected: protected:
explicit explicit
CreateTableTask(const ::milvus::grpc::TableSchema &request); CreateTableTask(const ::milvus::grpc::TableSchema *request);
ServerError ServerError
OnExecute() override; OnExecute() override;
private: private:
const ::milvus::grpc::TableSchema schema_; const ::milvus::grpc::TableSchema *schema_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -58,10 +58,10 @@ private:
class DescribeTableTask : public GrpcBaseTask { class DescribeTableTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const std::string &table_name, ::milvus::grpc::TableSchema &schema); Create(const std::string &table_name, ::milvus::grpc::TableSchema *schema);
protected: protected:
DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema &schema); DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema *schema);
ServerError ServerError
OnExecute() override; OnExecute() override;
@ -69,7 +69,7 @@ protected:
private: private:
std::string table_name_; std::string table_name_;
::milvus::grpc::TableSchema &schema_; ::milvus::grpc::TableSchema *schema_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -94,76 +94,76 @@ private:
class CreateIndexTask : public GrpcBaseTask { class CreateIndexTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const ::milvus::grpc::IndexParam &index_Param); Create(const ::milvus::grpc::IndexParam *index_Param);
protected: protected:
explicit explicit
CreateIndexTask(const ::milvus::grpc::IndexParam &index_Param); CreateIndexTask(const ::milvus::grpc::IndexParam *index_Param);
ServerError ServerError
OnExecute() override; OnExecute() override;
private: private:
::milvus::grpc::IndexParam index_param_; const ::milvus::grpc::IndexParam *index_param_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class ShowTablesTask : public GrpcBaseTask { class ShowTablesTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(::grpc::ServerWriter<::milvus::grpc::TableName> &writer); Create(::grpc::ServerWriter<::milvus::grpc::TableName> *writer);
protected: protected:
explicit explicit
ShowTablesTask(::grpc::ServerWriter<::milvus::grpc::TableName> &writer); ShowTablesTask(::grpc::ServerWriter<::milvus::grpc::TableName> *writer);
ServerError ServerError
OnExecute() override; OnExecute() override;
private: private:
::grpc::ServerWriter<::milvus::grpc::TableName> writer_; ::grpc::ServerWriter<::milvus::grpc::TableName> *writer_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class InsertTask : public GrpcBaseTask { class InsertTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const ::milvus::grpc::InsertParam &insert_Param, Create(const ::milvus::grpc::InsertParam *insert_Param,
::milvus::grpc::VectorIds &record_ids_); ::milvus::grpc::VectorIds *record_ids_);
protected: protected:
InsertTask(const ::milvus::grpc::InsertParam &insert_Param, InsertTask(const ::milvus::grpc::InsertParam *insert_Param,
::milvus::grpc::VectorIds &record_ids_); ::milvus::grpc::VectorIds *record_ids_);
ServerError ServerError
OnExecute() override; OnExecute() override;
private: private:
const ::milvus::grpc::InsertParam insert_param_; const ::milvus::grpc::InsertParam *insert_param_;
::milvus::grpc::VectorIds &record_ids_; ::milvus::grpc::VectorIds *record_ids_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SearchTask : public GrpcBaseTask { class SearchTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const ::milvus::grpc::SearchParam &search_param, Create(const ::milvus::grpc::SearchParam *search_param,
const std::vector<std::string> &file_id_array, const std::vector<std::string> &file_id_array,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer); ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer);
protected: protected:
SearchTask(const ::milvus::grpc::SearchParam &search_param, SearchTask(const ::milvus::grpc::SearchParam *search_param,
const std::vector<std::string> &file_id_array, const std::vector<std::string> &file_id_array,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer); ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer);
ServerError ServerError
OnExecute() override; OnExecute() override;
private: private:
const ::milvus::grpc::SearchParam search_param_; const ::milvus::grpc::SearchParam *search_param_;
std::vector<std::string> file_id_array_; std::vector<std::string> file_id_array_;
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> writer_; ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -204,16 +204,16 @@ private:
class DeleteByRangeTask : public GrpcBaseTask { class DeleteByRangeTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const ::milvus::grpc::DeleteByRangeParam &delete_by_range_param); Create(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param);
protected: protected:
DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam &delete_by_range_param); DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param);
ServerError ServerError
OnExecute() override; OnExecute() override;
private: private:
::milvus::grpc::DeleteByRangeParam delete_by_range_param_; const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -237,18 +237,18 @@ class DescribeIndexTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const std::string &table_name, Create(const std::string &table_name,
::milvus::grpc::IndexParam &index_param); ::milvus::grpc::IndexParam *index_param);
protected: protected:
DescribeIndexTask(const std::string &table_name, DescribeIndexTask(const std::string &table_name,
::milvus::grpc::IndexParam &index_param); ::milvus::grpc::IndexParam *index_param);
ServerError ServerError
OnExecute() override; OnExecute() override;
private: private:
std::string table_name_; std::string table_name_;
::milvus::grpc::IndexParam& index_param_; ::milvus::grpc::IndexParam *index_param_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////