From 12dd738e46e53bc3d36b542e263e48e0a4afd722 Mon Sep 17 00:00:00 2001 From: "xj.lin" Date: Wed, 11 Sep 2019 20:16:31 +0800 Subject: [PATCH] MS-544 fix Former-commit-id: 2d9d4afb74ce2522b8cf5bf29ade982d3d6722ad --- .../src/knowhere/index/vector_index/gpu_ivf.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/src/core/src/knowhere/index/vector_index/gpu_ivf.cpp b/cpp/src/core/src/knowhere/index/vector_index/gpu_ivf.cpp index c1498c3305..e7159a7d44 100644 --- a/cpp/src/core/src/knowhere/index/vector_index/gpu_ivf.cpp +++ b/cpp/src/core/src/knowhere/index/vector_index/gpu_ivf.cpp @@ -26,17 +26,17 @@ namespace knowhere { IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) { auto nlist = config["nlist"].as(); - auto gpu_device = config.get_with_default("gpu_id", gpu_id_); + gpu_id_ = config.get_with_default("gpu_id", gpu_id_); auto metric_type = config["metric_type"].as_string() == "L2" ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT; GETTENSOR(dataset) - auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_device); + auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); if (temp_resource != nullptr) { - ResScope rs(gpu_device, temp_resource); + ResScope rs(gpu_id_, temp_resource); faiss::gpu::GpuIndexIVFFlatConfig idx_config; - idx_config.device = gpu_device; + idx_config.device = gpu_id_; faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, nlist, metric_type, idx_config); device_index.train(rows, (float *) p_data); @@ -204,7 +204,7 @@ VectorIndexPtr GPUIVFPQ::CopyGpuToCpu(const Config &config) { IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) { auto nlist = config["nlist"].as(); auto nbits = config["nbits"].as(); // TODO(linxj): gpu only support SQ4 SQ8 SQ16 - auto gpu_num = config.get_with_default("gpu_id", gpu_id_); + gpu_id_ = config.get_with_default("gpu_id", gpu_id_); auto metric_type = config["metric_type"].as_string() == "L2" ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT; @@ -214,10 +214,10 @@ IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) { index_type << "IVF" << nlist << "," << "SQ" << nbits; auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type); - auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_num); + auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); if (temp_resource != nullptr) { - ResScope rs(gpu_num, temp_resource ); - auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_num, build_index); + ResScope rs(gpu_id_, temp_resource ); + auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index); device_index->train(rows, (float *) p_data); std::shared_ptr host_index = nullptr;