Fix metric_type bug in search (#3155)

* Fix dsl test case nb bug

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Fix dsl test case bug

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Add metric_type judge in search

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Fix test_search.py

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Fix test_db

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Fix search metric_type

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* ci retry

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

Co-authored-by: Wang Xiangyu <xy.wang@zilliz.com>
pull/3156/head^2
yukun 2020-08-06 19:08:42 +08:00 committed by GitHub
parent 38a3fe766d
commit cbba262442
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 0 deletions

View File

@ -269,6 +269,7 @@ ExecutionEngineImpl::VecSearch(milvus::engine::ExecutionEngineContext& context,
milvus::json conf = vector_param->extra_params; milvus::json conf = vector_param->extra_params;
conf[knowhere::meta::TOPK] = topk; conf[knowhere::meta::TOPK] = topk;
conf[knowhere::Metric::TYPE] = vector_param->metric_type;
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(vec_index->index_type()); auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(vec_index->index_type());
if (!adapter->CheckSearch(conf, vec_index->index_type(), vec_index->index_mode())) { if (!adapter->CheckSearch(conf, vec_index->index_type(), vec_index->index_mode())) {
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] Illegal search params", "search", 0); LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] Illegal search params", "search", 0);
@ -498,6 +499,7 @@ ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, const
std::unordered_map<std::string, DataType>& attr_type) { std::unordered_map<std::string, DataType>& attr_type) {
auto status = Status::OK(); auto status = Status::OK();
auto term_query_json = term_query->json_obj; auto term_query_json = term_query->json_obj;
JSON_NULL_CHECK(term_query_json);
auto term_it = term_query_json.begin(); auto term_it = term_query_json.begin();
if (term_it != term_query_json.end()) { if (term_it != term_query_json.end()) {
const std::string& field_name = term_it.key(); const std::string& field_name = term_it.key();
@ -578,6 +580,7 @@ ExecutionEngineImpl::ProcessRangeQuery(const std::unordered_map<std::string, Dat
auto status = Status::OK(); auto status = Status::OK();
auto range_query_json = range_query->json_obj; auto range_query_json = range_query->json_obj;
JSON_NULL_CHECK(range_query_json);
auto range_it = range_query_json.begin(); auto range_it = range_query_json.begin();
if (range_it != range_query_json.end()) { if (range_it != range_query_json.end()) {
const std::string& field_name = range_it.key(); const std::string& field_name = range_it.key();

View File

@ -79,6 +79,7 @@ struct VectorQuery {
milvus::json extra_params = {}; milvus::json extra_params = {};
int64_t topk; int64_t topk;
int64_t nq; int64_t nq;
std::string metric_type = "";
float boost; float boost;
VectorRecord query_vector; VectorRecord query_vector;
}; };

View File

@ -1568,6 +1568,7 @@ GrpcRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, query::Bool
auto leaf_query = std::make_shared<query::LeafQuery>(); auto leaf_query = std::make_shared<query::LeafQuery>();
auto term_query = std::make_shared<query::TermQuery>(); auto term_query = std::make_shared<query::TermQuery>();
nlohmann::json json_obj = json["term"]; nlohmann::json json_obj = json["term"];
JSON_NULL_CHECK(json_obj);
term_query->json_obj = json_obj; term_query->json_obj = json_obj;
nlohmann::json::iterator json_it = json_obj.begin(); nlohmann::json::iterator json_it = json_obj.begin();
field_name = json_it.key(); field_name = json_it.key();
@ -1578,6 +1579,7 @@ GrpcRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, query::Bool
auto leaf_query = std::make_shared<query::LeafQuery>(); auto leaf_query = std::make_shared<query::LeafQuery>();
auto range_query = std::make_shared<query::RangeQuery>(); auto range_query = std::make_shared<query::RangeQuery>();
nlohmann::json json_obj = json["range"]; nlohmann::json json_obj = json["range"];
JSON_NULL_CHECK(json_obj);
range_query->json_obj = json_obj; range_query->json_obj = json_obj;
nlohmann::json::iterator json_it = json_obj.begin(); nlohmann::json::iterator json_it = json_obj.begin();
field_name = json_it.key(); field_name = json_it.key();
@ -1587,9 +1589,12 @@ GrpcRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, query::Bool
} else if (json.contains("vector")) { } else if (json.contains("vector")) {
auto leaf_query = std::make_shared<query::LeafQuery>(); auto leaf_query = std::make_shared<query::LeafQuery>();
auto vector_json = json["vector"]; auto vector_json = json["vector"];
JSON_NULL_CHECK(vector_json);
leaf_query->vector_placeholder = vector_json.get<std::string>(); leaf_query->vector_placeholder = vector_json.get<std::string>();
query->AddLeafQuery(leaf_query); query->AddLeafQuery(leaf_query);
} else {
return Status{SERVER_INVALID_ARGUMENT, "Leaf query get wrong key"};
} }
return status; return status;
} }
@ -1704,6 +1709,8 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery(
} }
vector_query->topk = topk; vector_query->topk = topk;
if (vector_json.contains("metric_type")) { if (vector_json.contains("metric_type")) {
std::string metric_type = vector_json["metric_type"];
vector_query->metric_type = metric_type;
query_ptr->metric_types.insert({field_name, vector_json["metric_type"]}); query_ptr->metric_types.insert({field_name, vector_json["metric_type"]});
} }
if (!vector_param_it.value()["params"].empty()) { if (!vector_param_it.value()["params"].empty()) {
@ -1722,6 +1729,7 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery(
} }
if (dsl_json.contains("bool")) { if (dsl_json.contains("bool")) {
auto boolean_query_json = dsl_json["bool"]; auto boolean_query_json = dsl_json["bool"];
JSON_NULL_CHECK(boolean_query_json);
status = ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr); status = ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr);
if (!status.ok()) { if (!status.ok()) {
return status; return status;

View File

@ -17,4 +17,11 @@ namespace milvus {
using json = nlohmann::json; using json = nlohmann::json;
#define JSON_NULL_CHECK(json) \
do { \
if (json.empty()) { \
return Status{SERVER_INVALID_ARGUMENT, "Json is null"}; \
} \
} while (false)
} // namespace milvus } // namespace milvus