diff --git a/core/src/server/ValidationUtil.cpp b/core/src/server/ValidationUtil.cpp index cd50505a44..329d18c3c9 100644 --- a/core/src/server/ValidationUtil.cpp +++ b/core/src/server/ValidationUtil.cpp @@ -390,6 +390,28 @@ ValidateIndexMetricType(const std::string& metric_type, const std::string& index return Status::OK(); } +Status +ValidateSearchMetricType(const std::string& metric_type, bool is_binary) { + if (is_binary) { + // binary + if (metric_type == knowhere::Metric::L2 || metric_type == knowhere::Metric::IP) { + std::string msg = "Cannot search binary entities with index metric type " + metric_type; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + } else { + // float + if (metric_type == knowhere::Metric::HAMMING || metric_type == knowhere::Metric::JACCARD || + metric_type == knowhere::Metric::TANIMOTO) { + std::string msg = "Cannot search float entities with index metric type " + metric_type; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + } + + return Status::OK(); +} + Status ValidateSearchTopk(int64_t top_k) { if (top_k <= 0 || top_k > QUERY_MAX_TOPK) { diff --git a/core/src/server/ValidationUtil.h b/core/src/server/ValidationUtil.h index b0c0706593..79705aa302 100644 --- a/core/src/server/ValidationUtil.h +++ b/core/src/server/ValidationUtil.h @@ -44,6 +44,9 @@ ValidateSegmentRowCount(int64_t segment_row_count); extern Status ValidateIndexMetricType(const std::string& metric_type, const std::string& index_type); +extern Status +ValidateSearchMetricType(const std::string& metric_type, bool is_binary); + extern Status ValidateSearchTopk(int64_t top_k); diff --git a/core/src/server/delivery/request/SearchReq.cpp b/core/src/server/delivery/request/SearchReq.cpp index 01308d5d79..1632d4c39b 100644 --- a/core/src/server/delivery/request/SearchReq.cpp +++ b/core/src/server/delivery/request/SearchReq.cpp @@ -79,6 +79,12 @@ SearchReq::OnExecute() { if (field->GetFtype() == (int)engine::DataType::VECTOR_FLOAT || field->GetFtype() == (int)engine::DataType::VECTOR_BINARY) { dimension = field->GetParams()[engine::PARAM_DIMENSION]; + // validate search metric type and DataType match + bool is_binary = (field->GetFtype() == (int)engine::DataType::VECTOR_FLOAT) ? false : true; + if (query_ptr_->metric_types.find(field->GetName()) != query_ptr_->metric_types.end()) { + auto metric_type = query_ptr_->metric_types.at(field->GetName()); + STATUS_CHECK(ValidateSearchMetricType(metric_type, is_binary)); + } } } diff --git a/tests/milvus_python_test/test_index.py b/tests/milvus_python_test/test_index.py index 8a8e542a5b..871dbe931c 100644 --- a/tests/milvus_python_test/test_index.py +++ b/tests/milvus_python_test/test_index.py @@ -561,8 +561,9 @@ class TestIndexBinary: nq = get_nq ids = connect.insert(binary_collection, binary_entities) connect.create_index(binary_collection, binary_field_name, get_jaccard_index) - query, vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq) + query, vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq, metric_type="JACCARD") search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD") + logging.getLogger().info(search_param) res = connect.search(binary_collection, query, search_params=search_param) assert len(res) == nq