Fix http bug & add binary vectors support (#1073)

* refactoring(create_table done)

* refactoring

* refactor server delivery (insert done)

* refactoring server module (count_table done)

* server refactor done

* cmake pass

* refactor server module done.

* set grpc response status correctly

* format done.

* fix redefine ErrorMap()

* optimize insert reducing ids data copy

* optimize grpc request with reducing data copy

* clang format

* [skip ci] Refactor server module done. update changlog. prepare for PR

* remove explicit and change int32_t to int64_t

* add web server

* [skip ci] add license in web module

* modify header include & comment oatpp environment config

* add port configure & create table in handler

* modify web url

* simple url complation done & add swagger

* make sure web url

* web functionality done. debuging

* add web unittest

* web test pass

* add web server port

* add web server port in template

* update unittest cmake file

* change web server default port to 19121

* rename method in web module & unittest pass

* add search case in unittest for web module

* rename some variables

* fix bug

* unittest pass

* web prepare

* fix cmd bug(check server status)

* update changlog

* add web port validate & default set

* clang-format pass

* add web port test in unittest

* add CORS & redirect root to swagger ui

* add web status

* web table method func cascade test pass

* add config url in web module

* modify thirdparty cmake to avoid building oatpp test

* clang format

* update changlog

* add constants in web module

* reserve Config.cpp

* fix constants reference bug

* replace web server with async module

* modify component to support async

* format

* developing controller & add test clent into unittest

* add web port into demo/server_config

* modify thirdparty cmake to allow build test

* remove  unnecessary comment

* add endpoint info in controller

* finish web test(bug here)

* clang format

* add web test cpp to lint exclusions

* check null field in GetConfig

* add macro RETURN STATUS DTo

* fix cmake conflict

* fix crash when exit server

* remove surplus comments & add http param check

* add uri /docs to direct swagger

* format

* change cmd to system

* add default value & unittest in web module

* add macros to judge if GPU supported

* add macros in unit & add default in index dto & print error message when bind http port fail

* format (fix #788)

* fix cors bug (not completed)

* comment cors

* change web framework to simple api

* comments optimize

* change to simple API

* remove comments in controller.hpp

* remove EP_COMMON_CMAKE_ARGS in oatpp and oatpp-swagger

* add ep cmake args to sqlite

* clang-format

* change a format

* test pass

* change name to

* fix compiler issue(oatpp-swagger depend on oatpp)

* add & in start_server.h

* specify lib location with oatpp and oatpp-swagger

* add comments

* add swagger definition

* [skip ci] change http method options status code

* remove oatpp swagger(fix #970)

* remove comments

* check Start web behavior

* add default to cpu_cache_capacity

* remove swagger component.hpp & /docs url

* remove /docs info

* remove /docs in unittest

* remove space in test rpc

* remove repeate info in CHANGLOG

* change cache_insert_data default value as a constant

* [skip ci] Fix some broken links (#960)

* [skip ci] Fix broken link

* [skip ci] Fix broken link

* [skip ci] Fix broken link

* [skip ci] Fix broken links

* fix issue 373 (#964)

* fix issue 373

* Adjustment format

* Adjustment format

* Adjustment format

* change readme

* #966 update NOTICE.md (#967)

* remove comments

* check Start web behavior

* add default to cpu_cache_capacity

* remove swagger component.hpp & /docs url

* remove /docs info

* remove /docs in unittest

* remove space in test rpc

* remove repeate info in CHANGLOG

* change cache_insert_data default value as a constant

* adjust web port cofig place

* rename web_port variable

* change gpu resources invoke way to cmd()

* set advanced config name add DEFAULT

* change config setting to cmd

* modify ..

* optimize code

* assign TableDto' count default value 0 (fix #995)

* check if table exists when show partitions (fix #1028)

* check table exists when drop partition (fix #1029)

* check if partition name is legal (fix #1022)

* modify status code when partition tag is illegal

* update changlog

* add info to /system url

* add binary index and add bin uri & handler method(not completed)

* optimize http insert and search time(fix #1066) | add binary vectors support(fix #1067)

* fix test partition bug

* fix test bug when check insert records

* add binary vectors test

* add default for offset and page_size

* fix uinttest bug

* [skip ci] remove comments

* optimize web code for PR comments

* add new folder named utils

Co-authored-by: jielinxu <52057195+jielinxu@users.noreply.github.com>
Co-authored-by: JackLCL <53512883+JackLCL@users.noreply.github.com>
Co-authored-by: Cai Yudong <yudong.cai@zilliz.com>
pull/1085/head^2
BossZou 2020-01-18 10:05:49 +08:00 committed by Jin Hai
parent 4cd02b0976
commit 015f0352a6
18 changed files with 906 additions and 468 deletions

View File

@ -11,6 +11,13 @@ Please mark all change in change log and use the issue from GitHub
- \#805 - IVFTest.gpu_seal_test unittest failed
- \#831 - Judge branch error in CommonUtil.cpp
- \#977 - Server crash when create tables concurrently
- \#995 - table count set to 0 if no tables found
- \#1010 - improve error message when offset or page_size is equal 0
- \#1022 - check if partition name is legal
- \#1028 - check if table exists when show partitions
- \#1029 - check if table exists when try to delete partition
- \#1066 - optimize http insert and search speed
- \#1067 - Add binary vectors support in http server
## Feature
- \#216 - Add CLI to get server info

View File

@ -96,12 +96,14 @@ aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/handler web_handler_fi
aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/component web_conponent_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/controller web_controller_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/dto web_dto_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/utils web_utils_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl web_impl_files)
set(web_server_files
${web_handler_files}
${web_conponent_files}
${web_controller_files}
${web_dto_files}
${web_utils_files}
${web_impl_files}
)

View File

@ -47,8 +47,19 @@ DropPartitionRequest::OnExecute() {
std::string table_name = table_name_;
std::string partition_name = partition_name_;
std::string partition_tag = tag_;
bool exists;
auto status = DBWrapper::DB()->HasTable(table_name, exists);
if (!status.ok()) {
return status;
}
if (!exists) {
return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists");
}
if (!partition_name.empty()) {
auto status = ValidationUtil::ValidateTableName(partition_name);
status = ValidationUtil::ValidateTableName(partition_name);
if (!status.ok()) {
return status;
}
@ -68,7 +79,7 @@ DropPartitionRequest::OnExecute() {
return DBWrapper::DB()->DropPartition(partition_name);
} else {
auto status = ValidationUtil::ValidateTableName(table_name);
status = ValidationUtil::ValidateTableName(table_name);
if (!status.ok()) {
return status;
}

View File

@ -48,6 +48,16 @@ ShowPartitionsRequest::OnExecute() {
return status;
}
bool exists = false;
status = DBWrapper::DB()->HasTable(table_name_, exists);
if (!status.ok()) {
return status;
}
if (!exists) {
return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists");
}
std::vector<engine::meta::TableSchema> schema_array;
status = DBWrapper::DB()->ShowPartitions(table_name_, schema_array);
if (!status.ok()) {

View File

@ -1,4 +1,3 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
@ -46,6 +45,9 @@ static const char* NAME_ENGINE_TYPE_IVFPQ = "IVFPQ";
static const char* NAME_METRIC_TYPE_L2 = "L2";
static const char* NAME_METRIC_TYPE_IP = "IP";
static const char* NAME_METRIC_TYPE_HAMMING = "HAMMING";
static const char* NAME_METRIC_TYPE_JACCARD = "JACCARD";
static const char* NAME_METRIC_TYPE_TANIMOTO = "TANIMOTO";
////////////////////////////////////////////////////

View File

@ -21,9 +21,9 @@
#include <unordered_map>
#include <oatpp/core/data/mapping/type/Object.hpp>
#include <oatpp/web/protocol/http/Http.hpp>
#include "db/engine/ExecutionEngine.h"
#include "server/web_impl/Constants.h"
namespace milvus {
@ -31,6 +31,8 @@ namespace server {
namespace web {
using OString = oatpp::data::mapping::type::String;
using OInt8 = oatpp::data::mapping::type::Int8;
using OInt16 = oatpp::data::mapping::type::Int16;
using OInt64 = oatpp::data::mapping::type::Int64;
using OFloat32 = oatpp::data::mapping::type::Float32;
template <class T>
@ -65,10 +67,11 @@ enum StatusCode : int {
ILLEGAL_METRIC_TYPE = 23,
OUT_OF_MEMORY = 24,
// HTTP status code
// HTTP error code
PATH_PARAM_LOSS = 31,
QUERY_PARAM_LOSS = 32,
BODY_FIELD_LOSS = 33,
ILLEGAL_QUERY_PARAM = 36,
};
static const std::unordered_map<engine::EngineType, std::string> IndexMap = {
@ -92,11 +95,17 @@ static const std::unordered_map<std::string, engine::EngineType> IndexNameMap =
static const std::unordered_map<engine::MetricType, std::string> MetricMap = {
{engine::MetricType::L2, NAME_METRIC_TYPE_L2},
{engine::MetricType::IP, NAME_METRIC_TYPE_IP},
{engine::MetricType::HAMMING, NAME_METRIC_TYPE_HAMMING},
{engine::MetricType::JACCARD, NAME_METRIC_TYPE_JACCARD},
{engine::MetricType::TANIMOTO, NAME_METRIC_TYPE_TANIMOTO},
};
static const std::unordered_map<std::string, engine::MetricType> MetricNameMap = {
{NAME_METRIC_TYPE_L2, engine::MetricType::L2},
{NAME_METRIC_TYPE_IP, engine::MetricType::IP},
{NAME_METRIC_TYPE_HAMMING, engine::MetricType::HAMMING},
{NAME_METRIC_TYPE_JACCARD, engine::MetricType::JACCARD},
{NAME_METRIC_TYPE_TANIMOTO, engine::MetricType::TANIMOTO},
};
} // namespace web

View File

@ -22,8 +22,6 @@
#include <string>
#include <thread>
#include <oatpp/network/server/Server.hpp>
#include "server/web_impl/component/AppComponent.hpp"
#include "utils/Status.h"

View File

@ -98,10 +98,12 @@ class WebController : public oatpp::web::server::api::ApiController {
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
auto status_dto = handler.GetDevices(devices_dto);
std::shared_ptr<OutgoingResponse> response;
if (0 == status_dto->code->getValue()) {
response = createDtoResponse(Status::CODE_200, devices_dto);
} else {
response = createDtoResponse(Status::CODE_400, status_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, devices_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
@ -127,11 +129,14 @@ class WebController : public oatpp::web::server::api::ApiController {
WebRequestHandler handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
auto status_dto = handler.GetAdvancedConfig(config_dto);
std::shared_ptr<OutgoingResponse> response;
if (0 == status_dto->code->getValue()) {
response = createDtoResponse(Status::CODE_200, config_dto);
} else {
response = createDtoResponse(Status::CODE_400, status_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, config_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
@ -152,13 +157,17 @@ class WebController : public oatpp::web::server::api::ApiController {
WebRequestHandler handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
auto status_dto = handler.SetAdvancedConfig(body);
std::shared_ptr<OutgoingResponse> response;
if (0 == status_dto->code->getValue()) {
return createDtoResponse(Status::CODE_200, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
auto status_dto = handler.SetAdvancedConfig(body);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
#ifdef MILVUS_GPU_VERSION
@ -182,13 +191,18 @@ class WebController : public oatpp::web::server::api::ApiController {
auto gpu_config_dto = GPUConfigDto::createShared();
WebRequestHandler handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
auto status_dto = handler.GetGpuConfig(gpu_config_dto);
if (0 == status_dto->code->getValue()) {
return createDtoResponse(Status::CODE_200, gpu_config_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.GetGpuConfig(gpu_config_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, gpu_config_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ENDPOINT_INFO(SetGPUConfig) {
@ -207,11 +221,15 @@ class WebController : public oatpp::web::server::api::ApiController {
auto status_dto = handler.SetGpuConfig(body);
std::shared_ptr<OutgoingResponse> response;
if (0 == status_dto->code->getValue()) {
return createDtoResponse(Status::CODE_200, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
#endif
@ -227,7 +245,7 @@ class WebController : public oatpp::web::server::api::ApiController {
info->addConsumes<TableRequestDto::ObjectWrapper>("application/json");
info->addResponse<TableFieldsDto::ObjectWrapper>(Status::CODE_201, "application/json");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_201, "application/json");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_400, "application/json");
}
@ -237,12 +255,17 @@ class WebController : public oatpp::web::server::api::ApiController {
WebRequestHandler handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.CreateTable(body);
if (0 != status_dto->code) {
return createDtoResponse(Status::CODE_400, status_dto);
} else {
return createDtoResponse(Status::CODE_201, status_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ENDPOINT_INFO(ShowTables) {
@ -257,17 +280,25 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(ShowTables)
ENDPOINT("GET", "/tables", ShowTables, QUERY(Int64, offset, "offset"), QUERY(Int64, page_size, "page_size")) {
ENDPOINT("GET", "/tables", ShowTables, REQUEST(
const std::shared_ptr<IncomingRequest>&, request)) {
WebRequestHandler handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
auto response_dto = TableListFieldsDto::createShared();
auto status_dto = handler.ShowTables(offset, page_size, response_dto);
auto offset = request->getQueryParameter("offset", "0");
auto page_size = request->getQueryParameter("page_size", "10");
std::shared_ptr<OutgoingResponse> response;
if (0 == status_dto->code->getValue()) {
return createDtoResponse(Status::CODE_200, response_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
auto status_dto = handler.ShowTables(offset, page_size, response_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, response_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ADD_CORS(TableOptions)
@ -296,21 +327,25 @@ class WebController : public oatpp::web::server::api::ApiController {
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
auto fields_dto = TableFieldsDto::createShared();
auto status_dto = handler.GetTable(table_name, query_params, fields_dto);
auto code = status_dto->code->getValue();
if (0 == code) {
return createDtoResponse(Status::CODE_200, fields_dto);
} else if (StatusCode::TABLE_NOT_EXISTS == code) {
return createDtoResponse(Status::CODE_404, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
std::shared_ptr<OutgoingResponse> response;
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, fields_dto);
break;
case StatusCode::TABLE_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ENDPOINT_INFO(DropTable) {
info->summary = "Drop table";
info->pathParams.add<String>("table_name");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_204, "application/json");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_400, "application/json");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_404, "application/json");
@ -321,15 +356,21 @@ class WebController : public oatpp::web::server::api::ApiController {
ENDPOINT("DELETE", "/tables/{table_name}", DropTable, PATH(String, table_name)) {
WebRequestHandler handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.DropTable(table_name);
auto code = status_dto->code->getValue();
if (0 == code) {
return createDtoResponse(Status::CODE_204, status_dto);
} else if (StatusCode::TABLE_NOT_EXISTS == code) {
return createDtoResponse(Status::CODE_404, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_204, status_dto);
break;
case StatusCode::TABLE_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ADD_CORS(IndexOptions)
@ -341,29 +382,34 @@ class WebController : public oatpp::web::server::api::ApiController {
ENDPOINT_INFO(CreateIndex) {
info->summary = "Create index";
info->pathParams.add<String>("table_name");
info->addConsumes<IndexRequestDto::ObjectWrapper>("application/json");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_201, "application/json");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_400, "application/json");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_404, "application/json");
}
ADD_CORS(CreateIndex)
ENDPOINT("POST",
"/tables/{table_name}/indexes",
CreateIndex,
PATH(String, table_name),
BODY_DTO(IndexRequestDto::ObjectWrapper, body)) {
ENDPOINT("POST", "/tables/{table_name}/indexes", CreateIndex,
PATH(String, table_name), BODY_DTO(IndexRequestDto::ObjectWrapper, body)) {
auto handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
auto status_dto = handler.CreateIndex(table_name, body);
if (0 == status_dto->code->getValue()) {
return createDtoResponse(Status::CODE_201, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.CreateIndex(table_name, body);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, status_dto);
break;
case StatusCode::TABLE_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ENDPOINT_INFO(GetIndex) {
@ -382,15 +428,21 @@ class WebController : public oatpp::web::server::api::ApiController {
auto index_dto = IndexDto::createShared();
auto handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.GetIndex(table_name, index_dto);
auto code = status_dto->code->getValue();
if (0 == code) {
return createDtoResponse(Status::CODE_200, index_dto);
} else if (StatusCode::TABLE_NOT_EXISTS == code) {
return createDtoResponse(Status::CODE_404, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, index_dto);
break;
case StatusCode::TABLE_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ENDPOINT_INFO(DropIndex) {
@ -408,15 +460,21 @@ class WebController : public oatpp::web::server::api::ApiController {
ENDPOINT("DELETE", "/tables/{table_name}/indexes", DropIndex, PATH(String, table_name)) {
auto handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.DropIndex(table_name);
auto code = status_dto->code->getValue();
if (0 == code) {
return createDtoResponse(Status::CODE_204, status_dto);
} else if (StatusCode::TABLE_NOT_EXISTS == code) {
return createDtoResponse(Status::CODE_404, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_204, status_dto);
break;
case StatusCode::TABLE_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ADD_CORS(PartitionsOptions)
@ -434,6 +492,7 @@ class WebController : public oatpp::web::server::api::ApiController {
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_201, "application/json");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_400, "application/json");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_404, "application/json");
}
ADD_CORS(CreatePartition)
@ -442,13 +501,21 @@ class WebController : public oatpp::web::server::api::ApiController {
CreatePartition, PATH(String, table_name), BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) {
auto handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
auto status_dto = handler.CreatePartition(table_name, body);
if (0 == status_dto->code->getValue()) {
return createDtoResponse(Status::CODE_201, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.CreatePartition(table_name, body);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, status_dto);
break;
case StatusCode::TABLE_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ENDPOINT_INFO(ShowPartitions) {
@ -469,26 +536,29 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(ShowPartitions)
ENDPOINT("GET",
"/tables/{table_name}/partitions",
ShowPartitions,
PATH(String, table_name),
QUERY(Int64, offset, "offset"),
QUERY(Int64, page_size, "page_size")) {
auto status_dto = StatusDto::createShared();
ENDPOINT("GET", "/tables/{table_name}/partitions", ShowPartitions,
PATH(String, table_name), REQUEST(
const std::shared_ptr<IncomingRequest>&, request)) {
auto offset = request->getQueryParameter("offset", "0");
auto page_size = request->getQueryParameter("page_size", "10");
auto partition_list_dto = PartitionListDto::createShared();
auto handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
status_dto = handler.ShowPartitions(offset, page_size, table_name, partition_list_dto);
int64_t code = status_dto->code->getValue();
if (0 == code) {
return createDtoResponse(Status::CODE_200, partition_list_dto);
} else if (StatusCode::TABLE_NOT_EXISTS == code) {
return createDtoResponse(Status::CODE_404, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.ShowPartitions(offset, page_size, table_name, partition_list_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, partition_list_dto);
break;
case StatusCode::TABLE_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ADD_CORS(PartitionOptions)
@ -510,22 +580,31 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(DropPartition)
ENDPOINT("DELETE",
"/tables/{table_name}/partitions/{partition_tag}",
DropPartition,
PATH(String, table_name),
PATH(String, partition_tag)) {
ENDPOINT("DELETE", "/tables/{table_name}/partitions/{partition_tag}", DropPartition,
PATH(String, table_name), PATH(String, partition_tag)) {
auto handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.DropPartition(table_name, partition_tag);
auto code = status_dto->code->getValue();
if (0 == code) {
return createDtoResponse(Status::CODE_204, status_dto);
} else if (StatusCode::TABLE_NOT_EXISTS == code) {
return createDtoResponse(Status::CODE_404, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_204, status_dto);
break;
case StatusCode::TABLE_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ADD_CORS(VectorsOptions)
ENDPOINT("OPTIONS", "/tables/{table_name}/vectors", VectorsOptions) {
return createResponse(Status::CODE_204, "No Content");
}
ENDPOINT_INFO(Insert) {
@ -540,32 +619,28 @@ class WebController : public oatpp::web::server::api::ApiController {
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_404, "application/json");
}
ADD_CORS(VectorsOptions)
ENDPOINT("OPTIONS", "/tables/{table_name}/vectors", VectorsOptions) {
return createResponse(Status::CODE_204, "No Content");
}
ADD_CORS(Insert)
ENDPOINT("POST",
"/tables/{table_name}/vectors",
Insert,
PATH(String, table_name),
BODY_DTO(InsertRequestDto::ObjectWrapper, body)) {
ENDPOINT("POST", "/tables/{table_name}/vectors", Insert,
PATH(String, table_name), BODY_DTO(InsertRequestDto::ObjectWrapper, body)) {
auto ids_dto = VectorIdsDto::createShared();
WebRequestHandler handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
auto status_dto = handler.Insert(table_name, body, ids_dto);
int64_t code = status_dto->code->getValue();
if (0 == code) {
return createDtoResponse(Status::CODE_201, ids_dto);
} else if (StatusCode::TABLE_NOT_EXISTS == code) {
return createDtoResponse(Status::CODE_404, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.Insert(table_name, body, ids_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, ids_dto);
break;
case StatusCode::TABLE_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ENDPOINT_INFO(Search) {
@ -582,23 +657,26 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(Search)
ENDPOINT("PUT",
"/tables/{table_name}/vectors",
Search,
PATH(String, table_name),
BODY_DTO(SearchRequestDto::ObjectWrapper, body)) {
ENDPOINT("PUT", "/tables/{table_name}/vectors", Search,
PATH(String, table_name), BODY_DTO(SearchRequestDto::ObjectWrapper, body)) {
auto results_dto = TopkResultsDto::createShared();
WebRequestHandler handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.Search(table_name, body, results_dto);
int64_t code = status_dto->code->getValue();
if (0 == code) {
return createDtoResponse(Status::CODE_200, results_dto);
} else if (StatusCode::TABLE_NOT_EXISTS == code) {
return createDtoResponse(Status::CODE_404, status_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, results_dto);
break;
case StatusCode::TABLE_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
ENDPOINT_INFO(SystemMsg) {
@ -608,7 +686,6 @@ class WebController : public oatpp::web::server::api::ApiController {
info->addResponse<CommandDto::ObjectWrapper>(Status::CODE_200, "application/json");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_400, "application/json");
info->addResponse<StatusDto::ObjectWrapper>(Status::CODE_404, "application/json");
}
ADD_CORS(SystemMsg)
@ -618,13 +695,18 @@ class WebController : public oatpp::web::server::api::ApiController {
WebRequestHandler handler = WebRequestHandler();
handler.RegisterRequestHandler(::milvus::server::RequestHandler());
auto status_dto = handler.Cmd(msg, cmd_dto);
if (0 == status_dto->code->getValue()) {
return createDtoResponse(Status::CODE_200, cmd_dto);
} else {
return createDtoResponse(Status::CODE_400, status_dto);
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.Cmd(msg, cmd_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, cmd_dto);
break;
default:
return createDtoResponse(Status::CODE_400, status_dto);
}
return response;
}
/**

View File

@ -59,14 +59,7 @@ class TableListFieldsDto : public OObject {
DTO_INIT(TableListFieldsDto, Object)
DTO_FIELD(List<TableFieldsDto::ObjectWrapper>::ObjectWrapper, tables);
DTO_FIELD(Int64, count);
};
class TablesResponseDto : public OObject {
DTO_INIT(TablesResponseDto, Object)
DTO_FIELD(TableListFieldsDto::ObjectWrapper, tables_fields);
DTO_FIELD(Int64, page_num);
DTO_FIELD(Int64, count) = 0;
};
#include OATPP_CODEGEN_END(DTO)

View File

@ -46,14 +46,15 @@ class SearchRequestDto : public OObject {
DTO_FIELD(List<String>::ObjectWrapper, tags);
DTO_FIELD(List<String>::ObjectWrapper, file_ids);
DTO_FIELD(List<List<Float32>::ObjectWrapper>::ObjectWrapper, records);
DTO_FIELD(List<List<Int64>::ObjectWrapper>::ObjectWrapper, records_bin);
};
class InsertRequestDto : public oatpp::data::mapping::type::Object {
DTO_INIT(InsertRequestDto, Object)
DTO_FIELD(String, tag) = VALUE_PARTITION_TAG_DEFAULT;
DTO_FIELD(List<List<Float32>::ObjectWrapper>::ObjectWrapper, records);
DTO_FIELD(List<List<Int64>::ObjectWrapper>::ObjectWrapper, records_bin);
DTO_FIELD(List<Int64>::ObjectWrapper, ids);
};
@ -66,17 +67,10 @@ class VectorIdsDto : public oatpp::data::mapping::type::Object {
class ResultDto : public oatpp::data::mapping::type::Object {
DTO_INIT(ResultDto, Object)
// DTO_FIELD(Int64, num);
DTO_FIELD(String, id);
DTO_FIELD(String, dit, "distance");
};
class RowResultsDto : public OObject {
DTO_INIT(RowResultsDto, Object)
// DTO_FIELD(List<ResultDto::ObjectWrapper>::ObjectWrapper, );
};
class TopkResultsDto : public OObject {
DTO_INIT(TopkResultsDto, Object);

View File

@ -17,19 +17,20 @@
#include "server/web_impl/handler/WebRequestHandler.h"
#include <boost/algorithm/string.hpp>
#include <cmath>
#include <ctime>
#include <string>
#include <vector>
#include "metrics/SystemInfo.h"
#include "utils/Log.h"
#include "server/Config.h"
#include "server/delivery/request/BaseRequest.h"
#include "server/web_impl/Constants.h"
#include "server/web_impl/Types.h"
#include "server/web_impl/dto/PartitionDto.hpp"
#include "server/web_impl/utils/Util.h"
#include "utils/StringHelpFunctions.h"
#include "utils/TimeRecorder.h"
namespace milvus {
namespace server {
@ -79,83 +80,9 @@ WebErrorMap(ErrorCode code) {
}
}
namespace {
Status
CopyRowRecords(const InsertRequestDto::ObjectWrapper& param, engine::VectorsData& vectors) {
vectors.float_data_.clear();
vectors.binary_data_.clear();
vectors.id_array_.clear();
vectors.vector_count_ = param->records->count();
// step 1: copy vector data
if (nullptr == param->records.get()) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "");
}
size_t tal_size = 0;
for (int64_t i = 0; i < param->records->count(); i++) {
tal_size += param->records->get(i)->count();
}
std::vector<float>& datas = vectors.float_data_;
datas.resize(tal_size);
size_t index_offset = 0;
param->records->forEach([&datas, &index_offset](const OList<OFloat32>::ObjectWrapper& row_item) {
row_item->forEach([&datas, &index_offset](const OFloat32& item) {
datas[index_offset] = item->getValue();
index_offset++;
});
});
// step 2: copy id array
if (nullptr == param->ids.get()) {
return Status(SERVER_ILLEGAL_VECTOR_ID, "");
}
for (int64_t i = 0; i < param->ids->count(); i++) {
vectors.id_array_.emplace_back(param->ids->get(i)->getValue());
}
return Status::OK();
}
Status
CopyRowRecords(const SearchRequestDto::ObjectWrapper& param, engine::VectorsData& vectors) {
vectors.float_data_.clear();
vectors.binary_data_.clear();
vectors.id_array_.clear();
vectors.vector_count_ = param->records->count();
// step 1: copy vector data
if (nullptr == param->records.get()) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "");
}
size_t tal_size = 0;
for (int64_t i = 0; i < param->records->count(); i++) {
tal_size += param->records->get(i)->count();
}
std::vector<float>& datas = vectors.float_data_;
datas.resize(tal_size);
size_t index_offset = 0;
param->records->forEach([&datas, &index_offset](const OList<OFloat32>::ObjectWrapper& row_item) {
row_item->forEach([&datas, &index_offset](const OFloat32& item) {
datas[index_offset] = item->getValue();
index_offset++;
});
});
return Status::OK();
}
} // namespace
///////////////////////// WebRequestHandler methods ///////////////////////////////////////
Status
WebRequestHandler::GetTaleInfo(const std::shared_ptr<Context>& context, const std::string& table_name,
std::map<std::string, std::string>& table_info) {
WebRequestHandler::GetTableInfo(const std::string& table_name, TableFieldsDto::ObjectWrapper& table_fields) {
TableSchema schema;
auto status = request_handler_.DescribeTable(context_ptr_, table_name, schema);
if (!status.ok()) {
@ -174,26 +101,28 @@ WebRequestHandler::GetTaleInfo(const std::shared_ptr<Context>& context, const st
return status;
}
table_info[KEY_TABLE_TABLE_NAME] = schema.table_name_;
table_info[KEY_TABLE_DIMENSION] = std::to_string(schema.dimension_);
table_info[KEY_TABLE_INDEX_METRIC_TYPE] = std::string(MetricMap.at(engine::MetricType(schema.metric_type_)));
table_info[KEY_TABLE_INDEX_FILE_SIZE] = std::to_string(schema.index_file_size_);
table_fields->table_name = schema.table_name_.c_str();
table_fields->dimension = schema.dimension_;
table_fields->index_file_size = schema.index_file_size_;
table_fields->index = IndexMap.at(engine::EngineType(index_param.index_type_)).c_str();
table_fields->nlist = index_param.nlist_;
table_fields->metric_type = MetricMap.at(engine::MetricType(schema.metric_type_)).c_str();
table_fields->count = count;
}
table_info[KEY_INDEX_INDEX_TYPE] = std::string(IndexMap.at(engine::EngineType(index_param.index_type_)));
table_info[KEY_INDEX_NLIST] = std::to_string(index_param.nlist_);
table_info[KEY_TABLE_COUNT] = std::to_string(count);
Status
WebRequestHandler::CommandLine(const std::string& cmd, std::string& reply) {
return request_handler_.Cmd(context_ptr_, cmd, reply);
}
/////////////////////////////////////////// Router methods ////////////////////////////////////////////
StatusDto::ObjectWrapper
WebRequestHandler::GetDevices(DevicesDto::ObjectWrapper& devices_dto) {
auto getgb = [](uint64_t x) -> uint64_t { return x / 1024 / 1024 / 1024; };
auto system_info = SystemInfo::GetInstance();
devices_dto->cpu = devices_dto->cpu->createShared();
devices_dto->cpu->memory = getgb(system_info.GetPhysicalMemory());
devices_dto->cpu->memory = system_info.GetPhysicalMemory() >> 30;
devices_dto->gpus = devices_dto->gpus->createShared();
@ -203,12 +132,12 @@ WebRequestHandler::GetDevices(DevicesDto::ObjectWrapper& devices_dto) {
std::vector<uint64_t> device_mems = system_info.GPUMemoryTotal();
if (count != device_mems.size()) {
ASSIGN_RETURN_STATUS_DTO(Status(UNEXPECTED_ERROR, "Can't obtain GPU info"));
RETURN_STATUS_DTO(UNEXPECTED_ERROR, "Can't obtain GPU info");
}
for (size_t i = 0; i < count; i++) {
auto device_dto = DeviceInfoDto::createShared();
device_dto->memory = getgb(device_mems.at(i));
device_dto->memory = device_mems.at(i) >> 30;
devices_dto->gpus->put("GPU" + OString(std::to_string(i).c_str()), device_dto);
}
@ -220,35 +149,39 @@ WebRequestHandler::GetDevices(DevicesDto::ObjectWrapper& devices_dto) {
StatusDto::ObjectWrapper
WebRequestHandler::GetAdvancedConfig(AdvancedConfigDto::ObjectWrapper& advanced_config) {
Config& config = Config::GetInstance();
std::string reply;
std::string cache_cmd_prefix = "get_config " + std::string(CONFIG_CACHE) + ".";
int64_t value;
auto status = config.GetCacheConfigCpuCacheCapacity(value);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status);
}
advanced_config->cpu_cache_capacity = value;
bool ok;
status = config.GetCacheConfigCacheInsertData(ok);
std::string cache_cmd_string = cache_cmd_prefix + std::string(CONFIG_CACHE_CPU_CACHE_CAPACITY);
auto status = CommandLine(cache_cmd_string, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
advanced_config->cache_insert_data = ok;
advanced_config->cpu_cache_capacity = std::stol(reply);
status = config.GetEngineConfigUseBlasThreshold(value);
cache_cmd_string = cache_cmd_prefix + std::string(CONFIG_CACHE_CACHE_INSERT_DATA);
CommandLine(cache_cmd_string, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
advanced_config->use_blas_threshold = value;
advanced_config->cache_insert_data = ("1" == reply || "true" == reply);
auto engine_cmd_prefix = "get_config " + std::string(CONFIG_ENGINE) + ".";
auto engine_cmd_string = engine_cmd_prefix + std::string(CONFIG_ENGINE_USE_BLAS_THRESHOLD);
CommandLine(engine_cmd_string, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
advanced_config->use_blas_threshold = std::stol(reply);
#ifdef MILVUS_GPU_VERSION
status = config.GetEngineConfigGpuSearchThreshold(value);
engine_cmd_string = engine_cmd_prefix + std::string(CONFIG_ENGINE_GPU_SEARCH_THRESHOLD);
CommandLine(engine_cmd_string, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
advanced_config->gpu_search_threshold = value;
advanced_config->gpu_search_threshold = std::stol(reply);
#endif
ASSIGN_RETURN_STATUS_DTO(status)
@ -256,44 +189,57 @@ WebRequestHandler::GetAdvancedConfig(AdvancedConfigDto::ObjectWrapper& advanced_
StatusDto::ObjectWrapper
WebRequestHandler::SetAdvancedConfig(const AdvancedConfigDto::ObjectWrapper& advanced_config) {
Config& config = Config::GetInstance();
if (nullptr == advanced_config->cpu_cache_capacity.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'cpu_cache_capacity\' miss.");
}
auto status =
config.SetCacheConfigCpuCacheCapacity(std::to_string(advanced_config->cpu_cache_capacity->getValue()));
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
if (nullptr == advanced_config->cache_insert_data.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'cache_insert_data\' miss.");
}
status = config.SetCacheConfigCacheInsertData(std::to_string(advanced_config->cache_insert_data->getValue()));
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
if (nullptr == advanced_config->use_blas_threshold.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'use_blas_threshold\' miss.");
}
status = config.SetEngineConfigUseBlasThreshold(std::to_string(advanced_config->use_blas_threshold->getValue()));
#ifdef MILVUS_GPU_VERSION
if (nullptr == advanced_config->gpu_search_threshold.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'gpu_search_threshold\' miss.");
}
#endif
std::string reply;
std::string cache_cmd_prefix = "set_config " + std::string(CONFIG_CACHE) + ".";
std::string cache_cmd_string = cache_cmd_prefix + std::string(CONFIG_CACHE_CPU_CACHE_CAPACITY) + " " +
std::to_string(advanced_config->cpu_cache_capacity->getValue());
auto status = CommandLine(cache_cmd_string, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
cache_cmd_string = cache_cmd_prefix + std::string(CONFIG_CACHE_CACHE_INSERT_DATA) + " " +
std::to_string(advanced_config->cache_insert_data->getValue());
status = CommandLine(cache_cmd_string, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
auto engine_cmd_prefix = "set_config " + std::string(CONFIG_ENGINE) + ".";
auto engine_cmd_string = engine_cmd_prefix + std::string(CONFIG_ENGINE_USE_BLAS_THRESHOLD) + " " +
std::to_string(advanced_config->use_blas_threshold->getValue());
status = CommandLine(engine_cmd_string, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
#ifdef MILVUS_GPU_VERSION
if (nullptr == advanced_config->gpu_search_threshold.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'gpu_search_threshold\' miss.");
}
status =
config.SetEngineConfigGpuSearchThreshold(std::to_string(advanced_config->gpu_search_threshold->getValue()));
engine_cmd_string = engine_cmd_prefix + std::string(CONFIG_ENGINE_GPU_SEARCH_THRESHOLD) + " " +
std::to_string(advanced_config->gpu_search_threshold->getValue());
CommandLine(engine_cmd_string, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
#endif
ASSIGN_RETURN_STATUS_DTO(status)
@ -303,46 +249,52 @@ WebRequestHandler::SetAdvancedConfig(const AdvancedConfigDto::ObjectWrapper& adv
StatusDto::ObjectWrapper
WebRequestHandler::GetGpuConfig(GPUConfigDto::ObjectWrapper& gpu_config_dto) {
Config& config = Config::GetInstance();
std::string reply;
std::string gpu_cmd_prefix = "get_config " + std::string(CONFIG_GPU_RESOURCE) + ".";
bool enable;
auto status = config.GetGpuResourceConfigEnable(enable);
std::string gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_ENABLE);
auto status = CommandLine(gpu_cmd_request, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status);
}
gpu_config_dto->enable = enable;
gpu_config_dto->enable = reply == "1" || reply == "true";
if (!enable) {
if (!gpu_config_dto->enable->getValue()) {
ASSIGN_RETURN_STATUS_DTO(Status::OK());
}
int64_t capacity;
status = config.GetGpuResourceConfigCacheCapacity(capacity);
gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_CACHE_CAPACITY);
status = CommandLine(gpu_cmd_request, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status);
}
gpu_config_dto->cache_capacity = capacity;
gpu_config_dto->cache_capacity = std::stol(reply);
std::vector<int64_t> values;
status = config.GetGpuResourceConfigSearchResources(values);
gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_SEARCH_RESOURCES);
status = CommandLine(gpu_cmd_request, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status);
}
std::vector<std::string> gpu_entry;
StringHelpFunctions::SplitStringByDelimeter(reply, ",", gpu_entry);
gpu_config_dto->search_resources = gpu_config_dto->search_resources->createShared();
for (auto& device_id : values) {
gpu_config_dto->search_resources->pushBack("GPU" + OString(std::to_string(device_id).c_str()));
for (auto& device_id : gpu_entry) {
gpu_config_dto->search_resources->pushBack(OString(device_id.c_str())->toUpperCase());
}
gpu_entry.clear();
values.clear();
status = config.GetGpuResourceConfigBuildIndexResources(values);
gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES);
status = CommandLine(gpu_cmd_request, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status);
}
StringHelpFunctions::SplitStringByDelimeter(reply, ",", gpu_entry);
gpu_config_dto->build_index_resources = gpu_config_dto->build_index_resources->createShared();
for (auto& device_id : values) {
gpu_config_dto->build_index_resources->pushBack("GPU" + OString(std::to_string(device_id).c_str()));
for (auto& device_id : gpu_entry) {
gpu_config_dto->build_index_resources->pushBack(OString(device_id.c_str())->toUpperCase());
}
ASSIGN_RETURN_STATUS_DTO(Status::OK());
@ -354,33 +306,46 @@ WebRequestHandler::GetGpuConfig(GPUConfigDto::ObjectWrapper& gpu_config_dto) {
StatusDto::ObjectWrapper
WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dto) {
Config& config = Config::GetInstance();
// Step 1: Check config param
if (nullptr == gpu_config_dto->enable.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'enable\' miss")
}
auto status = config.SetGpuResourceConfigEnable(std::to_string(gpu_config_dto->enable->getValue()));
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status);
}
if (!gpu_config_dto->enable->getValue()) {
RETURN_STATUS_DTO(SUCCESS, "Set Gpu resources false");
}
if (nullptr == gpu_config_dto->cache_capacity.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'cache_capacity\' miss")
}
status = config.SetGpuResourceConfigCacheCapacity(std::to_string(gpu_config_dto->cache_capacity->getValue()));
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status);
}
if (nullptr == gpu_config_dto->search_resources.get()) {
gpu_config_dto->search_resources = gpu_config_dto->search_resources->createShared();
gpu_config_dto->search_resources->pushBack("GPU0");
}
if (nullptr == gpu_config_dto->build_index_resources.get()) {
gpu_config_dto->build_index_resources = gpu_config_dto->build_index_resources->createShared();
gpu_config_dto->build_index_resources->pushBack("GPU0");
}
// Step 2: Set config
std::string reply;
std::string gpu_cmd_prefix = "set_config " + std::string(CONFIG_GPU_RESOURCE) + ".";
std::string gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_ENABLE) + " " +
std::to_string(gpu_config_dto->enable->getValue());
auto status = CommandLine(gpu_cmd_request, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status);
}
if (!gpu_config_dto->enable->getValue()) {
RETURN_STATUS_DTO(SUCCESS, "Set Gpu resources to false");
}
gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_CACHE_CAPACITY) + " " +
std::to_string(gpu_config_dto->cache_capacity->getValue());
status = CommandLine(gpu_cmd_request, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status);
}
std::vector<std::string> search_resources;
gpu_config_dto->search_resources->forEach(
[&search_resources](const OString& res) { search_resources.emplace_back(res->toLowerCase()->std_str()); });
@ -393,15 +358,13 @@ WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dt
if (len > 0) {
search_resources_value.erase(len - 1);
}
status = config.SetGpuResourceConfigSearchResources(search_resources_value);
gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_SEARCH_RESOURCES) + " " + search_resources_value;
status = CommandLine(gpu_cmd_request, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status);
}
if (nullptr == gpu_config_dto->build_index_resources.get()) {
gpu_config_dto->build_index_resources = gpu_config_dto->build_index_resources->createShared();
gpu_config_dto->build_index_resources->pushBack("GPU0");
}
std::vector<std::string> build_resources;
gpu_config_dto->build_index_resources->forEach(
[&build_resources](const OString& res) { build_resources.emplace_back(res->toLowerCase()->std_str()); });
@ -415,7 +378,9 @@ WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dt
build_resources_value.erase(len - 1);
}
status = config.SetGpuResourceConfigBuildIndexResources(build_resources_value);
gpu_cmd_request =
gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES) + " " + build_resources_value;
status = CommandLine(gpu_cmd_request, reply);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status);
}
@ -461,74 +426,62 @@ WebRequestHandler::GetTable(const OString& table_name, const OQueryParams& query
RETURN_STATUS_DTO(PATH_PARAM_LOSS, "Path param \'table_name\' is required!");
}
Status status = Status::OK();
// TODO: query string field `fields` npt used here
std::map<std::string, std::string> table_info;
status = GetTaleInfo(context_ptr_, table_name->std_str(), table_info);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
fields_dto->table_name = table_info[KEY_TABLE_TABLE_NAME].c_str();
fields_dto->dimension = std::stol(table_info[KEY_TABLE_DIMENSION]);
fields_dto->index = table_info[KEY_INDEX_INDEX_TYPE].c_str();
fields_dto->nlist = std::stol(table_info[KEY_INDEX_NLIST]);
fields_dto->metric_type = table_info[KEY_TABLE_INDEX_METRIC_TYPE].c_str();
fields_dto->index_file_size = std::stol(table_info[KEY_TABLE_INDEX_FILE_SIZE]);
fields_dto->count = std::stol(table_info[KEY_TABLE_COUNT]);
auto status = GetTableInfo(table_name->std_str(), fields_dto);
ASSIGN_RETURN_STATUS_DTO(status);
}
StatusDto::ObjectWrapper
WebRequestHandler::ShowTables(const OInt64& offset, const OInt64& page_size,
WebRequestHandler::ShowTables(const OString& offset, const OString& page_size,
TableListFieldsDto::ObjectWrapper& response_dto) {
if (nullptr == offset.get()) {
RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'offset\' is required");
int64_t offset_value = 0;
int64_t page_size_value = 10;
if (nullptr != offset.get()) {
try {
offset_value = std::stol(offset->std_str());
} catch (const std::exception& e) {
RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, "Query param \'offset\' is illegal, only type of \'int\' allowed");
}
}
if (nullptr == page_size.get()) {
RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'page_size\' is required");
if (nullptr != page_size.get()) {
try {
page_size_value = std::stol(page_size->std_str());
} catch (const std::exception& e) {
RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM,
"Query param \'page_size\' is illegal, only type of \'int\' allowed");
}
}
if (offset_value < 0 || page_size_value < 0) {
RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, "Query param 'offset' or 'page_size' should equal or bigger than 0");
}
std::vector<std::string> tables;
Status status = Status::OK();
auto status = request_handler_.ShowTables(context_ptr_, tables);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
response_dto->tables = response_dto->tables->createShared();
if (offset < 0 || page_size < 0) {
ASSIGN_RETURN_STATUS_DTO(
Status(SERVER_UNEXPECTED_ERROR, "Query param 'offset' or 'page_size' should bigger than 0"));
} else {
status = request_handler_.ShowTables(context_ptr_, tables);
if (offset_value >= tables.size()) {
ASSIGN_RETURN_STATUS_DTO(Status::OK());
}
response_dto->count = tables.size();
int64_t size = page_size_value + offset_value > tables.size() ? tables.size() - offset_value : page_size_value;
for (int64_t i = offset_value; i < size + offset_value; i++) {
auto table_fields_dto = TableFieldsDto::createShared();
status = GetTableInfo(tables.at(i), table_fields_dto);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
break;
}
if (offset < tables.size()) {
int64_t size = (page_size->getValue() + offset->getValue() > tables.size()) ? tables.size() - offset
: page_size->getValue();
for (int64_t i = offset->getValue(); i < size + offset->getValue(); i++) {
std::map<std::string, std::string> table_info;
status = GetTaleInfo(context_ptr_, tables.at(i), table_info);
if (!status.ok()) {
break;
}
auto table_fields_dto = TableFieldsDto::createShared();
table_fields_dto->table_name = table_info[KEY_TABLE_TABLE_NAME].c_str();
table_fields_dto->dimension = std::stol(table_info[std::string(KEY_TABLE_DIMENSION)]);
table_fields_dto->index_file_size = std::stol(table_info[std::string(KEY_TABLE_INDEX_FILE_SIZE)]);
table_fields_dto->index = table_info[KEY_INDEX_INDEX_TYPE].c_str();
table_fields_dto->nlist = std::stol(table_info[KEY_INDEX_NLIST]);
table_fields_dto->metric_type = table_info[KEY_TABLE_INDEX_METRIC_TYPE].c_str();
table_fields_dto->count = std::stol(table_info[KEY_TABLE_COUNT]);
response_dto->tables->pushBack(table_fields_dto);
}
response_dto->count = tables.size();
}
response_dto->tables->pushBack(table_fields_dto);
}
ASSIGN_RETURN_STATUS_DTO(status)
@ -598,31 +551,50 @@ WebRequestHandler::CreatePartition(const OString& table_name, const PartitionReq
}
StatusDto::ObjectWrapper
WebRequestHandler::ShowPartitions(const OInt64& offset, const OInt64& page_size, const OString& table_name,
WebRequestHandler::ShowPartitions(const OString& offset, const OString& page_size, const OString& table_name,
PartitionListDto::ObjectWrapper& partition_list_dto) {
if (nullptr == offset.get()) {
RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'offset\' is required!");
int64_t offset_value = 0;
int64_t page_size_value = 10;
if (nullptr != offset.get()) {
try {
offset_value = std::stol(offset->std_str());
} catch (const std::exception& e) {
std::string msg = "Query param \'offset\' is illegal. Reason: " + std::string(e.what());
RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, msg.c_str());
}
}
if (nullptr == page_size.get()) {
RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'page_size\' is required!");
if (nullptr != page_size.get()) {
try {
page_size_value = std::stol(page_size->std_str());
} catch (const std::exception& e) {
std::string msg = "Query param \'page_size\' is illegal. Reason: " + std::string(e.what());
RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, msg.c_str());
}
}
if (offset_value < 0 || page_size_value < 0) {
ASSIGN_RETURN_STATUS_DTO(
Status(SERVER_UNEXPECTED_ERROR, "Query param 'offset' or 'page_size' should equal or bigger than 0"));
}
std::vector<PartitionParam> partitions;
auto status = request_handler_.ShowPartitions(context_ptr_, table_name->std_str(), partitions);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
if (status.ok()) {
partition_list_dto->partitions = partition_list_dto->partitions->createShared();
partition_list_dto->partitions = partition_list_dto->partitions->createShared();
if (offset->getValue() < partitions.size()) {
int64_t size = (offset->getValue() + page_size->getValue() > partitions.size()) ? partitions.size() - offset
: page_size->getValue();
for (int64_t i = offset->getValue(); i < size + offset->getValue(); i++) {
auto partition_dto = PartitionFieldsDto::createShared();
partition_dto->partition_name = partitions.at(i).partition_name_.c_str();
partition_dto->partition_tag = partitions.at(i).tag_.c_str();
partition_list_dto->partitions->pushBack(partition_dto);
}
if (offset_value < partitions.size()) {
int64_t size =
offset_value + page_size_value > partitions.size() ? partitions.size() - offset_value : page_size_value;
for (int64_t i = offset_value; i < size + offset_value; i++) {
auto partition_dto = PartitionFieldsDto::createShared();
partition_dto->partition_name = partitions.at(i).partition_name_.c_str();
partition_dto->partition_tag = partitions.at(i).tag_.c_str();
partition_list_dto->partitions->pushBack(partition_dto);
}
}
@ -637,15 +609,47 @@ WebRequestHandler::DropPartition(const OString& table_name, const OString& tag)
}
StatusDto::ObjectWrapper
WebRequestHandler::Insert(const OString& table_name, const InsertRequestDto::ObjectWrapper& param,
WebRequestHandler::Insert(const OString& table_name, const InsertRequestDto::ObjectWrapper& request,
VectorIdsDto::ObjectWrapper& ids_dto) {
engine::VectorsData vectors;
auto status = CopyRowRecords(param, vectors);
if (status.code() == SERVER_INVALID_ROWRECORD_ARRAY) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors")
TableSchema schema;
auto status = request_handler_.DescribeTable(context_ptr_, table_name->std_str(), schema);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
status = request_handler_.Insert(context_ptr_, table_name->std_str(), vectors, param->tag->std_str());
auto metric = engine::MetricType(schema.metric_type_);
engine::VectorsData vectors;
bool bin_flag = engine::MetricType::HAMMING == metric || engine::MetricType::JACCARD == metric ||
engine::MetricType::TANIMOTO == metric;
if (!bin_flag) {
if (nullptr == request->records.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors");
}
vectors.vector_count_ = request->records->count();
status = CopyRowRecords(request->records, vectors.float_data_);
} else {
if (nullptr == request->records_bin.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records_bin\' is required to fill vectors");
}
vectors.vector_count_ = request->records_bin->count();
status = CopyBinRowRecords(request->records_bin, vectors.binary_data_);
}
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
// step 2: copy id array
if (nullptr != request->ids.get()) {
auto& id_array = vectors.id_array_;
id_array.resize(request->ids->count());
size_t i = 0;
request->ids->forEach([&id_array, &i](const OInt64& item) { id_array[i++] = item->getValue(); });
}
status = request_handler_.Insert(context_ptr_, table_name->std_str(), vectors, request->tag->std_str());
if (status.ok()) {
ids_dto->ids = ids_dto->ids->createShared();
@ -658,42 +662,58 @@ WebRequestHandler::Insert(const OString& table_name, const InsertRequestDto::Obj
}
StatusDto::ObjectWrapper
WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::ObjectWrapper& search_request,
WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::ObjectWrapper& request,
TopkResultsDto::ObjectWrapper& results_dto) {
if (nullptr == search_request->topk.get()) {
if (nullptr == request->topk.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'topk\' is required in request body")
}
int64_t topk_t = search_request->topk->getValue();
int64_t topk_t = request->topk->getValue();
if (nullptr == search_request->nprobe.get()) {
if (nullptr == request->nprobe.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'nprobe\' is required in request body")
}
int64_t nprobe_t = search_request->nprobe->getValue();
int64_t nprobe_t = request->nprobe->getValue();
std::vector<std::string> tag_list;
if (nullptr != request->tags.get()) {
request->tags->forEach([&tag_list](const OString& tag) { tag_list.emplace_back(tag->std_str()); });
}
std::vector<std::string> file_id_list;
if (nullptr != search_request->tags.get()) {
search_request->tags->forEach([&tag_list](const OString& tag) { tag_list.emplace_back(tag->std_str()); });
if (nullptr != request->file_ids.get()) {
request->file_ids->forEach([&file_id_list](const OString& id) { file_id_list.emplace_back(id->std_str()); });
}
if (nullptr != search_request->file_ids.get()) {
search_request->file_ids->forEach(
[&file_id_list](const OString& id) { file_id_list.emplace_back(id->std_str()); });
}
if (nullptr == search_request->records.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill query vectors")
TableSchema schema;
auto status = request_handler_.DescribeTable(context_ptr_, table_name->std_str(), schema);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
auto metric = engine::MetricType(schema.metric_type_);
bool bin_flag = engine::MetricType::HAMMING == metric || engine::MetricType::JACCARD == metric ||
engine::MetricType::TANIMOTO == metric;
engine::VectorsData vectors;
auto status = CopyRowRecords(search_request, vectors);
if (status.code() == SERVER_INVALID_ROWRECORD_ARRAY) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors")
if (!bin_flag) {
if (nullptr == request->records.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors");
}
vectors.vector_count_ = request->records->count();
status = CopyRowRecords(request->records, vectors.float_data_);
} else {
if (nullptr == request->records_bin.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records_bin\' is required to fill vectors");
}
vectors.vector_count_ = request->records_bin->count();
status = CopyBinRowRecords(request->records_bin, vectors.binary_data_);
}
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
std::vector<Range> range_list;
TopKQueryResult result;
auto context_ptr = GenContextPtr("Web Handler");
status = request_handler_.Search(context_ptr, table_name->std_str(), vectors, range_list, topk_t, nprobe_t,
@ -725,8 +745,14 @@ WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::Obj
StatusDto::ObjectWrapper
WebRequestHandler::Cmd(const OString& cmd, CommandDto::ObjectWrapper& cmd_dto) {
std::string info = cmd->std_str();
if ("info" == info) {
info = "get_system_info";
}
std::string reply_str;
auto status = request_handler_.Cmd(context_ptr_, cmd->std_str(), reply_str);
auto status = CommandLine(info, reply_str);
if (status.ok()) {
cmd_dto->reply = reply_str.c_str();

View File

@ -23,10 +23,9 @@
#include <utility>
#include <opentracing/mocktracer/tracer.h>
#include <oatpp/web/server/api/ApiController.hpp>
#include <oatpp/core/data/mapping/type/Object.hpp>
#include <oatpp/core/macro/codegen.hpp>
#include <oatpp/web/server/api/ApiController.hpp>
#include "server/web_impl/Types.h"
#include "server/web_impl/dto/CmdDto.hpp"
@ -37,6 +36,7 @@
#include "server/web_impl/dto/TableDto.hpp"
#include "server/web_impl/dto/VectorDto.hpp"
#include "db/Types.h"
#include "server/context/Context.h"
#include "server/delivery/RequestHandler.h"
#include "utils/Status.h"
@ -82,15 +82,18 @@ class WebRequestHandler {
return context_ptr;
}
protected:
Status
GetTableInfo(const std::string& table_name, TableFieldsDto::ObjectWrapper& table_fields);
Status
CommandLine(const std::string& cmd, std::string& reply);
public:
WebRequestHandler() {
context_ptr_ = GenContextPtr("Web Handler");
}
Status
GetTaleInfo(const std::shared_ptr<Context>& context, const std::string& table_name,
std::map<std::string, std::string>& table_info);
StatusDto::ObjectWrapper
GetDevices(DevicesDto::ObjectWrapper& devices);
@ -115,7 +118,7 @@ class WebRequestHandler {
GetTable(const OString& table_name, const OQueryParams& query_params, TableFieldsDto::ObjectWrapper& schema_dto);
StatusDto::ObjectWrapper
ShowTables(const OInt64& offset, const OInt64& page_size, TableListFieldsDto::ObjectWrapper& table_list_dto);
ShowTables(const OString& offset, const OString& page_size, TableListFieldsDto::ObjectWrapper& table_list_dto);
StatusDto::ObjectWrapper
DropTable(const OString& table_name);
@ -133,7 +136,7 @@ class WebRequestHandler {
CreatePartition(const OString& table_name, const PartitionRequestDto::ObjectWrapper& param);
StatusDto::ObjectWrapper
ShowPartitions(const OInt64& offset, const OInt64& page_size, const OString& table_name,
ShowPartitions(const OString& offset, const OString& page_size, const OString& table_name,
PartitionListDto::ObjectWrapper& partition_list_dto);
StatusDto::ObjectWrapper
@ -150,6 +153,7 @@ class WebRequestHandler {
StatusDto::ObjectWrapper
Cmd(const OString& cmd, CommandDto::ObjectWrapper& cmd_dto);
public:
WebRequestHandler&
RegisterRequestHandler(const RequestHandler& handler) {
request_handler_ = handler;

View File

@ -0,0 +1,65 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "server/web_impl/utils/Util.h"
namespace milvus {
namespace server {
namespace web {
Status
CopyRowRecords(const OList<OList<OFloat32>::ObjectWrapper>::ObjectWrapper& records, std::vector<float>& vectors) {
size_t tal_size = 0;
records->forEach([&tal_size](const OList<OFloat32>::ObjectWrapper& row_item) { tal_size += row_item->count(); });
vectors.resize(tal_size);
size_t index_offset = 0;
records->forEach([&vectors, &index_offset](const OList<OFloat32>::ObjectWrapper& row_item) {
row_item->forEach(
[&vectors, &index_offset](const OFloat32& item) { vectors[index_offset++] = item->getValue(); });
});
return Status::OK();
}
Status
CopyBinRowRecords(const OList<OList<OInt64>::ObjectWrapper>::ObjectWrapper& records, std::vector<uint8_t>& vectors) {
size_t tal_size = 0;
records->forEach([&tal_size](const OList<OInt64>::ObjectWrapper& item) { tal_size += item->count(); });
vectors.resize(tal_size);
size_t index_offset = 0;
bool oor = false;
records->forEach([&vectors, &index_offset, &oor](const OList<OInt64>::ObjectWrapper& row_item) {
row_item->forEach([&vectors, &index_offset, &oor](const OInt64& item) {
if (!oor) {
int64_t value = item->getValue();
if (0 > value || value > 255) {
oor = true;
} else {
vectors[index_offset++] = static_cast<uint8_t>(value);
}
}
});
});
return Status::OK();
}
} // namespace web
} // namespace server
} // namespace milvus

View File

@ -0,0 +1,38 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <vector>
#include "db/Types.h"
#include "server/web_impl/Types.h"
#include "utils/Status.h"
namespace milvus {
namespace server {
namespace web {
Status
CopyRowRecords(const OList<OList<OFloat32>::ObjectWrapper>::ObjectWrapper& records, std::vector<float>& vectors);
Status
CopyBinRowRecords(const OList<OList<OInt64>::ObjectWrapper>::ObjectWrapper& records, std::vector<uint8_t>& vectors);
} // namespace web
} // namespace server
} // namespace milvus

View File

@ -194,7 +194,33 @@ ValidationUtil::ValidatePartitionName(const std::string& partition_name) {
return Status(SERVER_INVALID_TABLE_NAME, msg);
}
return ValidateTableName(partition_name);
std::string invalid_msg = "Invalid partition name: " + partition_name + ". ";
// Table name size shouldn't exceed 16384.
if (partition_name.size() > TABLE_NAME_SIZE_LIMIT) {
std::string msg = invalid_msg + "The length of a partition name must be less than 255 characters.";
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_TABLE_NAME, msg);
}
// Table name first character should be underscore or character.
char first_char = partition_name[0];
if (first_char != '_' && std::isalpha(first_char) == 0) {
std::string msg = invalid_msg + "The first character of a partition name must be an underscore or letter.";
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_TABLE_NAME, msg);
}
int64_t table_name_size = partition_name.size();
for (int64_t i = 1; i < table_name_size; ++i) {
char name_char = partition_name[i];
if (name_char != '_' && std::isalnum(name_char) == 0) {
std::string msg = invalid_msg + "Partition name can only contain numbers, letters, and underscores.";
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_TABLE_NAME, msg);
}
}
return Status::OK();
}
Status
@ -207,7 +233,7 @@ ValidationUtil::ValidatePartitionTags(const std::vector<std::string>& partition_
if (valid_tag.empty()) {
std::string msg = "Invalid partition tag: " + valid_tag + ". " + "Partition tag should not be empty.";
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_NPROBE, msg);
return Status(SERVER_INVALID_TABLE_NAME, msg);
}
}

View File

@ -78,12 +78,14 @@ aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/handler web_handler_fi
aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/component web_conponent_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/controller web_controller_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/dto web_dto_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/utils web_utils_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl web_impl_files)
set(web_server_files
${web_handler_files}
${web_conponent_files}
${web_controller_files}
${web_dto_files}
${web_utils_files}
${web_impl_files}
)

View File

@ -53,6 +53,8 @@
#include "server/DBWrapper.h"
#include "utils/CommonUtil.h"
#include "unittest/server/utils.h"
static const char* TABLE_NAME = "test_web";
static constexpr int64_t TABLE_DIM = 256;
static constexpr int64_t INDEX_FILE_SIZE = 1024;
@ -66,6 +68,7 @@ using OQueryParams = milvus::server::web::OQueryParams;
using OChunkedBuffer = oatpp::data::stream::ChunkedBuffer;
using OOutputStream = oatpp::data::stream::BufferOutputStream;
using OFloat32 = milvus::server::web::OFloat32;
using OInt64 = milvus::server::web::OInt64;
template<class T>
using OList = milvus::server::web::OList<T>;
@ -86,6 +89,19 @@ RandomRowRecordDto(int64_t dim) {
return row_record_dto;
}
OList<OInt64>::ObjectWrapper
RandomBinRowRecordDto(int64_t dim) {
auto row_record_dto = OList<OInt64>::createShared();
std::default_random_engine e;
std::uniform_real_distribution<float> u(0, 255);
for (size_t i = 0; i < dim / 8; i++) {
row_record_dto->pushBack(static_cast<int64_t>(u(e)));
}
return row_record_dto;
}
OList<OList<OFloat32>::ObjectWrapper>::ObjectWrapper
RandomRecordsDto(int64_t dim, int64_t num) {
auto records_dto = OList<OList<OFloat32>::ObjectWrapper>::createShared();
@ -96,6 +112,16 @@ RandomRecordsDto(int64_t dim, int64_t num) {
return records_dto;
}
OList<OList<OInt64>::ObjectWrapper>::ObjectWrapper
RandomBinRecordsDto(int64_t dim, int64_t num) {
auto records_dto = OList<OList<OInt64>::ObjectWrapper>::createShared();
for (size_t i = 0; i < num; i++) {
records_dto->pushBack(RandomBinRowRecordDto(dim));
}
return records_dto;
}
std::string
RandomName() {
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
@ -281,7 +307,7 @@ TEST_F(WebHandlerTest, INSERT_COUNT) {
ASSERT_EQ(0, status_dto->code->getValue());
ASSERT_EQ(1000, ids_dto->ids->count());
sleep(8);
sleep(2);
milvus::server::web::OQueryParams query_params;
query_params.put("fields", "num");
@ -344,7 +370,7 @@ TEST_F(WebHandlerTest, PARTITION) {
ASSERT_EQ(StatusCode::ILLEGAL_TABLE_NAME, status_dto->code->getValue());
auto partitions_dto = milvus::server::web::PartitionListDto::createShared();
status_dto = handler->ShowPartitions(0, 10, table_name, partitions_dto);
status_dto = handler->ShowPartitions("0", "10", table_name, partitions_dto);
ASSERT_EQ(1, partitions_dto->partitions->count());
status_dto = handler->DropPartition(table_name, "test");
@ -352,7 +378,7 @@ TEST_F(WebHandlerTest, PARTITION) {
// Show all partitions
partitions_dto = milvus::server::web::PartitionListDto::createShared();
status_dto = handler->ShowPartitions(0, 10, table_name, partitions_dto);
status_dto = handler->ShowPartitions("0", "10", table_name, partitions_dto);
}
TEST_F(WebHandlerTest, SEARCH) {
@ -400,7 +426,54 @@ TEST_F(WebHandlerTest, CMD) {
///////////////////////////////////////////////////////////////////////////////////////
namespace {
static const char* CONTROLLER_TEST_VALID_CONFIG_STR =
"# Default values are used when you make no changes to the following parameters.\n"
"\n"
"version: 0.1"
"\n"
"server_config:\n"
" address: 0.0.0.0 # milvus server ip address (IPv4)\n"
" port: 19530 # port range: 1025 ~ 65534\n"
" deploy_mode: single \n"
" time_zone: UTC+8\n"
"\n"
"db_config:\n"
" backend_url: sqlite://:@:/ \n"
"\n"
" insert_buffer_size: 4 # GB, maximum insert buffer size allowed\n"
" preload_table: \n"
"\n"
"storage_config:\n"
" primary_path: /tmp/milvus_web_controller_test # path used to store data and meta\n"
" secondary_path: # path used to store data only, split by semicolon\n"
"\n"
"metric_config:\n"
" enable_monitor: false # enable monitoring or not\n"
" address: 127.0.0.1\n"
" port: 8080 # port prometheus uses to fetch metrics\n"
"\n"
"cache_config:\n"
" cpu_cache_capacity: 4 # GB, CPU memory used for cache\n"
" cpu_cache_threshold: 0.85 \n"
" cache_insert_data: false # whether to load inserted data into cache\n"
"\n"
"engine_config:\n"
" use_blas_threshold: 20 \n"
"\n"
#ifdef MILVUS_GPU_VERSION
"gpu_resource_config:\n"
" enable: true # whether to enable GPU resources\n"
" cache_capacity: 4 # GB, size of GPU memory per card used for cache, must be a positive integer\n"
" search_resources: # define the GPU devices used for search computation, must be in format gpux\n"
" - gpu0\n"
" build_index_resources: # define the GPU devices used for index building, must be in format gpux\n"
" - gpu0\n"
#endif
"\n";
static const char* CONTROLLER_TEST_TABLE_NAME = "controller_unit_test";
static const char* CONTROLLER_TEST_CONFIG_DIR = "/tmp/milvus_web_controller_test/";
static const char* CONTROLLER_TEST_CONFIG_FILE = "config.yaml";
class TestClient : public oatpp::web::client::ApiClient {
public:
@ -445,11 +518,8 @@ class TestClient : public oatpp::web::client::ApiClient {
API_CALL("OPTIONS", "/tables/{table_name}/indexes", optionsIndexes, PATH(String, table_name, "table_name"))
API_CALL("POST",
"/tables/{table_name}/indexes",
createIndex,
PATH(String, table_name, "table_name"),
BODY_DTO(milvus::server::web::IndexRequestDto::ObjectWrapper, body))
API_CALL("POST", "/tables/{table_name}/indexes",createIndex,
PATH(String, table_name, "table_name"), BODY_DTO(milvus::server::web::IndexRequestDto::ObjectWrapper, body))
API_CALL("GET", "/tables/{table_name}/indexes", getIndex, PATH(String, table_name, "table_name"))
@ -505,6 +575,15 @@ class WebControllerTest : public testing::Test {
protected:
static void
SetUpTestCase() {
// Basic config
std::string config_path = std::string(CONTROLLER_TEST_CONFIG_DIR).append(CONTROLLER_TEST_CONFIG_FILE);
std::fstream fs(config_path.c_str(), std::ios_base::out);
fs << CONTROLLER_TEST_VALID_CONFIG_STR;
fs.close();
milvus::server::Config& config = milvus::server::Config::GetInstance();
config.LoadConfigFile(config_path);
auto res_mgr = milvus::scheduler::ResMgrInst::GetInstance();
res_mgr->Clear();
res_mgr->Add(milvus::scheduler::ResourceFactory::Create("disk", "DISK", 0, false));
@ -522,13 +601,8 @@ class WebControllerTest : public testing::Test {
milvus::engine::DBOptions opt;
milvus::server::Config::GetInstance().SetDBConfigBackendUrl("sqlite://:@:/");
boost::filesystem::remove_all("/tmp/milvus_web_controller_test");
milvus::server::Config::GetInstance().SetStorageConfigPrimaryPath("/tmp/milvus_web_controller_test");
milvus::server::Config::GetInstance().SetStorageConfigSecondaryPath("");
milvus::server::Config::GetInstance().SetDBConfigArchiveDiskThreshold("");
milvus::server::Config::GetInstance().SetDBConfigArchiveDaysThreshold("");
milvus::server::Config::GetInstance().SetCacheConfigCacheInsertData("");
milvus::server::Config::GetInstance().SetEngineConfigOmpThreadNum("");
boost::filesystem::remove_all(CONTROLLER_TEST_CONFIG_DIR);
milvus::server::Config::GetInstance().SetStorageConfigPrimaryPath(CONTROLLER_TEST_CONFIG_DIR);
milvus::server::DBWrapper::GetInstance().StartService();
@ -547,7 +621,7 @@ class WebControllerTest : public testing::Test {
milvus::scheduler::JobMgrInst::GetInstance()->Stop();
milvus::scheduler::ResMgrInst::GetInstance()->Stop();
milvus::scheduler::SchedInst::GetInstance()->Stop();
boost::filesystem::remove_all("/tmp/milvus_web_controller_test");
boost::filesystem::remove_all(CONTROLLER_TEST_CONFIG_DIR);
}
void
@ -638,16 +712,16 @@ TEST_F(WebControllerTest, CREATE_TABLE) {
auto table_dto = milvus::server::web::TableRequestDto::createShared();
auto response = client_ptr->createTable(table_dto, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
auto result_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code) << result_dto->message->std_str();
auto error_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code) << error_dto->message->std_str();
OString table_name = "web_test_create_table" + OString(RandomName().c_str());
table_dto->table_name = table_name;
response = client_ptr->createTable(table_dto, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
result_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code) << result_dto->message->std_str();
error_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code) << error_dto->message->std_str();
table_dto->dimension = 128;
table_dto->index_file_size = 10;
@ -655,6 +729,8 @@ TEST_F(WebControllerTest, CREATE_TABLE) {
response = client_ptr->createTable(table_dto, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
auto result_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::SUCCESS, result_dto->code->getValue()) << result_dto->message->std_str();
// invalid table name
table_dto->table_name = "9090&*&()";
@ -671,29 +747,32 @@ TEST_F(WebControllerTest, GET_TABLE) {
// fields value is 'num', test count table
params.put("fields", "num");
auto response = client_ptr->getTable(table_name, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
auto result_dto = response->readBodyToDto<milvus::server::web::TableFieldsDto>(object_mapper.get());
response = client_ptr->getTable(table_name, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
ASSERT_EQ(table_name->std_str(), result_dto->table_name->std_str());
ASSERT_EQ(10, result_dto->dimension);
ASSERT_EQ("L2", result_dto->metric_type->std_str());
ASSERT_EQ(10, result_dto->index_file_size->getValue());
ASSERT_EQ("FLAT", result_dto->index->std_str());
// invalid table name
table_name = "57474dgdfhdfhdh dgd";
response = client_ptr->getTable(table_name, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
auto status_sto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::ILLEGAL_TABLE_NAME, status_sto->code->getValue());
table_name = "test_table_not_found_0000000001110101010020202030203030435";
table_name = "test_table_not_found_000000000111010101002020203020aaaaa3030435";
response = client_ptr->getTable(table_name, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_404.code, response->getStatusCode());
status_sto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
}
TEST_F(WebControllerTest, SHOW_TABLES) {
// test query table limit 1
auto response = client_ptr->showTables(1, 1, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
auto result_dto = response->readBodyToDto<milvus::server::web::TableListFieldsDto>(object_mapper.get());
ASSERT_TRUE(result_dto->count->getValue() > 0);
// test query table empty
response = client_ptr->showTables(0, 0, conncetion_ptr);
@ -734,6 +813,24 @@ TEST_F(WebControllerTest, INSERT) {
ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode());
}
TEST_F(WebControllerTest, INSERT_BIN) {
auto table_name = "test_insert_bin_table_test" + OString(RandomName().c_str());
const int64_t dim = 64;
GenTable(table_name, dim, 100, "HAMMING");
auto insert_dto = milvus::server::web::InsertRequestDto::createShared();
insert_dto->ids = insert_dto->ids->createShared();
insert_dto->records_bin = RandomBinRecordsDto(dim, 20);
auto response = client_ptr->insert(table_name, insert_dto, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
auto result_dto = response->readBodyToDto<milvus::server::web::VectorIdsDto>(object_mapper.get());
ASSERT_EQ(20, result_dto->ids->count());
response = client_ptr->dropTable(table_name, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode());
}
TEST_F(WebControllerTest, INSERT_IDS) {
auto table_name = "test_insert_table_test" + OString(RandomName().c_str());
const int64_t dim = 64;
@ -764,6 +861,9 @@ TEST_F(WebControllerTest, INDEX) {
auto index_dto = milvus::server::web::IndexRequestDto::createShared();
auto response = client_ptr->createIndex(table_name, index_dto, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
auto create_index_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::SUCCESS, create_index_dto->code);
// drop index
response = client_ptr->dropIndex(table_name, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode());
@ -795,7 +895,6 @@ TEST_F(WebControllerTest, INDEX) {
// invalid index type
index_dto->index_type = 100;
response = client_ptr->createIndex(table_name, index_dto, conncetion_ptr);
ASSERT_NE(OStatus::CODE_201.code, response->getStatusCode());
ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
// insert data and create index
@ -816,6 +915,9 @@ TEST_F(WebControllerTest, INDEX) {
// get index
response = client_ptr->getIndex(table_name, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
auto result_index_dto = response->readBodyToDto<milvus::server::web::IndexDto>(object_mapper.get());
ASSERT_EQ("FLAT", result_index_dto->index_type->std_str());
ASSERT_EQ(10, result_index_dto->nlist->getValue());
}
TEST_F(WebControllerTest, PARTITION) {
@ -825,23 +927,25 @@ TEST_F(WebControllerTest, PARTITION) {
auto par_param = milvus::server::web::PartitionRequestDto::createShared();
auto response = client_ptr->createPartition(table_name, par_param);
ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
auto result_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code);
auto error_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code);
par_param->partition_name = "partition01" + OString(RandomName().c_str());
response = client_ptr->createPartition(table_name, par_param);
result_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code);
ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
error_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code);
par_param->partition_tag = "tag01";
response = client_ptr->createPartition(table_name, par_param);
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
auto create_result_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::SUCCESS, create_result_dto->code);
// insert 200 vectors into table with tag = 'tag01'
OQueryParams query_params;
// add partition tag
auto insert_dto = milvus::server::web::InsertRequestDto::createShared();
// add partition tag
insert_dto->tag = OString("tag01");
insert_dto->ids = insert_dto->ids->createShared();
insert_dto->records = insert_dto->records->createShared();
@ -854,13 +958,17 @@ TEST_F(WebControllerTest, PARTITION) {
// Show all partitins
response = client_ptr->showPartitions(table_name, 0, 10, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
auto result_dto = response->readBodyToDto<milvus::server::web::PartitionListDto>(object_mapper.get());
ASSERT_EQ(1, result_dto->partitions->count());
ASSERT_EQ("tag01", result_dto->partitions->get(0)->partition_tag->std_str());
ASSERT_EQ(par_param->partition_name->std_str(), result_dto->partitions->get(0)->partition_name->std_str());
response = client_ptr->dropPartition(table_name, "tag01", conncetion_ptr);
ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode());
}
TEST_F(WebControllerTest, SEARCH) {
const OString table_name = "test_partition_table_test" + OString(RandomName().c_str());
const OString table_name = "test_search_table_test" + OString(RandomName().c_str());
GenTable(table_name, 64, 100, "L2");
// Insert 200 vectors into table
@ -869,6 +977,67 @@ TEST_F(WebControllerTest, SEARCH) {
insert_dto->ids = insert_dto->ids->createShared();
insert_dto->records = RandomRecordsDto(64, 200);// insert_dto->records->createShared();
auto response = client_ptr->insert(table_name, insert_dto, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
auto insert_result_dto = response->readBodyToDto<milvus::server::web::VectorIdsDto>(object_mapper.get());
ASSERT_EQ(200, insert_result_dto->ids->count());
sleep(4);
//Create partition and insert 200 vectors into it
auto par_param = milvus::server::web::PartitionRequestDto::createShared();
par_param->partition_name = "partition" + OString(RandomName().c_str());
par_param->partition_tag = "tag" + OString(RandomName().c_str());
response = client_ptr->createPartition(table_name, par_param);
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode())
<< "Error: " << response->getStatusDescription()->std_str();
insert_dto->tag = par_param->partition_tag;
response = client_ptr->insert(table_name, insert_dto, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
sleep(2);
// Test search
auto search_request_dto = milvus::server::web::SearchRequestDto::createShared();
response = client_ptr->search(table_name, search_request_dto, conncetion_ptr);
auto error_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code);
search_request_dto->nprobe = 1;
response = client_ptr->search(table_name, search_request_dto, conncetion_ptr);
error_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code);
search_request_dto->topk = 1;
response = client_ptr->search(table_name, search_request_dto, conncetion_ptr);
error_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code);
search_request_dto->records = RandomRecordsDto(64, 10);
response = client_ptr->search(table_name, search_request_dto, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
auto result_dto = response->readBodyToDto<milvus::server::web::TopkResultsDto>(object_mapper.get());
ASSERT_EQ(10, result_dto->num);
ASSERT_EQ(10, result_dto->results->count());
ASSERT_EQ(1, result_dto->results->get(0)->count());
// Test search with tags
search_request_dto->tags = search_request_dto->tags->createShared();
search_request_dto->tags->pushBack(par_param->partition_tag);
response = client_ptr->search(table_name, search_request_dto, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
}
TEST_F(WebControllerTest, SEARCH_BIN) {
const OString table_name = "test_search_bin_table_test" + OString(RandomName().c_str());
GenTable(table_name, 64, 100, "HAMMING");
// Insert 200 vectors into table
OQueryParams query_params;
auto insert_dto = milvus::server::web::InsertRequestDto::createShared();
insert_dto->ids = insert_dto->ids->createShared();
insert_dto->records_bin = RandomBinRecordsDto(64, 200);
auto response = client_ptr->insert(table_name, insert_dto, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
@ -903,7 +1072,7 @@ TEST_F(WebControllerTest, SEARCH) {
result_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code);
search_request_dto->records = RandomRecordsDto(64, 10);
search_request_dto->records_bin = RandomBinRecordsDto(64, 10);
response = client_ptr->search(table_name, search_request_dto, conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());

View File

@ -258,8 +258,8 @@ class TestShowBase:
'''
partition_name = gen_unique_str()
status, res = connect.show_partitions(partition_name)
assert status.OK()
assert len(res) == 0
assert not status.OK()
# assert len(res) == 0
def test_show_multi_partitions(self, connect, table):
'''
@ -428,4 +428,4 @@ class TestNameInvalid(object):
partition_name = gen_unique_str()
status = connect.create_partition(table, partition_name, tag)
status, res = connect.show_partitions(table_name)
assert not status.OK()
assert not status.OK()