index created by std::make_shared (#5628)

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
pull/5658/head
shengjun.li 2021-06-05 16:34:51 +08:00 committed by GitHub
parent 9e7e1dcba4
commit 71c6eaeddb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 27 deletions

View File

@ -41,11 +41,9 @@ GPUIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
idx_config.device = static_cast<int32_t>(gpu_id_);
int32_t nlist = config[IndexParams::nlist];
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto device_index =
new faiss::gpu::GpuIndexIVFFlat(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config);
device_index->train(rows, (float*)p_data);
index_.reset(device_index);
index_ = std::make_shared<faiss::gpu::GpuIndexIVFFlat>(gpu_res->faiss_res.get(), dim, nlist, metric_type,
idx_config);
index_->train(rows, (float*)p_data);
res_ = gpu_res;
} else {
KNOWHERE_THROW_MSG("Build IVF can't get gpu resource");

View File

@ -38,10 +38,9 @@ GPUIVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
int32_t m = config[IndexParams::m];
int32_t nbits = config[IndexParams::nbits];
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto device_index =
new faiss::gpu::GpuIndexIVFPQ(gpu_res->faiss_res.get(), dim, nlist, m, nbits, metric_type, idx_config);
device_index->train(rows, (float*)p_data);
index_.reset(device_index);
index_ = std::make_shared<faiss::gpu::GpuIndexIVFPQ>(gpu_res->faiss_res.get(), dim, nlist, m, nbits,
metric_type, idx_config);
index_->train(rows, (float*)p_data);
res_ = gpu_res;
} else {
KNOWHERE_THROW_MSG("Build IVFPQ can't get gpu resource");

View File

@ -36,10 +36,9 @@ GPUIVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
idx_config.device = static_cast<int32_t>(gpu_id_);
int32_t nlist = config[IndexParams::nlist];
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto device_index = new faiss::gpu::GpuIndexIVFScalarQuantizer(
index_ = std::make_shared<faiss::gpu::GpuIndexIVFScalarQuantizer>(
gpu_res->faiss_res.get(), dim, nlist, faiss::QuantizerType::QT_8bit, metric_type, true, idx_config);
device_index->train(rows, (float*)p_data);
index_.reset(device_index);
index_->train(rows, (float*)p_data);
res_ = gpu_res;
} else {
KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource");

View File

@ -12,8 +12,7 @@
#include <faiss/IndexSQHybrid.h>
#include <faiss/gpu/GpuCloner.h>
#include <faiss/gpu/GpuIndexIVF.h>
#include <faiss/index_factory.h>
#include <faiss/gpu/GpuIndexIVFSQHybrid.h>
#include <fiu-local.h>
#include <string>
#include <utility>
@ -34,28 +33,22 @@ IVFSQHybrid::Train(const DatasetPtr& dataset_ptr, const Config& config) {
GETTENSOR(dataset_ptr)
gpu_id_ = config[knowhere::meta::DEVICEID];
std::stringstream index_type;
index_type << "IVF" << config[IndexParams::nlist] << ","
<< "SQ8Hybrid";
auto build_index =
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (gpu_res != nullptr) {
ResScope rs(gpu_res, gpu_id_, true);
auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index);
device_index->train(rows, (float*)p_data);
index_.reset(device_index);
faiss::gpu::GpuIndexIVFSQHybridConfig idx_config;
idx_config.device = static_cast<int32_t>(gpu_id_);
int32_t nlist = config[IndexParams::nlist];
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
index_ = std::make_shared<faiss::gpu::GpuIndexIVFSQHybrid>(
gpu_res->faiss_res.get(), dim, nlist, faiss::QuantizerType::QT_8bit, metric_type, true, idx_config);
index_->train(rows, reinterpret_cast<const float*>(p_data));
res_ = gpu_res;
gpu_mode_ = 2;
index_mode_ = IndexMode::MODE_GPU;
} else {
delete build_index;
KNOWHERE_THROW_MSG("Build IVFSQHybrid can't get gpu resource");
}
delete build_index;
}
VecIndexPtr