fix CPU version bug

pull/608/head
fishpenguin 2019-11-29 15:57:42 +08:00
parent fd304cf4b4
commit 6fcd2a13da
1 changed files with 16 additions and 16 deletions

View File

@ -116,29 +116,29 @@ NSG::Train(const DatasetPtr& dataset, const Config& config) {
}
// TODO(linxj): dev IndexFactory, support more IndexType
bool use_gpu = false;
#ifdef MILVUS_GPU_VERSION
use_gpu = true;
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(build_cfg->gpu_id);
if (temp_resource == nullptr)
use_gpu = false;
#endif
Graph knng;
if (use_gpu) {
auto preprocess_index = std::make_shared<GPUIVF>(build_cfg->gpu_id);
auto model = preprocess_index->Train(dataset, config);
preprocess_index->set_index_model(model);
preprocess_index->AddWithoutIds(dataset, config);
preprocess_index->GenGraph(build_cfg->knng, knng, dataset, config);
} else {
#ifdef MILVUS_GPU_VERSION
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(build_cfg->gpu_id);
if (temp_resource == nullptr) {
auto preprocess_index = std::make_shared<IVF>();
auto model = preprocess_index->Train(dataset, config);
preprocess_index->set_index_model(model);
preprocess_index->AddWithoutIds(dataset, config);
preprocess_index->GenGraph(build_cfg->knng, knng, dataset, config);
} else {
auto preprocess_index = std::make_shared<GPUIVF>(build_cfg->gpu_id);
auto model = preprocess_index->Train(dataset, config);
preprocess_index->set_index_model(model);
preprocess_index->AddWithoutIds(dataset, config);
preprocess_index->GenGraph(build_cfg->knng, knng, dataset, config);
}
#else
auto preprocess_index = std::make_shared<IVF>();
auto model = preprocess_index->Train(dataset, config);
preprocess_index->set_index_model(model);
preprocess_index->AddWithoutIds(dataset, config);
preprocess_index->GenGraph(build_cfg->knng, knng, dataset, config);
#endif
algo::BuildParams b_params;
b_params.candidate_pool_size = build_cfg->candidate_pool_size;