mirror of https://github.com/milvus-io/milvus.git
parent
5787a2af4d
commit
273863f54d
|
@ -123,25 +123,8 @@ XSearchTask::XSearchTask(const std::shared_ptr<server::Context>& context, Segmen
|
|||
milvus::json json_params;
|
||||
if (!file_->index_params_.empty()) {
|
||||
json_params = milvus::json::parse(file_->index_params_);
|
||||
if (json_params.contains(knowhere::Metric::TYPE) &&
|
||||
(engine_type == EngineType::FAISS_BIN_IDMAP || engine_type == EngineType::FAISS_IDMAP))
|
||||
ascending_reduce = json_params[knowhere::Metric::TYPE] != static_cast<int>(MetricType::IP);
|
||||
}
|
||||
// if (auto job = job_.lock()) {
|
||||
// auto search_job = std::static_pointer_cast<scheduler::SearchJob>(job);
|
||||
// query::GeneralQueryPtr general_query = search_job->general_query();
|
||||
// if (general_query != nullptr) {
|
||||
// std::unordered_map<std::string, engine::DataType> types;
|
||||
// auto attr_type = search_job->attr_type();
|
||||
// auto type_it = attr_type.begin();
|
||||
// for (; type_it != attr_type.end(); type_it++) {
|
||||
// types.insert(std::make_pair(type_it->first, (engine::DataType)(type_it->second)));
|
||||
// }
|
||||
// index_engine_ =
|
||||
// EngineFactory::Build(file_->dimension_, file_->location_, engine_type,
|
||||
// (MetricType)file_->metric_type_, types, json_params);
|
||||
// }
|
||||
// }
|
||||
|
||||
index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, engine_type,
|
||||
(MetricType)file_->metric_type_, json_params);
|
||||
}
|
||||
|
@ -195,8 +178,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) {
|
|||
|
||||
if (auto job = job_.lock()) {
|
||||
auto search_job = std::static_pointer_cast<scheduler::SearchJob>(job);
|
||||
search_job->SearchDone(file_->id_);
|
||||
search_job->GetStatus() = s;
|
||||
search_job->SearchDone(file_->id_);
|
||||
}
|
||||
|
||||
return;
|
||||
|
@ -249,6 +232,44 @@ XSearchTask::Execute() {
|
|||
const milvus::json& extra_params = search_job->extra_params();
|
||||
const engine::VectorsData& vectors = search_job->vectors();
|
||||
|
||||
auto engine_type = index_engine_->IndexEngineType();
|
||||
if (engine_type == EngineType::FAISS_IDMAP || engine_type == EngineType::FAISS_BIN_IDMAP) {
|
||||
// allow to assign a metric type in IDMAP and BIN_IDMAP
|
||||
if (extra_params.contains(knowhere::Metric::TYPE)) {
|
||||
auto metric_type = extra_params[knowhere::Metric::TYPE].get<int64_t>();
|
||||
LOG_ENGINE_DEBUG_ << "User's metric type " << metric_type;
|
||||
|
||||
auto Illegal_Metric_Type = [&]() {
|
||||
std::string msg = "Illegal metric type" + metric_type;
|
||||
search_job->GetStatus() = Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
search_job->SearchDone(index_id_);
|
||||
};
|
||||
|
||||
if (engine_type == EngineType::FAISS_IDMAP) {
|
||||
if (metric_type == static_cast<int64_t>(MetricType::IP)) {
|
||||
ascending_reduce = false;
|
||||
} else if (metric_type == static_cast<int64_t>(MetricType::L2)) {
|
||||
// do nothing
|
||||
} else {
|
||||
Illegal_Metric_Type();
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// FAISS_BIN_IDMAP
|
||||
if (metric_type == static_cast<int64_t>(MetricType::HAMMING) ||
|
||||
metric_type == static_cast<int64_t>(MetricType::JACCARD) ||
|
||||
metric_type == static_cast<int64_t>(MetricType::TANIMOTO) ||
|
||||
metric_type == static_cast<int64_t>(MetricType::SUBSTRUCTURE) ||
|
||||
metric_type == static_cast<int64_t>(MetricType::SUPERSTRUCTURE)) {
|
||||
// do nothing
|
||||
} else {
|
||||
Illegal_Metric_Type();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output_ids.resize(topk * nq);
|
||||
output_distance.resize(topk * nq);
|
||||
std::string hdr =
|
||||
|
|
Loading…
Reference in New Issue