mirror of https://github.com/milvus-io/milvus.git
parent
ed012f2980
commit
0fb88f733a
|
@ -40,14 +40,11 @@ GPUIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||||
idx_config.device = gpu_id_;
|
idx_config.device = gpu_id_;
|
||||||
int32_t nlist = config[IndexParams::nlist];
|
int32_t nlist = config[IndexParams::nlist];
|
||||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||||
faiss::gpu::GpuIndexIVFFlat device_index(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config);
|
auto device_index =
|
||||||
device_index.train(rows, (float*)p_data);
|
new faiss::gpu::GpuIndexIVFFlat(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config);
|
||||||
|
device_index->train(rows, (float*)p_data);
|
||||||
|
|
||||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
index_.reset(device_index);
|
||||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index));
|
|
||||||
|
|
||||||
auto device_index1 = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
|
|
||||||
index_.reset(device_index1);
|
|
||||||
res_ = gpu_res;
|
res_ = gpu_res;
|
||||||
} else {
|
} else {
|
||||||
KNOWHERE_THROW_MSG("Build IVF can't get gpu resource");
|
KNOWHERE_THROW_MSG("Build IVF can't get gpu resource");
|
||||||
|
|
|
@ -38,11 +38,8 @@ GPUIVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||||
config[IndexParams::m], config[IndexParams::nbits],
|
config[IndexParams::m], config[IndexParams::nbits],
|
||||||
GetMetricType(config[Metric::TYPE].get<std::string>())); // IP not support
|
GetMetricType(config[Metric::TYPE].get<std::string>())); // IP not support
|
||||||
device_index->train(rows, (float*)p_data);
|
device_index->train(rows, (float*)p_data);
|
||||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
|
||||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
|
|
||||||
|
|
||||||
auto device_index1 = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
|
index_.reset(device_index);
|
||||||
index_.reset(device_index1);
|
|
||||||
res_ = gpu_res;
|
res_ = gpu_res;
|
||||||
} else {
|
} else {
|
||||||
KNOWHERE_THROW_MSG("Build IVFPQ can't get gpu resource");
|
KNOWHERE_THROW_MSG("Build IVFPQ can't get gpu resource");
|
||||||
|
|
|
@ -45,19 +45,14 @@ IVFSQHybrid::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||||
auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index);
|
auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index);
|
||||||
device_index->train(rows, (float*)p_data);
|
device_index->train(rows, (float*)p_data);
|
||||||
|
|
||||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
|
||||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
|
|
||||||
|
|
||||||
delete device_index;
|
|
||||||
delete build_index;
|
|
||||||
|
|
||||||
device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
|
|
||||||
index_.reset(device_index);
|
index_.reset(device_index);
|
||||||
res_ = gpu_res;
|
res_ = gpu_res;
|
||||||
gpu_mode_ = 2;
|
gpu_mode_ = 2;
|
||||||
} else {
|
} else {
|
||||||
KNOWHERE_THROW_MSG("Build IVFSQHybrid can't get gpu resource");
|
KNOWHERE_THROW_MSG("Build IVFSQHybrid can't get gpu resource");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
delete build_index;
|
||||||
}
|
}
|
||||||
|
|
||||||
VecIndexPtr
|
VecIndexPtr
|
||||||
|
|
Loading…
Reference in New Issue