mirror of https://github.com/milvus-io/milvus.git
Improvement in GetEntityById grpc interface (#3019)
* Fix bug in GetEntityById Signed-off-by: godchen0212 <qingxiang.chen@zilliz.com> * format code and delete useless code Signed-off-by: godchen0212 <qingxiang.chen@zilliz.com>pull/3022/head
parent
10799ce1f2
commit
3d9a81068c
|
@ -16,6 +16,7 @@
|
|||
// under the License.
|
||||
|
||||
#include "server/delivery/hybrid_request/GetEntityByIDRequest.h"
|
||||
#include "db/meta/MetaTypes.h"
|
||||
#include "server/DBWrapper.h"
|
||||
#include "server/ValidationUtil.h"
|
||||
#include "utils/Log.h"
|
||||
|
@ -30,8 +31,8 @@ namespace server {
|
|||
constexpr uint64_t MAX_COUNT_RETURNED = 1000;
|
||||
|
||||
GetEntityByIDRequest::GetEntityByIDRequest(const std::shared_ptr<milvus::server::Context>& context,
|
||||
const std::string collection_name, const engine::IDNumbers& id_array,
|
||||
const std::vector<std::string>& field_names,
|
||||
const std::string& collection_name, const engine::IDNumbers& id_array,
|
||||
std::vector<std::string>& field_names,
|
||||
engine::snapshot::CollectionMappings& field_mappings,
|
||||
engine::DataChunkPtr& data_chunk)
|
||||
: BaseRequest(context, BaseRequest::kGetVectorByID),
|
||||
|
@ -43,8 +44,9 @@ GetEntityByIDRequest::GetEntityByIDRequest(const std::shared_ptr<milvus::server:
|
|||
}
|
||||
|
||||
BaseRequestPtr
|
||||
GetEntityByIDRequest::Create(const std::shared_ptr<milvus::server::Context>& context, std::string collection_name,
|
||||
const engine::IDNumbers& id_array, const std::vector<std::string>& field_names_,
|
||||
GetEntityByIDRequest::Create(const std::shared_ptr<milvus::server::Context>& context,
|
||||
const std::string& collection_name, const engine::IDNumbers& id_array,
|
||||
std::vector<std::string>& field_names_,
|
||||
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk) {
|
||||
return std::shared_ptr<BaseRequest>(
|
||||
new GetEntityByIDRequest(context, collection_name, id_array, field_names_, field_mappings, data_chunk));
|
||||
|
@ -81,19 +83,17 @@ GetEntityByIDRequest::OnExecute() {
|
|||
}
|
||||
|
||||
if (field_names_.empty()) {
|
||||
for (const auto& schema : field_mappings_)
|
||||
for (const auto& it : schema.second) {
|
||||
field_names_.emplace_back(it->GetName());
|
||||
}
|
||||
for (const auto& schema : field_mappings_) {
|
||||
if (schema.first->GetFtype() != engine::meta::hybrid::DataType::UID)
|
||||
field_names_.emplace_back(schema.first->GetName());
|
||||
}
|
||||
} else {
|
||||
for (const auto& name : field_names_) {
|
||||
bool find_field_name = false;
|
||||
for (const auto& schema : field_mappings_) {
|
||||
for (const auto& it : schema.second) {
|
||||
if (name == it->GetName()) {
|
||||
find_field_name = true;
|
||||
break;
|
||||
}
|
||||
if (name == schema.first->GetName()) {
|
||||
find_field_name = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (not find_field_name) {
|
||||
|
@ -107,6 +107,7 @@ GetEntityByIDRequest::OnExecute() {
|
|||
if (!status.ok()) {
|
||||
return Status(SERVER_INVALID_COLLECTION_NAME, CollectionNotExistMsg(collection_name_));
|
||||
}
|
||||
return Status::OK();
|
||||
} catch (std::exception& ex) {
|
||||
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
|
||||
}
|
||||
|
|
|
@ -32,13 +32,13 @@ namespace server {
|
|||
class GetEntityByIDRequest : public BaseRequest {
|
||||
public:
|
||||
static BaseRequestPtr
|
||||
Create(const std::shared_ptr<milvus::server::Context>& context, std::string collection_name,
|
||||
const engine::IDNumbers& id_array, const std::vector<std::string>& field_names_,
|
||||
Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
|
||||
const engine::IDNumbers& id_array, std::vector<std::string>& field_names_,
|
||||
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
|
||||
|
||||
protected:
|
||||
GetEntityByIDRequest(const std::shared_ptr<milvus::server::Context>& context, std::string collection_name,
|
||||
const engine::IDNumbers& id_array, const std::vector<std::string>& field_names,
|
||||
GetEntityByIDRequest(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
|
||||
const engine::IDNumbers& id_array, std::vector<std::string>& field_names,
|
||||
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
|
||||
|
||||
Status
|
||||
|
@ -47,9 +47,9 @@ class GetEntityByIDRequest : public BaseRequest {
|
|||
private:
|
||||
std::string collection_name_;
|
||||
engine::IDNumbers id_array_;
|
||||
std::vector<std::string> field_names_;
|
||||
engine::snapshot::CollectionMappings field_mappings_;
|
||||
engine::DataChunkPtr data_chunk_;
|
||||
std::vector<std::string>& field_names_;
|
||||
engine::snapshot::CollectionMappings& field_mappings_;
|
||||
engine::DataChunkPtr& data_chunk_;
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
|
|
|
@ -85,6 +85,7 @@ RequestMap(BaseRequest::RequestType request_type) {
|
|||
{BaseRequest::kSearchByID, "SearchByID"},
|
||||
{BaseRequest::kHybridSearch, "HybridSearch"},
|
||||
{BaseRequest::kFlush, "Flush"},
|
||||
{BaseRequest::kGetEntityByID, "GetEntityByID"},
|
||||
{BaseRequest::kCompact, "Compact"},
|
||||
};
|
||||
|
||||
|
@ -756,33 +757,80 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus
|
|||
Status status = request_handler_.GetEntityByID(GetContext(context), request->collection_name(), vector_ids,
|
||||
field_names, field_mappings, data_chunk);
|
||||
|
||||
std::vector<uint8_t> id_array = data_chunk->fixed_fields_[engine::DEFAULT_UID_NAME];
|
||||
|
||||
for (const auto& it : field_mappings) {
|
||||
auto type = it.first->GetFtype();
|
||||
std::string name = it.first->GetName();
|
||||
uint64_t type = it.first->GetFtype();
|
||||
std::vector<uint8_t> data = data_chunk->fixed_fields_[name];
|
||||
if (type == engine::FieldType::VECTOR_BINARY) {
|
||||
engine::VectorsData vectors_data;
|
||||
memcpy(vectors_data.binary_data_.data(), data.data(), data.size());
|
||||
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
|
||||
vectors.emplace_back(vectors_data);
|
||||
} else if (type == engine::FieldType::VECTOR_FLOAT) {
|
||||
engine::VectorsData vectors_data;
|
||||
memcpy(vectors_data.float_data_.data(), data.data(), data.size());
|
||||
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
|
||||
vectors.emplace_back(vectors_data);
|
||||
|
||||
if (type == engine::meta::hybrid::DataType::UID) {
|
||||
response->mutable_ids()->Resize(data.size(), 0);
|
||||
memcpy(response->mutable_ids()->mutable_data(), data.data(), data.size() * sizeof(uint64_t));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto field_value = response->add_fields();
|
||||
auto vector_record = field_value->mutable_vector_record();
|
||||
|
||||
field_value->set_field_name(name);
|
||||
field_value->set_type(static_cast<milvus::grpc::DataType>(type));
|
||||
// general data
|
||||
if (type == engine::meta::hybrid::DataType::VECTOR_BINARY) {
|
||||
// add binary vector data
|
||||
auto vector_row_record = vector_record->add_records();
|
||||
|
||||
std::vector<int8_t> binary_vector;
|
||||
binary_vector.resize(data.size());
|
||||
memcpy(binary_vector.data(), data.data(), data.size());
|
||||
vector_row_record->mutable_binary_data()->resize(binary_vector.size());
|
||||
memcpy(vector_row_record->mutable_binary_data()->data(), binary_vector.data(), binary_vector.size());
|
||||
|
||||
continue;
|
||||
} else if (type == engine::meta::hybrid::DataType::VECTOR_FLOAT) {
|
||||
// add float vector data
|
||||
auto vector_row_record = vector_record->add_records();
|
||||
std::vector<float> float_vector;
|
||||
float_vector.resize(data.size() * sizeof(int8_t) / sizeof(float));
|
||||
memcpy(float_vector.data(), data.data(), data.size());
|
||||
vector_row_record->mutable_float_data()->Resize(float_vector.size(), 0.0);
|
||||
memcpy(vector_row_record->mutable_float_data()->mutable_data(), float_vector.data(),
|
||||
float_vector.size() * sizeof(float));
|
||||
|
||||
continue;
|
||||
} else {
|
||||
engine::AttrsData attrs_data;
|
||||
attrs_data.attr_type_[name] = static_cast<engine::meta::hybrid::DataType>(type);
|
||||
attrs_data.attr_data_[name] = data;
|
||||
memcpy(attrs_data.id_array_.data(), id_array.data(), id_array.size());
|
||||
attrs.emplace_back(attrs_data);
|
||||
// add attribute data
|
||||
auto attr_record = field_value->mutable_attr_record();
|
||||
if (type == engine::meta::hybrid::DataType::INT32) {
|
||||
// add int32 data
|
||||
std::vector<int32_t> int32_value;
|
||||
int32_value.resize(data.size() * sizeof(int8_t) / sizeof(int32_t));
|
||||
memcpy(int32_value.data(), data.data(), data.size());
|
||||
attr_record->mutable_int32_value()->Resize(int32_value.size(), 0);
|
||||
memcpy(attr_record->mutable_int32_value()->mutable_data(), int32_value.data(), int32_value.size());
|
||||
} else if (type == engine::meta::hybrid::DataType::INT64) {
|
||||
// add int64 data
|
||||
std::vector<int64_t> int64_value;
|
||||
int64_value.resize(data.size() * sizeof(int8_t) / sizeof(int64_t));
|
||||
memcpy(int64_value.data(), data.data(), data.size());
|
||||
attr_record->mutable_int64_value()->Resize(int64_value.size(), 0);
|
||||
memcpy(attr_record->mutable_int64_value()->mutable_data(), int64_value.data(), int64_value.size());
|
||||
} else if (type == engine::meta::hybrid::DataType::DOUBLE) {
|
||||
// add double data
|
||||
std::vector<double> double_value;
|
||||
double_value.resize(data.size() * sizeof(int8_t) / sizeof(double));
|
||||
memcpy(double_value.data(), data.data(), data.size());
|
||||
attr_record->mutable_double_value()->Resize(double_value.size(), 0.0);
|
||||
memcpy(attr_record->mutable_double_value()->mutable_data(), double_value.data(), double_value.size());
|
||||
} else if (type == engine::meta::hybrid::DataType::FLOAT) {
|
||||
// add float data
|
||||
std::vector<float> float_value;
|
||||
float_value.resize(data.size() * sizeof(int8_t) / sizeof(float));
|
||||
memcpy(float_value.data(), data.data(), data.size());
|
||||
attr_record->mutable_float_value()->Resize(float_value.size(), 0.0);
|
||||
memcpy(attr_record->mutable_float_value()->mutable_data(), float_value.data(), float_value.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ConstructEntityResults(attrs, vectors, field_names, response);
|
||||
|
||||
LOG_SERVER_INFO_ << LogOut("Request [%s] %s end.", GetContext(context)->RequestID().c_str(), __func__);
|
||||
SET_RESPONSE(response->mutable_status(), status, context);
|
||||
|
||||
|
|
Loading…
Reference in New Issue