From cbba262442a42cc22ae24e00842d901b58321c97 Mon Sep 17 00:00:00 2001 From: yukun Date: Thu, 6 Aug 2020 19:08:42 +0800 Subject: [PATCH] Fix metric_type bug in search (#3155) * Fix dsl test case nb bug Signed-off-by: fishpenguin * Fix dsl test case bug Signed-off-by: fishpenguin * Add metric_type judge in search Signed-off-by: fishpenguin * Fix test_search.py Signed-off-by: fishpenguin * Fix test_db Signed-off-by: fishpenguin * Fix search metric_type Signed-off-by: fishpenguin * ci retry Signed-off-by: fishpenguin Co-authored-by: Wang Xiangyu --- core/src/db/engine/ExecutionEngineImpl.cpp | 3 +++ core/src/query/GeneralQuery.h | 1 + core/src/server/grpc_impl/GrpcRequestHandler.cpp | 8 ++++++++ core/src/utils/Json.h | 7 +++++++ 4 files changed, 19 insertions(+) diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index 9262594bc3..0d7e5d3cba 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -269,6 +269,7 @@ ExecutionEngineImpl::VecSearch(milvus::engine::ExecutionEngineContext& context, milvus::json conf = vector_param->extra_params; conf[knowhere::meta::TOPK] = topk; + conf[knowhere::Metric::TYPE] = vector_param->metric_type; auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(vec_index->index_type()); if (!adapter->CheckSearch(conf, vec_index->index_type(), vec_index->index_mode())) { LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] Illegal search params", "search", 0); @@ -498,6 +499,7 @@ ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, const std::unordered_map& attr_type) { auto status = Status::OK(); auto term_query_json = term_query->json_obj; + JSON_NULL_CHECK(term_query_json); auto term_it = term_query_json.begin(); if (term_it != term_query_json.end()) { const std::string& field_name = term_it.key(); @@ -578,6 +580,7 @@ ExecutionEngineImpl::ProcessRangeQuery(const std::unordered_mapjson_obj; + JSON_NULL_CHECK(range_query_json); auto range_it = range_query_json.begin(); if (range_it != range_query_json.end()) { const std::string& field_name = range_it.key(); diff --git a/core/src/query/GeneralQuery.h b/core/src/query/GeneralQuery.h index a3196fbe8a..eae08a2e14 100644 --- a/core/src/query/GeneralQuery.h +++ b/core/src/query/GeneralQuery.h @@ -79,6 +79,7 @@ struct VectorQuery { milvus::json extra_params = {}; int64_t topk; int64_t nq; + std::string metric_type = ""; float boost; VectorRecord query_vector; }; diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.cpp b/core/src/server/grpc_impl/GrpcRequestHandler.cpp index eb6653bd4b..2edacb9c64 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/core/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -1568,6 +1568,7 @@ GrpcRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, query::Bool auto leaf_query = std::make_shared(); auto term_query = std::make_shared(); nlohmann::json json_obj = json["term"]; + JSON_NULL_CHECK(json_obj); term_query->json_obj = json_obj; nlohmann::json::iterator json_it = json_obj.begin(); field_name = json_it.key(); @@ -1578,6 +1579,7 @@ GrpcRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, query::Bool auto leaf_query = std::make_shared(); auto range_query = std::make_shared(); nlohmann::json json_obj = json["range"]; + JSON_NULL_CHECK(json_obj); range_query->json_obj = json_obj; nlohmann::json::iterator json_it = json_obj.begin(); field_name = json_it.key(); @@ -1587,9 +1589,12 @@ GrpcRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, query::Bool } else if (json.contains("vector")) { auto leaf_query = std::make_shared(); auto vector_json = json["vector"]; + JSON_NULL_CHECK(vector_json); leaf_query->vector_placeholder = vector_json.get(); query->AddLeafQuery(leaf_query); + } else { + return Status{SERVER_INVALID_ARGUMENT, "Leaf query get wrong key"}; } return status; } @@ -1704,6 +1709,8 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery( } vector_query->topk = topk; if (vector_json.contains("metric_type")) { + std::string metric_type = vector_json["metric_type"]; + vector_query->metric_type = metric_type; query_ptr->metric_types.insert({field_name, vector_json["metric_type"]}); } if (!vector_param_it.value()["params"].empty()) { @@ -1722,6 +1729,7 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery( } if (dsl_json.contains("bool")) { auto boolean_query_json = dsl_json["bool"]; + JSON_NULL_CHECK(boolean_query_json); status = ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr); if (!status.ok()) { return status; diff --git a/core/src/utils/Json.h b/core/src/utils/Json.h index 95a2c70dcb..95bd0754cf 100644 --- a/core/src/utils/Json.h +++ b/core/src/utils/Json.h @@ -17,4 +17,11 @@ namespace milvus { using json = nlohmann::json; +#define JSON_NULL_CHECK(json) \ + do { \ + if (json.empty()) { \ + return Status{SERVER_INVALID_ARGUMENT, "Json is null"}; \ + } \ + } while (false) + } // namespace milvus