mirror of https://github.com/milvus-io/milvus.git
Caiyd refactor knowhere (#1687)
* add new knowhere Signed-off-by: Nicky <nicky.xj.lin@gmail.com> * build pass Signed-off-by: xiaojun.lin <xiaojun.lin@zilliz.com> * update Signed-off-by: Nicky <nicky.xj.lin@gmail.com> * update Signed-off-by: Nicky <nicky.xj.lin@gmail.com> * rename algo Signed-off-by: Nicky <nicky.xj.lin@gmail.com> * update... Signed-off-by: xiaojun.lin <xiaojun.lin@zilliz.com> * add archive Signed-off-by: Nicky <nicky.xj.lin@gmail.com> * add new APIs: GetVectorById/SearchVectorById/SetBlacklist/GetBlacklist * update unittest Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update unittest Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update unittest Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * all unittest pass Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * test_binary pass Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * move knowhere into namespace milvus Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update hnsw Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update ConfAdapterMgr Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update ExecutionEngineImpl Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * move SetBlacklist/GetBlacklist to VecIndex Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update VectorAdapter and rename SearchById to QueryById Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update interface in ExecutionEngineImpl Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * milvus build pass Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * milvus IDMAP sdk_simple pass Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix test_server Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix test_schedule Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix CPU version Milvus build issue Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update BinaryIVF BuildAll Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update VecIndexFactory Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update ConfAdapter Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix clang-format Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix clang-format Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update changelog Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix knowhere unittest Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix SPTAG unittest * fix clang-format Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix CPU version unittest build issue Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix db_test Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix test_engine Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix test_delete Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix CPU version build issue Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * change BinarySet key back Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * change IndexType to string, and add compatible API for 0.7.0 Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix unittest Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix IndexHNSW build warning Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix clang-format Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix test_cache Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix unittest Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix write_index error Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * code clean Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix unittest Signed-off-by: yudong.cai <yudong.cai@zilliz.com> Co-authored-by: Nicky <nicky.xj.lin@gmail.com> Co-authored-by: xiaojun.lin <xiaojun.lin@zilliz.com>pull/1707/head
parent
72ad100a90
commit
588ef95d76
|
@ -17,6 +17,7 @@ Please mark all change in change log and use the issue from GitHub
|
|||
- \#1660 IVF PQ CPU support deleted vectors searching
|
||||
|
||||
## Improvement
|
||||
- \#342 Knowhere and Wrapper refactor
|
||||
- \#1537 Optimize raw vector and uids read/write
|
||||
- \#1546 Move Config.cpp to config directory
|
||||
- \#1547 Rename storage/file to storage/disk and rename classes
|
||||
|
|
|
@ -116,7 +116,7 @@ set(storage_files
|
|||
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/utils utils_files)
|
||||
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/wrapper wrapper_files)
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/index/archive wrapper_files)
|
||||
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/tracing tracing_files)
|
||||
|
||||
|
|
|
@ -22,21 +22,31 @@
|
|||
#include "cache/GpuCacheMgr.h"
|
||||
#include "config/Config.h"
|
||||
#include "db/Utils.h"
|
||||
#include "index/archive/VecIndex.h"
|
||||
#include "knowhere/common/Config.h"
|
||||
#include "knowhere/index/vector_index/ConfAdapter.h"
|
||||
#include "knowhere/index/vector_index/ConfAdapterMgr.h"
|
||||
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndexFactory.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/gpu/GPUIndex.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h"
|
||||
#include "knowhere/index/vector_index/gpu/Quantizer.h"
|
||||
#include "knowhere/index/vector_index/helpers/Cloner.h"
|
||||
#endif
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
#include "metrics/Metrics.h"
|
||||
#include "scheduler/Utils.h"
|
||||
#include "utils/CommonUtil.h"
|
||||
#include "utils/Exception.h"
|
||||
#include "utils/Log.h"
|
||||
#include "utils/Status.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
#include "utils/ValidationUtil.h"
|
||||
#include "wrapper/BinVecImpl.h"
|
||||
#include "wrapper/ConfAdapter.h"
|
||||
#include "wrapper/ConfAdapterMgr.h"
|
||||
#include "wrapper/VecImpl.h"
|
||||
#include "wrapper/VecIndex.h"
|
||||
|
||||
//#define ON_SEARCH
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
|
@ -74,12 +84,13 @@ MappingMetricType(MetricType metric_type, milvus::json& conf) {
|
|||
}
|
||||
|
||||
bool
|
||||
IsBinaryIndexType(IndexType type) {
|
||||
return type == IndexType::FAISS_BIN_IDMAP || type == IndexType::FAISS_BIN_IVFLAT_CPU;
|
||||
IsBinaryIndexType(knowhere::IndexType type) {
|
||||
return type == knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP || type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
class CachedQuantizer : public cache::DataObj {
|
||||
public:
|
||||
explicit CachedQuantizer(knowhere::QuantizerPtr data) : data_(std::move(data)) {
|
||||
|
@ -98,6 +109,7 @@ class CachedQuantizer : public cache::DataObj {
|
|||
private:
|
||||
knowhere::QuantizerPtr data_;
|
||||
};
|
||||
#endif
|
||||
|
||||
ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type,
|
||||
MetricType metric_type, const milvus::json& index_params)
|
||||
|
@ -118,24 +130,22 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string&
|
|||
conf[knowhere::meta::DIM] = dimension;
|
||||
MappingMetricType(metric_type, conf);
|
||||
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
if (!adapter->CheckTrain(conf)) {
|
||||
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_->index_type());
|
||||
if (!adapter->CheckTrain(conf, index_->index_mode())) {
|
||||
throw Exception(DB_ERROR, "Illegal index params");
|
||||
}
|
||||
|
||||
ErrorCode ec = KNOWHERE_UNEXPECTED_ERROR;
|
||||
if (auto bf_index = std::dynamic_pointer_cast<BFIndex>(index_)) {
|
||||
ec = bf_index->Build(conf);
|
||||
} else if (auto bf_bin_index = std::dynamic_pointer_cast<BinBFIndex>(index_)) {
|
||||
ec = bf_bin_index->Build(conf);
|
||||
}
|
||||
if (ec != KNOWHERE_SUCCESS) {
|
||||
throw Exception(DB_ERROR, "Build index error");
|
||||
fiu_do_on("ExecutionEngineImpl.throw_exception", throw Exception(DB_ERROR, ""));
|
||||
if (auto bf_index = std::dynamic_pointer_cast<knowhere::IDMAP>(index_)) {
|
||||
bf_index->Train(knowhere::DatasetPtr(), conf);
|
||||
} else if (auto bf_bin_index = std::dynamic_pointer_cast<knowhere::BinaryIDMAP>(index_)) {
|
||||
bf_bin_index->Train(knowhere::DatasetPtr(), conf);
|
||||
}
|
||||
}
|
||||
|
||||
ExecutionEngineImpl::ExecutionEngineImpl(VecIndexPtr index, const std::string& location, EngineType index_type,
|
||||
MetricType metric_type, const milvus::json& index_params)
|
||||
ExecutionEngineImpl::ExecutionEngineImpl(knowhere::VecIndexPtr index, const std::string& location,
|
||||
EngineType index_type, MetricType metric_type,
|
||||
const milvus::json& index_params)
|
||||
: index_(std::move(index)),
|
||||
location_(location),
|
||||
index_type_(index_type),
|
||||
|
@ -143,105 +153,93 @@ ExecutionEngineImpl::ExecutionEngineImpl(VecIndexPtr index, const std::string& l
|
|||
index_params_(index_params) {
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
knowhere::VecIndexPtr
|
||||
ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
|
||||
knowhere::VecIndexFactory& vec_index_factory = knowhere::VecIndexFactory::GetInstance();
|
||||
knowhere::IndexMode mode = knowhere::IndexMode::MODE_CPU;
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
server::Config& config = server::Config::GetInstance();
|
||||
bool gpu_resource_enable = true;
|
||||
config.GetGpuResourceConfigEnable(gpu_resource_enable);
|
||||
fiu_do_on("ExecutionEngineImpl.CreatetVecIndex.gpu_res_disabled", gpu_resource_enable = false);
|
||||
if (gpu_resource_enable) {
|
||||
mode = knowhere::IndexMode::MODE_GPU;
|
||||
}
|
||||
#endif
|
||||
|
||||
fiu_do_on("ExecutionEngineImpl.CreatetVecIndex.invalid_type", type = EngineType::INVALID);
|
||||
std::shared_ptr<VecIndex> index;
|
||||
fiu_do_on("ExecutionEngineImpl.CreateVecIndex.invalid_type", type = EngineType::INVALID);
|
||||
knowhere::VecIndexPtr index = nullptr;
|
||||
switch (type) {
|
||||
case EngineType::FAISS_IDMAP: {
|
||||
index = GetVecIndexFactory(IndexType::FAISS_IDMAP);
|
||||
index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IDMAP, mode);
|
||||
break;
|
||||
}
|
||||
case EngineType::FAISS_IVFFLAT: {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (gpu_resource_enable)
|
||||
index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_MIX);
|
||||
else
|
||||
#endif
|
||||
index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_CPU);
|
||||
index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, mode);
|
||||
break;
|
||||
}
|
||||
case EngineType::FAISS_PQ: {
|
||||
index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, mode);
|
||||
break;
|
||||
}
|
||||
case EngineType::FAISS_IVFSQ8: {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (gpu_resource_enable)
|
||||
index = GetVecIndexFactory(IndexType::FAISS_IVFSQ8_MIX);
|
||||
else
|
||||
#endif
|
||||
index = GetVecIndexFactory(IndexType::FAISS_IVFSQ8_CPU);
|
||||
break;
|
||||
}
|
||||
case EngineType::NSG_MIX: {
|
||||
index = GetVecIndexFactory(IndexType::NSG_MIX);
|
||||
index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, mode);
|
||||
break;
|
||||
}
|
||||
#ifdef CUSTOMIZATION
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
case EngineType::FAISS_IVFSQ8H: {
|
||||
if (gpu_resource_enable) {
|
||||
index = GetVecIndexFactory(IndexType::FAISS_IVFSQ8_HYBRID);
|
||||
} else {
|
||||
throw Exception(DB_ERROR, "No GPU resources for IVFSQ8H");
|
||||
}
|
||||
index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H, mode);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
case EngineType::FAISS_PQ: {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (gpu_resource_enable)
|
||||
index = GetVecIndexFactory(IndexType::FAISS_IVFPQ_MIX);
|
||||
else
|
||||
#endif
|
||||
index = GetVecIndexFactory(IndexType::FAISS_IVFPQ_CPU);
|
||||
break;
|
||||
}
|
||||
case EngineType::SPTAG_KDT: {
|
||||
index = GetVecIndexFactory(IndexType::SPTAG_KDT_RNT_CPU);
|
||||
break;
|
||||
}
|
||||
case EngineType::SPTAG_BKT: {
|
||||
index = GetVecIndexFactory(IndexType::SPTAG_BKT_RNT_CPU);
|
||||
break;
|
||||
}
|
||||
case EngineType::HNSW: {
|
||||
index = GetVecIndexFactory(IndexType::HNSW);
|
||||
break;
|
||||
}
|
||||
case EngineType::FAISS_BIN_IDMAP: {
|
||||
index = GetVecIndexFactory(IndexType::FAISS_BIN_IDMAP);
|
||||
index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, mode);
|
||||
break;
|
||||
}
|
||||
case EngineType::FAISS_BIN_IVFFLAT: {
|
||||
index = GetVecIndexFactory(IndexType::FAISS_BIN_IVFLAT_CPU);
|
||||
index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, mode);
|
||||
break;
|
||||
}
|
||||
case EngineType::NSG_MIX: {
|
||||
index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_NSG, mode);
|
||||
break;
|
||||
}
|
||||
case EngineType::SPTAG_KDT: {
|
||||
index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_SPTAG_KDT_RNT, mode);
|
||||
break;
|
||||
}
|
||||
case EngineType::SPTAG_BKT: {
|
||||
index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_SPTAG_BKT_RNT, mode);
|
||||
break;
|
||||
}
|
||||
case EngineType::HNSW: {
|
||||
index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_HNSW, mode);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
ENGINE_LOG_ERROR << "Unsupported index type";
|
||||
ENGINE_LOG_ERROR << "Unsupported index type " << (int)type;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
if (index == nullptr) {
|
||||
std::string err_msg = "Invalid index type " + std::to_string((int)type) + " mod " + std::to_string((int)mode);
|
||||
ENGINE_LOG_ERROR << err_msg;
|
||||
throw Exception(DB_ERROR, err_msg);
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
void
|
||||
ExecutionEngineImpl::HybridLoad() const {
|
||||
if (index_type_ != EngineType::FAISS_IVFSQ8H) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (index_->GetType() == IndexType::FAISS_IDMAP) {
|
||||
ENGINE_LOG_WARNING << "HybridLoad with type FAISS_IDMAP, ignore";
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
auto hybrid_index = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_);
|
||||
if (hybrid_index == nullptr) {
|
||||
ENGINE_LOG_WARNING << "HybridLoad only support with IVFSQHybrid";
|
||||
return;
|
||||
}
|
||||
|
||||
const std::string key = location_ + ".quantizer";
|
||||
|
||||
server::Config& config = server::Config::GetInstance();
|
||||
|
@ -267,7 +265,7 @@ ExecutionEngineImpl::HybridLoad() const {
|
|||
}
|
||||
|
||||
if (device_id != NOT_FOUND) {
|
||||
index_->SetQuantizer(quantizer);
|
||||
hybrid_index->SetQuantizer(quantizer);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -286,12 +284,12 @@ ExecutionEngineImpl::HybridLoad() const {
|
|||
auto best_device_id = gpus[best_index];
|
||||
|
||||
milvus::json quantizer_conf{{knowhere::meta::DEVICEID, best_device_id}, {"mode", 1}};
|
||||
auto quantizer = index_->LoadQuantizer(quantizer_conf);
|
||||
auto quantizer = hybrid_index->LoadQuantizer(quantizer_conf);
|
||||
ENGINE_LOG_DEBUG << "Quantizer params: " << quantizer_conf.dump();
|
||||
if (quantizer == nullptr) {
|
||||
ENGINE_LOG_ERROR << "quantizer is nullptr";
|
||||
}
|
||||
index_->SetQuantizer(quantizer);
|
||||
hybrid_index->SetQuantizer(quantizer);
|
||||
auto cache_quantizer = std::make_shared<CachedQuantizer>(quantizer);
|
||||
cache::GpuCacheMgr::GetInstance(best_device_id)->InsertItem(key, cache_quantizer);
|
||||
}
|
||||
|
@ -300,25 +298,27 @@ ExecutionEngineImpl::HybridLoad() const {
|
|||
|
||||
void
|
||||
ExecutionEngineImpl::HybridUnset() const {
|
||||
if (index_type_ != EngineType::FAISS_IVFSQ8H) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
auto hybrid_index = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_);
|
||||
if (hybrid_index == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (index_->GetType() == IndexType::FAISS_IDMAP) {
|
||||
return;
|
||||
}
|
||||
index_->UnsetQuantizer();
|
||||
hybrid_index->UnsetQuantizer();
|
||||
#endif
|
||||
}
|
||||
|
||||
Status
|
||||
ExecutionEngineImpl::AddWithIds(int64_t n, const float* xdata, const int64_t* xids) {
|
||||
auto status = index_->Add(n, xdata, xids);
|
||||
return status;
|
||||
auto dataset = knowhere::GenDatasetWithIds(n, index_->Dim(), xdata, xids);
|
||||
index_->Add(dataset, knowhere::Config());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ExecutionEngineImpl::AddWithIds(int64_t n, const uint8_t* xdata, const int64_t* xids) {
|
||||
auto status = index_->Add(n, xdata, xids);
|
||||
return status;
|
||||
auto dataset = knowhere::GenDatasetWithIds(n, index_->Dim(), xdata, xids);
|
||||
index_->Add(dataset, knowhere::Config());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t
|
||||
|
@ -336,7 +336,7 @@ ExecutionEngineImpl::Dimension() const {
|
|||
ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, return dimension " << dim_;
|
||||
return dim_;
|
||||
}
|
||||
return index_->Dimension();
|
||||
return index_->Dim();
|
||||
}
|
||||
|
||||
size_t
|
||||
|
@ -369,21 +369,25 @@ Status
|
|||
ExecutionEngineImpl::Load(bool to_cache) {
|
||||
// TODO(zhiru): refactor
|
||||
|
||||
index_ = std::static_pointer_cast<VecIndex>(cache::CpuCacheMgr::GetInstance()->GetIndex(location_));
|
||||
index_ = std::static_pointer_cast<knowhere::VecIndex>(cache::CpuCacheMgr::GetInstance()->GetIndex(location_));
|
||||
bool already_in_cache = (index_ != nullptr);
|
||||
if (!already_in_cache) {
|
||||
std::string segment_dir;
|
||||
utils::GetParentPath(location_, segment_dir);
|
||||
auto segment_reader_ptr = std::make_shared<segment::SegmentReader>(segment_dir);
|
||||
knowhere::VecIndexFactory& vec_index_factory = knowhere::VecIndexFactory::GetInstance();
|
||||
|
||||
if (utils::IsRawIndexType((int32_t)index_type_)) {
|
||||
index_ = index_type_ == EngineType::FAISS_IDMAP ? GetVecIndexFactory(IndexType::FAISS_IDMAP)
|
||||
: GetVecIndexFactory(IndexType::FAISS_BIN_IDMAP);
|
||||
if (index_type_ == EngineType::FAISS_IDMAP) {
|
||||
index_ = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IDMAP);
|
||||
} else {
|
||||
index_ = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP);
|
||||
}
|
||||
milvus::json conf{{knowhere::meta::DEVICEID, gpu_num_}, {knowhere::meta::DIM, dim_}};
|
||||
MappingMetricType(metric_type_, conf);
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_->index_type());
|
||||
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
|
||||
if (!adapter->CheckTrain(conf)) {
|
||||
if (!adapter->CheckTrain(conf, index_->index_mode())) {
|
||||
throw Exception(DB_ERROR, "Illegal index params");
|
||||
}
|
||||
|
||||
|
@ -413,26 +417,21 @@ ExecutionEngineImpl::Load(bool to_cache) {
|
|||
}
|
||||
}
|
||||
|
||||
ErrorCode ec = KNOWHERE_UNEXPECTED_ERROR;
|
||||
if (index_type_ == EngineType::FAISS_IDMAP) {
|
||||
auto bf_index = std::static_pointer_cast<knowhere::IDMAP>(index_);
|
||||
std::vector<float> float_vectors;
|
||||
float_vectors.resize(vectors_data.size() / sizeof(float));
|
||||
memcpy(float_vectors.data(), vectors_data.data(), vectors_data.size());
|
||||
ec = std::static_pointer_cast<BFIndex>(index_)->Build(conf);
|
||||
if (ec != KNOWHERE_SUCCESS) {
|
||||
return status;
|
||||
}
|
||||
status = std::static_pointer_cast<BFIndex>(index_)->AddWithoutIds(vectors->GetCount(),
|
||||
float_vectors.data(), Config());
|
||||
status = std::static_pointer_cast<BFIndex>(index_)->SetBlacklist(concurrent_bitset_ptr);
|
||||
bf_index->Train(knowhere::DatasetPtr(), conf);
|
||||
auto dataset = knowhere::GenDataset(vectors->GetCount(), this->dim_, float_vectors.data());
|
||||
bf_index->AddWithoutIds(dataset, conf);
|
||||
bf_index->SetBlacklist(concurrent_bitset_ptr);
|
||||
} else if (index_type_ == EngineType::FAISS_BIN_IDMAP) {
|
||||
ec = std::static_pointer_cast<BinBFIndex>(index_)->Build(conf);
|
||||
if (ec != KNOWHERE_SUCCESS) {
|
||||
return status;
|
||||
}
|
||||
status = std::static_pointer_cast<BinBFIndex>(index_)->AddWithoutIds(vectors->GetCount(),
|
||||
vectors_data.data(), Config());
|
||||
status = std::static_pointer_cast<BinBFIndex>(index_)->SetBlacklist(concurrent_bitset_ptr);
|
||||
auto bin_bf_index = std::static_pointer_cast<knowhere::BinaryIDMAP>(index_);
|
||||
bin_bf_index->Train(knowhere::DatasetPtr(), conf);
|
||||
auto dataset = knowhere::GenDataset(vectors->GetCount(), this->dim_, vectors_data.data());
|
||||
bin_bf_index->AddWithoutIds(dataset, conf);
|
||||
bin_bf_index->SetBlacklist(concurrent_bitset_ptr);
|
||||
}
|
||||
|
||||
int64_t index_size = vectors->Size(); // vector data size + vector ids size
|
||||
|
@ -444,7 +443,6 @@ ExecutionEngineImpl::Load(bool to_cache) {
|
|||
}
|
||||
|
||||
ENGINE_LOG_DEBUG << "Finished loading raw data from segment " << segment_dir;
|
||||
|
||||
} else {
|
||||
try {
|
||||
// size_t physical_size = PhysicalSize();
|
||||
|
@ -554,7 +552,8 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
|
|||
#endif
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
auto index = std::static_pointer_cast<VecIndex>(cache::GpuCacheMgr::GetInstance(device_id)->GetIndex(location_));
|
||||
auto data_obj_ptr = cache::GpuCacheMgr::GetInstance(device_id)->GetIndex(location_);
|
||||
auto index = std::static_pointer_cast<knowhere::VecIndex>(data_obj_ptr);
|
||||
bool already_in_cache = (index != nullptr);
|
||||
if (already_in_cache) {
|
||||
index_ = index;
|
||||
|
@ -565,7 +564,7 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
|
|||
}
|
||||
|
||||
try {
|
||||
index_ = index_->CopyToGpu(device_id);
|
||||
index_ = knowhere::cloner::CopyCpuToGpu(index_, device_id, knowhere::Config());
|
||||
ENGINE_LOG_DEBUG << "CPU to GPU" << device_id;
|
||||
} catch (std::exception& e) {
|
||||
ENGINE_LOG_ERROR << e.what();
|
||||
|
@ -587,7 +586,7 @@ ExecutionEngineImpl::CopyToIndexFileToGpu(uint64_t device_id) {
|
|||
// the ToIndexData is only a placeholder, cpu-copy-to-gpu action is performed in
|
||||
if (index_) {
|
||||
gpu_num_ = device_id;
|
||||
auto to_index_data = std::make_shared<ToIndexData>(index_->Size());
|
||||
auto to_index_data = std::make_shared<knowhere::ToIndexData>(index_->Size());
|
||||
cache::DataObjPtr obj = std::static_pointer_cast<cache::DataObj>(to_index_data);
|
||||
milvus::cache::GpuCacheMgr::GetInstance(device_id)->InsertItem(location_ + "_placeholder", obj);
|
||||
}
|
||||
|
@ -597,7 +596,8 @@ ExecutionEngineImpl::CopyToIndexFileToGpu(uint64_t device_id) {
|
|||
|
||||
Status
|
||||
ExecutionEngineImpl::CopyToCpu() {
|
||||
auto index = std::static_pointer_cast<VecIndex>(cache::CpuCacheMgr::GetInstance()->GetIndex(location_));
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
auto index = std::static_pointer_cast<knowhere::VecIndex>(cache::CpuCacheMgr::GetInstance()->GetIndex(location_));
|
||||
bool already_in_cache = (index != nullptr);
|
||||
if (already_in_cache) {
|
||||
index_ = index;
|
||||
|
@ -608,7 +608,7 @@ ExecutionEngineImpl::CopyToCpu() {
|
|||
}
|
||||
|
||||
try {
|
||||
index_ = index_->CopyToCpu();
|
||||
index_ = knowhere::cloner::CopyGpuToCpu(index_, knowhere::Config());
|
||||
ENGINE_LOG_DEBUG << "GPU to CPU";
|
||||
} catch (std::exception& e) {
|
||||
ENGINE_LOG_ERROR << e.what();
|
||||
|
@ -620,14 +620,18 @@ ExecutionEngineImpl::CopyToCpu() {
|
|||
Cache();
|
||||
}
|
||||
return Status::OK();
|
||||
#else
|
||||
ENGINE_LOG_ERROR << "Calling ExecutionEngineImpl::CopyToCpu when using CPU version";
|
||||
return Status(DB_ERROR, "Calling ExecutionEngineImpl::CopyToCpu when using CPU version");
|
||||
#endif
|
||||
}
|
||||
|
||||
ExecutionEnginePtr
|
||||
ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_type) {
|
||||
ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_;
|
||||
|
||||
auto from_index = std::dynamic_pointer_cast<BFIndex>(index_);
|
||||
auto bin_from_index = std::dynamic_pointer_cast<BinBFIndex>(index_);
|
||||
auto from_index = std::dynamic_pointer_cast<knowhere::IDMAP>(index_);
|
||||
auto bin_from_index = std::dynamic_pointer_cast<knowhere::BinaryIDMAP>(index_);
|
||||
if (from_index == nullptr && bin_from_index == nullptr) {
|
||||
ENGINE_LOG_ERROR << "ExecutionEngineImpl: from_index is null, failed to build index";
|
||||
return nullptr;
|
||||
|
@ -644,24 +648,36 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
|
|||
conf[knowhere::meta::DEVICEID] = gpu_num_;
|
||||
MappingMetricType(metric_type_, conf);
|
||||
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(to_index->GetType());
|
||||
if (!adapter->CheckTrain(conf)) {
|
||||
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(to_index->index_type());
|
||||
if (!adapter->CheckTrain(conf, to_index->index_mode())) {
|
||||
throw Exception(DB_ERROR, "Illegal index params");
|
||||
}
|
||||
ENGINE_LOG_DEBUG << "Index config: " << conf.dump();
|
||||
|
||||
auto status = Status::OK();
|
||||
std::vector<segment::doc_id_t> uids;
|
||||
faiss::ConcurrentBitsetPtr blacklist;
|
||||
if (from_index) {
|
||||
status = to_index->BuildAll(Count(), from_index->GetRawVectors(), from_index->GetRawIds(), conf);
|
||||
auto dataset =
|
||||
knowhere::GenDatasetWithIds(Count(), Dimension(), from_index->GetRawVectors(), from_index->GetRawIds());
|
||||
to_index->BuildAll(dataset, conf);
|
||||
uids = from_index->GetUids();
|
||||
from_index->GetBlacklist(blacklist);
|
||||
} else if (bin_from_index) {
|
||||
status = to_index->BuildAll(Count(), bin_from_index->GetRawVectors(), bin_from_index->GetRawIds(), conf);
|
||||
auto dataset = knowhere::GenDatasetWithIds(Count(), Dimension(), bin_from_index->GetRawVectors(),
|
||||
bin_from_index->GetRawIds());
|
||||
to_index->BuildAll(dataset, conf);
|
||||
uids = bin_from_index->GetUids();
|
||||
bin_from_index->GetBlacklist(blacklist);
|
||||
}
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
/* for GPU index, need copy back to CPU */
|
||||
if (to_index->index_mode() == knowhere::IndexMode::MODE_GPU) {
|
||||
auto device_index = std::dynamic_pointer_cast<knowhere::GPUIndex>(to_index);
|
||||
to_index = device_index->CopyGpuToCpu(conf);
|
||||
}
|
||||
#endif
|
||||
|
||||
to_index->SetUids(uids);
|
||||
ENGINE_LOG_DEBUG << "Set " << to_index->GetUids().size() << "uids for " << location;
|
||||
if (blacklist != nullptr) {
|
||||
|
@ -669,23 +685,31 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
|
|||
ENGINE_LOG_DEBUG << "Set blacklist for index " << location;
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
throw Exception(DB_ERROR, status.message());
|
||||
}
|
||||
|
||||
ENGINE_LOG_DEBUG << "Finish build index: " << location;
|
||||
return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, index_params_);
|
||||
}
|
||||
|
||||
// map offsets to ids
|
||||
void
|
||||
MapUids(const std::vector<segment::doc_id_t>& uids, int64_t* labels, size_t num) {
|
||||
MapAndCopyResult(const knowhere::DatasetPtr& dataset, const std::vector<milvus::segment::doc_id_t>& uids, int64_t nq,
|
||||
int64_t k, float* distances, int64_t* labels) {
|
||||
int64_t* res_ids = dataset->Get<int64_t*>(knowhere::meta::IDS);
|
||||
float* res_dist = dataset->Get<float*>(knowhere::meta::DISTANCE);
|
||||
|
||||
memcpy(distances, res_dist, sizeof(float) * nq * k);
|
||||
|
||||
/* map offsets to ids */
|
||||
int64_t num = nq * k;
|
||||
for (int64_t i = 0; i < num; ++i) {
|
||||
int64_t& offset = labels[i];
|
||||
int64_t offset = res_ids[i];
|
||||
if (offset != -1) {
|
||||
offset = uids[offset];
|
||||
labels[i] = uids[offset];
|
||||
} else {
|
||||
labels[i] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
free(res_ids);
|
||||
free(res_dist);
|
||||
}
|
||||
|
||||
Status
|
||||
|
@ -752,9 +776,9 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, const milvu
|
|||
|
||||
milvus::json conf = extra_params;
|
||||
conf[knowhere::meta::TOPK] = k;
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_->index_type());
|
||||
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
|
||||
if (!adapter->CheckSearch(conf, index_->GetType())) {
|
||||
if (!adapter->CheckSearch(conf, index_->index_type(), index_->index_mode())) {
|
||||
throw Exception(DB_ERROR, "Illegal search params");
|
||||
}
|
||||
|
||||
|
@ -762,24 +786,20 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, const milvu
|
|||
HybridLoad();
|
||||
}
|
||||
|
||||
rc.RecordSection("search prepare");
|
||||
auto status = index_->Search(n, data, distances, labels, conf);
|
||||
rc.RecordSection("search done");
|
||||
rc.RecordSection("query prepare");
|
||||
auto dataset = knowhere::GenDataset(n, index_->Dim(), data);
|
||||
auto result = index_->Query(dataset, conf);
|
||||
rc.RecordSection("query done");
|
||||
|
||||
// map offsets to ids
|
||||
ENGINE_LOG_DEBUG << "get uids " << index_->GetUids().size() << " from index " << location_;
|
||||
MapUids(index_->GetUids(), labels, n * k);
|
||||
|
||||
MapAndCopyResult(result, index_->GetUids(), n, k, distances, labels);
|
||||
rc.RecordSection("map uids " + std::to_string(n * k));
|
||||
|
||||
if (hybrid) {
|
||||
HybridUnset();
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
ENGINE_LOG_ERROR << "Search error:" << status.message();
|
||||
}
|
||||
return status;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
|
@ -794,9 +814,9 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, const mil
|
|||
|
||||
milvus::json conf = extra_params;
|
||||
conf[knowhere::meta::TOPK] = k;
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_->index_type());
|
||||
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
|
||||
if (!adapter->CheckSearch(conf, index_->GetType())) {
|
||||
if (!adapter->CheckSearch(conf, index_->index_type(), index_->index_mode())) {
|
||||
throw Exception(DB_ERROR, "Illegal search params");
|
||||
}
|
||||
|
||||
|
@ -804,24 +824,20 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, const mil
|
|||
HybridLoad();
|
||||
}
|
||||
|
||||
rc.RecordSection("search prepare");
|
||||
auto status = index_->Search(n, data, distances, labels, conf);
|
||||
rc.RecordSection("search done");
|
||||
rc.RecordSection("query prepare");
|
||||
auto dataset = knowhere::GenDataset(n, index_->Dim(), data);
|
||||
auto result = index_->Query(dataset, conf);
|
||||
rc.RecordSection("query done");
|
||||
|
||||
// map offsets to ids
|
||||
ENGINE_LOG_DEBUG << "get uids " << index_->GetUids().size() << " from index " << location_;
|
||||
MapUids(index_->GetUids(), labels, n * k);
|
||||
|
||||
MapAndCopyResult(result, index_->GetUids(), n, k, distances, labels);
|
||||
rc.RecordSection("map uids " + std::to_string(n * k));
|
||||
|
||||
if (hybrid) {
|
||||
HybridUnset();
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
ENGINE_LOG_ERROR << "Search error:" << status.message();
|
||||
}
|
||||
return status;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
|
@ -836,9 +852,9 @@ ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t
|
|||
|
||||
milvus::json conf = extra_params;
|
||||
conf[knowhere::meta::TOPK] = k;
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_->index_type());
|
||||
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
|
||||
if (!adapter->CheckSearch(conf, index_->GetType())) {
|
||||
if (!adapter->CheckSearch(conf, index_->index_type(), index_->index_mode())) {
|
||||
throw Exception(DB_ERROR, "Illegal search params");
|
||||
}
|
||||
|
||||
|
@ -887,15 +903,13 @@ ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t
|
|||
|
||||
rc.RecordSection("get offset");
|
||||
|
||||
auto status = Status::OK();
|
||||
if (!offsets.empty()) {
|
||||
status = index_->SearchById(offsets.size(), offsets.data(), distances, labels, conf);
|
||||
rc.RecordSection("search done");
|
||||
auto dataset = knowhere::GenDatasetWithIds(offsets.size(), index_->Dim(), nullptr, offsets.data());
|
||||
auto result = index_->QueryById(dataset, conf);
|
||||
rc.RecordSection("query by id done");
|
||||
|
||||
// map offsets to ids
|
||||
ENGINE_LOG_DEBUG << "get uids " << index_->GetUids().size() << " from index " << location_;
|
||||
MapUids(uids, labels, offsets.size() * k);
|
||||
|
||||
MapAndCopyResult(result, uids, offsets.size(), k, distances, labels);
|
||||
rc.RecordSection("map uids " + std::to_string(offsets.size() * k));
|
||||
}
|
||||
|
||||
|
@ -903,10 +917,7 @@ ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t
|
|||
HybridUnset();
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
ENGINE_LOG_ERROR << "Search error:" << status.message();
|
||||
}
|
||||
return status;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
|
@ -922,16 +933,16 @@ ExecutionEngineImpl::GetVectorByID(const int64_t& id, float* vector, bool hybrid
|
|||
|
||||
// Only one id for now
|
||||
std::vector<int64_t> ids{id};
|
||||
auto status = index_->GetVectorById(1, ids.data(), vector, milvus::json());
|
||||
auto dataset = knowhere::GenDatasetWithIds(1, index_->Dim(), nullptr, ids.data());
|
||||
auto result = index_->GetVectorById(dataset, knowhere::Config());
|
||||
float* res_vec = (float*)(result->Get<void*>(knowhere::meta::TENSOR));
|
||||
memcpy(vector, res_vec, sizeof(float) * 1 * index_->Dim());
|
||||
|
||||
if (hybrid) {
|
||||
HybridUnset();
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
ENGINE_LOG_ERROR << "Search error:" << status.message();
|
||||
}
|
||||
return status;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
|
@ -949,16 +960,16 @@ ExecutionEngineImpl::GetVectorByID(const int64_t& id, uint8_t* vector, bool hybr
|
|||
|
||||
// Only one id for now
|
||||
std::vector<int64_t> ids{id};
|
||||
auto status = index_->GetVectorById(1, ids.data(), vector, milvus::json());
|
||||
auto dataset = knowhere::GenDatasetWithIds(1, index_->Dim(), nullptr, ids.data());
|
||||
auto result = index_->GetVectorById(dataset, knowhere::Config());
|
||||
uint8_t* res_vec = (uint8_t*)(result->Get<void*>(knowhere::meta::TENSOR));
|
||||
memcpy(vector, res_vec, sizeof(uint8_t) * 1 * index_->Dim());
|
||||
|
||||
if (hybrid) {
|
||||
HybridUnset();
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
ENGINE_LOG_ERROR << "Search error:" << status.message();
|
||||
}
|
||||
return status;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "ExecutionEngine.h"
|
||||
#include "wrapper/VecIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
@ -29,8 +29,8 @@ class ExecutionEngineImpl : public ExecutionEngine {
|
|||
ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type,
|
||||
const milvus::json& index_params);
|
||||
|
||||
ExecutionEngineImpl(VecIndexPtr index, const std::string& location, EngineType index_type, MetricType metric_type,
|
||||
const milvus::json& index_params);
|
||||
ExecutionEngineImpl(knowhere::VecIndexPtr index, const std::string& location, EngineType index_type,
|
||||
MetricType metric_type, const milvus::json& index_params);
|
||||
|
||||
Status
|
||||
AddWithIds(int64_t n, const float* xdata, const int64_t* xids) override;
|
||||
|
@ -108,10 +108,10 @@ class ExecutionEngineImpl : public ExecutionEngine {
|
|||
}
|
||||
|
||||
private:
|
||||
VecIndexPtr
|
||||
knowhere::VecIndexPtr
|
||||
CreatetVecIndex(EngineType type);
|
||||
|
||||
VecIndexPtr
|
||||
knowhere::VecIndexPtr
|
||||
Load(const std::string& location);
|
||||
|
||||
void
|
||||
|
@ -121,7 +121,7 @@ class ExecutionEngineImpl : public ExecutionEngine {
|
|||
HybridUnset() const;
|
||||
|
||||
protected:
|
||||
VecIndexPtr index_ = nullptr;
|
||||
knowhere::VecIndexPtr index_ = nullptr;
|
||||
EngineType index_type_;
|
||||
MetricType metric_type_;
|
||||
|
||||
|
|
|
@ -9,8 +9,6 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "db/insert/MemTable.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
|
@ -20,9 +18,10 @@
|
|||
#include "cache/CpuCacheMgr.h"
|
||||
#include "db/OngoingFileChecker.h"
|
||||
#include "db/Utils.h"
|
||||
#include "db/insert/MemTable.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include "segment/SegmentReader.h"
|
||||
#include "utils/Log.h"
|
||||
#include "wrapper/VecIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
@ -243,11 +242,11 @@ MemTable::ApplyDeletes() {
|
|||
}
|
||||
|
||||
// Get all index that contains blacklist in cache
|
||||
std::vector<VecIndexPtr> indexes;
|
||||
std::vector<knowhere::VecIndexPtr> indexes;
|
||||
std::vector<faiss::ConcurrentBitsetPtr> blacklists;
|
||||
for (auto& file : segment_files) {
|
||||
auto index =
|
||||
std::static_pointer_cast<VecIndex>(cache::CpuCacheMgr::GetInstance()->GetIndex(file.location_));
|
||||
auto data_obj_ptr = cache::CpuCacheMgr::GetInstance()->GetIndex(file.location_);
|
||||
auto index = std::static_pointer_cast<knowhere::VecIndex>(data_obj_ptr);
|
||||
faiss::ConcurrentBitsetPtr blacklist = nullptr;
|
||||
if (index != nullptr) {
|
||||
index->GetBlacklist(blacklist);
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "wrapper/KnowhereResource.h"
|
||||
#include "index/archive/KnowhereResource.h"
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
#endif
|
|
@ -0,0 +1,172 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "config/Config.h"
|
||||
#include "index/archive/VecIndex.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndexFactory.h"
|
||||
#include "storage/disk/DiskIOReader.h"
|
||||
#include "storage/disk/DiskIOWriter.h"
|
||||
#include "storage/s3/S3IOReader.h"
|
||||
#include "storage/s3/S3IOWriter.h"
|
||||
#include "utils/Log.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
knowhere::VecIndexPtr
|
||||
LoadVecIndex(const knowhere::IndexType& type, const knowhere::BinarySet& index_binary, int64_t size) {
|
||||
knowhere::VecIndexFactory& vec_index_factory = knowhere::VecIndexFactory::GetInstance();
|
||||
auto index = vec_index_factory.CreateVecIndex(type, knowhere::IndexMode::MODE_CPU);
|
||||
if (index == nullptr)
|
||||
return nullptr;
|
||||
// else
|
||||
index->Load(index_binary);
|
||||
index->set_size(size);
|
||||
return index;
|
||||
}
|
||||
|
||||
knowhere::VecIndexPtr
|
||||
read_index(const std::string& location) {
|
||||
milvus::TimeRecorder recorder("read_index");
|
||||
knowhere::BinarySet load_data_list;
|
||||
|
||||
bool s3_enable = false;
|
||||
milvus::server::Config& config = milvus::server::Config::GetInstance();
|
||||
config.GetStorageConfigS3Enable(s3_enable);
|
||||
|
||||
std::shared_ptr<milvus::storage::IOReader> reader_ptr;
|
||||
if (s3_enable) {
|
||||
reader_ptr = std::make_shared<milvus::storage::S3IOReader>();
|
||||
} else {
|
||||
reader_ptr = std::make_shared<milvus::storage::DiskIOReader>();
|
||||
}
|
||||
|
||||
recorder.RecordSection("Start");
|
||||
reader_ptr->open(location);
|
||||
|
||||
size_t length = reader_ptr->length();
|
||||
if (length <= 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t rp = 0;
|
||||
reader_ptr->seekg(0);
|
||||
|
||||
int32_t current_type = 0;
|
||||
reader_ptr->read(¤t_type, sizeof(current_type));
|
||||
rp += sizeof(current_type);
|
||||
reader_ptr->seekg(rp);
|
||||
|
||||
while (rp < length) {
|
||||
size_t meta_length;
|
||||
reader_ptr->read(&meta_length, sizeof(meta_length));
|
||||
rp += sizeof(meta_length);
|
||||
reader_ptr->seekg(rp);
|
||||
|
||||
auto meta = new char[meta_length];
|
||||
reader_ptr->read(meta, meta_length);
|
||||
rp += meta_length;
|
||||
reader_ptr->seekg(rp);
|
||||
|
||||
size_t bin_length;
|
||||
reader_ptr->read(&bin_length, sizeof(bin_length));
|
||||
rp += sizeof(bin_length);
|
||||
reader_ptr->seekg(rp);
|
||||
|
||||
auto bin = new uint8_t[bin_length];
|
||||
reader_ptr->read(bin, bin_length);
|
||||
rp += bin_length;
|
||||
reader_ptr->seekg(rp);
|
||||
|
||||
auto binptr = std::make_shared<uint8_t>();
|
||||
binptr.reset(bin);
|
||||
load_data_list.Append(std::string(meta, meta_length), binptr, bin_length);
|
||||
delete[] meta;
|
||||
}
|
||||
reader_ptr->close();
|
||||
|
||||
double span = recorder.RecordSection("End");
|
||||
double rate = length * 1000000.0 / span / 1024 / 1024;
|
||||
STORAGE_LOG_DEBUG << "read_index(" << location << ") rate " << rate << "MB/s";
|
||||
|
||||
return LoadVecIndex(knowhere::OldIndexTypeToStr(current_type), load_data_list, length);
|
||||
}
|
||||
|
||||
milvus::Status
|
||||
write_index(knowhere::VecIndexPtr index, const std::string& location) {
|
||||
try {
|
||||
milvus::TimeRecorder recorder("write_index");
|
||||
|
||||
auto binaryset = index->Serialize(knowhere::Config());
|
||||
int32_t index_type = knowhere::StrToOldIndexType(index->index_type());
|
||||
|
||||
bool s3_enable = false;
|
||||
milvus::server::Config& config = milvus::server::Config::GetInstance();
|
||||
config.GetStorageConfigS3Enable(s3_enable);
|
||||
|
||||
std::shared_ptr<milvus::storage::IOWriter> writer_ptr;
|
||||
if (s3_enable) {
|
||||
writer_ptr = std::make_shared<milvus::storage::S3IOWriter>();
|
||||
} else {
|
||||
writer_ptr = std::make_shared<milvus::storage::DiskIOWriter>();
|
||||
}
|
||||
|
||||
recorder.RecordSection("Start");
|
||||
writer_ptr->open(location);
|
||||
|
||||
writer_ptr->write(&index_type, sizeof(index_type));
|
||||
|
||||
for (auto& iter : binaryset.binary_map_) {
|
||||
auto meta = iter.first.c_str();
|
||||
size_t meta_length = iter.first.length();
|
||||
writer_ptr->write(&meta_length, sizeof(meta_length));
|
||||
writer_ptr->write((void*)meta, meta_length);
|
||||
|
||||
auto binary = iter.second;
|
||||
int64_t binary_length = binary->size;
|
||||
writer_ptr->write(&binary_length, sizeof(binary_length));
|
||||
writer_ptr->write((void*)binary->data.get(), binary_length);
|
||||
}
|
||||
writer_ptr->close();
|
||||
|
||||
double span = recorder.RecordSection("End");
|
||||
double rate = writer_ptr->length() * 1000000.0 / span / 1024 / 1024;
|
||||
STORAGE_LOG_DEBUG << "write_index(" << location << ") rate " << rate << "MB/s";
|
||||
} catch (knowhere::KnowhereException& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return milvus::Status(milvus::KNOWHERE_UNEXPECTED_ERROR, e.what());
|
||||
} catch (std::exception& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
std::string estring(e.what());
|
||||
if (estring.find("No space left on device") != estring.npos) {
|
||||
WRAPPER_LOG_ERROR << "No space left on the device";
|
||||
return milvus::Status(milvus::KNOWHERE_NO_SPACE, "No space left on the device");
|
||||
} else {
|
||||
return milvus::Status(milvus::KNOWHERE_ERROR, e.what());
|
||||
}
|
||||
}
|
||||
return milvus::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace engine
|
||||
} // namespace milvus
|
|
@ -0,0 +1,37 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "cache/DataObj.h"
|
||||
#include "knowhere/common/BinarySet.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include "utils/Status.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
extern milvus::Status
|
||||
write_index(knowhere::VecIndexPtr index, const std::string& location);
|
||||
|
||||
extern knowhere::VecIndexPtr
|
||||
read_index(const std::string& location);
|
||||
|
||||
} // namespace engine
|
||||
} // namespace milvus
|
|
@ -21,31 +21,35 @@ if (NOT TARGET SPTAGLibStatic)
|
|||
endif ()
|
||||
|
||||
set(external_srcs
|
||||
knowhere/adapter/SptagAdapter.cpp
|
||||
knowhere/common/Exception.cpp
|
||||
knowhere/common/Timer.cpp
|
||||
)
|
||||
|
||||
set(index_srcs
|
||||
knowhere/index/preprocessor/Normalize.cpp
|
||||
knowhere/index/vector_index/IndexSPTAG.cpp
|
||||
knowhere/index/vector_index/IndexIDMAP.cpp
|
||||
knowhere/index/vector_index/IndexIVF.cpp
|
||||
knowhere/index/vector_index/IndexBinaryIVF.cpp
|
||||
knowhere/index/vector_index/FaissBaseBinaryIndex.cpp
|
||||
knowhere/index/vector_index/IndexBinaryIDMAP.cpp
|
||||
knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
|
||||
knowhere/index/vector_index/IndexNSG.cpp
|
||||
knowhere/index/vector_index/IndexHNSW.cpp
|
||||
knowhere/index/vector_index/nsg/NSG.cpp
|
||||
knowhere/index/vector_index/nsg/NSGIO.cpp
|
||||
knowhere/index/vector_index/nsg/NSGHelper.cpp
|
||||
knowhere/index/vector_index/nsg/Distance.cpp
|
||||
knowhere/index/vector_index/IndexIVFSQ.cpp
|
||||
knowhere/index/vector_index/IndexIVFPQ.cpp
|
||||
knowhere/index/vector_index/FaissBaseIndex.cpp
|
||||
knowhere/index/vector_index/adapter/SptagAdapter.cpp
|
||||
knowhere/index/vector_index/adapter/VectorAdapter.cpp
|
||||
knowhere/index/vector_index/helpers/FaissIO.cpp
|
||||
knowhere/index/vector_index/helpers/IndexParameter.cpp
|
||||
knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
|
||||
knowhere/index/vector_index/impl/nsg/Distance.cpp
|
||||
knowhere/index/vector_index/impl/nsg/NSG.cpp
|
||||
knowhere/index/vector_index/impl/nsg/NSGHelper.cpp
|
||||
knowhere/index/vector_index/impl/nsg/NSGIO.cpp
|
||||
knowhere/index/vector_index/ConfAdapter.cpp
|
||||
knowhere/index/vector_index/ConfAdapterMgr.cpp
|
||||
knowhere/index/vector_index/FaissBaseBinaryIndex.cpp
|
||||
knowhere/index/vector_index/FaissBaseIndex.cpp
|
||||
knowhere/index/vector_index/IndexBinaryIDMAP.cpp
|
||||
knowhere/index/vector_index/IndexBinaryIVF.cpp
|
||||
knowhere/index/vector_index/IndexHNSW.cpp
|
||||
knowhere/index/vector_index/IndexIDMAP.cpp
|
||||
knowhere/index/vector_index/IndexIVF.cpp
|
||||
knowhere/index/vector_index/IndexIVFPQ.cpp
|
||||
knowhere/index/vector_index/IndexIVFSQ.cpp
|
||||
knowhere/index/vector_index/IndexNSG.cpp
|
||||
knowhere/index/vector_index/IndexSPTAG.cpp
|
||||
knowhere/index/vector_index/IndexType.cpp
|
||||
knowhere/index/vector_index/VecIndexFactory.cpp
|
||||
)
|
||||
|
||||
set(depend_libs
|
||||
|
@ -82,13 +86,13 @@ if (KNOWHERE_GPU_VERSION)
|
|||
)
|
||||
|
||||
set(index_srcs ${index_srcs}
|
||||
knowhere/index/vector_index/IndexGPUIVF.cpp
|
||||
knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp
|
||||
knowhere/index/vector_index/gpu/IndexGPUIVF.cpp
|
||||
knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp
|
||||
knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp
|
||||
knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp
|
||||
knowhere/index/vector_index/helpers/Cloner.cpp
|
||||
knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp
|
||||
knowhere/index/vector_index/IndexGPUIVFSQ.cpp
|
||||
knowhere/index/vector_index/IndexIVFSQHybrid.cpp
|
||||
knowhere/index/vector_index/IndexGPUIVFPQ.cpp
|
||||
knowhere/index/vector_index/IndexGPUIDMAP.cpp
|
||||
)
|
||||
|
||||
endif ()
|
||||
|
@ -110,7 +114,7 @@ set(INDEX_INCLUDE_DIRS
|
|||
${INDEX_SOURCE_DIR}/knowhere
|
||||
${INDEX_SOURCE_DIR}/thirdparty
|
||||
${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService
|
||||
${ARROW_INCLUDE_DIR}
|
||||
# ${ARROW_INCLUDE_DIR}
|
||||
${FAISS_INCLUDE_DIR}
|
||||
${OPENBLAS_INCLUDE_DIR}
|
||||
${LAPACK_INCLUDE_DIR}
|
||||
|
|
|
@ -7,24 +7,21 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "Id.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
struct Binary {
|
||||
ID id;
|
||||
std::shared_ptr<uint8_t> data;
|
||||
int64_t size = 0;
|
||||
};
|
||||
|
@ -76,3 +73,4 @@ class BinarySet {
|
|||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -13,8 +13,10 @@
|
|||
|
||||
#include "src/utils/Json.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
using Config = milvus::json;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,74 +7,21 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <any>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <typeindex>
|
||||
#include <utility>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
struct BaseValue;
|
||||
using BasePtr = std::unique_ptr<BaseValue>;
|
||||
struct BaseValue {
|
||||
virtual ~BaseValue() = default;
|
||||
|
||||
// virtual BasePtr
|
||||
// Clone() const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AnyValue : public BaseValue {
|
||||
T data_;
|
||||
|
||||
template <typename U>
|
||||
explicit AnyValue(U&& value) : data_(std::forward<U>(value)) {
|
||||
}
|
||||
|
||||
// BasePtr
|
||||
// Clone() const {
|
||||
// return BasePtr(data_);
|
||||
// }
|
||||
};
|
||||
|
||||
struct Value {
|
||||
std::type_index type_;
|
||||
BasePtr data_;
|
||||
|
||||
template <typename U,
|
||||
class = typename std::enable_if<!std::is_same<typename std::decay<U>::type, Value>::value, U>::type>
|
||||
explicit Value(U&& value)
|
||||
: data_(new AnyValue<typename std::decay<U>::type>(std::forward<U>(value))),
|
||||
type_(std::type_index(typeid(typename std::decay<U>::type))) {
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
bool
|
||||
Is() const {
|
||||
return type_ == std::type_index(typeid(U));
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
U&
|
||||
AnyCast() {
|
||||
if (!Is<U>()) {
|
||||
std::stringstream ss;
|
||||
ss << "Can't cast t " << type_.name() << " to " << typeid(U).name();
|
||||
throw std::logic_error(ss.str());
|
||||
}
|
||||
|
||||
auto derived = dynamic_cast<AnyValue<U>*>(data_.get());
|
||||
return derived->data_;
|
||||
}
|
||||
};
|
||||
using Value = std::any;
|
||||
using ValuePtr = std::shared_ptr<Value>;
|
||||
|
||||
class Dataset {
|
||||
|
@ -85,18 +32,16 @@ class Dataset {
|
|||
void
|
||||
Set(const std::string& k, T&& v) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
auto value = std::make_shared<Value>(std::forward<T>(v));
|
||||
data_[k] = value;
|
||||
data_[k] = std::make_shared<Value>(std::forward<T>(v));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T
|
||||
Get(const std::string& k) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
auto finder = data_.find(k);
|
||||
if (finder != data_.end()) {
|
||||
return finder->second->AnyCast<T>();
|
||||
} else {
|
||||
try {
|
||||
return std::any_cast<T>(*(data_.at(k)));
|
||||
} catch (...) {
|
||||
throw std::logic_error("Can't find this key");
|
||||
}
|
||||
}
|
||||
|
@ -113,3 +58,4 @@ class Dataset {
|
|||
using DatasetPtr = std::shared_ptr<Dataset>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,13 +7,14 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
#include "Log.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
KnowhereException::KnowhereException(const std::string& msg) : msg(msg) {
|
||||
|
@ -41,3 +42,4 @@ KnowhereException::what() const noexcept {
|
|||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,13 +7,14 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class KnowhereException : public std::exception {
|
||||
|
@ -45,3 +46,4 @@ class KnowhereException : public std::exception {
|
|||
} while (false)
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -1,50 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
class ID {
|
||||
public:
|
||||
constexpr static int64_t kIDSize = 20;
|
||||
|
||||
public:
|
||||
const int32_t*
|
||||
data() const {
|
||||
return content_;
|
||||
}
|
||||
|
||||
int32_t*
|
||||
mutable_data() {
|
||||
return content_;
|
||||
}
|
||||
|
||||
bool
|
||||
IsValid() const;
|
||||
|
||||
std::string
|
||||
ToString() const;
|
||||
|
||||
bool
|
||||
operator==(const ID& that) const;
|
||||
|
||||
bool
|
||||
operator<(const ID& that) const;
|
||||
|
||||
protected:
|
||||
int32_t content_[5] = {};
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
|
@ -7,12 +7,13 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "easyloggingpp/easylogging++.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
#define KNOWHERE_DOMAIN_NAME "[KNOWHERE] "
|
||||
|
@ -26,3 +27,4 @@ namespace knowhere {
|
|||
#define KNOWHERE_LOG_FATAL LOG(FATAL) << KNOWHERE_DOMAIN_NAME
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,13 +7,14 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <iostream> // TODO(linxj): using Log instead
|
||||
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
TimeRecorder::TimeRecorder(const std::string& header, int64_t log_level) : header_(header), log_level_(log_level) {
|
||||
|
@ -81,3 +82,4 @@ TimeRecorder::ElapseFromBegin(const std::string& msg) {
|
|||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,13 +7,14 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class TimeRecorder {
|
||||
|
@ -45,3 +46,4 @@ class TimeRecorder {
|
|||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
|
@ -7,25 +8,22 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
namespace knowhere {
|
||||
|
||||
class WrapperException : public std::exception {
|
||||
public:
|
||||
explicit WrapperException(const std::string& msg);
|
||||
using MetricType = std::string;
|
||||
// using IndexType = std::string;
|
||||
using IDType = int64_t;
|
||||
using FloatType = float;
|
||||
using BinaryType = uint8_t;
|
||||
using GraphType = std::vector<std::vector<IDType>>;
|
||||
|
||||
const char*
|
||||
what() const noexcept override;
|
||||
|
||||
const std::string msg;
|
||||
};
|
||||
|
||||
} // namespace engine
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,56 +7,57 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "IndexModel.h"
|
||||
#include "IndexType.h"
|
||||
#include "cache/DataObj.h"
|
||||
#include "knowhere/common/BinarySet.h"
|
||||
#include "knowhere/common/Config.h"
|
||||
#include "knowhere/common/Dataset.h"
|
||||
#include "knowhere/index/preprocessor/Preprocessor.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class Index {
|
||||
class Index : public milvus::cache::DataObj {
|
||||
public:
|
||||
virtual BinarySet
|
||||
Serialize() = 0;
|
||||
Serialize(const Config& config = Config()) = 0;
|
||||
|
||||
virtual void
|
||||
Load(const BinarySet& index_binary) = 0;
|
||||
Load(const BinarySet&) = 0;
|
||||
|
||||
// @throw
|
||||
virtual DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
public:
|
||||
IndexType
|
||||
idx_type() const {
|
||||
return idx_type_;
|
||||
int64_t
|
||||
Size() override {
|
||||
return size_;
|
||||
}
|
||||
|
||||
void
|
||||
set_idx_type(IndexType idx_type) {
|
||||
idx_type_ = idx_type;
|
||||
set_size(const int64_t& size) {
|
||||
size_ = size;
|
||||
}
|
||||
|
||||
virtual void
|
||||
set_preprocessor(PreprocessorPtr preprocessor) {
|
||||
}
|
||||
|
||||
virtual void
|
||||
set_index_model(IndexModelPtr model) {
|
||||
}
|
||||
|
||||
private:
|
||||
IndexType idx_type_;
|
||||
protected:
|
||||
int64_t size_ = -1;
|
||||
};
|
||||
|
||||
using IndexPtr = std::shared_ptr<Index>;
|
||||
|
||||
// todo: remove from knowhere
|
||||
class ToIndexData : public milvus::cache::DataObj {
|
||||
public:
|
||||
explicit ToIndexData(int64_t size) : size_(size) {
|
||||
}
|
||||
|
||||
int64_t
|
||||
Size() override {
|
||||
return size_;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
enum class IndexType {
|
||||
kUnknown = 0,
|
||||
kVecIdxBegin = 100,
|
||||
kVecIVFFlat = kVecIdxBegin,
|
||||
kVecIdxEnd,
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
|
@ -1,51 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
//
|
||||
//#include "knowhere/index/vector_index/definitions.h"
|
||||
//#include "knowhere/common/config.h"
|
||||
#include "knowhere/index/preprocessor/Normalize.h"
|
||||
//
|
||||
//
|
||||
//
|
||||
// namespace knowhere {
|
||||
//
|
||||
// DatasetPtr
|
||||
// NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
|
||||
// // TODO: wrap dataset->tensor
|
||||
// auto tensor = dataset->tensor()[0];
|
||||
// auto p_data = (float *)tensor->raw_mutable_data();
|
||||
// auto dimension = tensor->shape()[1];
|
||||
// auto rows = tensor->shape()[0];
|
||||
//
|
||||
//#pragma omp parallel for
|
||||
// for (auto i = 0; i < rows; ++i) {
|
||||
// Normalize(&(p_data[i * dimension]), dimension);
|
||||
// }
|
||||
//}
|
||||
//
|
||||
// void
|
||||
// NormalizePreprocessor::Normalize(float *arr, int64_t dimension) {
|
||||
// double vector_length = 0;
|
||||
// for (auto j = 0; j < dimension; j++) {
|
||||
// double val = arr[j];
|
||||
// vector_length += val * val;
|
||||
// }
|
||||
// vector_length = std::sqrt(vector_length);
|
||||
// if (vector_length < 1e-6) {
|
||||
// auto val = (float) (1.0 / std::sqrt((double) dimension));
|
||||
// for (int j = 0; j < dimension; j++) arr[j] = val;
|
||||
// } else {
|
||||
// for (int j = 0; j < dimension; j++) arr[j] = (float) (arr[j] / vector_length);
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//} // namespace knowhere
|
||||
//
|
|
@ -1,37 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
//
|
||||
//#pragma once
|
||||
//
|
||||
//#include <memory>
|
||||
//#include "preprocessor.h"
|
||||
//
|
||||
//
|
||||
//
|
||||
// namespace knowhere {
|
||||
//
|
||||
// class NormalizePreprocessor : public Preprocessor {
|
||||
// public:
|
||||
// DatasetPtr
|
||||
// Preprocess(const DatasetPtr &input) override;
|
||||
//
|
||||
// private:
|
||||
//
|
||||
// void
|
||||
// Normalize(float *arr, int64_t dimension);
|
||||
//};
|
||||
//
|
||||
//
|
||||
// using NormalizePreprocessorPtr = std::shared_ptr<NormalizePreprocessor>;
|
||||
//
|
||||
//
|
||||
//} // namespace knowhere
|
||||
//
|
|
@ -7,7 +7,7 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "knowhere/common/Dataset.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class Preprocessor {
|
||||
|
@ -26,3 +27,4 @@ class Preprocessor {
|
|||
using PreprocessorPtr = std::shared_ptr<Preprocessor>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -9,21 +9,17 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "wrapper/ConfAdapter.h"
|
||||
#include "knowhere/index/vector_index/ConfAdapter.h"
|
||||
|
||||
#include <fiu-local.h>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "WrapperException.h"
|
||||
#include "config/Config.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
namespace knowhere {
|
||||
|
||||
#if CUDA_VERSION > 9000
|
||||
#define GPU_MAX_NRPOBE 2048
|
||||
|
@ -70,19 +66,16 @@ namespace engine {
|
|||
}
|
||||
|
||||
bool
|
||||
ConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
ConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::IP};
|
||||
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
ConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
|
||||
ConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
CheckIntByRange(knowhere::meta::TOPK, DEFAULT_MIN_K - 1, DEFAULT_MAX_K);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -100,7 +93,7 @@ MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist)
|
|||
}
|
||||
|
||||
bool
|
||||
IVFConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
IVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t MAX_NLIST = 999999;
|
||||
static int64_t MIN_NLIST = 1;
|
||||
|
||||
|
@ -111,42 +104,43 @@ IVFConfAdapter::CheckTrain(milvus::json& oricfg) {
|
|||
// CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
|
||||
|
||||
// auto tune params
|
||||
oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(),
|
||||
oricfg[knowhere::IndexParams::nlist].get<int64_t>(), 16384);
|
||||
int64_t nq = oricfg[knowhere::meta::ROWS].get<int64_t>();
|
||||
int64_t nlist = oricfg[knowhere::IndexParams::nlist].get<int64_t>();
|
||||
oricfg[knowhere::IndexParams::nlist] = MatchNlist(nq, nlist, 16384);
|
||||
|
||||
// Best Practice
|
||||
// static int64_t MIN_POINTS_PER_CENTROID = 40;
|
||||
// static int64_t MAX_POINTS_PER_CENTROID = 256;
|
||||
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg);
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
IVFConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
|
||||
IVFConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
static int64_t MIN_NPROBE = 1;
|
||||
static int64_t MAX_NPROBE = 999999; // todo(linxj): [1, nlist]
|
||||
|
||||
if (type == IndexType::FAISS_IVFPQ_GPU || type == IndexType::FAISS_IVFSQ8_GPU ||
|
||||
type == IndexType::FAISS_IVFSQ8_HYBRID || type == IndexType::FAISS_IVFFLAT_GPU) {
|
||||
if (mode == IndexMode::MODE_GPU) {
|
||||
CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, GPU_MAX_NRPOBE);
|
||||
} else {
|
||||
CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, MAX_NPROBE);
|
||||
}
|
||||
CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, MAX_NPROBE);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type);
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
IVFSQConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
IVFSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t DEFAULT_NBITS = 8;
|
||||
oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS;
|
||||
|
||||
return IVFConfAdapter::CheckTrain(oricfg);
|
||||
return IVFConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
IVFPQConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t DEFAULT_NBITS = 8;
|
||||
static int64_t MAX_NLIST = 999999;
|
||||
static int64_t MIN_NLIST = 1;
|
||||
|
@ -155,17 +149,11 @@ IVFPQConfAdapter::CheckTrain(milvus::json& oricfg) {
|
|||
|
||||
oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS;
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
Status s;
|
||||
bool enable_gpu = false;
|
||||
server::Config& config = server::Config::GetInstance();
|
||||
s = config.GetGpuResourceConfigEnable(enable_gpu);
|
||||
if (s.ok()) {
|
||||
if (mode == IndexMode::MODE_GPU) {
|
||||
CheckStrByValues(knowhere::Metric::TYPE, GPU_METRICS);
|
||||
} else {
|
||||
CheckStrByValues(knowhere::Metric::TYPE, CPU_METRICS);
|
||||
}
|
||||
#endif
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
|
||||
|
@ -214,7 +202,7 @@ IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset)
|
|||
}
|
||||
|
||||
bool
|
||||
NSGConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
NSGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t MIN_KNNG = 5;
|
||||
static int64_t MAX_KNNG = 300;
|
||||
static int64_t MIN_SEARCH_LENGTH = 10;
|
||||
|
@ -242,17 +230,17 @@ NSGConfAdapter::CheckTrain(milvus::json& oricfg) {
|
|||
}
|
||||
|
||||
bool
|
||||
NSGConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
|
||||
NSGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
static int64_t MIN_SEARCH_LENGTH = 1;
|
||||
static int64_t MAX_SEARCH_LENGTH = 300;
|
||||
|
||||
CheckIntByRange(knowhere::IndexParams::search_length, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type);
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
HNSWConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
HNSWConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t MIN_EFCONSTRUCTION = 100;
|
||||
static int64_t MAX_EFCONSTRUCTION = 500;
|
||||
static int64_t MIN_M = 5;
|
||||
|
@ -262,20 +250,20 @@ HNSWConfAdapter::CheckTrain(milvus::json& oricfg) {
|
|||
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg);
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
HNSWConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
|
||||
HNSWConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
static int64_t MAX_EF = 4096;
|
||||
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type);
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
BinIDMAPConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
BinIDMAPConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
|
||||
knowhere::Metric::TANIMOTO, knowhere::Metric::SUBSTRUCTURE,
|
||||
knowhere::Metric::SUPERSTRUCTURE};
|
||||
|
@ -287,7 +275,7 @@ BinIDMAPConfAdapter::CheckTrain(milvus::json& oricfg) {
|
|||
}
|
||||
|
||||
bool
|
||||
BinIVFConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
BinIVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
|
||||
knowhere::Metric::TANIMOTO};
|
||||
static int64_t MAX_NLIST = 999999;
|
||||
|
@ -308,5 +296,6 @@ BinIVFConfAdapter::CheckTrain(milvus::json& oricfg) {
|
|||
|
||||
return true;
|
||||
}
|
||||
} // namespace engine
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -14,49 +14,41 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "VecIndex.h"
|
||||
#include "utils/Json.h"
|
||||
#include "knowhere/common/Config.h"
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
namespace knowhere {
|
||||
|
||||
class ConfAdapter {
|
||||
public:
|
||||
virtual bool
|
||||
CheckTrain(milvus::json& oricfg);
|
||||
CheckTrain(Config& oricfg, const IndexMode mode);
|
||||
|
||||
virtual bool
|
||||
CheckSearch(milvus::json& oricfg, const IndexType& type);
|
||||
|
||||
// todo(linxj): refactor in next release.
|
||||
//
|
||||
// virtual bool
|
||||
// CheckTrain(milvus::json&, IndexMode&) = 0;
|
||||
//
|
||||
// virtual bool
|
||||
// CheckSearch(milvus::json&, const IndexType&, IndexMode&) = 0;
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode);
|
||||
};
|
||||
using ConfAdapterPtr = std::shared_ptr<ConfAdapter>;
|
||||
|
||||
class IVFConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
bool
|
||||
CheckSearch(milvus::json& oricfg, const IndexType& type) override;
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class IVFSQConfAdapter : public IVFConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class IVFPQConfAdapter : public IVFConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
static void
|
||||
GetValidMList(int64_t dimension, std::vector<int64_t>& resset);
|
||||
|
@ -65,32 +57,32 @@ class IVFPQConfAdapter : public IVFConfAdapter {
|
|||
class NSGConfAdapter : public IVFConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
bool
|
||||
CheckSearch(milvus::json& oricfg, const IndexType& type) override;
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class BinIDMAPConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class BinIVFConfAdapter : public IVFConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class HNSWConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
bool
|
||||
CheckSearch(milvus::json& oricfg, const IndexType& type) override;
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
} // namespace engine
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/ConfAdapterMgr.h"
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
ConfAdapterPtr
|
||||
AdapterMgr::GetAdapter(const IndexType type) {
|
||||
if (!init_)
|
||||
RegisterAdapter();
|
||||
|
||||
try {
|
||||
return table_.at(type)();
|
||||
} catch (...) {
|
||||
KNOWHERE_THROW_MSG("Can not find this type of confadapter");
|
||||
}
|
||||
}
|
||||
|
||||
#define REGISTER_CONF_ADAPTER(T, TYPE, NAME) static AdapterMgr::register_t<T> reg_##NAME##_(TYPE)
|
||||
|
||||
void
|
||||
AdapterMgr::RegisterAdapter() {
|
||||
init_ = true;
|
||||
|
||||
REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_FAISS_IDMAP, idmap_adapter);
|
||||
REGISTER_CONF_ADAPTER(IVFConfAdapter, IndexEnum::INDEX_FAISS_IVFFLAT, ivf_adapter);
|
||||
REGISTER_CONF_ADAPTER(IVFPQConfAdapter, IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_adapter);
|
||||
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq8_adapter);
|
||||
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8H, ivfsq8h_adapter);
|
||||
REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IDMAP, idmap_bin_adapter);
|
||||
REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivf_bin_adapter);
|
||||
REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexEnum::INDEX_NSG, nsg_adapter);
|
||||
REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_SPTAG_KDT_RNT, sptag_kdt_adapter);
|
||||
REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_SPTAG_BKT_RNT, sptag_bkt_adapter);
|
||||
REGISTER_CONF_ADAPTER(HNSWConfAdapter, IndexEnum::INDEX_HNSW, hnsw_adapter);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -12,21 +12,21 @@
|
|||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ConfAdapter.h"
|
||||
#include "VecIndex.h"
|
||||
#include "knowhere/index/vector_index/ConfAdapter.h"
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
namespace knowhere {
|
||||
|
||||
class AdapterMgr {
|
||||
public:
|
||||
template <typename T>
|
||||
struct register_t {
|
||||
explicit register_t(const IndexType& key) {
|
||||
AdapterMgr::GetInstance().table_.emplace(key, [] { return std::make_shared<T>(); });
|
||||
explicit register_t(const IndexType type) {
|
||||
AdapterMgr::GetInstance().table_[type] = ([] { return std::make_shared<T>(); });
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -37,15 +37,15 @@ class AdapterMgr {
|
|||
}
|
||||
|
||||
ConfAdapterPtr
|
||||
GetAdapter(const IndexType& indexType);
|
||||
GetAdapter(const IndexType indexType);
|
||||
|
||||
void
|
||||
RegisterAdapter();
|
||||
|
||||
protected:
|
||||
bool init_ = false;
|
||||
std::map<IndexType, std::function<ConfAdapterPtr()> > table_;
|
||||
std::unordered_map<IndexType, std::function<ConfAdapterPtr()>> table_;
|
||||
};
|
||||
|
||||
} // namespace engine
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -11,31 +11,24 @@
|
|||
|
||||
#include <faiss/index_io.h>
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
FaissBaseBinaryIndex::FaissBaseBinaryIndex(std::shared_ptr<faiss::IndexBinary> index) : index_(std::move(index)) {
|
||||
}
|
||||
|
||||
BinarySet
|
||||
FaissBaseBinaryIndex::SerializeImpl() {
|
||||
FaissBaseBinaryIndex::SerializeImpl(const IndexType& type) {
|
||||
try {
|
||||
faiss::IndexBinary* index = index_.get();
|
||||
|
||||
// SealImpl();
|
||||
|
||||
MemoryIOWriter writer;
|
||||
faiss::write_index_binary(index, &writer);
|
||||
auto data = std::make_shared<uint8_t>();
|
||||
data.reset(writer.data_);
|
||||
|
||||
BinarySet res_set;
|
||||
// TODO(linxj): use virtual func Name() instead of raw string.
|
||||
res_set.Append("BinaryIVF", data, writer.rp);
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
|
@ -44,7 +37,7 @@ FaissBaseBinaryIndex::SerializeImpl() {
|
|||
}
|
||||
|
||||
void
|
||||
FaissBaseBinaryIndex::LoadImpl(const BinarySet& index_binary) {
|
||||
FaissBaseBinaryIndex::LoadImpl(const BinarySet& index_binary, const IndexType& type) {
|
||||
auto binary = index_binary.GetByName("BinaryIVF");
|
||||
|
||||
MemoryIOReader reader;
|
||||
|
@ -52,8 +45,8 @@ FaissBaseBinaryIndex::LoadImpl(const BinarySet& index_binary) {
|
|||
reader.data_ = binary->data.get();
|
||||
|
||||
faiss::IndexBinary* index = faiss::read_index_binary(&reader);
|
||||
|
||||
index_.reset(index);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -12,26 +12,31 @@
|
|||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include <faiss/IndexBinary.h>
|
||||
|
||||
#include "knowhere/common/BinarySet.h"
|
||||
#include "knowhere/common/Dataset.h"
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class FaissBaseBinaryIndex {
|
||||
protected:
|
||||
explicit FaissBaseBinaryIndex(std::shared_ptr<faiss::IndexBinary> index);
|
||||
explicit FaissBaseBinaryIndex(std::shared_ptr<faiss::IndexBinary> index) : index_(std::move(index)) {
|
||||
}
|
||||
|
||||
virtual BinarySet
|
||||
SerializeImpl();
|
||||
SerializeImpl(const IndexType& type);
|
||||
|
||||
virtual void
|
||||
LoadImpl(const BinarySet& index_binary);
|
||||
LoadImpl(const BinarySet& index_binary, const IndexType& type);
|
||||
|
||||
public:
|
||||
std::shared_ptr<faiss::IndexBinary> index_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,31 +7,25 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <faiss/index_io.h>
|
||||
#include <fiu-local.h>
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/FaissBaseIndex.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
FaissBaseIndex::FaissBaseIndex(std::shared_ptr<faiss::Index> index) : index_(std::move(index)) {
|
||||
}
|
||||
|
||||
BinarySet
|
||||
FaissBaseIndex::SerializeImpl() {
|
||||
FaissBaseIndex::SerializeImpl(const IndexType& type) {
|
||||
try {
|
||||
fiu_do_on("FaissBaseIndex.SerializeImpl.throw_exception", throw std::exception());
|
||||
faiss::Index* index = index_.get();
|
||||
|
||||
// SealImpl();
|
||||
|
||||
MemoryIOWriter writer;
|
||||
faiss::write_index(index, &writer);
|
||||
auto data = std::make_shared<uint8_t>();
|
||||
|
@ -47,32 +41,18 @@ FaissBaseIndex::SerializeImpl() {
|
|||
}
|
||||
|
||||
void
|
||||
FaissBaseIndex::LoadImpl(const BinarySet& index_binary) {
|
||||
auto binary = index_binary.GetByName("IVF");
|
||||
FaissBaseIndex::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
|
||||
auto binary = binary_set.GetByName("IVF");
|
||||
|
||||
MemoryIOReader reader;
|
||||
reader.total = binary->size;
|
||||
reader.data_ = binary->data.get();
|
||||
|
||||
faiss::Index* index = faiss::read_index(&reader);
|
||||
|
||||
index_.reset(index);
|
||||
|
||||
SealImpl();
|
||||
}
|
||||
|
||||
void
|
||||
FaissBaseIndex::SealImpl() {
|
||||
#ifdef CUSTOMIZATION
|
||||
faiss::Index* index = index_.get();
|
||||
auto idx = dynamic_cast<faiss::IndexIVF*>(index);
|
||||
if (idx != nullptr) {
|
||||
// To be deleted
|
||||
KNOWHERE_LOG_DEBUG << "Test before to_readonly:"
|
||||
<< " IVF READONLY " << std::boolalpha << idx->is_readonly();
|
||||
idx->to_readonly();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,33 +7,39 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include <faiss/Index.h>
|
||||
|
||||
#include "knowhere/common/BinarySet.h"
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class FaissBaseIndex {
|
||||
protected:
|
||||
explicit FaissBaseIndex(std::shared_ptr<faiss::Index> index);
|
||||
explicit FaissBaseIndex(std::shared_ptr<faiss::Index> index) : index_(std::move(index)) {
|
||||
}
|
||||
|
||||
virtual BinarySet
|
||||
SerializeImpl();
|
||||
SerializeImpl(const IndexType& type);
|
||||
|
||||
virtual void
|
||||
LoadImpl(const BinarySet& index_binary);
|
||||
LoadImpl(const BinarySet&, const IndexType& type);
|
||||
|
||||
virtual void
|
||||
SealImpl();
|
||||
SealImpl() { /* do nothing */
|
||||
}
|
||||
|
||||
public:
|
||||
std::shared_ptr<faiss::Index> index_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -17,41 +17,43 @@
|
|||
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
BinarySet
|
||||
BinaryIDMAP::Serialize() {
|
||||
BinaryIDMAP::Serialize(const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl();
|
||||
return SerializeImpl(index_type_);
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::Load(const BinarySet& index_binary) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(index_binary);
|
||||
LoadImpl(index_binary, index_type_);
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIDMAP::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GETBINARYTENSOR(dataset)
|
||||
GETTENSOR(dataset_ptr)
|
||||
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
int64_t k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
search_impl(rows, (uint8_t*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, Config());
|
||||
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, Config());
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
|
@ -70,44 +72,65 @@ BinaryIDMAP::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
return ret_ds;
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& cfg) {
|
||||
int32_t* pdistances = (int32_t*)distances;
|
||||
index_->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
|
||||
DatasetPtr
|
||||
BinaryIDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
auto dim = dataset_ptr->Get<int64_t>(meta::DIM);
|
||||
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
int64_t k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
auto* pdistances = (int32_t*)p_dist;
|
||||
index_->search_by_id(rows, p_data, k, pdistances, p_id, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
auto pf_dist = (float*)malloc(p_dist_size);
|
||||
int32_t* pi_dist = (int32_t*)p_dist;
|
||||
for (int i = 0; i < elems; i++) {
|
||||
*(pf_dist + i) = (float)(*(pi_dist + i));
|
||||
}
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, pf_dist);
|
||||
free(p_dist);
|
||||
} else {
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
}
|
||||
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
BinaryIDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
GETBINARYTENSOR(dataset)
|
||||
GETTENSORWITHIDS(dataset_ptr)
|
||||
|
||||
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
|
||||
index_->add_with_ids(rows, (uint8_t*)p_data, p_ids);
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::Train(const Config& config) {
|
||||
const char* type = "BFlat";
|
||||
auto index = faiss::index_binary_factory(config[meta::DIM].get<int64_t>(), type,
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
BinaryIDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
const char* desc = "BFlat";
|
||||
int64_t dim = config[meta::DIM].get<int64_t>();
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
auto index = faiss::index_binary_factory(dim, desc, metric_type);
|
||||
index_.reset(index);
|
||||
}
|
||||
|
||||
int64_t
|
||||
BinaryIDMAP::Count() {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t
|
||||
BinaryIDMAP::Dimension() {
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
const uint8_t*
|
||||
BinaryIDMAP::GetRawVectors() {
|
||||
try {
|
||||
|
@ -130,18 +153,13 @@ BinaryIDMAP::GetRawIds() {
|
|||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::Seal() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::AddWithoutId(const DatasetPtr& dataset, const Config& config) {
|
||||
BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
GETBINARYTENSOR(dataset)
|
||||
GETTENSOR(dataset_ptr)
|
||||
|
||||
std::vector<int64_t> new_ids(rows);
|
||||
for (int i = 0; i < rows; ++i) {
|
||||
|
@ -152,15 +170,15 @@ BinaryIDMAP::AddWithoutId(const DatasetPtr& dataset, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIDMAP::GetVectorById(const DatasetPtr& dataset, const Config& config) {
|
||||
BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
// GETBINARYTENSOR(dataset)
|
||||
// auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
auto elems = dataset->Get<int64_t>(meta::DIM);
|
||||
// GETBINARYTENSOR(dataset_ptr)
|
||||
// auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
auto elems = dataset_ptr->Get<int64_t>(meta::DIM);
|
||||
|
||||
size_t p_x_size = sizeof(uint8_t) * elems;
|
||||
auto p_x = (uint8_t*)malloc(p_x_size);
|
||||
|
@ -172,51 +190,12 @@ BinaryIDMAP::GetVectorById(const DatasetPtr& dataset, const Config& config) {
|
|||
return ret_ds;
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIDMAP::SearchById(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
auto dim = dataset->Get<int64_t>(meta::DIM);
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
auto* pdistances = (int32_t*)p_dist;
|
||||
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), pdistances, p_id, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
auto pf_dist = (float*)malloc(p_dist_size);
|
||||
int32_t* pi_dist = (int32_t*)p_dist;
|
||||
for (int i = 0; i < elems; i++) {
|
||||
*(pf_dist + i) = (float)(*(pi_dist + i));
|
||||
}
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, pf_dist);
|
||||
free(p_dist);
|
||||
} else {
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
}
|
||||
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::SetBlacklist(faiss::ConcurrentBitsetPtr list) {
|
||||
bitset_ = std::move(list);
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::GetBlacklist(faiss::ConcurrentBitsetPtr& list) {
|
||||
list = bitset_;
|
||||
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& config) {
|
||||
int32_t* pdistances = (int32_t*)distances;
|
||||
index_->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -11,81 +11,84 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <faiss/utils/ConcurrentBitset.h>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "FaissBaseBinaryIndex.h"
|
||||
#include "VectorIndex.h"
|
||||
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class BinaryIDMAP : public VectorIndex, public FaissBaseBinaryIndex {
|
||||
class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
||||
public:
|
||||
BinaryIDMAP() : FaissBaseBinaryIndex(nullptr) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_BIN_IDMAP;
|
||||
}
|
||||
|
||||
explicit BinaryIDMAP(std::shared_ptr<faiss::IndexBinary> index) : FaissBaseBinaryIndex(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_BIN_IDMAP;
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
Serialize(const Config& config = Config()) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
void
|
||||
Train(const DatasetPtr&, const Config&) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr&, const Config&) override;
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
AddWithoutId(const DatasetPtr& dataset, const Config& config);
|
||||
|
||||
void
|
||||
Train(const Config& config);
|
||||
DatasetPtr
|
||||
QueryById(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
Count() override {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t
|
||||
Dimension() override;
|
||||
Dim() override {
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
void
|
||||
Seal() override;
|
||||
int64_t
|
||||
Size() override {
|
||||
if (size_ != -1) {
|
||||
return size_;
|
||||
}
|
||||
return Count() * Dim() * sizeof(uint8_t);
|
||||
}
|
||||
|
||||
const uint8_t*
|
||||
DatasetPtr
|
||||
GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
virtual const uint8_t*
|
||||
GetRawVectors();
|
||||
|
||||
const int64_t*
|
||||
virtual const int64_t*
|
||||
GetRawIds();
|
||||
|
||||
DatasetPtr
|
||||
GetVectorById(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
DatasetPtr
|
||||
SearchById(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
SetBlacklist(faiss::ConcurrentBitsetPtr list);
|
||||
|
||||
void
|
||||
GetBlacklist(faiss::ConcurrentBitsetPtr& list);
|
||||
|
||||
protected:
|
||||
virtual void
|
||||
search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
|
||||
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
|
||||
private:
|
||||
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
|
||||
};
|
||||
|
||||
using BinaryIDMAPPtr = std::shared_ptr<BinaryIDMAP>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -17,47 +17,49 @@
|
|||
#include <chrono>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
using stdclock = std::chrono::high_resolution_clock;
|
||||
|
||||
BinarySet
|
||||
BinaryIVF::Serialize() {
|
||||
BinaryIVF::Serialize(const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl();
|
||||
return SerializeImpl(index_type_);
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::Load(const BinarySet& index_binary) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(index_binary);
|
||||
LoadImpl(index_binary, index_type_);
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIVF::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
GETBINARYTENSOR(dataset)
|
||||
GETTENSOR(dataset_ptr)
|
||||
|
||||
try {
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
int64_t k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
search_impl(rows, (uint8_t*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, config);
|
||||
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
|
@ -81,80 +83,73 @@ BinaryIVF::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& cfg) {
|
||||
auto params = GenParams(cfg);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
int32_t* pdistances = (int32_t*)distances;
|
||||
stdclock::time_point before = stdclock::now();
|
||||
|
||||
// todo: remove static cast (zhiru)
|
||||
static_cast<faiss::IndexBinary*>(index_.get())->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
|
||||
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
KNOWHERE_LOG_DEBUG << "IVF search cost: " << search_cost
|
||||
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
|
||||
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
|
||||
faiss::indexIVF_stats.quantization_time = 0;
|
||||
faiss::indexIVF_stats.search_time = 0;
|
||||
}
|
||||
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
BinaryIVF::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFSearchParameters>();
|
||||
params->nprobe = config[IndexParams::nprobe];
|
||||
// params->max_codes = config["max_code"];
|
||||
return params;
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
BinaryIVF::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
GETBINARYTENSOR(dataset)
|
||||
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
faiss::IndexBinary* coarse_quantizer =
|
||||
new faiss::IndexBinaryFlat(dim, GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, config[IndexParams::nlist],
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
index->train(rows, (uint8_t*)p_data);
|
||||
index->add_with_ids(rows, (uint8_t*)p_data, p_ids);
|
||||
index_ = index;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64_t
|
||||
BinaryIVF::Count() {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t
|
||||
BinaryIVF::Dimension() {
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
KNOWHERE_THROW_MSG("not support yet");
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::Seal() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIVF::GetVectorById(const DatasetPtr& dataset, const Config& config) {
|
||||
BinaryIVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
// GETBINARYTENSOR(dataset)
|
||||
// auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
auto elems = dataset->Get<int64_t>(meta::DIM);
|
||||
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
try {
|
||||
int64_t k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
int32_t* pdistances = (int32_t*)p_dist;
|
||||
index_->search_by_id(rows, p_data, k, pdistances, p_id, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
auto pf_dist = (float*)malloc(p_dist_size);
|
||||
int32_t* pi_dist = (int32_t*)p_dist;
|
||||
for (int i = 0; i < elems; i++) {
|
||||
*(pf_dist + i) = (float)(*(pi_dist + i));
|
||||
}
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, pf_dist);
|
||||
free(p_dist);
|
||||
} else {
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
}
|
||||
|
||||
return ret_ds;
|
||||
} catch (faiss::FaissException& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GETTENSORWITHIDS(dataset_ptr)
|
||||
|
||||
int64_t nlist = config[IndexParams::nlist];
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
faiss::IndexBinary* coarse_quantizer = new faiss::IndexBinaryFlat(dim, metric_type);
|
||||
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, nlist, metric_type);
|
||||
index->train(rows, (uint8_t*)p_data);
|
||||
index->add_with_ids(rows, (uint8_t*)p_data, p_ids);
|
||||
index_ = index;
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
// GETBINARYTENSOR(dataset_ptr)
|
||||
// auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
auto elems = dataset_ptr->Get<int64_t>(meta::DIM);
|
||||
|
||||
try {
|
||||
size_t p_x_size = sizeof(uint8_t) * elems;
|
||||
|
@ -172,57 +167,34 @@ BinaryIVF::GetVectorById(const DatasetPtr& dataset, const Config& config) {
|
|||
}
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIVF::SearchById(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
try {
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
int32_t* pdistances = (int32_t*)p_dist;
|
||||
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), pdistances, p_id, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
auto pf_dist = (float*)malloc(p_dist_size);
|
||||
int32_t* pi_dist = (int32_t*)p_dist;
|
||||
for (int i = 0; i < elems; i++) {
|
||||
*(pf_dist + i) = (float)(*(pi_dist + i));
|
||||
}
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, pf_dist);
|
||||
free(p_dist);
|
||||
} else {
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
}
|
||||
|
||||
return ret_ds;
|
||||
} catch (faiss::FaissException& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
BinaryIVF::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFSearchParameters>();
|
||||
params->nprobe = config[IndexParams::nprobe];
|
||||
// params->max_codes = config["max_code"];
|
||||
return params;
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::SetBlacklist(faiss::ConcurrentBitsetPtr list) {
|
||||
bitset_ = std::move(list);
|
||||
}
|
||||
BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& config) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
int32_t* pdistances = (int32_t*)distances;
|
||||
stdclock::time_point before = stdclock::now();
|
||||
|
||||
void
|
||||
BinaryIVF::GetBlacklist(faiss::ConcurrentBitsetPtr& list) {
|
||||
list = bitset_;
|
||||
// todo: remove static cast (zhiru)
|
||||
static_cast<faiss::IndexBinary*>(index_.get())->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
|
||||
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
KNOWHERE_LOG_DEBUG << "IVF search cost: " << search_cost
|
||||
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
|
||||
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
|
||||
faiss::indexIVF_stats.quantization_time = 0;
|
||||
faiss::indexIVF_stats.search_time = 0;
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -16,71 +16,88 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <faiss/utils/ConcurrentBitset.h>
|
||||
#include "FaissBaseBinaryIndex.h"
|
||||
#include "VectorIndex.h"
|
||||
#include "faiss/IndexIVF.h"
|
||||
#include <faiss/IndexIVF.h>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class BinaryIVF : public VectorIndex, public FaissBaseBinaryIndex {
|
||||
class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
|
||||
public:
|
||||
BinaryIVF() : FaissBaseBinaryIndex(nullptr) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
|
||||
}
|
||||
|
||||
explicit BinaryIVF(std::shared_ptr<faiss::IndexBinary> index) : FaissBaseBinaryIndex(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
Serialize(const Config& config = Config()) override;
|
||||
|
||||
void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override {
|
||||
Train(dataset_ptr, config);
|
||||
}
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
void
|
||||
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset_ptr, const Config& config) override {
|
||||
KNOWHERE_THROW_MSG("not support yet");
|
||||
}
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override {
|
||||
KNOWHERE_THROW_MSG("AddWithoutIds is not supported");
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Seal() override;
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
DatasetPtr
|
||||
QueryById(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
Count() override {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t
|
||||
Dimension() override;
|
||||
Dim() override {
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
int64_t
|
||||
Size() override {
|
||||
if (size_ != -1) {
|
||||
return size_;
|
||||
}
|
||||
return Count() * Dim() * sizeof(uint8_t);
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
GetVectorById(const DatasetPtr& dataset, const Config& config);
|
||||
|
||||
DatasetPtr
|
||||
SearchById(const DatasetPtr& dataset, const Config& config);
|
||||
|
||||
void
|
||||
SetBlacklist(faiss::ConcurrentBitsetPtr list);
|
||||
|
||||
void
|
||||
GetBlacklist(faiss::ConcurrentBitsetPtr& list);
|
||||
GetVectorById(const DatasetPtr& dataset_ptr, const Config& config);
|
||||
|
||||
protected:
|
||||
virtual std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GenParams(const Config& config);
|
||||
|
||||
virtual void
|
||||
search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
|
||||
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
|
||||
private:
|
||||
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
|
||||
};
|
||||
|
||||
using BinaryIVFIndexPtr = std::shared_ptr<BinaryIVF>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -1,86 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
class GPUIndex {
|
||||
public:
|
||||
explicit GPUIndex(const int& device_id) : gpu_id_(device_id) {
|
||||
}
|
||||
|
||||
GPUIndex(const int& device_id, const ResPtr& resource) : gpu_id_(device_id), res_(resource) {
|
||||
}
|
||||
|
||||
virtual VectorIndexPtr
|
||||
CopyGpuToCpu(const Config& config) = 0;
|
||||
|
||||
virtual VectorIndexPtr
|
||||
CopyGpuToGpu(const int64_t& device_id, const Config& config) = 0;
|
||||
|
||||
void
|
||||
SetGpuDevice(const int& gpu_id);
|
||||
|
||||
const int64_t&
|
||||
GetGpuDevice();
|
||||
|
||||
protected:
|
||||
int64_t gpu_id_;
|
||||
ResWPtr res_;
|
||||
};
|
||||
|
||||
class GPUIVF : public IVF, public GPUIndex {
|
||||
public:
|
||||
explicit GPUIVF(const int& device_id) : IVF(), GPUIndex(device_id) {
|
||||
}
|
||||
|
||||
explicit GPUIVF(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& resource)
|
||||
: IVF(std::move(index)), GPUIndex(device_id, resource) {
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
set_index_model(IndexModelPtr model) override;
|
||||
|
||||
// DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const Config& config) override;
|
||||
|
||||
VectorIndexPtr
|
||||
CopyGpuToGpu(const int64_t& device_id, const Config& config) override;
|
||||
|
||||
// VectorIndexPtr
|
||||
// Clone() final;
|
||||
|
||||
protected:
|
||||
void
|
||||
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) override;
|
||||
|
||||
BinarySet
|
||||
SerializeImpl() override;
|
||||
|
||||
void
|
||||
LoadImpl(const BinarySet& index_binary) override;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
|
@ -20,11 +20,12 @@
|
|||
#include "hnswlib/hnswalg.h"
|
||||
#include "hnswlib/space_ip.h"
|
||||
#include "hnswlib/space_l2.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
void
|
||||
|
@ -36,7 +37,7 @@ normalize_vector(float* data, float* norm_array, size_t dim) {
|
|||
}
|
||||
|
||||
BinarySet
|
||||
IndexHNSW::Serialize() {
|
||||
IndexHNSW::Serialize(const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -74,12 +75,62 @@ IndexHNSW::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GETTENSOR(dataset_ptr)
|
||||
|
||||
hnswlib::SpaceInterface<float>* space;
|
||||
if (config[Metric::TYPE] == Metric::L2) {
|
||||
space = new hnswlib::L2Space(dim);
|
||||
} else if (config[Metric::TYPE] == Metric::IP) {
|
||||
space = new hnswlib::InnerProductSpace(dim);
|
||||
normalize = true;
|
||||
}
|
||||
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space, rows, config[IndexParams::M].get<int64_t>(),
|
||||
config[IndexParams::efConstruction].get<int64_t>());
|
||||
}
|
||||
|
||||
void
|
||||
IndexHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
GETTENSORWITHIDS(dataset_ptr)
|
||||
|
||||
// if (normalize) {
|
||||
// std::vector<float> ep_norm_vector(Dim());
|
||||
// normalize_vector((float*)(p_data), ep_norm_vector.data(), Dim());
|
||||
// index_->addPoint((void*)(ep_norm_vector.data()), p_ids[0]);
|
||||
// #pragma omp parallel for
|
||||
// for (int i = 1; i < rows; ++i) {
|
||||
// std::vector<float> norm_vector(Dim());
|
||||
// normalize_vector((float*)(p_data + Dim() * i), norm_vector.data(), Dim());
|
||||
// index_->addPoint((void*)(norm_vector.data()), p_ids[i]);
|
||||
// }
|
||||
// } else {
|
||||
// index_->addPoint((void*)(p_data), p_ids[0]);
|
||||
// #pragma omp parallel for
|
||||
// for (int i = 1; i < rows; ++i) {
|
||||
// index_->addPoint((void*)(p_data + Dim() * i), p_ids[i]);
|
||||
// }
|
||||
// }
|
||||
|
||||
index_->addPoint(p_data, p_ids[0]);
|
||||
#pragma omp parallel for
|
||||
for (int i = 1; i < rows; ++i) {
|
||||
index_->addPoint(((float*)p_data + Dim() * i), p_ids[i]);
|
||||
}
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
GETTENSOR(dataset)
|
||||
GETTENSOR(dataset_ptr)
|
||||
|
||||
size_t id_size = sizeof(int64_t) * config[meta::TOPK].get<int64_t>();
|
||||
size_t dist_size = sizeof(float) * config[meta::TOPK].get<int64_t>();
|
||||
|
@ -93,11 +144,11 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < rows; ++i) {
|
||||
std::vector<P> ret;
|
||||
const float* single_query = p_data + i * Dimension();
|
||||
const float* single_query = (float*)p_data + i * Dim();
|
||||
|
||||
// if (normalize) {
|
||||
// std::vector<float> norm_vector(Dimension());
|
||||
// normalize_vector((float*)(single_query), norm_vector.data(), Dimension());
|
||||
// std::vector<float> norm_vector(Dim());
|
||||
// normalize_vector((float*)(single_query), norm_vector.data(), Dim());
|
||||
// ret = index_->searchKnn((float*)(norm_vector.data()), config[meta::TOPK].get<int64_t>(), compare);
|
||||
// } else {
|
||||
// ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
|
||||
|
@ -130,64 +181,6 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
return ret_ds;
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
IndexHNSW::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
GETTENSOR(dataset)
|
||||
|
||||
hnswlib::SpaceInterface<float>* space;
|
||||
if (config[Metric::TYPE] == Metric::L2) {
|
||||
space = new hnswlib::L2Space(dim);
|
||||
} else if (config[Metric::TYPE] == Metric::IP) {
|
||||
space = new hnswlib::InnerProductSpace(dim);
|
||||
normalize = true;
|
||||
}
|
||||
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space, rows, config[IndexParams::M].get<int64_t>(),
|
||||
config[IndexParams::efConstruction].get<int64_t>());
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void
|
||||
IndexHNSW::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
GETTENSOR(dataset)
|
||||
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
// if (normalize) {
|
||||
// std::vector<float> ep_norm_vector(Dimension());
|
||||
// normalize_vector((float*)(p_data), ep_norm_vector.data(), Dimension());
|
||||
// index_->addPoint((void*)(ep_norm_vector.data()), p_ids[0]);
|
||||
// #pragma omp parallel for
|
||||
// for (int i = 1; i < rows; ++i) {
|
||||
// std::vector<float> norm_vector(Dimension());
|
||||
// normalize_vector((float*)(p_data + Dimension() * i), norm_vector.data(), Dimension());
|
||||
// index_->addPoint((void*)(norm_vector.data()), p_ids[i]);
|
||||
// }
|
||||
// } else {
|
||||
// index_->addPoint((void*)(p_data), p_ids[0]);
|
||||
// #pragma omp parallel for
|
||||
// for (int i = 1; i < rows; ++i) {
|
||||
// index_->addPoint((void*)(p_data + Dimension() * i), p_ids[i]);
|
||||
// }
|
||||
// }
|
||||
|
||||
index_->addPoint((void*)(p_data), p_ids[0]);
|
||||
#pragma omp parallel for
|
||||
for (int i = 1; i < rows; ++i) {
|
||||
index_->addPoint((void*)(p_data + Dimension() * i), p_ids[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexHNSW::Seal() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
int64_t
|
||||
IndexHNSW::Count() {
|
||||
if (!index_) {
|
||||
|
@ -197,7 +190,7 @@ IndexHNSW::Count() {
|
|||
}
|
||||
|
||||
int64_t
|
||||
IndexHNSW::Dimension() {
|
||||
IndexHNSW::Dim() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
@ -205,3 +198,4 @@ IndexHNSW::Dimension() {
|
|||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -16,44 +16,43 @@
|
|||
|
||||
#include "hnswlib/hnswlib.h"
|
||||
|
||||
#include "knowhere/index/vector_index/VectorIndex.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IndexHNSW : public VectorIndex {
|
||||
class IndexHNSW : public VecIndex {
|
||||
public:
|
||||
IndexHNSW() {
|
||||
index_type_ = IndexEnum::INDEX_HNSW;
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
Serialize(const Config& config = Config()) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
void
|
||||
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override {
|
||||
KNOWHERE_THROW_MSG("Incremental index is not supported");
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
// void
|
||||
// set_preprocessor(PreprocessorPtr preprocessor) override;
|
||||
//
|
||||
// void
|
||||
// set_index_model(IndexModelPtr model) override;
|
||||
//
|
||||
// PreprocessorPtr
|
||||
// BuildPreprocessor(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Seal() override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
int64_t
|
||||
Dimension() override;
|
||||
Dim() override;
|
||||
|
||||
private:
|
||||
bool normalize = false;
|
||||
|
@ -62,3 +61,4 @@ class IndexHNSW : public VectorIndex {
|
|||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,7 +7,9 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
|
||||
#include <faiss/AutoTune.h>
|
||||
#include <faiss/IndexFlat.h>
|
||||
|
@ -15,95 +17,72 @@
|
|||
#include <faiss/clone_index.h>
|
||||
#include <faiss/index_factory.h>
|
||||
#include <faiss/index_io.h>
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
||||
#include "knowhere/index/vector_index/IndexGPUIDMAP.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
|
||||
#endif
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
BinarySet
|
||||
IDMAP::Serialize() {
|
||||
IDMAP::Serialize(const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl();
|
||||
return SerializeImpl(index_type_);
|
||||
}
|
||||
|
||||
void
|
||||
IDMAP::Load(const BinarySet& index_binary) {
|
||||
IDMAP::Load(const BinarySet& binary_set) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(index_binary);
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IDMAP::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GETTENSOR(dataset)
|
||||
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
search_impl(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, Config());
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
return ret_ds;
|
||||
LoadImpl(binary_set, index_type_);
|
||||
}
|
||||
|
||||
void
|
||||
IDMAP::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
|
||||
index_->search(n, (float*)data, k, distances, labels, bitset_);
|
||||
IDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
const char* desc = "IDMap,Flat";
|
||||
int64_t dim = config[meta::DIM].get<int64_t>();
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
auto index = faiss::index_factory(dim, desc, metric_type);
|
||||
index_.reset(index);
|
||||
}
|
||||
|
||||
void
|
||||
IDMAP::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
IDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
GETTENSOR(dataset)
|
||||
|
||||
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
|
||||
GETTENSORWITHIDS(dataset_ptr)
|
||||
index_->add_with_ids(rows, (float*)p_data, p_ids);
|
||||
}
|
||||
|
||||
void
|
||||
IDMAP::AddWithoutId(const DatasetPtr& dataset, const Config& config) {
|
||||
IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const float*>(meta::TENSOR);
|
||||
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset_ptr->Get<const void*>(meta::TENSOR);
|
||||
|
||||
// TODO: caiyd need check
|
||||
std::vector<int64_t> new_ids(rows);
|
||||
for (int i = 0; i < rows; ++i) {
|
||||
new_ids[i] = i;
|
||||
|
@ -112,14 +91,71 @@ IDMAP::AddWithoutId(const DatasetPtr& dataset, const Config& config) {
|
|||
index_->add_with_ids(rows, (float*)p_data, new_ids.data());
|
||||
}
|
||||
|
||||
int64_t
|
||||
IDMAP::Count() {
|
||||
return index_->ntotal;
|
||||
DatasetPtr
|
||||
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GETTENSOR(dataset_ptr)
|
||||
|
||||
int64_t k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
QueryImpl(rows, (float*)p_data, k, p_dist, p_id, Config());
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
int64_t
|
||||
IDMAP::Dimension() {
|
||||
return index_->d;
|
||||
DatasetPtr
|
||||
IDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
// GETTENSOR(dataset)
|
||||
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
int64_t k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
// todo: enable search by id (zhiru)
|
||||
// auto blacklist = dataset_ptr->Get<faiss::ConcurrentBitsetPtr>("bitset");
|
||||
// index_->searchById(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, blacklist);
|
||||
index_->search_by_id(rows, p_data, k, p_dist, p_id, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
IDMAP::CopyCpuToGpu(const int64_t device_id, const Config& config) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
|
||||
|
||||
std::shared_ptr<faiss::Index> device_index;
|
||||
device_index.reset(gpu_index);
|
||||
return std::make_shared<GPUIDMAP>(device_index, device_id, res);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
|
||||
}
|
||||
#else
|
||||
KNOWHERE_THROW_MSG("Calling IDMAP::CopyCpuToGpu when we are using CPU version");
|
||||
#endif
|
||||
}
|
||||
|
||||
const float*
|
||||
|
@ -143,57 +179,15 @@ IDMAP::GetRawIds() {
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
IDMAP::Train(const Config& config) {
|
||||
const char* type = "IDMap,Flat";
|
||||
auto index = faiss::index_factory(config[meta::DIM].get<int64_t>(), type,
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
index_.reset(index);
|
||||
}
|
||||
|
||||
// VectorIndexPtr
|
||||
// IDMAP::Clone() {
|
||||
// std::lock_guard<std::mutex> lk(mutex_);
|
||||
//
|
||||
// auto clone_index = faiss::clone_index(index_.get());
|
||||
// std::shared_ptr<faiss::Index> new_index;
|
||||
// new_index.reset(clone_index);
|
||||
// return std::make_shared<IDMAP>(new_index);
|
||||
//}
|
||||
|
||||
VectorIndexPtr
|
||||
IDMAP::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
|
||||
|
||||
std::shared_ptr<faiss::Index> device_index;
|
||||
device_index.reset(gpu_index);
|
||||
return std::make_shared<GPUIDMAP>(device_index, device_id, res);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
|
||||
}
|
||||
#else
|
||||
KNOWHERE_THROW_MSG("Calling IDMAP::CopyCpuToGpu when we are using CPU version");
|
||||
#endif
|
||||
}
|
||||
|
||||
void
|
||||
IDMAP::Seal() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IDMAP::GetVectorById(const DatasetPtr& dataset, const Config& config) {
|
||||
IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
// GETTENSOR(dataset)
|
||||
// auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
auto elems = dataset->Get<int64_t>(meta::DIM);
|
||||
// auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
auto elems = dataset_ptr->Get<int64_t>(meta::DIM);
|
||||
|
||||
size_t p_x_size = sizeof(float) * elems;
|
||||
auto p_x = (float*)malloc(p_x_size);
|
||||
|
@ -205,40 +199,10 @@ IDMAP::GetVectorById(const DatasetPtr& dataset, const Config& config) {
|
|||
return ret_ds;
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IDMAP::SearchById(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
// GETTENSOR(dataset)
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
// todo: enable search by id (zhiru)
|
||||
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
|
||||
// index_->searchById(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, blacklist);
|
||||
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
void
|
||||
IDMAP::SetBlacklist(faiss::ConcurrentBitsetPtr list) {
|
||||
bitset_ = std::move(list);
|
||||
}
|
||||
|
||||
void
|
||||
IDMAP::GetBlacklist(faiss::ConcurrentBitsetPtr& list) {
|
||||
list = bitset_;
|
||||
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
index_->search(n, (float*)data, k, distances, labels, bitset_);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,58 +7,65 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "IndexIVF.h"
|
||||
|
||||
#include <faiss/utils/ConcurrentBitset.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/index/vector_index/FaissBaseIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IDMAP : public VectorIndex, public FaissBaseIndex {
|
||||
class IDMAP : public VecIndex, public FaissBaseIndex {
|
||||
public:
|
||||
IDMAP() : FaissBaseIndex(nullptr) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IDMAP;
|
||||
}
|
||||
|
||||
explicit IDMAP(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IDMAP;
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
Serialize(const Config& config = Config()) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
Load(const BinarySet&) override;
|
||||
|
||||
void
|
||||
Train(const Config& config);
|
||||
Train(const DatasetPtr&, const Config&) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr&, const Config&) override;
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
QueryById(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
// VectorIndexPtr
|
||||
// Clone() override;
|
||||
Count() override {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t
|
||||
Dimension() override;
|
||||
Dim() override {
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
DatasetPtr
|
||||
GetVectorById(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
AddWithoutId(const DatasetPtr& dataset, const Config& config);
|
||||
|
||||
VectorIndexPtr
|
||||
CopyCpuToGpu(const int64_t& device_id, const Config& config);
|
||||
|
||||
void
|
||||
Seal() override;
|
||||
VecIndexPtr
|
||||
CopyCpuToGpu(const int64_t, const Config&);
|
||||
|
||||
virtual const float*
|
||||
GetRawVectors();
|
||||
|
@ -66,29 +73,15 @@ class IDMAP : public VectorIndex, public FaissBaseIndex {
|
|||
virtual const int64_t*
|
||||
GetRawIds();
|
||||
|
||||
DatasetPtr
|
||||
GetVectorById(const DatasetPtr& dataset, const Config& config);
|
||||
|
||||
DatasetPtr
|
||||
SearchById(const DatasetPtr& dataset, const Config& config);
|
||||
|
||||
void
|
||||
SetBlacklist(faiss::ConcurrentBitsetPtr list);
|
||||
|
||||
void
|
||||
GetBlacklist(faiss::ConcurrentBitsetPtr& list);
|
||||
|
||||
protected:
|
||||
virtual void
|
||||
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
|
||||
private:
|
||||
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
|
||||
};
|
||||
|
||||
using IDMAPPtr = std::shared_ptr<IDMAP>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <faiss/AutoTune.h>
|
||||
#include <faiss/IVFlib.h>
|
||||
|
@ -30,91 +30,92 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/IndexGPUIVF.h"
|
||||
#endif
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
#endif
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
using stdclock = std::chrono::high_resolution_clock;
|
||||
|
||||
IndexModelPtr
|
||||
IVF::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
GETTENSOR(dataset)
|
||||
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlatL2(dim);
|
||||
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(),
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
index->train(rows, (float*)p_data);
|
||||
|
||||
// TODO(linxj): override here. train return model or not.
|
||||
return std::make_shared<IVFIndexModel>(index);
|
||||
}
|
||||
|
||||
void
|
||||
IVF::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
BinarySet
|
||||
IVF::Serialize(const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
GETTENSOR(dataset)
|
||||
return SerializeImpl(index_type_);
|
||||
}
|
||||
|
||||
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
|
||||
void
|
||||
IVF::Load(const BinarySet& binary_set) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(binary_set, index_type_);
|
||||
}
|
||||
|
||||
void
|
||||
IVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GETTENSOR(dataset_ptr)
|
||||
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlatL2(dim);
|
||||
int64_t nlist = config[IndexParams::nlist].get<int64_t>();
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, nlist, metric_type);
|
||||
index->train(rows, (float*)p_data);
|
||||
|
||||
index_.reset(faiss::clone_index(index.get()));
|
||||
}
|
||||
|
||||
void
|
||||
IVF::Add(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
GETTENSORWITHIDS(dataset_ptr)
|
||||
index_->add_with_ids(rows, (float*)p_data, p_ids);
|
||||
}
|
||||
|
||||
void
|
||||
IVF::AddWithoutIds(const DatasetPtr& dataset, const Config& config) {
|
||||
IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
GETTENSOR(dataset)
|
||||
|
||||
GETTENSOR(dataset_ptr)
|
||||
index_->add(rows, (float*)p_data);
|
||||
}
|
||||
|
||||
BinarySet
|
||||
IVF::Serialize() {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl();
|
||||
}
|
||||
|
||||
void
|
||||
IVF::Load(const BinarySet& index_binary) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(index_binary);
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IVF::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
GETTENSOR(dataset_ptr)
|
||||
|
||||
try {
|
||||
fiu_do_on("IVF.Search.throw_std_exception", throw std::exception());
|
||||
fiu_do_on("IVF.Search.throw_faiss_exception", throw faiss::FaissException(""));
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
int64_t k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
search_impl(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, config);
|
||||
QueryImpl(rows, (float*)p_data, k, p_dist, p_id, config);
|
||||
|
||||
// std::stringstream ss_res_id, ss_res_dist;
|
||||
// for (int i = 0; i < 10; ++i) {
|
||||
|
@ -140,166 +141,18 @@ IVF::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVF::set_index_model(IndexModelPtr model) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto rel_model = std::static_pointer_cast<IVFIndexModel>(model);
|
||||
|
||||
// Deep copy here.
|
||||
index_.reset(faiss::clone_index(rel_model->index_.get()));
|
||||
}
|
||||
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
IVF::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFSearchParameters>();
|
||||
|
||||
params->nprobe = config[IndexParams::nprobe];
|
||||
// params->max_codes = config["max_codes"];
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
int64_t
|
||||
IVF::Count() {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t
|
||||
IVF::Dimension() {
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
void
|
||||
IVF::GenGraph(const float* data, const int64_t& k, Graph& graph, const Config& config) {
|
||||
int64_t K = k + 1;
|
||||
auto ntotal = Count();
|
||||
|
||||
size_t dim = config[meta::DIM];
|
||||
auto batch_size = 1000;
|
||||
auto tail_batch_size = ntotal % batch_size;
|
||||
auto batch_search_count = ntotal / batch_size;
|
||||
auto total_search_count = tail_batch_size == 0 ? batch_search_count : batch_search_count + 1;
|
||||
|
||||
std::vector<float> res_dis(K * batch_size);
|
||||
graph.resize(ntotal);
|
||||
Graph res_vec(total_search_count);
|
||||
for (int i = 0; i < total_search_count; ++i) {
|
||||
auto b_size = (i == (total_search_count - 1)) && tail_batch_size != 0 ? tail_batch_size : batch_size;
|
||||
|
||||
auto& res = res_vec[i];
|
||||
res.resize(K * b_size);
|
||||
|
||||
auto xq = data + batch_size * dim * i;
|
||||
search_impl(b_size, (float*)xq, K, res_dis.data(), res.data(), config);
|
||||
|
||||
for (int j = 0; j < b_size; ++j) {
|
||||
auto& node = graph[batch_size * i + j];
|
||||
node.resize(k);
|
||||
auto start_pos = j * K + 1;
|
||||
for (int m = 0, cursor = start_pos; m < k && cursor < start_pos + k; ++m, ++cursor) {
|
||||
node[m] = res[cursor];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
|
||||
auto params = GenParams(cfg);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
stdclock::time_point before = stdclock::now();
|
||||
ivf_index->search(n, (float*)data, k, distances, labels, bitset_);
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
KNOWHERE_LOG_DEBUG << "IVF search cost: " << search_cost
|
||||
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
|
||||
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
|
||||
faiss::indexIVF_stats.quantization_time = 0;
|
||||
faiss::indexIVF_stats.search_time = 0;
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
IVF::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
|
||||
|
||||
std::shared_ptr<faiss::Index> device_index;
|
||||
device_index.reset(gpu_index);
|
||||
return std::make_shared<GPUIVF>(device_index, device_id, res);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
|
||||
}
|
||||
|
||||
#else
|
||||
KNOWHERE_THROW_MSG("Calling IVF::CopyCpuToGpu when we are using CPU version");
|
||||
#endif
|
||||
}
|
||||
|
||||
// VectorIndexPtr
|
||||
// IVF::Clone() {
|
||||
// std::lock_guard<std::mutex> lk(mutex_);
|
||||
//
|
||||
// auto clone_index = faiss::clone_index(index_.get());
|
||||
// std::shared_ptr<faiss::Index> new_index;
|
||||
// new_index.reset(clone_index);
|
||||
// return Clone_impl(new_index);
|
||||
//}
|
||||
//
|
||||
// VectorIndexPtr
|
||||
// IVF::Clone_impl(const std::shared_ptr<faiss::Index>& index) {
|
||||
// return std::make_shared<IVF>(index);
|
||||
//}
|
||||
|
||||
void
|
||||
IVF::Seal() {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
SealImpl();
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IVF::GetVectorById(const DatasetPtr& dataset, const Config& config) {
|
||||
IVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
auto elems = dataset->Get<int64_t>(meta::DIM);
|
||||
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
try {
|
||||
size_t p_x_size = sizeof(float) * elems;
|
||||
auto p_x = (float*)malloc(p_x_size);
|
||||
|
||||
auto index_ivf = std::static_pointer_cast<faiss::IndexIVF>(index_);
|
||||
index_ivf->get_vector_by_id(1, p_data, p_x, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::TENSOR, p_x);
|
||||
return ret_ds;
|
||||
} catch (faiss::FaissException& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IVF::SearchById(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
try {
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
int64_t k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
|
@ -307,9 +160,9 @@ IVF::SearchById(const DatasetPtr& dataset, const Config& config) {
|
|||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
// todo: enable search by id (zhiru)
|
||||
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
|
||||
// auto blacklist = dataset_ptr->Get<faiss::ConcurrentBitsetPtr>("bitset");
|
||||
auto index_ivf = std::static_pointer_cast<faiss::IndexIVF>(index_);
|
||||
index_ivf->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, bitset_);
|
||||
index_ivf->search_by_id(rows, p_data, k, p_dist, p_id, bitset_);
|
||||
|
||||
// std::stringstream ss_res_id, ss_res_dist;
|
||||
// for (int i = 0; i < 10; ++i) {
|
||||
|
@ -335,37 +188,129 @@ IVF::SearchById(const DatasetPtr& dataset, const Config& config) {
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVF::SetBlacklist(faiss::ConcurrentBitsetPtr list) {
|
||||
bitset_ = std::move(list);
|
||||
}
|
||||
|
||||
void
|
||||
IVF::GetBlacklist(faiss::ConcurrentBitsetPtr& list) {
|
||||
list = bitset_;
|
||||
}
|
||||
|
||||
IVFIndexModel::IVFIndexModel(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
|
||||
}
|
||||
|
||||
BinarySet
|
||||
IVFIndexModel::Serialize() {
|
||||
DatasetPtr
|
||||
IVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("indexmodel not initialize or trained");
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
auto elems = dataset_ptr->Get<int64_t>(meta::DIM);
|
||||
|
||||
try {
|
||||
size_t p_x_size = sizeof(float) * elems;
|
||||
auto p_x = (float*)malloc(p_x_size);
|
||||
|
||||
auto index_ivf = std::static_pointer_cast<faiss::IndexIVF>(index_);
|
||||
index_ivf->get_vector_by_id(1, p_data, p_x, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::TENSOR, p_x);
|
||||
return ret_ds;
|
||||
} catch (faiss::FaissException& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl();
|
||||
}
|
||||
|
||||
void
|
||||
IVFIndexModel::Load(const BinarySet& binary_set) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(binary_set);
|
||||
IVF::Seal() {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
SealImpl();
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
IVF::CopyCpuToGpu(const int64_t device_id, const Config& config) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
|
||||
|
||||
std::shared_ptr<faiss::Index> device_index;
|
||||
device_index.reset(gpu_index);
|
||||
return std::make_shared<GPUIVF>(device_index, device_id, res);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
|
||||
}
|
||||
|
||||
#else
|
||||
KNOWHERE_THROW_MSG("Calling IVF::CopyCpuToGpu when we are using CPU version");
|
||||
#endif
|
||||
}
|
||||
|
||||
void
|
||||
IVFIndexModel::SealImpl() {
|
||||
// do nothing
|
||||
IVF::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config) {
|
||||
int64_t K = k + 1;
|
||||
auto ntotal = Count();
|
||||
|
||||
size_t dim = config[meta::DIM];
|
||||
auto batch_size = 1000;
|
||||
auto tail_batch_size = ntotal % batch_size;
|
||||
auto batch_search_count = ntotal / batch_size;
|
||||
auto total_search_count = tail_batch_size == 0 ? batch_search_count : batch_search_count + 1;
|
||||
|
||||
std::vector<float> res_dis(K * batch_size);
|
||||
graph.resize(ntotal);
|
||||
GraphType res_vec(total_search_count);
|
||||
for (int i = 0; i < total_search_count; ++i) {
|
||||
auto b_size = (i == (total_search_count - 1)) && tail_batch_size != 0 ? tail_batch_size : batch_size;
|
||||
|
||||
auto& res = res_vec[i];
|
||||
res.resize(K * b_size);
|
||||
|
||||
auto xq = data + batch_size * dim * i;
|
||||
QueryImpl(b_size, (float*)xq, K, res_dis.data(), res.data(), config);
|
||||
|
||||
for (int j = 0; j < b_size; ++j) {
|
||||
auto& node = graph[batch_size * i + j];
|
||||
node.resize(k);
|
||||
auto start_pos = j * K + 1;
|
||||
for (int m = 0, cursor = start_pos; m < k && cursor < start_pos + k; ++m, ++cursor) {
|
||||
node[m] = res[cursor];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
IVF::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFSearchParameters>();
|
||||
params->nprobe = config[IndexParams::nprobe];
|
||||
// params->max_codes = config["max_codes"];
|
||||
return params;
|
||||
}
|
||||
|
||||
void
|
||||
IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
stdclock::time_point before = stdclock::now();
|
||||
ivf_index->search(n, (float*)data, k, distances, labels, bitset_);
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
KNOWHERE_LOG_DEBUG << "IVF search cost: " << search_cost
|
||||
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
|
||||
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
|
||||
faiss::indexIVF_stats.quantization_time = 0;
|
||||
faiss::indexIVF_stats.search_time = 0;
|
||||
}
|
||||
|
||||
void
|
||||
IVF::SealImpl() {
|
||||
#ifdef CUSTOMIZATION
|
||||
faiss::Index* index = index_.get();
|
||||
auto idx = dynamic_cast<faiss::IndexIVF*>(index);
|
||||
if (idx != nullptr) {
|
||||
// To be deleted
|
||||
KNOWHERE_LOG_DEBUG << "Test before to_readonly: IVF READONLY " << std::boolalpha << idx->is_readonly();
|
||||
idx->to_readonly();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
|
@ -16,111 +16,75 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "FaissBaseIndex.h"
|
||||
#include "VectorIndex.h"
|
||||
#include "faiss/IndexIVF.h"
|
||||
#include "faiss/utils/ConcurrentBitset.h"
|
||||
#include <faiss/IndexIVF.h>
|
||||
|
||||
#include "knowhere/common/Typedef.h"
|
||||
#include "knowhere/index/vector_index/FaissBaseIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
using Graph = std::vector<std::vector<int64_t>>;
|
||||
|
||||
class IVF : public VectorIndex, public FaissBaseIndex {
|
||||
class IVF : public VecIndex, public FaissBaseIndex {
|
||||
public:
|
||||
IVF() : FaissBaseIndex(nullptr) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
|
||||
}
|
||||
|
||||
explicit IVF(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
|
||||
}
|
||||
|
||||
// VectorIndexPtr
|
||||
// Clone() override;
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
BinarySet
|
||||
Serialize(const Config& config = Config()) override;
|
||||
|
||||
void
|
||||
set_index_model(IndexModelPtr model) override;
|
||||
Load(const BinarySet&) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
Train(const DatasetPtr&, const Config&) override;
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr& dataset, const Config& config);
|
||||
Add(const DatasetPtr&, const Config&) override;
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
|
||||
void
|
||||
GenGraph(const float* data, const int64_t& k, Graph& graph, const Config& config);
|
||||
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
DatasetPtr
|
||||
QueryById(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
Count() override {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t
|
||||
Dimension() override;
|
||||
|
||||
void
|
||||
Seal() override;
|
||||
|
||||
virtual VectorIndexPtr
|
||||
CopyCpuToGpu(const int64_t& device_id, const Config& config);
|
||||
Dim() override {
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
GetVectorById(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
DatasetPtr
|
||||
SearchById(const DatasetPtr& dataset, const Config& config) override;
|
||||
virtual void
|
||||
Seal();
|
||||
|
||||
void
|
||||
SetBlacklist(faiss::ConcurrentBitsetPtr list);
|
||||
virtual VecIndexPtr
|
||||
CopyCpuToGpu(const int64_t, const Config&);
|
||||
|
||||
void
|
||||
GetBlacklist(faiss::ConcurrentBitsetPtr& list);
|
||||
virtual void
|
||||
GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config);
|
||||
|
||||
protected:
|
||||
virtual std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GenParams(const Config& config);
|
||||
|
||||
// virtual VectorIndexPtr
|
||||
// Clone_impl(const std::shared_ptr<faiss::Index>& index);
|
||||
GenParams(const Config&);
|
||||
|
||||
virtual void
|
||||
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
|
||||
private:
|
||||
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
|
||||
};
|
||||
|
||||
using IVFIndexPtr = std::shared_ptr<IVF>;
|
||||
|
||||
class GPUIVF;
|
||||
class IVFIndexModel : public IndexModel, public FaissBaseIndex {
|
||||
friend IVF;
|
||||
friend GPUIVF;
|
||||
|
||||
public:
|
||||
explicit IVFIndexModel(std::shared_ptr<faiss::Index> index);
|
||||
|
||||
IVFIndexModel() : FaissBaseIndex(nullptr) {
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& binary) override;
|
||||
|
||||
protected:
|
||||
void
|
||||
SealImpl() override;
|
||||
|
||||
|
@ -128,6 +92,7 @@ class IVFIndexModel : public IndexModel, public FaissBaseIndex {
|
|||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
using IVFIndexModelPtr = std::shared_ptr<IVFIndexModel>;
|
||||
using IVFPtr = std::shared_ptr<IVF>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,31 +7,33 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/IndexIVFPQ.h>
|
||||
#include <faiss/clone_index.h>
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVFPQ.h"
|
||||
#endif
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h"
|
||||
#endif
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
IndexModelPtr
|
||||
IVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
GETTENSOR(dataset)
|
||||
void
|
||||
IVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GETTENSOR(dataset_ptr)
|
||||
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
auto index = std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(),
|
||||
|
@ -39,27 +41,11 @@ IVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
|
|||
config[IndexParams::nbits].get<int64_t>());
|
||||
index->train(rows, (float*)p_data);
|
||||
|
||||
return std::make_shared<IVFIndexModel>(index);
|
||||
index_.reset(faiss::clone_index(index.get()));
|
||||
}
|
||||
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
IVFPQ::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
|
||||
params->nprobe = config[IndexParams::nprobe];
|
||||
// params->scan_table_threshold = config["scan_table_threhold"]
|
||||
// params->polysemous_ht = config["polysemous_ht"]
|
||||
// params->max_codes = config["max_codes"]
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
// VectorIndexPtr
|
||||
// IVFPQ::Clone_impl(const std::shared_ptr<faiss::Index>& index) {
|
||||
// return std::make_shared<IVFPQ>(index);
|
||||
//}
|
||||
|
||||
VectorIndexPtr
|
||||
IVFPQ::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
VecIndexPtr
|
||||
IVFPQ::CopyCpuToGpu(const int64_t device_id, const Config& config) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
|
@ -76,4 +62,16 @@ IVFPQ::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
|
|||
#endif
|
||||
}
|
||||
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
IVFPQ::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
|
||||
params->nprobe = config[IndexParams::nprobe];
|
||||
// params->scan_table_threshold = config["scan_table_threhold"]
|
||||
// params->polysemous_ht = config["polysemous_ht"]
|
||||
// params->max_codes = config["max_codes"]
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,36 +7,40 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IVFPQ : public IVF {
|
||||
public:
|
||||
explicit IVFPQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
|
||||
IVFPQ() : IVF() {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFPQ;
|
||||
}
|
||||
|
||||
IVFPQ() = default;
|
||||
explicit IVFPQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFPQ;
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
void
|
||||
Train(const DatasetPtr&, const Config&) override;
|
||||
|
||||
VectorIndexPtr
|
||||
CopyCpuToGpu(const int64_t& device_id, const Config& config) override;
|
||||
VecIndexPtr
|
||||
CopyCpuToGpu(const int64_t, const Config&) override;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GenParams(const Config& config) override;
|
||||
|
||||
// VectorIndexPtr
|
||||
// Clone_impl(const std::shared_ptr<faiss::Index>& index) override;
|
||||
};
|
||||
|
||||
using IVFPQPtr = std::shared_ptr<IVFPQ>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,31 +7,33 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include <faiss/gpu/GpuAutoTune.h>
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#endif
|
||||
#include <faiss/clone_index.h>
|
||||
#include <faiss/index_factory.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/IndexGPUIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
#endif
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
IndexModelPtr
|
||||
IVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
GETTENSOR(dataset)
|
||||
void
|
||||
IVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GETTENSOR(dataset_ptr)
|
||||
|
||||
std::stringstream index_type;
|
||||
index_type << "IVF" << config[IndexParams::nlist] << ","
|
||||
|
@ -40,20 +42,12 @@ IVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
|
|||
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
build_index->train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> ret_index;
|
||||
ret_index.reset(build_index);
|
||||
return std::make_shared<IVFIndexModel>(ret_index);
|
||||
index_.reset(faiss::clone_index(build_index));
|
||||
}
|
||||
|
||||
// VectorIndexPtr
|
||||
// IVFSQ::Clone_impl(const std::shared_ptr<faiss::Index>& index) {
|
||||
// return std::make_shared<IVFSQ>(index);
|
||||
//}
|
||||
|
||||
VectorIndexPtr
|
||||
IVFSQ::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
VecIndexPtr
|
||||
IVFSQ::CopyCpuToGpu(const int64_t device_id, const Config& config) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
|
||||
|
@ -65,10 +59,10 @@ IVFSQ::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
|
|||
} else {
|
||||
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
|
||||
}
|
||||
|
||||
#else
|
||||
KNOWHERE_THROW_MSG("Calling IVFSQ::CopyCpuToGpu when we are using CPU version");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,33 +7,36 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IVFSQ : public IVF {
|
||||
public:
|
||||
explicit IVFSQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
|
||||
IVFSQ() : IVF() {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8;
|
||||
}
|
||||
|
||||
IVFSQ() = default;
|
||||
explicit IVFSQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8;
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
void
|
||||
Train(const DatasetPtr&, const Config&) override;
|
||||
|
||||
VectorIndexPtr
|
||||
CopyCpuToGpu(const int64_t& device_id, const Config& config) override;
|
||||
|
||||
protected:
|
||||
// VectorIndexPtr
|
||||
// Clone_impl(const std::shared_ptr<faiss::Index>& index) override;
|
||||
VecIndexPtr
|
||||
CopyCpuToGpu(const int64_t, const Config&) override;
|
||||
};
|
||||
|
||||
using IVFSQPtr = std::shared_ptr<IVFSQ>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,43 +7,42 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexNSG.h"
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
||||
#include "knowhere/index/vector_index/IndexGPUIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/helpers/Cloner.h"
|
||||
|
||||
#endif
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <fiu-local.h>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/nsg/NSG.h"
|
||||
#include "knowhere/index/vector_index/nsg/NSGIO.h"
|
||||
#include "knowhere/index/vector_index/IndexNSG.h"
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/impl/nsg/NSG.h"
|
||||
#include "knowhere/index/vector_index/impl/nsg/NSGIO.h"
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/helpers/Cloner.h"
|
||||
#endif
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
BinarySet
|
||||
NSG::Serialize() {
|
||||
NSG::Serialize(const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
try {
|
||||
fiu_do_on("NSG.Serialize.throw_exception", throw std::exception());
|
||||
algo::NsgIndex* index = index_.get();
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
impl::NsgIndex* index = index_.get();
|
||||
|
||||
MemoryIOWriter writer;
|
||||
algo::write_index(index, writer);
|
||||
impl::write_index(index, writer);
|
||||
auto data = std::make_shared<uint8_t>();
|
||||
data.reset(writer.data_);
|
||||
|
||||
|
@ -59,13 +58,14 @@ void
|
|||
NSG::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
fiu_do_on("NSG.Load.throw_exception", throw std::exception());
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
auto binary = index_binary.GetByName("NSG");
|
||||
|
||||
MemoryIOReader reader;
|
||||
reader.total = binary->size;
|
||||
reader.data_ = binary->data.get();
|
||||
|
||||
auto index = algo::read_index(reader);
|
||||
auto index = impl::read_index(reader);
|
||||
index_.reset(index);
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
|
@ -73,73 +73,75 @@ NSG::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
NSG::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
GETTENSOR(dataset_ptr)
|
||||
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
try {
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
algo::SearchParams s_params;
|
||||
s_params.search_length = config[IndexParams::search_length];
|
||||
index_->Search((float*)p_data, rows, dim, config[meta::TOPK].get<int64_t>(), p_dist, p_id, s_params);
|
||||
impl::SearchParams s_params;
|
||||
s_params.search_length = config[IndexParams::search_length];
|
||||
s_params.k = config[meta::TOPK];
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
index_->Search((float*)p_data, rows, dim, config[meta::TOPK].get<int64_t>(), p_dist, p_id, s_params);
|
||||
}
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
return ret_ds;
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
return ret_ds;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
NSG::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
void
|
||||
NSG::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
auto idmap = std::make_shared<IDMAP>();
|
||||
idmap->Train(config);
|
||||
idmap->AddWithoutId(dataset, config);
|
||||
Graph knng;
|
||||
idmap->Train(dataset_ptr, config);
|
||||
idmap->AddWithoutIds(dataset_ptr, config);
|
||||
impl::Graph knng;
|
||||
const float* raw_data = idmap->GetRawVectors();
|
||||
const int64_t device_id = config[knowhere::meta::DEVICEID].get<int64_t>();
|
||||
const int64_t k = config[IndexParams::knng].get<int64_t>();
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (config[knowhere::meta::DEVICEID].get<int64_t>() == -1) {
|
||||
if (device_id == -1) {
|
||||
auto preprocess_index = std::make_shared<IVF>();
|
||||
auto model = preprocess_index->Train(dataset, config);
|
||||
preprocess_index->set_index_model(model);
|
||||
preprocess_index->AddWithoutIds(dataset, config);
|
||||
preprocess_index->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
|
||||
preprocess_index->Train(dataset_ptr, config);
|
||||
preprocess_index->AddWithoutIds(dataset_ptr, config);
|
||||
preprocess_index->GenGraph(raw_data, k, knng, config);
|
||||
} else {
|
||||
auto gpu_idx = cloner::CopyCpuToGpu(idmap, config[knowhere::meta::DEVICEID].get<int64_t>(), config);
|
||||
auto gpu_idx = cloner::CopyCpuToGpu(idmap, device_id, config);
|
||||
auto gpu_idmap = std::dynamic_pointer_cast<GPUIDMAP>(gpu_idx);
|
||||
gpu_idmap->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
|
||||
gpu_idmap->GenGraph(raw_data, k, knng, config);
|
||||
}
|
||||
#else
|
||||
auto preprocess_index = std::make_shared<IVF>();
|
||||
auto model = preprocess_index->Train(dataset, config);
|
||||
preprocess_index->set_index_model(model);
|
||||
preprocess_index->AddWithoutIds(dataset, config);
|
||||
preprocess_index->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
|
||||
preprocess_index->Train(dataset_ptr, config);
|
||||
preprocess_index->AddWithoutIds(dataset_ptr, config);
|
||||
preprocess_index->GenGraph(raw_data, k, knng, config);
|
||||
#endif
|
||||
|
||||
algo::BuildParams b_params;
|
||||
impl::BuildParams b_params;
|
||||
b_params.candidate_pool_size = config[IndexParams::candidate];
|
||||
b_params.out_degree = config[IndexParams::out_degree];
|
||||
b_params.search_length = config[IndexParams::search_length];
|
||||
|
||||
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
|
||||
auto p_ids = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
GETTENSOR(dataset)
|
||||
index_ = std::make_shared<algo::NsgIndex>(dim, rows);
|
||||
GETTENSOR(dataset_ptr)
|
||||
index_ = std::make_shared<impl::NsgIndex>(dim, rows);
|
||||
index_->SetKnnGraph(knng);
|
||||
index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params);
|
||||
return nullptr; // TODO(linxj): support serialize
|
||||
}
|
||||
|
||||
void
|
||||
NSG::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
int64_t
|
||||
|
@ -148,18 +150,9 @@ NSG::Count() {
|
|||
}
|
||||
|
||||
int64_t
|
||||
NSG::Dimension() {
|
||||
NSG::Dim() {
|
||||
return index_->dimension;
|
||||
}
|
||||
|
||||
// VectorIndexPtr
|
||||
// NSG::Clone() {
|
||||
// KNOWHERE_THROW_MSG("not support");
|
||||
//}
|
||||
|
||||
void
|
||||
NSG::Seal() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,52 +7,73 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "VectorIndex.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
namespace algo {
|
||||
namespace impl {
|
||||
class NsgIndex;
|
||||
}
|
||||
|
||||
class NSG : public VectorIndex {
|
||||
class NSG : public VecIndex {
|
||||
public:
|
||||
explicit NSG(const int64_t& gpu_num) : gpu_(gpu_num) {
|
||||
explicit NSG(const int64_t& gpu_num = -1) : gpu_(gpu_num) {
|
||||
if (gpu_ >= 0) {
|
||||
index_mode_ = IndexMode::MODE_GPU;
|
||||
}
|
||||
index_type_ = IndexEnum::INDEX_NSG;
|
||||
}
|
||||
|
||||
NSG() = default;
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
Serialize(const Config& config = Config()) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
Load(const BinarySet&) override;
|
||||
|
||||
void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override {
|
||||
Train(dataset_ptr, config);
|
||||
}
|
||||
|
||||
void
|
||||
Train(const DatasetPtr&, const Config&) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr&, const Config&) override {
|
||||
KNOWHERE_THROW_MSG("Incremental index is not supported");
|
||||
}
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override {
|
||||
KNOWHERE_THROW_MSG("Addwithoutids is not supported");
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
int64_t
|
||||
Dimension() override;
|
||||
// VectorIndexPtr
|
||||
// Clone() override;
|
||||
void
|
||||
Seal() override;
|
||||
Dim() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<algo::NsgIndex> index_;
|
||||
std::mutex mutex_;
|
||||
int64_t gpu_;
|
||||
std::shared_ptr<impl::NsgIndex> index_;
|
||||
};
|
||||
|
||||
using NSGIndexPtr = std::shared_ptr<NSG>();
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -19,29 +19,29 @@
|
|||
|
||||
#undef mkdir
|
||||
|
||||
#include "knowhere/adapter/SptagAdapter.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexSPTAG.h"
|
||||
#include "knowhere/index/vector_index/helpers/Definitions.h"
|
||||
#include "knowhere/index/vector_index/adapter/SptagAdapter.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/SPTAGParameterMgr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
CPUSPTAGRNG::CPUSPTAGRNG(const std::string& IndexType) {
|
||||
if (IndexType == "KDT") {
|
||||
index_ptr_ = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::KDT, SPTAG::VectorValueType::Float);
|
||||
index_ptr_->SetParameter("DistCalcMethod", "L2");
|
||||
index_type_ = SPTAG::IndexAlgoType::KDT;
|
||||
index_type_ = IndexEnum::INDEX_SPTAG_KDT_RNT;
|
||||
} else {
|
||||
index_ptr_ = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::BKT, SPTAG::VectorValueType::Float);
|
||||
index_ptr_->SetParameter("DistCalcMethod", "L2");
|
||||
index_type_ = SPTAG::IndexAlgoType::BKT;
|
||||
index_type_ = IndexEnum::INDEX_SPTAG_BKT_RNT;
|
||||
}
|
||||
}
|
||||
|
||||
BinarySet
|
||||
CPUSPTAGRNG::Serialize() {
|
||||
CPUSPTAGRNG::Serialize(const Config& config) {
|
||||
std::string index_config;
|
||||
std::vector<SPTAG::ByteArray> index_blobs;
|
||||
|
||||
|
@ -72,15 +72,15 @@ CPUSPTAGRNG::Serialize() {
|
|||
metadata1.reset(static_cast<uint8_t*>(index_blobs[4].Data()));
|
||||
auto metadata2 = std::make_shared<uint8_t>();
|
||||
metadata2.reset(static_cast<uint8_t*>(index_blobs[5].Data()));
|
||||
auto config = std::make_shared<uint8_t>();
|
||||
config.reset(static_cast<uint8_t*>((void*)cstr));
|
||||
auto x_cfg = std::make_shared<uint8_t>();
|
||||
x_cfg.reset(static_cast<uint8_t*>((void*)cstr));
|
||||
|
||||
binary_set.Append("samples", sample, index_blobs[0].Length());
|
||||
binary_set.Append("tree", tree, index_blobs[1].Length());
|
||||
binary_set.Append("deleteid", deleteid, index_blobs[3].Length());
|
||||
binary_set.Append("metadata1", metadata1, index_blobs[4].Length());
|
||||
binary_set.Append("metadata2", metadata2, index_blobs[5].Length());
|
||||
binary_set.Append("config", config, length);
|
||||
binary_set.Append("config", x_cfg, length);
|
||||
binary_set.Append("graph", graph, index_blobs[2].Length());
|
||||
|
||||
return binary_set;
|
||||
|
@ -115,16 +115,11 @@ CPUSPTAGRNG::Load(const BinarySet& binary_set) {
|
|||
index_ptr_->LoadIndex(index_config, index_blobs);
|
||||
}
|
||||
|
||||
// PreprocessorPtr
|
||||
// CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
|
||||
// return std::make_shared<NormalizePreprocessor>();
|
||||
//}
|
||||
|
||||
IndexModelPtr
|
||||
void
|
||||
CPUSPTAGRNG::Train(const DatasetPtr& origin, const Config& train_config) {
|
||||
SetParameters(train_config);
|
||||
|
||||
DatasetPtr dataset = origin; // TODO(linxj): copy or reference?
|
||||
DatasetPtr dataset = origin;
|
||||
|
||||
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
|
||||
// && preprocessor_) {
|
||||
|
@ -134,24 +129,6 @@ CPUSPTAGRNG::Train(const DatasetPtr& origin, const Config& train_config) {
|
|||
auto vectorset = ConvertToVectorSet(dataset);
|
||||
auto metaset = ConvertToMetadataSet(dataset);
|
||||
index_ptr_->BuildIndex(vectorset, metaset);
|
||||
|
||||
// TODO: return IndexModelPtr
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void
|
||||
CPUSPTAGRNG::Add(const DatasetPtr& origin, const Config& add_config) {
|
||||
// SetParameters(add_config);
|
||||
// DatasetPtr dataset = origin->Clone();
|
||||
//
|
||||
// // if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
|
||||
// // && preprocessor_) {
|
||||
// // preprocessor_->Preprocess(dataset);
|
||||
// //}
|
||||
//
|
||||
// auto vectorset = ConvertToVectorSet(dataset);
|
||||
// auto metaset = ConvertToMetadataSet(dataset);
|
||||
// index_ptr_->AddIndex(vectorset, metaset);
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -159,7 +136,7 @@ CPUSPTAGRNG::SetParameters(const Config& config) {
|
|||
#define Assign(param_name, str_name) \
|
||||
index_ptr_->SetParameter(str_name, std::to_string(build_cfg[param_name].get<int64_t>()))
|
||||
|
||||
if (index_type_ == SPTAG::IndexAlgoType::KDT) {
|
||||
if (index_type_ == IndexEnum::INDEX_SPTAG_KDT_RNT) {
|
||||
auto build_cfg = SPTAGParameterMgr::GetInstance().GetKDTParameters();
|
||||
|
||||
Assign("kdtnumber", "KDTNumber");
|
||||
|
@ -204,17 +181,17 @@ CPUSPTAGRNG::SetParameters(const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
CPUSPTAGRNG::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
SetParameters(config);
|
||||
|
||||
auto p_data = dataset->Get<const float*>(meta::TENSOR);
|
||||
float* p_data = (float*)dataset_ptr->Get<const void*>(meta::TENSOR);
|
||||
for (auto i = 0; i < 10; ++i) {
|
||||
for (auto j = 0; j < 10; ++j) {
|
||||
std::cout << p_data[i * 10 + j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::vector<SPTAG::QueryResult> query_results = ConvertToQueryResult(dataset, config);
|
||||
std::vector<SPTAG::QueryResult> query_results = ConvertToQueryResult(dataset_ptr, config);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (auto i = 0; i < query_results.size(); ++i) {
|
||||
|
@ -232,28 +209,24 @@ CPUSPTAGRNG::Count() {
|
|||
}
|
||||
|
||||
int64_t
|
||||
CPUSPTAGRNG::Dimension() {
|
||||
CPUSPTAGRNG::Dim() {
|
||||
return index_ptr_->GetFeatureDim();
|
||||
}
|
||||
|
||||
// VectorIndexPtr
|
||||
// CPUSPTAGRNG::Clone() {
|
||||
// KNOWHERE_THROW_MSG("not support");
|
||||
//}
|
||||
// void
|
||||
// CPUSPTAGRNG::Add(const DatasetPtr& origin, const Config& add_config) {
|
||||
// SetParameters(add_config);
|
||||
// DatasetPtr dataset = origin->Clone();
|
||||
|
||||
void
|
||||
CPUSPTAGRNG::Seal() {
|
||||
return; // do nothing
|
||||
}
|
||||
// // if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
|
||||
// // && preprocessor_) {
|
||||
// // preprocessor_->Preprocess(dataset);
|
||||
// //}
|
||||
|
||||
BinarySet
|
||||
CPUSPTAGRNGIndexModel::Serialize() {
|
||||
// KNOWHERE_THROW_MSG("not support"); // not support
|
||||
}
|
||||
|
||||
void
|
||||
CPUSPTAGRNGIndexModel::Load(const BinarySet& binary) {
|
||||
// KNOWHERE_THROW_MSG("not support"); // not support
|
||||
}
|
||||
// auto vectorset = ConvertToVectorSet(dataset);
|
||||
// auto metaset = ConvertToMetadataSet(dataset);
|
||||
// index_ptr_->AddIndex(vectorset, metaset);
|
||||
// }
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -17,71 +17,58 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "VectorIndex.h"
|
||||
#include "knowhere/index/IndexModel.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class CPUSPTAGRNG : public VectorIndex {
|
||||
class CPUSPTAGRNG : public VecIndex {
|
||||
public:
|
||||
explicit CPUSPTAGRNG(const std::string& IndexType);
|
||||
|
||||
public:
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
|
||||
// VectorIndexPtr
|
||||
// Clone() override;
|
||||
Serialize(const Config& config = Config()) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_array) override;
|
||||
|
||||
public:
|
||||
// PreprocessorPtr
|
||||
// BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
|
||||
void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override {
|
||||
Train(dataset_ptr, config);
|
||||
}
|
||||
|
||||
void
|
||||
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr&, const Config&) override {
|
||||
KNOWHERE_THROW_MSG("Incremental index is not supported");
|
||||
}
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override {
|
||||
KNOWHERE_THROW_MSG("Incremental index is not supported");
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
int64_t
|
||||
Dimension() override;
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Seal() override;
|
||||
Dim() override;
|
||||
|
||||
private:
|
||||
void
|
||||
SetParameters(const Config& config);
|
||||
|
||||
private:
|
||||
PreprocessorPtr preprocessor_;
|
||||
std::shared_ptr<SPTAG::VectorIndex> index_ptr_;
|
||||
SPTAG::IndexAlgoType index_type_;
|
||||
};
|
||||
|
||||
using CPUSPTAGRNGPtr = std::shared_ptr<CPUSPTAGRNG>;
|
||||
|
||||
class CPUSPTAGRNGIndexModel : public IndexModel {
|
||||
public:
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& binary) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SPTAG::VectorIndex> index_;
|
||||
};
|
||||
|
||||
using CPUSPTAGRNGIndexModelPtr = std::shared_ptr<CPUSPTAGRNGIndexModel>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
#include <unordered_map>
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
/* for compatible with 0.7.0 */
|
||||
static std::unordered_map<int32_t, std::string> old_index_type_str_map = {
|
||||
{(int32_t)OldIndexType::INVALID, "INVALID"},
|
||||
{(int32_t)OldIndexType::FAISS_IDMAP, IndexEnum::INDEX_FAISS_IDMAP},
|
||||
{(int32_t)OldIndexType::FAISS_IVFFLAT_CPU, IndexEnum::INDEX_FAISS_IVFFLAT},
|
||||
{(int32_t)OldIndexType::FAISS_IVFFLAT_GPU, IndexEnum::INDEX_FAISS_IVFFLAT},
|
||||
{(int32_t)OldIndexType::FAISS_IVFFLAT_MIX, IndexEnum::INDEX_FAISS_IVFFLAT},
|
||||
{(int32_t)OldIndexType::FAISS_IVFPQ_CPU, IndexEnum::INDEX_FAISS_IVFPQ},
|
||||
{(int32_t)OldIndexType::FAISS_IVFPQ_GPU, IndexEnum::INDEX_FAISS_IVFPQ},
|
||||
{(int32_t)OldIndexType::FAISS_IVFPQ_MIX, IndexEnum::INDEX_FAISS_IVFPQ},
|
||||
{(int32_t)OldIndexType::FAISS_IVFSQ8_MIX, IndexEnum::INDEX_FAISS_IVFSQ8},
|
||||
{(int32_t)OldIndexType::FAISS_IVFSQ8_CPU, IndexEnum::INDEX_FAISS_IVFSQ8},
|
||||
{(int32_t)OldIndexType::FAISS_IVFSQ8_GPU, IndexEnum::INDEX_FAISS_IVFSQ8},
|
||||
{(int32_t)OldIndexType::FAISS_IVFSQ8_HYBRID, IndexEnum::INDEX_FAISS_IVFSQ8H},
|
||||
{(int32_t)OldIndexType::NSG_MIX, IndexEnum::INDEX_NSG},
|
||||
{(int32_t)OldIndexType::SPTAG_KDT_RNT_CPU, IndexEnum::INDEX_SPTAG_KDT_RNT},
|
||||
{(int32_t)OldIndexType::SPTAG_BKT_RNT_CPU, IndexEnum::INDEX_SPTAG_BKT_RNT},
|
||||
{(int32_t)OldIndexType::HNSW, IndexEnum::INDEX_HNSW},
|
||||
{(int32_t)OldIndexType::FAISS_BIN_IDMAP, IndexEnum::INDEX_FAISS_BIN_IDMAP},
|
||||
{(int32_t)OldIndexType::FAISS_BIN_IVFLAT_CPU, IndexEnum::INDEX_FAISS_BIN_IVFFLAT},
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, int32_t> str_old_index_type_map = {
|
||||
{"", (int32_t)OldIndexType::INVALID},
|
||||
{IndexEnum::INDEX_FAISS_IDMAP, (int32_t)OldIndexType::FAISS_IDMAP},
|
||||
{IndexEnum::INDEX_FAISS_IVFFLAT, (int32_t)OldIndexType::FAISS_IVFFLAT_CPU},
|
||||
{IndexEnum::INDEX_FAISS_IVFPQ, (int32_t)OldIndexType::FAISS_IVFPQ_CPU},
|
||||
{IndexEnum::INDEX_FAISS_IVFSQ8, (int32_t)OldIndexType::FAISS_IVFSQ8_CPU},
|
||||
{IndexEnum::INDEX_FAISS_IVFSQ8H, (int32_t)OldIndexType::FAISS_IVFSQ8_HYBRID},
|
||||
{IndexEnum::INDEX_NSG, (int32_t)OldIndexType::NSG_MIX},
|
||||
{IndexEnum::INDEX_SPTAG_KDT_RNT, (int32_t)OldIndexType::SPTAG_KDT_RNT_CPU},
|
||||
{IndexEnum::INDEX_SPTAG_BKT_RNT, (int32_t)OldIndexType::SPTAG_BKT_RNT_CPU},
|
||||
{IndexEnum::INDEX_HNSW, (int32_t)OldIndexType::HNSW},
|
||||
{IndexEnum::INDEX_FAISS_BIN_IDMAP, (int32_t)OldIndexType::FAISS_BIN_IDMAP},
|
||||
{IndexEnum::INDEX_FAISS_BIN_IVFFLAT, (int32_t)OldIndexType::FAISS_BIN_IVFLAT_CPU},
|
||||
};
|
||||
|
||||
std::string
|
||||
OldIndexTypeToStr(const int32_t type) {
|
||||
try {
|
||||
return old_index_type_str_map.at(type);
|
||||
} catch (...) {
|
||||
KNOWHERE_THROW_MSG("Invalid index type " + std::to_string(type));
|
||||
}
|
||||
}
|
||||
|
||||
int32_t
|
||||
StrToOldIndexType(const std::string& str) {
|
||||
try {
|
||||
return str_old_index_type_map.at(str);
|
||||
} catch (...) {
|
||||
KNOWHERE_THROW_MSG("Invalid index str " + str);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
/* used in 0.7.0 */
|
||||
enum class OldIndexType {
|
||||
INVALID = 0,
|
||||
FAISS_IDMAP = 1,
|
||||
FAISS_IVFFLAT_CPU,
|
||||
FAISS_IVFFLAT_GPU,
|
||||
FAISS_IVFFLAT_MIX, // build on gpu and search on cpu
|
||||
FAISS_IVFPQ_CPU,
|
||||
FAISS_IVFPQ_GPU,
|
||||
SPTAG_KDT_RNT_CPU,
|
||||
FAISS_IVFSQ8_MIX,
|
||||
FAISS_IVFSQ8_CPU,
|
||||
FAISS_IVFSQ8_GPU,
|
||||
FAISS_IVFSQ8_HYBRID, // only support build on gpu.
|
||||
NSG_MIX,
|
||||
FAISS_IVFPQ_MIX,
|
||||
SPTAG_BKT_RNT_CPU,
|
||||
HNSW,
|
||||
FAISS_BIN_IDMAP = 100,
|
||||
FAISS_BIN_IVFLAT_CPU = 101,
|
||||
};
|
||||
|
||||
using IndexType = std::string;
|
||||
|
||||
/* used in 0.8.0 */
|
||||
namespace IndexEnum {
|
||||
constexpr const char* INVALID = "";
|
||||
constexpr const char* INDEX_FAISS_IDMAP = "IDMAP";
|
||||
constexpr const char* INDEX_FAISS_IVFFLAT = "IVF_FLAT";
|
||||
constexpr const char* INDEX_FAISS_IVFPQ = "IVF_PQ";
|
||||
constexpr const char* INDEX_FAISS_IVFSQ8 = "IVF_SQ8";
|
||||
constexpr const char* INDEX_FAISS_IVFSQ8H = "IVF_SQ8_HYBRID";
|
||||
constexpr const char* INDEX_FAISS_BIN_IDMAP = "BIN_IDMAP";
|
||||
constexpr const char* INDEX_FAISS_BIN_IVFFLAT = "BIN_IVF_FLAT";
|
||||
constexpr const char* INDEX_NSG = "NSG";
|
||||
constexpr const char* INDEX_SPTAG_KDT_RNT = "SPTAG_KDT_RNT";
|
||||
constexpr const char* INDEX_SPTAG_BKT_RNT = "SPTAG_BKT_RNT";
|
||||
constexpr const char* INDEX_HNSW = "HNSW";
|
||||
} // namespace IndexEnum
|
||||
|
||||
enum class IndexMode { MODE_CPU = 0, MODE_GPU = 1 };
|
||||
|
||||
extern std::string
|
||||
OldIndexTypeToStr(const int32_t type);
|
||||
|
||||
extern int32_t
|
||||
StrToOldIndexType(const std::string& str);
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,121 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <faiss/utils/ConcurrentBitset.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "knowhere/common/Dataset.h"
|
||||
#include "knowhere/common/Typedef.h"
|
||||
#include "knowhere/index/Index.h"
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
#include "segment/Types.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class VecIndex : public Index {
|
||||
public:
|
||||
virtual void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
Train(dataset_ptr, config);
|
||||
Add(dataset_ptr, config);
|
||||
}
|
||||
|
||||
virtual void
|
||||
Train(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
virtual void
|
||||
Add(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
virtual void
|
||||
AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
virtual DatasetPtr
|
||||
Query(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
virtual DatasetPtr
|
||||
QueryById(const DatasetPtr& dataset, const Config& config) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// virtual DatasetPtr
|
||||
// QueryByRange(const DatasetPtr&, const Config&) = 0;
|
||||
//
|
||||
// virtual MetricType
|
||||
// metric_type() = 0;
|
||||
|
||||
virtual int64_t
|
||||
Dim() = 0;
|
||||
|
||||
virtual int64_t
|
||||
Count() = 0;
|
||||
|
||||
virtual IndexType
|
||||
index_type() const {
|
||||
return index_type_;
|
||||
}
|
||||
|
||||
virtual IndexMode
|
||||
index_mode() const {
|
||||
return index_mode_;
|
||||
}
|
||||
|
||||
virtual DatasetPtr
|
||||
GetVectorById(const DatasetPtr& dataset, const Config& config) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual void
|
||||
GetBlacklist(faiss::ConcurrentBitsetPtr& bitset_ptr) {
|
||||
bitset_ptr = bitset_;
|
||||
}
|
||||
|
||||
virtual void
|
||||
SetBlacklist(faiss::ConcurrentBitsetPtr bitset_ptr) {
|
||||
bitset_ = std::move(bitset_ptr);
|
||||
}
|
||||
|
||||
virtual const std::vector<milvus::segment::doc_id_t>&
|
||||
GetUids() const {
|
||||
return uids_;
|
||||
}
|
||||
|
||||
virtual void
|
||||
SetUids(std::vector<milvus::segment::doc_id_t>& uids) {
|
||||
uids_.clear();
|
||||
uids_.swap(uids);
|
||||
}
|
||||
|
||||
int64_t
|
||||
Size() override {
|
||||
if (size_ != -1) {
|
||||
return size_;
|
||||
}
|
||||
return Count() * Dim() * sizeof(FloatType);
|
||||
}
|
||||
|
||||
protected:
|
||||
IndexType index_type_ = "";
|
||||
IndexMode index_mode_ = IndexMode::MODE_CPU;
|
||||
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
|
||||
|
||||
private:
|
||||
std::vector<milvus::segment::doc_id_t> uids_;
|
||||
};
|
||||
|
||||
using VecIndexPtr = std::shared_ptr<VecIndex>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "knowhere/index/vector_index/VecIndexFactory.h"
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexBinaryIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexHNSW.h"
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/IndexNSG.h"
|
||||
#include "knowhere/index/vector_index/IndexSPTAG.h"
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include <cuda.h>
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h"
|
||||
#include "knowhere/index/vector_index/helpers/Cloner.h"
|
||||
#endif
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
VecIndexPtr
|
||||
VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) {
|
||||
auto gpu_device = -1; // TODO: remove hardcode here, get from invoker
|
||||
if (type == IndexEnum::INDEX_FAISS_IDMAP) {
|
||||
return std::make_shared<knowhere::IDMAP>();
|
||||
} else if (type == IndexEnum::INDEX_FAISS_IVFFLAT) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (mode == IndexMode::MODE_GPU) {
|
||||
return std::make_shared<knowhere::GPUIVF>(gpu_device);
|
||||
}
|
||||
#endif
|
||||
return std::make_shared<knowhere::IVF>();
|
||||
} else if (type == IndexEnum::INDEX_FAISS_IVFPQ) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (mode == IndexMode::MODE_GPU) {
|
||||
return std::make_shared<knowhere::GPUIVFPQ>(gpu_device);
|
||||
}
|
||||
#endif
|
||||
return std::make_shared<knowhere::IVFPQ>();
|
||||
} else if (type == IndexEnum::INDEX_FAISS_IVFSQ8) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (mode == IndexMode::MODE_GPU) {
|
||||
return std::make_shared<knowhere::GPUIVFSQ>(gpu_device);
|
||||
}
|
||||
#endif
|
||||
return std::make_shared<knowhere::IVFSQ>();
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
} else if (type == IndexEnum::INDEX_FAISS_IVFSQ8H) {
|
||||
return std::make_shared<knowhere::IVFSQHybrid>(gpu_device);
|
||||
#endif
|
||||
} else if (type == IndexEnum::INDEX_FAISS_BIN_IDMAP) {
|
||||
return std::make_shared<knowhere::BinaryIDMAP>();
|
||||
} else if (type == IndexEnum::INDEX_FAISS_BIN_IVFFLAT) {
|
||||
return std::make_shared<knowhere::BinaryIVF>();
|
||||
} else if (type == IndexEnum::INDEX_NSG) {
|
||||
return std::make_shared<knowhere::NSG>(-1);
|
||||
} else if (type == IndexEnum::INDEX_SPTAG_KDT_RNT) {
|
||||
return std::make_shared<knowhere::CPUSPTAGRNG>("KDT");
|
||||
} else if (type == IndexEnum::INDEX_SPTAG_BKT_RNT) {
|
||||
return std::make_shared<knowhere::CPUSPTAGRNG>("BKT");
|
||||
} else if (type == IndexEnum::INDEX_HNSW) {
|
||||
return std::make_shared<knowhere::IndexHNSW>();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,24 +7,35 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include "knowhere/common/BinarySet.h"
|
||||
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IndexModel {
|
||||
public:
|
||||
virtual BinarySet
|
||||
Serialize() = 0;
|
||||
class VecIndexFactory {
|
||||
private:
|
||||
VecIndexFactory() = default;
|
||||
VecIndexFactory(const VecIndexFactory&) = delete;
|
||||
VecIndexFactory
|
||||
operator=(const VecIndexFactory&) = delete;
|
||||
|
||||
virtual void
|
||||
Load(const BinarySet& binary) = 0;
|
||||
public:
|
||||
static VecIndexFactory&
|
||||
GetInstance() {
|
||||
static VecIndexFactory inst;
|
||||
return inst;
|
||||
}
|
||||
|
||||
knowhere::VecIndexPtr
|
||||
CreateVecIndex(const IndexType& type, const IndexMode mode = IndexMode::MODE_CPU);
|
||||
};
|
||||
|
||||
using IndexModelPtr = std::shared_ptr<IndexModel>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -1,82 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "knowhere/common/Config.h"
|
||||
#include "knowhere/common/Dataset.h"
|
||||
#include "knowhere/index/Index.h"
|
||||
#include "knowhere/index/preprocessor/Preprocessor.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
#include "segment/Types.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
class VectorIndex;
|
||||
using VectorIndexPtr = std::shared_ptr<VectorIndex>;
|
||||
|
||||
class VectorIndex : public Index {
|
||||
public:
|
||||
virtual PreprocessorPtr
|
||||
BuildPreprocessor(const DatasetPtr& dataset, const Config& config) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual DatasetPtr
|
||||
GetVectorById(const DatasetPtr& dataset, const Config& config) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual DatasetPtr
|
||||
SearchById(const DatasetPtr& dataset, const Config& config) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual void
|
||||
Add(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
virtual void
|
||||
Seal() = 0;
|
||||
|
||||
// TODO(linxj): Deprecated
|
||||
// virtual VectorIndexPtr
|
||||
// Clone() = 0;
|
||||
|
||||
virtual int64_t
|
||||
Count() = 0;
|
||||
|
||||
virtual int64_t
|
||||
Dimension() = 0;
|
||||
|
||||
virtual const std::vector<milvus::segment::doc_id_t>&
|
||||
GetUids() const {
|
||||
return uids_;
|
||||
}
|
||||
|
||||
virtual void
|
||||
SetUids(std::vector<milvus::segment::doc_id_t>& uids) {
|
||||
uids_.clear();
|
||||
uids_.swap(uids);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<milvus::segment::doc_id_t> uids_;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
|
@ -9,16 +9,16 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/adapter/SptagAdapter.h"
|
||||
|
||||
#include "VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/adapter/SptagAdapter.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
std::shared_ptr<SPTAG::MetadataSet>
|
||||
ConvertToMetadataSet(const DatasetPtr& dataset) {
|
||||
auto elems = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
ConvertToMetadataSet(const DatasetPtr& dataset_ptr) {
|
||||
auto elems = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
auto p_offset = (int64_t*)malloc(sizeof(int64_t) * (elems + 1));
|
||||
for (auto i = 0; i <= elems; ++i) p_offset[i] = i * 8;
|
||||
|
@ -31,8 +31,8 @@ ConvertToMetadataSet(const DatasetPtr& dataset) {
|
|||
}
|
||||
|
||||
std::shared_ptr<SPTAG::VectorSet>
|
||||
ConvertToVectorSet(const DatasetPtr& dataset) {
|
||||
GETTENSOR(dataset);
|
||||
ConvertToVectorSet(const DatasetPtr& dataset_ptr) {
|
||||
GETTENSOR(dataset_ptr);
|
||||
size_t num_bytes = rows * dim * sizeof(float);
|
||||
SPTAG::ByteArray byte_array((uint8_t*)p_data, num_bytes, false);
|
||||
|
||||
|
@ -41,13 +41,13 @@ ConvertToVectorSet(const DatasetPtr& dataset) {
|
|||
}
|
||||
|
||||
std::vector<SPTAG::QueryResult>
|
||||
ConvertToQueryResult(const DatasetPtr& dataset, const Config& config) {
|
||||
GETTENSOR(dataset);
|
||||
ConvertToQueryResult(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GETTENSOR(dataset_ptr);
|
||||
|
||||
std::vector<SPTAG::QueryResult> query_results(rows,
|
||||
SPTAG::QueryResult(nullptr, config[meta::TOPK].get<int64_t>(), true));
|
||||
int64_t k = config[meta::TOPK].get<int64_t>();
|
||||
std::vector<SPTAG::QueryResult> query_results(rows, SPTAG::QueryResult(nullptr, k, true));
|
||||
for (auto i = 0; i < rows; ++i) {
|
||||
query_results[i].SetTarget(&p_data[i * dim]);
|
||||
query_results[i].SetTarget((float*)p_data + i * dim);
|
||||
}
|
||||
|
||||
return query_results;
|
||||
|
@ -81,3 +81,4 @@ ConvertToDataset(std::vector<SPTAG::QueryResult> query_results) {
|
|||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -18,18 +18,20 @@
|
|||
#include "knowhere/common/Config.h"
|
||||
#include "knowhere/common/Dataset.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
std::shared_ptr<SPTAG::VectorSet>
|
||||
ConvertToVectorSet(const DatasetPtr& dataset);
|
||||
ConvertToVectorSet(const DatasetPtr& dataset_ptr);
|
||||
|
||||
std::shared_ptr<SPTAG::MetadataSet>
|
||||
ConvertToMetadataSet(const DatasetPtr& dataset);
|
||||
ConvertToMetadataSet(const DatasetPtr& dataset_ptr);
|
||||
|
||||
std::vector<SPTAG::QueryResult>
|
||||
ConvertToQueryResult(const DatasetPtr& dataset, const Config& config);
|
||||
ConvertToQueryResult(const DatasetPtr& dataset_ptr, const Config& config);
|
||||
|
||||
DatasetPtr
|
||||
ConvertToDataset(std::vector<SPTAG::QueryResult> query_results);
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -9,22 +9,33 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <memory>
|
||||
|
||||
#include <string>
|
||||
#include "knowhere/common/Dataset.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
#define GETTENSOR(dataset) \
|
||||
auto dim = dataset->Get<int64_t>(meta::DIM); \
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS); \
|
||||
auto p_data = dataset->Get<const float*>(meta::TENSOR);
|
||||
DatasetPtr
|
||||
GenDatasetWithIds(const int64_t nb, const int64_t dim, const void* xb, const int64_t* ids) {
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::ROWS, nb);
|
||||
ret_ds->Set(meta::DIM, dim);
|
||||
ret_ds->Set(meta::TENSOR, xb);
|
||||
ret_ds->Set(meta::IDS, ids);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
#define GETBINARYTENSOR(dataset) \
|
||||
auto dim = dataset->Get<int64_t>(meta::DIM); \
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS); \
|
||||
auto p_data = dataset->Get<const uint8_t*>(meta::TENSOR);
|
||||
DatasetPtr
|
||||
GenDataset(const int64_t nb, const int64_t dim, const void* xb) {
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::ROWS, nb);
|
||||
ret_ds->Set(meta::DIM, dim);
|
||||
ret_ds->Set(meta::TENSOR, xb);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "knowhere/common/Dataset.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
#define GETTENSOR(dataset_ptr) \
|
||||
int64_t dim = dataset_ptr->Get<int64_t>(meta::DIM); \
|
||||
int64_t rows = dataset_ptr->Get<int64_t>(meta::ROWS); \
|
||||
const void* p_data = dataset_ptr->Get<const void*>(meta::TENSOR);
|
||||
|
||||
#define GETTENSORWITHIDS(dataset_ptr) \
|
||||
int64_t dim = dataset_ptr->Get<int64_t>(meta::DIM); \
|
||||
int64_t rows = dataset_ptr->Get<int64_t>(meta::ROWS); \
|
||||
const void* p_data = dataset_ptr->Get<const void*>(meta::TENSOR); \
|
||||
const int64_t* p_ids = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
extern DatasetPtr
|
||||
GenDatasetWithIds(const int64_t nb, const int64_t dim, const void* xb, const int64_t* ids);
|
||||
|
||||
extern DatasetPtr
|
||||
GenDataset(const int64_t nb, const int64_t dim, const void* xb);
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class GPUIndex {
|
||||
public:
|
||||
explicit GPUIndex(const int& device_id) : gpu_id_(device_id) {
|
||||
}
|
||||
|
||||
GPUIndex(const int& device_id, const ResPtr& resource) : gpu_id_(device_id), res_(resource) {
|
||||
}
|
||||
|
||||
virtual VecIndexPtr
|
||||
CopyGpuToCpu(const Config&) = 0;
|
||||
|
||||
virtual VecIndexPtr
|
||||
CopyGpuToGpu(const int64_t, const Config&) = 0;
|
||||
|
||||
void
|
||||
SetGpuDevice(const int& gpu_id) {
|
||||
gpu_id_ = gpu_id;
|
||||
}
|
||||
|
||||
const int64_t
|
||||
GetGpuDevice() {
|
||||
return gpu_id_;
|
||||
}
|
||||
|
||||
protected:
|
||||
int64_t gpu_id_;
|
||||
ResWPtr res_;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,31 +7,28 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexGPUIDMAP.h"
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <faiss/AutoTune.h>
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/MetaIndexes.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <fiu-local.h>
|
||||
#include <string>
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
|
||||
#endif
|
||||
#include <fiu-local.h>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
VectorIndexPtr
|
||||
VecIndexPtr
|
||||
GPUIDMAP::CopyGpuToCpu(const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
|
@ -43,19 +40,8 @@ GPUIDMAP::CopyGpuToCpu(const Config& config) {
|
|||
return std::make_shared<IDMAP>(new_index);
|
||||
}
|
||||
|
||||
// VectorIndexPtr
|
||||
// GPUIDMAP::Clone() {
|
||||
// auto cpu_idx = CopyGpuToCpu(Config());
|
||||
//
|
||||
// if (auto idmap = std::dynamic_pointer_cast<IDMAP>(cpu_idx)) {
|
||||
// return idmap->CopyCpuToGpu(gpu_id_, Config());
|
||||
// } else {
|
||||
// KNOWHERE_THROW_MSG("IndexType not Support GpuClone");
|
||||
// }
|
||||
//}
|
||||
|
||||
BinarySet
|
||||
GPUIDMAP::SerializeImpl() {
|
||||
GPUIDMAP::SerializeImpl(const IndexType& type) {
|
||||
try {
|
||||
fiu_do_on("GPUIDMP.SerializeImpl.throw_exception", throw std::exception());
|
||||
MemoryIOWriter writer;
|
||||
|
@ -79,7 +65,7 @@ GPUIDMAP::SerializeImpl() {
|
|||
}
|
||||
|
||||
void
|
||||
GPUIDMAP::LoadImpl(const BinarySet& index_binary) {
|
||||
GPUIDMAP::LoadImpl(const BinarySet& index_binary, const IndexType& type) {
|
||||
auto binary = index_binary.GetByName("IVF");
|
||||
MemoryIOReader reader;
|
||||
{
|
||||
|
@ -101,30 +87,30 @@ GPUIDMAP::LoadImpl(const BinarySet& index_binary) {
|
|||
}
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
GPUIDMAP::CopyGpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
VecIndexPtr
|
||||
GPUIDMAP::CopyGpuToGpu(const int64_t device_id, const Config& config) {
|
||||
auto cpu_index = CopyGpuToCpu(config);
|
||||
return std::static_pointer_cast<IDMAP>(cpu_index)->CopyCpuToGpu(device_id, config);
|
||||
}
|
||||
|
||||
float*
|
||||
const float*
|
||||
GPUIDMAP::GetRawVectors() {
|
||||
KNOWHERE_THROW_MSG("Not support");
|
||||
}
|
||||
|
||||
int64_t*
|
||||
const int64_t*
|
||||
GPUIDMAP::GetRawIds() {
|
||||
KNOWHERE_THROW_MSG("Not support");
|
||||
}
|
||||
|
||||
void
|
||||
GPUIDMAP::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
|
||||
GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
ResScope rs(res_, gpu_id_);
|
||||
index_->search(n, (float*)data, k, distances, labels);
|
||||
}
|
||||
|
||||
void
|
||||
GPUIDMAP::GenGraph(const float* data, const int64_t& k, Graph& graph, const Config& config) {
|
||||
GPUIDMAP::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config) {
|
||||
int64_t K = k + 1;
|
||||
auto ntotal = Count();
|
||||
|
||||
|
@ -144,7 +130,7 @@ GPUIDMAP::GenGraph(const float* data, const int64_t& k, Graph& graph, const Conf
|
|||
res.resize(K * b_size);
|
||||
|
||||
auto xq = data + batch_size * dim * i;
|
||||
search_impl(b_size, (float*)xq, K, res_dis.data(), res.data(), config);
|
||||
QueryImpl(b_size, (float*)xq, K, res_dis.data(), res.data(), config);
|
||||
|
||||
for (int j = 0; j < b_size; ++j) {
|
||||
auto& node = graph[batch_size * i + j];
|
||||
|
@ -158,3 +144,4 @@ GPUIDMAP::GenGraph(const float* data, const int64_t& k, Graph& graph, const Conf
|
|||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,55 +7,56 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "IndexGPUIVF.h"
|
||||
#include "IndexIDMAP.h"
|
||||
#include "IndexIVF.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/gpu/GPUIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
using Graph = std::vector<std::vector<int64_t>>;
|
||||
|
||||
class GPUIDMAP : public IDMAP, public GPUIndex {
|
||||
public:
|
||||
explicit GPUIDMAP(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& res)
|
||||
: IDMAP(std::move(index)), GPUIndex(device_id, res) {
|
||||
index_mode_ = IndexMode::MODE_GPU;
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const Config& config) override;
|
||||
VecIndexPtr
|
||||
CopyGpuToCpu(const Config&) override;
|
||||
|
||||
float*
|
||||
VecIndexPtr
|
||||
CopyGpuToGpu(const int64_t, const Config&) override;
|
||||
|
||||
const float*
|
||||
GetRawVectors() override;
|
||||
|
||||
int64_t*
|
||||
const int64_t*
|
||||
GetRawIds() override;
|
||||
|
||||
// VectorIndexPtr
|
||||
// Clone() override;
|
||||
|
||||
VectorIndexPtr
|
||||
CopyGpuToGpu(const int64_t& device_id, const Config& config) override;
|
||||
|
||||
void
|
||||
GenGraph(const float* data, const int64_t& k, Graph& graph, const Config& config);
|
||||
GenGraph(const float*, const int64_t, GraphType&, const Config&);
|
||||
|
||||
protected:
|
||||
void
|
||||
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) override;
|
||||
|
||||
BinarySet
|
||||
SerializeImpl() override;
|
||||
SerializeImpl(const IndexType&) override;
|
||||
|
||||
void
|
||||
LoadImpl(const BinarySet& index_binary) override;
|
||||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
|
||||
};
|
||||
|
||||
using GPUIDMAPPtr = std::shared_ptr<GPUIDMAP>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,66 +7,87 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#include <faiss/gpu/GpuIndexIVF.h>
|
||||
#include <faiss/gpu/GpuIndexIVFFlat.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <fiu-local.h>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/helpers/Cloner.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
IndexModelPtr
|
||||
GPUIVF::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
GETTENSOR(dataset)
|
||||
void
|
||||
GPUIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GETTENSOR(dataset_ptr)
|
||||
gpu_id_ = config[knowhere::meta::DEVICEID];
|
||||
|
||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (temp_resource != nullptr) {
|
||||
ResScope rs(temp_resource, gpu_id_, true);
|
||||
auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (gpu_res != nullptr) {
|
||||
ResScope rs(gpu_res, gpu_id_, true);
|
||||
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
|
||||
idx_config.device = gpu_id_;
|
||||
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, config[IndexParams::nlist],
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>()), idx_config);
|
||||
int32_t nlist = config[IndexParams::nlist];
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
faiss::gpu::GpuIndexIVFFlat device_index(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config);
|
||||
device_index.train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index));
|
||||
|
||||
return std::make_shared<IVFIndexModel>(host_index);
|
||||
auto device_index1 = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
|
||||
index_.reset(device_index1);
|
||||
res_ = gpu_res;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Build IVF can't get gpu resource");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
GPUIVF::set_index_model(IndexModelPtr model) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto host_index = std::static_pointer_cast<IVFIndexModel>(model);
|
||||
if (auto gpures = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) {
|
||||
ResScope rs(gpures, gpu_id_, false);
|
||||
auto device_index = faiss::gpu::index_cpu_to_gpu(gpures->faiss_res.get(), gpu_id_, host_index->index_.get());
|
||||
index_.reset(device_index);
|
||||
res_ = gpures;
|
||||
GPUIVF::Add(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (auto spt = res_.lock()) {
|
||||
ResScope rs(res_, gpu_id_);
|
||||
IVF::Add(dataset_ptr, config);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("load index model error, can't get gpu_resource");
|
||||
KNOWHERE_THROW_MSG("Add IVF can't get gpu resource");
|
||||
}
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
GPUIVF::CopyGpuToCpu(const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
if (auto device_idx = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) {
|
||||
faiss::Index* device_index = index_.get();
|
||||
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
|
||||
std::shared_ptr<faiss::Index> new_index;
|
||||
new_index.reset(host_index);
|
||||
return std::make_shared<IVF>(new_index);
|
||||
} else {
|
||||
return std::make_shared<IVF>(index_);
|
||||
}
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
GPUIVF::CopyGpuToGpu(const int64_t device_id, const Config& config) {
|
||||
auto host_index = CopyGpuToCpu(config);
|
||||
return std::static_pointer_cast<IVF>(host_index)->CopyCpuToGpu(device_id, config);
|
||||
}
|
||||
|
||||
BinarySet
|
||||
GPUIVF::SerializeImpl() {
|
||||
GPUIVF::SerializeImpl(const IndexType& type) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -94,8 +115,8 @@ GPUIVF::SerializeImpl() {
|
|||
}
|
||||
|
||||
void
|
||||
GPUIVF::LoadImpl(const BinarySet& index_binary) {
|
||||
auto binary = index_binary.GetByName("IVF");
|
||||
GPUIVF::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
|
||||
auto binary = binary_set.GetByName("IVF");
|
||||
MemoryIOReader reader;
|
||||
{
|
||||
reader.total = binary->size;
|
||||
|
@ -117,7 +138,7 @@ GPUIVF::LoadImpl(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
void
|
||||
GPUIVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
|
||||
|
@ -131,52 +152,5 @@ GPUIVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, i
|
|||
}
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
GPUIVF::CopyGpuToCpu(const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
if (auto device_idx = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) {
|
||||
faiss::Index* device_index = index_.get();
|
||||
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
|
||||
std::shared_ptr<faiss::Index> new_index;
|
||||
new_index.reset(host_index);
|
||||
return std::make_shared<IVF>(new_index);
|
||||
} else {
|
||||
return std::make_shared<IVF>(index_);
|
||||
}
|
||||
}
|
||||
|
||||
// VectorIndexPtr
|
||||
// GPUIVF::Clone() {
|
||||
// auto cpu_idx = CopyGpuToCpu(Config());
|
||||
// return knowhere::cloner::CopyCpuToGpu(cpu_idx, gpu_id_, Config());
|
||||
//}
|
||||
|
||||
VectorIndexPtr
|
||||
GPUIVF::CopyGpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
auto host_index = CopyGpuToCpu(config);
|
||||
return std::static_pointer_cast<IVF>(host_index)->CopyCpuToGpu(device_id, config);
|
||||
}
|
||||
|
||||
void
|
||||
GPUIVF::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
if (auto spt = res_.lock()) {
|
||||
ResScope rs(res_, gpu_id_);
|
||||
IVF::Add(dataset, config);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Add IVF can't get gpu resource");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
GPUIndex::SetGpuDevice(const int& gpu_id) {
|
||||
gpu_id_ = gpu_id;
|
||||
}
|
||||
|
||||
const int64_t&
|
||||
GPUIndex::GetGpuDevice() {
|
||||
return gpu_id_;
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,60 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/gpu/GPUIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class GPUIVF : public IVF, public GPUIndex {
|
||||
public:
|
||||
explicit GPUIVF(const int& device_id) : IVF(), GPUIndex(device_id) {
|
||||
index_mode_ = IndexMode::MODE_GPU;
|
||||
}
|
||||
|
||||
explicit GPUIVF(std::shared_ptr<faiss::Index> index, const int64_t device_id, ResPtr& res)
|
||||
: IVF(std::move(index)), GPUIndex(device_id, res) {
|
||||
index_mode_ = IndexMode::MODE_GPU;
|
||||
}
|
||||
|
||||
void
|
||||
Train(const DatasetPtr&, const Config&) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr&, const Config&) override;
|
||||
|
||||
VecIndexPtr
|
||||
CopyGpuToCpu(const Config&) override;
|
||||
|
||||
VecIndexPtr
|
||||
CopyGpuToGpu(const int64_t, const Config&) override;
|
||||
|
||||
protected:
|
||||
BinarySet
|
||||
SerializeImpl(const IndexType&) override;
|
||||
|
||||
void
|
||||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
|
||||
};
|
||||
|
||||
using GPUIVFPtr = std::shared_ptr<GPUIVF>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,44 +7,60 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <faiss/IndexIVFPQ.h>
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#include <faiss/gpu/GpuIndexIVFPQ.h>
|
||||
#include <faiss/index_factory.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
IndexModelPtr
|
||||
GPUIVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
GETTENSOR(dataset)
|
||||
void
|
||||
GPUIVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GETTENSOR(dataset_ptr)
|
||||
gpu_id_ = config[knowhere::meta::DEVICEID];
|
||||
|
||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (temp_resource != nullptr) {
|
||||
ResScope rs(temp_resource, gpu_id_, true);
|
||||
auto device_index = new faiss::gpu::GpuIndexIVFPQ(
|
||||
temp_resource->faiss_res.get(), dim, config[IndexParams::nlist].get<int64_t>(), config[IndexParams::m],
|
||||
config[IndexParams::nbits],
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>())); // IP not support
|
||||
auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (gpu_res != nullptr) {
|
||||
ResScope rs(gpu_res, gpu_id_, true);
|
||||
auto device_index =
|
||||
new faiss::gpu::GpuIndexIVFPQ(gpu_res->faiss_res.get(), dim, config[IndexParams::nlist].get<int64_t>(),
|
||||
config[IndexParams::m], config[IndexParams::nbits],
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>())); // IP not support
|
||||
device_index->train(rows, (float*)p_data);
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
|
||||
return std::make_shared<IVFIndexModel>(host_index);
|
||||
|
||||
auto device_index1 = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
|
||||
index_.reset(device_index1);
|
||||
res_ = gpu_res;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Build IVFPQ can't get gpu resource");
|
||||
}
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
GPUIVFPQ::CopyGpuToCpu(const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
faiss::Index* device_index = index_.get();
|
||||
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
|
||||
std::shared_ptr<faiss::Index> new_index;
|
||||
new_index.reset(host_index);
|
||||
return std::make_shared<IVFPQ>(new_index);
|
||||
}
|
||||
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GPUIVFPQ::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
|
||||
|
@ -56,16 +72,5 @@ GPUIVFPQ::GenParams(const Config& config) {
|
|||
return params;
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
GPUIVFPQ::CopyGpuToCpu(const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
faiss::Index* device_index = index_.get();
|
||||
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
|
||||
std::shared_ptr<faiss::Index> new_index;
|
||||
new_index.reset(host_index);
|
||||
return std::make_shared<IVFPQ>(new_index);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,37 +7,41 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class GPUIVFPQ : public GPUIVF {
|
||||
public:
|
||||
explicit GPUIVFPQ(const int& device_id) : GPUIVF(device_id) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFPQ;
|
||||
}
|
||||
|
||||
GPUIVFPQ(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& resource)
|
||||
: GPUIVF(std::move(index), device_id, resource) {
|
||||
GPUIVFPQ(std::shared_ptr<faiss::Index> index, const int64_t device_id, ResPtr& res)
|
||||
: GPUIVF(std::move(index), device_id, res) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFPQ;
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
void
|
||||
Train(const DatasetPtr&, const Config&) override;
|
||||
|
||||
public:
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const Config& config) override;
|
||||
VecIndexPtr
|
||||
CopyGpuToCpu(const Config&) override;
|
||||
|
||||
protected:
|
||||
// TODO(linxj): remove GenParams.
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GenParams(const Config& config) override;
|
||||
};
|
||||
|
||||
using GPUIVFPQPtr = std::shared_ptr<GPUIVFPQ>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,7 +7,7 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#include <faiss/index_factory.h>
|
||||
|
@ -15,43 +15,40 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
IndexModelPtr
|
||||
GPUIVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
GETTENSOR(dataset)
|
||||
void
|
||||
GPUIVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GETTENSOR(dataset_ptr)
|
||||
gpu_id_ = config[knowhere::meta::DEVICEID];
|
||||
|
||||
std::stringstream index_type;
|
||||
index_type << "IVF" << config[IndexParams::nlist] << ","
|
||||
<< "SQ" << config[IndexParams::nbits];
|
||||
auto build_index =
|
||||
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type);
|
||||
|
||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (temp_resource != nullptr) {
|
||||
ResScope rs(temp_resource, gpu_id_, true);
|
||||
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
|
||||
auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (gpu_res != nullptr) {
|
||||
ResScope rs(gpu_res, gpu_id_, true);
|
||||
auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index);
|
||||
device_index->train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
|
||||
|
||||
delete device_index;
|
||||
delete build_index;
|
||||
|
||||
return std::make_shared<IVFIndexModel>(host_index);
|
||||
index_.reset(device_index);
|
||||
res_ = gpu_res;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource");
|
||||
}
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
VecIndexPtr
|
||||
GPUIVFSQ::CopyGpuToCpu(const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
|
@ -64,3 +61,4 @@ GPUIVFSQ::CopyGpuToCpu(const Config& config) {
|
|||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,31 +7,37 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class GPUIVFSQ : public GPUIVF {
|
||||
public:
|
||||
explicit GPUIVFSQ(const int& device_id) : GPUIVF(device_id) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8;
|
||||
}
|
||||
|
||||
explicit GPUIVFSQ(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& resource)
|
||||
: GPUIVF(std::move(index), device_id, resource) {
|
||||
explicit GPUIVFSQ(std::shared_ptr<faiss::Index> index, const int64_t device_id, ResPtr& res)
|
||||
: GPUIVF(std::move(index), device_id, res) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8;
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
void
|
||||
Train(const DatasetPtr&, const Config&) override;
|
||||
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const Config& config) override;
|
||||
VecIndexPtr
|
||||
CopyGpuToCpu(const Config&) override;
|
||||
};
|
||||
|
||||
using GPUIVFSQPtr = std::shared_ptr<GPUIVFSQ>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -8,12 +8,7 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexIVFSQHybrid.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#include <faiss/gpu/GpuIndexIVF.h>
|
||||
|
@ -22,16 +17,20 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
#ifdef CUSTOMIZATION
|
||||
|
||||
// std::mutex g_mutex;
|
||||
|
||||
IndexModelPtr
|
||||
IVFSQHybrid::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
// std::lock_guard<std::mutex> lk(g_mutex);
|
||||
GETTENSOR(dataset)
|
||||
void
|
||||
IVFSQHybrid::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GETTENSOR(dataset_ptr)
|
||||
gpu_id_ = config[knowhere::meta::DEVICEID];
|
||||
|
||||
std::stringstream index_type;
|
||||
|
@ -40,10 +39,10 @@ IVFSQHybrid::Train(const DatasetPtr& dataset, const Config& config) {
|
|||
auto build_index =
|
||||
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
|
||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (temp_resource != nullptr) {
|
||||
ResScope rs(temp_resource, gpu_id_, true);
|
||||
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
|
||||
auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (gpu_res != nullptr) {
|
||||
ResScope rs(gpu_res, gpu_id_, true);
|
||||
auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index);
|
||||
device_index->train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
|
@ -52,15 +51,18 @@ IVFSQHybrid::Train(const DatasetPtr& dataset, const Config& config) {
|
|||
delete device_index;
|
||||
delete build_index;
|
||||
|
||||
return std::make_shared<IVFIndexModel>(host_index);
|
||||
device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
|
||||
index_.reset(device_index);
|
||||
res_ = gpu_res;
|
||||
gpu_mode_ = 2;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Build IVFSQHybrid can't get gpu resource");
|
||||
}
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
VecIndexPtr
|
||||
IVFSQHybrid::CopyGpuToCpu(const Config& config) {
|
||||
if (gpu_mode == 0) {
|
||||
if (gpu_mode_ == 0) {
|
||||
return std::make_shared<IVFSQHybrid>(index_);
|
||||
}
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
@ -80,8 +82,8 @@ IVFSQHybrid::CopyGpuToCpu(const Config& config) {
|
|||
return std::make_shared<IVFSQHybrid>(new_index);
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
IVFSQHybrid::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
VecIndexPtr
|
||||
IVFSQHybrid::CopyCpuToGpu(const int64_t device_id, const Config& config) {
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
faiss::gpu::GpuClonerOptions option;
|
||||
|
@ -98,135 +100,8 @@ IVFSQHybrid::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::LoadImpl(const BinarySet& index_binary) {
|
||||
FaissBaseIndex::LoadImpl(index_binary); // load on cpu
|
||||
auto* ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->backup_quantizer();
|
||||
gpu_mode = 0;
|
||||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& cfg) {
|
||||
// std::lock_guard<std::mutex> lk(g_mutex);
|
||||
// static int64_t search_count;
|
||||
// ++search_count;
|
||||
|
||||
if (gpu_mode == 2) {
|
||||
GPUIVF::search_impl(n, data, k, distances, labels, cfg);
|
||||
// index_->search(n, (float*)data, k, distances, labels);
|
||||
} else if (gpu_mode == 1) { // hybrid
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(quantizer_gpu_id_)) {
|
||||
ResScope rs(res, quantizer_gpu_id_, true);
|
||||
IVF::search_impl(n, data, k, distances, labels, cfg);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Hybrid Search Error, can't get gpu: " + std::to_string(quantizer_gpu_id_) + "resource");
|
||||
}
|
||||
} else if (gpu_mode == 0) {
|
||||
IVF::search_impl(n, data, k, distances, labels, cfg);
|
||||
}
|
||||
}
|
||||
|
||||
QuantizerPtr
|
||||
IVFSQHybrid::LoadQuantizer(const Config& config) {
|
||||
// std::lock_guard<std::mutex> lk(g_mutex);
|
||||
|
||||
auto gpu_id = config[knowhere::meta::DEVICEID].get<int64_t>();
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
|
||||
ResScope rs(res, gpu_id, false);
|
||||
faiss::gpu::GpuClonerOptions option;
|
||||
option.allInGpu = true;
|
||||
|
||||
auto index_composition = new faiss::IndexComposition;
|
||||
index_composition->index = index_.get();
|
||||
index_composition->quantizer = nullptr;
|
||||
index_composition->mode = 1; // only 1
|
||||
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
|
||||
delete gpu_index;
|
||||
|
||||
auto q = std::make_shared<FaissIVFQuantizer>();
|
||||
|
||||
auto& q_ptr = index_composition->quantizer;
|
||||
q->size = q_ptr->d * q_ptr->getNumVecs() * sizeof(float);
|
||||
q->quantizer = q_ptr;
|
||||
q->gpu_id = gpu_id;
|
||||
res_ = res;
|
||||
gpu_mode = 1;
|
||||
return q;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu: " + std::to_string(gpu_id) + "resource");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::SetQuantizer(const QuantizerPtr& q) {
|
||||
// std::lock_guard<std::mutex> lk(g_mutex);
|
||||
|
||||
auto ivf_quantizer = std::dynamic_pointer_cast<FaissIVFQuantizer>(q);
|
||||
if (ivf_quantizer == nullptr) {
|
||||
KNOWHERE_THROW_MSG("Quantizer type error");
|
||||
}
|
||||
|
||||
faiss::IndexIVF* ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
|
||||
faiss::gpu::GpuIndexFlat* is_gpu_flat_index = dynamic_cast<faiss::gpu::GpuIndexFlat*>(ivf_index->quantizer);
|
||||
if (is_gpu_flat_index == nullptr) {
|
||||
// delete ivf_index->quantizer;
|
||||
ivf_index->quantizer = ivf_quantizer->quantizer;
|
||||
}
|
||||
quantizer_gpu_id_ = ivf_quantizer->gpu_id;
|
||||
gpu_mode = 1;
|
||||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::UnsetQuantizer() {
|
||||
// std::lock_guard<std::mutex> lk(g_mutex);
|
||||
|
||||
auto* ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
if (ivf_index == nullptr) {
|
||||
KNOWHERE_THROW_MSG("Index type error");
|
||||
}
|
||||
|
||||
ivf_index->quantizer = nullptr;
|
||||
quantizer_gpu_id_ = -1;
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& config) {
|
||||
// std::lock_guard<std::mutex> lk(g_mutex);
|
||||
|
||||
int64_t gpu_id = config[knowhere::meta::DEVICEID];
|
||||
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
|
||||
ResScope rs(res, gpu_id, false);
|
||||
faiss::gpu::GpuClonerOptions option;
|
||||
option.allInGpu = true;
|
||||
|
||||
auto ivf_quantizer = std::dynamic_pointer_cast<FaissIVFQuantizer>(q);
|
||||
if (ivf_quantizer == nullptr)
|
||||
KNOWHERE_THROW_MSG("quantizer type not faissivfquantizer");
|
||||
|
||||
auto index_composition = new faiss::IndexComposition;
|
||||
index_composition->index = index_.get();
|
||||
index_composition->quantizer = ivf_quantizer->quantizer;
|
||||
index_composition->mode = 2; // only 2
|
||||
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
|
||||
std::shared_ptr<faiss::Index> new_idx;
|
||||
new_idx.reset(gpu_index);
|
||||
auto sq_idx = std::make_shared<IVFSQHybrid>(new_idx, gpu_id, res);
|
||||
return sq_idx;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu: " + std::to_string(gpu_id) + "resource");
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<VectorIndexPtr, QuantizerPtr>
|
||||
IVFSQHybrid::CopyCpuToGpuWithQuantizer(const int64_t& device_id, const Config& config) {
|
||||
// std::lock_guard<std::mutex> lk(g_mutex);
|
||||
|
||||
std::pair<VecIndexPtr, QuantizerPtr>
|
||||
IVFSQHybrid::CopyCpuToGpuWithQuantizer(const int64_t device_id, const Config& config) {
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
faiss::gpu::GpuClonerOptions option;
|
||||
|
@ -253,29 +128,102 @@ IVFSQHybrid::CopyCpuToGpuWithQuantizer(const int64_t& device_id, const Config& c
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::set_index_model(IndexModelPtr model) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
VecIndexPtr
|
||||
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& quantizer_ptr, const Config& config) {
|
||||
int64_t gpu_id = config[knowhere::meta::DEVICEID];
|
||||
|
||||
auto host_index = std::static_pointer_cast<IVFIndexModel>(model);
|
||||
if (auto gpures = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) {
|
||||
ResScope rs(gpures, gpu_id_, false);
|
||||
auto device_index = faiss::gpu::index_cpu_to_gpu(gpures->faiss_res.get(), gpu_id_, host_index->index_.get());
|
||||
index_.reset(device_index);
|
||||
res_ = gpures;
|
||||
gpu_mode = 2;
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
|
||||
ResScope rs(res, gpu_id, false);
|
||||
faiss::gpu::GpuClonerOptions option;
|
||||
option.allInGpu = true;
|
||||
|
||||
auto ivf_quantizer = std::dynamic_pointer_cast<FaissIVFQuantizer>(quantizer_ptr);
|
||||
if (ivf_quantizer == nullptr)
|
||||
KNOWHERE_THROW_MSG("quantizer type not faissivfquantizer");
|
||||
|
||||
auto index_composition = new faiss::IndexComposition;
|
||||
index_composition->index = index_.get();
|
||||
index_composition->quantizer = ivf_quantizer->quantizer;
|
||||
index_composition->mode = 2; // only 2
|
||||
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
|
||||
std::shared_ptr<faiss::Index> new_idx;
|
||||
new_idx.reset(gpu_index);
|
||||
auto sq_idx = std::make_shared<IVFSQHybrid>(new_idx, gpu_id, res);
|
||||
return sq_idx;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("load index model error, can't get gpu_resource");
|
||||
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu: " + std::to_string(gpu_id) + "resource");
|
||||
}
|
||||
}
|
||||
|
||||
QuantizerPtr
|
||||
IVFSQHybrid::LoadQuantizer(const Config& config) {
|
||||
auto gpu_id = config[knowhere::meta::DEVICEID].get<int64_t>();
|
||||
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
|
||||
ResScope rs(res, gpu_id, false);
|
||||
faiss::gpu::GpuClonerOptions option;
|
||||
option.allInGpu = true;
|
||||
|
||||
auto index_composition = new faiss::IndexComposition;
|
||||
index_composition->index = index_.get();
|
||||
index_composition->quantizer = nullptr;
|
||||
index_composition->mode = 1; // only 1
|
||||
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
|
||||
delete gpu_index;
|
||||
|
||||
auto q = std::make_shared<FaissIVFQuantizer>();
|
||||
|
||||
auto& q_ptr = index_composition->quantizer;
|
||||
q->size = q_ptr->d * q_ptr->getNumVecs() * sizeof(float);
|
||||
q->quantizer = q_ptr;
|
||||
q->gpu_id = gpu_id;
|
||||
res_ = res;
|
||||
gpu_mode_ = 1;
|
||||
return q;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu: " + std::to_string(gpu_id) + "resource");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::SetQuantizer(const QuantizerPtr& quantizer_ptr) {
|
||||
auto ivf_quantizer = std::dynamic_pointer_cast<FaissIVFQuantizer>(quantizer_ptr);
|
||||
if (ivf_quantizer == nullptr) {
|
||||
KNOWHERE_THROW_MSG("Quantizer type error");
|
||||
}
|
||||
|
||||
faiss::IndexIVF* ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
|
||||
faiss::gpu::GpuIndexFlat* is_gpu_flat_index = dynamic_cast<faiss::gpu::GpuIndexFlat*>(ivf_index->quantizer);
|
||||
if (is_gpu_flat_index == nullptr) {
|
||||
// delete ivf_index->quantizer;
|
||||
ivf_index->quantizer = ivf_quantizer->quantizer;
|
||||
}
|
||||
quantizer_gpu_id_ = ivf_quantizer->gpu_id;
|
||||
gpu_mode_ = 1;
|
||||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::UnsetQuantizer() {
|
||||
auto* ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
if (ivf_index == nullptr) {
|
||||
KNOWHERE_THROW_MSG("Index type error");
|
||||
}
|
||||
|
||||
ivf_index->quantizer = nullptr;
|
||||
quantizer_gpu_id_ = -1;
|
||||
}
|
||||
|
||||
BinarySet
|
||||
IVFSQHybrid::SerializeImpl() {
|
||||
IVFSQHybrid::SerializeImpl(const IndexType& type) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
fiu_do_on("IVFSQHybrid.SerializeImpl.zero_gpu_mode", gpu_mode = 0);
|
||||
if (gpu_mode == 0) {
|
||||
|
||||
fiu_do_on("IVFSQHybrid.SerializeImpl.zero_gpu_mode", gpu_mode_ = 0);
|
||||
if (gpu_mode_ == 0) {
|
||||
MemoryIOWriter writer;
|
||||
faiss::write_index(index_.get(), &writer);
|
||||
|
||||
|
@ -286,13 +234,39 @@ IVFSQHybrid::SerializeImpl() {
|
|||
res_set.Append("IVF", data, writer.rp);
|
||||
|
||||
return res_set;
|
||||
} else if (gpu_mode == 2) {
|
||||
return GPUIVF::SerializeImpl();
|
||||
} else if (gpu_mode_ == 2) {
|
||||
return GPUIVF::SerializeImpl(type);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Can't serialize IVFSQ8Hybrid");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
|
||||
FaissBaseIndex::LoadImpl(binary_set, index_type_); // load on cpu
|
||||
auto* ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->backup_quantizer();
|
||||
gpu_mode_ = 0;
|
||||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& config) {
|
||||
if (gpu_mode_ == 2) {
|
||||
GPUIVF::QueryImpl(n, data, k, distances, labels, config);
|
||||
// index_->search(n, (float*)data, k, distances, labels);
|
||||
} else if (gpu_mode_ == 1) { // hybrid
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(quantizer_gpu_id_)) {
|
||||
ResScope rs(res, quantizer_gpu_id_, true);
|
||||
IVF::QueryImpl(n, data, k, distances, labels, config);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Hybrid Search Error, can't get gpu: " + std::to_string(quantizer_gpu_id_) + "resource");
|
||||
}
|
||||
} else if (gpu_mode_ == 0) {
|
||||
IVF::QueryImpl(n, data, k, distances, labels, config);
|
||||
}
|
||||
}
|
||||
|
||||
FaissIVFQuantizer::~FaissIVFQuantizer() {
|
||||
if (quantizer != nullptr) {
|
||||
delete quantizer;
|
||||
|
@ -302,4 +276,6 @@ FaissIVFQuantizer::~FaissIVFQuantizer() {
|
|||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,7 +7,7 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
|
@ -17,9 +17,10 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexGPUIVFSQ.h"
|
||||
#include "Quantizer.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/gpu/Quantizer.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
#ifdef CUSTOMIZATION
|
||||
|
@ -34,22 +35,37 @@ using FaissIVFQuantizerPtr = std::shared_ptr<FaissIVFQuantizer>;
|
|||
class IVFSQHybrid : public GPUIVFSQ {
|
||||
public:
|
||||
explicit IVFSQHybrid(const int& device_id) : GPUIVFSQ(device_id) {
|
||||
gpu_mode = 0;
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8H;
|
||||
gpu_mode_ = 0;
|
||||
}
|
||||
|
||||
explicit IVFSQHybrid(std::shared_ptr<faiss::Index> index) : GPUIVFSQ(-1) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8H;
|
||||
index_ = index;
|
||||
gpu_mode = 0;
|
||||
gpu_mode_ = 0;
|
||||
}
|
||||
|
||||
explicit IVFSQHybrid(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& resource)
|
||||
: GPUIVFSQ(index, device_id, resource) {
|
||||
gpu_mode = 2;
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8H;
|
||||
gpu_mode_ = 2;
|
||||
}
|
||||
|
||||
public:
|
||||
void
|
||||
set_index_model(IndexModelPtr model) override;
|
||||
Train(const DatasetPtr&, const Config&) override;
|
||||
|
||||
VecIndexPtr
|
||||
CopyGpuToCpu(const Config&) override;
|
||||
|
||||
VecIndexPtr
|
||||
CopyCpuToGpu(const int64_t, const Config&) override;
|
||||
|
||||
std::pair<VecIndexPtr, QuantizerPtr>
|
||||
CopyCpuToGpuWithQuantizer(const int64_t, const Config&);
|
||||
|
||||
VecIndexPtr
|
||||
LoadData(const knowhere::QuantizerPtr&, const Config&);
|
||||
|
||||
QuantizerPtr
|
||||
LoadQuantizer(const Config& conf);
|
||||
|
@ -60,35 +76,24 @@ class IVFSQHybrid : public GPUIVFSQ {
|
|||
void
|
||||
UnsetQuantizer();
|
||||
|
||||
VectorIndexPtr
|
||||
LoadData(const knowhere::QuantizerPtr& q, const Config& conf);
|
||||
|
||||
std::pair<VectorIndexPtr, QuantizerPtr>
|
||||
CopyCpuToGpuWithQuantizer(const int64_t& device_id, const Config& config);
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const Config& config) override;
|
||||
|
||||
VectorIndexPtr
|
||||
CopyCpuToGpu(const int64_t& device_id, const Config& config) override;
|
||||
|
||||
protected:
|
||||
BinarySet
|
||||
SerializeImpl();
|
||||
|
||||
protected:
|
||||
void
|
||||
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) override;
|
||||
SerializeImpl(const IndexType&) override;
|
||||
|
||||
void
|
||||
LoadImpl(const BinarySet& index_binary) override;
|
||||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
|
||||
|
||||
protected:
|
||||
int64_t gpu_mode = 0; // 0,1,2
|
||||
int64_t gpu_mode_ = 0; // 0,1,2
|
||||
int64_t quantizer_gpu_id_ = -1;
|
||||
};
|
||||
|
||||
using IVFSQHybridPtr = std::shared_ptr<IVFSQHybrid>;
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,13 +7,14 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include "knowhere/common/Config.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
struct Quantizer {
|
||||
|
@ -29,3 +30,4 @@ using QuantizerPtr = std::shared_ptr<Quantizer>;
|
|||
// using QuantizerConfig = std::shared_ptr<QuantizerCfg>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,24 +7,26 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/Cloner.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQHybrid.h"
|
||||
#include "knowhere/index/vector_index/gpu/GPUIndex.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace cloner {
|
||||
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const VectorIndexPtr& index, const Config& config) {
|
||||
VecIndexPtr
|
||||
CopyGpuToCpu(const VecIndexPtr& index, const Config& config) {
|
||||
if (auto device_index = std::dynamic_pointer_cast<GPUIndex>(index)) {
|
||||
VectorIndexPtr result = device_index->CopyGpuToCpu(config);
|
||||
VecIndexPtr result = device_index->CopyGpuToCpu(config);
|
||||
auto uids = index->GetUids();
|
||||
result->SetUids(uids);
|
||||
return result;
|
||||
|
@ -33,9 +35,9 @@ CopyGpuToCpu(const VectorIndexPtr& index, const Config& config) {
|
|||
}
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
CopyCpuToGpu(const VectorIndexPtr& index, const int64_t& device_id, const Config& config) {
|
||||
VectorIndexPtr result;
|
||||
VecIndexPtr
|
||||
CopyCpuToGpu(const VecIndexPtr& index, const int64_t device_id, const Config& config) {
|
||||
VecIndexPtr result;
|
||||
auto uids = index->GetUids();
|
||||
#ifdef CUSTOMIZATION
|
||||
if (auto device_index = std::dynamic_pointer_cast<IVFSQHybrid>(index)) {
|
||||
|
@ -60,7 +62,7 @@ CopyCpuToGpu(const VectorIndexPtr& index, const int64_t& device_id, const Config
|
|||
} else if (auto cpu_index = std::dynamic_pointer_cast<IDMAP>(index)) {
|
||||
result = cpu_index->CopyCpuToGpu(device_id, config);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("this index type not support tranfer to gpu");
|
||||
KNOWHERE_THROW_MSG("this index type not support transfer to gpu");
|
||||
}
|
||||
|
||||
result->SetUids(uids);
|
||||
|
@ -69,3 +71,4 @@ CopyCpuToGpu(const VectorIndexPtr& index, const int64_t& device_id, const Config
|
|||
|
||||
} // namespace cloner
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,21 +7,22 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "knowhere/index/vector_index/VectorIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace cloner {
|
||||
|
||||
// TODO(linxj): rename CopyToGpu
|
||||
extern VectorIndexPtr
|
||||
CopyCpuToGpu(const VectorIndexPtr& index, const int64_t& device_id, const Config& config);
|
||||
extern VecIndexPtr
|
||||
CopyCpuToGpu(const VecIndexPtr& index, const int64_t device_id, const Config& config);
|
||||
|
||||
extern VectorIndexPtr
|
||||
CopyGpuToCpu(const VectorIndexPtr& index, const Config& config);
|
||||
extern VecIndexPtr
|
||||
CopyGpuToCpu(const VecIndexPtr& index, const Config& config);
|
||||
|
||||
} // namespace cloner
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace knowhere {
|
||||
namespace definition {
|
||||
|
||||
#define META_ROWS ("rows")
|
||||
#define META_DIM ("dimension")
|
||||
#define META_K ("k")
|
||||
|
||||
} // namespace definition
|
||||
} // namespace knowhere
|
|
@ -7,13 +7,14 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
|
||||
#include <fiu-local.h>
|
||||
#include <utility>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
FaissGpuResourceMgr&
|
||||
|
@ -120,3 +121,4 @@ FaissGpuResourceMgr::Dump() {
|
|||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
|
@ -18,8 +18,9 @@
|
|||
|
||||
#include <faiss/gpu/StandardGpuResources.h>
|
||||
|
||||
#include "utils/BlockingQueue.h"
|
||||
#include "src/utils/BlockingQueue.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
struct Resource {
|
||||
|
@ -125,3 +126,4 @@ class ResScope {
|
|||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,12 +7,13 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
// TODO(linxj): Get From Config File
|
||||
|
@ -27,6 +28,7 @@ MemoryIOWriter::operator()(const void* ptr, size_t size, size_t nitems) {
|
|||
rp = size * nitems;
|
||||
data_ = new uint8_t[total];
|
||||
memcpy((void*)(data_), ptr, rp);
|
||||
return nitems;
|
||||
}
|
||||
|
||||
if (total_need > total) {
|
||||
|
@ -59,3 +61,4 @@ MemoryIOReader::operator()(void* ptr, size_t size, size_t nitems) {
|
|||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,12 +7,13 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <faiss/impl/io.h>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
struct MemoryIOWriter : public faiss::IOWriter {
|
||||
|
@ -46,3 +47,4 @@ struct MemoryIOReader : public faiss::IOReader {
|
|||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include <faiss/Index.h>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
faiss::MetricType
|
||||
|
@ -44,3 +45,4 @@ GetMetricType(const std::string& type) {
|
|||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include <faiss/Index.h>
|
||||
#include <string>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
namespace meta {
|
||||
|
@ -60,3 +61,4 @@ extern faiss::MetricType
|
|||
GetMetricType(const std::string& type);
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include "knowhere/index/vector_index/helpers/SPTAGParameterMgr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
const Config&
|
||||
|
@ -65,3 +66,4 @@ SPTAGParameterMgr::SPTAGParameterMgr() {
|
|||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -20,7 +20,9 @@
|
|||
#include "IndexParameter.h"
|
||||
#include "knowhere/common/Config.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class SPTAGParameterMgr {
|
||||
public:
|
||||
const Config&
|
||||
|
@ -50,3 +52,4 @@ class SPTAGParameterMgr {
|
|||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -7,14 +7,15 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "knowhere/index/vector_index/nsg/Distance.h"
|
||||
#include "knowhere/index/vector_index/impl/nsg/Distance.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
namespace impl {
|
||||
|
||||
float
|
||||
DistanceL2::Compare(const float* a, const float* b, unsigned size) const {
|
||||
|
@ -235,5 +236,6 @@ DistanceIP::Compare(const float* a, const float* b, unsigned size) const {
|
|||
// return faiss::fvec_inner_product(a,b,size);
|
||||
//}
|
||||
|
||||
} // namespace algo
|
||||
} // namespace impl
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,12 +7,13 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
namespace impl {
|
||||
|
||||
struct Distance {
|
||||
virtual float
|
||||
|
@ -29,5 +30,6 @@ struct DistanceIP : public Distance {
|
|||
Compare(const float* a, const float* b, unsigned size) const override;
|
||||
};
|
||||
|
||||
} // namespace algo
|
||||
} // namespace impl
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,31 +7,38 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
#include "knowhere/index/vector_index/nsg/NSG.h"
|
||||
#include "knowhere/index/vector_index/nsg/NSGHelper.h"
|
||||
|
||||
//#include <gperftools/profiler.h>
|
||||
#include "knowhere/index/vector_index/impl/nsg/NSG.h"
|
||||
#include "knowhere/index/vector_index/impl/nsg/NSGHelper.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
namespace impl {
|
||||
|
||||
unsigned int seed = 100;
|
||||
|
||||
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, const std::string& metric)
|
||||
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, std::string metric)
|
||||
: dimension(dimension), ntotal(n), metric_type(metric) {
|
||||
distance_ = new DistanceL2; // hardcode here
|
||||
// switch (metric) {
|
||||
// case METRICTYPE::L2:
|
||||
// break;
|
||||
// case METRICTYPE::IP:
|
||||
// distance_ = new DistanceIP;
|
||||
// break;
|
||||
// }
|
||||
distance_ = new DistanceL2;
|
||||
}
|
||||
|
||||
NsgIndex::~NsgIndex() {
|
||||
|
@ -53,7 +60,6 @@ NsgIndex::Build_with_ids(size_t nb, const float* data, const int64_t* ids, const
|
|||
candidate_pool_size = parameters.candidate_pool_size;
|
||||
|
||||
TimeRecorder rc("NSG", 1);
|
||||
|
||||
InitNavigationPoint();
|
||||
rc.RecordSection("init");
|
||||
|
||||
|
@ -62,17 +68,17 @@ NsgIndex::Build_with_ids(size_t nb, const float* data, const int64_t* ids, const
|
|||
|
||||
CheckConnectivity();
|
||||
rc.RecordSection("Connect");
|
||||
rc.ElapseFromBegin("finish");
|
||||
|
||||
is_trained = true;
|
||||
|
||||
int total_degree = 0;
|
||||
for (size_t i = 0; i < ntotal; ++i) {
|
||||
total_degree += nsg[i].size();
|
||||
}
|
||||
|
||||
KNOWHERE_LOG_DEBUG << "Graph physical size: " << total_degree * sizeof(node_t) / 1024 / 1024 << "m";
|
||||
KNOWHERE_LOG_DEBUG << "Average degree: " << total_degree / ntotal;
|
||||
|
||||
is_trained = true;
|
||||
|
||||
// Debug code
|
||||
// for (size_t i = 0; i < ntotal; i++) {
|
||||
// auto& x = nsg[i];
|
||||
|
@ -125,7 +131,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, std::v
|
|||
size_t buffer_size = search_length;
|
||||
|
||||
if (buffer_size > ntotal) {
|
||||
// TODO: throw exception here.
|
||||
KNOWHERE_THROW_MSG("Build Error, search_length > ntotal");
|
||||
}
|
||||
|
||||
resset.resize(search_length);
|
||||
|
@ -226,7 +232,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, std::v
|
|||
size_t buffer_size = search_length;
|
||||
|
||||
if (buffer_size > ntotal) {
|
||||
// TODO: throw exception here.
|
||||
KNOWHERE_THROW_MSG("Build Error, search_length > ntotal");
|
||||
}
|
||||
|
||||
// std::vector<node_t> init_ids;
|
||||
|
@ -322,7 +328,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, Graph&
|
|||
size_t buffer_size = params ? params->search_length : search_length;
|
||||
|
||||
if (buffer_size > ntotal) {
|
||||
// TODO: throw exception here.
|
||||
KNOWHERE_THROW_MSG("Build Error, search_length > ntotal");
|
||||
}
|
||||
|
||||
// std::vector<node_t> init_ids;
|
||||
|
@ -445,6 +451,7 @@ NsgIndex::Link() {
|
|||
// std::cout << "id: " << fullset[k].id << ", dis: " << fullset[k].distance << std::endl;
|
||||
// }
|
||||
}
|
||||
knng.clear();
|
||||
|
||||
// Debug code
|
||||
// for (size_t i = 0; i < ntotal; i++)
|
||||
|
@ -457,14 +464,11 @@ NsgIndex::Link() {
|
|||
// std::cout << std::endl;
|
||||
// }
|
||||
|
||||
knng.clear();
|
||||
|
||||
std::vector<std::mutex> mutex_vec(ntotal);
|
||||
#pragma omp for schedule(dynamic, 100)
|
||||
for (unsigned n = 0; n < ntotal; ++n) {
|
||||
InterInsert(n, mutex_vec, cut_graph_dist);
|
||||
}
|
||||
|
||||
delete[] cut_graph_dist;
|
||||
}
|
||||
|
||||
|
@ -527,7 +531,7 @@ NsgIndex::InterInsert(unsigned n, std::vector<std::mutex>& mutex_vec, float* cut
|
|||
if (nsn_dist_pool[j] == -1)
|
||||
break;
|
||||
|
||||
// 保证至少有一条边能连回来
|
||||
// At least one edge can be connected back
|
||||
if (n == nsn_id_pool[j]) {
|
||||
duplicate = true;
|
||||
break;
|
||||
|
@ -693,41 +697,139 @@ NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<>& has_linked, int64_t& root
|
|||
nsg[root].push_back(id);
|
||||
}
|
||||
|
||||
void
|
||||
NsgIndex::GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params) {
|
||||
size_t buffer_size = params ? params->search_length : search_length;
|
||||
|
||||
if (buffer_size > ntotal) {
|
||||
KNOWHERE_THROW_MSG("Search Error, search_length > ntotal");
|
||||
}
|
||||
|
||||
std::vector<Neighbor> resset(buffer_size);
|
||||
std::vector<node_t> init_ids(buffer_size);
|
||||
boost::dynamic_bitset<> has_calculated_dist{ntotal, 0};
|
||||
|
||||
{
|
||||
/*
|
||||
* copy navigation-point neighbor, pick random node if less than buffer size
|
||||
*/
|
||||
size_t count = 0;
|
||||
|
||||
// Get all neighbors
|
||||
for (size_t i = 0; i < init_ids.size() && i < nsg[navigation_point].size(); ++i) {
|
||||
init_ids[i] = nsg[navigation_point][i];
|
||||
has_calculated_dist[init_ids[i]] = true;
|
||||
++count;
|
||||
}
|
||||
while (count < buffer_size) {
|
||||
node_t id = rand_r(&seed) % ntotal;
|
||||
if (has_calculated_dist[id])
|
||||
continue; // duplicate id
|
||||
init_ids[count] = id;
|
||||
++count;
|
||||
has_calculated_dist[id] = true;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// init resset and sort by distance
|
||||
for (size_t i = 0; i < init_ids.size(); ++i) {
|
||||
node_t id = init_ids[i];
|
||||
|
||||
if (id >= static_cast<node_t>(ntotal)) {
|
||||
KNOWHERE_THROW_MSG("Search Error, id > ntotal");
|
||||
}
|
||||
|
||||
float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension);
|
||||
resset[i] = Neighbor(id, dist, false);
|
||||
}
|
||||
std::sort(resset.begin(), resset.end()); // sort by distance
|
||||
|
||||
// search nearest neighbor
|
||||
size_t cursor = 0;
|
||||
while (cursor < buffer_size) {
|
||||
size_t nearest_updated_pos = buffer_size;
|
||||
|
||||
if (!resset[cursor].has_explored) {
|
||||
resset[cursor].has_explored = true;
|
||||
|
||||
node_t start_pos = resset[cursor].id;
|
||||
auto& wait_for_search_node_vec = nsg[start_pos];
|
||||
for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) {
|
||||
node_t id = wait_for_search_node_vec[i];
|
||||
if (has_calculated_dist[id])
|
||||
continue;
|
||||
has_calculated_dist[id] = true;
|
||||
|
||||
float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension);
|
||||
|
||||
if (dist >= resset[buffer_size - 1].distance)
|
||||
continue;
|
||||
|
||||
//// difference from other GetNeighbors
|
||||
Neighbor nn(id, dist, false);
|
||||
///////////////////////////////////////
|
||||
|
||||
size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node
|
||||
if (pos < nearest_updated_pos)
|
||||
nearest_updated_pos = pos;
|
||||
|
||||
//>> Debug code
|
||||
/////
|
||||
// std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " <<
|
||||
// nearest_updated_pos << std::endl;
|
||||
/////
|
||||
|
||||
// trick: avoid search query search_length < init_ids.size() ...
|
||||
if (buffer_size + 1 < resset.size())
|
||||
++buffer_size;
|
||||
}
|
||||
}
|
||||
if (cursor >= nearest_updated_pos) {
|
||||
cursor = nearest_updated_pos; // re-search from new pos
|
||||
} else {
|
||||
++cursor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ((resset.size() - params->k) >= 0) {
|
||||
for (size_t i = 0; i < params->k; ++i) {
|
||||
I[i] = resset[i].id;
|
||||
D[i] = resset[i].distance;
|
||||
}
|
||||
} else {
|
||||
size_t i = 0;
|
||||
for (; i < resset.size(); ++i) {
|
||||
I[i] = resset[i].id;
|
||||
D[i] = resset[i].distance;
|
||||
}
|
||||
for (; i < params->k; ++i) {
|
||||
I[i] = -1;
|
||||
D[i] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist,
|
||||
int64_t* ids, SearchParams& params) {
|
||||
std::vector<std::vector<Neighbor>> resset(nq);
|
||||
// if (k >= 45) {
|
||||
// params.search_length = k;
|
||||
// }
|
||||
|
||||
TimeRecorder rc("nsgsearch", 1);
|
||||
|
||||
TimeRecorder rc("NsgIndex::search", 1);
|
||||
if (nq == 1) {
|
||||
GetNeighbors(query, resset[0], nsg, ¶ms);
|
||||
GetNeighbors(query, ids, dist, ¶ms);
|
||||
} else {
|
||||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < nq; ++i) {
|
||||
const float* single_query = query + i * dim;
|
||||
GetNeighbors(single_query, resset[i], nsg, ¶ms);
|
||||
GetNeighbors(single_query, ids + i * k, dist + i * k, ¶ms);
|
||||
}
|
||||
}
|
||||
rc.RecordSection("search");
|
||||
for (unsigned int i = 0; i < nq; ++i) {
|
||||
int64_t var = resset[i].size() - k;
|
||||
if (var >= 0) {
|
||||
for (unsigned int j = 0; j < k; ++j) {
|
||||
ids[i * k + j] = ids_[resset[i][j].id];
|
||||
dist[i * k + j] = resset[i][j].distance;
|
||||
}
|
||||
} else {
|
||||
for (unsigned int j = 0; j < resset[i].size(); ++j) {
|
||||
ids[i * k + j] = ids_[resset[i][j].id];
|
||||
dist[i * k + j] = resset[i][j].distance;
|
||||
}
|
||||
for (unsigned int j = resset[i].size(); j < k; ++j) {
|
||||
ids[i * k + j] = -1;
|
||||
dist[i * k + j] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
rc.RecordSection("merge");
|
||||
rc.ElapseFromBegin("seach finish");
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -735,5 +837,6 @@ NsgIndex::SetKnnGraph(Graph& g) {
|
|||
knng = std::move(g);
|
||||
}
|
||||
|
||||
} // namespace algo
|
||||
} // namespace impl
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,7 +7,7 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
|
@ -22,8 +22,9 @@
|
|||
#include "Neighbor.h"
|
||||
#include "knowhere/common/Config.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
namespace impl {
|
||||
|
||||
using node_t = int64_t;
|
||||
|
||||
|
@ -35,6 +36,7 @@ struct BuildParams {
|
|||
|
||||
struct SearchParams {
|
||||
size_t search_length;
|
||||
size_t k;
|
||||
};
|
||||
|
||||
using Graph = std::vector<std::vector<node_t>>;
|
||||
|
@ -43,7 +45,7 @@ class NsgIndex {
|
|||
public:
|
||||
size_t dimension;
|
||||
size_t ntotal; // totabl nb of indexed vectors
|
||||
std::string metric_type; // todo(linxj) IP
|
||||
std::string metric_type; // L2 | IP
|
||||
Distance* distance_;
|
||||
|
||||
float* ori_data_;
|
||||
|
@ -63,7 +65,7 @@ class NsgIndex {
|
|||
size_t out_degree;
|
||||
|
||||
public:
|
||||
explicit NsgIndex(const size_t& dimension, const size_t& n, const std::string& metric = "L2");
|
||||
explicit NsgIndex(const size_t& dimension, const size_t& n, std::string metric = "L2");
|
||||
|
||||
NsgIndex() = default;
|
||||
|
||||
|
@ -105,10 +107,14 @@ class NsgIndex {
|
|||
void
|
||||
GetNeighbors(const float* query, std::vector<Neighbor>& resset, std::vector<Neighbor>& fullset);
|
||||
|
||||
// search and navigation-point
|
||||
// navigation-point
|
||||
void
|
||||
GetNeighbors(const float* query, std::vector<Neighbor>& resset, Graph& graph, SearchParams* param = nullptr);
|
||||
|
||||
// used by search
|
||||
void
|
||||
GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params);
|
||||
|
||||
void
|
||||
Link();
|
||||
|
||||
|
@ -131,5 +137,6 @@ class NsgIndex {
|
|||
FindUnconnectedNode(boost::dynamic_bitset<>& flags, int64_t& root);
|
||||
};
|
||||
|
||||
} // namespace algo
|
||||
} // namespace impl
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,14 +7,15 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "knowhere/index/vector_index/nsg/NSGHelper.h"
|
||||
#include "knowhere/index/vector_index/impl/nsg/NSGHelper.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
namespace impl {
|
||||
|
||||
// TODO: impl search && insert && return insert pos. why not just find and swap?
|
||||
int
|
||||
|
@ -61,5 +62,6 @@ InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn) {
|
|||
return right;
|
||||
}
|
||||
|
||||
}; // namespace algo
|
||||
}; // namespace impl
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,17 +7,19 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Neighbor.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
namespace impl {
|
||||
|
||||
extern int
|
||||
InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn);
|
||||
|
||||
} // namespace algo
|
||||
} // namespace impl
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,14 +7,15 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "knowhere/index/vector_index/nsg/NSGIO.h"
|
||||
#include "knowhere/index/vector_index/impl/nsg/NSGIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
namespace impl {
|
||||
|
||||
void
|
||||
write_index(NsgIndex* index, MemoryIOWriter& writer) {
|
||||
|
@ -59,5 +60,6 @@ read_index(MemoryIOReader& reader) {
|
|||
return index;
|
||||
}
|
||||
|
||||
} // namespace algo
|
||||
} // namespace impl
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,15 +7,16 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "NSG.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
namespace impl {
|
||||
|
||||
extern void
|
||||
write_index(NsgIndex* index, MemoryIOWriter& writer);
|
||||
|
@ -23,5 +24,6 @@ write_index(NsgIndex* index, MemoryIOWriter& writer);
|
|||
extern NsgIndex*
|
||||
read_index(MemoryIOReader& reader);
|
||||
|
||||
} // namespace algo
|
||||
} // namespace impl
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,14 +7,15 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
namespace impl {
|
||||
|
||||
using node_t = int64_t;
|
||||
|
||||
|
@ -38,19 +39,8 @@ struct Neighbor {
|
|||
}
|
||||
};
|
||||
|
||||
// struct SimpleNeighbor {
|
||||
// node_t id; // offset of node in origin data
|
||||
// float distance;
|
||||
//
|
||||
// SimpleNeighbor() = default;
|
||||
// explicit SimpleNeighbor(node_t id, float distance) : id{id}, distance{distance}{}
|
||||
//
|
||||
// inline bool operator<(const Neighbor &other) const {
|
||||
// return distance < other.distance;
|
||||
// }
|
||||
//};
|
||||
|
||||
typedef std::lock_guard<std::mutex> LockGuard;
|
||||
|
||||
} // namespace algo
|
||||
} // namespace impl
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
File diff suppressed because it is too large
Load Diff
|
@ -30,6 +30,7 @@ set(util_srcs
|
|||
${MILVUS_THIRDPARTY_SRC}/easyloggingpp/easylogging++.cc
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexType.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Exception.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Timer.cpp
|
||||
${INDEX_SOURCE_DIR}/unittest/utils.cpp
|
||||
|
@ -52,23 +53,23 @@ endif ()
|
|||
|
||||
#<IVF-TEST>
|
||||
set(ivf_srcs
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVF.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVF.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp
|
||||
)
|
||||
if (KNOWHERE_GPU_VERSION)
|
||||
set(ivf_srcs ${ivf_srcs}
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexGPUIDMAP.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexGPUIVF.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexGPUIVFSQ.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp
|
||||
)
|
||||
endif ()
|
||||
if (NOT TARGET test_ivf)
|
||||
|
@ -96,8 +97,7 @@ target_link_libraries(test_binaryidmap ${depend_libs} ${unittest_libs} ${basic_l
|
|||
|
||||
#<SPTAG-TEST>
|
||||
set(sptag_srcs
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/adapter/SptagAdapter.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/preprocessor/Normalize.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp
|
||||
)
|
||||
|
|
|
@ -15,12 +15,14 @@
|
|||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQHybrid.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h"
|
||||
#endif
|
||||
|
||||
int DEVICEID = 0;
|
||||
|
@ -32,35 +34,37 @@ constexpr int64_t PINMEM = 1024 * 1024 * 200;
|
|||
constexpr int64_t TEMPMEM = 1024 * 1024 * 300;
|
||||
constexpr int64_t RESNUM = 2;
|
||||
|
||||
knowhere::IVFIndexPtr
|
||||
IndexFactory(const std::string& type) {
|
||||
if (type == "IVF") {
|
||||
return std::make_shared<knowhere::IVF>();
|
||||
} else if (type == "IVFPQ") {
|
||||
return std::make_shared<knowhere::IVFPQ>();
|
||||
} else if (type == "IVFSQ") {
|
||||
return std::make_shared<knowhere::IVFSQ>();
|
||||
milvus::knowhere::IVFPtr
|
||||
IndexFactory(const milvus::knowhere::IndexType& type, const milvus::knowhere::IndexMode mode) {
|
||||
if (mode == milvus::knowhere::IndexMode::MODE_CPU) {
|
||||
if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) {
|
||||
return std::make_shared<milvus::knowhere::IVF>();
|
||||
} else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) {
|
||||
return std::make_shared<milvus::knowhere::IVFPQ>();
|
||||
} else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8) {
|
||||
return std::make_shared<milvus::knowhere::IVFSQ>();
|
||||
} else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) {
|
||||
std::cout << "IVFSQ8H does not support MODE_CPU" << std::endl;
|
||||
} else {
|
||||
std::cout << "Invalid IndexType " << type << std::endl;
|
||||
}
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
} else if (type == "GPUIVF") {
|
||||
return std::make_shared<knowhere::GPUIVF>(DEVICEID);
|
||||
} else if (type == "GPUIVFPQ") {
|
||||
return std::make_shared<knowhere::GPUIVFPQ>(DEVICEID);
|
||||
} else if (type == "GPUIVFSQ") {
|
||||
return std::make_shared<knowhere::GPUIVFSQ>(DEVICEID);
|
||||
#ifdef CUSTOMIZATION
|
||||
} else if (type == "IVFSQHybrid") {
|
||||
return std::make_shared<knowhere::IVFSQHybrid>(DEVICEID);
|
||||
#endif
|
||||
} else {
|
||||
if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) {
|
||||
return std::make_shared<milvus::knowhere::GPUIVF>(DEVICEID);
|
||||
} else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) {
|
||||
return std::make_shared<milvus::knowhere::GPUIVFPQ>(DEVICEID);
|
||||
} else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8) {
|
||||
return std::make_shared<milvus::knowhere::GPUIVFSQ>(DEVICEID);
|
||||
} else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) {
|
||||
return std::make_shared<milvus::knowhere::IVFSQHybrid>(DEVICEID);
|
||||
} else {
|
||||
std::cout << "Invalid IndexType " << type << std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
enum class ParameterType {
|
||||
ivf,
|
||||
ivfpq,
|
||||
ivfsq,
|
||||
};
|
||||
|
||||
class ParamGenerator {
|
||||
public:
|
||||
static ParamGenerator&
|
||||
|
@ -69,35 +73,41 @@ class ParamGenerator {
|
|||
return instance;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
Gen(const ParameterType& type) {
|
||||
if (type == ParameterType::ivf) {
|
||||
return knowhere::Config{
|
||||
{knowhere::meta::DIM, DIM},
|
||||
{knowhere::meta::TOPK, K},
|
||||
{knowhere::IndexParams::nlist, 100},
|
||||
{knowhere::IndexParams::nprobe, 4},
|
||||
{knowhere::Metric::TYPE, knowhere::Metric::L2},
|
||||
{knowhere::meta::DEVICEID, DEVICEID},
|
||||
milvus::knowhere::Config
|
||||
Gen(const milvus::knowhere::IndexType& type) {
|
||||
if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) {
|
||||
return milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, DIM},
|
||||
{milvus::knowhere::meta::TOPK, K},
|
||||
{milvus::knowhere::IndexParams::nlist, 100},
|
||||
{milvus::knowhere::IndexParams::nprobe, 4},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{milvus::knowhere::meta::DEVICEID, DEVICEID},
|
||||
};
|
||||
} else if (type == ParameterType::ivfpq) {
|
||||
return knowhere::Config{
|
||||
{knowhere::meta::DIM, DIM},
|
||||
{knowhere::meta::TOPK, K},
|
||||
{knowhere::IndexParams::nlist, 100},
|
||||
{knowhere::IndexParams::nprobe, 4},
|
||||
{knowhere::IndexParams::m, 4},
|
||||
{knowhere::IndexParams::nbits, 8},
|
||||
{knowhere::Metric::TYPE, knowhere::Metric::L2},
|
||||
{knowhere::meta::DEVICEID, DEVICEID},
|
||||
} else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) {
|
||||
return milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, DIM},
|
||||
{milvus::knowhere::meta::TOPK, K},
|
||||
{milvus::knowhere::IndexParams::nlist, 100},
|
||||
{milvus::knowhere::IndexParams::nprobe, 4},
|
||||
{milvus::knowhere::IndexParams::m, 4},
|
||||
{milvus::knowhere::IndexParams::nbits, 8},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{milvus::knowhere::meta::DEVICEID, DEVICEID},
|
||||
};
|
||||
} else if (type == ParameterType::ivfsq) {
|
||||
return knowhere::Config{
|
||||
{knowhere::meta::DIM, DIM}, {knowhere::meta::TOPK, K},
|
||||
{knowhere::IndexParams::nlist, 100}, {knowhere::IndexParams::nprobe, 4},
|
||||
{knowhere::IndexParams::nbits, 8}, {knowhere::Metric::TYPE, knowhere::Metric::L2},
|
||||
{knowhere::meta::DEVICEID, DEVICEID},
|
||||
} else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 ||
|
||||
type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) {
|
||||
return milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, DIM},
|
||||
{milvus::knowhere::meta::TOPK, K},
|
||||
{milvus::knowhere::IndexParams::nlist, 100},
|
||||
{milvus::knowhere::IndexParams::nprobe, 4},
|
||||
{milvus::knowhere::IndexParams::nbits, 8},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{milvus::knowhere::meta::DEVICEID, DEVICEID},
|
||||
};
|
||||
} else {
|
||||
std::cout << "Invalid index type " << type << std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -109,14 +119,14 @@ class TestGpuIndexBase : public ::testing::Test {
|
|||
void
|
||||
SetUp() override {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM);
|
||||
milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM);
|
||||
#endif
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
knowhere::FaissGpuResourceMgr::GetInstance().Free();
|
||||
milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
|
|
@ -26,14 +26,14 @@ class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<std::string>
|
|||
void
|
||||
SetUp() override {
|
||||
Init_with_binary_default();
|
||||
index_ = std::make_shared<knowhere::BinaryIDMAP>();
|
||||
index_ = std::make_shared<milvus::knowhere::BinaryIDMAP>();
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override{};
|
||||
|
||||
protected:
|
||||
knowhere::BinaryIDMAPPtr index_ = nullptr;
|
||||
milvus::knowhere::BinaryIDMAPPtr index_ = nullptr;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest,
|
||||
|
@ -43,26 +43,26 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
|||
ASSERT_TRUE(!xb.empty());
|
||||
|
||||
std::string MetricType = GetParam();
|
||||
knowhere::Config conf{
|
||||
{knowhere::meta::DIM, dim},
|
||||
{knowhere::meta::TOPK, k},
|
||||
{knowhere::Metric::TYPE, MetricType},
|
||||
milvus::knowhere::Config conf{
|
||||
{milvus::knowhere::meta::DIM, dim},
|
||||
{milvus::knowhere::meta::TOPK, k},
|
||||
{milvus::knowhere::Metric::TYPE, MetricType},
|
||||
};
|
||||
|
||||
index_->Train(conf);
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
ASSERT_TRUE(index_->GetRawVectors() != nullptr);
|
||||
ASSERT_TRUE(index_->GetRawIds() != nullptr);
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
|
||||
auto binaryset = index_->Serialize();
|
||||
auto new_index = std::make_shared<knowhere::BinaryIDMAP>();
|
||||
auto new_index = std::make_shared<milvus::knowhere::BinaryIDMAP>();
|
||||
new_index->Load(binaryset);
|
||||
auto result2 = index_->Search(query_dataset, conf);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result2, nq, k);
|
||||
// PrintResult(re_result, nq, k);
|
||||
|
||||
|
@ -72,7 +72,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
|||
}
|
||||
index_->SetBlacklist(concurrent_bitset_ptr);
|
||||
|
||||
auto result3 = index_->Search(query_dataset, conf);
|
||||
auto result3 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result3, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
|
||||
// auto result4 = index_->SearchById(id_dataset, conf);
|
||||
|
@ -80,7 +80,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
|||
}
|
||||
|
||||
TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
|
||||
auto serialize = [](const std::string& filename, knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
|
||||
|
@ -89,21 +89,21 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
|
|||
};
|
||||
|
||||
std::string MetricType = GetParam();
|
||||
knowhere::Config conf{
|
||||
{knowhere::meta::DIM, dim},
|
||||
{knowhere::meta::TOPK, k},
|
||||
{knowhere::Metric::TYPE, MetricType},
|
||||
milvus::knowhere::Config conf{
|
||||
{milvus::knowhere::meta::DIM, dim},
|
||||
{milvus::knowhere::meta::TOPK, k},
|
||||
{milvus::knowhere::Metric::TYPE, MetricType},
|
||||
};
|
||||
|
||||
{
|
||||
// serialize index
|
||||
index_->Train(conf);
|
||||
index_->Add(base_dataset, knowhere::Config());
|
||||
auto re_result = index_->Search(query_dataset, conf);
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, milvus::knowhere::Config());
|
||||
auto re_result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(re_result, nq, k);
|
||||
// PrintResult(re_result, nq, k);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
auto binaryset = index_->Serialize();
|
||||
auto bin = binaryset.GetByName("BinaryIVF");
|
||||
|
||||
|
@ -118,8 +118,8 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
|
|||
|
||||
index_->Load(binaryset);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
}
|
||||
|
|
|
@ -10,16 +10,13 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
|
||||
#include "knowhere/index/vector_index/IndexBinaryIVF.h"
|
||||
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "unittest/Helper.h"
|
||||
#include "unittest/utils.h"
|
||||
|
||||
|
@ -37,12 +34,12 @@ class BinaryIVFTest : public BinaryDataGen, public TestWithParam<std::string> {
|
|||
// nq = 1000;
|
||||
// k = 1000;
|
||||
// Generate(DIM, NB, NQ);
|
||||
index_ = std::make_shared<knowhere::BinaryIVF>();
|
||||
index_ = std::make_shared<milvus::knowhere::BinaryIVF>();
|
||||
|
||||
knowhere::Config temp_conf{
|
||||
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k},
|
||||
{knowhere::IndexParams::nlist, 100}, {knowhere::IndexParams::nprobe, 10},
|
||||
{knowhere::Metric::TYPE, MetricType},
|
||||
milvus::knowhere::Config temp_conf{
|
||||
{milvus::knowhere::meta::DIM, dim}, {milvus::knowhere::meta::TOPK, k},
|
||||
{milvus::knowhere::IndexParams::nlist, 100}, {milvus::knowhere::IndexParams::nprobe, 10},
|
||||
{milvus::knowhere::Metric::TYPE, MetricType},
|
||||
};
|
||||
conf = temp_conf;
|
||||
}
|
||||
|
@ -53,8 +50,8 @@ class BinaryIVFTest : public BinaryDataGen, public TestWithParam<std::string> {
|
|||
|
||||
protected:
|
||||
std::string index_type;
|
||||
knowhere::Config conf;
|
||||
knowhere::BinaryIVFIndexPtr index_ = nullptr;
|
||||
milvus::knowhere::Config conf;
|
||||
milvus::knowhere::BinaryIVFIndexPtr index_ = nullptr;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIVFTest,
|
||||
|
@ -70,10 +67,10 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
|
|||
// index_->set_index_model(model);
|
||||
// index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
|
||||
faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(nb);
|
||||
|
@ -82,10 +79,10 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
|
|||
}
|
||||
index_->SetBlacklist(concurrent_bitset_ptr);
|
||||
|
||||
auto result2 = index_->Search(query_dataset, conf);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
|
||||
auto result3 = index_->SearchById(id_dataset, conf);
|
||||
auto result3 = index_->QueryById(id_dataset, conf);
|
||||
AssertAnns(result3, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
|
||||
// auto result4 = index_->GetVectorById(xid_dataset, conf);
|
||||
|
@ -93,7 +90,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
|
|||
}
|
||||
|
||||
TEST_P(BinaryIVFTest, binaryivf_serialize) {
|
||||
auto serialize = [](const std::string& filename, knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
|
||||
|
@ -120,8 +117,8 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) {
|
|||
//
|
||||
// index_->set_index_model(model);
|
||||
// index_->Add(base_dataset, conf);
|
||||
// auto result = index_->Search(query_dataset, conf);
|
||||
// AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
// auto result = index_->Query(query_dataset, conf);
|
||||
// AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
// }
|
||||
|
||||
{
|
||||
|
@ -143,9 +140,9 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) {
|
|||
|
||||
index_->Load(binaryset);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
}
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue