mirror of https://github.com/milvus-io/milvus.git
reduce insert memory copy (#3842)
* avoid memory copy for insert request Signed-off-by: groot <yihua.mo@zilliz.com> * typo Signed-off-by: groot <yihua.mo@zilliz.com> * fix unittest failure Signed-off-by: groot <yihua.mo@zilliz.com> * add more log Signed-off-by: groot <yihua.mo@zilliz.com> * typo Signed-off-by: groot <yihua.mo@zilliz.com> * typo Signed-off-by: groot <yihua.mo@zilliz.com> * add log Signed-off-by: groot <yihua.mo@zilliz.com> * add log Signed-off-by: groot <yihua.mo@zilliz.com> * typo Signed-off-by: groot <yihua.mo@zilliz.com> * refine code Signed-off-by: groot <yihua.mo@zilliz.com> * format code Signed-off-by: groot <yihua.mo@zilliz.com> Signed-off-by: shengjun.li <shengjun.li@zilliz.com>pull/3916/head
parent
cd70c122f1
commit
7785f44ef4
|
@ -517,6 +517,8 @@ DBImpl::Insert(const std::string& collection_name, const std::string& partition_
|
|||
|
||||
// generate id
|
||||
if (auto_genid) {
|
||||
LOG_SERVER_DEBUG_ << "Auto generate entities id";
|
||||
|
||||
SafeIDGenerator& id_generator = SafeIDGenerator::GetInstance();
|
||||
IDNumbers ids;
|
||||
STATUS_CHECK(id_generator.GetNextIDNumbers(consume_chunk->count_, ids));
|
||||
|
@ -541,6 +543,7 @@ DBImpl::Insert(const std::string& collection_name, const std::string& partition_
|
|||
std::vector<DataChunkPtr> chunks;
|
||||
STATUS_CHECK(utils::SplitChunk(consume_chunk, segment_row_count, chunks));
|
||||
|
||||
LOG_ENGINE_DEBUG_ << "Insert entities into mem manager";
|
||||
for (auto& chunk : chunks) {
|
||||
auto status = mem_mgr_->InsertEntities(collection_id, partition_id, chunk, op_id);
|
||||
if (!status.ok()) {
|
||||
|
|
|
@ -215,6 +215,8 @@ SplitChunk(const DataChunkPtr& chunk, int64_t segment_row_count, std::vector<Dat
|
|||
}
|
||||
}
|
||||
|
||||
LOG_SERVER_DEBUG_ << "Split chunk since the chunk row count greater than segment_row_limit";
|
||||
|
||||
// secondly, copy new chunk
|
||||
int64_t copied_count = 0;
|
||||
while (copied_count < chunk_count) {
|
||||
|
@ -249,6 +251,11 @@ SplitChunk(const DataChunkPtr& chunk, int64_t segment_row_count, std::vector<Dat
|
|||
chunks.emplace_back(new_chunk);
|
||||
}
|
||||
|
||||
// data has been copied, do this to free memory
|
||||
chunk->fixed_fields_.clear();
|
||||
chunk->variable_fields_.clear();
|
||||
chunk->count_ = 0;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -171,7 +171,7 @@ MemCollection::ApplyDeleteToFile() {
|
|||
segment::SegmentReaderPtr segment_reader =
|
||||
std::make_shared<segment::SegmentReader>(options_.meta_.path_, seg_visitor);
|
||||
|
||||
// Step 1: Check to-delete id possibly in this segment
|
||||
// Step 1: Check to-delete id possissbly in this segment
|
||||
std::unordered_set<idx_t> ids_to_check;
|
||||
segment::IdBloomFilterPtr pre_bloom_filter;
|
||||
STATUS_CHECK(segment_reader->LoadBloomFilter(pre_bloom_filter));
|
||||
|
|
|
@ -99,12 +99,6 @@ WalProxy::Insert(const std::string& collection_name, const std::string& partitio
|
|||
// split chunk accordding to segment row count
|
||||
std::vector<DataChunkPtr> chunks;
|
||||
STATUS_CHECK(utils::SplitChunk(data_chunk, row_count_per_segment, chunks));
|
||||
if (chunks.size() > 0 && data_chunk != chunks[0]) {
|
||||
// data has been copied to new chunk, do this to free memory
|
||||
data_chunk->fixed_fields_.clear();
|
||||
data_chunk->variable_fields_.clear();
|
||||
data_chunk->count_ = 0;
|
||||
}
|
||||
|
||||
// write operation into wal file, and insert to memory
|
||||
for (auto& chunk : chunks) {
|
||||
|
|
|
@ -491,8 +491,14 @@ ValidatePartitionTags(const std::vector<std::string>& partition_tags) {
|
|||
}
|
||||
|
||||
Status
|
||||
ValidateInsertDataSize(const engine::DataChunkPtr& data) {
|
||||
int64_t chunk_size = engine::utils::GetSizeOfChunk(data);
|
||||
ValidateInsertDataSize(const InsertParam& insert_param) {
|
||||
int64_t chunk_size = 0;
|
||||
for (auto& pair : insert_param.fields_data_) {
|
||||
for (auto& data : pair.second) {
|
||||
chunk_size += data.second;
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk_size > engine::MAX_INSERT_DATA_SIZE) {
|
||||
std::string msg = "The amount of data inserted each time cannot exceed " +
|
||||
std::to_string(engine::MAX_INSERT_DATA_SIZE / engine::MB) + " MB";
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "db/Types.h"
|
||||
#include "server/delivery/request/Types.h"
|
||||
#include "utils/Json.h"
|
||||
#include "utils/Status.h"
|
||||
|
||||
|
@ -56,7 +57,7 @@ extern Status
|
|||
ValidatePartitionTags(const std::vector<std::string>& partition_tags);
|
||||
|
||||
extern Status
|
||||
ValidateInsertDataSize(const engine::DataChunkPtr& data);
|
||||
ValidateInsertDataSize(const InsertParam& insert_param);
|
||||
|
||||
extern Status
|
||||
ValidateCompactThreshold(double threshold);
|
||||
|
|
|
@ -150,8 +150,8 @@ ReqHandler::DropIndex(const ContextPtr& context, const std::string& collection_n
|
|||
|
||||
Status
|
||||
ReqHandler::Insert(const ContextPtr& context, const std::string& collection_name, const std::string& partition_name,
|
||||
const int64_t& row_count, std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data) {
|
||||
BaseReqPtr req_ptr = InsertReq::Create(context, collection_name, partition_name, row_count, chunk_data);
|
||||
InsertParam& insert_param) {
|
||||
BaseReqPtr req_ptr = InsertReq::Create(context, collection_name, partition_name, insert_param);
|
||||
ReqScheduler::ExecReq(req_ptr);
|
||||
return req_ptr->status();
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ class ReqHandler {
|
|||
|
||||
Status
|
||||
Insert(const ContextPtr& context, const std::string& collection_name, const std::string& partition_name,
|
||||
const int64_t& row_count, std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data);
|
||||
InsertParam& insert_param);
|
||||
|
||||
Status
|
||||
GetEntityByID(const ContextPtr& context, const std::string& collection_name, const engine::IDNumbers& ids,
|
||||
|
|
|
@ -32,56 +32,79 @@
|
|||
namespace milvus {
|
||||
namespace server {
|
||||
|
||||
namespace {
|
||||
Status
|
||||
ConvertToChunk(const InsertParam& insert_param, engine::DataChunkPtr& data_chunk) {
|
||||
TimeRecorderAuto rc("Copy insert data to chunk");
|
||||
data_chunk = std::make_shared<engine::DataChunk>();
|
||||
data_chunk->count_ = insert_param.row_count_;
|
||||
for (auto& pair : insert_param.fields_data_) {
|
||||
engine::BinaryDataPtr bin = std::make_shared<engine::BinaryData>();
|
||||
|
||||
// calculate data size
|
||||
int64_t bytes = 0;
|
||||
for (auto& data_segment : pair.second) {
|
||||
bytes += data_segment.second;
|
||||
}
|
||||
bin->data_.resize(bytes);
|
||||
|
||||
// copy data
|
||||
int64_t offset = 0;
|
||||
for (auto& data_segment : pair.second) {
|
||||
memcpy(bin->data_.data() + offset, data_segment.first, data_segment.second);
|
||||
offset += data_segment.second;
|
||||
}
|
||||
|
||||
data_chunk->fixed_fields_.insert(std::make_pair(pair.first, bin));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
InsertReq::InsertReq(const ContextPtr& context, const std::string& collection_name, const std::string& partition_name,
|
||||
const int64_t& row_count, std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data)
|
||||
InsertParam& insert_param)
|
||||
: BaseReq(context, ReqType::kInsert),
|
||||
collection_name_(collection_name),
|
||||
partition_name_(partition_name),
|
||||
row_count_(row_count),
|
||||
chunk_data_(chunk_data) {
|
||||
insert_param_(insert_param) {
|
||||
}
|
||||
|
||||
BaseReqPtr
|
||||
InsertReq::Create(const ContextPtr& context, const std::string& collection_name, const std::string& partition_name,
|
||||
const int64_t& row_count, std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data) {
|
||||
return std::shared_ptr<BaseReq>(new InsertReq(context, collection_name, partition_name, row_count, chunk_data));
|
||||
InsertParam& insert_param) {
|
||||
return std::shared_ptr<BaseReq>(new InsertReq(context, collection_name, partition_name, insert_param));
|
||||
}
|
||||
|
||||
Status
|
||||
InsertReq::OnExecute() {
|
||||
LOG_SERVER_INFO_ << LogOut("[%s][%ld] ", "insert", 0) << "Execute InsertReq.";
|
||||
try {
|
||||
std::string hdr = "InsertReq(table=" + collection_name_ + ", partition_name=" + partition_name_ + ")";
|
||||
std::string hdr = "InsertReq(collection=" + collection_name_ + ", partition_name=" + partition_name_ + ")";
|
||||
TimeRecorder rc(hdr);
|
||||
|
||||
if (chunk_data_.empty()) {
|
||||
return Status{SERVER_INVALID_ARGUMENT,
|
||||
"The vector field is empty, Make sure you have entered vector records"};
|
||||
if (insert_param_.row_count_ == 0 || insert_param_.fields_data_.empty()) {
|
||||
return Status{SERVER_INVALID_ARGUMENT, "The field is empty, make sure you have entered entities"};
|
||||
}
|
||||
|
||||
// step 1: check collection existence
|
||||
bool exist = false;
|
||||
STATUS_CHECK(DBWrapper::DB()->HasCollection(collection_name_, exist));
|
||||
if (!exist) {
|
||||
return Status(SERVER_COLLECTION_NOT_EXIST, "Collection not exist: " + collection_name_);
|
||||
return Status(SERVER_COLLECTION_NOT_EXIST, "Collection doesn't exist: " + collection_name_);
|
||||
}
|
||||
|
||||
// step 2: construct insert data
|
||||
engine::DataChunkPtr data_chunk = std::make_shared<engine::DataChunk>();
|
||||
data_chunk->count_ = row_count_;
|
||||
for (auto& pair : chunk_data_) {
|
||||
engine::BinaryDataPtr bin = std::make_shared<engine::BinaryData>();
|
||||
bin->data_.swap(pair.second);
|
||||
data_chunk->fixed_fields_.insert(std::make_pair(pair.first, bin));
|
||||
}
|
||||
|
||||
// step 3: check insert data limitation
|
||||
auto status = ValidateInsertDataSize(data_chunk);
|
||||
// step 2: check insert data limitation
|
||||
auto status = ValidateInsertDataSize(insert_param_);
|
||||
if (!status.ok()) {
|
||||
LOG_SERVER_ERROR_ << LogOut("[%s][%d] Invalid vector data: %s", "insert", 0, status.message().c_str());
|
||||
return status;
|
||||
}
|
||||
|
||||
// step 3: construct insert data
|
||||
engine::DataChunkPtr data_chunk;
|
||||
STATUS_CHECK(ConvertToChunk(insert_param_, data_chunk));
|
||||
|
||||
// step 4: insert data into db
|
||||
status = DBWrapper::DB()->Insert(collection_name_, partition_name_, data_chunk);
|
||||
if (!status.ok()) {
|
||||
|
@ -94,7 +117,10 @@ InsertReq::OnExecute() {
|
|||
if (iter == data_chunk->fixed_fields_.end() || iter->second == nullptr) {
|
||||
return Status(SERVER_UNEXPECTED_ERROR, "Insert action return empty id array");
|
||||
}
|
||||
chunk_data_[engine::FIELD_UID] = iter->second->data_;
|
||||
|
||||
int64_t num = iter->second->data_.size() / sizeof(int64_t);
|
||||
insert_param_.id_returned_.resize(num);
|
||||
memcpy(insert_param_.id_returned_.data(), iter->second->data_.data(), iter->second->data_.size());
|
||||
|
||||
rc.ElapseFromBegin("done");
|
||||
} catch (std::exception& ex) {
|
||||
|
|
|
@ -25,11 +25,11 @@ class InsertReq : public BaseReq {
|
|||
public:
|
||||
static BaseReqPtr
|
||||
Create(const ContextPtr& context, const std::string& collection_name, const std::string& partition_name,
|
||||
const int64_t& row_count, std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data);
|
||||
InsertParam& insert_param);
|
||||
|
||||
protected:
|
||||
InsertReq(const ContextPtr& context, const std::string& collection_name, const std::string& partition_name,
|
||||
const int64_t& row_count, std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data);
|
||||
InsertParam& insert_param);
|
||||
|
||||
Status
|
||||
OnExecute() override;
|
||||
|
@ -37,8 +37,7 @@ class InsertReq : public BaseReq {
|
|||
private:
|
||||
const std::string collection_name_;
|
||||
const std::string partition_name_;
|
||||
const int64_t row_count_;
|
||||
std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data_;
|
||||
InsertParam& insert_param_;
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
|
|
|
@ -85,6 +85,20 @@ struct IndexParam {
|
|||
}
|
||||
};
|
||||
|
||||
struct InsertParam {
|
||||
using DataSegment = std::pair<const char*, int64_t>;
|
||||
using DataSegments = std::vector<DataSegment>;
|
||||
using FieldDataMap = std::unordered_map<std::string, DataSegments>;
|
||||
|
||||
// for the purpose to avoid data copy
|
||||
// the fields_data_ only pass data address, makesure all data address are keep alive
|
||||
FieldDataMap fields_data_;
|
||||
int64_t row_count_ = 0;
|
||||
|
||||
// to return entities id
|
||||
std::vector<int64_t> id_returned_;
|
||||
};
|
||||
|
||||
enum class ReqType {
|
||||
// general operations
|
||||
kCmd = 0,
|
||||
|
|
|
@ -93,35 +93,33 @@ RequestMap(ReqType req_type) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
void
|
||||
CopyVectorData(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorRowRecord>& grpc_records,
|
||||
std::vector<uint8_t>& vectors_data) {
|
||||
// calculate buffer size
|
||||
RecordDataAddr(const std::string& field_name, int32_t num, const T* data, InsertParam& insert_param) {
|
||||
int64_t bytes = num * sizeof(T);
|
||||
const char* data_addr = reinterpret_cast<const char*>(data);
|
||||
auto data_segment = std::make_pair(data_addr, bytes);
|
||||
insert_param.fields_data_[field_name].emplace_back(data_segment);
|
||||
}
|
||||
|
||||
void
|
||||
RecordVectorDataAddr(const std::string& field_name,
|
||||
const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorRowRecord>& grpc_records,
|
||||
InsertParam& insert_param) {
|
||||
// calculate data size
|
||||
int64_t float_data_size = 0, binary_data_size = 0;
|
||||
for (auto& record : grpc_records) {
|
||||
float_data_size += record.float_data_size();
|
||||
binary_data_size += record.binary_data().size();
|
||||
}
|
||||
|
||||
int64_t data_size = binary_data_size;
|
||||
if (float_data_size > 0) {
|
||||
data_size = float_data_size * sizeof(float);
|
||||
}
|
||||
|
||||
// copy vector data
|
||||
vectors_data.resize(data_size);
|
||||
int64_t offset = 0;
|
||||
if (float_data_size > 0) {
|
||||
for (auto& record : grpc_records) {
|
||||
int64_t single_size = record.float_data_size() * sizeof(float);
|
||||
memcpy(&vectors_data[offset], record.float_data().data(), single_size);
|
||||
offset += single_size;
|
||||
RecordDataAddr<float>(field_name, record.float_data_size(), record.float_data().data(), insert_param);
|
||||
}
|
||||
} else if (binary_data_size > 0) {
|
||||
for (auto& record : grpc_records) {
|
||||
int64_t single_size = record.binary_data().size();
|
||||
memcpy(&vectors_data[offset], record.binary_data().data(), single_size);
|
||||
offset += single_size;
|
||||
RecordDataAddr<char>(field_name, record.binary_data().size(), record.binary_data().data(), insert_param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1321,10 +1319,6 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:
|
|||
}
|
||||
}
|
||||
|
||||
auto field_size = request->fields_size();
|
||||
|
||||
std::unordered_map<std::string, std::vector<uint8_t>> chunk_data;
|
||||
|
||||
auto valid_row_count = [&](int32_t& base, int32_t test) -> bool {
|
||||
if (base < 0) {
|
||||
base = test;
|
||||
|
@ -1341,8 +1335,10 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:
|
|||
return true;
|
||||
};
|
||||
|
||||
// copy field data
|
||||
// construct insert parameter
|
||||
InsertParam insert_param;
|
||||
int32_t row_num = -1;
|
||||
auto field_size = request->fields_size();
|
||||
for (int i = 0; i < field_size; i++) {
|
||||
auto grpc_int32_size = request->fields(i).attr_record().int32_value_size();
|
||||
auto grpc_int64_size = request->fields(i).attr_record().int64_value_size();
|
||||
|
@ -1351,62 +1347,57 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:
|
|||
const auto& field = request->fields(i);
|
||||
auto& field_name = field.field_name();
|
||||
|
||||
std::vector<uint8_t> temp_data;
|
||||
if (grpc_int32_size > 0) {
|
||||
if (!valid_row_count(row_num, grpc_int32_size)) {
|
||||
return ::grpc::Status::OK;
|
||||
}
|
||||
temp_data.resize(grpc_int32_size * sizeof(int32_t));
|
||||
memcpy(temp_data.data(), field.attr_record().int32_value().data(), grpc_int32_size * sizeof(int32_t));
|
||||
RecordDataAddr<int32_t>(field_name, grpc_int32_size, field.attr_record().int32_value().data(),
|
||||
insert_param);
|
||||
} else if (grpc_int64_size > 0) {
|
||||
if (!valid_row_count(row_num, grpc_int64_size)) {
|
||||
return ::grpc::Status::OK;
|
||||
}
|
||||
temp_data.resize(grpc_int64_size * sizeof(int64_t));
|
||||
memcpy(temp_data.data(), field.attr_record().int64_value().data(), grpc_int64_size * sizeof(int64_t));
|
||||
RecordDataAddr<int64_t>(field_name, grpc_int64_size, field.attr_record().int64_value().data(),
|
||||
insert_param);
|
||||
} else if (grpc_float_size > 0) {
|
||||
if (!valid_row_count(row_num, grpc_float_size)) {
|
||||
return ::grpc::Status::OK;
|
||||
}
|
||||
temp_data.resize(grpc_float_size * sizeof(float));
|
||||
memcpy(temp_data.data(), field.attr_record().float_value().data(), grpc_float_size * sizeof(float));
|
||||
RecordDataAddr<float>(field_name, grpc_float_size, field.attr_record().float_value().data(), insert_param);
|
||||
} else if (grpc_double_size > 0) {
|
||||
if (!valid_row_count(row_num, grpc_double_size)) {
|
||||
return ::grpc::Status::OK;
|
||||
}
|
||||
temp_data.resize(grpc_double_size * sizeof(double));
|
||||
memcpy(temp_data.data(), field.attr_record().double_value().data(), grpc_double_size * sizeof(double));
|
||||
RecordDataAddr<double>(field_name, grpc_double_size, field.attr_record().double_value().data(),
|
||||
insert_param);
|
||||
} else {
|
||||
if (!valid_row_count(row_num, field.vector_record().records_size())) {
|
||||
return ::grpc::Status::OK;
|
||||
}
|
||||
CopyVectorData(field.vector_record().records(), temp_data);
|
||||
RecordVectorDataAddr(field_name, field.vector_record().records(), insert_param);
|
||||
}
|
||||
|
||||
chunk_data.insert(std::make_pair(field_name, temp_data));
|
||||
}
|
||||
insert_param.row_count_ = row_num;
|
||||
|
||||
// copy id array
|
||||
if (request->entity_id_array_size() > 0) {
|
||||
int64_t size = request->entity_id_array_size() * sizeof(int64_t);
|
||||
std::vector<uint8_t> temp_data(size, 0);
|
||||
memcpy(temp_data.data(), request->entity_id_array().data(), size);
|
||||
chunk_data.insert(std::make_pair(engine::FIELD_UID, temp_data));
|
||||
RecordDataAddr<int64_t>(engine::FIELD_UID, request->entity_id_array_size(), request->entity_id_array().data(),
|
||||
insert_param);
|
||||
}
|
||||
|
||||
std::string collection_name = request->collection_name();
|
||||
std::string partition_name = request->partition_tag();
|
||||
Status status = req_handler_.Insert(GetContext(context), collection_name, partition_name, row_num, chunk_data);
|
||||
Status status = req_handler_.Insert(GetContext(context), collection_name, partition_name, insert_param);
|
||||
if (!status.ok()) {
|
||||
SET_RESPONSE(response->mutable_status(), status, context);
|
||||
return ::grpc::Status::OK;
|
||||
}
|
||||
|
||||
// return generated ids
|
||||
auto pair = chunk_data.find(engine::FIELD_UID);
|
||||
if (pair != chunk_data.end()) {
|
||||
response->mutable_entity_id_array()->Resize(static_cast<int>(pair->second.size() / sizeof(int64_t)), 0);
|
||||
memcpy(response->mutable_entity_id_array()->mutable_data(), pair->second.data(), pair->second.size());
|
||||
if (!insert_param.id_returned_.empty()) {
|
||||
response->mutable_entity_id_array()->Resize(static_cast<int>(insert_param.id_returned_.size()), 0);
|
||||
memcpy(response->mutable_entity_id_array()->mutable_data(), insert_param.id_returned_.data(),
|
||||
insert_param.id_returned_.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
LOG_SERVER_INFO_ << LogOut("Request [%s] %s end.", request_id.c_str(), __func__);
|
||||
|
|
|
@ -83,55 +83,70 @@ WebErrorMap(ErrorCode code) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
CopyStructuredData(const nlohmann::json& json, std::vector<uint8_t>& raw) {
|
||||
std::vector<T> values;
|
||||
auto size = json.size();
|
||||
values.resize(size);
|
||||
raw.resize(size * sizeof(T));
|
||||
size_t offset = 0;
|
||||
for (const auto& data : json) {
|
||||
values[offset] = data.get<T>();
|
||||
++offset;
|
||||
}
|
||||
memcpy(raw.data(), values.data(), size * sizeof(T));
|
||||
}
|
||||
using ChunkDataMap = std::unordered_map<std::string, std::vector<uint8_t>>;
|
||||
|
||||
void
|
||||
CopyRowVectorFromJson(const nlohmann::json& json, std::vector<uint8_t>& vectors_data, bool bin) {
|
||||
// if (!json.is_array()) {
|
||||
// return Status(ILLEGAL_BODY, "field \"vectors\" must be a array");
|
||||
// }
|
||||
CopyRowVectorFromJson(const nlohmann::json& json, const std::string& field_name, int64_t offset, int64_t row_num,
|
||||
bool is_binary, ChunkDataMap& chunk_data) {
|
||||
std::vector<uint8_t> binary_data;
|
||||
std::vector<float> float_vector;
|
||||
if (!bin) {
|
||||
uint64_t bytes = 0;
|
||||
if (is_binary) {
|
||||
for (auto& data : json) {
|
||||
binary_data.emplace_back(data.get<uint8_t>());
|
||||
}
|
||||
bytes = binary_data.size() * sizeof(uint8_t);
|
||||
} else {
|
||||
for (auto& data : json) {
|
||||
float_vector.emplace_back(data.get<float>());
|
||||
}
|
||||
auto size = float_vector.size() * sizeof(float);
|
||||
vectors_data.resize(size);
|
||||
memcpy(vectors_data.data(), float_vector.data(), size);
|
||||
bytes = float_vector.size() * sizeof(float);
|
||||
}
|
||||
|
||||
if (chunk_data.find(field_name) == chunk_data.end()) {
|
||||
std::vector<uint8_t> data(row_num * bytes, 0);
|
||||
chunk_data.insert({field_name, data});
|
||||
}
|
||||
|
||||
int64_t vector_offset = offset * bytes;
|
||||
std::vector<uint8_t>& target_data = chunk_data.at(field_name);
|
||||
if (is_binary) {
|
||||
memcpy(target_data.data() + vector_offset, binary_data.data(), bytes);
|
||||
} else {
|
||||
for (auto& data : json) {
|
||||
vectors_data.emplace_back(data.get<uint8_t>());
|
||||
}
|
||||
memcpy(target_data.data() + vector_offset, float_vector.data(), bytes);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
CopyRowStructuredData(const nlohmann::json& entity_json, const std::string& field_name, const int64_t offset,
|
||||
const int64_t row_num, std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data) {
|
||||
CopyRowStructuredData(const nlohmann::json& entity_json, const std::string& field_name, int64_t offset, int64_t row_num,
|
||||
ChunkDataMap& chunk_data) {
|
||||
T value = entity_json.get<T>();
|
||||
std::vector<uint8_t> temp_data(sizeof(T), 0);
|
||||
memcpy(temp_data.data(), &value, sizeof(T));
|
||||
if (chunk_data.find(field_name) == chunk_data.end()) {
|
||||
std::vector<uint8_t> T_data(row_num * sizeof(T), 0);
|
||||
memcpy(T_data.data(), temp_data.data(), sizeof(T));
|
||||
memcpy(T_data.data(), &value, sizeof(T));
|
||||
chunk_data.insert({field_name, T_data});
|
||||
} else {
|
||||
int64_t T_offset = offset * sizeof(T);
|
||||
memcpy(chunk_data.at(field_name).data() + T_offset, temp_data.data(), sizeof(T));
|
||||
memcpy(chunk_data.at(field_name).data() + T_offset, &value, sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
RecordDataAddr(const std::string& field_name, int32_t num, const T* data, InsertParam& insert_param) {
|
||||
int64_t bytes = num * sizeof(T);
|
||||
const char* data_addr = reinterpret_cast<const char*>(data);
|
||||
auto data_segment = std::make_pair(data_addr, bytes);
|
||||
insert_param.fields_data_[field_name].emplace_back(data_segment);
|
||||
}
|
||||
|
||||
void
|
||||
ConvertToParam(const ChunkDataMap& data_chunk, int64_t row_num, InsertParam& insert_param) {
|
||||
insert_param.row_count_ = row_num;
|
||||
for (auto& pair : data_chunk) {
|
||||
auto& bin = pair.second;
|
||||
RecordDataAddr<uint8_t>(pair.first, bin.size(), bin.data(), insert_param);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1623,29 +1638,24 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
|
|||
field_types.insert({field.first, field.second.field_type_});
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::vector<uint8_t>> chunk_data;
|
||||
int64_t row_num;
|
||||
|
||||
auto entities_json = body_json["entities"];
|
||||
if (!entities_json.is_array()) {
|
||||
RETURN_STATUS_DTO(ILLEGAL_ARGUMENT, "Entities is not an array");
|
||||
}
|
||||
row_num = entities_json.size();
|
||||
|
||||
// construct chunk data by json object
|
||||
ChunkDataMap chunk_data;
|
||||
int64_t row_num = entities_json.size();
|
||||
int64_t offset = 0;
|
||||
std::vector<uint8_t> ids;
|
||||
for (auto& one_entity : entities_json) {
|
||||
for (auto& entity : one_entity.items()) {
|
||||
std::string field_name = entity.key();
|
||||
if (field_name == NAME_ID) {
|
||||
if (ids.empty()) {
|
||||
ids.resize(row_num * sizeof(int64_t));
|
||||
}
|
||||
auto id = entity.value().get<int64_t>();
|
||||
int64_t id_offset = offset * sizeof(int64_t);
|
||||
memcpy(ids.data() + id_offset, &id, sizeof(int64_t));
|
||||
// special handle id field
|
||||
CopyRowStructuredData<int64_t>(entity.value(), engine::FIELD_UID, offset, row_num, chunk_data);
|
||||
continue;
|
||||
}
|
||||
std::vector<uint8_t> temp_data;
|
||||
|
||||
switch (field_types.at(field_name)) {
|
||||
case engine::DataType::INT32: {
|
||||
CopyRowStructuredData<int32_t>(entity.value(), field_name, offset, row_num, chunk_data);
|
||||
|
@ -1666,16 +1676,7 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
|
|||
case engine::DataType::VECTOR_FLOAT:
|
||||
case engine::DataType::VECTOR_BINARY: {
|
||||
bool is_bin = !(field_types.at(field_name) == engine::DataType::VECTOR_FLOAT);
|
||||
CopyRowVectorFromJson(entity.value(), temp_data, is_bin);
|
||||
auto size = temp_data.size();
|
||||
if (chunk_data.find(field_name) == chunk_data.end()) {
|
||||
std::vector<uint8_t> vector_data(row_num * size, 0);
|
||||
memcpy(vector_data.data(), temp_data.data(), size);
|
||||
chunk_data.insert({field_name, vector_data});
|
||||
} else {
|
||||
int64_t vector_offset = offset * size;
|
||||
memcpy(chunk_data.at(field_name).data() + vector_offset, temp_data.data(), size);
|
||||
}
|
||||
CopyRowVectorFromJson(entity.value(), field_name, offset, row_num, is_bin, chunk_data);
|
||||
break;
|
||||
}
|
||||
default: {}
|
||||
|
@ -1684,71 +1685,21 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
|
|||
offset++;
|
||||
}
|
||||
|
||||
if (!ids.empty()) {
|
||||
chunk_data.insert({engine::FIELD_UID, ids});
|
||||
}
|
||||
// conver to InsertParam, no memory copy, just record the data address and pass to InsertReq
|
||||
InsertParam insert_param;
|
||||
ConvertToParam(chunk_data, row_num, insert_param);
|
||||
|
||||
#if 0
|
||||
for (auto& entity : body_json["entities"].items()) {
|
||||
std::string field_name = entity.key();
|
||||
auto field_value = entity.value();
|
||||
if (!field_value.is_array()) {
|
||||
RETURN_STATUS_DTO(ILLEGAL_ROWRECORD, "Field value is not an array");
|
||||
}
|
||||
if (field_name == NAME_ID) {
|
||||
std::vector<uint8_t> temp_data(field_value.size() * sizeof(int64_t), 0);
|
||||
CopyStructuredData<int64_t>(field_value, temp_data);
|
||||
chunk_data.insert({engine::FIELD_UID, temp_data});
|
||||
continue;
|
||||
}
|
||||
row_num = field_value.size();
|
||||
|
||||
std::vector<uint8_t> temp_data;
|
||||
switch (field_types.at(field_name)) {
|
||||
case engine::DataType::INT32: {
|
||||
CopyStructuredData<int32_t>(field_value, temp_data);
|
||||
break;
|
||||
}
|
||||
case engine::DataType::INT64: {
|
||||
CopyStructuredData<int64_t>(field_value, temp_data);
|
||||
break;
|
||||
}
|
||||
case engine::DataType::FLOAT: {
|
||||
CopyStructuredData<float>(field_value, temp_data);
|
||||
break;
|
||||
}
|
||||
case engine::DataType::DOUBLE: {
|
||||
CopyStructuredData<double>(field_value, temp_data);
|
||||
break;
|
||||
}
|
||||
case engine::DataType::VECTOR_FLOAT: {
|
||||
CopyRecordsFromJson(field_value, temp_data, false);
|
||||
break;
|
||||
}
|
||||
case engine::DataType::VECTOR_BINARY: {
|
||||
CopyRecordsFromJson(field_value, temp_data, true);
|
||||
break;
|
||||
}
|
||||
default: {}
|
||||
}
|
||||
|
||||
chunk_data.insert(std::make_pair(field_name, temp_data));
|
||||
}
|
||||
#endif
|
||||
|
||||
status = req_handler_.Insert(context_ptr_, collection_name->c_str(), partition_name, row_num, chunk_data);
|
||||
// do insert
|
||||
status = req_handler_.Insert(context_ptr_, collection_name->c_str(), partition_name, insert_param);
|
||||
if (!status.ok()) {
|
||||
RETURN_STATUS_DTO(UNEXPECTED_ERROR, "Failed to insert data");
|
||||
}
|
||||
|
||||
// return generated ids
|
||||
auto pair = chunk_data.find(engine::FIELD_UID);
|
||||
if (pair != chunk_data.end()) {
|
||||
int64_t count = pair->second.size() / 8;
|
||||
auto pdata = reinterpret_cast<int64_t*>(pair->second.data());
|
||||
if (!insert_param.id_returned_.empty()) {
|
||||
ids_dto->ids = ids_dto->ids.createShared();
|
||||
for (int64_t i = 0; i < count; ++i) {
|
||||
ids_dto->ids->push_back(std::to_string(pdata[i]).c_str());
|
||||
for (auto id : insert_param.id_returned_) {
|
||||
ids_dto->ids->push_back(std::to_string(id).c_str());
|
||||
}
|
||||
}
|
||||
ids_dto->code = status.code();
|
||||
|
|
Loading…
Reference in New Issue