fix users' metric type (#3568)

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
pull/3609/head
shengjun.li 2020-09-03 10:21:26 +08:00 committed by GitHub
parent 5787a2af4d
commit 273863f54d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 40 additions and 19 deletions

View File

@ -123,25 +123,8 @@ XSearchTask::XSearchTask(const std::shared_ptr<server::Context>& context, Segmen
milvus::json json_params;
if (!file_->index_params_.empty()) {
json_params = milvus::json::parse(file_->index_params_);
if (json_params.contains(knowhere::Metric::TYPE) &&
(engine_type == EngineType::FAISS_BIN_IDMAP || engine_type == EngineType::FAISS_IDMAP))
ascending_reduce = json_params[knowhere::Metric::TYPE] != static_cast<int>(MetricType::IP);
}
// if (auto job = job_.lock()) {
// auto search_job = std::static_pointer_cast<scheduler::SearchJob>(job);
// query::GeneralQueryPtr general_query = search_job->general_query();
// if (general_query != nullptr) {
// std::unordered_map<std::string, engine::DataType> types;
// auto attr_type = search_job->attr_type();
// auto type_it = attr_type.begin();
// for (; type_it != attr_type.end(); type_it++) {
// types.insert(std::make_pair(type_it->first, (engine::DataType)(type_it->second)));
// }
// index_engine_ =
// EngineFactory::Build(file_->dimension_, file_->location_, engine_type,
// (MetricType)file_->metric_type_, types, json_params);
// }
// }
index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, engine_type,
(MetricType)file_->metric_type_, json_params);
}
@ -195,8 +178,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) {
if (auto job = job_.lock()) {
auto search_job = std::static_pointer_cast<scheduler::SearchJob>(job);
search_job->SearchDone(file_->id_);
search_job->GetStatus() = s;
search_job->SearchDone(file_->id_);
}
return;
@ -249,6 +232,44 @@ XSearchTask::Execute() {
const milvus::json& extra_params = search_job->extra_params();
const engine::VectorsData& vectors = search_job->vectors();
auto engine_type = index_engine_->IndexEngineType();
if (engine_type == EngineType::FAISS_IDMAP || engine_type == EngineType::FAISS_BIN_IDMAP) {
// allow to assign a metric type in IDMAP and BIN_IDMAP
if (extra_params.contains(knowhere::Metric::TYPE)) {
auto metric_type = extra_params[knowhere::Metric::TYPE].get<int64_t>();
LOG_ENGINE_DEBUG_ << "User's metric type " << metric_type;
auto Illegal_Metric_Type = [&]() {
std::string msg = "Illegal metric type" + metric_type;
search_job->GetStatus() = Status(SERVER_INVALID_ARGUMENT, msg);
search_job->SearchDone(index_id_);
};
if (engine_type == EngineType::FAISS_IDMAP) {
if (metric_type == static_cast<int64_t>(MetricType::IP)) {
ascending_reduce = false;
} else if (metric_type == static_cast<int64_t>(MetricType::L2)) {
// do nothing
} else {
Illegal_Metric_Type();
return;
}
} else {
// FAISS_BIN_IDMAP
if (metric_type == static_cast<int64_t>(MetricType::HAMMING) ||
metric_type == static_cast<int64_t>(MetricType::JACCARD) ||
metric_type == static_cast<int64_t>(MetricType::TANIMOTO) ||
metric_type == static_cast<int64_t>(MetricType::SUBSTRUCTURE) ||
metric_type == static_cast<int64_t>(MetricType::SUPERSTRUCTURE)) {
// do nothing
} else {
Illegal_Metric_Type();
return;
}
}
}
}
output_ids.resize(topk * nq);
output_distance.resize(topk * nq);
std::string hdr =