diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index 07d95fe950..152e59c4c6 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -130,7 +130,7 @@ Status DBImpl::PreloadTable(const std::string &table_id) { for(auto &day_files : files) { for (auto &file : day_files.second) { - ExecutionEnginePtr engine = EngineFactory::Build(file.dimension_, file.location_, (EngineType)file.engine_type_); + ExecutionEnginePtr engine = EngineFactory::Build(file.dimension_, file.location_, (EngineType)file.engine_type_, (MetricType)file.metric_type_, file.nlist_); if(engine == nullptr) { ENGINE_LOG_ERROR << "Invalid engine type"; return Status::Error("Invalid engine type"); @@ -411,7 +411,8 @@ Status DBImpl::MergeFiles(const std::string& table_id, const meta::DateT& date, //step 2: merge files ExecutionEnginePtr index = - EngineFactory::Build(table_file.dimension_, table_file.location_, (EngineType)table_file.engine_type_); + EngineFactory::Build(table_file.dimension_, table_file.location_, (EngineType)table_file.engine_type_, + (MetricType)table_file.metric_type_, table_file.nlist_); meta::TableFilesSchema updated; long index_size = 0; @@ -613,7 +614,9 @@ Status DBImpl::DropIndex(const std::string& table_id) { } Status DBImpl::BuildIndex(const meta::TableFileSchema& file) { - ExecutionEnginePtr to_index = EngineFactory::Build(file.dimension_, file.location_, (EngineType)file.engine_type_); + ExecutionEnginePtr to_index = + EngineFactory::Build(file.dimension_, file.location_, (EngineType)file.engine_type_, + (MetricType)file.metric_type_, file.nlist_); if(to_index == nullptr) { ENGINE_LOG_ERROR << "Invalid engine type"; return Status::Error("Invalid engine type"); diff --git a/cpp/src/db/engine/EngineFactory.cpp b/cpp/src/db/engine/EngineFactory.cpp index d09e9f8b97..a326d6a2c6 100644 --- a/cpp/src/db/engine/EngineFactory.cpp +++ b/cpp/src/db/engine/EngineFactory.cpp @@ -4,7 +4,6 @@ * Proprietary and confidential. ******************************************************************************/ #include "EngineFactory.h" -//#include "FaissExecutionEngine.h" #include "ExecutionEngineImpl.h" #include "db/Log.h" @@ -12,61 +11,25 @@ namespace zilliz { namespace milvus { namespace engine { -#if 0 ExecutionEnginePtr EngineFactory::Build(uint16_t dimension, const std::string &location, - EngineType type) { + EngineType index_type, + MetricType metric_type, + int32_t nlist) { - ExecutionEnginePtr execution_engine_ptr; - - switch (type) { - case EngineType::FAISS_IDMAP: { - execution_engine_ptr = - ExecutionEnginePtr(new FaissExecutionEngine(dimension, location, BUILD_INDEX_TYPE_IDMAP, "IDMap,Flat")); - break; - } - - case EngineType::FAISS_IVFFLAT_GPU: { - execution_engine_ptr = - ExecutionEnginePtr(new FaissExecutionEngine(dimension, location, BUILD_INDEX_TYPE_IVF, "IDMap,Flat")); - break; - } - - case EngineType::FAISS_IVFSQ8: { - execution_engine_ptr = - ExecutionEnginePtr(new FaissExecutionEngine(dimension, location, BUILD_INDEX_TYPE_IVFSQ8, "IDMap,Flat")); - break; - } - - default: { - ENGINE_LOG_ERROR << "Unsupported engine type"; - return nullptr; - } - } - - execution_engine_ptr->Init(); - return execution_engine_ptr; -} -#else -ExecutionEnginePtr -EngineFactory::Build(uint16_t dimension, - const std::string &location, - EngineType type) { - - if(type == EngineType::INVALID) { + if(index_type == EngineType::INVALID) { ENGINE_LOG_ERROR << "Unsupported engine type"; return nullptr; } - ENGINE_LOG_DEBUG << "EngineFactory EngineTypee: " << int(type); + ENGINE_LOG_DEBUG << "EngineFactory EngineTypee: " << (int)index_type; ExecutionEnginePtr execution_engine_ptr = - std::make_shared(dimension, location, type); + std::make_shared(dimension, location, index_type, metric_type, nlist); execution_engine_ptr->Init(); return execution_engine_ptr; } -#endif } } diff --git a/cpp/src/db/engine/EngineFactory.h b/cpp/src/db/engine/EngineFactory.h index d8c35468da..7f2047af9b 100644 --- a/cpp/src/db/engine/EngineFactory.h +++ b/cpp/src/db/engine/EngineFactory.h @@ -16,7 +16,9 @@ class EngineFactory { public: static ExecutionEnginePtr Build(uint16_t dimension, const std::string& location, - EngineType type); + EngineType index_type, + MetricType metric_type, + int32_t nlist); }; } diff --git a/cpp/src/db/engine/ExecutionEngine.h b/cpp/src/db/engine/ExecutionEngine.h index 0f2cf42b22..e6b832db0d 100644 --- a/cpp/src/db/engine/ExecutionEngine.h +++ b/cpp/src/db/engine/ExecutionEngine.h @@ -65,6 +65,10 @@ public: virtual Status Cache() = 0; virtual Status Init() = 0; + + virtual EngineType IndexEngineType() const = 0; + + virtual MetricType IndexMetricType() const = 0; }; using ExecutionEnginePtr = std::shared_ptr; diff --git a/cpp/src/db/engine/ExecutionEngineImpl.cpp b/cpp/src/db/engine/ExecutionEngineImpl.cpp index dd38369832..a7188d5b4e 100644 --- a/cpp/src/db/engine/ExecutionEngineImpl.cpp +++ b/cpp/src/db/engine/ExecutionEngineImpl.cpp @@ -5,7 +5,6 @@ ******************************************************************************/ #include -#include "src/server/ServerConfig.h" #include "src/metrics/Metrics.h" #include "db/Log.h" #include "utils/CommonUtil.h" @@ -22,26 +21,23 @@ namespace zilliz { namespace milvus { namespace engine { -namespace { -std::string GetMetricType() { - server::ServerConfig &config = server::ServerConfig::GetInstance(); - server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE); - return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2"); -} -} - ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string &location, - EngineType type) - : location_(location), dim(dimension), build_type(type) { - current_type = EngineType::FAISS_IDMAP; + EngineType index_type, + MetricType metric_type, + int32_t nlist) + : location_(location), + dim_(dimension), + index_type_(index_type), + metric_type_(metric_type), + nlist_(nlist) { index_ = CreatetVecIndex(EngineType::FAISS_IDMAP); if (!index_) throw Exception("Create Empty VecIndex"); Config build_cfg; build_cfg["dim"] = dimension; - build_cfg["metric_type"] = GetMetricType(); + build_cfg["metric_type"] = (metric_type_ == MetricType::IP) ? "IP" : "L2"; AutoGenParams(index_->GetType(), 0, build_cfg); auto ec = std::static_pointer_cast(index_)->Build(build_cfg); if (ec != server::KNOWHERE_SUCCESS) { throw Exception("Build index error"); } @@ -49,9 +45,14 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, ExecutionEngineImpl::ExecutionEngineImpl(VecIndexPtr index, const std::string &location, - EngineType type) - : index_(std::move(index)), location_(location), build_type(type) { - current_type = type; + EngineType index_type, + MetricType metric_type, + int32_t nlist) + : index_(std::move(index)), + location_(location), + index_type_(index_type), + metric_type_(metric_type), + nlist_(nlist) { } VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) { @@ -204,15 +205,15 @@ ExecutionEngineImpl::BuildIndex(const std::string &location) { ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_; auto from_index = std::dynamic_pointer_cast(index_); - auto to_index = CreatetVecIndex(build_type); + auto to_index = CreatetVecIndex(index_type_); if (!to_index) { throw Exception("Create Empty VecIndex"); } Config build_cfg; build_cfg["dim"] = Dimension(); - build_cfg["metric_type"] = GetMetricType(); - build_cfg["gpu_id"] = gpu_num; + build_cfg["metric_type"] = (metric_type_ == MetricType::IP) ? "IP" : "L2"; + build_cfg["gpu_id"] = gpu_num_; build_cfg["nlist"] = nlist_; AutoGenParams(to_index->GetType(), Count(), build_cfg); @@ -222,7 +223,7 @@ ExecutionEngineImpl::BuildIndex(const std::string &location) { build_cfg); if (ec != server::KNOWHERE_SUCCESS) { throw Exception("Build index error"); } - return std::make_shared(to_index, location, build_type); + return std::make_shared(to_index, location, index_type_, metric_type_, nlist_); } Status ExecutionEngineImpl::Search(long n, @@ -251,16 +252,7 @@ Status ExecutionEngineImpl::Init() { using namespace zilliz::milvus::server; ServerConfig &config = ServerConfig::GetInstance(); ConfigNode server_config = config.GetConfig(CONFIG_SERVER); - gpu_num = server_config.GetInt32Value("gpu_index", 0); - - switch (build_type) { - case EngineType::FAISS_IVFSQ8: - case EngineType::FAISS_IVFFLAT: { - ConfigNode engine_config = config.GetConfig(CONFIG_ENGINE); - nlist_ = engine_config.GetInt32Value(CONFIG_NLIST, 16384); - break; - } - } + gpu_num_ = server_config.GetInt32Value("gpu_index", 0); return Status::OK(); } diff --git a/cpp/src/db/engine/ExecutionEngineImpl.h b/cpp/src/db/engine/ExecutionEngineImpl.h index 948719310c..16f4707c6a 100644 --- a/cpp/src/db/engine/ExecutionEngineImpl.h +++ b/cpp/src/db/engine/ExecutionEngineImpl.h @@ -22,11 +22,15 @@ public: ExecutionEngineImpl(uint16_t dimension, const std::string &location, - EngineType type); + EngineType index_type, + MetricType metric_type, + int32_t nlist); ExecutionEngineImpl(VecIndexPtr index, const std::string &location, - EngineType type); + EngineType index_type, + MetricType metric_type, + int32_t nlist); Status AddWithIds(long n, const float *xdata, const long *xids) override; @@ -61,6 +65,10 @@ public: Status Init() override; + EngineType IndexEngineType() const override { return index_type_; } + + MetricType IndexMetricType() const override { return metric_type_; } + private: VecIndexPtr CreatetVecIndex(EngineType type); @@ -68,14 +76,14 @@ private: protected: VecIndexPtr index_ = nullptr; - EngineType build_type; - EngineType current_type; + EngineType index_type_; + MetricType metric_type_; - int64_t dim; + int64_t dim_; std::string location_; - size_t nlist_ = 0; - int64_t gpu_num = 0; + int32_t nlist_ = 0; + int64_t gpu_num_ = 0; }; diff --git a/cpp/src/db/insert/MemTableFile.cpp b/cpp/src/db/insert/MemTableFile.cpp index 672bd50b00..f8f79c8618 100644 --- a/cpp/src/db/insert/MemTableFile.cpp +++ b/cpp/src/db/insert/MemTableFile.cpp @@ -23,7 +23,9 @@ MemTableFile::MemTableFile(const std::string &table_id, if (status.ok()) { execution_engine_ = EngineFactory::Build(table_file_schema_.dimension_, table_file_schema_.location_, - (EngineType) table_file_schema_.engine_type_); + (EngineType) table_file_schema_.engine_type_, + (MetricType)table_file_schema_.metric_type_, + table_file_schema_.nlist_); } } diff --git a/cpp/src/db/meta/MetaTypes.h b/cpp/src/db/meta/MetaTypes.h index 852a416c88..b0c3376593 100644 --- a/cpp/src/db/meta/MetaTypes.h +++ b/cpp/src/db/meta/MetaTypes.h @@ -17,6 +17,11 @@ namespace milvus { namespace engine { namespace meta { +constexpr int32_t DEFAULT_ENGINE_TYPE = (int)EngineType::FAISS_IDMAP; +constexpr int32_t DEFAULT_NLIST = 16384; +constexpr int32_t DEFAULT_INDEX_FILE_SIZE = 1024*ONE_MB; +constexpr int32_t DEFAULT_METRIC_TYPE = (int)MetricType::L2; + typedef int DateT; const DateT EmptyDate = -1; typedef std::vector DatesT; @@ -32,10 +37,10 @@ struct TableSchema { int32_t state_ = (int)NORMAL; uint16_t dimension_ = 0; int64_t created_on_ = 0; - int32_t engine_type_ = (int)EngineType::FAISS_IDMAP; - int32_t nlist_ = 16384; - int32_t index_file_size_ = 1024*ONE_MB; - int32_t metric_type_ = (int)MetricType::L2; + int32_t engine_type_ = DEFAULT_ENGINE_TYPE; + int32_t nlist_ = DEFAULT_NLIST; + int32_t index_file_size_ = DEFAULT_INDEX_FILE_SIZE; + int32_t metric_type_ = DEFAULT_METRIC_TYPE; }; // TableSchema struct TableFileSchema { @@ -52,7 +57,6 @@ struct TableFileSchema { size_t id_ = 0; std::string table_id_; - int32_t engine_type_ = (int)EngineType::FAISS_IDMAP; std::string file_id_; int32_t file_type_ = NEW; size_t file_size_ = 0; @@ -62,6 +66,9 @@ struct TableFileSchema { std::string location_; int64_t updated_time_ = 0; int64_t created_on_ = 0; + int32_t engine_type_ = DEFAULT_ENGINE_TYPE; + int32_t nlist_ = DEFAULT_NLIST; //not persist to meta + int32_t metric_type_ = DEFAULT_METRIC_TYPE; //not persist to meta }; // TableFileSchema typedef std::vector TableFilesSchema; diff --git a/cpp/src/db/meta/MySQLMetaImpl.cpp b/cpp/src/db/meta/MySQLMetaImpl.cpp index a243630a57..954c498f7f 100644 --- a/cpp/src/db/meta/MySQLMetaImpl.cpp +++ b/cpp/src/db/meta/MySQLMetaImpl.cpp @@ -747,7 +747,7 @@ Status MySQLMetaImpl::AllTables(std::vector &table_schema_array) { } Query allTablesQuery = connectionPtr->query(); - allTablesQuery << "SELECT id, table_id, dimension, engine_type " << + allTablesQuery << "SELECT id, table_id, dimension, engine_type, nlist, index_file_size, metric_type " << "FROM Tables " << "WHERE state <> " << std::to_string(TableSchema::TO_DELETE) << ";"; @@ -769,6 +769,12 @@ Status MySQLMetaImpl::AllTables(std::vector &table_schema_array) { table_schema.engine_type_ = resRow["engine_type"]; + table_schema.nlist_ = resRow["nlist"]; + + table_schema.index_file_size_ = resRow["index_file_size"]; + + table_schema.metric_type_ = resRow["metric_type"]; + table_schema_array.emplace_back(table_schema); } } catch (const BadQuery &er) { @@ -805,6 +811,8 @@ Status MySQLMetaImpl::CreateTableFile(TableFileSchema &file_schema) { file_schema.created_on_ = utils::GetMicroSecTimeStamp(); file_schema.updated_time_ = file_schema.created_on_; file_schema.engine_type_ = table_schema.engine_type_; + file_schema.nlist_ = table_schema.nlist_; + file_schema.metric_type_ = table_schema.metric_type_; utils::GetTableFilePath(options_, file_schema); std::string id = "NULL"; //auto-increment @@ -918,6 +926,8 @@ Status MySQLMetaImpl::FilesToIndex(TableFilesSchema &files) { groups[table_file.table_id_] = table_schema; } + table_file.metric_type_ = groups[table_file.table_id_].metric_type_; + table_file.nlist_ = groups[table_file.table_id_].nlist_; table_file.dimension_ = groups[table_file.table_id_].dimension_; utils::GetTableFilePath(options_, table_file); @@ -1010,6 +1020,10 @@ Status MySQLMetaImpl::FilesToSearch(const std::string &table_id, table_file.engine_type_ = resRow["engine_type"]; + table_file.metric_type_ = table_schema.metric_type_; + + table_file.nlist_ = table_schema.nlist_; + std::string file_id; resRow["file_id"].to_string(file_id); table_file.file_id_ = file_id; @@ -1118,6 +1132,10 @@ Status MySQLMetaImpl::FilesToSearch(const std::string &table_id, table_file.engine_type_ = resRow["engine_type"]; + table_file.metric_type_ = table_schema.metric_type_; + + table_file.nlist_ = table_schema.nlist_; + std::string file_id; resRow["file_id"].to_string(file_id); table_file.file_id_ = file_id; @@ -1214,6 +1232,10 @@ Status MySQLMetaImpl::FilesToMerge(const std::string &table_id, table_file.engine_type_ = resRow["engine_type"]; + table_file.metric_type_ = table_schema.metric_type_; + + table_file.nlist_ = table_schema.nlist_; + table_file.created_on_ = resRow["created_on"]; table_file.dimension_ = table_schema.dimension_; @@ -1293,6 +1315,10 @@ Status MySQLMetaImpl::GetTableFiles(const std::string &table_id, file_schema.engine_type_ = resRow["engine_type"]; + file_schema.metric_type_ = table_schema.metric_type_; + + file_schema.nlist_ = table_schema.nlist_; + std::string file_id; resRow["file_id"].to_string(file_id); file_schema.file_id_ = file_id; diff --git a/cpp/src/db/meta/SqliteMetaImpl.cpp b/cpp/src/db/meta/SqliteMetaImpl.cpp index f93e421698..b4859473ef 100644 --- a/cpp/src/db/meta/SqliteMetaImpl.cpp +++ b/cpp/src/db/meta/SqliteMetaImpl.cpp @@ -218,22 +218,15 @@ Status SqliteMetaImpl::DeleteTable(const std::string& table_id) { std::lock_guard meta_lock(meta_mutex_); //soft delete table - auto tables = ConnectorPtr->select(columns(&TableSchema::id_, - &TableSchema::dimension_, - &TableSchema::engine_type_, - &TableSchema::created_on_), - where(c(&TableSchema::table_id_) == table_id)); - for (auto &table : tables) { - TableSchema table_schema; - table_schema.table_id_ = table_id; - table_schema.state_ = (int)TableSchema::TO_DELETE; - table_schema.id_ = std::get<0>(table); - table_schema.dimension_ = std::get<1>(table); - table_schema.engine_type_ = std::get<2>(table); - table_schema.created_on_ = std::get<3>(table); + ConnectorPtr->update_all( + set( + c(&TableSchema::state_) = (int) TableSchema::TO_DELETE + ), + where( + c(&TableSchema::table_id_) == table_id and + c(&TableSchema::state_) != (int) TableSchema::TO_DELETE + )); - ConnectorPtr->update(table_schema); - } } catch (std::exception &e) { return HandleException("Encounter exception when delete table", e); } @@ -493,16 +486,24 @@ Status SqliteMetaImpl::AllTables(std::vector& table_schema_array) { MetricCollector metric; auto selected = ConnectorPtr->select(columns(&TableSchema::id_, - &TableSchema::table_id_, - &TableSchema::dimension_, - &TableSchema::engine_type_), + &TableSchema::table_id_, + &TableSchema::dimension_, + &TableSchema::created_on_, + &TableSchema::engine_type_, + &TableSchema::nlist_, + &TableSchema::index_file_size_, + &TableSchema::metric_type_), where(c(&TableSchema::state_) != (int)TableSchema::TO_DELETE)); for (auto &table : selected) { TableSchema schema; schema.id_ = std::get<0>(table); schema.table_id_ = std::get<1>(table); - schema.dimension_ = std::get<2>(table); - schema.engine_type_ = std::get<3>(table); + schema.created_on_ = std::get<2>(table); + schema.dimension_ = std::get<3>(table); + schema.engine_type_ = std::get<4>(table); + schema.nlist_ = std::get<5>(table); + schema.index_file_size_ = std::get<6>(table); + schema.metric_type_ = std::get<7>(table); table_schema_array.emplace_back(schema); } @@ -535,6 +536,8 @@ Status SqliteMetaImpl::CreateTableFile(TableFileSchema &file_schema) { file_schema.created_on_ = utils::GetMicroSecTimeStamp(); file_schema.updated_time_ = file_schema.created_on_; file_schema.engine_type_ = table_schema.engine_type_; + file_schema.nlist_ = table_schema.nlist_; + file_schema.metric_type_ = table_schema.metric_type_; //multi-threads call sqlite update may get exception('bad logic', etc), so we add a lock here std::lock_guard meta_lock(meta_mutex_); @@ -594,6 +597,8 @@ Status SqliteMetaImpl::FilesToIndex(TableFilesSchema &files) { } groups[table_file.table_id_] = table_schema; } + table_file.metric_type_ = groups[table_file.table_id_].metric_type_; + table_file.nlist_ = groups[table_file.table_id_].nlist_; table_file.dimension_ = groups[table_file.table_id_].dimension_; files.push_back(table_file); } @@ -644,6 +649,8 @@ Status SqliteMetaImpl::FilesToSearch(const std::string &table_id, table_file.row_count_ = std::get<5>(file); table_file.date_ = std::get<6>(file); table_file.engine_type_ = std::get<7>(file); + table_file.metric_type_ = table_schema.metric_type_; + table_file.nlist_ = table_schema.nlist_; table_file.dimension_ = table_schema.dimension_; utils::GetTableFilePath(options_, table_file); auto dateItr = files.find(table_file.date_); @@ -685,6 +692,8 @@ Status SqliteMetaImpl::FilesToSearch(const std::string &table_id, table_file.row_count_ = std::get<5>(file); table_file.date_ = std::get<6>(file); table_file.engine_type_ = std::get<7>(file); + table_file.metric_type_ = table_schema.metric_type_; + table_file.nlist_ = table_schema.nlist_; table_file.dimension_ = table_schema.dimension_; utils::GetTableFilePath(options_, table_file); auto dateItr = files.find(table_file.date_); @@ -762,6 +771,8 @@ Status SqliteMetaImpl::FilesToSearch(const std::string &table_id, table_file.date_ = std::get<6>(file); table_file.engine_type_ = std::get<7>(file); table_file.dimension_ = table_schema.dimension_; + table_file.metric_type_ = table_schema.metric_type_; + table_file.nlist_ = table_schema.nlist_; utils::GetTableFilePath(options_, table_file); auto dateItr = files.find(table_file.date_); if (dateItr == files.end()) { @@ -820,6 +831,8 @@ Status SqliteMetaImpl::FilesToMerge(const std::string &table_id, table_file.date_ = std::get<6>(file); table_file.created_on_ = std::get<7>(file); table_file.dimension_ = table_schema.dimension_; + table_file.metric_type_ = table_schema.metric_type_; + table_file.nlist_ = table_schema.nlist_; utils::GetTableFilePath(options_, table_file); auto dateItr = files.find(table_file.date_); if (dateItr == files.end()) { @@ -868,8 +881,11 @@ Status SqliteMetaImpl::GetTableFiles(const std::string& table_id, file_schema.row_count_ = std::get<4>(file); file_schema.date_ = std::get<5>(file); file_schema.engine_type_ = std::get<6>(file); + file_schema.metric_type_ = table_schema.metric_type_; + file_schema.nlist_ = table_schema.nlist_; file_schema.created_on_ = std::get<7>(file); file_schema.dimension_ = table_schema.dimension_; + utils::GetTableFilePath(options_, file_schema); table_files.emplace_back(file_schema); diff --git a/cpp/src/db/scheduler/task/IndexLoadTask.cpp b/cpp/src/db/scheduler/task/IndexLoadTask.cpp index 4b242f230d..561bf07f13 100644 --- a/cpp/src/db/scheduler/task/IndexLoadTask.cpp +++ b/cpp/src/db/scheduler/task/IndexLoadTask.cpp @@ -45,7 +45,9 @@ std::shared_ptr IndexLoadTask::Execute() { //step 1: load index ExecutionEnginePtr index_ptr = EngineFactory::Build(file_->dimension_, file_->location_, - (EngineType)file_->engine_type_); + (EngineType)file_->engine_type_, + (MetricType)file_->metric_type_, + file_->nlist_); try { index_ptr->Load(); @@ -75,7 +77,7 @@ std::shared_ptr IndexLoadTask::Execute() { //step 2: return search task for later execution SearchTaskPtr task_ptr = std::make_shared(); task_ptr->index_id_ = file_->id_; - task_ptr->index_type_ = file_->file_type_; + task_ptr->file_type_ = file_->file_type_; task_ptr->index_engine_ = index_ptr; task_ptr->search_contexts_.swap(search_contexts_); return std::static_pointer_cast(task_ptr); diff --git a/cpp/src/db/scheduler/task/SearchTask.cpp b/cpp/src/db/scheduler/task/SearchTask.cpp index fd9d679d5e..4e7c0f4611 100644 --- a/cpp/src/db/scheduler/task/SearchTask.cpp +++ b/cpp/src/db/scheduler/task/SearchTask.cpp @@ -76,20 +76,10 @@ void CollectDurationMetrics(int index_type, double total_time) { } } -std::string GetMetricType() { - server::ServerConfig &config = server::ServerConfig::GetInstance(); - server::ConfigNode& engine_config = config.GetConfig(server::CONFIG_ENGINE); - return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2"); -} - } SearchTask::SearchTask() : IScheduleTask(ScheduleTaskType::kSearch) { - std::string metric_type = GetMetricType(); - if(metric_type != "L2") { - metric_l2 = false; - } } std::shared_ptr SearchTask::Execute() { @@ -104,6 +94,8 @@ std::shared_ptr SearchTask::Execute() { auto start_time = METRICS_NOW_TIME; + bool metric_l2 = (index_engine_->IndexMetricType() == MetricType::L2); + std::vector output_ids; std::vector output_distence; for(auto& context : search_contexts_) { @@ -147,7 +139,7 @@ std::shared_ptr SearchTask::Execute() { auto end_time = METRICS_NOW_TIME; auto total_time = METRICS_MICROSECONDS(start_time, end_time); - CollectDurationMetrics(index_type_, total_time); + CollectDurationMetrics(file_type_, total_time); rc.ElapseFromBegin("totally cost"); diff --git a/cpp/src/db/scheduler/task/SearchTask.h b/cpp/src/db/scheduler/task/SearchTask.h index 034b53d4dc..6010046446 100644 --- a/cpp/src/db/scheduler/task/SearchTask.h +++ b/cpp/src/db/scheduler/task/SearchTask.h @@ -37,10 +37,9 @@ public: public: size_t index_id_ = 0; - int index_type_ = 0; //for metrics + int file_type_ = 0; //for metrics ExecutionEnginePtr index_engine_; std::vector search_contexts_; - bool metric_l2 = true; }; using SearchTaskPtr = std::shared_ptr; diff --git a/cpp/unittest/db/mem_test.cpp b/cpp/unittest/db/mem_test.cpp index fb4796a34a..77a83abc4e 100644 --- a/cpp/unittest/db/mem_test.cpp +++ b/cpp/unittest/db/mem_test.cpp @@ -65,9 +65,13 @@ TEST_F(NewMemManagerTest, VECTOR_SOURCE_TEST) { engine::VectorSource source(n, vectors.data()); size_t num_vectors_added; - engine::ExecutionEnginePtr execution_engine_ = engine::EngineFactory::Build(table_file_schema.dimension_, - table_file_schema.location_, - (engine::EngineType) table_file_schema.engine_type_); + engine::ExecutionEnginePtr execution_engine_ = + engine::EngineFactory::Build(table_file_schema.dimension_, + table_file_schema.location_, + (engine::EngineType) table_file_schema.engine_type_, + (engine::MetricType)table_file_schema.metric_type_, + table_schema.nlist_); + engine::IDNumbers vector_ids; status = source.Add(execution_engine_, table_file_schema, 50, num_vectors_added, vector_ids); ASSERT_TRUE(status.ok());