mirror of https://github.com/milvus-io/milvus.git
parent
6f5be4b54f
commit
4beb05499d
|
@ -1698,17 +1698,17 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery(
|
|||
if (vector_param_it != it.value().end()) {
|
||||
const std::string& field_name = vector_param_it.key();
|
||||
vector_query->field_name = field_name;
|
||||
nlohmann::json vector_json = vector_param_it.value();
|
||||
int64_t topk = vector_json["topk"];
|
||||
nlohmann::json param_json = vector_param_it.value();
|
||||
int64_t topk = param_json["topk"];
|
||||
status = server::ValidateSearchTopk(topk);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
vector_query->topk = topk;
|
||||
if (vector_json.contains("metric_type")) {
|
||||
std::string metric_type = vector_json["metric_type"];
|
||||
if (param_json.contains("metric_type")) {
|
||||
std::string metric_type = param_json["metric_type"];
|
||||
vector_query->metric_type = metric_type;
|
||||
query_ptr->metric_types.insert({field_name, vector_json["metric_type"]});
|
||||
query_ptr->metric_types.insert({field_name, param_json["metric_type"]});
|
||||
}
|
||||
if (!vector_param_it.value()["params"].empty()) {
|
||||
vector_query->extra_params = vector_param_it.value()["params"];
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
|
@ -72,6 +73,13 @@ enum StatusCode : int {
|
|||
MAX = ILLEGAL_QUERY_PARAM
|
||||
};
|
||||
|
||||
static std::map<std::string, engine::DataType> str2type = {{"int32", engine::DataType::INT32},
|
||||
{"int64", engine::DataType::INT64},
|
||||
{"float", engine::DataType::FLOAT},
|
||||
{"double", engine::DataType::DOUBLE},
|
||||
{"vector_float", engine::DataType::VECTOR_FLOAT},
|
||||
{"vector_binary", engine::DataType::VECTOR_BINARY}};
|
||||
|
||||
} // namespace web
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
|
|
|
@ -215,74 +215,71 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(CreateCollection)
|
||||
|
||||
ENDPOINT("POST", "/collections", CreateCollection, BODY_DTO(CollectionRequestDto::ObjectWrapper, body)) {
|
||||
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections\'");
|
||||
// tr.RecordSection("Received request.");
|
||||
//
|
||||
// WebRequestHandler handler = WebRequestHandler();
|
||||
//
|
||||
// std::shared_ptr<OutgoingResponse> response;
|
||||
// auto status_dto = handler.CreateCollection(body);
|
||||
// switch (status_dto->code->getValue()) {
|
||||
// case StatusCode::SUCCESS:
|
||||
// response = createDtoResponse(Status::CODE_201, status_dto);
|
||||
// break;
|
||||
// default:
|
||||
// response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
// }
|
||||
//
|
||||
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
// ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
// tr.ElapseFromBegin(ttr);
|
||||
ENDPOINT("POST", "/collections", CreateCollection, BODY_STRING(String, body_str)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.CreateCollection(body_str);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_201, status_dto);
|
||||
break;
|
||||
default:
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
StatusDto::ObjectWrapper status;
|
||||
auto response = createDtoResponse(Status::CODE_200, status);
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(ShowCollections)
|
||||
|
||||
ENDPOINT("GET", "/collections", ShowCollections, QUERIES(const QueryParams&, query_params)) {
|
||||
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections\'");
|
||||
// tr.RecordSection("Received request.");
|
||||
//
|
||||
// WebRequestHandler handler = WebRequestHandler();
|
||||
//
|
||||
// String result;
|
||||
// auto status_dto = handler.ShowCollections(query_params, result);
|
||||
// std::shared_ptr<OutgoingResponse> response;
|
||||
// switch (status_dto->code->getValue()) {
|
||||
// case StatusCode::SUCCESS:
|
||||
// response = createResponse(Status::CODE_200, result);
|
||||
// break;
|
||||
// default:
|
||||
// response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
// }
|
||||
//
|
||||
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
// ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
// tr.ElapseFromBegin(ttr);
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
json result_json = R"({
|
||||
"collections": [
|
||||
{
|
||||
"collection_name": "test_collection",
|
||||
"fields": [
|
||||
{
|
||||
"field_name": "field_vec",
|
||||
"field_type": "VECTOR_FLOAT",
|
||||
"index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
|
||||
"extra_params": {"dimension": 128, "metric_type": "L2"}
|
||||
}
|
||||
],
|
||||
"segment_size": 1024
|
||||
}
|
||||
],
|
||||
"count": 58
|
||||
})";
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
String result = result_json.dump().c_str();
|
||||
auto response = createResponse(Status::CODE_200, result);
|
||||
String result;
|
||||
auto status_dto = handler.ShowCollections(query_params, result);
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createResponse(Status::CODE_200, result);
|
||||
break;
|
||||
default:
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
// json result_json = R"({
|
||||
// "collections": [
|
||||
// {
|
||||
// "collection_name": "test_collection",
|
||||
// "fields": [
|
||||
// {
|
||||
// "field_name": "field_vec",
|
||||
// "field_type": "VECTOR_FLOAT",
|
||||
// "index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
|
||||
// "extra_params": {"dimension": 128, "metric_type": "L2"}
|
||||
// }
|
||||
// ],
|
||||
// "segment_size": 1024
|
||||
// }
|
||||
// ],
|
||||
// "count": 58
|
||||
// })";
|
||||
|
||||
response = createResponse(Status::CODE_200, result);
|
||||
return response;
|
||||
}
|
||||
|
||||
|
@ -296,74 +293,71 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ENDPOINT("GET", "/collections/{collection_name}", GetCollection, PATH(String, collection_name),
|
||||
QUERIES(const QueryParams&, query_params)) {
|
||||
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
|
||||
// "\'"); tr.RecordSection("Received request.");
|
||||
//
|
||||
// WebRequestHandler handler = WebRequestHandler();
|
||||
//
|
||||
// String response_str;
|
||||
// auto status_dto = handler.GetCollection(collection_name, query_params, response_str);
|
||||
//
|
||||
// std::shared_ptr<OutgoingResponse> response;
|
||||
// switch (status_dto->code->getValue()) {
|
||||
// case StatusCode::SUCCESS:
|
||||
// response = createResponse(Status::CODE_200, response_str);
|
||||
// break;
|
||||
// case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
// response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
// break;
|
||||
// default:
|
||||
// response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
// }
|
||||
//
|
||||
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
// ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
// tr.ElapseFromBegin(ttr);
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
json result_json = R"({
|
||||
"collection_name": "test_collection",
|
||||
"fields": [
|
||||
{
|
||||
"field_name": "field_vec",
|
||||
"field_type": "VECTOR_FLOAT",
|
||||
"index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
|
||||
"extra_params": {"dimension": 128, "metric_type": "L2"}
|
||||
}
|
||||
],
|
||||
"row_count": 10000
|
||||
})";
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
String response_str;
|
||||
auto status_dto = handler.GetCollection(collection_name, query_params, response_str);
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createResponse(Status::CODE_200, response_str);
|
||||
break;
|
||||
case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
break;
|
||||
default:
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
// json result_json = R"({
|
||||
// "collection_name": "test_collection",
|
||||
// "fields": [
|
||||
// {
|
||||
// "field_name": "field_vec",
|
||||
// "field_type": "VECTOR_FLOAT",
|
||||
// "index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
|
||||
// "extra_params": {"dimension": 128, "metric_type": "L2"}
|
||||
// }
|
||||
// ],
|
||||
// "row_count": 10000
|
||||
// })";
|
||||
|
||||
auto response = createResponse(Status::CODE_200, result_json.dump().c_str());
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(DropCollection)
|
||||
|
||||
ENDPOINT("DELETE", "/collections/{collection_name}", DropCollection, PATH(String, collection_name)) {
|
||||
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
|
||||
// "\'"); tr.RecordSection("Received request.");
|
||||
//
|
||||
// WebRequestHandler handler = WebRequestHandler();
|
||||
//
|
||||
// std::shared_ptr<OutgoingResponse> response;
|
||||
// auto status_dto = handler.DropCollection(collection_name);
|
||||
// switch (status_dto->code->getValue()) {
|
||||
// case StatusCode::SUCCESS:
|
||||
// response = createDtoResponse(Status::CODE_204, status_dto);
|
||||
// break;
|
||||
// case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
// response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
// break;
|
||||
// default:
|
||||
// response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
// }
|
||||
//
|
||||
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
// ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
// tr.ElapseFromBegin(ttr);
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + "\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.DropCollection(collection_name);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_204, status_dto);
|
||||
break;
|
||||
case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
break;
|
||||
default:
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
StatusDto::ObjectWrapper status;
|
||||
auto response = createDtoResponse(Status::CODE_201, status);
|
||||
return response;
|
||||
}
|
||||
|
||||
|
@ -378,97 +372,90 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
ENDPOINT("POST", "/collections/{collection_name}/fields/{field_name}/indexes/{index_name}", CreateIndex,
|
||||
PATH(String, collection_name), PATH(String, field_name), PATH(String, index_name),
|
||||
BODY_STRING(String, body)) {
|
||||
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() +
|
||||
// "/indexes\'"); tr.RecordSection("Received request.");
|
||||
//
|
||||
// auto handler = WebRequestHandler();
|
||||
//
|
||||
// std::shared_ptr<OutgoingResponse> response;
|
||||
// auto status_dto = handler.CreateIndex(collection_name, body);
|
||||
// switch (status_dto->code->getValue()) {
|
||||
// case StatusCode::SUCCESS:
|
||||
// response = createDtoResponse(Status::CODE_201, status_dto);
|
||||
// break;
|
||||
// case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
// response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
// break;
|
||||
// default:
|
||||
// response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
// }
|
||||
//
|
||||
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
// ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
// tr.ElapseFromBegin(ttr);
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() + "/indexes\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.CreateIndex(collection_name, field_name, body);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_201, status_dto);
|
||||
break;
|
||||
case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
break;
|
||||
default:
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
StatusDto::ObjectWrapper status;
|
||||
auto response = createDtoResponse(Status::CODE_201, status);
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(GetIndex)
|
||||
|
||||
ENDPOINT("GET", "/collections/{collection_name}/fields/{field_name}/indexes", GetIndex,
|
||||
PATH(String, collection_name), PATH(String, field_name)) {
|
||||
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
|
||||
// "/indexes\'");
|
||||
// tr.RecordSection("Received request.");
|
||||
//
|
||||
// auto handler = WebRequestHandler();
|
||||
//
|
||||
// OString result;
|
||||
// auto status_dto = handler.GetIndex(collection_name, result);
|
||||
//
|
||||
// std::shared_ptr<OutgoingResponse> response;
|
||||
// switch (status_dto->code->getValue()) {
|
||||
// case StatusCode::SUCCESS:
|
||||
// response = createResponse(Status::CODE_200, result);
|
||||
// break;
|
||||
// case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
// response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
// break;
|
||||
// default:
|
||||
// response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
// }
|
||||
//
|
||||
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
// ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
// tr.ElapseFromBegin(ttr);
|
||||
|
||||
json result = R"({ "index_name": "FLAT", "params": {"index_type": "IVF_FLAT", "nlist": 4096 } })";
|
||||
|
||||
auto response = createResponse(Status::CODE_200, result.dump().c_str());
|
||||
return response;
|
||||
}
|
||||
// ADD_CORS(GetIndex)
|
||||
//
|
||||
// ENDPOINT("GET", "/collections/{collection_name}/fields/{field_name}/indexes", GetIndex,
|
||||
// PATH(String, collection_name), PATH(String, field_name)) {
|
||||
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
|
||||
// "/indexes\'");
|
||||
// tr.RecordSection("Received request.");
|
||||
//
|
||||
// auto handler = WebRequestHandler();
|
||||
//
|
||||
// OString result;
|
||||
// auto status_dto = handler.GetIndex(collection_name, result);
|
||||
//
|
||||
// std::shared_ptr<OutgoingResponse> response;
|
||||
// switch (status_dto->code->getValue()) {
|
||||
// case StatusCode::SUCCESS:
|
||||
// response = createResponse(Status::CODE_200, result);
|
||||
// break;
|
||||
// case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
// response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
// break;
|
||||
// default:
|
||||
// response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
// }
|
||||
//
|
||||
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
// ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
// tr.ElapseFromBegin(ttr);
|
||||
//
|
||||
// return response;
|
||||
// }
|
||||
|
||||
ADD_CORS(DropIndex)
|
||||
|
||||
ENDPOINT("DELETE", "/collections/{collection_name}/fields/{field_name}/indexes/{index_name}", DropIndex,
|
||||
PATH(String, collection_name), PATH(String, field_name), PATH(String, index_name)) {
|
||||
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
|
||||
// "/indexes\'");
|
||||
// tr.RecordSection("Received request.");
|
||||
//
|
||||
// auto handler = WebRequestHandler();
|
||||
//
|
||||
// std::shared_ptr<OutgoingResponse> response;
|
||||
// auto status_dto = handler.DropIndex(collection_name);
|
||||
// switch (status_dto->code->getValue()) {
|
||||
// case StatusCode::SUCCESS:
|
||||
// response = createDtoResponse(Status::CODE_204, status_dto);
|
||||
// break;
|
||||
// case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
// response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
// break;
|
||||
// default:
|
||||
// response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
// }
|
||||
//
|
||||
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
// ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
// tr.ElapseFromBegin(ttr);
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
|
||||
"/indexes\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.DropIndex(collection_name, field_name);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_204, status_dto);
|
||||
break;
|
||||
case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
break;
|
||||
default:
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
StatusDto::ObjectWrapper status;
|
||||
auto response = createDtoResponse(Status::CODE_204, status);
|
||||
return response;
|
||||
}
|
||||
|
||||
|
@ -574,23 +561,18 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
ENDPOINT("GET", "/collections/{collection_name}/partitions/{partition_tag}/entities", GetEntities,
|
||||
PATH(String, collection_name), PATH(String, partition_tag), QUERIES(const QueryParams&, query_params),
|
||||
BODY_STRING(String, body)) {
|
||||
json result = R"({
|
||||
"entities": [
|
||||
{
|
||||
"__id": "1578989029645098000",
|
||||
"field_1": 1,
|
||||
"field_vec": []
|
||||
},
|
||||
{
|
||||
"__id": "1578989029645098001",
|
||||
"field_1": 2,
|
||||
"field_vec": []
|
||||
}
|
||||
]
|
||||
})";
|
||||
auto handler = WebRequestHandler();
|
||||
|
||||
auto response = createResponse(Status::CODE_200, result.dump().c_str());
|
||||
return response;
|
||||
String response;
|
||||
auto status_dto = handler.GetEntity(collection_name, query_params, response);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
return createResponse(Status::CODE_200, response);
|
||||
case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
return createDtoResponse(Status::CODE_404, status_dto);
|
||||
default:
|
||||
return createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
}
|
||||
|
||||
ADD_CORS(ShowSegments)
|
||||
|
@ -645,75 +627,6 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
return createResponse(Status::CODE_204, "No Content");
|
||||
}
|
||||
|
||||
ADD_CORS(GetVectors)
|
||||
/**
|
||||
*
|
||||
* GetVectorByID ?id=
|
||||
*/
|
||||
ENDPOINT("GET", "/collections/{collection_name}/Entities", GetVectors, PATH(String, collection_name),
|
||||
QUERIES(const QueryParams&, query_params)) {
|
||||
// auto handler = WebRequestHandler();
|
||||
// String response;
|
||||
// auto status_dto = handler.GetVector(collection_name, query_params, response);
|
||||
//
|
||||
// switch (status_dto->code->getValue()) {
|
||||
// case StatusCode::SUCCESS:
|
||||
// return createResponse(Status::CODE_200, response);
|
||||
// case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
// return createDtoResponse(Status::CODE_404, status_dto);
|
||||
// default:
|
||||
// return createDtoResponse(Status::CODE_400, status_dto);
|
||||
// }
|
||||
json result = R"({
|
||||
"entities": [
|
||||
{
|
||||
"__id": "1578989029645098000",
|
||||
"field_1": 1,
|
||||
"field_vec": []
|
||||
},
|
||||
{
|
||||
"__id": "1578989029645098001",
|
||||
"field_1": 2,
|
||||
"field_vec": []
|
||||
}
|
||||
]
|
||||
})";
|
||||
auto response = createResponse(Status::CODE_200, result.dump().c_str());
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(Insert)
|
||||
|
||||
ENDPOINT("POST", "/collections/{collection_name}/entities", Insert, PATH(String, collection_name),
|
||||
BODY_STRING(String, body)) {
|
||||
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() +
|
||||
// "/vectors\'");
|
||||
// tr.RecordSection("Received request.");
|
||||
//
|
||||
// auto ids_dto = VectorIdsDto::createShared();
|
||||
// WebRequestHandler handler = WebRequestHandler();
|
||||
//
|
||||
// std::shared_ptr<OutgoingResponse> response;
|
||||
// auto status_dto = handler.Insert(collection_name, body, ids_dto);
|
||||
// switch (status_dto->code->getValue()) {
|
||||
// case StatusCode::SUCCESS:
|
||||
// response = createDtoResponse(Status::CODE_201, ids_dto);
|
||||
// break;
|
||||
// case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
// response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
// break;
|
||||
// default:
|
||||
// response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
// }
|
||||
//
|
||||
// tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
// ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
StatusDto::ObjectWrapper status;
|
||||
auto response = createDtoResponse(Status::CODE_201, status);
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(InsertEntity)
|
||||
|
||||
ENDPOINT("POST", "/hybrid_collections/{collection_name}/entities", InsertEntity, PATH(String, collection_name),
|
||||
|
@ -756,7 +669,7 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
OString result;
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.VectorsOp(collection_name, body, result);
|
||||
auto status_dto = handler.EntityOp(collection_name, body, result);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createResponse(Status::CODE_200, result);
|
||||
|
@ -774,61 +687,6 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(VectorsOp)
|
||||
|
||||
ENDPOINT("PUT", "/collections/{collection_name}/entities", VectorsOp, PATH(String, collection_name),
|
||||
BODY_STRING(String, body)) {
|
||||
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() +
|
||||
// "/vectors\'");
|
||||
// tr.RecordSection("Received request.");
|
||||
//
|
||||
// WebRequestHandler handler = WebRequestHandler();
|
||||
//
|
||||
// OString result;
|
||||
// std::shared_ptr<OutgoingResponse> response;
|
||||
// auto status_dto = handler.VectorsOp(collection_name, body, result);
|
||||
// switch (status_dto->code->getValue()) {
|
||||
// case StatusCode::SUCCESS:
|
||||
// response = createResponse(Status::CODE_200, result);
|
||||
// break;
|
||||
// case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
// response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
// break;
|
||||
// default:
|
||||
// response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
// }
|
||||
//
|
||||
// tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
// ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
json result = R"({
|
||||
"num": 2,
|
||||
"results": [
|
||||
[
|
||||
{
|
||||
"id": "1578989029645098000",
|
||||
"distance": "0.000000",
|
||||
"entity": {
|
||||
"field_1": 1,
|
||||
"field_2": 2,
|
||||
"field_vec": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "1578989029645098001",
|
||||
"distance": "0.010000",
|
||||
"entity": {
|
||||
"field_1": 10,
|
||||
"field_2": 20,
|
||||
"field_vec": []
|
||||
}
|
||||
}
|
||||
]
|
||||
]
|
||||
})";
|
||||
auto response = createResponse(Status::CODE_200, result.dump().c_str());
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(SystemOptions)
|
||||
|
||||
ENDPOINT("OPTIONS", "/system/{info}", SystemOptions) {
|
||||
|
@ -885,29 +743,6 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(CreateHybridCollection)
|
||||
|
||||
ENDPOINT("POST", "/hybrid_collections", CreateHybridCollection, BODY_STRING(String, body_str)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/hybrid_collections\'");
|
||||
tr.RecordSection("Received request.");
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.CreateHybridCollection(body_str);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_201, status_dto);
|
||||
break;
|
||||
default:
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
return response;
|
||||
}
|
||||
|
||||
/**
|
||||
* Finish ENDPOINTs generation ('ApiController' codegen)
|
||||
*/
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <ctime>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
@ -23,6 +24,7 @@
|
|||
#include "db/Utils.h"
|
||||
#include "metrics/SystemInfo.h"
|
||||
#include "query/BinaryQuery.h"
|
||||
#include "server/ValidationUtil.h"
|
||||
#include "server/delivery/request/BaseReq.h"
|
||||
#include "server/web_impl/Constants.h"
|
||||
#include "server/web_impl/Types.h"
|
||||
|
@ -117,29 +119,31 @@ WebRequestHandler::IsBinaryCollection(const std::string& collection_name, bool&
|
|||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, engine::VectorsData& vectors, bool bin) {
|
||||
WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, std::vector<uint8_t>& vectors_data, bool bin) {
|
||||
if (!json.is_array()) {
|
||||
return Status(ILLEGAL_BODY, "field \"vectors\" must be a array");
|
||||
}
|
||||
|
||||
vectors.vector_count_ = json.size();
|
||||
|
||||
std::vector<float> float_vector;
|
||||
if (!bin) {
|
||||
for (auto& vec : json) {
|
||||
if (!vec.is_array()) {
|
||||
return Status(ILLEGAL_BODY, "A vector in field \"vectors\" must be a float array");
|
||||
}
|
||||
for (auto& data : vec) {
|
||||
vectors.float_data_.emplace_back(data.get<float>());
|
||||
float_vector.emplace_back(data.get<float>());
|
||||
}
|
||||
}
|
||||
auto size = float_vector.size() * sizeof(float);
|
||||
vectors_data.resize(size);
|
||||
memcpy(vectors_data.data(), float_vector.data(), size);
|
||||
} else {
|
||||
for (auto& vec : json) {
|
||||
if (!vec.is_array()) {
|
||||
return Status(ILLEGAL_BODY, "A vector in field \"vectors\" must be a float array");
|
||||
}
|
||||
for (auto& data : vec) {
|
||||
vectors.binary_data_.emplace_back(data.get<uint8_t>());
|
||||
vectors_data.emplace_back(data.get<uint8_t>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -147,6 +151,79 @@ WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, engine::Vecto
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::CopyData2Json(const milvus::engine::DataChunkPtr& data_chunk,
|
||||
const milvus::engine::snapshot::FieldElementMappings& field_mappings,
|
||||
const std::vector<int64_t>& id_array, nlohmann::json& json_res) {
|
||||
int64_t id_size = id_array.size();
|
||||
for (int i = 0; i < id_size; i++) {
|
||||
nlohmann::json one_json;
|
||||
nlohmann::json entity_json;
|
||||
for (const auto& it : field_mappings) {
|
||||
auto type = it.first->GetFtype();
|
||||
std::string name = it.first->GetName();
|
||||
|
||||
engine::BinaryDataPtr data = data_chunk->fixed_fields_[name];
|
||||
if (data == nullptr || data->data_.empty())
|
||||
continue;
|
||||
|
||||
auto single_size = data->data_.size() / id_size;
|
||||
|
||||
switch (type) {
|
||||
case engine::DataType::INT32: {
|
||||
int32_t int32_value;
|
||||
int64_t offset = sizeof(int32_t) * i;
|
||||
memcpy(&int32_value, data->data_.data() + offset, sizeof(int32_t));
|
||||
entity_json[name] = int32_value;
|
||||
break;
|
||||
}
|
||||
case engine::DataType::INT64: {
|
||||
int64_t int64_value;
|
||||
int64_t offset = sizeof(int64_t) * i;
|
||||
memcpy(&int64_value, data->data_.data() + offset, sizeof(int64_t));
|
||||
entity_json[name] = int64_value;
|
||||
break;
|
||||
}
|
||||
case engine::DataType::FLOAT: {
|
||||
float float_value;
|
||||
int64_t offset = sizeof(float) * i;
|
||||
memcpy(&float_value, data->data_.data() + offset, sizeof(float));
|
||||
entity_json[name] = float_value;
|
||||
break;
|
||||
}
|
||||
case engine::DataType::DOUBLE: {
|
||||
double double_value;
|
||||
int64_t offset = sizeof(double) * i;
|
||||
memcpy(&double_value, data->data_.data() + offset, sizeof(double));
|
||||
entity_json[name] = double_value;
|
||||
break;
|
||||
}
|
||||
case engine::DataType::VECTOR_BINARY: {
|
||||
std::vector<int8_t> binary_vector;
|
||||
auto vector_size = single_size * sizeof(int8_t) / sizeof(int8_t);
|
||||
binary_vector.resize(vector_size);
|
||||
int64_t offset = vector_size * i;
|
||||
memcpy(binary_vector.data(), data->data_.data() + offset, vector_size);
|
||||
entity_json[name] = binary_vector;
|
||||
break;
|
||||
}
|
||||
case engine::DataType::VECTOR_FLOAT: {
|
||||
std::vector<float> float_vector;
|
||||
auto vector_size = single_size * sizeof(int8_t) / sizeof(float);
|
||||
float_vector.resize(vector_size);
|
||||
int64_t offset = vector_size * i;
|
||||
memcpy(float_vector.data(), data->data_.data() + offset, vector_size);
|
||||
entity_json[name] = float_vector;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
one_json["entity"] = entity_json;
|
||||
one_json["id"] = id_array[i];
|
||||
json_res.push_back(one_json);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////// WebRequestHandler methods ///////////////////////////////////////
|
||||
Status
|
||||
WebRequestHandler::GetCollectionMetaInfo(const std::string& collection_name, nlohmann::json& json_out) {
|
||||
|
@ -157,12 +234,14 @@ WebRequestHandler::GetCollectionMetaInfo(const std::string& collection_name, nlo
|
|||
STATUS_CHECK(req_handler_.CountEntities(context_ptr_, collection_name, count));
|
||||
|
||||
json_out["collection_name"] = schema.collection_name_;
|
||||
json_out["dimension"] = schema.extra_params_[engine::PARAM_DIMENSION].get<int64_t>();
|
||||
json_out["segment_row_count"] = schema.extra_params_[engine::PARAM_SEGMENT_ROW_COUNT].get<int64_t>();
|
||||
json_out["metric_type"] = schema.extra_params_[engine::PARAM_INDEX_METRIC_TYPE].get<int64_t>();
|
||||
json_out["index_params"] = schema.extra_params_[engine::PARAM_INDEX_EXTRA_PARAMS].get<std::string>();
|
||||
json_out["count"] = count;
|
||||
|
||||
for (const auto& field : schema.fields_) {
|
||||
nlohmann::json field_json;
|
||||
field_json["field_name"] = field.first;
|
||||
field_json["field_type"] = field.second.field_type_;
|
||||
field_json["index_params"] = field.second.index_params_;
|
||||
field_json["extra_params"] = field.second.field_params_;
|
||||
json_out["field"].push_back(field_json);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -194,7 +273,7 @@ WebRequestHandler::GetSegmentVectors(const std::string& collection_name, int64_t
|
|||
|
||||
auto new_ids = std::vector<int64_t>(vector_ids.begin() + ids_begin, vector_ids.begin() + ids_end);
|
||||
nlohmann::json vectors_json;
|
||||
auto status = GetVectorsByIDs(collection_name, new_ids, vectors_json);
|
||||
// auto status = GetVectorsByIDs(collection_name, new_ids, vectors_json);
|
||||
|
||||
nlohmann::json result_json;
|
||||
if (vectors_json.empty()) {
|
||||
|
@ -204,7 +283,7 @@ WebRequestHandler::GetSegmentVectors(const std::string& collection_name, int64_t
|
|||
}
|
||||
json_out["count"] = vector_ids.size();
|
||||
|
||||
AddStatusToJson(json_out, status.code(), status.message());
|
||||
// AddStatusToJson(json_out, status.code(), status.message());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -406,287 +485,162 @@ WebRequestHandler::SetConfig(const nlohmann::json& json, std::string& result_str
|
|||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, milvus::query::BooleanQueryPtr& query) {
|
||||
WebRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, milvus::query::BooleanQueryPtr& query,
|
||||
std::string& field_name, query::QueryPtr& query_ptr) {
|
||||
auto status = Status::OK();
|
||||
if (json.contains("term")) {
|
||||
auto leaf_query = std::make_shared<query::LeafQuery>();
|
||||
auto term_json = json["term"];
|
||||
std::string field_name = term_json["field_name"];
|
||||
auto term_value_json = term_json["values"];
|
||||
if (!term_value_json.is_array()) {
|
||||
std::string msg = "Term json string is not an array";
|
||||
return Status{BODY_PARSE_FAIL, msg};
|
||||
auto term_query = std::make_shared<query::TermQuery>();
|
||||
nlohmann::json json_obj = json["term"];
|
||||
JSON_NULL_CHECK(json_obj);
|
||||
JSON_OBJECT_CHECK(json_obj);
|
||||
term_query->json_obj = json_obj;
|
||||
nlohmann::json::iterator json_it = json_obj.begin();
|
||||
field_name = json_it.key();
|
||||
|
||||
leaf_query->term_query = term_query;
|
||||
query->AddLeafQuery(leaf_query);
|
||||
} else if (json.contains("range")) {
|
||||
auto leaf_query = std::make_shared<query::LeafQuery>();
|
||||
auto range_query = std::make_shared<query::RangeQuery>();
|
||||
nlohmann::json json_obj = json["range"];
|
||||
JSON_NULL_CHECK(json_obj);
|
||||
JSON_OBJECT_CHECK(json_obj);
|
||||
range_query->json_obj = json_obj;
|
||||
nlohmann::json::iterator json_it = json_obj.begin();
|
||||
field_name = json_it.key();
|
||||
|
||||
leaf_query->range_query = range_query;
|
||||
query->AddLeafQuery(leaf_query);
|
||||
} else if (json.contains("vector")) {
|
||||
auto leaf_query = std::make_shared<query::LeafQuery>();
|
||||
auto vector_json = json["vector"];
|
||||
JSON_NULL_CHECK(vector_json);
|
||||
|
||||
std::random_device dev;
|
||||
std::mt19937 rng(dev());
|
||||
std::uniform_int_distribution<std::mt19937::result_type> dist(0, 64);
|
||||
int64_t place_number = dist(rng);
|
||||
std::string placeholder = "placeholder" + std::to_string(place_number);
|
||||
leaf_query->vector_placeholder = placeholder;
|
||||
query->AddLeafQuery(leaf_query);
|
||||
|
||||
auto vector_query = std::make_shared<query::VectorQuery>();
|
||||
json::iterator vector_param_it = vector_json.begin();
|
||||
if (vector_param_it != vector_json.end()) {
|
||||
const std::string& vector_name = vector_param_it.key();
|
||||
vector_query->field_name = vector_name;
|
||||
nlohmann::json param_json = vector_param_it.value();
|
||||
int64_t topk = param_json["topk"];
|
||||
status = server::ValidateSearchTopk(topk);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
vector_query->topk = topk;
|
||||
if (param_json.contains("metric_type")) {
|
||||
std::string metric_type = param_json["metric_type"];
|
||||
vector_query->metric_type = metric_type;
|
||||
query_ptr->metric_types.insert({vector_name, param_json["metric_type"]});
|
||||
}
|
||||
if (!vector_param_it.value()["params"].empty()) {
|
||||
vector_query->extra_params = vector_param_it.value()["params"];
|
||||
}
|
||||
engine::VectorsData vector_data;
|
||||
for (auto& vector_records : vector_param_it.value()["values"]) {
|
||||
// TODO: Binary vector???
|
||||
for (auto& data : vector_records) {
|
||||
vector_query->query_vector.float_data.emplace_back(data.get<float>());
|
||||
}
|
||||
}
|
||||
query_ptr->index_fields.insert(vector_name);
|
||||
}
|
||||
|
||||
// auto term_size = term_value_json.size();
|
||||
// auto term_query = std::make_shared<query::TermQuery>();
|
||||
// term_query->field_name = field_name;
|
||||
// term_query->field_value.resize(term_size * sizeof(int64_t));
|
||||
//
|
||||
// switch (field_type_.at(field_name)) {
|
||||
// case engine::DataType::INT8:
|
||||
// case engine::DataType::INT16:
|
||||
// case engine::DataType::INT32:
|
||||
// case engine::DataType::INT64: {
|
||||
// std::vector<int64_t> term_value(term_size, 0);
|
||||
// for (uint64_t i = 0; i < term_size; ++i) {
|
||||
// term_value[i] = term_value_json[i].get<int64_t>();
|
||||
// }
|
||||
// memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(int64_t));
|
||||
// break;
|
||||
// }
|
||||
// case engine::DataType::FLOAT:
|
||||
// case engine::DataType::DOUBLE: {
|
||||
// std::vector<double> term_value(term_size, 0);
|
||||
// for (uint64_t i = 0; i < term_size; ++i) {
|
||||
// term_value[i] = term_value_json[i].get<double>();
|
||||
// }
|
||||
// memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(double));
|
||||
// break;
|
||||
// }
|
||||
// default:
|
||||
// break;
|
||||
// }
|
||||
//
|
||||
// leaf_query->term_query = term_query;
|
||||
// query->AddLeafQuery(leaf_query);
|
||||
// } else if (json.contains("range")) {
|
||||
// auto leaf_query = std::make_shared<query::LeafQuery>();
|
||||
// auto range_query = std::make_shared<query::RangeQuery>();
|
||||
//
|
||||
// auto range_json = json["range"];
|
||||
// std::string field_name = range_json["field_name"];
|
||||
// range_query->field_name = field_name;
|
||||
//
|
||||
// auto range_value_json = range_json["values"];
|
||||
// if (range_value_json.contains("lt")) {
|
||||
// query::CompareExpr compare_expr;
|
||||
// compare_expr.compare_operator = query::CompareOperator::LT;
|
||||
// compare_expr.operand = range_value_json["lt"].get<std::string>();
|
||||
// range_query->compare_expr.emplace_back(compare_expr);
|
||||
// }
|
||||
// if (range_value_json.contains("lte")) {
|
||||
// query::CompareExpr compare_expr;
|
||||
// compare_expr.compare_operator = query::CompareOperator::LTE;
|
||||
// compare_expr.operand = range_value_json["lte"].get<std::string>();
|
||||
// range_query->compare_expr.emplace_back(compare_expr);
|
||||
// }
|
||||
// if (range_value_json.contains("eq")) {
|
||||
// query::CompareExpr compare_expr;
|
||||
// compare_expr.compare_operator = query::CompareOperator::EQ;
|
||||
// compare_expr.operand = range_value_json["eq"].get<std::string>();
|
||||
// range_query->compare_expr.emplace_back(compare_expr);
|
||||
// }
|
||||
// if (range_value_json.contains("ne")) {
|
||||
// query::CompareExpr compare_expr;
|
||||
// compare_expr.compare_operator = query::CompareOperator::NE;
|
||||
// compare_expr.operand = range_value_json["ne"].get<std::string>();
|
||||
// range_query->compare_expr.emplace_back(compare_expr);
|
||||
// }
|
||||
// if (range_value_json.contains("gt")) {
|
||||
// query::CompareExpr compare_expr;
|
||||
// compare_expr.compare_operator = query::CompareOperator::GT;
|
||||
// compare_expr.operand = range_value_json["gt"].get<std::string>();
|
||||
// range_query->compare_expr.emplace_back(compare_expr);
|
||||
// }
|
||||
// if (range_value_json.contains("gte")) {
|
||||
// query::CompareExpr compare_expr;
|
||||
// compare_expr.compare_operator = query::CompareOperator::GTE;
|
||||
// compare_expr.operand = range_value_json["gte"].get<std::string>();
|
||||
// range_query->compare_expr.emplace_back(compare_expr);
|
||||
// }
|
||||
//
|
||||
// leaf_query->range_query = range_query;
|
||||
// query->AddLeafQuery(leaf_query);
|
||||
// } else if (json.contains("vector")) {
|
||||
// auto leaf_query = std::make_shared<query::LeafQuery>();
|
||||
// auto vector_query = std::make_shared<query::VectorQuery>();
|
||||
//
|
||||
// auto vector_json = json["vector"];
|
||||
// std::string field_name = vector_json["field_name"];
|
||||
// vector_query->field_name = field_name;
|
||||
//
|
||||
// engine::VectorsData vectors;
|
||||
// // TODO(yukun): process binary vector
|
||||
// CopyRecordsFromJson(vector_json["values"], vectors, false);
|
||||
//
|
||||
// vector_query->query_vector.float_data = vectors.float_data_;
|
||||
// vector_query->query_vector.binary_data = vectors.binary_data_;
|
||||
//
|
||||
// vector_query->topk = vector_json["topk"].get<int64_t>();
|
||||
// vector_query->extra_params = vector_json["extra_params"];
|
||||
//
|
||||
// // TODO(yukun): remove hardcode here
|
||||
// std::string vector_placeholder = "placeholder_1";
|
||||
// query_ptr_->vectors.insert(std::make_pair(vector_placeholder, vector_query));
|
||||
// leaf_query->vector_placeholder = vector_placeholder;
|
||||
// query->AddLeafQuery(leaf_query);
|
||||
query_ptr->vectors.insert(std::make_pair(placeholder, vector_query));
|
||||
|
||||
} else {
|
||||
return Status{SERVER_INVALID_ARGUMENT, "Leaf query get wrong key"};
|
||||
}
|
||||
return Status::OK();
|
||||
return status;
|
||||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query) {
|
||||
if (query_json.contains("must")) {
|
||||
boolean_query->SetOccur(query::Occur::MUST);
|
||||
auto must_json = query_json["must"];
|
||||
if (!must_json.is_array()) {
|
||||
std::string msg = "Must json string is not an array";
|
||||
return Status{BODY_PARSE_FAIL, msg};
|
||||
}
|
||||
|
||||
for (auto& json : must_json) {
|
||||
auto must_query = std::make_shared<query::BooleanQuery>();
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
ProcessBoolQueryJson(json, must_query);
|
||||
boolean_query->AddBooleanQuery(must_query);
|
||||
} else {
|
||||
ProcessLeafQueryJson(json, boolean_query);
|
||||
WebRequestHandler::ProcessBooleanQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query,
|
||||
query::QueryPtr& query_ptr) {
|
||||
auto status = Status::OK();
|
||||
if (query_json.empty()) {
|
||||
return Status{SERVER_INVALID_ARGUMENT, "BoolQuery is null"};
|
||||
}
|
||||
for (auto& el : query_json.items()) {
|
||||
if (el.key() == "must") {
|
||||
boolean_query->SetOccur(query::Occur::MUST);
|
||||
auto must_json = el.value();
|
||||
if (!must_json.is_array()) {
|
||||
std::string msg = "Must json string is not an array";
|
||||
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
} else if (query_json.contains("should")) {
|
||||
boolean_query->SetOccur(query::Occur::SHOULD);
|
||||
auto should_json = query_json["should"];
|
||||
if (!should_json.is_array()) {
|
||||
std::string msg = "Should json string is not an array";
|
||||
return Status{BODY_PARSE_FAIL, msg};
|
||||
}
|
||||
|
||||
for (auto& json : should_json) {
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
for (auto& json : must_json) {
|
||||
auto must_query = std::make_shared<query::BooleanQuery>();
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
STATUS_CHECK(ProcessBooleanQueryJson(json, must_query, query_ptr));
|
||||
boolean_query->AddBooleanQuery(must_query);
|
||||
} else {
|
||||
std::string field_name;
|
||||
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name, query_ptr));
|
||||
if (!field_name.empty()) {
|
||||
query_ptr->index_fields.insert(field_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (el.key() == "should") {
|
||||
boolean_query->SetOccur(query::Occur::SHOULD);
|
||||
auto should_json = el.value();
|
||||
if (!should_json.is_array()) {
|
||||
std::string msg = "Should json string is not an array";
|
||||
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
|
||||
}
|
||||
|
||||
for (auto& json : should_json) {
|
||||
auto should_query = std::make_shared<query::BooleanQuery>();
|
||||
ProcessBoolQueryJson(json, should_query);
|
||||
boolean_query->AddBooleanQuery(should_query);
|
||||
} else {
|
||||
ProcessLeafQueryJson(json, boolean_query);
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
STATUS_CHECK(ProcessBooleanQueryJson(json, should_query, query_ptr));
|
||||
boolean_query->AddBooleanQuery(should_query);
|
||||
} else {
|
||||
std::string field_name;
|
||||
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name, query_ptr));
|
||||
if (!field_name.empty()) {
|
||||
query_ptr->index_fields.insert(field_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (el.key() == "must_not") {
|
||||
boolean_query->SetOccur(query::Occur::MUST_NOT);
|
||||
auto should_json = el.value();
|
||||
if (!should_json.is_array()) {
|
||||
std::string msg = "Must_not json string is not an array";
|
||||
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
} else if (query_json.contains("must_not")) {
|
||||
boolean_query->SetOccur(query::Occur::MUST_NOT);
|
||||
auto should_json = query_json["must_not"];
|
||||
if (!should_json.is_array()) {
|
||||
std::string msg = "Must_not json string is not an array";
|
||||
return Status{BODY_PARSE_FAIL, msg};
|
||||
}
|
||||
|
||||
for (auto& json : should_json) {
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
auto must_not_query = std::make_shared<query::BooleanQuery>();
|
||||
ProcessBoolQueryJson(json, must_not_query);
|
||||
boolean_query->AddBooleanQuery(must_not_query);
|
||||
} else {
|
||||
ProcessLeafQueryJson(json, boolean_query);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
std::string msg = "Must json string doesnot include right query";
|
||||
return Status{BODY_PARSE_FAIL, msg};
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ConvertRowToColumnJson(const std::vector<engine::AttrsData>& row_attrs, const std::vector<std::string>& field_names,
|
||||
const int64_t row_num, nlohmann::json& column_attrs_json) {
|
||||
// if (field_names.size() == 0) {
|
||||
// if (row_attrs.size() > 0) {
|
||||
// auto attr_it = row_attrs[0].attr_type_.begin();
|
||||
// for (; attr_it != row_attrs[0].attr_type_.end(); attr_it++) {
|
||||
// field_names.emplace_back(attr_it->first);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
for (uint64_t i = 0; i < field_names.size() - 1; i++) {
|
||||
std::vector<int64_t> int_data;
|
||||
std::vector<double> double_data;
|
||||
for (auto& attr : row_attrs) {
|
||||
int64_t int_value;
|
||||
double double_value;
|
||||
auto attr_data = attr.attr_data_.at(field_names[i]);
|
||||
switch (attr.attr_type_.at(field_names[i])) {
|
||||
case engine::DataType::INT8: {
|
||||
if (attr_data.size() == sizeof(int8_t)) {
|
||||
int_value = attr_data[0];
|
||||
int_data.emplace_back(int_value);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case engine::DataType::INT16: {
|
||||
if (attr_data.size() == sizeof(int16_t)) {
|
||||
memcpy(&int_value, attr_data.data(), sizeof(int16_t));
|
||||
int_data.emplace_back(int_value);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case engine::DataType::INT32: {
|
||||
if (attr_data.size() == sizeof(int32_t)) {
|
||||
memcpy(&int_value, attr_data.data(), sizeof(int32_t));
|
||||
int_data.emplace_back(int_value);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case engine::DataType::INT64: {
|
||||
if (attr_data.size() == sizeof(int64_t)) {
|
||||
memcpy(&int_value, attr_data.data(), sizeof(int64_t));
|
||||
int_data.emplace_back(int_value);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case engine::DataType::FLOAT: {
|
||||
if (attr_data.size() == sizeof(float)) {
|
||||
float float_value;
|
||||
memcpy(&float_value, attr_data.data(), sizeof(float));
|
||||
double_value = float_value;
|
||||
double_data.emplace_back(double_value);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case engine::DataType::DOUBLE: {
|
||||
if (attr_data.size() == sizeof(double)) {
|
||||
memcpy(&double_value, attr_data.data(), sizeof(double));
|
||||
double_data.emplace_back(double_value);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: { return; }
|
||||
}
|
||||
}
|
||||
if (int_data.size() > 0) {
|
||||
if (row_num == -1) {
|
||||
nlohmann::json int_data_json(int_data);
|
||||
column_attrs_json[field_names[i]] = int_data_json;
|
||||
} else {
|
||||
nlohmann::json topk_int_result;
|
||||
int64_t topk = int_data.size() / row_num;
|
||||
for (int64_t j = 0; j < row_num; j++) {
|
||||
std::vector<int64_t> one_int_result(topk);
|
||||
memcpy(one_int_result.data(), int_data.data() + j * topk, sizeof(int64_t) * topk);
|
||||
nlohmann::json one_int_result_json(one_int_result);
|
||||
std::string tag = "top" + std::to_string(j);
|
||||
topk_int_result[tag] = one_int_result_json;
|
||||
}
|
||||
column_attrs_json[field_names[i]] = topk_int_result;
|
||||
}
|
||||
} else if (double_data.size() > 0) {
|
||||
if (row_num == -1) {
|
||||
nlohmann::json double_data_json(double_data);
|
||||
column_attrs_json[field_names[i]] = double_data_json;
|
||||
} else {
|
||||
nlohmann::json topk_double_result;
|
||||
int64_t topk = int_data.size() / row_num;
|
||||
for (int64_t j = 0; j < row_num; j++) {
|
||||
std::vector<double> one_double_result(topk);
|
||||
memcpy(one_double_result.data(), double_data.data() + j * topk, sizeof(double) * topk);
|
||||
nlohmann::json one_double_result_json(one_double_result);
|
||||
std::string tag = "top" + std::to_string(j);
|
||||
topk_double_result[tag] = one_double_result_json;
|
||||
}
|
||||
column_attrs_json[field_names[i]] = topk_double_result;
|
||||
for (auto& json : should_json) {
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
auto must_not_query = std::make_shared<query::BooleanQuery>();
|
||||
STATUS_CHECK(ProcessBooleanQueryJson(json, must_not_query, query_ptr));
|
||||
boolean_query->AddBooleanQuery(must_not_query);
|
||||
} else {
|
||||
std::string field_name;
|
||||
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name, query_ptr));
|
||||
if (!field_name.empty()) {
|
||||
query_ptr->index_fields.insert(field_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::string msg = "BoolQuery json string does not include bool query";
|
||||
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
Status
|
||||
|
@ -724,7 +678,7 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
|
|||
auto boolean_query = std::make_shared<query::BooleanQuery>();
|
||||
query_ptr_ = std::make_shared<query::Query>();
|
||||
|
||||
status = ProcessBoolQueryJson(boolean_query_json, boolean_query);
|
||||
status = ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr_);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -749,22 +703,75 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
auto step = result->result_ids_.size() / result->row_num_;
|
||||
nlohmann::json search_result_json;
|
||||
auto step = result->result_ids_.size() / result->row_num_; // topk
|
||||
auto field_data = result->data_chunk_->fixed_fields_;
|
||||
for (int64_t i = 0; i < result->row_num_; i++) {
|
||||
nlohmann::json raw_result_json;
|
||||
for (size_t j = 0; j < step; j++) {
|
||||
nlohmann::json one_result_json;
|
||||
one_result_json["id"] = std::to_string(result->result_ids_.at(i * step + j));
|
||||
one_result_json["distance"] = std::to_string(result->result_distances_.at(i * step + j));
|
||||
raw_result_json.emplace_back(one_result_json);
|
||||
nlohmann::json one_entity_json;
|
||||
for (const auto& field : field_mappings) {
|
||||
auto field_name = field.first->GetName();
|
||||
switch ((int64_t)field.first->GetFtype()) {
|
||||
case engine::DataType::INT32: {
|
||||
int32_t int32_value;
|
||||
int64_t offset = (i * step + j) * sizeof(int32_t);
|
||||
memcpy(&int32_value, field_data.at(field_name)->data_.data() + offset, sizeof(int32_t));
|
||||
one_entity_json[field_name] = int32_value;
|
||||
break;
|
||||
}
|
||||
case engine::DataType::INT64: {
|
||||
int64_t int64_value;
|
||||
int64_t offset = (i * step + j) * sizeof(int64_t);
|
||||
memcpy(&int64_value, field_data.at(field_name)->data_.data() + offset, sizeof(int64_t));
|
||||
one_entity_json[field_name] = int64_value;
|
||||
break;
|
||||
}
|
||||
case engine::DataType::FLOAT: {
|
||||
float float_value;
|
||||
int64_t offset = (i * step + j) * sizeof(float);
|
||||
memcpy(&float_value, field_data.at(field_name)->data_.data() + offset, sizeof(float));
|
||||
one_entity_json[field_name] = float_value;
|
||||
break;
|
||||
}
|
||||
case engine::DataType::DOUBLE: {
|
||||
double double_value;
|
||||
int64_t offset = (i * step + j) * sizeof(double);
|
||||
memcpy(&double_value, field_data.at(field_name)->data_.data() + offset, sizeof(double));
|
||||
one_entity_json[field_name] = double_value;
|
||||
break;
|
||||
}
|
||||
case engine::DataType::VECTOR_FLOAT: {
|
||||
std::vector<float> float_vector;
|
||||
auto dim =
|
||||
field_data.at(field_name)->data_.size() / (result->result_ids_.size() * sizeof(float));
|
||||
int64_t offset = (i * step + j) * dim * sizeof(float);
|
||||
float_vector.resize(dim);
|
||||
memcpy(float_vector.data(), field_data.at(field_name)->data_.data() + offset,
|
||||
dim * sizeof(float));
|
||||
one_entity_json[field_name] = float_vector;
|
||||
break;
|
||||
}
|
||||
case engine::DataType::VECTOR_BINARY: {
|
||||
std::vector<int8_t> binary_vector;
|
||||
auto dim = field_data.at(field_name)->data_.size() / (result->result_ids_.size());
|
||||
int64_t offset = (i * step + j) * dim;
|
||||
binary_vector.resize(dim);
|
||||
memcpy(binary_vector.data(), field_data.at(field_name)->data_.data() + offset,
|
||||
dim * sizeof(int8_t));
|
||||
one_entity_json[field_name] = binary_vector;
|
||||
break;
|
||||
}
|
||||
default: { return Status(SERVER_UNEXPECTED_ERROR, "Return field data type is wrong"); }
|
||||
}
|
||||
}
|
||||
one_result_json["entity"] = one_entity_json;
|
||||
raw_result_json.push_back(one_result_json);
|
||||
}
|
||||
search_result_json.emplace_back(raw_result_json);
|
||||
result_json.emplace_back(raw_result_json);
|
||||
}
|
||||
nlohmann::json attr_json;
|
||||
// ConvertRowToColumnJson(result->attrs_, query_ptr_->field_names, result->row_num_, attr_json);
|
||||
result_json["Entity"] = attr_json;
|
||||
result_json["result"] = search_result_json;
|
||||
result_str = result_json.dump();
|
||||
}
|
||||
|
||||
|
@ -774,7 +781,7 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
|
|||
Status
|
||||
WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohmann::json& json,
|
||||
std::string& result_str) {
|
||||
std::vector<int64_t> vector_ids;
|
||||
std::vector<int64_t> entity_ids;
|
||||
if (!json.contains("ids")) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"delete\" must contains \"ids\"");
|
||||
}
|
||||
|
@ -788,10 +795,10 @@ WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohman
|
|||
if (!ValidateStringIsNumber(id_str).ok()) {
|
||||
return Status(ILLEGAL_BODY, "Members in \"ids\" must be integer string");
|
||||
}
|
||||
vector_ids.emplace_back(std::stol(id_str));
|
||||
entity_ids.emplace_back(std::stol(id_str));
|
||||
}
|
||||
|
||||
auto status = req_handler_.DeleteEntityByID(context_ptr_, collection_name, vector_ids);
|
||||
auto status = req_handler_.DeleteEntityByID(context_ptr_, collection_name, entity_ids);
|
||||
|
||||
nlohmann::json result_json;
|
||||
AddStatusToJson(result_json, status.code(), status.message());
|
||||
|
@ -807,89 +814,23 @@ WebRequestHandler::GetEntityByIDs(const std::string& collection_name, const std:
|
|||
engine::DataChunkPtr data_chunk;
|
||||
engine::snapshot::FieldElementMappings field_mappings;
|
||||
|
||||
std::vector<engine::AttrsData> attr_batch;
|
||||
std::vector<engine::VectorsData> vector_batch;
|
||||
auto status = req_handler_.GetEntityByID(context_ptr_, collection_name, ids, field_names, valid_row, field_mappings,
|
||||
data_chunk);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
std::vector<uint8_t> id_array = data_chunk->fixed_fields_[engine::FIELD_UID]->data_;
|
||||
|
||||
for (const auto& it : field_mappings) {
|
||||
std::string name = it.first->GetName();
|
||||
uint64_t type = it.first->GetFtype();
|
||||
std::vector<uint8_t>& data = data_chunk->fixed_fields_[name]->data_;
|
||||
if (type == engine::DataType::VECTOR_BINARY) {
|
||||
engine::VectorsData vectors_data;
|
||||
memcpy(vectors_data.binary_data_.data(), data.data(), data.size());
|
||||
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
|
||||
vector_batch.emplace_back(vectors_data);
|
||||
} else if (type == engine::DataType::VECTOR_FLOAT) {
|
||||
engine::VectorsData vectors_data;
|
||||
memcpy(vectors_data.float_data_.data(), data.data(), data.size());
|
||||
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
|
||||
vector_batch.emplace_back(vectors_data);
|
||||
} else {
|
||||
engine::AttrsData attrs_data;
|
||||
attrs_data.attr_type_[name] = static_cast<engine::DataType>(type);
|
||||
attrs_data.attr_data_[name] = data;
|
||||
memcpy(attrs_data.id_array_.data(), id_array.data(), id_array.size());
|
||||
attr_batch.emplace_back(attrs_data);
|
||||
int64_t valid_size = 0;
|
||||
for (auto row : valid_row) {
|
||||
if (row) {
|
||||
valid_size++;
|
||||
}
|
||||
}
|
||||
|
||||
bool bin;
|
||||
status = IsBinaryCollection(collection_name, bin);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
nlohmann::json vectors_json, attrs_json;
|
||||
for (size_t i = 0; i < vector_batch.size(); i++) {
|
||||
nlohmann::json vector_json;
|
||||
if (bin) {
|
||||
vector_json["vector"] = vector_batch.at(i).binary_data_;
|
||||
} else {
|
||||
vector_json["vector"] = vector_batch.at(i).float_data_;
|
||||
}
|
||||
vector_json["id"] = std::to_string(ids[i]);
|
||||
vectors_json.push_back(vector_json);
|
||||
}
|
||||
ConvertRowToColumnJson(attr_batch, field_names, -1, attrs_json);
|
||||
json_out["vectors"] = vectors_json;
|
||||
json_out["attributes"] = attrs_json;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::GetVectorsByIDs(const std::string& collection_name, const std::vector<int64_t>& ids,
|
||||
nlohmann::json& json_out) {
|
||||
std::vector<engine::VectorsData> vector_batch;
|
||||
auto status = Status::OK();
|
||||
// auto status = req_handler_.GetVectorsByID(context_ptr_, collection_name, ids, vector_batch);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
bool bin;
|
||||
status = IsBinaryCollection(collection_name, bin);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
nlohmann::json vectors_json;
|
||||
for (size_t i = 0; i < vector_batch.size(); i++) {
|
||||
nlohmann::json vector_json;
|
||||
if (bin) {
|
||||
vector_json["vector"] = vector_batch.at(i).binary_data_;
|
||||
} else {
|
||||
vector_json["vector"] = vector_batch.at(i).float_data_;
|
||||
}
|
||||
vector_json["id"] = std::to_string(ids[i]);
|
||||
json_out.push_back(vector_json);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> id_data = data_chunk->fixed_fields_[engine::FIELD_UID]->data_;
|
||||
std::vector<int64_t> id_array(valid_size);
|
||||
memcpy(id_array.data(), id_data.data(), valid_size * sizeof(int64_t));
|
||||
CopyData2Json(data_chunk, field_mappings, id_array, json_out);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1169,34 +1110,7 @@ WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dt
|
|||
* Collection {
|
||||
*/
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::CreateCollection(const CollectionRequestDto::ObjectWrapper& collection_schema) {
|
||||
if (nullptr == collection_schema->collection_name.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'collection_name\' is missing")
|
||||
}
|
||||
|
||||
if (nullptr == collection_schema->dimension.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'dimension\' is missing")
|
||||
}
|
||||
|
||||
if (nullptr == collection_schema->index_file_size.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_file_size\' is missing")
|
||||
}
|
||||
|
||||
if (nullptr == collection_schema->metric_type.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'metric_type\' is missing")
|
||||
}
|
||||
|
||||
auto status = Status::OK();
|
||||
// auto status = req_handler_.CreateCollection(
|
||||
// context_ptr_, collection_schema->collection_name->std_str(), collection_schema->dimension,
|
||||
// collection_schema->index_file_size,
|
||||
// static_cast<int64_t>(MetricNameMap.at(collection_schema->metric_type->std_str())));
|
||||
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::CreateHybridCollection(const milvus::server::web::OString& body) {
|
||||
WebRequestHandler::CreateCollection(const milvus::server::web::OString& body) {
|
||||
auto json_str = nlohmann::json::parse(body->c_str());
|
||||
std::string collection_name = json_str["collection_name"];
|
||||
|
||||
|
@ -1208,24 +1122,14 @@ WebRequestHandler::CreateHybridCollection(const milvus::server::web::OString& bo
|
|||
|
||||
field_schema.field_params_ = field["extra_params"];
|
||||
|
||||
const std::string& field_type = field["field_type"];
|
||||
if (field_type == "int8") {
|
||||
field_schema.field_type_ = engine::DataType::INT8;
|
||||
} else if (field_type == "int16") {
|
||||
field_schema.field_type_ = engine::DataType::INT16;
|
||||
} else if (field_type == "int32") {
|
||||
field_schema.field_type_ = engine::DataType::INT32;
|
||||
} else if (field_type == "int64") {
|
||||
field_schema.field_type_ = engine::DataType::INT64;
|
||||
} else if (field_type == "float") {
|
||||
field_schema.field_type_ = engine::DataType::FLOAT;
|
||||
} else if (field_type == "double") {
|
||||
field_schema.field_type_ = engine::DataType::DOUBLE;
|
||||
} else if (field_type == "vector") {
|
||||
} else {
|
||||
std::string field_type = field["field_type"];
|
||||
std::transform(field_type.begin(), field_type.end(), field_type.begin(), ::tolower);
|
||||
|
||||
if (str2type.find(field_type) == str2type.end()) {
|
||||
std::string msg = field_name + " has wrong field_type";
|
||||
RETURN_STATUS_DTO(BODY_PARSE_FAIL, msg.c_str());
|
||||
}
|
||||
field_schema.field_type_ = str2type.at(field_type);
|
||||
|
||||
fields[field_name] = field_schema;
|
||||
}
|
||||
|
@ -1336,18 +1240,15 @@ WebRequestHandler::DropCollection(const OString& collection_name) {
|
|||
*/
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::CreateIndex(const OString& collection_name, const OString& body) {
|
||||
WebRequestHandler::CreateIndex(const OString& collection_name, const OString& field_name, const OString& body) {
|
||||
try {
|
||||
auto request_json = nlohmann::json::parse(body->std_str());
|
||||
std::string field_name, index_name;
|
||||
if (!request_json.contains("index_type")) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_type\' is required");
|
||||
}
|
||||
|
||||
auto status = Status::OK();
|
||||
// auto status =
|
||||
// req_handler_.CreateIndex(context_ptr_, collection_name->std_str(), index,
|
||||
// request_json["params"]);
|
||||
auto status =
|
||||
req_handler_.CreateIndex(context_ptr_, collection_name->std_str(), field_name->std_str(), "", request_json);
|
||||
ASSIGN_RETURN_STATUS_DTO(status);
|
||||
} catch (nlohmann::detail::parse_error& e) {
|
||||
RETURN_STATUS_DTO(BODY_PARSE_FAIL, e.what())
|
||||
|
@ -1359,10 +1260,8 @@ WebRequestHandler::CreateIndex(const OString& collection_name, const OString& bo
|
|||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::DropIndex(const OString& collection_name) {
|
||||
auto status = Status::OK();
|
||||
// auto status = req_handler_.DropIndex(context_ptr_, collection_name->std_str());
|
||||
|
||||
WebRequestHandler::DropIndex(const OString& collection_name, const OString& field_name) {
|
||||
auto status = req_handler_.DropIndex(context_ptr_, collection_name->std_str(), field_name->std_str(), "");
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
|
@ -1583,9 +1482,12 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
|
|||
std::string partition_name = body_json["partition_tag"];
|
||||
int32_t row_num = body_json["row_num"];
|
||||
|
||||
CollectionSchema collection_schema;
|
||||
std::unordered_map<std::string, engine::DataType> field_types;
|
||||
auto status = Status::OK();
|
||||
// auto status = req_handler_.DescribeHybridCollection(context_ptr_, collection_name->c_str(), field_types);
|
||||
auto status = req_handler_.GetCollectionInfo(context_ptr_, collection_name->std_str(), collection_schema);
|
||||
for (const auto& field : collection_schema.fields_) {
|
||||
field_types.insert({field.first, field.second.field_type_});
|
||||
}
|
||||
|
||||
auto entities = body_json["entity"];
|
||||
if (!entities.is_array()) {
|
||||
|
@ -1621,15 +1523,12 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
|
|||
break;
|
||||
}
|
||||
case engine::DataType::VECTOR_FLOAT: {
|
||||
bool bin_flag;
|
||||
status = IsBinaryCollection(collection_name->c_str(), bin_flag);
|
||||
if (!status.ok()) {
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
// engine::VectorsData vectors;
|
||||
// CopyRecordsFromJson(field_value, vectors, bin_flag);
|
||||
// vector_datas.insert(std::make_pair(field_name, vectors));
|
||||
CopyRecordsFromJson(field_value, temp_data, false);
|
||||
break;
|
||||
}
|
||||
case engine::DataType::VECTOR_BINARY: {
|
||||
CopyRecordsFromJson(field_value, temp_data, true);
|
||||
break;
|
||||
}
|
||||
default: {}
|
||||
}
|
||||
|
@ -1702,47 +1601,7 @@ WebRequestHandler::GetEntity(const milvus::server::web::OString& collection_name
|
|||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response) {
|
||||
auto status = Status::OK();
|
||||
try {
|
||||
auto query_ids = query_params.get("ids");
|
||||
if (query_ids == nullptr || query_ids.get() == nullptr) {
|
||||
RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param ids is required.");
|
||||
}
|
||||
|
||||
std::vector<std::string> ids;
|
||||
StringHelpFunctions::SplitStringByDelimeter(query_ids->c_str(), ",", ids);
|
||||
|
||||
std::vector<int64_t> vector_ids;
|
||||
for (auto& id : ids) {
|
||||
vector_ids.push_back(std::stol(id));
|
||||
}
|
||||
engine::VectorsData vectors;
|
||||
nlohmann::json vectors_json;
|
||||
status = GetVectorsByIDs(collection_name->std_str(), vector_ids, vectors_json);
|
||||
if (!status.ok()) {
|
||||
response = "NULL";
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
FloatJson json;
|
||||
json["code"] = (int64_t)status.code();
|
||||
json["message"] = status.message();
|
||||
if (vectors_json.empty()) {
|
||||
json["vectors"] = std::vector<int64_t>();
|
||||
} else {
|
||||
json["vectors"] = vectors_json;
|
||||
}
|
||||
response = json.dump().c_str();
|
||||
} catch (std::exception& e) {
|
||||
RETURN_STATUS_DTO(SERVER_UNEXPECTED_ERROR, e.what());
|
||||
}
|
||||
|
||||
ASSIGN_RETURN_STATUS_DTO(status);
|
||||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::VectorsOp(const OString& collection_name, const OString& payload, OString& response) {
|
||||
WebRequestHandler::EntityOp(const OString& collection_name, const OString& payload, OString& response) {
|
||||
auto status = Status::OK();
|
||||
std::string result_str;
|
||||
|
||||
|
|
|
@ -85,7 +85,11 @@ class WebRequestHandler {
|
|||
IsBinaryCollection(const std::string& collection_name, bool& bin);
|
||||
|
||||
Status
|
||||
CopyRecordsFromJson(const nlohmann::json& json, engine::VectorsData& vectors, bool bin);
|
||||
CopyRecordsFromJson(const nlohmann::json& json, std::vector<uint8_t>& vectors_data, bool bin);
|
||||
|
||||
Status
|
||||
CopyData2Json(const engine::DataChunkPtr& data_chunk, const engine::snapshot::FieldElementMappings& field_mappings,
|
||||
const std::vector<int64_t>& id_array, nlohmann::json& json_res);
|
||||
|
||||
protected:
|
||||
Status
|
||||
|
@ -124,10 +128,12 @@ class WebRequestHandler {
|
|||
SetConfig(const nlohmann::json& json, std::string& result_str);
|
||||
|
||||
Status
|
||||
ProcessLeafQueryJson(const nlohmann::json& json, query::BooleanQueryPtr& boolean_query);
|
||||
ProcessLeafQueryJson(const nlohmann::json& json, query::BooleanQueryPtr& boolean_query, std::string& field_name,
|
||||
query::QueryPtr& query_ptr);
|
||||
|
||||
Status
|
||||
ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query);
|
||||
ProcessBooleanQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query,
|
||||
query::QueryPtr& query_ptr);
|
||||
|
||||
Status
|
||||
Search(const std::string& collection_name, const nlohmann::json& json, std::string& result_str);
|
||||
|
@ -135,9 +141,6 @@ class WebRequestHandler {
|
|||
Status
|
||||
DeleteByIDs(const std::string& collection_name, const nlohmann::json& json, std::string& result_str);
|
||||
|
||||
Status
|
||||
GetVectorsByIDs(const std::string& collection_name, const std::vector<int64_t>& ids, nlohmann::json& json_out);
|
||||
|
||||
Status
|
||||
GetEntityByIDs(const std::string& collection_name, const std::vector<int64_t>& ids,
|
||||
std::vector<std::string>& field_names, nlohmann::json& json_out);
|
||||
|
@ -167,12 +170,10 @@ class WebRequestHandler {
|
|||
#endif
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
CreateCollection(const CollectionRequestDto::ObjectWrapper& table_schema);
|
||||
StatusDto::ObjectWrapper
|
||||
ShowCollections(const OQueryParams& query_params, OString& result);
|
||||
CreateCollection(const milvus::server::web::OString& body);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
CreateHybridCollection(const OString& body);
|
||||
ShowCollections(const OQueryParams& query_params, OString& result);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
GetCollection(const OString& collection_name, const OQueryParams& query_params, OString& result);
|
||||
|
@ -181,10 +182,10 @@ class WebRequestHandler {
|
|||
DropCollection(const OString& collection_name);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
CreateIndex(const OString& collection_name, const OString& body);
|
||||
CreateIndex(const OString& collection_name, const OString& field_name, const OString& body);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
DropIndex(const OString& collection_name);
|
||||
DropIndex(const OString& collection_name, const OString& field_name);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
CreatePartition(const OString& collection_name, const PartitionRequestDto::ObjectWrapper& param);
|
||||
|
@ -221,7 +222,7 @@ class WebRequestHandler {
|
|||
GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
VectorsOp(const OString& collection_name, const OString& payload, OString& response);
|
||||
EntityOp(const OString& collection_name, const OString& payload, OString& response);
|
||||
|
||||
/**
|
||||
*
|
||||
|
|
|
@ -27,7 +27,7 @@ const char* COLLECTION_NAME = milvus_sdk::Utils::GenCollectionName().c_str();
|
|||
constexpr int64_t COLLECTION_DIMENSION = 512;
|
||||
constexpr int64_t COLLECTION_INDEX_FILE_SIZE = 1024;
|
||||
constexpr milvus::MetricType COLLECTION_METRIC_TYPE = milvus::MetricType::L2;
|
||||
constexpr int64_t BATCH_ENTITY_COUNT = 4000;
|
||||
constexpr int64_t BATCH_ENTITY_COUNT = 10000;
|
||||
constexpr int64_t NQ = 5;
|
||||
constexpr int64_t TOP_K = 10;
|
||||
constexpr int64_t NPROBE = 32;
|
||||
|
|
Loading…
Reference in New Issue