diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.cpp b/core/src/server/grpc_impl/GrpcRequestHandler.cpp index 9045c5cebf..65825c11b0 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/core/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -1698,17 +1698,17 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery( if (vector_param_it != it.value().end()) { const std::string& field_name = vector_param_it.key(); vector_query->field_name = field_name; - nlohmann::json vector_json = vector_param_it.value(); - int64_t topk = vector_json["topk"]; + nlohmann::json param_json = vector_param_it.value(); + int64_t topk = param_json["topk"]; status = server::ValidateSearchTopk(topk); if (!status.ok()) { return status; } vector_query->topk = topk; - if (vector_json.contains("metric_type")) { - std::string metric_type = vector_json["metric_type"]; + if (param_json.contains("metric_type")) { + std::string metric_type = param_json["metric_type"]; vector_query->metric_type = metric_type; - query_ptr->metric_types.insert({field_name, vector_json["metric_type"]}); + query_ptr->metric_types.insert({field_name, param_json["metric_type"]}); } if (!vector_param_it.value()["params"].empty()) { vector_query->extra_params = vector_param_it.value()["params"]; diff --git a/core/src/server/web_impl/Types.h b/core/src/server/web_impl/Types.h index 60e31ebec6..4666ff53e3 100644 --- a/core/src/server/web_impl/Types.h +++ b/core/src/server/web_impl/Types.h @@ -11,6 +11,7 @@ #pragma once +#include #include #include @@ -72,6 +73,13 @@ enum StatusCode : int { MAX = ILLEGAL_QUERY_PARAM }; +static std::map str2type = {{"int32", engine::DataType::INT32}, + {"int64", engine::DataType::INT64}, + {"float", engine::DataType::FLOAT}, + {"double", engine::DataType::DOUBLE}, + {"vector_float", engine::DataType::VECTOR_FLOAT}, + {"vector_binary", engine::DataType::VECTOR_BINARY}}; + } // namespace web } // namespace server } // namespace milvus diff --git a/core/src/server/web_impl/controller/WebController.hpp b/core/src/server/web_impl/controller/WebController.hpp index 96cdb11b81..09f140f8f7 100644 --- a/core/src/server/web_impl/controller/WebController.hpp +++ b/core/src/server/web_impl/controller/WebController.hpp @@ -215,74 +215,71 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(CreateCollection) - ENDPOINT("POST", "/collections", CreateCollection, BODY_DTO(CollectionRequestDto::ObjectWrapper, body)) { - // TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections\'"); - // tr.RecordSection("Received request."); - // - // WebRequestHandler handler = WebRequestHandler(); - // - // std::shared_ptr response; - // auto status_dto = handler.CreateCollection(body); - // switch (status_dto->code->getValue()) { - // case StatusCode::SUCCESS: - // response = createDtoResponse(Status::CODE_201, status_dto); - // break; - // default: - // response = createDtoResponse(Status::CODE_400, status_dto); - // } - // - // std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + - // ", reason = " + status_dto->message->std_str() + ". Total cost"; - // tr.ElapseFromBegin(ttr); + ENDPOINT("POST", "/collections", CreateCollection, BODY_STRING(String, body_str)) { + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections\'"); + tr.RecordSection("Received request."); + + WebRequestHandler handler = WebRequestHandler(); + + std::shared_ptr response; + auto status_dto = handler.CreateCollection(body_str); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_201, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); + } + + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; + tr.ElapseFromBegin(ttr); - StatusDto::ObjectWrapper status; - auto response = createDtoResponse(Status::CODE_200, status); return response; } ADD_CORS(ShowCollections) ENDPOINT("GET", "/collections", ShowCollections, QUERIES(const QueryParams&, query_params)) { - // TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections\'"); - // tr.RecordSection("Received request."); - // - // WebRequestHandler handler = WebRequestHandler(); - // - // String result; - // auto status_dto = handler.ShowCollections(query_params, result); - // std::shared_ptr response; - // switch (status_dto->code->getValue()) { - // case StatusCode::SUCCESS: - // response = createResponse(Status::CODE_200, result); - // break; - // default: - // response = createDtoResponse(Status::CODE_400, status_dto); - // } - // - // std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + - // ", reason = " + status_dto->message->std_str() + ". Total cost"; - // tr.ElapseFromBegin(ttr); + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections\'"); + tr.RecordSection("Received request."); - json result_json = R"({ - "collections": [ - { - "collection_name": "test_collection", - "fields": [ - { - "field_name": "field_vec", - "field_type": "VECTOR_FLOAT", - "index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096}, - "extra_params": {"dimension": 128, "metric_type": "L2"} - } - ], - "segment_size": 1024 - } - ], - "count": 58 - })"; + WebRequestHandler handler = WebRequestHandler(); - String result = result_json.dump().c_str(); - auto response = createResponse(Status::CODE_200, result); + String result; + auto status_dto = handler.ShowCollections(query_params, result); + std::shared_ptr response; + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createResponse(Status::CODE_200, result); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); + } + + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; + tr.ElapseFromBegin(ttr); + + // json result_json = R"({ + // "collections": [ + // { + // "collection_name": "test_collection", + // "fields": [ + // { + // "field_name": "field_vec", + // "field_type": "VECTOR_FLOAT", + // "index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096}, + // "extra_params": {"dimension": 128, "metric_type": "L2"} + // } + // ], + // "segment_size": 1024 + // } + // ], + // "count": 58 + // })"; + + response = createResponse(Status::CODE_200, result); return response; } @@ -296,74 +293,71 @@ class WebController : public oatpp::web::server::api::ApiController { ENDPOINT("GET", "/collections/{collection_name}", GetCollection, PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) { - // TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + - // "\'"); tr.RecordSection("Received request."); - // - // WebRequestHandler handler = WebRequestHandler(); - // - // String response_str; - // auto status_dto = handler.GetCollection(collection_name, query_params, response_str); - // - // std::shared_ptr response; - // switch (status_dto->code->getValue()) { - // case StatusCode::SUCCESS: - // response = createResponse(Status::CODE_200, response_str); - // break; - // case StatusCode::COLLECTION_NOT_EXISTS: - // response = createDtoResponse(Status::CODE_404, status_dto); - // break; - // default: - // response = createDtoResponse(Status::CODE_400, status_dto); - // } - // - // std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + - // ", reason = " + status_dto->message->std_str() + ". Total cost"; - // tr.ElapseFromBegin(ttr); + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "\'"); + tr.RecordSection("Received request."); - json result_json = R"({ - "collection_name": "test_collection", - "fields": [ - { - "field_name": "field_vec", - "field_type": "VECTOR_FLOAT", - "index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096}, - "extra_params": {"dimension": 128, "metric_type": "L2"} - } - ], - "row_count": 10000 - })"; + WebRequestHandler handler = WebRequestHandler(); + + String response_str; + auto status_dto = handler.GetCollection(collection_name, query_params, response_str); + + std::shared_ptr response; + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createResponse(Status::CODE_200, response_str); + break; + case StatusCode::COLLECTION_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); + } + + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; + tr.ElapseFromBegin(ttr); + + // json result_json = R"({ + // "collection_name": "test_collection", + // "fields": [ + // { + // "field_name": "field_vec", + // "field_type": "VECTOR_FLOAT", + // "index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096}, + // "extra_params": {"dimension": 128, "metric_type": "L2"} + // } + // ], + // "row_count": 10000 + // })"; - auto response = createResponse(Status::CODE_200, result_json.dump().c_str()); return response; } ADD_CORS(DropCollection) ENDPOINT("DELETE", "/collections/{collection_name}", DropCollection, PATH(String, collection_name)) { - // TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + - // "\'"); tr.RecordSection("Received request."); - // - // WebRequestHandler handler = WebRequestHandler(); - // - // std::shared_ptr response; - // auto status_dto = handler.DropCollection(collection_name); - // switch (status_dto->code->getValue()) { - // case StatusCode::SUCCESS: - // response = createDtoResponse(Status::CODE_204, status_dto); - // break; - // case StatusCode::COLLECTION_NOT_EXISTS: - // response = createDtoResponse(Status::CODE_404, status_dto); - // break; - // default: - // response = createDtoResponse(Status::CODE_400, status_dto); - // } - // - // std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + - // ", reason = " + status_dto->message->std_str() + ". Total cost"; - // tr.ElapseFromBegin(ttr); + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + "\'"); + tr.RecordSection("Received request."); + + WebRequestHandler handler = WebRequestHandler(); + + std::shared_ptr response; + auto status_dto = handler.DropCollection(collection_name); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_204, status_dto); + break; + case StatusCode::COLLECTION_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); + } + + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; + tr.ElapseFromBegin(ttr); - StatusDto::ObjectWrapper status; - auto response = createDtoResponse(Status::CODE_201, status); return response; } @@ -378,97 +372,90 @@ class WebController : public oatpp::web::server::api::ApiController { ENDPOINT("POST", "/collections/{collection_name}/fields/{field_name}/indexes/{index_name}", CreateIndex, PATH(String, collection_name), PATH(String, field_name), PATH(String, index_name), BODY_STRING(String, body)) { - // TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() + - // "/indexes\'"); tr.RecordSection("Received request."); - // - // auto handler = WebRequestHandler(); - // - // std::shared_ptr response; - // auto status_dto = handler.CreateIndex(collection_name, body); - // switch (status_dto->code->getValue()) { - // case StatusCode::SUCCESS: - // response = createDtoResponse(Status::CODE_201, status_dto); - // break; - // case StatusCode::COLLECTION_NOT_EXISTS: - // response = createDtoResponse(Status::CODE_404, status_dto); - // break; - // default: - // response = createDtoResponse(Status::CODE_400, status_dto); - // } - // - // std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + - // ", reason = " + status_dto->message->std_str() + ". Total cost"; - // tr.ElapseFromBegin(ttr); + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() + "/indexes\'"); + tr.RecordSection("Received request."); + + auto handler = WebRequestHandler(); + + std::shared_ptr response; + auto status_dto = handler.CreateIndex(collection_name, field_name, body); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_201, status_dto); + break; + case StatusCode::COLLECTION_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); + } + + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; + tr.ElapseFromBegin(ttr); - StatusDto::ObjectWrapper status; - auto response = createDtoResponse(Status::CODE_201, status); return response; } - ADD_CORS(GetIndex) - - ENDPOINT("GET", "/collections/{collection_name}/fields/{field_name}/indexes", GetIndex, - PATH(String, collection_name), PATH(String, field_name)) { - // TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + - // "/indexes\'"); - // tr.RecordSection("Received request."); - // - // auto handler = WebRequestHandler(); - // - // OString result; - // auto status_dto = handler.GetIndex(collection_name, result); - // - // std::shared_ptr response; - // switch (status_dto->code->getValue()) { - // case StatusCode::SUCCESS: - // response = createResponse(Status::CODE_200, result); - // break; - // case StatusCode::COLLECTION_NOT_EXISTS: - // response = createDtoResponse(Status::CODE_404, status_dto); - // break; - // default: - // response = createDtoResponse(Status::CODE_400, status_dto); - // } - // - // std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + - // ", reason = " + status_dto->message->std_str() + ". Total cost"; - // tr.ElapseFromBegin(ttr); - - json result = R"({ "index_name": "FLAT", "params": {"index_type": "IVF_FLAT", "nlist": 4096 } })"; - - auto response = createResponse(Status::CODE_200, result.dump().c_str()); - return response; - } + // ADD_CORS(GetIndex) + // + // ENDPOINT("GET", "/collections/{collection_name}/fields/{field_name}/indexes", GetIndex, + // PATH(String, collection_name), PATH(String, field_name)) { + // TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + + // "/indexes\'"); + // tr.RecordSection("Received request."); + // + // auto handler = WebRequestHandler(); + // + // OString result; + // auto status_dto = handler.GetIndex(collection_name, result); + // + // std::shared_ptr response; + // switch (status_dto->code->getValue()) { + // case StatusCode::SUCCESS: + // response = createResponse(Status::CODE_200, result); + // break; + // case StatusCode::COLLECTION_NOT_EXISTS: + // response = createDtoResponse(Status::CODE_404, status_dto); + // break; + // default: + // response = createDtoResponse(Status::CODE_400, status_dto); + // } + // + // std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + // ", reason = " + status_dto->message->std_str() + ". Total cost"; + // tr.ElapseFromBegin(ttr); + // + // return response; + // } ADD_CORS(DropIndex) ENDPOINT("DELETE", "/collections/{collection_name}/fields/{field_name}/indexes/{index_name}", DropIndex, PATH(String, collection_name), PATH(String, field_name), PATH(String, index_name)) { - // TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + - // "/indexes\'"); - // tr.RecordSection("Received request."); - // - // auto handler = WebRequestHandler(); - // - // std::shared_ptr response; - // auto status_dto = handler.DropIndex(collection_name); - // switch (status_dto->code->getValue()) { - // case StatusCode::SUCCESS: - // response = createDtoResponse(Status::CODE_204, status_dto); - // break; - // case StatusCode::COLLECTION_NOT_EXISTS: - // response = createDtoResponse(Status::CODE_404, status_dto); - // break; - // default: - // response = createDtoResponse(Status::CODE_400, status_dto); - // } - // - // std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + - // ", reason = " + status_dto->message->std_str() + ". Total cost"; - // tr.ElapseFromBegin(ttr); + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + + "/indexes\'"); + tr.RecordSection("Received request."); + + auto handler = WebRequestHandler(); + + std::shared_ptr response; + auto status_dto = handler.DropIndex(collection_name, field_name); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_204, status_dto); + break; + case StatusCode::COLLECTION_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); + } + + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; + tr.ElapseFromBegin(ttr); - StatusDto::ObjectWrapper status; - auto response = createDtoResponse(Status::CODE_204, status); return response; } @@ -574,23 +561,18 @@ class WebController : public oatpp::web::server::api::ApiController { ENDPOINT("GET", "/collections/{collection_name}/partitions/{partition_tag}/entities", GetEntities, PATH(String, collection_name), PATH(String, partition_tag), QUERIES(const QueryParams&, query_params), BODY_STRING(String, body)) { - json result = R"({ - "entities": [ - { - "__id": "1578989029645098000", - "field_1": 1, - "field_vec": [] - }, - { - "__id": "1578989029645098001", - "field_1": 2, - "field_vec": [] - } - ] - })"; + auto handler = WebRequestHandler(); - auto response = createResponse(Status::CODE_200, result.dump().c_str()); - return response; + String response; + auto status_dto = handler.GetEntity(collection_name, query_params, response); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + return createResponse(Status::CODE_200, response); + case StatusCode::COLLECTION_NOT_EXISTS: + return createDtoResponse(Status::CODE_404, status_dto); + default: + return createDtoResponse(Status::CODE_400, status_dto); + } } ADD_CORS(ShowSegments) @@ -645,75 +627,6 @@ class WebController : public oatpp::web::server::api::ApiController { return createResponse(Status::CODE_204, "No Content"); } - ADD_CORS(GetVectors) - /** - * - * GetVectorByID ?id= - */ - ENDPOINT("GET", "/collections/{collection_name}/Entities", GetVectors, PATH(String, collection_name), - QUERIES(const QueryParams&, query_params)) { - // auto handler = WebRequestHandler(); - // String response; - // auto status_dto = handler.GetVector(collection_name, query_params, response); - // - // switch (status_dto->code->getValue()) { - // case StatusCode::SUCCESS: - // return createResponse(Status::CODE_200, response); - // case StatusCode::COLLECTION_NOT_EXISTS: - // return createDtoResponse(Status::CODE_404, status_dto); - // default: - // return createDtoResponse(Status::CODE_400, status_dto); - // } - json result = R"({ - "entities": [ - { - "__id": "1578989029645098000", - "field_1": 1, - "field_vec": [] - }, - { - "__id": "1578989029645098001", - "field_1": 2, - "field_vec": [] - } - ] - })"; - auto response = createResponse(Status::CODE_200, result.dump().c_str()); - return response; - } - - ADD_CORS(Insert) - - ENDPOINT("POST", "/collections/{collection_name}/entities", Insert, PATH(String, collection_name), - BODY_STRING(String, body)) { - // TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + - // "/vectors\'"); - // tr.RecordSection("Received request."); - // - // auto ids_dto = VectorIdsDto::createShared(); - // WebRequestHandler handler = WebRequestHandler(); - // - // std::shared_ptr response; - // auto status_dto = handler.Insert(collection_name, body, ids_dto); - // switch (status_dto->code->getValue()) { - // case StatusCode::SUCCESS: - // response = createDtoResponse(Status::CODE_201, ids_dto); - // break; - // case StatusCode::COLLECTION_NOT_EXISTS: - // response = createDtoResponse(Status::CODE_404, status_dto); - // break; - // default: - // response = createDtoResponse(Status::CODE_400, status_dto); - // } - // - // tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + - // ", reason = " + status_dto->message->std_str() + ". Total cost"); - - StatusDto::ObjectWrapper status; - auto response = createDtoResponse(Status::CODE_201, status); - return response; - } - ADD_CORS(InsertEntity) ENDPOINT("POST", "/hybrid_collections/{collection_name}/entities", InsertEntity, PATH(String, collection_name), @@ -756,7 +669,7 @@ class WebController : public oatpp::web::server::api::ApiController { OString result; std::shared_ptr response; - auto status_dto = handler.VectorsOp(collection_name, body, result); + auto status_dto = handler.EntityOp(collection_name, body, result); switch (status_dto->code->getValue()) { case StatusCode::SUCCESS: response = createResponse(Status::CODE_200, result); @@ -774,61 +687,6 @@ class WebController : public oatpp::web::server::api::ApiController { return response; } - ADD_CORS(VectorsOp) - - ENDPOINT("PUT", "/collections/{collection_name}/entities", VectorsOp, PATH(String, collection_name), - BODY_STRING(String, body)) { - // TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() + - // "/vectors\'"); - // tr.RecordSection("Received request."); - // - // WebRequestHandler handler = WebRequestHandler(); - // - // OString result; - // std::shared_ptr response; - // auto status_dto = handler.VectorsOp(collection_name, body, result); - // switch (status_dto->code->getValue()) { - // case StatusCode::SUCCESS: - // response = createResponse(Status::CODE_200, result); - // break; - // case StatusCode::COLLECTION_NOT_EXISTS: - // response = createDtoResponse(Status::CODE_404, status_dto); - // break; - // default: - // response = createDtoResponse(Status::CODE_400, status_dto); - // } - // - // tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + - // ", reason = " + status_dto->message->std_str() + ". Total cost"); - json result = R"({ - "num": 2, - "results": [ - [ - { - "id": "1578989029645098000", - "distance": "0.000000", - "entity": { - "field_1": 1, - "field_2": 2, - "field_vec": [] - } - }, - { - "id": "1578989029645098001", - "distance": "0.010000", - "entity": { - "field_1": 10, - "field_2": 20, - "field_vec": [] - } - } - ] - ] - })"; - auto response = createResponse(Status::CODE_200, result.dump().c_str()); - return response; - } - ADD_CORS(SystemOptions) ENDPOINT("OPTIONS", "/system/{info}", SystemOptions) { @@ -885,29 +743,6 @@ class WebController : public oatpp::web::server::api::ApiController { return response; } - ADD_CORS(CreateHybridCollection) - - ENDPOINT("POST", "/hybrid_collections", CreateHybridCollection, BODY_STRING(String, body_str)) { - TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/hybrid_collections\'"); - tr.RecordSection("Received request."); - WebRequestHandler handler = WebRequestHandler(); - - std::shared_ptr response; - auto status_dto = handler.CreateHybridCollection(body_str); - switch (status_dto->code->getValue()) { - case StatusCode::SUCCESS: - response = createDtoResponse(Status::CODE_201, status_dto); - break; - default: - response = createDtoResponse(Status::CODE_400, status_dto); - } - - std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + - ", reason = " + status_dto->message->std_str() + ". Total cost"; - tr.ElapseFromBegin(ttr); - return response; - } - /** * Finish ENDPOINTs generation ('ApiController' codegen) */ diff --git a/core/src/server/web_impl/handler/WebRequestHandler.cpp b/core/src/server/web_impl/handler/WebRequestHandler.cpp index 0d626b37f8..2f6dcc3e6b 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.cpp +++ b/core/src/server/web_impl/handler/WebRequestHandler.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -23,6 +24,7 @@ #include "db/Utils.h" #include "metrics/SystemInfo.h" #include "query/BinaryQuery.h" +#include "server/ValidationUtil.h" #include "server/delivery/request/BaseReq.h" #include "server/web_impl/Constants.h" #include "server/web_impl/Types.h" @@ -117,29 +119,31 @@ WebRequestHandler::IsBinaryCollection(const std::string& collection_name, bool& } Status -WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, engine::VectorsData& vectors, bool bin) { +WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, std::vector& vectors_data, bool bin) { if (!json.is_array()) { return Status(ILLEGAL_BODY, "field \"vectors\" must be a array"); } - vectors.vector_count_ = json.size(); - + std::vector float_vector; if (!bin) { for (auto& vec : json) { if (!vec.is_array()) { return Status(ILLEGAL_BODY, "A vector in field \"vectors\" must be a float array"); } for (auto& data : vec) { - vectors.float_data_.emplace_back(data.get()); + float_vector.emplace_back(data.get()); } } + auto size = float_vector.size() * sizeof(float); + vectors_data.resize(size); + memcpy(vectors_data.data(), float_vector.data(), size); } else { for (auto& vec : json) { if (!vec.is_array()) { return Status(ILLEGAL_BODY, "A vector in field \"vectors\" must be a float array"); } for (auto& data : vec) { - vectors.binary_data_.emplace_back(data.get()); + vectors_data.emplace_back(data.get()); } } } @@ -147,6 +151,79 @@ WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, engine::Vecto return Status::OK(); } +Status +WebRequestHandler::CopyData2Json(const milvus::engine::DataChunkPtr& data_chunk, + const milvus::engine::snapshot::FieldElementMappings& field_mappings, + const std::vector& id_array, nlohmann::json& json_res) { + int64_t id_size = id_array.size(); + for (int i = 0; i < id_size; i++) { + nlohmann::json one_json; + nlohmann::json entity_json; + for (const auto& it : field_mappings) { + auto type = it.first->GetFtype(); + std::string name = it.first->GetName(); + + engine::BinaryDataPtr data = data_chunk->fixed_fields_[name]; + if (data == nullptr || data->data_.empty()) + continue; + + auto single_size = data->data_.size() / id_size; + + switch (type) { + case engine::DataType::INT32: { + int32_t int32_value; + int64_t offset = sizeof(int32_t) * i; + memcpy(&int32_value, data->data_.data() + offset, sizeof(int32_t)); + entity_json[name] = int32_value; + break; + } + case engine::DataType::INT64: { + int64_t int64_value; + int64_t offset = sizeof(int64_t) * i; + memcpy(&int64_value, data->data_.data() + offset, sizeof(int64_t)); + entity_json[name] = int64_value; + break; + } + case engine::DataType::FLOAT: { + float float_value; + int64_t offset = sizeof(float) * i; + memcpy(&float_value, data->data_.data() + offset, sizeof(float)); + entity_json[name] = float_value; + break; + } + case engine::DataType::DOUBLE: { + double double_value; + int64_t offset = sizeof(double) * i; + memcpy(&double_value, data->data_.data() + offset, sizeof(double)); + entity_json[name] = double_value; + break; + } + case engine::DataType::VECTOR_BINARY: { + std::vector binary_vector; + auto vector_size = single_size * sizeof(int8_t) / sizeof(int8_t); + binary_vector.resize(vector_size); + int64_t offset = vector_size * i; + memcpy(binary_vector.data(), data->data_.data() + offset, vector_size); + entity_json[name] = binary_vector; + break; + } + case engine::DataType::VECTOR_FLOAT: { + std::vector float_vector; + auto vector_size = single_size * sizeof(int8_t) / sizeof(float); + float_vector.resize(vector_size); + int64_t offset = vector_size * i; + memcpy(float_vector.data(), data->data_.data() + offset, vector_size); + entity_json[name] = float_vector; + break; + } + } + } + one_json["entity"] = entity_json; + one_json["id"] = id_array[i]; + json_res.push_back(one_json); + } +} + ///////////////////////// WebRequestHandler methods /////////////////////////////////////// Status WebRequestHandler::GetCollectionMetaInfo(const std::string& collection_name, nlohmann::json& json_out) { @@ -157,12 +234,14 @@ WebRequestHandler::GetCollectionMetaInfo(const std::string& collection_name, nlo STATUS_CHECK(req_handler_.CountEntities(context_ptr_, collection_name, count)); json_out["collection_name"] = schema.collection_name_; - json_out["dimension"] = schema.extra_params_[engine::PARAM_DIMENSION].get(); - json_out["segment_row_count"] = schema.extra_params_[engine::PARAM_SEGMENT_ROW_COUNT].get(); - json_out["metric_type"] = schema.extra_params_[engine::PARAM_INDEX_METRIC_TYPE].get(); - json_out["index_params"] = schema.extra_params_[engine::PARAM_INDEX_EXTRA_PARAMS].get(); - json_out["count"] = count; - + for (const auto& field : schema.fields_) { + nlohmann::json field_json; + field_json["field_name"] = field.first; + field_json["field_type"] = field.second.field_type_; + field_json["index_params"] = field.second.index_params_; + field_json["extra_params"] = field.second.field_params_; + json_out["field"].push_back(field_json); + } return Status::OK(); } @@ -194,7 +273,7 @@ WebRequestHandler::GetSegmentVectors(const std::string& collection_name, int64_t auto new_ids = std::vector(vector_ids.begin() + ids_begin, vector_ids.begin() + ids_end); nlohmann::json vectors_json; - auto status = GetVectorsByIDs(collection_name, new_ids, vectors_json); + // auto status = GetVectorsByIDs(collection_name, new_ids, vectors_json); nlohmann::json result_json; if (vectors_json.empty()) { @@ -204,7 +283,7 @@ WebRequestHandler::GetSegmentVectors(const std::string& collection_name, int64_t } json_out["count"] = vector_ids.size(); - AddStatusToJson(json_out, status.code(), status.message()); + // AddStatusToJson(json_out, status.code(), status.message()); return Status::OK(); } @@ -406,287 +485,162 @@ WebRequestHandler::SetConfig(const nlohmann::json& json, std::string& result_str } Status -WebRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, milvus::query::BooleanQueryPtr& query) { +WebRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, milvus::query::BooleanQueryPtr& query, + std::string& field_name, query::QueryPtr& query_ptr) { + auto status = Status::OK(); if (json.contains("term")) { auto leaf_query = std::make_shared(); - auto term_json = json["term"]; - std::string field_name = term_json["field_name"]; - auto term_value_json = term_json["values"]; - if (!term_value_json.is_array()) { - std::string msg = "Term json string is not an array"; - return Status{BODY_PARSE_FAIL, msg}; + auto term_query = std::make_shared(); + nlohmann::json json_obj = json["term"]; + JSON_NULL_CHECK(json_obj); + JSON_OBJECT_CHECK(json_obj); + term_query->json_obj = json_obj; + nlohmann::json::iterator json_it = json_obj.begin(); + field_name = json_it.key(); + + leaf_query->term_query = term_query; + query->AddLeafQuery(leaf_query); + } else if (json.contains("range")) { + auto leaf_query = std::make_shared(); + auto range_query = std::make_shared(); + nlohmann::json json_obj = json["range"]; + JSON_NULL_CHECK(json_obj); + JSON_OBJECT_CHECK(json_obj); + range_query->json_obj = json_obj; + nlohmann::json::iterator json_it = json_obj.begin(); + field_name = json_it.key(); + + leaf_query->range_query = range_query; + query->AddLeafQuery(leaf_query); + } else if (json.contains("vector")) { + auto leaf_query = std::make_shared(); + auto vector_json = json["vector"]; + JSON_NULL_CHECK(vector_json); + + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_int_distribution dist(0, 64); + int64_t place_number = dist(rng); + std::string placeholder = "placeholder" + std::to_string(place_number); + leaf_query->vector_placeholder = placeholder; + query->AddLeafQuery(leaf_query); + + auto vector_query = std::make_shared(); + json::iterator vector_param_it = vector_json.begin(); + if (vector_param_it != vector_json.end()) { + const std::string& vector_name = vector_param_it.key(); + vector_query->field_name = vector_name; + nlohmann::json param_json = vector_param_it.value(); + int64_t topk = param_json["topk"]; + status = server::ValidateSearchTopk(topk); + if (!status.ok()) { + return status; + } + vector_query->topk = topk; + if (param_json.contains("metric_type")) { + std::string metric_type = param_json["metric_type"]; + vector_query->metric_type = metric_type; + query_ptr->metric_types.insert({vector_name, param_json["metric_type"]}); + } + if (!vector_param_it.value()["params"].empty()) { + vector_query->extra_params = vector_param_it.value()["params"]; + } + engine::VectorsData vector_data; + for (auto& vector_records : vector_param_it.value()["values"]) { + // TODO: Binary vector??? + for (auto& data : vector_records) { + vector_query->query_vector.float_data.emplace_back(data.get()); + } + } + query_ptr->index_fields.insert(vector_name); } - // auto term_size = term_value_json.size(); - // auto term_query = std::make_shared(); - // term_query->field_name = field_name; - // term_query->field_value.resize(term_size * sizeof(int64_t)); - // - // switch (field_type_.at(field_name)) { - // case engine::DataType::INT8: - // case engine::DataType::INT16: - // case engine::DataType::INT32: - // case engine::DataType::INT64: { - // std::vector term_value(term_size, 0); - // for (uint64_t i = 0; i < term_size; ++i) { - // term_value[i] = term_value_json[i].get(); - // } - // memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(int64_t)); - // break; - // } - // case engine::DataType::FLOAT: - // case engine::DataType::DOUBLE: { - // std::vector term_value(term_size, 0); - // for (uint64_t i = 0; i < term_size; ++i) { - // term_value[i] = term_value_json[i].get(); - // } - // memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(double)); - // break; - // } - // default: - // break; - // } - // - // leaf_query->term_query = term_query; - // query->AddLeafQuery(leaf_query); - // } else if (json.contains("range")) { - // auto leaf_query = std::make_shared(); - // auto range_query = std::make_shared(); - // - // auto range_json = json["range"]; - // std::string field_name = range_json["field_name"]; - // range_query->field_name = field_name; - // - // auto range_value_json = range_json["values"]; - // if (range_value_json.contains("lt")) { - // query::CompareExpr compare_expr; - // compare_expr.compare_operator = query::CompareOperator::LT; - // compare_expr.operand = range_value_json["lt"].get(); - // range_query->compare_expr.emplace_back(compare_expr); - // } - // if (range_value_json.contains("lte")) { - // query::CompareExpr compare_expr; - // compare_expr.compare_operator = query::CompareOperator::LTE; - // compare_expr.operand = range_value_json["lte"].get(); - // range_query->compare_expr.emplace_back(compare_expr); - // } - // if (range_value_json.contains("eq")) { - // query::CompareExpr compare_expr; - // compare_expr.compare_operator = query::CompareOperator::EQ; - // compare_expr.operand = range_value_json["eq"].get(); - // range_query->compare_expr.emplace_back(compare_expr); - // } - // if (range_value_json.contains("ne")) { - // query::CompareExpr compare_expr; - // compare_expr.compare_operator = query::CompareOperator::NE; - // compare_expr.operand = range_value_json["ne"].get(); - // range_query->compare_expr.emplace_back(compare_expr); - // } - // if (range_value_json.contains("gt")) { - // query::CompareExpr compare_expr; - // compare_expr.compare_operator = query::CompareOperator::GT; - // compare_expr.operand = range_value_json["gt"].get(); - // range_query->compare_expr.emplace_back(compare_expr); - // } - // if (range_value_json.contains("gte")) { - // query::CompareExpr compare_expr; - // compare_expr.compare_operator = query::CompareOperator::GTE; - // compare_expr.operand = range_value_json["gte"].get(); - // range_query->compare_expr.emplace_back(compare_expr); - // } - // - // leaf_query->range_query = range_query; - // query->AddLeafQuery(leaf_query); - // } else if (json.contains("vector")) { - // auto leaf_query = std::make_shared(); - // auto vector_query = std::make_shared(); - // - // auto vector_json = json["vector"]; - // std::string field_name = vector_json["field_name"]; - // vector_query->field_name = field_name; - // - // engine::VectorsData vectors; - // // TODO(yukun): process binary vector - // CopyRecordsFromJson(vector_json["values"], vectors, false); - // - // vector_query->query_vector.float_data = vectors.float_data_; - // vector_query->query_vector.binary_data = vectors.binary_data_; - // - // vector_query->topk = vector_json["topk"].get(); - // vector_query->extra_params = vector_json["extra_params"]; - // - // // TODO(yukun): remove hardcode here - // std::string vector_placeholder = "placeholder_1"; - // query_ptr_->vectors.insert(std::make_pair(vector_placeholder, vector_query)); - // leaf_query->vector_placeholder = vector_placeholder; - // query->AddLeafQuery(leaf_query); + query_ptr->vectors.insert(std::make_pair(placeholder, vector_query)); + + } else { + return Status{SERVER_INVALID_ARGUMENT, "Leaf query get wrong key"}; } - return Status::OK(); + return status; } Status -WebRequestHandler::ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query) { - if (query_json.contains("must")) { - boolean_query->SetOccur(query::Occur::MUST); - auto must_json = query_json["must"]; - if (!must_json.is_array()) { - std::string msg = "Must json string is not an array"; - return Status{BODY_PARSE_FAIL, msg}; - } - - for (auto& json : must_json) { - auto must_query = std::make_shared(); - if (json.contains("must") || json.contains("should") || json.contains("must_not")) { - ProcessBoolQueryJson(json, must_query); - boolean_query->AddBooleanQuery(must_query); - } else { - ProcessLeafQueryJson(json, boolean_query); +WebRequestHandler::ProcessBooleanQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query, + query::QueryPtr& query_ptr) { + auto status = Status::OK(); + if (query_json.empty()) { + return Status{SERVER_INVALID_ARGUMENT, "BoolQuery is null"}; + } + for (auto& el : query_json.items()) { + if (el.key() == "must") { + boolean_query->SetOccur(query::Occur::MUST); + auto must_json = el.value(); + if (!must_json.is_array()) { + std::string msg = "Must json string is not an array"; + return Status{SERVER_INVALID_DSL_PARAMETER, msg}; } - } - return Status::OK(); - } else if (query_json.contains("should")) { - boolean_query->SetOccur(query::Occur::SHOULD); - auto should_json = query_json["should"]; - if (!should_json.is_array()) { - std::string msg = "Should json string is not an array"; - return Status{BODY_PARSE_FAIL, msg}; - } - for (auto& json : should_json) { - if (json.contains("must") || json.contains("should") || json.contains("must_not")) { + for (auto& json : must_json) { + auto must_query = std::make_shared(); + if (json.contains("must") || json.contains("should") || json.contains("must_not")) { + STATUS_CHECK(ProcessBooleanQueryJson(json, must_query, query_ptr)); + boolean_query->AddBooleanQuery(must_query); + } else { + std::string field_name; + STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name, query_ptr)); + if (!field_name.empty()) { + query_ptr->index_fields.insert(field_name); + } + } + } + } else if (el.key() == "should") { + boolean_query->SetOccur(query::Occur::SHOULD); + auto should_json = el.value(); + if (!should_json.is_array()) { + std::string msg = "Should json string is not an array"; + return Status{SERVER_INVALID_DSL_PARAMETER, msg}; + } + + for (auto& json : should_json) { auto should_query = std::make_shared(); - ProcessBoolQueryJson(json, should_query); - boolean_query->AddBooleanQuery(should_query); - } else { - ProcessLeafQueryJson(json, boolean_query); + if (json.contains("must") || json.contains("should") || json.contains("must_not")) { + STATUS_CHECK(ProcessBooleanQueryJson(json, should_query, query_ptr)); + boolean_query->AddBooleanQuery(should_query); + } else { + std::string field_name; + STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name, query_ptr)); + if (!field_name.empty()) { + query_ptr->index_fields.insert(field_name); + } + } + } + } else if (el.key() == "must_not") { + boolean_query->SetOccur(query::Occur::MUST_NOT); + auto should_json = el.value(); + if (!should_json.is_array()) { + std::string msg = "Must_not json string is not an array"; + return Status{SERVER_INVALID_DSL_PARAMETER, msg}; } - } - return Status::OK(); - } else if (query_json.contains("must_not")) { - boolean_query->SetOccur(query::Occur::MUST_NOT); - auto should_json = query_json["must_not"]; - if (!should_json.is_array()) { - std::string msg = "Must_not json string is not an array"; - return Status{BODY_PARSE_FAIL, msg}; - } - for (auto& json : should_json) { - if (json.contains("must") || json.contains("should") || json.contains("must_not")) { - auto must_not_query = std::make_shared(); - ProcessBoolQueryJson(json, must_not_query); - boolean_query->AddBooleanQuery(must_not_query); - } else { - ProcessLeafQueryJson(json, boolean_query); - } - } - return Status::OK(); - } else { - std::string msg = "Must json string doesnot include right query"; - return Status{BODY_PARSE_FAIL, msg}; - } -} - -void -ConvertRowToColumnJson(const std::vector& row_attrs, const std::vector& field_names, - const int64_t row_num, nlohmann::json& column_attrs_json) { - // if (field_names.size() == 0) { - // if (row_attrs.size() > 0) { - // auto attr_it = row_attrs[0].attr_type_.begin(); - // for (; attr_it != row_attrs[0].attr_type_.end(); attr_it++) { - // field_names.emplace_back(attr_it->first); - // } - // } - // } - - for (uint64_t i = 0; i < field_names.size() - 1; i++) { - std::vector int_data; - std::vector double_data; - for (auto& attr : row_attrs) { - int64_t int_value; - double double_value; - auto attr_data = attr.attr_data_.at(field_names[i]); - switch (attr.attr_type_.at(field_names[i])) { - case engine::DataType::INT8: { - if (attr_data.size() == sizeof(int8_t)) { - int_value = attr_data[0]; - int_data.emplace_back(int_value); - } - break; - } - case engine::DataType::INT16: { - if (attr_data.size() == sizeof(int16_t)) { - memcpy(&int_value, attr_data.data(), sizeof(int16_t)); - int_data.emplace_back(int_value); - } - break; - } - case engine::DataType::INT32: { - if (attr_data.size() == sizeof(int32_t)) { - memcpy(&int_value, attr_data.data(), sizeof(int32_t)); - int_data.emplace_back(int_value); - } - break; - } - case engine::DataType::INT64: { - if (attr_data.size() == sizeof(int64_t)) { - memcpy(&int_value, attr_data.data(), sizeof(int64_t)); - int_data.emplace_back(int_value); - } - break; - } - case engine::DataType::FLOAT: { - if (attr_data.size() == sizeof(float)) { - float float_value; - memcpy(&float_value, attr_data.data(), sizeof(float)); - double_value = float_value; - double_data.emplace_back(double_value); - } - break; - } - case engine::DataType::DOUBLE: { - if (attr_data.size() == sizeof(double)) { - memcpy(&double_value, attr_data.data(), sizeof(double)); - double_data.emplace_back(double_value); - } - break; - } - default: { return; } - } - } - if (int_data.size() > 0) { - if (row_num == -1) { - nlohmann::json int_data_json(int_data); - column_attrs_json[field_names[i]] = int_data_json; - } else { - nlohmann::json topk_int_result; - int64_t topk = int_data.size() / row_num; - for (int64_t j = 0; j < row_num; j++) { - std::vector one_int_result(topk); - memcpy(one_int_result.data(), int_data.data() + j * topk, sizeof(int64_t) * topk); - nlohmann::json one_int_result_json(one_int_result); - std::string tag = "top" + std::to_string(j); - topk_int_result[tag] = one_int_result_json; - } - column_attrs_json[field_names[i]] = topk_int_result; - } - } else if (double_data.size() > 0) { - if (row_num == -1) { - nlohmann::json double_data_json(double_data); - column_attrs_json[field_names[i]] = double_data_json; - } else { - nlohmann::json topk_double_result; - int64_t topk = int_data.size() / row_num; - for (int64_t j = 0; j < row_num; j++) { - std::vector one_double_result(topk); - memcpy(one_double_result.data(), double_data.data() + j * topk, sizeof(double) * topk); - nlohmann::json one_double_result_json(one_double_result); - std::string tag = "top" + std::to_string(j); - topk_double_result[tag] = one_double_result_json; - } - column_attrs_json[field_names[i]] = topk_double_result; + for (auto& json : should_json) { + if (json.contains("must") || json.contains("should") || json.contains("must_not")) { + auto must_not_query = std::make_shared(); + STATUS_CHECK(ProcessBooleanQueryJson(json, must_not_query, query_ptr)); + boolean_query->AddBooleanQuery(must_not_query); + } else { + std::string field_name; + STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name, query_ptr)); + if (!field_name.empty()) { + query_ptr->index_fields.insert(field_name); + } + } } + } else { + std::string msg = "BoolQuery json string does not include bool query"; + return Status{SERVER_INVALID_DSL_PARAMETER, msg}; } } + + return status; } Status @@ -724,7 +678,7 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js auto boolean_query = std::make_shared(); query_ptr_ = std::make_shared(); - status = ProcessBoolQueryJson(boolean_query_json, boolean_query); + status = ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr_); if (!status.ok()) { return status; } @@ -749,22 +703,75 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js return Status::OK(); } - auto step = result->result_ids_.size() / result->row_num_; - nlohmann::json search_result_json; + auto step = result->result_ids_.size() / result->row_num_; // topk + auto field_data = result->data_chunk_->fixed_fields_; for (int64_t i = 0; i < result->row_num_; i++) { nlohmann::json raw_result_json; for (size_t j = 0; j < step; j++) { nlohmann::json one_result_json; one_result_json["id"] = std::to_string(result->result_ids_.at(i * step + j)); one_result_json["distance"] = std::to_string(result->result_distances_.at(i * step + j)); - raw_result_json.emplace_back(one_result_json); + nlohmann::json one_entity_json; + for (const auto& field : field_mappings) { + auto field_name = field.first->GetName(); + switch ((int64_t)field.first->GetFtype()) { + case engine::DataType::INT32: { + int32_t int32_value; + int64_t offset = (i * step + j) * sizeof(int32_t); + memcpy(&int32_value, field_data.at(field_name)->data_.data() + offset, sizeof(int32_t)); + one_entity_json[field_name] = int32_value; + break; + } + case engine::DataType::INT64: { + int64_t int64_value; + int64_t offset = (i * step + j) * sizeof(int64_t); + memcpy(&int64_value, field_data.at(field_name)->data_.data() + offset, sizeof(int64_t)); + one_entity_json[field_name] = int64_value; + break; + } + case engine::DataType::FLOAT: { + float float_value; + int64_t offset = (i * step + j) * sizeof(float); + memcpy(&float_value, field_data.at(field_name)->data_.data() + offset, sizeof(float)); + one_entity_json[field_name] = float_value; + break; + } + case engine::DataType::DOUBLE: { + double double_value; + int64_t offset = (i * step + j) * sizeof(double); + memcpy(&double_value, field_data.at(field_name)->data_.data() + offset, sizeof(double)); + one_entity_json[field_name] = double_value; + break; + } + case engine::DataType::VECTOR_FLOAT: { + std::vector float_vector; + auto dim = + field_data.at(field_name)->data_.size() / (result->result_ids_.size() * sizeof(float)); + int64_t offset = (i * step + j) * dim * sizeof(float); + float_vector.resize(dim); + memcpy(float_vector.data(), field_data.at(field_name)->data_.data() + offset, + dim * sizeof(float)); + one_entity_json[field_name] = float_vector; + break; + } + case engine::DataType::VECTOR_BINARY: { + std::vector binary_vector; + auto dim = field_data.at(field_name)->data_.size() / (result->result_ids_.size()); + int64_t offset = (i * step + j) * dim; + binary_vector.resize(dim); + memcpy(binary_vector.data(), field_data.at(field_name)->data_.data() + offset, + dim * sizeof(int8_t)); + one_entity_json[field_name] = binary_vector; + break; + } + default: { return Status(SERVER_UNEXPECTED_ERROR, "Return field data type is wrong"); } + } + } + one_result_json["entity"] = one_entity_json; + raw_result_json.push_back(one_result_json); } - search_result_json.emplace_back(raw_result_json); + result_json.emplace_back(raw_result_json); } - nlohmann::json attr_json; - // ConvertRowToColumnJson(result->attrs_, query_ptr_->field_names, result->row_num_, attr_json); - result_json["Entity"] = attr_json; - result_json["result"] = search_result_json; result_str = result_json.dump(); } @@ -774,7 +781,7 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js Status WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohmann::json& json, std::string& result_str) { - std::vector vector_ids; + std::vector entity_ids; if (!json.contains("ids")) { return Status(BODY_FIELD_LOSS, "Field \"delete\" must contains \"ids\""); } @@ -788,10 +795,10 @@ WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohman if (!ValidateStringIsNumber(id_str).ok()) { return Status(ILLEGAL_BODY, "Members in \"ids\" must be integer string"); } - vector_ids.emplace_back(std::stol(id_str)); + entity_ids.emplace_back(std::stol(id_str)); } - auto status = req_handler_.DeleteEntityByID(context_ptr_, collection_name, vector_ids); + auto status = req_handler_.DeleteEntityByID(context_ptr_, collection_name, entity_ids); nlohmann::json result_json; AddStatusToJson(result_json, status.code(), status.message()); @@ -807,89 +814,23 @@ WebRequestHandler::GetEntityByIDs(const std::string& collection_name, const std: engine::DataChunkPtr data_chunk; engine::snapshot::FieldElementMappings field_mappings; - std::vector attr_batch; - std::vector vector_batch; auto status = req_handler_.GetEntityByID(context_ptr_, collection_name, ids, field_names, valid_row, field_mappings, data_chunk); if (!status.ok()) { return status; } - std::vector id_array = data_chunk->fixed_fields_[engine::FIELD_UID]->data_; - for (const auto& it : field_mappings) { - std::string name = it.first->GetName(); - uint64_t type = it.first->GetFtype(); - std::vector& data = data_chunk->fixed_fields_[name]->data_; - if (type == engine::DataType::VECTOR_BINARY) { - engine::VectorsData vectors_data; - memcpy(vectors_data.binary_data_.data(), data.data(), data.size()); - memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size()); - vector_batch.emplace_back(vectors_data); - } else if (type == engine::DataType::VECTOR_FLOAT) { - engine::VectorsData vectors_data; - memcpy(vectors_data.float_data_.data(), data.data(), data.size()); - memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size()); - vector_batch.emplace_back(vectors_data); - } else { - engine::AttrsData attrs_data; - attrs_data.attr_type_[name] = static_cast(type); - attrs_data.attr_data_[name] = data; - memcpy(attrs_data.id_array_.data(), id_array.data(), id_array.size()); - attr_batch.emplace_back(attrs_data); + int64_t valid_size = 0; + for (auto row : valid_row) { + if (row) { + valid_size++; } } - bool bin; - status = IsBinaryCollection(collection_name, bin); - if (!status.ok()) { - return status; - } - - nlohmann::json vectors_json, attrs_json; - for (size_t i = 0; i < vector_batch.size(); i++) { - nlohmann::json vector_json; - if (bin) { - vector_json["vector"] = vector_batch.at(i).binary_data_; - } else { - vector_json["vector"] = vector_batch.at(i).float_data_; - } - vector_json["id"] = std::to_string(ids[i]); - vectors_json.push_back(vector_json); - } - ConvertRowToColumnJson(attr_batch, field_names, -1, attrs_json); - json_out["vectors"] = vectors_json; - json_out["attributes"] = attrs_json; - return Status::OK(); -} - -Status -WebRequestHandler::GetVectorsByIDs(const std::string& collection_name, const std::vector& ids, - nlohmann::json& json_out) { - std::vector vector_batch; - auto status = Status::OK(); - // auto status = req_handler_.GetVectorsByID(context_ptr_, collection_name, ids, vector_batch); - if (!status.ok()) { - return status; - } - - bool bin; - status = IsBinaryCollection(collection_name, bin); - if (!status.ok()) { - return status; - } - - nlohmann::json vectors_json; - for (size_t i = 0; i < vector_batch.size(); i++) { - nlohmann::json vector_json; - if (bin) { - vector_json["vector"] = vector_batch.at(i).binary_data_; - } else { - vector_json["vector"] = vector_batch.at(i).float_data_; - } - vector_json["id"] = std::to_string(ids[i]); - json_out.push_back(vector_json); - } - + std::vector id_data = data_chunk->fixed_fields_[engine::FIELD_UID]->data_; + std::vector id_array(valid_size); + memcpy(id_array.data(), id_data.data(), valid_size * sizeof(int64_t)); + CopyData2Json(data_chunk, field_mappings, id_array, json_out); return Status::OK(); } @@ -1169,34 +1110,7 @@ WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dt * Collection { */ StatusDto::ObjectWrapper -WebRequestHandler::CreateCollection(const CollectionRequestDto::ObjectWrapper& collection_schema) { - if (nullptr == collection_schema->collection_name.get()) { - RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'collection_name\' is missing") - } - - if (nullptr == collection_schema->dimension.get()) { - RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'dimension\' is missing") - } - - if (nullptr == collection_schema->index_file_size.get()) { - RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_file_size\' is missing") - } - - if (nullptr == collection_schema->metric_type.get()) { - RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'metric_type\' is missing") - } - - auto status = Status::OK(); - // auto status = req_handler_.CreateCollection( - // context_ptr_, collection_schema->collection_name->std_str(), collection_schema->dimension, - // collection_schema->index_file_size, - // static_cast(MetricNameMap.at(collection_schema->metric_type->std_str()))); - - ASSIGN_RETURN_STATUS_DTO(status) -} - -StatusDto::ObjectWrapper -WebRequestHandler::CreateHybridCollection(const milvus::server::web::OString& body) { +WebRequestHandler::CreateCollection(const milvus::server::web::OString& body) { auto json_str = nlohmann::json::parse(body->c_str()); std::string collection_name = json_str["collection_name"]; @@ -1208,24 +1122,14 @@ WebRequestHandler::CreateHybridCollection(const milvus::server::web::OString& bo field_schema.field_params_ = field["extra_params"]; - const std::string& field_type = field["field_type"]; - if (field_type == "int8") { - field_schema.field_type_ = engine::DataType::INT8; - } else if (field_type == "int16") { - field_schema.field_type_ = engine::DataType::INT16; - } else if (field_type == "int32") { - field_schema.field_type_ = engine::DataType::INT32; - } else if (field_type == "int64") { - field_schema.field_type_ = engine::DataType::INT64; - } else if (field_type == "float") { - field_schema.field_type_ = engine::DataType::FLOAT; - } else if (field_type == "double") { - field_schema.field_type_ = engine::DataType::DOUBLE; - } else if (field_type == "vector") { - } else { + std::string field_type = field["field_type"]; + std::transform(field_type.begin(), field_type.end(), field_type.begin(), ::tolower); + + if (str2type.find(field_type) == str2type.end()) { std::string msg = field_name + " has wrong field_type"; RETURN_STATUS_DTO(BODY_PARSE_FAIL, msg.c_str()); } + field_schema.field_type_ = str2type.at(field_type); fields[field_name] = field_schema; } @@ -1336,18 +1240,15 @@ WebRequestHandler::DropCollection(const OString& collection_name) { */ StatusDto::ObjectWrapper -WebRequestHandler::CreateIndex(const OString& collection_name, const OString& body) { +WebRequestHandler::CreateIndex(const OString& collection_name, const OString& field_name, const OString& body) { try { auto request_json = nlohmann::json::parse(body->std_str()); - std::string field_name, index_name; if (!request_json.contains("index_type")) { RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_type\' is required"); } - auto status = Status::OK(); - // auto status = - // req_handler_.CreateIndex(context_ptr_, collection_name->std_str(), index, - // request_json["params"]); + auto status = + req_handler_.CreateIndex(context_ptr_, collection_name->std_str(), field_name->std_str(), "", request_json); ASSIGN_RETURN_STATUS_DTO(status); } catch (nlohmann::detail::parse_error& e) { RETURN_STATUS_DTO(BODY_PARSE_FAIL, e.what()) @@ -1359,10 +1260,8 @@ WebRequestHandler::CreateIndex(const OString& collection_name, const OString& bo } StatusDto::ObjectWrapper -WebRequestHandler::DropIndex(const OString& collection_name) { - auto status = Status::OK(); - // auto status = req_handler_.DropIndex(context_ptr_, collection_name->std_str()); - +WebRequestHandler::DropIndex(const OString& collection_name, const OString& field_name) { + auto status = req_handler_.DropIndex(context_ptr_, collection_name->std_str(), field_name->std_str(), ""); ASSIGN_RETURN_STATUS_DTO(status) } @@ -1583,9 +1482,12 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se std::string partition_name = body_json["partition_tag"]; int32_t row_num = body_json["row_num"]; + CollectionSchema collection_schema; std::unordered_map field_types; - auto status = Status::OK(); - // auto status = req_handler_.DescribeHybridCollection(context_ptr_, collection_name->c_str(), field_types); + auto status = req_handler_.GetCollectionInfo(context_ptr_, collection_name->std_str(), collection_schema); + for (const auto& field : collection_schema.fields_) { + field_types.insert({field.first, field.second.field_type_}); + } auto entities = body_json["entity"]; if (!entities.is_array()) { @@ -1621,15 +1523,12 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se break; } case engine::DataType::VECTOR_FLOAT: { - bool bin_flag; - status = IsBinaryCollection(collection_name->c_str(), bin_flag); - if (!status.ok()) { - ASSIGN_RETURN_STATUS_DTO(status) - } - - // engine::VectorsData vectors; - // CopyRecordsFromJson(field_value, vectors, bin_flag); - // vector_datas.insert(std::make_pair(field_name, vectors)); + CopyRecordsFromJson(field_value, temp_data, false); + break; + } + case engine::DataType::VECTOR_BINARY: { + CopyRecordsFromJson(field_value, temp_data, true); + break; } default: {} } @@ -1702,47 +1601,7 @@ WebRequestHandler::GetEntity(const milvus::server::web::OString& collection_name } StatusDto::ObjectWrapper -WebRequestHandler::GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response) { - auto status = Status::OK(); - try { - auto query_ids = query_params.get("ids"); - if (query_ids == nullptr || query_ids.get() == nullptr) { - RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param ids is required."); - } - - std::vector ids; - StringHelpFunctions::SplitStringByDelimeter(query_ids->c_str(), ",", ids); - - std::vector vector_ids; - for (auto& id : ids) { - vector_ids.push_back(std::stol(id)); - } - engine::VectorsData vectors; - nlohmann::json vectors_json; - status = GetVectorsByIDs(collection_name->std_str(), vector_ids, vectors_json); - if (!status.ok()) { - response = "NULL"; - ASSIGN_RETURN_STATUS_DTO(status) - } - - FloatJson json; - json["code"] = (int64_t)status.code(); - json["message"] = status.message(); - if (vectors_json.empty()) { - json["vectors"] = std::vector(); - } else { - json["vectors"] = vectors_json; - } - response = json.dump().c_str(); - } catch (std::exception& e) { - RETURN_STATUS_DTO(SERVER_UNEXPECTED_ERROR, e.what()); - } - - ASSIGN_RETURN_STATUS_DTO(status); -} - -StatusDto::ObjectWrapper -WebRequestHandler::VectorsOp(const OString& collection_name, const OString& payload, OString& response) { +WebRequestHandler::EntityOp(const OString& collection_name, const OString& payload, OString& response) { auto status = Status::OK(); std::string result_str; diff --git a/core/src/server/web_impl/handler/WebRequestHandler.h b/core/src/server/web_impl/handler/WebRequestHandler.h index 9cd198e153..6e9bea5a22 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.h +++ b/core/src/server/web_impl/handler/WebRequestHandler.h @@ -85,7 +85,11 @@ class WebRequestHandler { IsBinaryCollection(const std::string& collection_name, bool& bin); Status - CopyRecordsFromJson(const nlohmann::json& json, engine::VectorsData& vectors, bool bin); + CopyRecordsFromJson(const nlohmann::json& json, std::vector& vectors_data, bool bin); + + Status + CopyData2Json(const engine::DataChunkPtr& data_chunk, const engine::snapshot::FieldElementMappings& field_mappings, + const std::vector& id_array, nlohmann::json& json_res); protected: Status @@ -124,10 +128,12 @@ class WebRequestHandler { SetConfig(const nlohmann::json& json, std::string& result_str); Status - ProcessLeafQueryJson(const nlohmann::json& json, query::BooleanQueryPtr& boolean_query); + ProcessLeafQueryJson(const nlohmann::json& json, query::BooleanQueryPtr& boolean_query, std::string& field_name, + query::QueryPtr& query_ptr); Status - ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query); + ProcessBooleanQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query, + query::QueryPtr& query_ptr); Status Search(const std::string& collection_name, const nlohmann::json& json, std::string& result_str); @@ -135,9 +141,6 @@ class WebRequestHandler { Status DeleteByIDs(const std::string& collection_name, const nlohmann::json& json, std::string& result_str); - Status - GetVectorsByIDs(const std::string& collection_name, const std::vector& ids, nlohmann::json& json_out); - Status GetEntityByIDs(const std::string& collection_name, const std::vector& ids, std::vector& field_names, nlohmann::json& json_out); @@ -167,12 +170,10 @@ class WebRequestHandler { #endif StatusDto::ObjectWrapper - CreateCollection(const CollectionRequestDto::ObjectWrapper& table_schema); - StatusDto::ObjectWrapper - ShowCollections(const OQueryParams& query_params, OString& result); + CreateCollection(const milvus::server::web::OString& body); StatusDto::ObjectWrapper - CreateHybridCollection(const OString& body); + ShowCollections(const OQueryParams& query_params, OString& result); StatusDto::ObjectWrapper GetCollection(const OString& collection_name, const OQueryParams& query_params, OString& result); @@ -181,10 +182,10 @@ class WebRequestHandler { DropCollection(const OString& collection_name); StatusDto::ObjectWrapper - CreateIndex(const OString& collection_name, const OString& body); + CreateIndex(const OString& collection_name, const OString& field_name, const OString& body); StatusDto::ObjectWrapper - DropIndex(const OString& collection_name); + DropIndex(const OString& collection_name, const OString& field_name); StatusDto::ObjectWrapper CreatePartition(const OString& collection_name, const PartitionRequestDto::ObjectWrapper& param); @@ -221,7 +222,7 @@ class WebRequestHandler { GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response); StatusDto::ObjectWrapper - VectorsOp(const OString& collection_name, const OString& payload, OString& response); + EntityOp(const OString& collection_name, const OString& payload, OString& response); /** * diff --git a/sdk/examples/simple/src/ClientTest.cpp b/sdk/examples/simple/src/ClientTest.cpp index 65d3a04dc2..0436816de2 100644 --- a/sdk/examples/simple/src/ClientTest.cpp +++ b/sdk/examples/simple/src/ClientTest.cpp @@ -27,7 +27,7 @@ const char* COLLECTION_NAME = milvus_sdk::Utils::GenCollectionName().c_str(); constexpr int64_t COLLECTION_DIMENSION = 512; constexpr int64_t COLLECTION_INDEX_FILE_SIZE = 1024; constexpr milvus::MetricType COLLECTION_METRIC_TYPE = milvus::MetricType::L2; -constexpr int64_t BATCH_ENTITY_COUNT = 4000; +constexpr int64_t BATCH_ENTITY_COUNT = 10000; constexpr int64_t NQ = 5; constexpr int64_t TOP_K = 10; constexpr int64_t NPROBE = 32;