mirror of https://github.com/milvus-io/milvus.git
parent
ed012f2980
commit
0fb88f733a
|
@ -40,14 +40,11 @@ GPUIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
idx_config.device = gpu_id_;
|
||||
int32_t nlist = config[IndexParams::nlist];
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
faiss::gpu::GpuIndexIVFFlat device_index(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config);
|
||||
device_index.train(rows, (float*)p_data);
|
||||
auto device_index =
|
||||
new faiss::gpu::GpuIndexIVFFlat(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config);
|
||||
device_index->train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index));
|
||||
|
||||
auto device_index1 = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
|
||||
index_.reset(device_index1);
|
||||
index_.reset(device_index);
|
||||
res_ = gpu_res;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Build IVF can't get gpu resource");
|
||||
|
|
|
@ -38,11 +38,8 @@ GPUIVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
config[IndexParams::m], config[IndexParams::nbits],
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>())); // IP not support
|
||||
device_index->train(rows, (float*)p_data);
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
|
||||
|
||||
auto device_index1 = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
|
||||
index_.reset(device_index1);
|
||||
index_.reset(device_index);
|
||||
res_ = gpu_res;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Build IVFPQ can't get gpu resource");
|
||||
|
|
|
@ -45,19 +45,14 @@ IVFSQHybrid::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index);
|
||||
device_index->train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
|
||||
|
||||
delete device_index;
|
||||
delete build_index;
|
||||
|
||||
device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
|
||||
index_.reset(device_index);
|
||||
res_ = gpu_res;
|
||||
gpu_mode_ = 2;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Build IVFSQHybrid can't get gpu resource");
|
||||
}
|
||||
|
||||
delete build_index;
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
|
|
Loading…
Reference in New Issue