fix too many copies (#2660)

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
pull/2677/head
shengjun.li 2020-06-24 09:46:48 +08:00 committed by GitHub
parent ed012f2980
commit 0fb88f733a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 18 deletions

View File

@ -40,14 +40,11 @@ GPUIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
idx_config.device = gpu_id_;
int32_t nlist = config[IndexParams::nlist];
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
faiss::gpu::GpuIndexIVFFlat device_index(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config);
device_index.train(rows, (float*)p_data);
auto device_index =
new faiss::gpu::GpuIndexIVFFlat(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config);
device_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index));
auto device_index1 = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
index_.reset(device_index1);
index_.reset(device_index);
res_ = gpu_res;
} else {
KNOWHERE_THROW_MSG("Build IVF can't get gpu resource");

View File

@ -38,11 +38,8 @@ GPUIVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
config[IndexParams::m], config[IndexParams::nbits],
GetMetricType(config[Metric::TYPE].get<std::string>())); // IP not support
device_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
auto device_index1 = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
index_.reset(device_index1);
index_.reset(device_index);
res_ = gpu_res;
} else {
KNOWHERE_THROW_MSG("Build IVFPQ can't get gpu resource");

View File

@ -45,19 +45,14 @@ IVFSQHybrid::Train(const DatasetPtr& dataset_ptr, const Config& config) {
auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index);
device_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
delete device_index;
delete build_index;
device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
index_.reset(device_index);
res_ = gpu_res;
gpu_mode_ = 2;
} else {
KNOWHERE_THROW_MSG("Build IVFSQHybrid can't get gpu resource");
}
delete build_index;
}
VecIndexPtr