Fix getEntity when no query_params and payload is provided in webserver (#3600)

* Fix getEntity when no query_params and payload is provided in webserver

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Change entities object to json::array

Signed-off-by: fishpenguin <kun.yu@zilliz.com>
Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
pull/3745/head
yukun 2020-09-05 10:39:20 +08:00 committed by shengjun.li
parent feb0cf3daf
commit b806440985
2 changed files with 50 additions and 11 deletions

View File

@ -368,6 +368,7 @@ WebRequestHandler::GetPageEntities(const std::string& collection_name, const std
real_offset = 0;
}
if (segment_ids.empty()) {
json_out["entities"] = json::array();
return Status::OK();
}
std::vector<std::string> field_names;
@ -923,6 +924,11 @@ WebRequestHandler::GetEntityByIDs(const std::string& collection_name, const std:
engine::DataChunkPtr data_chunk;
engine::snapshot::FieldElementMappings field_mappings;
if (ids.empty()) {
json_out["entities"] = {};
return Status::OK();
}
auto status = req_handler_.GetEntityByID(context_ptr_, collection_name, ids, field_names, valid_row, field_mappings,
data_chunk);
if (!status.ok()) {
@ -1761,6 +1767,9 @@ WebRequestHandler::GetEntity(const milvus::server::web::OString& collection_name
partition_tag = query_params.get("partition_tag")->std_str();
}
status = GetPageEntities(collection_name->std_str(), partition_tag, page_size, offset, json_out);
if (!status.ok()) {
json_out["entities"] = json::array();
}
AddStatusToJson(json_out, status.code(), status.message());
response = json_out.dump().c_str();
return status;
@ -1816,16 +1825,25 @@ WebRequestHandler::EntityOp(const OString& collection_name, const OQueryParams&
std::string result_str;
try {
nlohmann::json payload_json;
if (!payload->std_str().empty()) {
payload_json = nlohmann::json::parse(payload->std_str());
}
if (query_params.get("offset") || query_params.get("page_size") || query_params.get("ids")) {
status = GetEntity(collection_name, query_params, response);
ASSIGN_RETURN_STATUS_DTO(status);
} else {
nlohmann::json payload_json = nlohmann::json::parse(payload->std_str());
} else if (!payload_json.empty()) {
if (payload_json.contains("query")) {
status = Search(collection_name->c_str(), payload_json, result_str);
} else {
status = Status(ILLEGAL_BODY, "Unknown body");
status = Status(ILLEGAL_BODY, "Unknown payload");
}
} else {
OQueryParams self_params;
self_params.put("offset", "0");
self_params.put("page_size", "10");
status = GetEntity(collection_name, self_params, response);
ASSIGN_RETURN_STATUS_DTO(status)
}
} catch (nlohmann::detail::parse_error& e) {
std::string emsg = "json error: code=" + std::to_string(e.id) + ", reason=" + e.what();

View File

@ -295,8 +295,11 @@ class TestClient : public oatpp::web::client::ApiClient {
API_CALL("GET", "/collections/{collection_name}/entities", getEntityByID,
PATH(String, collection_name, "collection_name"), QUERY(String, ids))
API_CALL("GET", "/collections/{collection_name}/entities", search,
PATH(String, collection_name, "collection_name"), BODY_STRING(String, body))
API_CALL("GET", "/collections/{collection_name}/entities", search, PATH(String, collection_name, "collection_name"),
BODY_STRING(String, body))
API_CALL("GET", "/collections/{collection_name}/entities", getEntityWithNoParams,
PATH(String, collection_name, "collection_name"))
API_CALL("POST", "/collections/{collection_name}/entities", insert,
PATH(String, collection_name, "collection_name"), BODY_STRING(String, body))
@ -327,16 +330,16 @@ class WebControllerTest : public ::testing::Test {
fs.close();
milvus::ConfigMgr::GetInstance().Init();
// milvus::ConfigMgr::GetInstance().Set("general.meta_uri", "mock://:@:/");
// milvus::ConfigMgr::GetInstance().Set("storage.path", CONTROLLER_TEST_CONFIG_DIR);
// milvus::ConfigMgr::GetInstance().Set("network.http.enable", "true");
// milvus::ConfigMgr::GetInstance().Set("network.http.port", "20121");
// milvus::ConfigMgr::GetInstance().Set("general.meta_uri", "mock://:@:/");
// milvus::ConfigMgr::GetInstance().Set("storage.path", CONTROLLER_TEST_CONFIG_DIR);
// milvus::ConfigMgr::GetInstance().Set("network.http.enable", "true");
// milvus::ConfigMgr::GetInstance().Set("network.http.port", "20121");
auto& config = milvus::ConfigMgr::GetInstance();
// milvus::ConfigMgr::GetInstance().Init();
// milvus::ConfigMgr::GetInstance().Init();
config.LoadFile(config_path);
// milvus::ConfigMgr::GetInstance().Set("general.meta_uri", "mock://:@:/");
// milvus::ConfigMgr::GetInstance().Set("general.meta_uri", "mock://:@:/");
milvus::engine::snapshot::Snapshots::GetInstance().StartService();
@ -743,6 +746,24 @@ TEST_F(WebControllerTest, GET_PAGE_ENTITY) {
// ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
}
TEST_F(WebControllerTest, GET_ENTITY_WITH_NO_PARAMS) {
auto collection_name = "test_get_collection_test" + RandomName();
nlohmann::json mapping_json;
CreateCollection(client_ptr, connection_ptr, collection_name, mapping_json);
const int64_t dim = DIM;
const int64_t nb = 3;
nlohmann::json insert_json;
GenEntities(nb, dim, insert_json);
auto status = FlushCollection(client_ptr, connection_ptr, OString(collection_name.c_str()));
ASSERT_TRUE(status.ok());
auto response = client_ptr->getEntity(collection_name.c_str(), "", "", "", connection_ptr);
std::cout << response->readBodyToString()->std_str() << std::endl;
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
}
TEST_F(WebControllerTest, SYSTEM_INFO) {
std::string req = R"(
{