mirror of https://github.com/milvus-io/milvus.git
* add check Signed-off-by: sahuang <xiaohai.xu@zilliz.com>pull/3367/head
parent
1938d6b769
commit
b437639b7a
|
@ -390,6 +390,28 @@ ValidateIndexMetricType(const std::string& metric_type, const std::string& index
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ValidateSearchMetricType(const std::string& metric_type, bool is_binary) {
|
||||
if (is_binary) {
|
||||
// binary
|
||||
if (metric_type == knowhere::Metric::L2 || metric_type == knowhere::Metric::IP) {
|
||||
std::string msg = "Cannot search binary entities with index metric type " + metric_type;
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}
|
||||
} else {
|
||||
// float
|
||||
if (metric_type == knowhere::Metric::HAMMING || metric_type == knowhere::Metric::JACCARD ||
|
||||
metric_type == knowhere::Metric::TANIMOTO) {
|
||||
std::string msg = "Cannot search float entities with index metric type " + metric_type;
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ValidateSearchTopk(int64_t top_k) {
|
||||
if (top_k <= 0 || top_k > QUERY_MAX_TOPK) {
|
||||
|
|
|
@ -44,6 +44,9 @@ ValidateSegmentRowCount(int64_t segment_row_count);
|
|||
extern Status
|
||||
ValidateIndexMetricType(const std::string& metric_type, const std::string& index_type);
|
||||
|
||||
extern Status
|
||||
ValidateSearchMetricType(const std::string& metric_type, bool is_binary);
|
||||
|
||||
extern Status
|
||||
ValidateSearchTopk(int64_t top_k);
|
||||
|
||||
|
|
|
@ -79,6 +79,12 @@ SearchReq::OnExecute() {
|
|||
if (field->GetFtype() == (int)engine::DataType::VECTOR_FLOAT ||
|
||||
field->GetFtype() == (int)engine::DataType::VECTOR_BINARY) {
|
||||
dimension = field->GetParams()[engine::PARAM_DIMENSION];
|
||||
// validate search metric type and DataType match
|
||||
bool is_binary = (field->GetFtype() == (int)engine::DataType::VECTOR_FLOAT) ? false : true;
|
||||
if (query_ptr_->metric_types.find(field->GetName()) != query_ptr_->metric_types.end()) {
|
||||
auto metric_type = query_ptr_->metric_types.at(field->GetName());
|
||||
STATUS_CHECK(ValidateSearchMetricType(metric_type, is_binary));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -561,8 +561,9 @@ class TestIndexBinary:
|
|||
nq = get_nq
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
query, vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq)
|
||||
query, vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq, metric_type="JACCARD")
|
||||
search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD")
|
||||
logging.getLogger().info(search_param)
|
||||
res = connect.search(binary_collection, query, search_params=search_param)
|
||||
assert len(res) == nq
|
||||
|
||||
|
|
Loading…
Reference in New Issue