mirror of https://github.com/milvus-io/milvus.git
Faiss hnsw upgrade (#3134)
* combine the hnsw's implemention of faiss and hnswlib Signed-off-by: cmli <chengming.li@zilliz.com> * transplant the datastructure of hnsw from hnswlib 2 faiss Signed-off-by: cmli <chengming.li@zilliz.com> * basic work finished, pass compile, to be tested Signed-off-by: cmli <chengming.li@zilliz.com> * rhnswflat, rhnswsq, rhnswpq pass ut Signed-off-by: cmli <chengming.li@zilliz.com> * remove AssertAnns of RHNSWPQ because PQ has accuracy loss Signed-off-by: cmli <chengming.li@zilliz.com> Co-authored-by: cmli <chengming.li@zilliz.com> Co-authored-by: shengjun.li <shengjun.li@zilliz.com>pull/3153/head^2
parent
b662295d63
commit
7688f51343
|
@ -59,6 +59,7 @@ Please mark all changes in change log and use the issue from GitHub
|
|||
- \#2802 Add new index: IVFSQ8NR
|
||||
- \#2834 Add C++ sdk support 4 hnsw_sq8nr
|
||||
- \#2940 Add option to build.sh for cuda arch
|
||||
- \#3132 Refine the implementation of hnsw in faiss and add support for hnsw-flat, hnsw-pq and hnsw-sq8 based on faiss
|
||||
|
||||
## Improvement
|
||||
- \#2543 Remove secondary_path related code
|
||||
|
|
|
@ -63,6 +63,10 @@ set(vector_index_srcs
|
|||
knowhere/index/IndexType.cpp
|
||||
knowhere/index/vector_index/VecIndexFactory.cpp
|
||||
knowhere/index/vector_index/IndexAnnoy.cpp
|
||||
knowhere/index/vector_index/IndexRHNSW.cpp
|
||||
knowhere/index/vector_index/IndexRHNSWFlat.cpp
|
||||
knowhere/index/vector_index/IndexRHNSWSQ.cpp
|
||||
knowhere/index/vector_index/IndexRHNSWPQ.cpp
|
||||
)
|
||||
|
||||
set(vector_offset_index_srcs
|
||||
|
|
|
@ -34,6 +34,9 @@ const char* INDEX_SPTAG_KDT_RNT = "SPTAG_KDT_RNT";
|
|||
const char* INDEX_SPTAG_BKT_RNT = "SPTAG_BKT_RNT";
|
||||
#endif
|
||||
const char* INDEX_HNSW = "HNSW";
|
||||
const char* INDEX_RHNSWFlat = "RHNSW_FLAT";
|
||||
const char* INDEX_RHNSWPQ = "RHNSW_PQ";
|
||||
const char* INDEX_RHNSWSQ = "RHNSW_SQ";
|
||||
const char* INDEX_ANNOY = "ANNOY";
|
||||
const char* INDEX_HNSW_SQ8NM = "HNSW_SQ8NM";
|
||||
} // namespace IndexEnum
|
||||
|
|
|
@ -37,6 +37,9 @@ enum class OldIndexType {
|
|||
ANNOY,
|
||||
FAISS_IVFSQ8NR,
|
||||
HNSW_SQ8NM,
|
||||
RHNSW_FLAT,
|
||||
RHNSW_PQ,
|
||||
RHNSW_SQ,
|
||||
FAISS_BIN_IDMAP = 100,
|
||||
FAISS_BIN_IVFLAT_CPU = 101,
|
||||
};
|
||||
|
@ -60,6 +63,9 @@ extern const char* INDEX_SPTAG_KDT_RNT;
|
|||
extern const char* INDEX_SPTAG_BKT_RNT;
|
||||
#endif
|
||||
extern const char* INDEX_HNSW;
|
||||
extern const char* INDEX_RHNSWFlat;
|
||||
extern const char* INDEX_RHNSWPQ;
|
||||
extern const char* INDEX_RHNSWSQ;
|
||||
extern const char* INDEX_ANNOY;
|
||||
extern const char* INDEX_HNSW_SQ8NM;
|
||||
} // namespace IndexEnum
|
||||
|
|
|
@ -280,6 +280,80 @@ HNSWSQ8NRConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const In
|
|||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWFlatConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t MIN_EFCONSTRUCTION = 8;
|
||||
static int64_t MAX_EFCONSTRUCTION = 512;
|
||||
static int64_t MIN_M = 4;
|
||||
static int64_t MAX_M = 64;
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWFlatConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
static int64_t MAX_EF = 4096;
|
||||
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t MIN_EFCONSTRUCTION = 8;
|
||||
static int64_t MAX_EFCONSTRUCTION = 512;
|
||||
static int64_t MIN_M = 4;
|
||||
static int64_t MAX_M = 64;
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
|
||||
|
||||
std::vector<int64_t> resset;
|
||||
int64_t dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
|
||||
IVFPQConfAdapter::GetValidMList(dimension, resset);
|
||||
|
||||
CheckIntByValues(knowhere::IndexParams::m, resset);
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWPQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
static int64_t MAX_EF = 4096;
|
||||
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t MIN_EFCONSTRUCTION = 8;
|
||||
static int64_t MAX_EFCONSTRUCTION = 512;
|
||||
static int64_t MIN_M = 4;
|
||||
static int64_t MAX_M = 64;
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWSQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
static int64_t MAX_EF = 4096;
|
||||
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
BinIDMAPConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
|
||||
|
|
|
@ -109,5 +109,31 @@ class IVFSQ8NRConfAdapter : public IVFConfAdapter {
|
|||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class RHNSWFlatConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
bool
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class RHNSWPQConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
bool
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class RHNSWSQConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
bool
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -51,6 +51,9 @@ AdapterMgr::RegisterAdapter() {
|
|||
REGISTER_CONF_ADAPTER(ANNOYConfAdapter, IndexEnum::INDEX_ANNOY, annoy_adapter);
|
||||
REGISTER_CONF_ADAPTER(HNSWSQ8NRConfAdapter, IndexEnum::INDEX_HNSW_SQ8NM, hnswsq8nr_adapter);
|
||||
REGISTER_CONF_ADAPTER(IVFSQ8NRConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8NR, ivfsq8nr_adapter);
|
||||
REGISTER_CONF_ADAPTER(RHNSWFlatConfAdapter, IndexEnum::INDEX_RHNSWFlat, rhnswflat_adapter);
|
||||
REGISTER_CONF_ADAPTER(RHNSWPQConfAdapter, IndexEnum::INDEX_RHNSWPQ, rhnswpq_adapter);
|
||||
REGISTER_CONF_ADAPTER(RHNSWSQConfAdapter, IndexEnum::INDEX_RHNSWSQ, rhnswsq_adapter);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -53,5 +53,11 @@ FaissBaseIndex::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
|
|||
SealImpl();
|
||||
}
|
||||
|
||||
void
|
||||
FaissBaseIndex::SealImpl() {
|
||||
}
|
||||
|
||||
// FaissBaseIndex::~FaissBaseIndex() {}
|
||||
//
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -34,8 +34,7 @@ class FaissBaseIndex {
|
|||
LoadImpl(const BinarySet&, const IndexType& type);
|
||||
|
||||
virtual void
|
||||
SealImpl() { /* do nothing */
|
||||
}
|
||||
SealImpl();
|
||||
|
||||
public:
|
||||
std::shared_ptr<faiss::Index> index_ = nullptr;
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexRHNSW.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "faiss/BuilderSuspend.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
BinarySet
|
||||
IndexRHNSW::Serialize(const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
try {
|
||||
MemoryIOWriter writer;
|
||||
writer.name = this->index_type() + "_Index";
|
||||
faiss::write_index(index_.get(), &writer);
|
||||
std::shared_ptr<uint8_t[]> data(writer.data_);
|
||||
|
||||
BinarySet res_set;
|
||||
res_set.Append(writer.name, data, writer.rp);
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSW::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
MemoryIOReader reader;
|
||||
reader.name = this->index_type() + "_Index";
|
||||
auto binary = index_binary.GetByName(reader.name);
|
||||
|
||||
reader.total = (size_t)binary->size;
|
||||
reader.data_ = binary->data.get();
|
||||
|
||||
auto idx = faiss::read_index(&reader);
|
||||
index_.reset(idx);
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
KNOWHERE_THROW_MSG("IndexRHNSW has no implementation of Train, please use IndexRHNSW(Flat/SQ/PQ) instead!");
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset_ptr)
|
||||
|
||||
index_->add(rows, (float*)p_data);
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset_ptr)
|
||||
|
||||
size_t k = config[meta::TOPK].get<int64_t>();
|
||||
size_t id_size = sizeof(int64_t) * k;
|
||||
size_t dist_size = sizeof(float) * k;
|
||||
auto p_id = (int64_t*)malloc(id_size * rows);
|
||||
auto p_dist = (float*)malloc(dist_size * rows);
|
||||
for (auto i = 0; i < k * rows; ++i) {
|
||||
p_id[i] = -1;
|
||||
p_dist[i] = -1;
|
||||
}
|
||||
|
||||
auto real_index = dynamic_cast<faiss::IndexRHNSW*>(index_.get());
|
||||
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
|
||||
|
||||
real_index->hnsw.efSearch = (config[IndexParams::ef]);
|
||||
real_index->search(rows, (float*)p_data, k, p_dist, p_id, blacklist);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
int64_t
|
||||
IndexRHNSW::Count() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t
|
||||
IndexRHNSW::Dim() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSW::UpdateIndexSize() {
|
||||
KNOWHERE_THROW_MSG(
|
||||
"IndexRHNSW has no implementation of UpdateIndexSize, please use IndexRHNSW(Flat/SQ/PQ) instead!");
|
||||
}
|
||||
|
||||
/*
|
||||
BinarySet
|
||||
IndexRHNSW::SerializeImpl(const milvus::knowhere::IndexType &type) { return BinarySet(); }
|
||||
|
||||
void
|
||||
IndexRHNSW::SealImpl() {}
|
||||
|
||||
void
|
||||
IndexRHNSW::LoadImpl(const milvus::knowhere::BinarySet &, const milvus::knowhere::IndexType &type) {}
|
||||
*/
|
||||
|
||||
void
|
||||
IndexRHNSW::AddWithoutIds(const milvus::knowhere::DatasetPtr& dataset, const milvus::knowhere::Config& config) {
|
||||
KNOWHERE_THROW_MSG("IndexRHNSW has no implementation of AddWithoutIds, please use IndexRHNSW(Flat/SQ/PQ) instead!");
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/FaissBaseIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
|
||||
#include <faiss/index_io.h>
|
||||
#include "faiss/IndexRHNSW.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IndexRHNSW : public VecIndex, public FaissBaseIndex {
|
||||
public:
|
||||
IndexRHNSW() : FaissBaseIndex(nullptr) {
|
||||
index_type_ = IndexEnum::INVALID;
|
||||
}
|
||||
|
||||
explicit IndexRHNSW(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
|
||||
index_type_ = IndexEnum::INVALID;
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize(const Config& config) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
void
|
||||
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
int64_t
|
||||
Dim() override;
|
||||
|
||||
void
|
||||
UpdateIndexSize() override;
|
||||
};
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexRHNSWFlat.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "faiss/BuilderSuspend.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
IndexRHNSWFlat::IndexRHNSWFlat(int d, int M, milvus::knowhere::MetricType metric) {
|
||||
faiss::MetricType mt =
|
||||
metric == Metric::L2 ? faiss::MetricType::METRIC_L2 : faiss::MetricType::METRIC_INNER_PRODUCT;
|
||||
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexRHNSWFlat(d, M, mt));
|
||||
}
|
||||
|
||||
BinarySet
|
||||
IndexRHNSWFlat::Serialize(const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
try {
|
||||
auto res_set = IndexRHNSW::Serialize(config);
|
||||
MemoryIOWriter writer;
|
||||
writer.name = this->index_type() + "_Data";
|
||||
auto real_idx = dynamic_cast<faiss::IndexRHNSWFlat*>(index_.get());
|
||||
if (real_idx == nullptr) {
|
||||
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWFlat*>(index_) failed during Serialize!");
|
||||
}
|
||||
auto storage_index = dynamic_cast<faiss::IndexFlat*>(real_idx->storage);
|
||||
faiss::write_index(storage_index, &writer);
|
||||
std::shared_ptr<uint8_t[]> data(writer.data_);
|
||||
|
||||
res_set.Append(writer.name, data, writer.rp);
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSWFlat::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
IndexRHNSW::Load(index_binary);
|
||||
MemoryIOReader reader;
|
||||
reader.name = this->index_type() + "_Data";
|
||||
auto binary = index_binary.GetByName(reader.name);
|
||||
|
||||
reader.total = (size_t)binary->size;
|
||||
reader.data_ = binary->data.get();
|
||||
|
||||
auto real_idx = dynamic_cast<faiss::IndexRHNSWFlat*>(index_.get());
|
||||
if (real_idx == nullptr) {
|
||||
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWFlat*>(index_) failed during Load!");
|
||||
}
|
||||
real_idx->storage = faiss::read_index(&reader);
|
||||
real_idx->init_hnsw();
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSWFlat::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
try {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
|
||||
auto idx = new faiss::IndexRHNSWFlat(int(dim), config[IndexParams::M], metric_type);
|
||||
idx->hnsw.efConstruction = config[IndexParams::efConstruction];
|
||||
index_ = std::shared_ptr<faiss::Index>(idx);
|
||||
index_->train(rows, (float*)p_data);
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSWFlat::UpdateIndexSize() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
index_size_ = dynamic_cast<faiss::IndexRHNSWFlat*>(index_.get())->cal_size();
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,51 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexRHNSW.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/IndexType.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IndexRHNSWFlat : public IndexRHNSW {
|
||||
public:
|
||||
IndexRHNSWFlat() : IndexRHNSW() {
|
||||
index_type_ = IndexEnum::INDEX_RHNSWFlat;
|
||||
}
|
||||
|
||||
explicit IndexRHNSWFlat(std::shared_ptr<faiss::Index> index) : IndexRHNSW(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_RHNSWFlat;
|
||||
}
|
||||
|
||||
IndexRHNSWFlat(int d, int M, MetricType metric = Metric::L2);
|
||||
|
||||
BinarySet
|
||||
Serialize(const Config& config = Config()) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
void
|
||||
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
UpdateIndexSize() override;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,102 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexRHNSWPQ.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "faiss/BuilderSuspend.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
IndexRHNSWPQ::IndexRHNSWPQ(int d, int pq_m, int M) {
|
||||
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexRHNSWPQ(d, pq_m, M));
|
||||
}
|
||||
|
||||
BinarySet
|
||||
IndexRHNSWPQ::Serialize(const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
try {
|
||||
auto res_set = IndexRHNSW::Serialize(config);
|
||||
MemoryIOWriter writer;
|
||||
writer.name = this->index_type() + "_Data";
|
||||
auto real_idx = dynamic_cast<faiss::IndexRHNSWPQ*>(index_.get());
|
||||
if (real_idx == nullptr) {
|
||||
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWPQ*>(index_) failed during Serialize!");
|
||||
}
|
||||
faiss::write_index(real_idx->storage, &writer);
|
||||
std::shared_ptr<uint8_t[]> data(writer.data_);
|
||||
|
||||
res_set.Append(writer.name, data, writer.rp);
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSWPQ::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
IndexRHNSW::Load(index_binary);
|
||||
MemoryIOReader reader;
|
||||
reader.name = this->index_type() + "_Data";
|
||||
auto binary = index_binary.GetByName(reader.name);
|
||||
|
||||
reader.total = (size_t)binary->size;
|
||||
reader.data_ = binary->data.get();
|
||||
|
||||
auto real_idx = dynamic_cast<faiss::IndexRHNSWPQ*>(index_.get());
|
||||
if (real_idx == nullptr) {
|
||||
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWPQ*>(index_) failed during Load!");
|
||||
}
|
||||
real_idx->storage = faiss::read_index(&reader);
|
||||
real_idx->init_hnsw();
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSWPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
try {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
|
||||
auto idx = new faiss::IndexRHNSWPQ(int(dim), config[IndexParams::PQM], config[IndexParams::M]);
|
||||
idx->hnsw.efConstruction = config[IndexParams::efConstruction];
|
||||
index_ = std::shared_ptr<faiss::Index>(idx);
|
||||
index_->train(rows, (float*)p_data);
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSWPQ::UpdateIndexSize() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
index_size_ = dynamic_cast<faiss::IndexRHNSWPQ*>(index_.get())->cal_size();
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexRHNSW.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IndexRHNSWPQ : public IndexRHNSW {
|
||||
public:
|
||||
IndexRHNSWPQ() : IndexRHNSW() {
|
||||
index_type_ = IndexEnum::INDEX_RHNSWPQ;
|
||||
}
|
||||
|
||||
explicit IndexRHNSWPQ(std::shared_ptr<faiss::Index> index) : IndexRHNSW(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_RHNSWPQ;
|
||||
}
|
||||
|
||||
IndexRHNSWPQ(int d, int pq_m, int M);
|
||||
|
||||
BinarySet
|
||||
Serialize(const Config& config = Config()) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
void
|
||||
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
UpdateIndexSize() override;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexRHNSWSQ.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "faiss/BuilderSuspend.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
IndexRHNSWSQ::IndexRHNSWSQ(int d, faiss::QuantizerType qtype, int M, milvus::knowhere::MetricType metric) {
|
||||
faiss::MetricType mt =
|
||||
metric == Metric::L2 ? faiss::MetricType::METRIC_L2 : faiss::MetricType::METRIC_INNER_PRODUCT;
|
||||
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexRHNSWSQ(d, qtype, M, mt));
|
||||
}
|
||||
|
||||
BinarySet
|
||||
IndexRHNSWSQ::Serialize(const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
try {
|
||||
auto res_set = IndexRHNSW::Serialize(config);
|
||||
MemoryIOWriter writer;
|
||||
writer.name = this->index_type() + "_Data";
|
||||
auto real_idx = dynamic_cast<faiss::IndexRHNSWSQ*>(index_.get());
|
||||
if (real_idx == nullptr) {
|
||||
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWSQ*>(index_) failed during Serialize!");
|
||||
}
|
||||
faiss::write_index(real_idx->storage, &writer);
|
||||
std::shared_ptr<uint8_t[]> data(writer.data_);
|
||||
|
||||
res_set.Append(writer.name, data, writer.rp);
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSWSQ::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
IndexRHNSW::Load(index_binary);
|
||||
MemoryIOReader reader;
|
||||
reader.name = this->index_type() + "_Data";
|
||||
auto binary = index_binary.GetByName(reader.name);
|
||||
|
||||
reader.total = (size_t)binary->size;
|
||||
reader.data_ = binary->data.get();
|
||||
|
||||
auto real_idx = dynamic_cast<faiss::IndexRHNSWSQ*>(index_.get());
|
||||
if (real_idx == nullptr) {
|
||||
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWSQ*>(index_) failed during Load!");
|
||||
}
|
||||
real_idx->storage = faiss::read_index(&reader);
|
||||
real_idx->init_hnsw();
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSWSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
try {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
|
||||
auto idx =
|
||||
new faiss::IndexRHNSWSQ(int(dim), faiss::QuantizerType::QT_8bit, config[IndexParams::M], metric_type);
|
||||
idx->hnsw.efConstruction = config[IndexParams::efConstruction];
|
||||
index_ = std::shared_ptr<faiss::Index>(idx);
|
||||
index_->train(rows, (float*)p_data);
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSWSQ::UpdateIndexSize() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
index_size_ = dynamic_cast<faiss::IndexRHNSWSQ*>(index_.get())->cal_size();
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexRHNSW.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IndexRHNSWSQ : public IndexRHNSW {
|
||||
public:
|
||||
IndexRHNSWSQ() : IndexRHNSW() {
|
||||
index_type_ = IndexEnum::INDEX_RHNSWSQ;
|
||||
}
|
||||
|
||||
explicit IndexRHNSWSQ(std::shared_ptr<faiss::Index> index) : IndexRHNSW(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_RHNSWSQ;
|
||||
}
|
||||
|
||||
IndexRHNSWSQ(int d, faiss::QuantizerType qtype, int M, MetricType metric = Metric::L2);
|
||||
|
||||
BinarySet
|
||||
Serialize(const Config& config = Config()) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
void
|
||||
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
UpdateIndexSize() override;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -20,11 +20,15 @@
|
|||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/IndexRHNSWFlat.h"
|
||||
#include "knowhere/index/vector_index/IndexRHNSWPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexRHNSWSQ.h"
|
||||
#include "knowhere/index/vector_offset_index/IndexHNSW_NM.h"
|
||||
#include "knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.h"
|
||||
#include "knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h"
|
||||
#include "knowhere/index/vector_offset_index/IndexIVF_NM.h"
|
||||
#include "knowhere/index/vector_offset_index/IndexNSG_NM.h"
|
||||
|
||||
#ifdef MILVUS_SUPPORT_SPTAG
|
||||
#include "knowhere/index/vector_index/IndexSPTAG.h"
|
||||
#endif
|
||||
|
@ -94,6 +98,12 @@ VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) {
|
|||
return std::make_shared<knowhere::IVFSQNR_NM>();
|
||||
} else if (type == IndexEnum::INDEX_HNSW_SQ8NM) {
|
||||
return std::make_shared<knowhere::IndexHNSW_SQ8NM>();
|
||||
} else if (type == IndexEnum::INDEX_RHNSWFlat) {
|
||||
return std::make_shared<knowhere::IndexRHNSWFlat>();
|
||||
} else if (type == IndexEnum::INDEX_RHNSWPQ) {
|
||||
return std::make_shared<knowhere::IndexRHNSWPQ>();
|
||||
} else if (type == IndexEnum::INDEX_RHNSWSQ) {
|
||||
return std::make_shared<knowhere::IndexRHNSWSQ>();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -48,6 +48,9 @@ constexpr const char* ef = "ef";
|
|||
// Annoy Params
|
||||
constexpr const char* n_trees = "n_trees";
|
||||
constexpr const char* search_k = "search_k";
|
||||
|
||||
// PQ Params
|
||||
constexpr const char* PQM = "PQM";
|
||||
} // namespace IndexParams
|
||||
|
||||
namespace Metric {
|
||||
|
|
|
@ -80,6 +80,7 @@ struct Index2Layer: Index {
|
|||
void sa_encode (idx_t n, const float *x, uint8_t *bytes) const override;
|
||||
void sa_decode (idx_t n, const uint8_t *bytes, float *x) const override;
|
||||
|
||||
size_t cal_size() { return sizeof(*this) + codes.size() * sizeof(uint8_t) + pq.cal_size(); }
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -85,6 +85,8 @@ struct IndexFlat: Index {
|
|||
void sa_decode (idx_t n, const uint8_t *bytes,
|
||||
float *x) const override;
|
||||
|
||||
size_t cal_size() { return xb.size() * sizeof(float); }
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -124,6 +124,8 @@ struct IndexPQ: Index {
|
|||
void hamming_distance_table (idx_t n, const float *x,
|
||||
int32_t *dis) const;
|
||||
|
||||
size_t cal_size() { return codes.size() * sizeof(uint8_t) + pq.cal_size(); }
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,812 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// -*- c++ -*-
|
||||
|
||||
#include <faiss/IndexRHNSW.h>
|
||||
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include <cmath>
|
||||
#include <omp.h>
|
||||
|
||||
#include <unordered_set>
|
||||
#include <queue>
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __SSE__
|
||||
#endif
|
||||
|
||||
#include <faiss/utils/distances.h>
|
||||
#include <faiss/utils/random.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/IndexIVFPQ.h>
|
||||
#include <faiss/Index2Layer.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/FaissHook.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
||||
|
||||
int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
|
||||
n, FINTEGER *k, const float *alpha, const float *a,
|
||||
FINTEGER *lda, const float *b, FINTEGER *
|
||||
ldb, float *beta, float *c, FINTEGER *ldc);
|
||||
|
||||
}
|
||||
|
||||
namespace faiss {
|
||||
|
||||
using idx_t = Index::idx_t;
|
||||
using MinimaxHeap = RHNSW::MinimaxHeap;
|
||||
using storage_idx_t = RHNSW::storage_idx_t;
|
||||
using NodeDistFarther = RHNSW::NodeDistFarther;
|
||||
|
||||
RHNSWStats rhnsw_stats;
|
||||
|
||||
/**************************************************************
|
||||
* add / search blocks of descriptors
|
||||
**************************************************************/
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
/* Wrap the distance computer into one that negates the
|
||||
distances. This makes supporting INNER_PRODUCE search easier */
|
||||
|
||||
struct NegativeDistanceComputer: DistanceComputer {
|
||||
|
||||
/// owned by this
|
||||
DistanceComputer *basedis;
|
||||
|
||||
explicit NegativeDistanceComputer(DistanceComputer *basedis):
|
||||
basedis(basedis)
|
||||
{}
|
||||
|
||||
void set_query(const float *x) override {
|
||||
basedis->set_query(x);
|
||||
}
|
||||
|
||||
/// compute distance of vector i to current query
|
||||
float operator () (idx_t i) override {
|
||||
return -(*basedis)(i);
|
||||
}
|
||||
|
||||
/// compute distance between two stored vectors
|
||||
float symmetric_dis (idx_t i, idx_t j) override {
|
||||
return -basedis->symmetric_dis(i, j);
|
||||
}
|
||||
|
||||
virtual ~NegativeDistanceComputer ()
|
||||
{
|
||||
delete basedis;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
DistanceComputer *storage_distance_computer(const Index *storage)
|
||||
{
|
||||
if (storage->metric_type == METRIC_INNER_PRODUCT) {
|
||||
return new NegativeDistanceComputer(storage->get_distance_computer());
|
||||
} else {
|
||||
return storage->get_distance_computer();
|
||||
}
|
||||
}
|
||||
|
||||
void hnsw_add_vertices(IndexRHNSW &index_hnsw,
|
||||
size_t n0,
|
||||
size_t n, const float *x,
|
||||
bool verbose,
|
||||
bool preset_levels = false) {
|
||||
size_t d = index_hnsw.d;
|
||||
RHNSW & hnsw = index_hnsw.hnsw;
|
||||
size_t ntotal = n0 + n;
|
||||
double t0 = getmillisecs();
|
||||
if (verbose) {
|
||||
printf("hnsw_add_vertices: adding %ld elements on top of %ld "
|
||||
"(preset_levels=%d)\n",
|
||||
n, n0, int(preset_levels));
|
||||
}
|
||||
|
||||
if (n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int max_level = hnsw.prepare_level_tab(n, preset_levels);
|
||||
|
||||
if (verbose) {
|
||||
printf(" max_level = %d\n", max_level);
|
||||
}
|
||||
|
||||
|
||||
{ // perform add
|
||||
auto tas = getmillisecs();
|
||||
RandomGenerator rng2(789);
|
||||
DistanceComputer *dis0 =
|
||||
storage_distance_computer (index_hnsw.storage);
|
||||
ScopeDeleter1<DistanceComputer> del0(dis0);
|
||||
|
||||
dis0->set_query(x);
|
||||
hnsw.addPoint(*dis0, hnsw.levels[n0], n0);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 1; i < n; ++ i) {
|
||||
DistanceComputer *dis =
|
||||
storage_distance_computer (index_hnsw.storage);
|
||||
ScopeDeleter1<DistanceComputer> del(dis);
|
||||
dis->set_query(x + i * d);
|
||||
hnsw.addPoint(*dis, hnsw.levels[n0 + i], i + n0);
|
||||
}
|
||||
}
|
||||
if (verbose) {
|
||||
printf("Done in %.3f ms\n", getmillisecs() - t0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
|
||||
|
||||
/**************************************************************
|
||||
* IndexRHNSW implementation
|
||||
**************************************************************/
|
||||
|
||||
IndexRHNSW::IndexRHNSW(int d, int M, MetricType metric):
|
||||
Index(d, metric),
|
||||
hnsw(M),
|
||||
own_fields(false),
|
||||
storage(nullptr),
|
||||
reconstruct_from_neighbors(nullptr)
|
||||
{}
|
||||
|
||||
IndexRHNSW::IndexRHNSW(Index *storage, int M):
|
||||
Index(storage->d, storage->metric_type),
|
||||
hnsw(M),
|
||||
own_fields(false),
|
||||
storage(storage),
|
||||
reconstruct_from_neighbors(nullptr)
|
||||
{}
|
||||
|
||||
IndexRHNSW::~IndexRHNSW() {
|
||||
if (own_fields) {
|
||||
delete storage;
|
||||
}
|
||||
}
|
||||
|
||||
void IndexRHNSW::init_hnsw() {
|
||||
hnsw.init(ntotal);
|
||||
}
|
||||
|
||||
void IndexRHNSW::train(idx_t n, const float* x)
|
||||
{
|
||||
FAISS_THROW_IF_NOT_MSG(storage,
|
||||
"Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly");
|
||||
// hnsw structure does not require training
|
||||
storage->train (n, x);
|
||||
is_trained = true;
|
||||
}
|
||||
|
||||
void IndexRHNSW::search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const
|
||||
|
||||
{
|
||||
FAISS_THROW_IF_NOT_MSG(storage,
|
||||
"Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly");
|
||||
size_t nreorder = 0;
|
||||
|
||||
idx_t check_period = InterruptCallback::get_period_hint (
|
||||
hnsw.max_level * d * hnsw.efSearch);
|
||||
|
||||
for (idx_t i0 = 0; i0 < n; i0 += check_period) {
|
||||
idx_t i1 = std::min(i0 + check_period, n);
|
||||
|
||||
#pragma omp parallel reduction(+ : nreorder)
|
||||
{
|
||||
|
||||
DistanceComputer *dis = storage_distance_computer(storage);
|
||||
ScopeDeleter1<DistanceComputer> del(dis);
|
||||
|
||||
#pragma omp for
|
||||
for(idx_t i = i0; i < i1; i++) {
|
||||
idx_t * idxi = labels + i * k;
|
||||
float * simi = distances + i * k;
|
||||
dis->set_query(x + i * d);
|
||||
|
||||
maxheap_heapify (k, simi, idxi);
|
||||
|
||||
hnsw.searchKnn(*dis, k, idxi, simi, bitset);
|
||||
|
||||
maxheap_reorder (k, simi, idxi);
|
||||
|
||||
if (reconstruct_from_neighbors &&
|
||||
reconstruct_from_neighbors->k_reorder != 0) {
|
||||
int k_reorder = reconstruct_from_neighbors->k_reorder;
|
||||
if (k_reorder == -1 || k_reorder > k) k_reorder = k;
|
||||
|
||||
nreorder += reconstruct_from_neighbors->compute_distances(
|
||||
k_reorder, idxi, x + i * d, simi);
|
||||
|
||||
// sort top k_reorder
|
||||
maxheap_heapify (k_reorder, simi, idxi, simi, idxi, k_reorder);
|
||||
maxheap_reorder (k_reorder, simi, idxi);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
InterruptCallback::check ();
|
||||
}
|
||||
|
||||
if (metric_type == METRIC_INNER_PRODUCT) {
|
||||
// we need to revert the negated distances
|
||||
for (size_t i = 0; i < k * n; i++) {
|
||||
distances[i] = -distances[i];
|
||||
}
|
||||
}
|
||||
|
||||
rhnsw_stats.nreorder += nreorder;
|
||||
}
|
||||
|
||||
|
||||
void IndexRHNSW::add(idx_t n, const float *x)
|
||||
{
|
||||
FAISS_THROW_IF_NOT_MSG(storage,
|
||||
"Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly");
|
||||
FAISS_THROW_IF_NOT(is_trained);
|
||||
int n0 = ntotal;
|
||||
storage->add(n, x);
|
||||
ntotal = storage->ntotal;
|
||||
|
||||
hnsw_add_vertices (*this, n0, n, x, verbose,
|
||||
hnsw.levels.size() == ntotal);
|
||||
}
|
||||
|
||||
void IndexRHNSW::reset()
|
||||
{
|
||||
hnsw.reset();
|
||||
storage->reset();
|
||||
ntotal = 0;
|
||||
}
|
||||
|
||||
void IndexRHNSW::reconstruct (idx_t key, float* recons) const
|
||||
{
|
||||
storage->reconstruct(key, recons);
|
||||
}
|
||||
|
||||
size_t IndexRHNSW::cal_size() {
|
||||
return hnsw.cal_size();
|
||||
}
|
||||
|
||||
/**************************************************************
|
||||
* ReconstructFromNeighbors implementation
|
||||
**************************************************************/
|
||||
|
||||
ReconstructFromNeighbors2::ReconstructFromNeighbors2(
|
||||
const IndexRHNSW & index, size_t k, size_t nsq):
|
||||
index(index), k(k), nsq(nsq) {
|
||||
M = index.hnsw.M << 1;
|
||||
FAISS_ASSERT(k <= 256);
|
||||
code_size = k == 1 ? 0 : nsq;
|
||||
ntotal = 0;
|
||||
d = index.d;
|
||||
FAISS_ASSERT(d % nsq == 0);
|
||||
dsub = d / nsq;
|
||||
k_reorder = -1;
|
||||
}
|
||||
|
||||
void ReconstructFromNeighbors2::reconstruct(storage_idx_t i, float *x, float *tmp) const
|
||||
{
|
||||
|
||||
|
||||
const RHNSW & hnsw = index.hnsw;
|
||||
int *cur_links = hnsw.get_neighbor_link(i, 0);
|
||||
int *cur_neighbors = cur_links + 1;
|
||||
auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links);
|
||||
|
||||
if (k == 1 || nsq == 1) {
|
||||
const float * beta;
|
||||
if (k == 1) {
|
||||
beta = codebook.data();
|
||||
} else {
|
||||
int idx = codes[i];
|
||||
beta = codebook.data() + idx * (M + 1);
|
||||
}
|
||||
|
||||
float w0 = beta[0]; // weight of image itself
|
||||
index.storage->reconstruct(i, tmp);
|
||||
|
||||
for (int l = 0; l < d; l++)
|
||||
x[l] = w0 * tmp[l];
|
||||
|
||||
for (auto j = 0; j < cur_neighbor_num; ++ j) {
|
||||
|
||||
storage_idx_t ji = cur_neighbors[j];
|
||||
if (ji < 0) ji = i;
|
||||
float w = beta[j + 1];
|
||||
index.storage->reconstruct(ji, tmp);
|
||||
for (int l = 0; l < d; l++)
|
||||
x[l] += w * tmp[l];
|
||||
}
|
||||
} else if (nsq == 2) {
|
||||
int idx0 = codes[2 * i];
|
||||
int idx1 = codes[2 * i + 1];
|
||||
|
||||
const float *beta0 = codebook.data() + idx0 * (M + 1);
|
||||
const float *beta1 = codebook.data() + (idx1 + k) * (M + 1);
|
||||
|
||||
index.storage->reconstruct(i, tmp);
|
||||
|
||||
float w0;
|
||||
|
||||
w0 = beta0[0];
|
||||
for (int l = 0; l < dsub; l++)
|
||||
x[l] = w0 * tmp[l];
|
||||
|
||||
w0 = beta1[0];
|
||||
for (int l = dsub; l < d; l++)
|
||||
x[l] = w0 * tmp[l];
|
||||
|
||||
for (auto j = 0; j < cur_neighbor_num; ++ j) {
|
||||
storage_idx_t ji = cur_neighbors[j];
|
||||
if (ji < 0) ji = i;
|
||||
index.storage->reconstruct(ji, tmp);
|
||||
float w;
|
||||
w = beta0[j + 1];
|
||||
for (int l = 0; l < dsub; l++)
|
||||
x[l] += w * tmp[l];
|
||||
|
||||
w = beta1[j + 1];
|
||||
for (int l = dsub; l < d; l++)
|
||||
x[l] += w * tmp[l];
|
||||
}
|
||||
} else {
|
||||
const float *betas[nsq];
|
||||
{
|
||||
const float *b = codebook.data();
|
||||
const uint8_t *c = &codes[i * code_size];
|
||||
for (int sq = 0; sq < nsq; sq++) {
|
||||
betas[sq] = b + (*c++) * (M + 1);
|
||||
b += (M + 1) * k;
|
||||
}
|
||||
}
|
||||
|
||||
index.storage->reconstruct(i, tmp);
|
||||
{
|
||||
int d0 = 0;
|
||||
for (int sq = 0; sq < nsq; sq++) {
|
||||
float w = *(betas[sq]++);
|
||||
int d1 = d0 + dsub;
|
||||
for (int l = d0; l < d1; l++) {
|
||||
x[l] = w * tmp[l];
|
||||
}
|
||||
d0 = d1;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto j = 0; j < cur_neighbor_num; ++ j) {
|
||||
storage_idx_t ji = cur_neighbors[j];
|
||||
if (ji < 0) ji = i;
|
||||
|
||||
index.storage->reconstruct(ji, tmp);
|
||||
int d0 = 0;
|
||||
for (int sq = 0; sq < nsq; sq++) {
|
||||
float w = *(betas[sq]++);
|
||||
int d1 = d0 + dsub;
|
||||
for (int l = d0; l < d1; l++) {
|
||||
x[l] += w * tmp[l];
|
||||
}
|
||||
d0 = d1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ReconstructFromNeighbors2::reconstruct_n(storage_idx_t n0,
|
||||
storage_idx_t ni,
|
||||
float *x) const
|
||||
{
|
||||
#pragma omp parallel
|
||||
{
|
||||
std::vector<float> tmp(index.d);
|
||||
#pragma omp for
|
||||
for (storage_idx_t i = 0; i < ni; i++) {
|
||||
reconstruct(n0 + i, x + i * index.d, tmp.data());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t ReconstructFromNeighbors2::compute_distances(
|
||||
size_t n, const idx_t *shortlist,
|
||||
const float *query, float *distances) const
|
||||
{
|
||||
std::vector<float> tmp(2 * index.d);
|
||||
size_t ncomp = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
if (shortlist[i] < 0) break;
|
||||
reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d);
|
||||
distances[i] = fvec_L2sqr(query, tmp.data(), index.d);
|
||||
ncomp++;
|
||||
}
|
||||
return ncomp;
|
||||
}
|
||||
|
||||
void ReconstructFromNeighbors2::get_neighbor_table(storage_idx_t i, float *tmp1) const
|
||||
{
|
||||
const RHNSW & hnsw = index.hnsw;
|
||||
int *cur_links = hnsw.get_neighbor_link(i, 0);
|
||||
int *cur_neighbors = cur_links + 1;
|
||||
auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links);
|
||||
size_t d = index.d;
|
||||
|
||||
index.storage->reconstruct(i, tmp1);
|
||||
|
||||
for (auto j = 0; j < cur_neighbor_num; ++ j) {
|
||||
storage_idx_t ji = cur_neighbors[j];
|
||||
if (ji < 0) ji = i;
|
||||
index.storage->reconstruct(ji, tmp1 + (j + 1) * d);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// called by add_codes
|
||||
void ReconstructFromNeighbors2::estimate_code(
|
||||
const float *x, storage_idx_t i, uint8_t *code) const
|
||||
{
|
||||
|
||||
// fill in tmp table with the neighbor values
|
||||
float *tmp1 = new float[d * (M + 1) + (d * k)];
|
||||
float *tmp2 = tmp1 + d * (M + 1);
|
||||
ScopeDeleter<float> del(tmp1);
|
||||
|
||||
// collect coordinates of base
|
||||
get_neighbor_table (i, tmp1);
|
||||
|
||||
for (size_t sq = 0; sq < nsq; sq++) {
|
||||
int d0 = sq * dsub;
|
||||
|
||||
{
|
||||
FINTEGER ki = k, di = d, m1 = M + 1;
|
||||
FINTEGER dsubi = dsub;
|
||||
float zero = 0, one = 1;
|
||||
|
||||
sgemm_ ("N", "N", &dsubi, &ki, &m1, &one,
|
||||
tmp1 + d0, &di,
|
||||
codebook.data() + sq * (m1 * k), &m1,
|
||||
&zero, tmp2, &dsubi);
|
||||
}
|
||||
|
||||
float min = HUGE_VAL;
|
||||
int argmin = -1;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub);
|
||||
if (dis < min) {
|
||||
min = dis;
|
||||
argmin = j;
|
||||
}
|
||||
}
|
||||
code[sq] = argmin;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void ReconstructFromNeighbors2::add_codes(size_t n, const float *x)
|
||||
{
|
||||
if (k == 1) { // nothing to encode
|
||||
ntotal += n;
|
||||
return;
|
||||
}
|
||||
codes.resize(codes.size() + code_size * n);
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < n; i++) {
|
||||
estimate_code(x + i * index.d, ntotal + i,
|
||||
codes.data() + (ntotal + i) * code_size);
|
||||
}
|
||||
ntotal += n;
|
||||
FAISS_ASSERT (codes.size() == ntotal * code_size);
|
||||
}
|
||||
|
||||
|
||||
/**************************************************************
|
||||
* IndexRHNSWFlat implementation
|
||||
**************************************************************/
|
||||
|
||||
|
||||
IndexRHNSWFlat::IndexRHNSWFlat()
|
||||
{
|
||||
is_trained = true;
|
||||
}
|
||||
|
||||
IndexRHNSWFlat::IndexRHNSWFlat(int d, int M, MetricType metric):
|
||||
IndexRHNSW(new IndexFlat(d, metric), M)
|
||||
{
|
||||
own_fields = true;
|
||||
is_trained = true;
|
||||
}
|
||||
|
||||
size_t IndexRHNSWFlat::cal_size() {
|
||||
return IndexRHNSW::cal_size() + dynamic_cast<IndexFlat*>(storage)->cal_size();
|
||||
}
|
||||
|
||||
/**************************************************************
|
||||
* IndexRHNSWPQ implementation
|
||||
**************************************************************/
|
||||
|
||||
|
||||
IndexRHNSWPQ::IndexRHNSWPQ() {}
|
||||
|
||||
IndexRHNSWPQ::IndexRHNSWPQ(int d, int pq_m, int M):
|
||||
IndexRHNSW(new IndexPQ(d, pq_m, 8), M)
|
||||
{
|
||||
own_fields = true;
|
||||
is_trained = false;
|
||||
}
|
||||
|
||||
void IndexRHNSWPQ::train(idx_t n, const float* x)
|
||||
{
|
||||
IndexRHNSW::train (n, x);
|
||||
(dynamic_cast<IndexPQ*> (storage))->pq.compute_sdc_table();
|
||||
}
|
||||
|
||||
size_t IndexRHNSWPQ::cal_size() {
|
||||
return IndexRHNSW::cal_size() + dynamic_cast<IndexPQ*>(storage)->cal_size();
|
||||
}
|
||||
|
||||
/**************************************************************
|
||||
* IndexRHNSWSQ implementation
|
||||
**************************************************************/
|
||||
|
||||
|
||||
IndexRHNSWSQ::IndexRHNSWSQ(int d, QuantizerType qtype, int M,
|
||||
MetricType metric):
|
||||
IndexRHNSW (new IndexScalarQuantizer (d, qtype, metric), M)
|
||||
{
|
||||
is_trained = false;
|
||||
own_fields = true;
|
||||
}
|
||||
|
||||
IndexRHNSWSQ::IndexRHNSWSQ() {}
|
||||
|
||||
size_t IndexRHNSWSQ::cal_size() {
|
||||
return IndexRHNSW::cal_size() + dynamic_cast<IndexScalarQuantizer *>(storage)->cal_size();
|
||||
}
|
||||
|
||||
/**************************************************************
|
||||
* IndexRHNSW2Level implementation
|
||||
**************************************************************/
|
||||
|
||||
|
||||
IndexRHNSW2Level::IndexRHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M):
|
||||
IndexRHNSW (new Index2Layer (quantizer, nlist, m_pq), M)
|
||||
{
|
||||
own_fields = true;
|
||||
is_trained = false;
|
||||
}
|
||||
|
||||
IndexRHNSW2Level::IndexRHNSW2Level() {}
|
||||
|
||||
size_t IndexRHNSW2Level::cal_size() {
|
||||
return IndexRHNSW::cal_size() + dynamic_cast<Index2Layer *>(storage)->cal_size();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
// same as search_from_candidates but uses v
|
||||
// visno -> is in result list
|
||||
// visno + 1 -> in result list + in candidates
|
||||
int search_from_candidates_2(const RHNSW & hnsw,
|
||||
DistanceComputer & qdis, int k,
|
||||
idx_t *I, float * D,
|
||||
MinimaxHeap &candidates,
|
||||
VisitedList &vt,
|
||||
int level, int nres_in = 0)
|
||||
{
|
||||
int nres = nres_in;
|
||||
int ndis = 0;
|
||||
for (int i = 0; i < candidates.size(); i++) {
|
||||
idx_t v1 = candidates.ids[i];
|
||||
FAISS_ASSERT(v1 >= 0);
|
||||
vt.mass[v1] = vt.curV + 1;
|
||||
}
|
||||
|
||||
int nstep = 0;
|
||||
|
||||
while (candidates.size() > 0) {
|
||||
float d0 = 0;
|
||||
int v0 = candidates.pop_min(&d0);
|
||||
|
||||
int *cur_links = hnsw.get_neighbor_link(v0, level);
|
||||
int *cur_neighbors = cur_links + 1;
|
||||
auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links);
|
||||
|
||||
for (auto j = 0; j < cur_neighbor_num; ++ j) {
|
||||
int v1 = cur_neighbors[j];
|
||||
if (v1 < 0) break;
|
||||
if (vt.mass[v1] == vt.curV + 1) {
|
||||
// nothing to do
|
||||
} else {
|
||||
ndis++;
|
||||
float d = qdis(v1);
|
||||
candidates.push(v1, d);
|
||||
|
||||
// never seen before --> add to heap
|
||||
if (vt.mass[v1] < vt.curV) {
|
||||
if (nres < k) {
|
||||
faiss::maxheap_push (++nres, D, I, d, v1);
|
||||
} else if (d < D[0]) {
|
||||
faiss::maxheap_pop (nres--, D, I);
|
||||
faiss::maxheap_push (++nres, D, I, d, v1);
|
||||
}
|
||||
}
|
||||
vt.mass[v1] = vt.curV + 1;
|
||||
}
|
||||
}
|
||||
|
||||
nstep++;
|
||||
if (nstep > hnsw.efSearch) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (level == 0) {
|
||||
#pragma omp critical
|
||||
{
|
||||
rhnsw_stats.n1 ++;
|
||||
if (candidates.size() == 0)
|
||||
rhnsw_stats.n2 ++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nres;
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
||||
|
||||
void IndexRHNSW2Level::search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const
|
||||
{
|
||||
if (dynamic_cast<const Index2Layer*>(storage)) {
|
||||
IndexRHNSW::search (n, x, k, distances, labels);
|
||||
|
||||
} else { // "mixed" search
|
||||
|
||||
const IndexIVFPQ *index_ivfpq =
|
||||
dynamic_cast<const IndexIVFPQ*>(storage);
|
||||
|
||||
int nprobe = index_ivfpq->nprobe;
|
||||
|
||||
std::unique_ptr<idx_t[]> coarse_assign(new idx_t[n * nprobe]);
|
||||
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
||||
|
||||
index_ivfpq->quantizer->search (n, x, nprobe, coarse_dis.get(),
|
||||
coarse_assign.get());
|
||||
|
||||
index_ivfpq->search_preassigned (n, x, k, coarse_assign.get(),
|
||||
coarse_dis.get(), distances, labels,
|
||||
false);
|
||||
|
||||
#pragma omp parallel
|
||||
{
|
||||
VisitedList vt (ntotal);
|
||||
DistanceComputer *dis = storage_distance_computer(storage);
|
||||
ScopeDeleter1<DistanceComputer> del(dis);
|
||||
|
||||
int candidates_size = hnsw.upper_beam;
|
||||
MinimaxHeap candidates(candidates_size);
|
||||
|
||||
#pragma omp for
|
||||
for(idx_t i = 0; i < n; i++) {
|
||||
idx_t * idxi = labels + i * k;
|
||||
float * simi = distances + i * k;
|
||||
dis->set_query(x + i * d);
|
||||
|
||||
// mark all inverted list elements as visited
|
||||
|
||||
for (int j = 0; j < nprobe; j++) {
|
||||
idx_t key = coarse_assign[j + i * nprobe];
|
||||
if (key < 0) break;
|
||||
size_t list_length = index_ivfpq->get_list_size (key);
|
||||
const idx_t * ids = index_ivfpq->invlists->get_ids (key);
|
||||
|
||||
for (int jj = 0; jj < list_length; jj++) {
|
||||
vt.set (ids[jj]);
|
||||
}
|
||||
}
|
||||
|
||||
candidates.clear();
|
||||
// copy the upper_beam elements to candidates list
|
||||
|
||||
int search_policy = 2;
|
||||
|
||||
if (search_policy == 1) {
|
||||
|
||||
for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
|
||||
if (idxi[j] < 0) break;
|
||||
candidates.push (idxi[j], simi[j]);
|
||||
// search_from_candidates adds them back
|
||||
idxi[j] = -1;
|
||||
simi[j] = HUGE_VAL;
|
||||
}
|
||||
|
||||
// reorder from sorted to heap
|
||||
maxheap_heapify (k, simi, idxi, simi, idxi, k);
|
||||
|
||||
// removed from RHNSW, but still available in HNSW
|
||||
// hnsw.search_from_candidates(
|
||||
// *dis, k, idxi, simi,
|
||||
// candidates, vt, 0, k
|
||||
// );
|
||||
|
||||
vt.advance();
|
||||
|
||||
} else if (search_policy == 2) {
|
||||
|
||||
for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
|
||||
if (idxi[j] < 0) break;
|
||||
candidates.push (idxi[j], simi[j]);
|
||||
}
|
||||
|
||||
// reorder from sorted to heap
|
||||
maxheap_heapify (k, simi, idxi, simi, idxi, k);
|
||||
|
||||
search_from_candidates_2 (
|
||||
hnsw, *dis, k, idxi, simi,
|
||||
candidates, vt, 0, k);
|
||||
vt.advance ();
|
||||
vt.advance ();
|
||||
|
||||
}
|
||||
|
||||
maxheap_reorder (k, simi, idxi);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
void IndexRHNSW2Level::flip_to_ivf ()
|
||||
{
|
||||
Index2Layer *storage2l =
|
||||
dynamic_cast<Index2Layer*>(storage);
|
||||
|
||||
FAISS_THROW_IF_NOT (storage2l);
|
||||
|
||||
IndexIVFPQ * index_ivfpq =
|
||||
new IndexIVFPQ (storage2l->q1.quantizer,
|
||||
d, storage2l->q1.nlist,
|
||||
storage2l->pq.M, 8);
|
||||
index_ivfpq->pq = storage2l->pq;
|
||||
index_ivfpq->is_trained = storage2l->is_trained;
|
||||
index_ivfpq->precompute_table();
|
||||
index_ivfpq->own_fields = storage2l->q1.own_fields;
|
||||
storage2l->transfer_to_IVFPQ(*index_ivfpq);
|
||||
index_ivfpq->make_direct_map (true);
|
||||
|
||||
storage = index_ivfpq;
|
||||
delete storage2l;
|
||||
|
||||
}
|
||||
|
||||
|
||||
} // namespace faiss
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// -*- c++ -*-
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <faiss/impl/RHNSW.h>
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/IndexPQ.h>
|
||||
#include <faiss/IndexScalarQuantizer.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
//#include <faiss/IndexHNSW.h>
|
||||
|
||||
|
||||
namespace faiss {
|
||||
|
||||
struct IndexRHNSW;
|
||||
|
||||
struct ReconstructFromNeighbors2 {
|
||||
typedef Index::idx_t idx_t;
|
||||
typedef RHNSW::storage_idx_t storage_idx_t;
|
||||
|
||||
const IndexRHNSW & index;
|
||||
size_t M; // number of neighbors
|
||||
size_t k; // number of codebook entries
|
||||
size_t nsq; // number of subvectors
|
||||
size_t code_size;
|
||||
int k_reorder; // nb to reorder. -1 = all
|
||||
|
||||
std::vector<float> codebook; // size nsq * k * (M + 1)
|
||||
|
||||
std::vector<uint8_t> codes; // size ntotal * code_size
|
||||
size_t ntotal;
|
||||
size_t d, dsub; // derived values
|
||||
|
||||
explicit ReconstructFromNeighbors2(const IndexRHNSW& index,
|
||||
size_t k=256, size_t nsq=1);
|
||||
|
||||
/// codes must be added in the correct order and the IndexRHNSW
|
||||
/// must be populated and sorted
|
||||
void add_codes(size_t n, const float *x);
|
||||
|
||||
size_t compute_distances(size_t n, const idx_t *shortlist,
|
||||
const float *query, float *distances) const;
|
||||
|
||||
/// called by add_codes
|
||||
void estimate_code(const float *x, storage_idx_t i, uint8_t *code) const;
|
||||
|
||||
/// called by compute_distances
|
||||
void reconstruct(storage_idx_t i, float *x, float *tmp) const;
|
||||
|
||||
void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float *x) const;
|
||||
|
||||
/// get the M+1 -by-d table for neighbor coordinates for vector i
|
||||
void get_neighbor_table(storage_idx_t i, float *out) const;
|
||||
|
||||
};
|
||||
|
||||
/** The HNSW index is a normal random-access index with a HNSW
|
||||
* link structure built on top */
|
||||
|
||||
struct IndexRHNSW : Index {
|
||||
|
||||
typedef RHNSW::storage_idx_t storage_idx_t;
|
||||
|
||||
// the link strcuture
|
||||
RHNSW hnsw;
|
||||
|
||||
// the sequential storage
|
||||
bool own_fields;
|
||||
Index *storage;
|
||||
|
||||
ReconstructFromNeighbors2 *reconstruct_from_neighbors;
|
||||
|
||||
explicit IndexRHNSW (int d = 0, int M = 32, MetricType metric = METRIC_L2);
|
||||
explicit IndexRHNSW (Index *storage, int M = 32);
|
||||
|
||||
~IndexRHNSW() override;
|
||||
|
||||
void add(idx_t n, const float *x) override;
|
||||
|
||||
/// Trains the storage if needed
|
||||
void train(idx_t n, const float* x) override;
|
||||
|
||||
/// entry point for search
|
||||
void search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
ConcurrentBitsetPtr bitset = nullptr) const override;
|
||||
|
||||
void reconstruct(idx_t key, float* recons) const override;
|
||||
|
||||
void reset () override;
|
||||
|
||||
size_t cal_size();
|
||||
|
||||
void init_hnsw();
|
||||
};
|
||||
|
||||
|
||||
/** Flat index topped with with a HNSW structure to access elements
|
||||
* more efficiently.
|
||||
*/
|
||||
|
||||
struct IndexRHNSWFlat : IndexRHNSW {
|
||||
IndexRHNSWFlat();
|
||||
IndexRHNSWFlat(int d, int M, MetricType metric = METRIC_L2);
|
||||
size_t cal_size();
|
||||
};
|
||||
|
||||
/** PQ index topped with with a HNSW structure to access elements
|
||||
* more efficiently.
|
||||
*/
|
||||
struct IndexRHNSWPQ : IndexRHNSW {
|
||||
IndexRHNSWPQ();
|
||||
IndexRHNSWPQ(int d, int pq_m, int M);
|
||||
void train(idx_t n, const float* x) override;
|
||||
size_t cal_size();
|
||||
};
|
||||
|
||||
/** SQ index topped with with a HNSW structure to access elements
|
||||
* more efficiently.
|
||||
*/
|
||||
struct IndexRHNSWSQ : IndexRHNSW {
|
||||
IndexRHNSWSQ();
|
||||
IndexRHNSWSQ(int d, QuantizerType qtype, int M, MetricType metric = METRIC_L2);
|
||||
size_t cal_size();
|
||||
};
|
||||
|
||||
/** 2-level code structure with fast random access
|
||||
*/
|
||||
struct IndexRHNSW2Level : IndexRHNSW {
|
||||
IndexRHNSW2Level();
|
||||
IndexRHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M);
|
||||
|
||||
void flip_to_ivf();
|
||||
|
||||
/// entry point for search
|
||||
void search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
ConcurrentBitsetPtr bitset = nullptr) const override;
|
||||
size_t cal_size();
|
||||
};
|
||||
|
||||
|
||||
} // namespace faiss
|
|
@ -78,6 +78,7 @@ struct IndexScalarQuantizer: Index {
|
|||
void sa_decode (idx_t n, const uint8_t *bytes,
|
||||
float *x) const override;
|
||||
|
||||
size_t cal_size() { return codes.size() * sizeof(uint8_t) + sizeof(size_t) + sq.cal_size(); }
|
||||
|
||||
};
|
||||
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
#./configure CPUFLAGS='-mavx -mf16c -msse4 -mpopcnt' CXXFLAGS='-O0 -g -fPIC -m64 -Wno-sign-compare -Wall -Wextra' --prefix=$PWD --with-cuda-arch=-gencode=arch=compute_75,code=sm_75 --with-cuda=/usr/local/cuda
|
||||
./configure --prefix=$PWD CFLAGS='-g -fPIC' CXXFLAGS='-O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp -g -fPIC -mf16c -O3' --without-python --with-cuda=/usr/local/cuda --with-cuda-arch='-gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75'
|
||||
./configure --prefix=$PWD CFLAGS='-g -fPIC' CXXFLAGS='-O3 -g -fPIC -DELPP_THREAD_SAFE -fopenmp -g -fPIC -mf16c -O3 -DNDEBUG' --without-python --with-cuda=/usr/local/cuda --with-cuda-arch='-gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75'
|
||||
make install -j8
|
||||
|
|
|
@ -173,6 +173,7 @@ struct ProductQuantizer {
|
|||
float_maxheap_array_t * res,
|
||||
bool init_finalize_heap = true) const;
|
||||
|
||||
size_t cal_size() { return sizeof(*this) + centroids.size() * sizeof(float); }
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,441 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// -*- c++ -*-
|
||||
|
||||
#include <faiss/impl/RHNSW.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
||||
/**************************************************************
|
||||
* hnsw structure implementation
|
||||
**************************************************************/
|
||||
|
||||
RHNSW::RHNSW(int M) : M(M), rng(12345) {
|
||||
level_generator.seed(100);
|
||||
max_level = -1;
|
||||
entry_point = -1;
|
||||
efSearch = 16;
|
||||
efConstruction = 40;
|
||||
upper_beam = 1;
|
||||
level0_link_size = sizeof(int) * ((M << 1) | 1);
|
||||
link_size = sizeof(int) * (M + 1);
|
||||
level0_links = nullptr;
|
||||
linkLists = nullptr;
|
||||
level_constant = 1 / log(1.0 * M);
|
||||
visited_list_pool = nullptr;
|
||||
}
|
||||
|
||||
void RHNSW::init(int ntotal) {
|
||||
level_generator.seed(100);
|
||||
if (visited_list_pool) delete visited_list_pool;
|
||||
visited_list_pool = new VisitedListPool(1, ntotal);
|
||||
std::vector<std::mutex>(ntotal).swap(link_list_locks);
|
||||
}
|
||||
|
||||
RHNSW::~RHNSW() {
|
||||
free(level0_links);
|
||||
for (auto i = 0; i < levels.size(); ++ i) {
|
||||
if (levels[i])
|
||||
free(linkLists[i]);
|
||||
}
|
||||
free(linkLists);
|
||||
delete visited_list_pool;
|
||||
}
|
||||
|
||||
void RHNSW::reset() {
|
||||
max_level = -1;
|
||||
entry_point = -1;
|
||||
levels.clear();
|
||||
free(level0_links);
|
||||
for (auto i = 0; i < levels.size(); ++ i) {
|
||||
if (levels[i])
|
||||
free(linkLists[i]);
|
||||
}
|
||||
free(linkLists);
|
||||
level0_links = nullptr;
|
||||
linkLists = nullptr;
|
||||
level_constant = 1 / log(1.0 * M);
|
||||
}
|
||||
|
||||
int RHNSW::prepare_level_tab(size_t n, bool preset_levels)
|
||||
{
|
||||
size_t n0 = levels.size();
|
||||
|
||||
std::vector<int> level_stats(n);
|
||||
if (preset_levels) {
|
||||
FAISS_ASSERT (n0 + n == levels.size());
|
||||
} else {
|
||||
FAISS_ASSERT (n0 == levels.size());
|
||||
for (int i = 0; i < n; i++) {
|
||||
int pt_level = random_level(level_constant);
|
||||
levels.push_back(pt_level);
|
||||
}
|
||||
}
|
||||
|
||||
char *level0_links_new = (char*)malloc((n0 + n) * level0_link_size);
|
||||
if (level0_links_new == nullptr) {
|
||||
throw std::runtime_error("No enough memory 4 level0_links!");
|
||||
}
|
||||
memset(level0_links_new, 0, (n0 + n) * level0_link_size);
|
||||
if (level0_links) {
|
||||
memcpy(level0_links_new, level0_links, n0 * level0_link_size);
|
||||
free(level0_links);
|
||||
}
|
||||
level0_links = level0_links_new;
|
||||
|
||||
char **linkLists_new = (char **)malloc(sizeof(void*) * (n0 + n));
|
||||
if (linkLists_new == nullptr) {
|
||||
throw std::runtime_error("No enough memory 4 level0_links_new!");
|
||||
}
|
||||
if (linkLists) {
|
||||
memcpy(linkLists_new, linkLists, n0 * sizeof(void*));
|
||||
free(linkLists);
|
||||
}
|
||||
linkLists = linkLists_new;
|
||||
|
||||
int max_level = 0;
|
||||
int debug_space = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
int pt_level = levels[i + n0];
|
||||
if (pt_level > max_level) max_level = pt_level;
|
||||
if (pt_level) {
|
||||
linkLists[n0 + i] = (char*) malloc(link_size * pt_level + 1);
|
||||
if (linkLists[n0 + i] == nullptr) {
|
||||
throw std::runtime_error("No enough memory 4 linkLists!");
|
||||
}
|
||||
memset(linkLists[n0 + i], 0, link_size * pt_level + 1);
|
||||
}
|
||||
if (max_level >= level_stats.size()) {
|
||||
level_stats.resize(max_level + 1);
|
||||
}
|
||||
level_stats[pt_level] ++;
|
||||
}
|
||||
|
||||
// printf("level stats:\n");
|
||||
// for (int i = 0; i <= max_level; ++ i)
|
||||
// printf("level %d: %d points\n", i, level_stats[i]);
|
||||
// printf("\n");
|
||||
std::vector<std::mutex>(n0 + n).swap(link_list_locks);
|
||||
if (visited_list_pool) delete visited_list_pool;
|
||||
visited_list_pool = new VisitedListPool(1, n0 + n);
|
||||
|
||||
return max_level;
|
||||
}
|
||||
|
||||
|
||||
/**************************************************************
|
||||
* new implementation of hnsw ispired by hnswlib
|
||||
* by cmli@zilliz July 30, 2020
|
||||
**************************************************************/
|
||||
using Node = faiss::RHNSW::Node;
|
||||
using CompareByFirst = faiss::RHNSW::CompareByFirst;
|
||||
void RHNSW::addPoint(DistanceComputer& ptdis, int pt_level, int pt_id) {
|
||||
|
||||
std::unique_lock<std::mutex> lock_el(link_list_locks[pt_id]);
|
||||
std::unique_lock<std::mutex> temp_lock(global);
|
||||
int maxlevel_copy = max_level;
|
||||
if (pt_level <= maxlevel_copy)
|
||||
temp_lock.unlock();
|
||||
int currObj = entry_point;
|
||||
int ep_copy = entry_point;
|
||||
|
||||
if (currObj != -1) {
|
||||
if (pt_level < maxlevel_copy) {
|
||||
float curdist = ptdis(currObj);
|
||||
for (int lev = maxlevel_copy; lev > pt_level; lev --) {
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
std::unique_lock<std::mutex> lk(link_list_locks[currObj]);
|
||||
int *curObj_link = get_neighbor_link(currObj, lev);
|
||||
auto curObj_nei_num = get_neighbors_num(curObj_link);
|
||||
for (auto i = 1; i <= curObj_nei_num; ++ i) {
|
||||
int cand = curObj_link[i];
|
||||
if (cand < 0 || cand > levels.size())
|
||||
throw std::runtime_error("cand error when addPoint");
|
||||
float d = ptdis(cand);
|
||||
if (d < curdist) {
|
||||
curdist = d;
|
||||
currObj = cand;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int lev = std::min(pt_level, maxlevel_copy); lev >= 0; -- lev) {
|
||||
if (lev > maxlevel_copy || lev < 0)
|
||||
throw std::runtime_error("Level error");
|
||||
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates = search_layer(ptdis, pt_id, currObj, lev);
|
||||
currObj = top_candidates.top().second;
|
||||
make_connection(ptdis, pt_id, top_candidates, lev);
|
||||
}
|
||||
} else {
|
||||
entry_point = 0;
|
||||
max_level = pt_level;
|
||||
}
|
||||
|
||||
if (pt_level > maxlevel_copy) {
|
||||
entry_point = pt_id;
|
||||
max_level = pt_level;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
|
||||
RHNSW::search_layer(DistanceComputer& ptdis,
|
||||
storage_idx_t pt_id,
|
||||
storage_idx_t nearest,
|
||||
int level) {
|
||||
VisitedList *vl = visited_list_pool->getFreeVisitedList();
|
||||
vl_type *visited_array = vl->mass;
|
||||
vl_type visited_array_tag = vl->curV;
|
||||
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates;
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> candidate_set;
|
||||
|
||||
float d_nearest = ptdis(nearest);
|
||||
float lb = d_nearest;
|
||||
top_candidates.emplace(d_nearest, nearest);
|
||||
candidate_set.emplace(-d_nearest, nearest);
|
||||
visited_array[nearest] = visited_array_tag;
|
||||
|
||||
while (!candidate_set.empty()) {
|
||||
Node currNode = candidate_set.top();
|
||||
if ((-currNode.first) > lb)
|
||||
break;
|
||||
candidate_set.pop();
|
||||
int cur_id = currNode.second;
|
||||
std::unique_lock<std::mutex> lk(link_list_locks[cur_id]);
|
||||
int *cur_link = get_neighbor_link(cur_id, level);
|
||||
auto cur_neighbor_num = get_neighbors_num(cur_link);
|
||||
|
||||
for (auto i = 1; i <= cur_neighbor_num; ++ i) {
|
||||
int candidate_id = cur_link[i];
|
||||
if (visited_array[candidate_id] == visited_array_tag) continue;
|
||||
visited_array[candidate_id] = visited_array_tag;
|
||||
float dcand = ptdis(candidate_id);
|
||||
if (top_candidates.size() < efConstruction || lb > dcand) {
|
||||
candidate_set.emplace(-dcand, candidate_id);
|
||||
top_candidates.emplace(dcand, candidate_id);
|
||||
if (top_candidates.size() > efConstruction)
|
||||
top_candidates.pop();
|
||||
if (!top_candidates.empty())
|
||||
lb = top_candidates.top().first;
|
||||
}
|
||||
}
|
||||
}
|
||||
visited_list_pool->releaseVisitedList(vl);
|
||||
return top_candidates;
|
||||
}
|
||||
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
|
||||
RHNSW::search_base_layer(DistanceComputer& ptdis,
|
||||
storage_idx_t nearest,
|
||||
storage_idx_t ef,
|
||||
float d_nearest,
|
||||
ConcurrentBitsetPtr bitset) const {
|
||||
VisitedList *vl = visited_list_pool->getFreeVisitedList();
|
||||
vl_type *visited_array = vl->mass;
|
||||
vl_type visited_array_tag = vl->curV;
|
||||
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates;
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> candidate_set;
|
||||
|
||||
float lb;
|
||||
if (bitset == nullptr || !bitset->test((faiss::ConcurrentBitset::id_type_t)(nearest))) {
|
||||
lb = d_nearest;
|
||||
top_candidates.emplace(d_nearest, nearest);
|
||||
candidate_set.emplace(-d_nearest, nearest);
|
||||
} else {
|
||||
lb = std::numeric_limits<float>::max();
|
||||
candidate_set.emplace(-lb, nearest);
|
||||
}
|
||||
visited_array[nearest] = visited_array_tag;
|
||||
|
||||
while (!candidate_set.empty()) {
|
||||
Node currNode = candidate_set.top();
|
||||
if ((-currNode.first) > lb)
|
||||
break;
|
||||
candidate_set.pop();
|
||||
int cur_id = currNode.second;
|
||||
int *cur_link = get_neighbor_link(cur_id, 0);
|
||||
auto cur_neighbor_num = get_neighbors_num(cur_link);
|
||||
for (auto i = 1; i <= cur_neighbor_num; ++ i) {
|
||||
int candidate_id = cur_link[i];
|
||||
if (visited_array[candidate_id] != visited_array_tag) {
|
||||
visited_array[candidate_id] = visited_array_tag;
|
||||
float dcand = ptdis(candidate_id);
|
||||
if (top_candidates.size() < ef || lb > dcand) {
|
||||
candidate_set.emplace(-dcand, candidate_id);
|
||||
if (bitset == nullptr || !bitset->test((faiss::ConcurrentBitset::id_type_t)(candidate_id)))
|
||||
top_candidates.emplace(dcand, candidate_id);
|
||||
if (top_candidates.size() > ef)
|
||||
top_candidates.pop();
|
||||
if (!top_candidates.empty())
|
||||
lb = top_candidates.top().first;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
visited_list_pool->releaseVisitedList(vl);
|
||||
return top_candidates;
|
||||
}
|
||||
|
||||
void
|
||||
RHNSW::make_connection(DistanceComputer& ptdis,
|
||||
storage_idx_t pt_id,
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
|
||||
int level) {
|
||||
int maxM = level ? M : M << 1;
|
||||
int *selectedNeighbors = (int*)malloc(sizeof(int) * maxM);
|
||||
int selectedNeighborsNum = 0;
|
||||
prune_neighbors(ptdis, cand, maxM, selectedNeighbors, selectedNeighborsNum);
|
||||
if (selectedNeighborsNum > maxM)
|
||||
throw std::runtime_error("Wrong size of candidates returned by prune_neighbors!");
|
||||
|
||||
int *cur_link = get_neighbor_link(pt_id, level);
|
||||
if (*cur_link)
|
||||
throw std::runtime_error("The newly inserted element should have blank link");
|
||||
|
||||
set_neighbors_num(cur_link, selectedNeighborsNum);
|
||||
for (auto i = 1; i <= selectedNeighborsNum; ++ i) {
|
||||
if (cur_link[i])
|
||||
throw std::runtime_error("Possible memory corruption.");
|
||||
if (level > levels[selectedNeighbors[i - 1]])
|
||||
throw std::runtime_error("Trying to make a link on a non-exisitent level.");
|
||||
cur_link[i] = selectedNeighbors[i - 1];
|
||||
}
|
||||
|
||||
for (auto i = 0; i < selectedNeighborsNum; ++ i) {
|
||||
std::unique_lock<std::mutex> lk(link_list_locks[selectedNeighbors[i]]);
|
||||
|
||||
int *selected_link = get_neighbor_link(selectedNeighbors[i], level);
|
||||
auto selected_neighbor_num = get_neighbors_num(selected_link);
|
||||
if (selected_neighbor_num > maxM)
|
||||
throw std::runtime_error("Bad value of selected_neighbor_num.");
|
||||
if (selectedNeighbors[i] == pt_id)
|
||||
throw std::runtime_error("Trying to connect an element to itself.");
|
||||
if (level > levels[selectedNeighbors[i]])
|
||||
throw std::runtime_error("Trying to make a link on a non-exisitent level.");
|
||||
if (selected_neighbor_num < maxM) {
|
||||
selected_link[selected_neighbor_num + 1] = pt_id;
|
||||
set_neighbors_num(selected_link, selected_neighbor_num + 1);
|
||||
} else {
|
||||
double d_max = ptdis(selectedNeighbors[i]);
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> candi;
|
||||
candi.emplace(d_max, pt_id);
|
||||
for (auto j = 1; j <= selected_neighbor_num; ++ j)
|
||||
candi.emplace(ptdis.symmetric_dis(selectedNeighbors[i], selected_link[j]), selected_link[j]);
|
||||
int indx = 0;
|
||||
prune_neighbors(ptdis, candi, maxM, selected_link + 1, indx);
|
||||
set_neighbors_num(selected_link, indx);
|
||||
}
|
||||
}
|
||||
|
||||
free(selectedNeighbors);
|
||||
}
|
||||
|
||||
void RHNSW::prune_neighbors(DistanceComputer& ptdis,
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
|
||||
const int maxM, int *ret, int &ret_len) {
|
||||
if (cand.size() < maxM) {
|
||||
while (!cand.empty()) {
|
||||
ret[ret_len ++] = cand.top().second;
|
||||
cand.pop();
|
||||
}
|
||||
return;
|
||||
}
|
||||
std::priority_queue<Node> closest;
|
||||
|
||||
while (!cand.empty()) {
|
||||
closest.emplace(-cand.top().first, cand.top().second);
|
||||
cand.pop();
|
||||
}
|
||||
|
||||
while (closest.size()) {
|
||||
if (ret_len >= maxM)
|
||||
break;
|
||||
Node curr = closest.top();
|
||||
float dist_to_query = -curr.first;
|
||||
closest.pop();
|
||||
bool good = true;
|
||||
for (auto i = 0; i < ret_len; ++ i) {
|
||||
float cur_dist = ptdis.symmetric_dis(curr.second, ret[i]);
|
||||
if (cur_dist < dist_to_query) {
|
||||
good = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (good) {
|
||||
ret[ret_len ++] = curr.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RHNSW::searchKnn(DistanceComputer& qdis, int k,
|
||||
idx_t *I, float *D,
|
||||
ConcurrentBitsetPtr bitset) const {
|
||||
if (levels.size() == 0)
|
||||
return;
|
||||
int ep = entry_point;
|
||||
float dist = qdis(ep);
|
||||
|
||||
for (auto i = max_level; i > 0; -- i) {
|
||||
bool good = true;
|
||||
while (good) {
|
||||
good = false;
|
||||
int *ep_link = get_neighbor_link(ep, i);
|
||||
auto ep_neighbors_cnt = get_neighbors_num(ep_link);
|
||||
for (auto j = 1; j <= ep_neighbors_cnt; ++ j) {
|
||||
int cand = ep_link[j];
|
||||
if (cand < 0 || cand > levels.size())
|
||||
throw std::runtime_error("cand error");
|
||||
float d = qdis(cand);
|
||||
if (d < dist) {
|
||||
dist = d;
|
||||
ep = cand;
|
||||
good = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates = search_base_layer(qdis, ep, std::max(efSearch, k), dist, bitset);
|
||||
while (top_candidates.size() > k)
|
||||
top_candidates.pop();
|
||||
int i = 0;
|
||||
while (!top_candidates.empty()) {
|
||||
I[i] = top_candidates.top().second;
|
||||
D[i] = top_candidates.top().first;
|
||||
i ++;
|
||||
top_candidates.pop();
|
||||
}
|
||||
}
|
||||
|
||||
size_t RHNSW::cal_size() {
|
||||
size_t ret = 0;
|
||||
ret += sizeof(*this);
|
||||
ret += visited_list_pool->GetSize();
|
||||
ret += link_list_locks.size() * sizeof(std::mutex);
|
||||
ret += levels.size() * sizeof(int);
|
||||
ret += levels.size() * level0_link_size;
|
||||
ret += levels.size() * sizeof(void*);
|
||||
for (auto i = 0; i < levels.size(); ++ i) {
|
||||
ret += levels[i] ? link_size * levels[i] : 0;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace faiss
|
|
@ -0,0 +1,367 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// -*- c++ -*-
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <unordered_set>
|
||||
#include <queue>
|
||||
|
||||
#include <omp.h>
|
||||
|
||||
#include <faiss/Index.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/random.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
||||
/** Implementation of the Hierarchical Navigable Small World
|
||||
* datastructure.
|
||||
*
|
||||
* Efficient and robust approximate nearest neighbor search using
|
||||
* Hierarchical Navigable Small World graphs
|
||||
*
|
||||
* Yu. A. Malkov, D. A. Yashunin, arXiv 2017
|
||||
*
|
||||
* This implmentation is heavily influenced by the hnswlib
|
||||
* implementation by Yury Malkov and Leonid Boystov
|
||||
* (https://github.com/searchivarius/nmslib/hnswlib)
|
||||
*
|
||||
* The HNSW object stores only the neighbor link structure, see
|
||||
* IndexHNSW.h for the full index object.
|
||||
*/
|
||||
|
||||
|
||||
struct DistanceComputer; // from AuxIndexStructures
|
||||
class VisitedListPool;
|
||||
|
||||
struct RHNSW {
|
||||
/// internal storage of vectors (32 bits: this is expensive)
|
||||
typedef int storage_idx_t;
|
||||
|
||||
/// Faiss results are 64-bit
|
||||
typedef Index::idx_t idx_t;
|
||||
|
||||
typedef std::pair<float, storage_idx_t> Node;
|
||||
|
||||
/** Heap structure that allows fast
|
||||
*/
|
||||
struct MinimaxHeap {
|
||||
int n;
|
||||
int k;
|
||||
int nvalid;
|
||||
|
||||
std::vector<storage_idx_t> ids;
|
||||
std::vector<float> dis;
|
||||
typedef faiss::CMax<float, storage_idx_t> HC;
|
||||
|
||||
explicit MinimaxHeap(int n): n(n), k(0), nvalid(0), ids(n), dis(n) {}
|
||||
|
||||
void push(storage_idx_t i, float v) {
|
||||
if (k == n) {
|
||||
if (v >= dis[0]) return;
|
||||
faiss::heap_pop<HC> (k--, dis.data(), ids.data());
|
||||
--nvalid;
|
||||
}
|
||||
faiss::heap_push<HC> (++k, dis.data(), ids.data(), v, i);
|
||||
++nvalid;
|
||||
}
|
||||
|
||||
float max() const {
|
||||
return dis[0];
|
||||
}
|
||||
|
||||
int size() const {
|
||||
return nvalid;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
nvalid = k = 0;
|
||||
}
|
||||
|
||||
int pop_min(float *vmin_out = nullptr) {
|
||||
assert(k > 0);
|
||||
// returns min. This is an O(n) operation
|
||||
int i = k - 1;
|
||||
while (i >= 0) {
|
||||
if (ids[i] != -1) break;
|
||||
i--;
|
||||
}
|
||||
if (i == -1) return -1;
|
||||
int imin = i;
|
||||
float vmin = dis[i];
|
||||
i--;
|
||||
while(i >= 0) {
|
||||
if (ids[i] != -1 && dis[i] < vmin) {
|
||||
vmin = dis[i];
|
||||
imin = i;
|
||||
}
|
||||
i--;
|
||||
}
|
||||
if (vmin_out) *vmin_out = vmin;
|
||||
int ret = ids[imin];
|
||||
ids[imin] = -1;
|
||||
--nvalid;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
int count_below(float thresh) {
|
||||
int n_below = 0;
|
||||
for(int i = 0; i < k; i++) {
|
||||
if (dis[i] < thresh) {
|
||||
n_below++;
|
||||
}
|
||||
}
|
||||
|
||||
return n_below;
|
||||
}
|
||||
};
|
||||
|
||||
/// to sort pairs of (id, distance) from nearest to fathest or the reverse
|
||||
struct NodeDistCloser {
|
||||
float d;
|
||||
int id;
|
||||
NodeDistCloser(float d, int id): d(d), id(id) {}
|
||||
bool operator < (const NodeDistCloser &obj1) const { return d < obj1.d; }
|
||||
};
|
||||
|
||||
struct NodeDistFarther {
|
||||
float d;
|
||||
int id;
|
||||
NodeDistFarther(float d, int id): d(d), id(id) {}
|
||||
bool operator < (const NodeDistFarther &obj1) const { return d > obj1.d; }
|
||||
};
|
||||
|
||||
struct CompareByFirst {
|
||||
constexpr bool operator()(Node const &a,
|
||||
Node const &b) const noexcept {
|
||||
return a.first < b.first;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// level of each vector (base level = 1), size = ntotal
|
||||
std::vector<int> levels;
|
||||
|
||||
/// number of entry points in levels > 0.
|
||||
int upper_beam;
|
||||
|
||||
/// entry point in the search structure (one of the points with maximum level
|
||||
storage_idx_t entry_point;
|
||||
|
||||
faiss::RandomGenerator rng;
|
||||
std::default_random_engine level_generator;
|
||||
|
||||
/// maximum level
|
||||
int max_level;
|
||||
int M;
|
||||
char *level0_links;
|
||||
char **linkLists;
|
||||
size_t level0_link_size;
|
||||
size_t link_size;
|
||||
double level_constant;
|
||||
VisitedListPool *visited_list_pool;
|
||||
std::vector<std::mutex> link_list_locks;
|
||||
std::mutex global;
|
||||
|
||||
/// expansion factor at construction time
|
||||
int efConstruction;
|
||||
|
||||
/// expansion factor at search time
|
||||
int efSearch;
|
||||
|
||||
/// range of entries in the neighbors table of vertex no at layer_no
|
||||
storage_idx_t* get_neighbor_link(idx_t no, int layer_no) const {
|
||||
return layer_no == 0 ? (int*)(level0_links + no * level0_link_size) : (int*)(linkLists[no] + (layer_no - 1) * link_size);
|
||||
}
|
||||
unsigned short int get_neighbors_num(int *p) const {
|
||||
return *((unsigned short int*)p);
|
||||
}
|
||||
void set_neighbors_num(int *p, unsigned short int num) const {
|
||||
*((unsigned short int*)(p)) = *((unsigned short int *)(&num));
|
||||
}
|
||||
|
||||
/// only mandatory parameter: nb of neighbors
|
||||
explicit RHNSW(int M = 32);
|
||||
~RHNSW();
|
||||
|
||||
void init(int ntotal);
|
||||
/// pick a random level for a new point, arg = 1/log(M)
|
||||
int random_level(double arg) {
|
||||
std::uniform_real_distribution<double> distribution(0.0, 1.0);
|
||||
double r = -log(distribution(level_generator)) * arg;
|
||||
return (int)r;
|
||||
}
|
||||
|
||||
void reset();
|
||||
|
||||
int prepare_level_tab(size_t n, bool preset_levels = false);
|
||||
|
||||
// re-implementations inspired by hnswlib
|
||||
/** add point pt_id on all levels <= pt_level and build the link
|
||||
* structure for them. inspired by implementation of hnswlib */
|
||||
void addPoint(DistanceComputer& ptdis, int pt_level, int pt_id);
|
||||
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
|
||||
search_layer (DistanceComputer& ptdis,
|
||||
storage_idx_t pt_id,
|
||||
storage_idx_t nearest,
|
||||
int level);
|
||||
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
|
||||
search_base_layer (DistanceComputer& ptdis,
|
||||
storage_idx_t nearest,
|
||||
storage_idx_t ef,
|
||||
float d_nearest,
|
||||
ConcurrentBitsetPtr bitset = nullptr) const;
|
||||
|
||||
void make_connection(DistanceComputer& ptdis,
|
||||
storage_idx_t pt_id,
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
|
||||
int level);
|
||||
|
||||
void prune_neighbors(DistanceComputer& ptdis,
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
|
||||
const int maxM, int *ret, int &ret_len);
|
||||
|
||||
/// search interface inspired by hnswlib
|
||||
void searchKnn(DistanceComputer& qdis, int k,
|
||||
idx_t *I, float *D,
|
||||
ConcurrentBitsetPtr bitset = nullptr) const;
|
||||
|
||||
size_t cal_size();
|
||||
|
||||
};
|
||||
|
||||
|
||||
/**************************************************************
|
||||
* Auxiliary structures
|
||||
**************************************************************/
|
||||
|
||||
typedef unsigned short int vl_type;
|
||||
|
||||
class VisitedList {
|
||||
public:
|
||||
vl_type curV;
|
||||
vl_type *mass;
|
||||
unsigned int numelements;
|
||||
|
||||
VisitedList(int numelements1) {
|
||||
curV = -1;
|
||||
numelements = numelements1;
|
||||
mass = new vl_type[numelements];
|
||||
}
|
||||
|
||||
void reset() {
|
||||
curV++;
|
||||
if (curV == 0) {
|
||||
memset(mass, 0, sizeof(vl_type) * numelements);
|
||||
curV++;
|
||||
}
|
||||
};
|
||||
|
||||
// keep compatibae with original version VisitedTable
|
||||
/// set flog #no to true
|
||||
void set(int no) {
|
||||
mass[no] = curV;
|
||||
}
|
||||
|
||||
/// get flag #no
|
||||
bool get(int no) const {
|
||||
return mass[no] == curV;
|
||||
}
|
||||
|
||||
void advance() {
|
||||
reset();
|
||||
}
|
||||
|
||||
~VisitedList() { delete[] mass; }
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
//
|
||||
// Class for multi-threaded pool-management of VisitedLists
|
||||
//
|
||||
/////////////////////////////////////////////////////////
|
||||
|
||||
class VisitedListPool {
|
||||
std::deque<VisitedList *> pool;
|
||||
std::mutex poolguard;
|
||||
int numelements;
|
||||
|
||||
public:
|
||||
VisitedListPool(int initmaxpools, int numelements1) {
|
||||
numelements = numelements1;
|
||||
for (int i = 0; i < initmaxpools; i++)
|
||||
pool.push_front(new VisitedList(numelements));
|
||||
}
|
||||
|
||||
VisitedList *getFreeVisitedList() {
|
||||
VisitedList *rez;
|
||||
{
|
||||
std::unique_lock <std::mutex> lock(poolguard);
|
||||
if (pool.size() > 0) {
|
||||
rez = pool.front();
|
||||
pool.pop_front();
|
||||
} else {
|
||||
rez = new VisitedList(numelements);
|
||||
}
|
||||
}
|
||||
rez->reset();
|
||||
return rez;
|
||||
};
|
||||
|
||||
void releaseVisitedList(VisitedList *vl) {
|
||||
std::unique_lock <std::mutex> lock(poolguard);
|
||||
pool.push_front(vl);
|
||||
};
|
||||
|
||||
~VisitedListPool() {
|
||||
while (pool.size()) {
|
||||
VisitedList *rez = pool.front();
|
||||
pool.pop_front();
|
||||
delete rez;
|
||||
}
|
||||
};
|
||||
|
||||
int64_t GetSize() {
|
||||
auto visit_list_size = sizeof(VisitedList) + numelements * sizeof(vl_type);
|
||||
auto pool_size = pool.size() * (sizeof(VisitedList *) + visit_list_size);
|
||||
return pool_size + sizeof(*this);
|
||||
}
|
||||
};
|
||||
|
||||
struct RHNSWStats {
|
||||
size_t n1, n2, n3;
|
||||
size_t ndis;
|
||||
size_t nreorder;
|
||||
bool view;
|
||||
|
||||
RHNSWStats() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
n1 = n2 = n3 = 0;
|
||||
ndis = 0;
|
||||
nreorder = 0;
|
||||
view = false;
|
||||
}
|
||||
};
|
||||
|
||||
// global var that collects them all
|
||||
extern RHNSWStats rhnsw_stats;
|
||||
|
||||
|
||||
} // namespace faiss
|
|
@ -75,6 +75,7 @@ struct ScalarQuantizer {
|
|||
(MetricType mt, const Index *quantizer, bool store_pairs,
|
||||
bool by_residual=false) const;
|
||||
|
||||
size_t cal_size() { return sizeof(*this) + trained.size() * sizeof(float); }
|
||||
};
|
||||
|
||||
template<class DCClass>
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include <faiss/IndexScalarQuantizer.h>
|
||||
#include <faiss/IndexSQHybrid.h>
|
||||
#include <faiss/IndexHNSW.h>
|
||||
#include <faiss/IndexRHNSW.h>
|
||||
#include <faiss/IndexLattice.h>
|
||||
|
||||
#include <faiss/OnDiskInvertedLists.h>
|
||||
|
@ -458,6 +459,29 @@ static void read_HNSW (HNSW *hnsw, IOReader *f) {
|
|||
READ1 (hnsw->upper_beam);
|
||||
}
|
||||
|
||||
static void read_RHNSW (RHNSW *rhnsw, IOReader *f) {
|
||||
READ1 (rhnsw->entry_point);
|
||||
READ1 (rhnsw->max_level);
|
||||
READ1 (rhnsw->M);
|
||||
READ1 (rhnsw->level0_link_size);
|
||||
READ1 (rhnsw->link_size);
|
||||
READ1 (rhnsw->level_constant);
|
||||
READ1 (rhnsw->efConstruction);
|
||||
READ1 (rhnsw->efSearch);
|
||||
|
||||
READVECTOR (rhnsw->levels);
|
||||
auto ntotal = rhnsw->levels.size();
|
||||
rhnsw->level0_links = (char*) malloc(ntotal * rhnsw->level0_link_size);
|
||||
READANDCHECK( rhnsw->level0_links, ntotal * rhnsw->level0_link_size);
|
||||
rhnsw->linkLists = (char**) malloc(ntotal * sizeof(void*));
|
||||
for (auto i = 0; i < ntotal; ++ i) {
|
||||
if (rhnsw->levels[i]) {
|
||||
rhnsw->linkLists[i] = (char*)malloc(rhnsw->link_size * rhnsw->levels[i] + 1);
|
||||
READANDCHECK( rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ProductQuantizer * read_ProductQuantizer (const char*fname) {
|
||||
FileIOReader reader(fname);
|
||||
return read_ProductQuantizer(&reader);
|
||||
|
@ -624,6 +648,9 @@ Index *read_index (IOReader *f, int io_flags) {
|
|||
if (h == fourcc ("IxPQ") || h == fourcc ("IxPo")) {
|
||||
idxp->metric_type = METRIC_L2;
|
||||
}
|
||||
if (h == fourcc("IxPq")) {
|
||||
idxp->pq.compute_sdc_table ();
|
||||
}
|
||||
idx = idxp;
|
||||
} else if (h == fourcc ("IvFl") || h == fourcc("IvFL")) { // legacy
|
||||
IndexIVFFlat * ivfl = new IndexIVFFlat ();
|
||||
|
@ -800,6 +827,17 @@ Index *read_index (IOReader *f, int io_flags) {
|
|||
dynamic_cast<IndexPQ*>(idxhnsw->storage)->pq.compute_sdc_table ();
|
||||
}
|
||||
idx = idxhnsw;
|
||||
} else if(h == fourcc("IRHf") || h == fourcc("IRHp") ||
|
||||
h == fourcc("IRHs") || h == fourcc("IRH2")) {
|
||||
IndexRHNSW *idxrhnsw = nullptr;
|
||||
if (h == fourcc("IRHf")) idxrhnsw = new IndexRHNSWFlat ();
|
||||
if (h == fourcc("IRHp")) idxrhnsw = new IndexRHNSWPQ ();
|
||||
if (h == fourcc("IRHs")) idxrhnsw = new IndexRHNSWSQ ();
|
||||
if (h == fourcc("IRH2")) idxrhnsw = new IndexRHNSW2Level ();
|
||||
read_index_header (idxrhnsw, f);
|
||||
read_RHNSW (&idxrhnsw->hnsw, f);
|
||||
idxrhnsw->own_fields = true;
|
||||
idx = idxrhnsw;
|
||||
} else {
|
||||
FAISS_THROW_FMT("Index type 0x%08x not supported\n", h);
|
||||
idx = nullptr;
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include <faiss/IndexScalarQuantizer.h>
|
||||
#include <faiss/IndexSQHybrid.h>
|
||||
#include <faiss/IndexHNSW.h>
|
||||
#include <faiss/IndexRHNSW.h>
|
||||
#include <faiss/IndexLattice.h>
|
||||
|
||||
#include <faiss/OnDiskInvertedLists.h>
|
||||
|
@ -364,6 +365,24 @@ static void write_HNSW (const HNSW *hnsw, IOWriter *f) {
|
|||
WRITE1 (hnsw->upper_beam);
|
||||
}
|
||||
|
||||
static void write_RHNSW (const RHNSW *rhnsw, IOWriter *f) {
|
||||
WRITE1 (rhnsw->entry_point);
|
||||
WRITE1 (rhnsw->max_level);
|
||||
WRITE1 (rhnsw->M);
|
||||
WRITE1 (rhnsw->level0_link_size);
|
||||
WRITE1 (rhnsw->link_size);
|
||||
WRITE1 (rhnsw->level_constant);
|
||||
WRITE1 (rhnsw->efConstruction);
|
||||
WRITE1 (rhnsw->efSearch);
|
||||
|
||||
WRITEVECTOR (rhnsw->levels);
|
||||
WRITEANDCHECK (rhnsw->level0_links, rhnsw->level0_link_size * rhnsw->levels.size());
|
||||
for (auto i = 0; i < rhnsw->levels.size(); ++ i) {
|
||||
if (rhnsw->levels[i])
|
||||
WRITEANDCHECK (rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1);
|
||||
}
|
||||
}
|
||||
|
||||
static void write_direct_map (const DirectMap *dm, IOWriter *f) {
|
||||
char maintain_direct_map = (char)dm->type; // for backwards compatibility with bool
|
||||
WRITE1 (maintain_direct_map);
|
||||
|
@ -560,6 +579,18 @@ void write_index (const Index *idx, IOWriter *f) {
|
|||
write_index_header (idxhnsw, f);
|
||||
write_HNSW (&idxhnsw->hnsw, f);
|
||||
write_index (idxhnsw->storage, f);
|
||||
} else if (const IndexRHNSW * idxrhnsw =
|
||||
dynamic_cast<const IndexRHNSW *>(idx)) {
|
||||
uint32_t h =
|
||||
dynamic_cast<const IndexRHNSWFlat*>(idx) ? fourcc("IRHf") :
|
||||
dynamic_cast<const IndexRHNSWPQ*>(idx) ? fourcc("IRHp") :
|
||||
dynamic_cast<const IndexRHNSWSQ*>(idx) ? fourcc("IRHs") :
|
||||
dynamic_cast<const IndexRHNSW2Level*>(idx) ? fourcc("IRH2") :
|
||||
0;
|
||||
FAISS_THROW_IF_NOT (h != 0);
|
||||
WRITE1 (h);
|
||||
write_index_header (idxrhnsw, f);
|
||||
write_RHNSW (&idxrhnsw->hnsw, f);
|
||||
} else {
|
||||
FAISS_THROW_MSG ("don't know how to serialize this type of index");
|
||||
}
|
||||
|
|
|
@ -1224,4 +1224,4 @@ namespace hnswlib_nm {
|
|||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -219,6 +219,42 @@ endif ()
|
|||
target_link_libraries(test_hnsw_sq8nm ${depend_libs} ${unittest_libs} ${basic_libs})
|
||||
install(TARGETS test_hnsw_sq8nm DESTINATION unittest)
|
||||
|
||||
################################################################################
|
||||
#<RHNSW_FLAT-TEST>
|
||||
set(rhnsw_flat_srcs
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.cpp
|
||||
)
|
||||
if (NOT TARGET test_rhnsw_flat)
|
||||
add_executable(test_rhnsw_flat test_rhnsw_flat.cpp ${rhnsw_flat_srcs} ${util_srcs} ${faiss_srcs})
|
||||
endif ()
|
||||
target_link_libraries(test_rhnsw_flat ${depend_libs} ${unittest_libs} ${basic_libs})
|
||||
install(TARGETS test_rhnsw_flat DESTINATION unittest)
|
||||
|
||||
################################################################################
|
||||
#<RHNSW_PQ-TEST>
|
||||
set(rhnsw_pq_srcs
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.cpp
|
||||
)
|
||||
if (NOT TARGET test_rhnsw_pq)
|
||||
add_executable(test_rhnsw_pq test_rhnsw_pq.cpp ${rhnsw_pq_srcs} ${util_srcs} ${faiss_srcs})
|
||||
endif ()
|
||||
target_link_libraries(test_rhnsw_pq ${depend_libs} ${unittest_libs} ${basic_libs})
|
||||
install(TARGETS test_rhnsw_pq DESTINATION unittest)
|
||||
|
||||
################################################################################
|
||||
#<RHNSW_SQ8-TEST>
|
||||
set(rhnsw_sq8_srcs
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.cpp
|
||||
)
|
||||
if (NOT TARGET test_rhnsw_sq8)
|
||||
add_executable(test_rhnsw_sq8 test_rhnsw_sq8.cpp ${rhnsw_sq8_srcs} ${util_srcs} ${faiss_srcs})
|
||||
endif ()
|
||||
target_link_libraries(test_rhnsw_sq8 ${depend_libs} ${unittest_libs} ${basic_libs})
|
||||
install(TARGETS test_rhnsw_sq8 DESTINATION unittest)
|
||||
|
||||
################################################################################
|
||||
#<SPTAG-TEST>
|
||||
if (MILVUS_SUPPORT_SPTAG)
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <knowhere/index/vector_index/IndexRHNSWFlat.h>
|
||||
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "unittest/utils.h"
|
||||
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class RHNSWFlatTest : public DataGen, public TestWithParam<std::string> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
IndexType = GetParam();
|
||||
std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
|
||||
Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10
|
||||
index_ = std::make_shared<milvus::knowhere::IndexRHNSWFlat>();
|
||||
conf = milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10},
|
||||
{milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200},
|
||||
{milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
};
|
||||
}
|
||||
|
||||
protected:
|
||||
milvus::knowhere::Config conf;
|
||||
std::shared_ptr<milvus::knowhere::IndexRHNSWFlat> index_ = nullptr;
|
||||
std::string IndexType;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWFlatTest, Values("RHNSWFlat"));
|
||||
|
||||
TEST_P(RHNSWFlatTest, HNSW_basic) {
|
||||
assert(!xb.empty());
|
||||
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result1, nq, k);
|
||||
|
||||
// Serialize and Load before Query
|
||||
milvus::knowhere::BinarySet bs = index_->Serialize();
|
||||
|
||||
auto tmp_index = std::make_shared<milvus::knowhere::IndexRHNSWFlat>();
|
||||
|
||||
tmp_index->Load(bs);
|
||||
|
||||
auto result2 = tmp_index->Query(query_dataset, conf);
|
||||
AssertAnns(result2, nq, k);
|
||||
}
|
||||
|
||||
TEST_P(RHNSWFlatTest, HNSW_delete) {
|
||||
assert(!xb.empty());
|
||||
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
bitset->set(i);
|
||||
}
|
||||
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result1, nq, k);
|
||||
|
||||
index_->SetBlacklist(bitset);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
|
||||
/*
|
||||
* delete result checked by eyes
|
||||
auto ids1 = result1->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto ids2 = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
std::cout << std::endl;
|
||||
for (int i = 0; i < nq; ++ i) {
|
||||
std::cout << "ids1: ";
|
||||
for (int j = 0; j < k; ++ j) {
|
||||
std::cout << *(ids1 + i * k + j) << " ";
|
||||
}
|
||||
std::cout << "ids2: ";
|
||||
for (int j = 0; j < k; ++ j) {
|
||||
std::cout << *(ids2 + i * k + j) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (int j = 0; j < std::min(5, k>>1); ++ j) {
|
||||
ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j));
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
TEST_P(RHNSWFlatTest, HNSW_serialize) {
|
||||
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
{
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
}
|
||||
|
||||
FileIOReader reader(filename);
|
||||
reader(ret, bin->size);
|
||||
};
|
||||
|
||||
{
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
auto binaryset = index_->Serialize();
|
||||
std::string index_type = index_->index_type();
|
||||
std::string idx_name = index_type + "_Index";
|
||||
std::string dat_name = index_type + "_Data";
|
||||
if (binaryset.binary_map_.find(idx_name) == binaryset.binary_map_.end()) {
|
||||
std::cout << "no idx!" << std::endl;
|
||||
}
|
||||
if (binaryset.binary_map_.find(dat_name) == binaryset.binary_map_.end()) {
|
||||
std::cout << "no dat!" << std::endl;
|
||||
}
|
||||
auto bin_idx = binaryset.GetByName(idx_name);
|
||||
auto bin_dat = binaryset.GetByName(dat_name);
|
||||
|
||||
std::string filename_idx = "/tmp/RHNSWFlat_test_serialize_idx.bin";
|
||||
std::string filename_dat = "/tmp/RHNSWFlat_test_serialize_dat.bin";
|
||||
auto load_idx = new uint8_t[bin_idx->size];
|
||||
auto load_dat = new uint8_t[bin_dat->size];
|
||||
serialize(filename_idx, bin_idx, load_idx);
|
||||
serialize(filename_dat, bin_dat, load_dat);
|
||||
|
||||
binaryset.clear();
|
||||
auto new_idx = std::make_shared<milvus::knowhere::IndexRHNSWFlat>();
|
||||
std::shared_ptr<uint8_t[]> dat(load_dat);
|
||||
std::shared_ptr<uint8_t[]> idx(load_idx);
|
||||
binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size);
|
||||
binaryset.Append(new_idx->index_type() + "_Data", dat, bin_dat->size);
|
||||
|
||||
new_idx->Load(binaryset);
|
||||
EXPECT_EQ(new_idx->Count(), nb);
|
||||
EXPECT_EQ(new_idx->Dim(), dim);
|
||||
auto result = new_idx->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,148 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <knowhere/index/vector_index/IndexRHNSWPQ.h>
|
||||
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "unittest/utils.h"
|
||||
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class RHNSWPQTest : public DataGen, public TestWithParam<std::string> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
IndexType = GetParam();
|
||||
std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
|
||||
Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10
|
||||
index_ = std::make_shared<milvus::knowhere::IndexRHNSWPQ>();
|
||||
conf = milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10},
|
||||
{milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200},
|
||||
{milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{milvus::knowhere::IndexParams::PQM, 8}};
|
||||
}
|
||||
|
||||
protected:
|
||||
milvus::knowhere::Config conf;
|
||||
std::shared_ptr<milvus::knowhere::IndexRHNSWPQ> index_ = nullptr;
|
||||
std::string IndexType;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWPQTest, Values("RHNSWPQ"));
|
||||
|
||||
TEST_P(RHNSWPQTest, HNSW_basic) {
|
||||
assert(!xb.empty());
|
||||
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
// Serialize and Load before Query
|
||||
milvus::knowhere::BinarySet bs = index_->Serialize();
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
// AssertAnns(result1, nq, k);
|
||||
|
||||
auto tmp_index = std::make_shared<milvus::knowhere::IndexRHNSWPQ>();
|
||||
|
||||
tmp_index->Load(bs);
|
||||
|
||||
auto result2 = tmp_index->Query(query_dataset, conf);
|
||||
// AssertAnns(result2, nq, k);
|
||||
}
|
||||
|
||||
TEST_P(RHNSWPQTest, HNSW_delete) {
|
||||
assert(!xb.empty());
|
||||
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
bitset->set(i);
|
||||
}
|
||||
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
// AssertAnns(result1, nq, k);
|
||||
|
||||
index_->SetBlacklist(bitset);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
// AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
|
||||
/*
|
||||
* delete result checked by eyes
|
||||
auto ids1 = result1->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto ids2 = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
std::cout << std::endl;
|
||||
for (int i = 0; i < nq; ++ i) {
|
||||
std::cout << "ids1: ";
|
||||
for (int j = 0; j < k; ++ j) {
|
||||
std::cout << *(ids1 + i * k + j) << " ";
|
||||
}
|
||||
std::cout << "ids2: ";
|
||||
for (int j = 0; j < k; ++ j) {
|
||||
std::cout << *(ids2 + i * k + j) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (int j = 0; j < std::min(5, k>>1); ++ j) {
|
||||
ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j));
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
TEST_P(RHNSWPQTest, HNSW_serialize) {
|
||||
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
{
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
}
|
||||
|
||||
FileIOReader reader(filename);
|
||||
reader(ret, bin->size);
|
||||
};
|
||||
|
||||
{
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
auto binaryset = index_->Serialize();
|
||||
auto bin_idx = binaryset.GetByName(index_->index_type() + "_Index");
|
||||
auto bin_dat = binaryset.GetByName(index_->index_type() + "_Data");
|
||||
|
||||
std::string filename_idx = "/tmp/RHNSWPQ_test_serialize_idx.bin";
|
||||
std::string filename_dat = "/tmp/RHNSWPQ_test_serialize_dat.bin";
|
||||
auto load_idx = new uint8_t[bin_idx->size];
|
||||
auto load_dat = new uint8_t[bin_dat->size];
|
||||
serialize(filename_idx, bin_idx, load_idx);
|
||||
serialize(filename_dat, bin_dat, load_dat);
|
||||
|
||||
binaryset.clear();
|
||||
auto new_idx = std::make_shared<milvus::knowhere::IndexRHNSWPQ>();
|
||||
std::shared_ptr<uint8_t[]> dat(load_dat);
|
||||
std::shared_ptr<uint8_t[]> idx(load_idx);
|
||||
binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size);
|
||||
binaryset.Append(new_idx->index_type() + "_Data", dat, bin_dat->size);
|
||||
|
||||
new_idx->Load(binaryset);
|
||||
EXPECT_EQ(new_idx->Count(), nb);
|
||||
EXPECT_EQ(new_idx->Dim(), dim);
|
||||
auto result = new_idx->Query(query_dataset, conf);
|
||||
// AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <knowhere/index/vector_index/IndexRHNSWSQ.h>
|
||||
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "unittest/utils.h"
|
||||
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class RHNSWSQ8Test : public DataGen, public TestWithParam<std::string> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
IndexType = GetParam();
|
||||
std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
|
||||
Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10
|
||||
// Generate(2, 10, 2); // dim = 64, nb = 10000, nq = 10
|
||||
index_ = std::make_shared<milvus::knowhere::IndexRHNSWSQ>();
|
||||
conf = milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10},
|
||||
{milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200},
|
||||
{milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
};
|
||||
}
|
||||
|
||||
protected:
|
||||
milvus::knowhere::Config conf;
|
||||
std::shared_ptr<milvus::knowhere::IndexRHNSWSQ> index_ = nullptr;
|
||||
std::string IndexType;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWSQ8Test, Values("RHNSWSQ8"));
|
||||
|
||||
TEST_P(RHNSWSQ8Test, HNSW_basic) {
|
||||
assert(!xb.empty());
|
||||
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
// Serialize and Load before Query
|
||||
milvus::knowhere::BinarySet bs = index_->Serialize();
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result1, nq, k);
|
||||
|
||||
auto tmp_index = std::make_shared<milvus::knowhere::IndexRHNSWSQ>();
|
||||
|
||||
tmp_index->Load(bs);
|
||||
|
||||
auto result2 = tmp_index->Query(query_dataset, conf);
|
||||
AssertAnns(result2, nq, k);
|
||||
}
|
||||
|
||||
TEST_P(RHNSWSQ8Test, HNSW_delete) {
|
||||
assert(!xb.empty());
|
||||
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
bitset->set(i);
|
||||
}
|
||||
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result1, nq, k);
|
||||
|
||||
index_->SetBlacklist(bitset);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
|
||||
/*
|
||||
* delete result checked by eyes
|
||||
auto ids1 = result1->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto ids2 = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
std::cout << std::endl;
|
||||
for (int i = 0; i < nq; ++ i) {
|
||||
std::cout << "ids1: ";
|
||||
for (int j = 0; j < k; ++ j) {
|
||||
std::cout << *(ids1 + i * k + j) << " ";
|
||||
}
|
||||
std::cout << "ids2: ";
|
||||
for (int j = 0; j < k; ++ j) {
|
||||
std::cout << *(ids2 + i * k + j) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (int j = 0; j < std::min(5, k>>1); ++ j) {
|
||||
ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j));
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
TEST_P(RHNSWSQ8Test, HNSW_serialize) {
|
||||
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
{
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
}
|
||||
|
||||
FileIOReader reader(filename);
|
||||
reader(ret, bin->size);
|
||||
};
|
||||
|
||||
{
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
auto binaryset = index_->Serialize();
|
||||
auto bin_idx = binaryset.GetByName(index_->index_type() + "_Index");
|
||||
auto bin_dat = binaryset.GetByName(index_->index_type() + "_Data");
|
||||
|
||||
std::string filename_idx = "/tmp/RHNSWSQ_test_serialize_idx.bin";
|
||||
std::string filename_dat = "/tmp/RHNSWSQ_test_serialize_dat.bin";
|
||||
auto load_idx = new uint8_t[bin_idx->size];
|
||||
auto load_dat = new uint8_t[bin_dat->size];
|
||||
serialize(filename_idx, bin_idx, load_idx);
|
||||
serialize(filename_dat, bin_dat, load_dat);
|
||||
|
||||
binaryset.clear();
|
||||
auto new_idx = std::make_shared<milvus::knowhere::IndexRHNSWSQ>();
|
||||
std::shared_ptr<uint8_t[]> dat(load_dat);
|
||||
std::shared_ptr<uint8_t[]> idx(load_idx);
|
||||
binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size);
|
||||
binaryset.Append(new_idx->index_type() + "_Data", dat, bin_dat->size);
|
||||
|
||||
new_idx->Load(binaryset);
|
||||
EXPECT_EQ(new_idx->Count(), nb);
|
||||
EXPECT_EQ(new_idx->Dim(), dim);
|
||||
auto result = new_idx->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
}
|
||||
}
|
|
@ -279,7 +279,9 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
|
|||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_HNSW || index_type == knowhere::IndexEnum::INDEX_HNSW_SQ8NM) {
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_HNSW || index_type == knowhere::IndexEnum::INDEX_HNSW_SQ8NM ||
|
||||
index_type == knowhere::IndexEnum::INDEX_RHNSWPQ || index_type == knowhere::IndexEnum::INDEX_RHNSWSQ ||
|
||||
index_type == knowhere::IndexEnum::INDEX_RHNSWFlat) {
|
||||
auto status = CheckParameterRange(index_params, knowhere::IndexParams::M, 4, 64);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
|
@ -288,6 +290,38 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
|
|||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
if (index_type == knowhere::IndexEnum::INDEX_RHNSWPQ) {
|
||||
status = CheckParameterExistence(index_params, knowhere::IndexParams::PQM);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// special check for 'PQM' parameter
|
||||
std::vector<int64_t> resset;
|
||||
milvus::knowhere::IVFPQConfAdapter::GetValidMList(dimension, resset);
|
||||
int64_t pqm_value = index_params[knowhere::IndexParams::PQM];
|
||||
if (resset.empty()) {
|
||||
std::string msg = "Invalid collection dimension, unable to get reasonable values for 'PQM'";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_COLLECTION_DIMENSION, msg);
|
||||
}
|
||||
|
||||
auto iter = std::find(std::begin(resset), std::end(resset), pqm_value);
|
||||
if (iter == std::end(resset)) {
|
||||
std::string msg =
|
||||
"Invalid " + std::string(knowhere::IndexParams::PQM) + ", must be one of the following values: ";
|
||||
for (size_t i = 0; i < resset.size(); i++) {
|
||||
if (i != 0) {
|
||||
msg += ",";
|
||||
}
|
||||
msg += std::to_string(resset[i]);
|
||||
}
|
||||
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}
|
||||
}
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_ANNOY) {
|
||||
auto status = CheckParameterRange(index_params, knowhere::IndexParams::n_trees, 1, 1024);
|
||||
if (!status.ok()) {
|
||||
|
|
|
@ -25,6 +25,9 @@ const char* NAME_ENGINE_TYPE_HNSW = "HNSW";
|
|||
const char* NAME_ENGINE_TYPE_ANNOY = "ANNOY";
|
||||
const char* NAME_ENGINE_TYPE_IVFSQ8NR = "IVFSQ8NR";
|
||||
const char* NAME_ENGINE_TYPE_HNSWSQ8NM = "HNSWSQ8NM";
|
||||
const char* NAME_ENGINE_TYPE_RHNSWFLAT = "RHNSWFLAT";
|
||||
const char* NAME_ENGINE_TYPE_RHNSWPQ = "RHNSWPQ";
|
||||
const char* NAME_ENGINE_TYPE_RHNSWSQ8 = "RHNSWSQ8";
|
||||
|
||||
const char* NAME_METRIC_TYPE_L2 = "L2";
|
||||
const char* NAME_METRIC_TYPE_IP = "IP";
|
||||
|
|
|
@ -28,6 +28,9 @@ extern const char* NAME_ENGINE_TYPE_IVFPQ;
|
|||
extern const char* NAME_ENGINE_TYPE_HNSW;
|
||||
extern const char* NAME_ENGINE_TYPE_HNSW_SQ8NM;
|
||||
extern const char* NAME_ENGINE_TYPE_ANNOY;
|
||||
extern const char* NAME_ENGINE_TYPE_RHNSWFLAT;
|
||||
extern const char* NAME_ENGINE_TYPE_RHNSWPQ;
|
||||
extern const char* NAME_ENGINE_TYPE_RHNSWSQ;
|
||||
|
||||
extern const char* NAME_METRIC_TYPE_L2;
|
||||
extern const char* NAME_METRIC_TYPE_IP;
|
||||
|
|
Loading…
Reference in New Issue