From 3d9a81068cca8a3feed62230c517064280f85245 Mon Sep 17 00:00:00 2001 From: godchen0212 <67679556+godchen0212@users.noreply.github.com> Date: Sun, 26 Jul 2020 11:22:35 +0800 Subject: [PATCH] Improvement in GetEntityById grpc interface (#3019) * Fix bug in GetEntityById Signed-off-by: godchen0212 * format code and delete useless code Signed-off-by: godchen0212 --- .../hybrid_request/GetEntityByIDRequest.cpp | 27 +++--- .../hybrid_request/GetEntityByIDRequest.h | 14 +-- .../server/grpc_impl/GrpcRequestHandler.cpp | 88 ++++++++++++++----- 3 files changed, 89 insertions(+), 40 deletions(-) diff --git a/core/src/server/delivery/hybrid_request/GetEntityByIDRequest.cpp b/core/src/server/delivery/hybrid_request/GetEntityByIDRequest.cpp index 52fb3b9a30..6dd1c9d8e4 100644 --- a/core/src/server/delivery/hybrid_request/GetEntityByIDRequest.cpp +++ b/core/src/server/delivery/hybrid_request/GetEntityByIDRequest.cpp @@ -16,6 +16,7 @@ // under the License. #include "server/delivery/hybrid_request/GetEntityByIDRequest.h" +#include "db/meta/MetaTypes.h" #include "server/DBWrapper.h" #include "server/ValidationUtil.h" #include "utils/Log.h" @@ -30,8 +31,8 @@ namespace server { constexpr uint64_t MAX_COUNT_RETURNED = 1000; GetEntityByIDRequest::GetEntityByIDRequest(const std::shared_ptr& context, - const std::string collection_name, const engine::IDNumbers& id_array, - const std::vector& field_names, + const std::string& collection_name, const engine::IDNumbers& id_array, + std::vector& field_names, engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk) : BaseRequest(context, BaseRequest::kGetVectorByID), @@ -43,8 +44,9 @@ GetEntityByIDRequest::GetEntityByIDRequest(const std::shared_ptr& context, std::string collection_name, - const engine::IDNumbers& id_array, const std::vector& field_names_, +GetEntityByIDRequest::Create(const std::shared_ptr& context, + const std::string& collection_name, const engine::IDNumbers& id_array, + std::vector& field_names_, engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk) { return std::shared_ptr( new GetEntityByIDRequest(context, collection_name, id_array, field_names_, field_mappings, data_chunk)); @@ -81,19 +83,17 @@ GetEntityByIDRequest::OnExecute() { } if (field_names_.empty()) { - for (const auto& schema : field_mappings_) - for (const auto& it : schema.second) { - field_names_.emplace_back(it->GetName()); - } + for (const auto& schema : field_mappings_) { + if (schema.first->GetFtype() != engine::meta::hybrid::DataType::UID) + field_names_.emplace_back(schema.first->GetName()); + } } else { for (const auto& name : field_names_) { bool find_field_name = false; for (const auto& schema : field_mappings_) { - for (const auto& it : schema.second) { - if (name == it->GetName()) { - find_field_name = true; - break; - } + if (name == schema.first->GetName()) { + find_field_name = true; + break; } } if (not find_field_name) { @@ -107,6 +107,7 @@ GetEntityByIDRequest::OnExecute() { if (!status.ok()) { return Status(SERVER_INVALID_COLLECTION_NAME, CollectionNotExistMsg(collection_name_)); } + return Status::OK(); } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } diff --git a/core/src/server/delivery/hybrid_request/GetEntityByIDRequest.h b/core/src/server/delivery/hybrid_request/GetEntityByIDRequest.h index 6cd16be233..9af70eb109 100644 --- a/core/src/server/delivery/hybrid_request/GetEntityByIDRequest.h +++ b/core/src/server/delivery/hybrid_request/GetEntityByIDRequest.h @@ -32,13 +32,13 @@ namespace server { class GetEntityByIDRequest : public BaseRequest { public: static BaseRequestPtr - Create(const std::shared_ptr& context, std::string collection_name, - const engine::IDNumbers& id_array, const std::vector& field_names_, + Create(const std::shared_ptr& context, const std::string& collection_name, + const engine::IDNumbers& id_array, std::vector& field_names_, engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk); protected: - GetEntityByIDRequest(const std::shared_ptr& context, std::string collection_name, - const engine::IDNumbers& id_array, const std::vector& field_names, + GetEntityByIDRequest(const std::shared_ptr& context, const std::string& collection_name, + const engine::IDNumbers& id_array, std::vector& field_names, engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk); Status @@ -47,9 +47,9 @@ class GetEntityByIDRequest : public BaseRequest { private: std::string collection_name_; engine::IDNumbers id_array_; - std::vector field_names_; - engine::snapshot::CollectionMappings field_mappings_; - engine::DataChunkPtr data_chunk_; + std::vector& field_names_; + engine::snapshot::CollectionMappings& field_mappings_; + engine::DataChunkPtr& data_chunk_; }; } // namespace server diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.cpp b/core/src/server/grpc_impl/GrpcRequestHandler.cpp index 3ada636afd..8d3f03fabd 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/core/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -85,6 +85,7 @@ RequestMap(BaseRequest::RequestType request_type) { {BaseRequest::kSearchByID, "SearchByID"}, {BaseRequest::kHybridSearch, "HybridSearch"}, {BaseRequest::kFlush, "Flush"}, + {BaseRequest::kGetEntityByID, "GetEntityByID"}, {BaseRequest::kCompact, "Compact"}, }; @@ -756,33 +757,80 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus Status status = request_handler_.GetEntityByID(GetContext(context), request->collection_name(), vector_ids, field_names, field_mappings, data_chunk); - std::vector id_array = data_chunk->fixed_fields_[engine::DEFAULT_UID_NAME]; - for (const auto& it : field_mappings) { + auto type = it.first->GetFtype(); std::string name = it.first->GetName(); - uint64_t type = it.first->GetFtype(); std::vector data = data_chunk->fixed_fields_[name]; - if (type == engine::FieldType::VECTOR_BINARY) { - engine::VectorsData vectors_data; - memcpy(vectors_data.binary_data_.data(), data.data(), data.size()); - memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size()); - vectors.emplace_back(vectors_data); - } else if (type == engine::FieldType::VECTOR_FLOAT) { - engine::VectorsData vectors_data; - memcpy(vectors_data.float_data_.data(), data.data(), data.size()); - memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size()); - vectors.emplace_back(vectors_data); + + if (type == engine::meta::hybrid::DataType::UID) { + response->mutable_ids()->Resize(data.size(), 0); + memcpy(response->mutable_ids()->mutable_data(), data.data(), data.size() * sizeof(uint64_t)); + continue; + } + + auto field_value = response->add_fields(); + auto vector_record = field_value->mutable_vector_record(); + + field_value->set_field_name(name); + field_value->set_type(static_cast(type)); + // general data + if (type == engine::meta::hybrid::DataType::VECTOR_BINARY) { + // add binary vector data + auto vector_row_record = vector_record->add_records(); + + std::vector binary_vector; + binary_vector.resize(data.size()); + memcpy(binary_vector.data(), data.data(), data.size()); + vector_row_record->mutable_binary_data()->resize(binary_vector.size()); + memcpy(vector_row_record->mutable_binary_data()->data(), binary_vector.data(), binary_vector.size()); + + continue; + } else if (type == engine::meta::hybrid::DataType::VECTOR_FLOAT) { + // add float vector data + auto vector_row_record = vector_record->add_records(); + std::vector float_vector; + float_vector.resize(data.size() * sizeof(int8_t) / sizeof(float)); + memcpy(float_vector.data(), data.data(), data.size()); + vector_row_record->mutable_float_data()->Resize(float_vector.size(), 0.0); + memcpy(vector_row_record->mutable_float_data()->mutable_data(), float_vector.data(), + float_vector.size() * sizeof(float)); + + continue; } else { - engine::AttrsData attrs_data; - attrs_data.attr_type_[name] = static_cast(type); - attrs_data.attr_data_[name] = data; - memcpy(attrs_data.id_array_.data(), id_array.data(), id_array.size()); - attrs.emplace_back(attrs_data); + // add attribute data + auto attr_record = field_value->mutable_attr_record(); + if (type == engine::meta::hybrid::DataType::INT32) { + // add int32 data + std::vector int32_value; + int32_value.resize(data.size() * sizeof(int8_t) / sizeof(int32_t)); + memcpy(int32_value.data(), data.data(), data.size()); + attr_record->mutable_int32_value()->Resize(int32_value.size(), 0); + memcpy(attr_record->mutable_int32_value()->mutable_data(), int32_value.data(), int32_value.size()); + } else if (type == engine::meta::hybrid::DataType::INT64) { + // add int64 data + std::vector int64_value; + int64_value.resize(data.size() * sizeof(int8_t) / sizeof(int64_t)); + memcpy(int64_value.data(), data.data(), data.size()); + attr_record->mutable_int64_value()->Resize(int64_value.size(), 0); + memcpy(attr_record->mutable_int64_value()->mutable_data(), int64_value.data(), int64_value.size()); + } else if (type == engine::meta::hybrid::DataType::DOUBLE) { + // add double data + std::vector double_value; + double_value.resize(data.size() * sizeof(int8_t) / sizeof(double)); + memcpy(double_value.data(), data.data(), data.size()); + attr_record->mutable_double_value()->Resize(double_value.size(), 0.0); + memcpy(attr_record->mutable_double_value()->mutable_data(), double_value.data(), double_value.size()); + } else if (type == engine::meta::hybrid::DataType::FLOAT) { + // add float data + std::vector float_value; + float_value.resize(data.size() * sizeof(int8_t) / sizeof(float)); + memcpy(float_value.data(), data.data(), data.size()); + attr_record->mutable_float_value()->Resize(float_value.size(), 0.0); + memcpy(attr_record->mutable_float_value()->mutable_data(), float_value.data(), float_value.size()); + } } } - ConstructEntityResults(attrs, vectors, field_names, response); - LOG_SERVER_INFO_ << LogOut("Request [%s] %s end.", GetContext(context)->RequestID().c_str(), __func__); SET_RESPONSE(response->mutable_status(), status, context);