diff --git a/CHANGELOG.md b/CHANGELOG.md index 17e90dd346..25c315ce3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,7 @@ Please mark all changes in change log and use the issue from GitHub - \#2802 Add new index: IVFSQ8NR - \#2834 Add C++ sdk support 4 hnsw_sq8nr - \#2940 Add option to build.sh for cuda arch +- \#3132 Refine the implementation of hnsw in faiss and add support for hnsw-flat, hnsw-pq and hnsw-sq8 based on faiss ## Improvement - \#2543 Remove secondary_path related code diff --git a/core/src/index/knowhere/CMakeLists.txt b/core/src/index/knowhere/CMakeLists.txt index 172f9ee078..c2c31ed7e7 100644 --- a/core/src/index/knowhere/CMakeLists.txt +++ b/core/src/index/knowhere/CMakeLists.txt @@ -63,6 +63,10 @@ set(vector_index_srcs knowhere/index/IndexType.cpp knowhere/index/vector_index/VecIndexFactory.cpp knowhere/index/vector_index/IndexAnnoy.cpp + knowhere/index/vector_index/IndexRHNSW.cpp + knowhere/index/vector_index/IndexRHNSWFlat.cpp + knowhere/index/vector_index/IndexRHNSWSQ.cpp + knowhere/index/vector_index/IndexRHNSWPQ.cpp ) set(vector_offset_index_srcs diff --git a/core/src/index/knowhere/knowhere/index/IndexType.cpp b/core/src/index/knowhere/knowhere/index/IndexType.cpp index 8eac2ae839..771729c2aa 100644 --- a/core/src/index/knowhere/knowhere/index/IndexType.cpp +++ b/core/src/index/knowhere/knowhere/index/IndexType.cpp @@ -34,6 +34,9 @@ const char* INDEX_SPTAG_KDT_RNT = "SPTAG_KDT_RNT"; const char* INDEX_SPTAG_BKT_RNT = "SPTAG_BKT_RNT"; #endif const char* INDEX_HNSW = "HNSW"; +const char* INDEX_RHNSWFlat = "RHNSW_FLAT"; +const char* INDEX_RHNSWPQ = "RHNSW_PQ"; +const char* INDEX_RHNSWSQ = "RHNSW_SQ"; const char* INDEX_ANNOY = "ANNOY"; const char* INDEX_HNSW_SQ8NM = "HNSW_SQ8NM"; } // namespace IndexEnum diff --git a/core/src/index/knowhere/knowhere/index/IndexType.h b/core/src/index/knowhere/knowhere/index/IndexType.h index b5fd1b9186..58c9575c3c 100644 --- a/core/src/index/knowhere/knowhere/index/IndexType.h +++ b/core/src/index/knowhere/knowhere/index/IndexType.h @@ -37,6 +37,9 @@ enum class OldIndexType { ANNOY, FAISS_IVFSQ8NR, HNSW_SQ8NM, + RHNSW_FLAT, + RHNSW_PQ, + RHNSW_SQ, FAISS_BIN_IDMAP = 100, FAISS_BIN_IVFLAT_CPU = 101, }; @@ -60,6 +63,9 @@ extern const char* INDEX_SPTAG_KDT_RNT; extern const char* INDEX_SPTAG_BKT_RNT; #endif extern const char* INDEX_HNSW; +extern const char* INDEX_RHNSWFlat; +extern const char* INDEX_RHNSWPQ; +extern const char* INDEX_RHNSWSQ; extern const char* INDEX_ANNOY; extern const char* INDEX_HNSW_SQ8NM; } // namespace IndexEnum diff --git a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp index 8fd99bda43..ddafc486d4 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp @@ -280,6 +280,80 @@ HNSWSQ8NRConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const In return ConfAdapter::CheckSearch(oricfg, type, mode); } +bool +RHNSWFlatConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static int64_t MIN_EFCONSTRUCTION = 8; + static int64_t MAX_EFCONSTRUCTION = 512; + static int64_t MIN_M = 4; + static int64_t MAX_M = 64; + + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + + return ConfAdapter::CheckTrain(oricfg, mode); +} + +bool +RHNSWFlatConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + static int64_t MAX_EF = 4096; + + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + +bool +RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static int64_t MIN_EFCONSTRUCTION = 8; + static int64_t MAX_EFCONSTRUCTION = 512; + static int64_t MIN_M = 4; + static int64_t MAX_M = 64; + + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + + std::vector resset; + int64_t dimension = oricfg[knowhere::meta::DIM].get(); + IVFPQConfAdapter::GetValidMList(dimension, resset); + + CheckIntByValues(knowhere::IndexParams::m, resset); + return ConfAdapter::CheckTrain(oricfg, mode); +} + +bool +RHNSWPQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + static int64_t MAX_EF = 4096; + + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + +bool +RHNSWSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static int64_t MIN_EFCONSTRUCTION = 8; + static int64_t MAX_EFCONSTRUCTION = 512; + static int64_t MIN_M = 4; + static int64_t MAX_M = 64; + + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + + return ConfAdapter::CheckTrain(oricfg, mode); +} + +bool +RHNSWSQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + static int64_t MAX_EF = 4096; + + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + bool BinIDMAPConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { static std::vector METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD, diff --git a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h index 7ac1454292..45d35d3503 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h @@ -109,5 +109,31 @@ class IVFSQ8NRConfAdapter : public IVFConfAdapter { CheckTrain(Config& oricfg, const IndexMode mode) override; }; +class RHNSWFlatConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; + +class RHNSWPQConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; + +class RHNSWSQConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; } // namespace knowhere } // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp index 21abe6dc18..cbe7f20fef 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp @@ -51,6 +51,9 @@ AdapterMgr::RegisterAdapter() { REGISTER_CONF_ADAPTER(ANNOYConfAdapter, IndexEnum::INDEX_ANNOY, annoy_adapter); REGISTER_CONF_ADAPTER(HNSWSQ8NRConfAdapter, IndexEnum::INDEX_HNSW_SQ8NM, hnswsq8nr_adapter); REGISTER_CONF_ADAPTER(IVFSQ8NRConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8NR, ivfsq8nr_adapter); + REGISTER_CONF_ADAPTER(RHNSWFlatConfAdapter, IndexEnum::INDEX_RHNSWFlat, rhnswflat_adapter); + REGISTER_CONF_ADAPTER(RHNSWPQConfAdapter, IndexEnum::INDEX_RHNSWPQ, rhnswpq_adapter); + REGISTER_CONF_ADAPTER(RHNSWSQConfAdapter, IndexEnum::INDEX_RHNSWSQ, rhnswsq_adapter); } } // namespace knowhere diff --git a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp index 9969336648..2fdfe9fee3 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp @@ -53,5 +53,11 @@ FaissBaseIndex::LoadImpl(const BinarySet& binary_set, const IndexType& type) { SealImpl(); } +void +FaissBaseIndex::SealImpl() { +} + +// FaissBaseIndex::~FaissBaseIndex() {} +// } // namespace knowhere } // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.h b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.h index 53a9c3a307..70604ab74f 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.h @@ -34,8 +34,7 @@ class FaissBaseIndex { LoadImpl(const BinarySet&, const IndexType& type); virtual void - SealImpl() { /* do nothing */ - } + SealImpl(); public: std::shared_ptr index_ = nullptr; diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp new file mode 100644 index 0000000000..6508a34217 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp @@ -0,0 +1,148 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexRHNSW.h" + +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +BinarySet +IndexRHNSW::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + MemoryIOWriter writer; + writer.name = this->index_type() + "_Index"; + faiss::write_index(index_.get(), &writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append(writer.name, data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSW::Load(const BinarySet& index_binary) { + try { + MemoryIOReader reader; + reader.name = this->index_type() + "_Index"; + auto binary = index_binary.GetByName(reader.name); + + reader.total = (size_t)binary->size; + reader.data_ = binary->data.get(); + + auto idx = faiss::read_index(&reader); + index_.reset(idx); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) { + KNOWHERE_THROW_MSG("IndexRHNSW has no implementation of Train, please use IndexRHNSW(Flat/SQ/PQ) instead!"); +} + +void +IndexRHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + GET_TENSOR_DATA(dataset_ptr) + + index_->add(rows, (float*)p_data); +} + +DatasetPtr +IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + GET_TENSOR_DATA(dataset_ptr) + + size_t k = config[meta::TOPK].get(); + size_t id_size = sizeof(int64_t) * k; + size_t dist_size = sizeof(float) * k; + auto p_id = (int64_t*)malloc(id_size * rows); + auto p_dist = (float*)malloc(dist_size * rows); + for (auto i = 0; i < k * rows; ++i) { + p_id[i] = -1; + p_dist[i] = -1; + } + + auto real_index = dynamic_cast(index_.get()); + faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); + + real_index->hnsw.efSearch = (config[IndexParams::ef]); + real_index->search(rows, (float*)p_data, k, p_dist, p_id, blacklist); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; +} + +int64_t +IndexRHNSW::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +IndexRHNSW::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + +void +IndexRHNSW::UpdateIndexSize() { + KNOWHERE_THROW_MSG( + "IndexRHNSW has no implementation of UpdateIndexSize, please use IndexRHNSW(Flat/SQ/PQ) instead!"); +} + +/* +BinarySet +IndexRHNSW::SerializeImpl(const milvus::knowhere::IndexType &type) { return BinarySet(); } + +void +IndexRHNSW::SealImpl() {} + +void +IndexRHNSW::LoadImpl(const milvus::knowhere::BinarySet &, const milvus::knowhere::IndexType &type) {} +*/ + +void +IndexRHNSW::AddWithoutIds(const milvus::knowhere::DatasetPtr& dataset, const milvus::knowhere::Config& config) { + KNOWHERE_THROW_MSG("IndexRHNSW has no implementation of AddWithoutIds, please use IndexRHNSW(Flat/SQ/PQ) instead!"); +} +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.h new file mode 100644 index 0000000000..7c5a4a6eaf --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.h @@ -0,0 +1,67 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/FaissBaseIndex.h" +#include "knowhere/index/vector_index/VecIndex.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" + +#include +#include "faiss/IndexRHNSW.h" + +namespace milvus { +namespace knowhere { + +class IndexRHNSW : public VecIndex, public FaissBaseIndex { + public: + IndexRHNSW() : FaissBaseIndex(nullptr) { + index_type_ = IndexEnum::INVALID; + } + + explicit IndexRHNSW(std::shared_ptr index) : FaissBaseIndex(std::move(index)) { + index_type_ = IndexEnum::INVALID; + } + + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + Add(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + AddWithoutIds(const DatasetPtr&, const Config&) override; + + DatasetPtr + Query(const DatasetPtr& dataset_ptr, const Config& config) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; +}; +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.cpp new file mode 100644 index 0000000000..90eb0459a0 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.cpp @@ -0,0 +1,107 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexRHNSWFlat.h" + +#include +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +IndexRHNSWFlat::IndexRHNSWFlat(int d, int M, milvus::knowhere::MetricType metric) { + faiss::MetricType mt = + metric == Metric::L2 ? faiss::MetricType::METRIC_L2 : faiss::MetricType::METRIC_INNER_PRODUCT; + index_ = std::shared_ptr(new faiss::IndexRHNSWFlat(d, M, mt)); +} + +BinarySet +IndexRHNSWFlat::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + auto res_set = IndexRHNSW::Serialize(config); + MemoryIOWriter writer; + writer.name = this->index_type() + "_Data"; + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Serialize!"); + } + auto storage_index = dynamic_cast(real_idx->storage); + faiss::write_index(storage_index, &writer); + std::shared_ptr data(writer.data_); + + res_set.Append(writer.name, data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWFlat::Load(const BinarySet& index_binary) { + try { + IndexRHNSW::Load(index_binary); + MemoryIOReader reader; + reader.name = this->index_type() + "_Data"; + auto binary = index_binary.GetByName(reader.name); + + reader.total = (size_t)binary->size; + reader.data_ = binary->data.get(); + + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Load!"); + } + real_idx->storage = faiss::read_index(&reader); + real_idx->init_hnsw(); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWFlat::Train(const DatasetPtr& dataset_ptr, const Config& config) { + try { + GET_TENSOR_DATA_DIM(dataset_ptr) + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + + auto idx = new faiss::IndexRHNSWFlat(int(dim), config[IndexParams::M], metric_type); + idx->hnsw.efConstruction = config[IndexParams::efConstruction]; + index_ = std::shared_ptr(idx); + index_->train(rows, (float*)p_data); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWFlat::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = dynamic_cast(index_.get())->cal_size(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.h new file mode 100644 index 0000000000..ff68fdbcfb --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.h @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "IndexRHNSW.h" +#include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" + +namespace milvus { +namespace knowhere { + +class IndexRHNSWFlat : public IndexRHNSW { + public: + IndexRHNSWFlat() : IndexRHNSW() { + index_type_ = IndexEnum::INDEX_RHNSWFlat; + } + + explicit IndexRHNSWFlat(std::shared_ptr index) : IndexRHNSW(std::move(index)) { + index_type_ = IndexEnum::INDEX_RHNSWFlat; + } + + IndexRHNSWFlat(int d, int M, MetricType metric = Metric::L2); + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + UpdateIndexSize() override; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.cpp new file mode 100644 index 0000000000..862ae5d2e9 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.cpp @@ -0,0 +1,102 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexRHNSWPQ.h" + +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +IndexRHNSWPQ::IndexRHNSWPQ(int d, int pq_m, int M) { + index_ = std::shared_ptr(new faiss::IndexRHNSWPQ(d, pq_m, M)); +} + +BinarySet +IndexRHNSWPQ::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + auto res_set = IndexRHNSW::Serialize(config); + MemoryIOWriter writer; + writer.name = this->index_type() + "_Data"; + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Serialize!"); + } + faiss::write_index(real_idx->storage, &writer); + std::shared_ptr data(writer.data_); + + res_set.Append(writer.name, data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWPQ::Load(const BinarySet& index_binary) { + try { + IndexRHNSW::Load(index_binary); + MemoryIOReader reader; + reader.name = this->index_type() + "_Data"; + auto binary = index_binary.GetByName(reader.name); + + reader.total = (size_t)binary->size; + reader.data_ = binary->data.get(); + + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Load!"); + } + real_idx->storage = faiss::read_index(&reader); + real_idx->init_hnsw(); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { + try { + GET_TENSOR_DATA_DIM(dataset_ptr) + + auto idx = new faiss::IndexRHNSWPQ(int(dim), config[IndexParams::PQM], config[IndexParams::M]); + idx->hnsw.efConstruction = config[IndexParams::efConstruction]; + index_ = std::shared_ptr(idx); + index_->train(rows, (float*)p_data); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWPQ::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = dynamic_cast(index_.get())->cal_size(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.h new file mode 100644 index 0000000000..0c0e5199b9 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.h @@ -0,0 +1,52 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "IndexRHNSW.h" +#include "knowhere/common/Exception.h" + +namespace milvus { +namespace knowhere { + +class IndexRHNSWPQ : public IndexRHNSW { + public: + IndexRHNSWPQ() : IndexRHNSW() { + index_type_ = IndexEnum::INDEX_RHNSWPQ; + } + + explicit IndexRHNSWPQ(std::shared_ptr index) : IndexRHNSW(std::move(index)) { + index_type_ = IndexEnum::INDEX_RHNSWPQ; + } + + IndexRHNSWPQ(int d, int pq_m, int M); + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + UpdateIndexSize() override; + + private: +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.cpp new file mode 100644 index 0000000000..785015a227 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.cpp @@ -0,0 +1,107 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexRHNSWSQ.h" + +#include +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +IndexRHNSWSQ::IndexRHNSWSQ(int d, faiss::QuantizerType qtype, int M, milvus::knowhere::MetricType metric) { + faiss::MetricType mt = + metric == Metric::L2 ? faiss::MetricType::METRIC_L2 : faiss::MetricType::METRIC_INNER_PRODUCT; + index_ = std::shared_ptr(new faiss::IndexRHNSWSQ(d, qtype, M, mt)); +} + +BinarySet +IndexRHNSWSQ::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + auto res_set = IndexRHNSW::Serialize(config); + MemoryIOWriter writer; + writer.name = this->index_type() + "_Data"; + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Serialize!"); + } + faiss::write_index(real_idx->storage, &writer); + std::shared_ptr data(writer.data_); + + res_set.Append(writer.name, data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWSQ::Load(const BinarySet& index_binary) { + try { + IndexRHNSW::Load(index_binary); + MemoryIOReader reader; + reader.name = this->index_type() + "_Data"; + auto binary = index_binary.GetByName(reader.name); + + reader.total = (size_t)binary->size; + reader.data_ = binary->data.get(); + + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Load!"); + } + real_idx->storage = faiss::read_index(&reader); + real_idx->init_hnsw(); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { + try { + GET_TENSOR_DATA_DIM(dataset_ptr) + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + + auto idx = + new faiss::IndexRHNSWSQ(int(dim), faiss::QuantizerType::QT_8bit, config[IndexParams::M], metric_type); + idx->hnsw.efConstruction = config[IndexParams::efConstruction]; + index_ = std::shared_ptr(idx); + index_->train(rows, (float*)p_data); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWSQ::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = dynamic_cast(index_.get())->cal_size(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.h new file mode 100644 index 0000000000..18410e3e42 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.h @@ -0,0 +1,52 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "IndexRHNSW.h" +#include "knowhere/common/Exception.h" + +namespace milvus { +namespace knowhere { + +class IndexRHNSWSQ : public IndexRHNSW { + public: + IndexRHNSWSQ() : IndexRHNSW() { + index_type_ = IndexEnum::INDEX_RHNSWSQ; + } + + explicit IndexRHNSWSQ(std::shared_ptr index) : IndexRHNSW(std::move(index)) { + index_type_ = IndexEnum::INDEX_RHNSWSQ; + } + + IndexRHNSWSQ(int d, faiss::QuantizerType qtype, int M, MetricType metric = Metric::L2); + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + UpdateIndexSize() override; + + private: +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp b/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp index aead15349a..73d087e785 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp @@ -20,11 +20,15 @@ #include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/IndexIVFPQ.h" #include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/IndexRHNSWFlat.h" +#include "knowhere/index/vector_index/IndexRHNSWPQ.h" +#include "knowhere/index/vector_index/IndexRHNSWSQ.h" #include "knowhere/index/vector_offset_index/IndexHNSW_NM.h" #include "knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.h" #include "knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h" #include "knowhere/index/vector_offset_index/IndexIVF_NM.h" #include "knowhere/index/vector_offset_index/IndexNSG_NM.h" + #ifdef MILVUS_SUPPORT_SPTAG #include "knowhere/index/vector_index/IndexSPTAG.h" #endif @@ -94,6 +98,12 @@ VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) { return std::make_shared(); } else if (type == IndexEnum::INDEX_HNSW_SQ8NM) { return std::make_shared(); + } else if (type == IndexEnum::INDEX_RHNSWFlat) { + return std::make_shared(); + } else if (type == IndexEnum::INDEX_RHNSWPQ) { + return std::make_shared(); + } else if (type == IndexEnum::INDEX_RHNSWSQ) { + return std::make_shared(); } else { return nullptr; } diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h b/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h index b37988d881..69c84f5279 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h @@ -48,6 +48,9 @@ constexpr const char* ef = "ef"; // Annoy Params constexpr const char* n_trees = "n_trees"; constexpr const char* search_k = "search_k"; + +// PQ Params +constexpr const char* PQM = "PQM"; } // namespace IndexParams namespace Metric { diff --git a/core/src/index/thirdparty/faiss/Index2Layer.h b/core/src/index/thirdparty/faiss/Index2Layer.h index 7062ff3690..b7d8ccd1fa 100644 --- a/core/src/index/thirdparty/faiss/Index2Layer.h +++ b/core/src/index/thirdparty/faiss/Index2Layer.h @@ -80,6 +80,7 @@ struct Index2Layer: Index { void sa_encode (idx_t n, const float *x, uint8_t *bytes) const override; void sa_decode (idx_t n, const uint8_t *bytes, float *x) const override; + size_t cal_size() { return sizeof(*this) + codes.size() * sizeof(uint8_t) + pq.cal_size(); } }; diff --git a/core/src/index/thirdparty/faiss/IndexFlat.h b/core/src/index/thirdparty/faiss/IndexFlat.h index a04d32a614..13f8829f5f 100644 --- a/core/src/index/thirdparty/faiss/IndexFlat.h +++ b/core/src/index/thirdparty/faiss/IndexFlat.h @@ -85,6 +85,8 @@ struct IndexFlat: Index { void sa_decode (idx_t n, const uint8_t *bytes, float *x) const override; + size_t cal_size() { return xb.size() * sizeof(float); } + }; diff --git a/core/src/index/thirdparty/faiss/IndexPQ.h b/core/src/index/thirdparty/faiss/IndexPQ.h index 25a643efe2..97ce84f11d 100644 --- a/core/src/index/thirdparty/faiss/IndexPQ.h +++ b/core/src/index/thirdparty/faiss/IndexPQ.h @@ -124,6 +124,8 @@ struct IndexPQ: Index { void hamming_distance_table (idx_t n, const float *x, int32_t *dis) const; + size_t cal_size() { return codes.size() * sizeof(uint8_t) + pq.cal_size(); } + }; diff --git a/core/src/index/thirdparty/faiss/IndexRHNSW.cpp b/core/src/index/thirdparty/faiss/IndexRHNSW.cpp new file mode 100644 index 0000000000..bd112e596d --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexRHNSW.cpp @@ -0,0 +1,812 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#ifdef __SSE__ +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +extern "C" { + +/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ + +int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const float *alpha, const float *a, + FINTEGER *lda, const float *b, FINTEGER * + ldb, float *beta, float *c, FINTEGER *ldc); + +} + +namespace faiss { + +using idx_t = Index::idx_t; +using MinimaxHeap = RHNSW::MinimaxHeap; +using storage_idx_t = RHNSW::storage_idx_t; +using NodeDistFarther = RHNSW::NodeDistFarther; + +RHNSWStats rhnsw_stats; + +/************************************************************** + * add / search blocks of descriptors + **************************************************************/ + +namespace { + + +/* Wrap the distance computer into one that negates the + distances. This makes supporting INNER_PRODUCE search easier */ + +struct NegativeDistanceComputer: DistanceComputer { + + /// owned by this + DistanceComputer *basedis; + + explicit NegativeDistanceComputer(DistanceComputer *basedis): + basedis(basedis) + {} + + void set_query(const float *x) override { + basedis->set_query(x); + } + + /// compute distance of vector i to current query + float operator () (idx_t i) override { + return -(*basedis)(i); + } + + /// compute distance between two stored vectors + float symmetric_dis (idx_t i, idx_t j) override { + return -basedis->symmetric_dis(i, j); + } + + virtual ~NegativeDistanceComputer () + { + delete basedis; + } + +}; + +DistanceComputer *storage_distance_computer(const Index *storage) +{ + if (storage->metric_type == METRIC_INNER_PRODUCT) { + return new NegativeDistanceComputer(storage->get_distance_computer()); + } else { + return storage->get_distance_computer(); + } +} + +void hnsw_add_vertices(IndexRHNSW &index_hnsw, + size_t n0, + size_t n, const float *x, + bool verbose, + bool preset_levels = false) { + size_t d = index_hnsw.d; + RHNSW & hnsw = index_hnsw.hnsw; + size_t ntotal = n0 + n; + double t0 = getmillisecs(); + if (verbose) { + printf("hnsw_add_vertices: adding %ld elements on top of %ld " + "(preset_levels=%d)\n", + n, n0, int(preset_levels)); + } + + if (n == 0) { + return; + } + + int max_level = hnsw.prepare_level_tab(n, preset_levels); + + if (verbose) { + printf(" max_level = %d\n", max_level); + } + + + { // perform add + auto tas = getmillisecs(); + RandomGenerator rng2(789); + DistanceComputer *dis0 = + storage_distance_computer (index_hnsw.storage); + ScopeDeleter1 del0(dis0); + + dis0->set_query(x); + hnsw.addPoint(*dis0, hnsw.levels[n0], n0); + +#pragma omp parallel for + for (int i = 1; i < n; ++ i) { + DistanceComputer *dis = + storage_distance_computer (index_hnsw.storage); + ScopeDeleter1 del(dis); + dis->set_query(x + i * d); + hnsw.addPoint(*dis, hnsw.levels[n0 + i], i + n0); + } + } + if (verbose) { + printf("Done in %.3f ms\n", getmillisecs() - t0); + } + +} + +} // namespace + + + + +/************************************************************** + * IndexRHNSW implementation + **************************************************************/ + +IndexRHNSW::IndexRHNSW(int d, int M, MetricType metric): + Index(d, metric), + hnsw(M), + own_fields(false), + storage(nullptr), + reconstruct_from_neighbors(nullptr) +{} + +IndexRHNSW::IndexRHNSW(Index *storage, int M): + Index(storage->d, storage->metric_type), + hnsw(M), + own_fields(false), + storage(storage), + reconstruct_from_neighbors(nullptr) +{} + +IndexRHNSW::~IndexRHNSW() { + if (own_fields) { + delete storage; + } +} + +void IndexRHNSW::init_hnsw() { + hnsw.init(ntotal); +} + +void IndexRHNSW::train(idx_t n, const float* x) +{ + FAISS_THROW_IF_NOT_MSG(storage, + "Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly"); + // hnsw structure does not require training + storage->train (n, x); + is_trained = true; +} + +void IndexRHNSW::search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const + +{ + FAISS_THROW_IF_NOT_MSG(storage, + "Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly"); + size_t nreorder = 0; + + idx_t check_period = InterruptCallback::get_period_hint ( + hnsw.max_level * d * hnsw.efSearch); + + for (idx_t i0 = 0; i0 < n; i0 += check_period) { + idx_t i1 = std::min(i0 + check_period, n); + +#pragma omp parallel reduction(+ : nreorder) + { + + DistanceComputer *dis = storage_distance_computer(storage); + ScopeDeleter1 del(dis); + +#pragma omp for + for(idx_t i = i0; i < i1; i++) { + idx_t * idxi = labels + i * k; + float * simi = distances + i * k; + dis->set_query(x + i * d); + + maxheap_heapify (k, simi, idxi); + + hnsw.searchKnn(*dis, k, idxi, simi, bitset); + + maxheap_reorder (k, simi, idxi); + + if (reconstruct_from_neighbors && + reconstruct_from_neighbors->k_reorder != 0) { + int k_reorder = reconstruct_from_neighbors->k_reorder; + if (k_reorder == -1 || k_reorder > k) k_reorder = k; + + nreorder += reconstruct_from_neighbors->compute_distances( + k_reorder, idxi, x + i * d, simi); + + // sort top k_reorder + maxheap_heapify (k_reorder, simi, idxi, simi, idxi, k_reorder); + maxheap_reorder (k_reorder, simi, idxi); + } + + } + + } + InterruptCallback::check (); + } + + if (metric_type == METRIC_INNER_PRODUCT) { + // we need to revert the negated distances + for (size_t i = 0; i < k * n; i++) { + distances[i] = -distances[i]; + } + } + + rhnsw_stats.nreorder += nreorder; +} + + +void IndexRHNSW::add(idx_t n, const float *x) +{ + FAISS_THROW_IF_NOT_MSG(storage, + "Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly"); + FAISS_THROW_IF_NOT(is_trained); + int n0 = ntotal; + storage->add(n, x); + ntotal = storage->ntotal; + + hnsw_add_vertices (*this, n0, n, x, verbose, + hnsw.levels.size() == ntotal); +} + +void IndexRHNSW::reset() +{ + hnsw.reset(); + storage->reset(); + ntotal = 0; +} + +void IndexRHNSW::reconstruct (idx_t key, float* recons) const +{ + storage->reconstruct(key, recons); +} + +size_t IndexRHNSW::cal_size() { + return hnsw.cal_size(); +} + +/************************************************************** + * ReconstructFromNeighbors implementation + **************************************************************/ + +ReconstructFromNeighbors2::ReconstructFromNeighbors2( + const IndexRHNSW & index, size_t k, size_t nsq): + index(index), k(k), nsq(nsq) { + M = index.hnsw.M << 1; + FAISS_ASSERT(k <= 256); + code_size = k == 1 ? 0 : nsq; + ntotal = 0; + d = index.d; + FAISS_ASSERT(d % nsq == 0); + dsub = d / nsq; + k_reorder = -1; +} + +void ReconstructFromNeighbors2::reconstruct(storage_idx_t i, float *x, float *tmp) const +{ + + + const RHNSW & hnsw = index.hnsw; + int *cur_links = hnsw.get_neighbor_link(i, 0); + int *cur_neighbors = cur_links + 1; + auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links); + + if (k == 1 || nsq == 1) { + const float * beta; + if (k == 1) { + beta = codebook.data(); + } else { + int idx = codes[i]; + beta = codebook.data() + idx * (M + 1); + } + + float w0 = beta[0]; // weight of image itself + index.storage->reconstruct(i, tmp); + + for (int l = 0; l < d; l++) + x[l] = w0 * tmp[l]; + + for (auto j = 0; j < cur_neighbor_num; ++ j) { + + storage_idx_t ji = cur_neighbors[j]; + if (ji < 0) ji = i; + float w = beta[j + 1]; + index.storage->reconstruct(ji, tmp); + for (int l = 0; l < d; l++) + x[l] += w * tmp[l]; + } + } else if (nsq == 2) { + int idx0 = codes[2 * i]; + int idx1 = codes[2 * i + 1]; + + const float *beta0 = codebook.data() + idx0 * (M + 1); + const float *beta1 = codebook.data() + (idx1 + k) * (M + 1); + + index.storage->reconstruct(i, tmp); + + float w0; + + w0 = beta0[0]; + for (int l = 0; l < dsub; l++) + x[l] = w0 * tmp[l]; + + w0 = beta1[0]; + for (int l = dsub; l < d; l++) + x[l] = w0 * tmp[l]; + + for (auto j = 0; j < cur_neighbor_num; ++ j) { + storage_idx_t ji = cur_neighbors[j]; + if (ji < 0) ji = i; + index.storage->reconstruct(ji, tmp); + float w; + w = beta0[j + 1]; + for (int l = 0; l < dsub; l++) + x[l] += w * tmp[l]; + + w = beta1[j + 1]; + for (int l = dsub; l < d; l++) + x[l] += w * tmp[l]; + } + } else { + const float *betas[nsq]; + { + const float *b = codebook.data(); + const uint8_t *c = &codes[i * code_size]; + for (int sq = 0; sq < nsq; sq++) { + betas[sq] = b + (*c++) * (M + 1); + b += (M + 1) * k; + } + } + + index.storage->reconstruct(i, tmp); + { + int d0 = 0; + for (int sq = 0; sq < nsq; sq++) { + float w = *(betas[sq]++); + int d1 = d0 + dsub; + for (int l = d0; l < d1; l++) { + x[l] = w * tmp[l]; + } + d0 = d1; + } + } + + for (auto j = 0; j < cur_neighbor_num; ++ j) { + storage_idx_t ji = cur_neighbors[j]; + if (ji < 0) ji = i; + + index.storage->reconstruct(ji, tmp); + int d0 = 0; + for (int sq = 0; sq < nsq; sq++) { + float w = *(betas[sq]++); + int d1 = d0 + dsub; + for (int l = d0; l < d1; l++) { + x[l] += w * tmp[l]; + } + d0 = d1; + } + } + } +} + +void ReconstructFromNeighbors2::reconstruct_n(storage_idx_t n0, + storage_idx_t ni, + float *x) const +{ +#pragma omp parallel + { + std::vector tmp(index.d); +#pragma omp for + for (storage_idx_t i = 0; i < ni; i++) { + reconstruct(n0 + i, x + i * index.d, tmp.data()); + } + } +} + +size_t ReconstructFromNeighbors2::compute_distances( + size_t n, const idx_t *shortlist, + const float *query, float *distances) const +{ + std::vector tmp(2 * index.d); + size_t ncomp = 0; + for (int i = 0; i < n; i++) { + if (shortlist[i] < 0) break; + reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d); + distances[i] = fvec_L2sqr(query, tmp.data(), index.d); + ncomp++; + } + return ncomp; +} + +void ReconstructFromNeighbors2::get_neighbor_table(storage_idx_t i, float *tmp1) const +{ + const RHNSW & hnsw = index.hnsw; + int *cur_links = hnsw.get_neighbor_link(i, 0); + int *cur_neighbors = cur_links + 1; + auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links); + size_t d = index.d; + + index.storage->reconstruct(i, tmp1); + + for (auto j = 0; j < cur_neighbor_num; ++ j) { + storage_idx_t ji = cur_neighbors[j]; + if (ji < 0) ji = i; + index.storage->reconstruct(ji, tmp1 + (j + 1) * d); + } + +} + + +/// called by add_codes +void ReconstructFromNeighbors2::estimate_code( + const float *x, storage_idx_t i, uint8_t *code) const +{ + + // fill in tmp table with the neighbor values + float *tmp1 = new float[d * (M + 1) + (d * k)]; + float *tmp2 = tmp1 + d * (M + 1); + ScopeDeleter del(tmp1); + + // collect coordinates of base + get_neighbor_table (i, tmp1); + + for (size_t sq = 0; sq < nsq; sq++) { + int d0 = sq * dsub; + + { + FINTEGER ki = k, di = d, m1 = M + 1; + FINTEGER dsubi = dsub; + float zero = 0, one = 1; + + sgemm_ ("N", "N", &dsubi, &ki, &m1, &one, + tmp1 + d0, &di, + codebook.data() + sq * (m1 * k), &m1, + &zero, tmp2, &dsubi); + } + + float min = HUGE_VAL; + int argmin = -1; + for (size_t j = 0; j < k; j++) { + float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub); + if (dis < min) { + min = dis; + argmin = j; + } + } + code[sq] = argmin; + } + +} + +void ReconstructFromNeighbors2::add_codes(size_t n, const float *x) +{ + if (k == 1) { // nothing to encode + ntotal += n; + return; + } + codes.resize(codes.size() + code_size * n); +#pragma omp parallel for + for (int i = 0; i < n; i++) { + estimate_code(x + i * index.d, ntotal + i, + codes.data() + (ntotal + i) * code_size); + } + ntotal += n; + FAISS_ASSERT (codes.size() == ntotal * code_size); +} + + +/************************************************************** + * IndexRHNSWFlat implementation + **************************************************************/ + + +IndexRHNSWFlat::IndexRHNSWFlat() +{ + is_trained = true; +} + +IndexRHNSWFlat::IndexRHNSWFlat(int d, int M, MetricType metric): + IndexRHNSW(new IndexFlat(d, metric), M) +{ + own_fields = true; + is_trained = true; +} + +size_t IndexRHNSWFlat::cal_size() { + return IndexRHNSW::cal_size() + dynamic_cast(storage)->cal_size(); +} + +/************************************************************** + * IndexRHNSWPQ implementation + **************************************************************/ + + +IndexRHNSWPQ::IndexRHNSWPQ() {} + +IndexRHNSWPQ::IndexRHNSWPQ(int d, int pq_m, int M): + IndexRHNSW(new IndexPQ(d, pq_m, 8), M) +{ + own_fields = true; + is_trained = false; +} + +void IndexRHNSWPQ::train(idx_t n, const float* x) +{ + IndexRHNSW::train (n, x); + (dynamic_cast (storage))->pq.compute_sdc_table(); +} + +size_t IndexRHNSWPQ::cal_size() { + return IndexRHNSW::cal_size() + dynamic_cast(storage)->cal_size(); +} + +/************************************************************** + * IndexRHNSWSQ implementation + **************************************************************/ + + +IndexRHNSWSQ::IndexRHNSWSQ(int d, QuantizerType qtype, int M, + MetricType metric): + IndexRHNSW (new IndexScalarQuantizer (d, qtype, metric), M) +{ + is_trained = false; + own_fields = true; +} + +IndexRHNSWSQ::IndexRHNSWSQ() {} + +size_t IndexRHNSWSQ::cal_size() { + return IndexRHNSW::cal_size() + dynamic_cast(storage)->cal_size(); +} + +/************************************************************** + * IndexRHNSW2Level implementation + **************************************************************/ + + +IndexRHNSW2Level::IndexRHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M): + IndexRHNSW (new Index2Layer (quantizer, nlist, m_pq), M) +{ + own_fields = true; + is_trained = false; +} + +IndexRHNSW2Level::IndexRHNSW2Level() {} + +size_t IndexRHNSW2Level::cal_size() { + return IndexRHNSW::cal_size() + dynamic_cast(storage)->cal_size(); +} + +namespace { + + +// same as search_from_candidates but uses v +// visno -> is in result list +// visno + 1 -> in result list + in candidates +int search_from_candidates_2(const RHNSW & hnsw, + DistanceComputer & qdis, int k, + idx_t *I, float * D, + MinimaxHeap &candidates, + VisitedList &vt, + int level, int nres_in = 0) +{ + int nres = nres_in; + int ndis = 0; + for (int i = 0; i < candidates.size(); i++) { + idx_t v1 = candidates.ids[i]; + FAISS_ASSERT(v1 >= 0); + vt.mass[v1] = vt.curV + 1; + } + + int nstep = 0; + + while (candidates.size() > 0) { + float d0 = 0; + int v0 = candidates.pop_min(&d0); + + int *cur_links = hnsw.get_neighbor_link(v0, level); + int *cur_neighbors = cur_links + 1; + auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links); + + for (auto j = 0; j < cur_neighbor_num; ++ j) { + int v1 = cur_neighbors[j]; + if (v1 < 0) break; + if (vt.mass[v1] == vt.curV + 1) { + // nothing to do + } else { + ndis++; + float d = qdis(v1); + candidates.push(v1, d); + + // never seen before --> add to heap + if (vt.mass[v1] < vt.curV) { + if (nres < k) { + faiss::maxheap_push (++nres, D, I, d, v1); + } else if (d < D[0]) { + faiss::maxheap_pop (nres--, D, I); + faiss::maxheap_push (++nres, D, I, d, v1); + } + } + vt.mass[v1] = vt.curV + 1; + } + } + + nstep++; + if (nstep > hnsw.efSearch) { + break; + } + } + + if (level == 0) { +#pragma omp critical + { + rhnsw_stats.n1 ++; + if (candidates.size() == 0) + rhnsw_stats.n2 ++; + } + } + + + return nres; +} + + +} // namespace + +void IndexRHNSW2Level::search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const +{ + if (dynamic_cast(storage)) { + IndexRHNSW::search (n, x, k, distances, labels); + + } else { // "mixed" search + + const IndexIVFPQ *index_ivfpq = + dynamic_cast(storage); + + int nprobe = index_ivfpq->nprobe; + + std::unique_ptr coarse_assign(new idx_t[n * nprobe]); + std::unique_ptr coarse_dis(new float[n * nprobe]); + + index_ivfpq->quantizer->search (n, x, nprobe, coarse_dis.get(), + coarse_assign.get()); + + index_ivfpq->search_preassigned (n, x, k, coarse_assign.get(), + coarse_dis.get(), distances, labels, + false); + +#pragma omp parallel + { + VisitedList vt (ntotal); + DistanceComputer *dis = storage_distance_computer(storage); + ScopeDeleter1 del(dis); + + int candidates_size = hnsw.upper_beam; + MinimaxHeap candidates(candidates_size); + +#pragma omp for + for(idx_t i = 0; i < n; i++) { + idx_t * idxi = labels + i * k; + float * simi = distances + i * k; + dis->set_query(x + i * d); + + // mark all inverted list elements as visited + + for (int j = 0; j < nprobe; j++) { + idx_t key = coarse_assign[j + i * nprobe]; + if (key < 0) break; + size_t list_length = index_ivfpq->get_list_size (key); + const idx_t * ids = index_ivfpq->invlists->get_ids (key); + + for (int jj = 0; jj < list_length; jj++) { + vt.set (ids[jj]); + } + } + + candidates.clear(); + // copy the upper_beam elements to candidates list + + int search_policy = 2; + + if (search_policy == 1) { + + for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) { + if (idxi[j] < 0) break; + candidates.push (idxi[j], simi[j]); + // search_from_candidates adds them back + idxi[j] = -1; + simi[j] = HUGE_VAL; + } + + // reorder from sorted to heap + maxheap_heapify (k, simi, idxi, simi, idxi, k); + + // removed from RHNSW, but still available in HNSW +// hnsw.search_from_candidates( +// *dis, k, idxi, simi, +// candidates, vt, 0, k +// ); + + vt.advance(); + + } else if (search_policy == 2) { + + for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) { + if (idxi[j] < 0) break; + candidates.push (idxi[j], simi[j]); + } + + // reorder from sorted to heap + maxheap_heapify (k, simi, idxi, simi, idxi, k); + + search_from_candidates_2 ( + hnsw, *dis, k, idxi, simi, + candidates, vt, 0, k); + vt.advance (); + vt.advance (); + + } + + maxheap_reorder (k, simi, idxi); + } + } + } + + +} + + +void IndexRHNSW2Level::flip_to_ivf () +{ + Index2Layer *storage2l = + dynamic_cast(storage); + + FAISS_THROW_IF_NOT (storage2l); + + IndexIVFPQ * index_ivfpq = + new IndexIVFPQ (storage2l->q1.quantizer, + d, storage2l->q1.nlist, + storage2l->pq.M, 8); + index_ivfpq->pq = storage2l->pq; + index_ivfpq->is_trained = storage2l->is_trained; + index_ivfpq->precompute_table(); + index_ivfpq->own_fields = storage2l->q1.own_fields; + storage2l->transfer_to_IVFPQ(*index_ivfpq); + index_ivfpq->make_direct_map (true); + + storage = index_ivfpq; + delete storage2l; + +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexRHNSW.h b/core/src/index/thirdparty/faiss/IndexRHNSW.h new file mode 100644 index 0000000000..f2641a076c --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexRHNSW.h @@ -0,0 +1,152 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include + +#include +#include +#include +#include +#include +//#include + + +namespace faiss { + +struct IndexRHNSW; + +struct ReconstructFromNeighbors2 { + typedef Index::idx_t idx_t; + typedef RHNSW::storage_idx_t storage_idx_t; + + const IndexRHNSW & index; + size_t M; // number of neighbors + size_t k; // number of codebook entries + size_t nsq; // number of subvectors + size_t code_size; + int k_reorder; // nb to reorder. -1 = all + + std::vector codebook; // size nsq * k * (M + 1) + + std::vector codes; // size ntotal * code_size + size_t ntotal; + size_t d, dsub; // derived values + + explicit ReconstructFromNeighbors2(const IndexRHNSW& index, + size_t k=256, size_t nsq=1); + + /// codes must be added in the correct order and the IndexRHNSW + /// must be populated and sorted + void add_codes(size_t n, const float *x); + + size_t compute_distances(size_t n, const idx_t *shortlist, + const float *query, float *distances) const; + + /// called by add_codes + void estimate_code(const float *x, storage_idx_t i, uint8_t *code) const; + + /// called by compute_distances + void reconstruct(storage_idx_t i, float *x, float *tmp) const; + + void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float *x) const; + + /// get the M+1 -by-d table for neighbor coordinates for vector i + void get_neighbor_table(storage_idx_t i, float *out) const; + +}; + +/** The HNSW index is a normal random-access index with a HNSW + * link structure built on top */ + +struct IndexRHNSW : Index { + + typedef RHNSW::storage_idx_t storage_idx_t; + + // the link strcuture + RHNSW hnsw; + + // the sequential storage + bool own_fields; + Index *storage; + + ReconstructFromNeighbors2 *reconstruct_from_neighbors; + + explicit IndexRHNSW (int d = 0, int M = 32, MetricType metric = METRIC_L2); + explicit IndexRHNSW (Index *storage, int M = 32); + + ~IndexRHNSW() override; + + void add(idx_t n, const float *x) override; + + /// Trains the storage if needed + void train(idx_t n, const float* x) override; + + /// entry point for search + void search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reconstruct(idx_t key, float* recons) const override; + + void reset () override; + + size_t cal_size(); + + void init_hnsw(); +}; + + +/** Flat index topped with with a HNSW structure to access elements + * more efficiently. + */ + +struct IndexRHNSWFlat : IndexRHNSW { + IndexRHNSWFlat(); + IndexRHNSWFlat(int d, int M, MetricType metric = METRIC_L2); + size_t cal_size(); +}; + +/** PQ index topped with with a HNSW structure to access elements + * more efficiently. + */ +struct IndexRHNSWPQ : IndexRHNSW { + IndexRHNSWPQ(); + IndexRHNSWPQ(int d, int pq_m, int M); + void train(idx_t n, const float* x) override; + size_t cal_size(); +}; + +/** SQ index topped with with a HNSW structure to access elements + * more efficiently. + */ +struct IndexRHNSWSQ : IndexRHNSW { + IndexRHNSWSQ(); + IndexRHNSWSQ(int d, QuantizerType qtype, int M, MetricType metric = METRIC_L2); + size_t cal_size(); +}; + +/** 2-level code structure with fast random access + */ +struct IndexRHNSW2Level : IndexRHNSW { + IndexRHNSW2Level(); + IndexRHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M); + + void flip_to_ivf(); + + /// entry point for search + void search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + size_t cal_size(); +}; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexScalarQuantizer.h b/core/src/index/thirdparty/faiss/IndexScalarQuantizer.h index 03b9abf37d..4313a5b37e 100644 --- a/core/src/index/thirdparty/faiss/IndexScalarQuantizer.h +++ b/core/src/index/thirdparty/faiss/IndexScalarQuantizer.h @@ -78,6 +78,7 @@ struct IndexScalarQuantizer: Index { void sa_decode (idx_t n, const uint8_t *bytes, float *x) const override; + size_t cal_size() { return codes.size() * sizeof(uint8_t) + sizeof(size_t) + sq.cal_size(); } }; diff --git a/core/src/index/thirdparty/faiss/build.sh b/core/src/index/thirdparty/faiss/build.sh index ea6b4c0c7d..a58a6e6134 100755 --- a/core/src/index/thirdparty/faiss/build.sh +++ b/core/src/index/thirdparty/faiss/build.sh @@ -1,3 +1,3 @@ #./configure CPUFLAGS='-mavx -mf16c -msse4 -mpopcnt' CXXFLAGS='-O0 -g -fPIC -m64 -Wno-sign-compare -Wall -Wextra' --prefix=$PWD --with-cuda-arch=-gencode=arch=compute_75,code=sm_75 --with-cuda=/usr/local/cuda -./configure --prefix=$PWD CFLAGS='-g -fPIC' CXXFLAGS='-O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp -g -fPIC -mf16c -O3' --without-python --with-cuda=/usr/local/cuda --with-cuda-arch='-gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75' +./configure --prefix=$PWD CFLAGS='-g -fPIC' CXXFLAGS='-O3 -g -fPIC -DELPP_THREAD_SAFE -fopenmp -g -fPIC -mf16c -O3 -DNDEBUG' --without-python --with-cuda=/usr/local/cuda --with-cuda-arch='-gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75' make install -j8 diff --git a/core/src/index/thirdparty/faiss/impl/ProductQuantizer.h b/core/src/index/thirdparty/faiss/impl/ProductQuantizer.h index c900d9c9d4..6364be4eae 100644 --- a/core/src/index/thirdparty/faiss/impl/ProductQuantizer.h +++ b/core/src/index/thirdparty/faiss/impl/ProductQuantizer.h @@ -173,6 +173,7 @@ struct ProductQuantizer { float_maxheap_array_t * res, bool init_finalize_heap = true) const; + size_t cal_size() { return sizeof(*this) + centroids.size() * sizeof(float); } }; diff --git a/core/src/index/thirdparty/faiss/impl/RHNSW.cpp b/core/src/index/thirdparty/faiss/impl/RHNSW.cpp new file mode 100644 index 0000000000..4928910baa --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/RHNSW.cpp @@ -0,0 +1,441 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include + +#include + +namespace faiss { + + +/************************************************************** + * hnsw structure implementation + **************************************************************/ + +RHNSW::RHNSW(int M) : M(M), rng(12345) { + level_generator.seed(100); + max_level = -1; + entry_point = -1; + efSearch = 16; + efConstruction = 40; + upper_beam = 1; + level0_link_size = sizeof(int) * ((M << 1) | 1); + link_size = sizeof(int) * (M + 1); + level0_links = nullptr; + linkLists = nullptr; + level_constant = 1 / log(1.0 * M); + visited_list_pool = nullptr; +} + +void RHNSW::init(int ntotal) { + level_generator.seed(100); + if (visited_list_pool) delete visited_list_pool; + visited_list_pool = new VisitedListPool(1, ntotal); + std::vector(ntotal).swap(link_list_locks); +} + +RHNSW::~RHNSW() { + free(level0_links); + for (auto i = 0; i < levels.size(); ++ i) { + if (levels[i]) + free(linkLists[i]); + } + free(linkLists); + delete visited_list_pool; +} + +void RHNSW::reset() { + max_level = -1; + entry_point = -1; + levels.clear(); + free(level0_links); + for (auto i = 0; i < levels.size(); ++ i) { + if (levels[i]) + free(linkLists[i]); + } + free(linkLists); + level0_links = nullptr; + linkLists = nullptr; + level_constant = 1 / log(1.0 * M); +} + +int RHNSW::prepare_level_tab(size_t n, bool preset_levels) +{ + size_t n0 = levels.size(); + + std::vector level_stats(n); + if (preset_levels) { + FAISS_ASSERT (n0 + n == levels.size()); + } else { + FAISS_ASSERT (n0 == levels.size()); + for (int i = 0; i < n; i++) { + int pt_level = random_level(level_constant); + levels.push_back(pt_level); + } + } + + char *level0_links_new = (char*)malloc((n0 + n) * level0_link_size); + if (level0_links_new == nullptr) { + throw std::runtime_error("No enough memory 4 level0_links!"); + } + memset(level0_links_new, 0, (n0 + n) * level0_link_size); + if (level0_links) { + memcpy(level0_links_new, level0_links, n0 * level0_link_size); + free(level0_links); + } + level0_links = level0_links_new; + + char **linkLists_new = (char **)malloc(sizeof(void*) * (n0 + n)); + if (linkLists_new == nullptr) { + throw std::runtime_error("No enough memory 4 level0_links_new!"); + } + if (linkLists) { + memcpy(linkLists_new, linkLists, n0 * sizeof(void*)); + free(linkLists); + } + linkLists = linkLists_new; + + int max_level = 0; + int debug_space = 0; + for (int i = 0; i < n; i++) { + int pt_level = levels[i + n0]; + if (pt_level > max_level) max_level = pt_level; + if (pt_level) { + linkLists[n0 + i] = (char*) malloc(link_size * pt_level + 1); + if (linkLists[n0 + i] == nullptr) { + throw std::runtime_error("No enough memory 4 linkLists!"); + } + memset(linkLists[n0 + i], 0, link_size * pt_level + 1); + } + if (max_level >= level_stats.size()) { + level_stats.resize(max_level + 1); + } + level_stats[pt_level] ++; + } + +// printf("level stats:\n"); +// for (int i = 0; i <= max_level; ++ i) +// printf("level %d: %d points\n", i, level_stats[i]); +// printf("\n"); + std::vector(n0 + n).swap(link_list_locks); + if (visited_list_pool) delete visited_list_pool; + visited_list_pool = new VisitedListPool(1, n0 + n); + + return max_level; +} + + +/************************************************************** + * new implementation of hnsw ispired by hnswlib + * by cmli@zilliz July 30, 2020 + **************************************************************/ +using Node = faiss::RHNSW::Node; +using CompareByFirst = faiss::RHNSW::CompareByFirst; +void RHNSW::addPoint(DistanceComputer& ptdis, int pt_level, int pt_id) { + + std::unique_lock lock_el(link_list_locks[pt_id]); + std::unique_lock temp_lock(global); + int maxlevel_copy = max_level; + if (pt_level <= maxlevel_copy) + temp_lock.unlock(); + int currObj = entry_point; + int ep_copy = entry_point; + + if (currObj != -1) { + if (pt_level < maxlevel_copy) { + float curdist = ptdis(currObj); + for (int lev = maxlevel_copy; lev > pt_level; lev --) { + bool changed = true; + while (changed) { + changed = false; + std::unique_lock lk(link_list_locks[currObj]); + int *curObj_link = get_neighbor_link(currObj, lev); + auto curObj_nei_num = get_neighbors_num(curObj_link); + for (auto i = 1; i <= curObj_nei_num; ++ i) { + int cand = curObj_link[i]; + if (cand < 0 || cand > levels.size()) + throw std::runtime_error("cand error when addPoint"); + float d = ptdis(cand); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + for (int lev = std::min(pt_level, maxlevel_copy); lev >= 0; -- lev) { + if (lev > maxlevel_copy || lev < 0) + throw std::runtime_error("Level error"); + + std::priority_queue, CompareByFirst> top_candidates = search_layer(ptdis, pt_id, currObj, lev); + currObj = top_candidates.top().second; + make_connection(ptdis, pt_id, top_candidates, lev); + } + } else { + entry_point = 0; + max_level = pt_level; + } + + if (pt_level > maxlevel_copy) { + entry_point = pt_id; + max_level = pt_level; + } + +} + +std::priority_queue, CompareByFirst> +RHNSW::search_layer(DistanceComputer& ptdis, + storage_idx_t pt_id, + storage_idx_t nearest, + int level) { + VisitedList *vl = visited_list_pool->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, CompareByFirst> top_candidates; + std::priority_queue, CompareByFirst> candidate_set; + + float d_nearest = ptdis(nearest); + float lb = d_nearest; + top_candidates.emplace(d_nearest, nearest); + candidate_set.emplace(-d_nearest, nearest); + visited_array[nearest] = visited_array_tag; + + while (!candidate_set.empty()) { + Node currNode = candidate_set.top(); + if ((-currNode.first) > lb) + break; + candidate_set.pop(); + int cur_id = currNode.second; + std::unique_lock lk(link_list_locks[cur_id]); + int *cur_link = get_neighbor_link(cur_id, level); + auto cur_neighbor_num = get_neighbors_num(cur_link); + + for (auto i = 1; i <= cur_neighbor_num; ++ i) { + int candidate_id = cur_link[i]; + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + float dcand = ptdis(candidate_id); + if (top_candidates.size() < efConstruction || lb > dcand) { + candidate_set.emplace(-dcand, candidate_id); + top_candidates.emplace(dcand, candidate_id); + if (top_candidates.size() > efConstruction) + top_candidates.pop(); + if (!top_candidates.empty()) + lb = top_candidates.top().first; + } + } + } + visited_list_pool->releaseVisitedList(vl); + return top_candidates; +} + +std::priority_queue, CompareByFirst> +RHNSW::search_base_layer(DistanceComputer& ptdis, + storage_idx_t nearest, + storage_idx_t ef, + float d_nearest, + ConcurrentBitsetPtr bitset) const { + VisitedList *vl = visited_list_pool->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, CompareByFirst> top_candidates; + std::priority_queue, CompareByFirst> candidate_set; + + float lb; + if (bitset == nullptr || !bitset->test((faiss::ConcurrentBitset::id_type_t)(nearest))) { + lb = d_nearest; + top_candidates.emplace(d_nearest, nearest); + candidate_set.emplace(-d_nearest, nearest); + } else { + lb = std::numeric_limits::max(); + candidate_set.emplace(-lb, nearest); + } + visited_array[nearest] = visited_array_tag; + + while (!candidate_set.empty()) { + Node currNode = candidate_set.top(); + if ((-currNode.first) > lb) + break; + candidate_set.pop(); + int cur_id = currNode.second; + int *cur_link = get_neighbor_link(cur_id, 0); + auto cur_neighbor_num = get_neighbors_num(cur_link); + for (auto i = 1; i <= cur_neighbor_num; ++ i) { + int candidate_id = cur_link[i]; + if (visited_array[candidate_id] != visited_array_tag) { + visited_array[candidate_id] = visited_array_tag; + float dcand = ptdis(candidate_id); + if (top_candidates.size() < ef || lb > dcand) { + candidate_set.emplace(-dcand, candidate_id); + if (bitset == nullptr || !bitset->test((faiss::ConcurrentBitset::id_type_t)(candidate_id))) + top_candidates.emplace(dcand, candidate_id); + if (top_candidates.size() > ef) + top_candidates.pop(); + if (!top_candidates.empty()) + lb = top_candidates.top().first; + } + } + } + } + visited_list_pool->releaseVisitedList(vl); + return top_candidates; +} + +void +RHNSW::make_connection(DistanceComputer& ptdis, + storage_idx_t pt_id, + std::priority_queue, CompareByFirst> &cand, + int level) { + int maxM = level ? M : M << 1; + int *selectedNeighbors = (int*)malloc(sizeof(int) * maxM); + int selectedNeighborsNum = 0; + prune_neighbors(ptdis, cand, maxM, selectedNeighbors, selectedNeighborsNum); + if (selectedNeighborsNum > maxM) + throw std::runtime_error("Wrong size of candidates returned by prune_neighbors!"); + + int *cur_link = get_neighbor_link(pt_id, level); + if (*cur_link) + throw std::runtime_error("The newly inserted element should have blank link"); + + set_neighbors_num(cur_link, selectedNeighborsNum); + for (auto i = 1; i <= selectedNeighborsNum; ++ i) { + if (cur_link[i]) + throw std::runtime_error("Possible memory corruption."); + if (level > levels[selectedNeighbors[i - 1]]) + throw std::runtime_error("Trying to make a link on a non-exisitent level."); + cur_link[i] = selectedNeighbors[i - 1]; + } + + for (auto i = 0; i < selectedNeighborsNum; ++ i) { + std::unique_lock lk(link_list_locks[selectedNeighbors[i]]); + + int *selected_link = get_neighbor_link(selectedNeighbors[i], level); + auto selected_neighbor_num = get_neighbors_num(selected_link); + if (selected_neighbor_num > maxM) + throw std::runtime_error("Bad value of selected_neighbor_num."); + if (selectedNeighbors[i] == pt_id) + throw std::runtime_error("Trying to connect an element to itself."); + if (level > levels[selectedNeighbors[i]]) + throw std::runtime_error("Trying to make a link on a non-exisitent level."); + if (selected_neighbor_num < maxM) { + selected_link[selected_neighbor_num + 1] = pt_id; + set_neighbors_num(selected_link, selected_neighbor_num + 1); + } else { + double d_max = ptdis(selectedNeighbors[i]); + std::priority_queue, CompareByFirst> candi; + candi.emplace(d_max, pt_id); + for (auto j = 1; j <= selected_neighbor_num; ++ j) + candi.emplace(ptdis.symmetric_dis(selectedNeighbors[i], selected_link[j]), selected_link[j]); + int indx = 0; + prune_neighbors(ptdis, candi, maxM, selected_link + 1, indx); + set_neighbors_num(selected_link, indx); + } + } + + free(selectedNeighbors); +} + +void RHNSW::prune_neighbors(DistanceComputer& ptdis, + std::priority_queue, CompareByFirst> &cand, + const int maxM, int *ret, int &ret_len) { + if (cand.size() < maxM) { + while (!cand.empty()) { + ret[ret_len ++] = cand.top().second; + cand.pop(); + } + return; + } + std::priority_queue closest; + + while (!cand.empty()) { + closest.emplace(-cand.top().first, cand.top().second); + cand.pop(); + } + + while (closest.size()) { + if (ret_len >= maxM) + break; + Node curr = closest.top(); + float dist_to_query = -curr.first; + closest.pop(); + bool good = true; + for (auto i = 0; i < ret_len; ++ i) { + float cur_dist = ptdis.symmetric_dis(curr.second, ret[i]); + if (cur_dist < dist_to_query) { + good = false; + break; + } + } + if (good) { + ret[ret_len ++] = curr.second; + } + } +} + +void RHNSW::searchKnn(DistanceComputer& qdis, int k, + idx_t *I, float *D, + ConcurrentBitsetPtr bitset) const { + if (levels.size() == 0) + return; + int ep = entry_point; + float dist = qdis(ep); + + for (auto i = max_level; i > 0; -- i) { + bool good = true; + while (good) { + good = false; + int *ep_link = get_neighbor_link(ep, i); + auto ep_neighbors_cnt = get_neighbors_num(ep_link); + for (auto j = 1; j <= ep_neighbors_cnt; ++ j) { + int cand = ep_link[j]; + if (cand < 0 || cand > levels.size()) + throw std::runtime_error("cand error"); + float d = qdis(cand); + if (d < dist) { + dist = d; + ep = cand; + good = true; + } + } + } + } + std::priority_queue, CompareByFirst> top_candidates = search_base_layer(qdis, ep, std::max(efSearch, k), dist, bitset); + while (top_candidates.size() > k) + top_candidates.pop(); + int i = 0; + while (!top_candidates.empty()) { + I[i] = top_candidates.top().second; + D[i] = top_candidates.top().first; + i ++; + top_candidates.pop(); + } +} + +size_t RHNSW::cal_size() { + size_t ret = 0; + ret += sizeof(*this); + ret += visited_list_pool->GetSize(); + ret += link_list_locks.size() * sizeof(std::mutex); + ret += levels.size() * sizeof(int); + ret += levels.size() * level0_link_size; + ret += levels.size() * sizeof(void*); + for (auto i = 0; i < levels.size(); ++ i) { + ret += levels[i] ? link_size * levels[i] : 0; + } + return ret; +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/RHNSW.h b/core/src/index/thirdparty/faiss/impl/RHNSW.h new file mode 100644 index 0000000000..40ac9d68ef --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/RHNSW.h @@ -0,0 +1,367 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + + +namespace faiss { + + +/** Implementation of the Hierarchical Navigable Small World + * datastructure. + * + * Efficient and robust approximate nearest neighbor search using + * Hierarchical Navigable Small World graphs + * + * Yu. A. Malkov, D. A. Yashunin, arXiv 2017 + * + * This implmentation is heavily influenced by the hnswlib + * implementation by Yury Malkov and Leonid Boystov + * (https://github.com/searchivarius/nmslib/hnswlib) + * + * The HNSW object stores only the neighbor link structure, see + * IndexHNSW.h for the full index object. + */ + + +struct DistanceComputer; // from AuxIndexStructures +class VisitedListPool; + +struct RHNSW { + /// internal storage of vectors (32 bits: this is expensive) + typedef int storage_idx_t; + + /// Faiss results are 64-bit + typedef Index::idx_t idx_t; + + typedef std::pair Node; + + /** Heap structure that allows fast + */ + struct MinimaxHeap { + int n; + int k; + int nvalid; + + std::vector ids; + std::vector dis; + typedef faiss::CMax HC; + + explicit MinimaxHeap(int n): n(n), k(0), nvalid(0), ids(n), dis(n) {} + + void push(storage_idx_t i, float v) { + if (k == n) { + if (v >= dis[0]) return; + faiss::heap_pop (k--, dis.data(), ids.data()); + --nvalid; + } + faiss::heap_push (++k, dis.data(), ids.data(), v, i); + ++nvalid; + } + + float max() const { + return dis[0]; + } + + int size() const { + return nvalid; + } + + void clear() { + nvalid = k = 0; + } + + int pop_min(float *vmin_out = nullptr) { + assert(k > 0); + // returns min. This is an O(n) operation + int i = k - 1; + while (i >= 0) { + if (ids[i] != -1) break; + i--; + } + if (i == -1) return -1; + int imin = i; + float vmin = dis[i]; + i--; + while(i >= 0) { + if (ids[i] != -1 && dis[i] < vmin) { + vmin = dis[i]; + imin = i; + } + i--; + } + if (vmin_out) *vmin_out = vmin; + int ret = ids[imin]; + ids[imin] = -1; + --nvalid; + + return ret; + } + + int count_below(float thresh) { + int n_below = 0; + for(int i = 0; i < k; i++) { + if (dis[i] < thresh) { + n_below++; + } + } + + return n_below; + } + }; + + /// to sort pairs of (id, distance) from nearest to fathest or the reverse + struct NodeDistCloser { + float d; + int id; + NodeDistCloser(float d, int id): d(d), id(id) {} + bool operator < (const NodeDistCloser &obj1) const { return d < obj1.d; } + }; + + struct NodeDistFarther { + float d; + int id; + NodeDistFarther(float d, int id): d(d), id(id) {} + bool operator < (const NodeDistFarther &obj1) const { return d > obj1.d; } + }; + + struct CompareByFirst { + constexpr bool operator()(Node const &a, + Node const &b) const noexcept { + return a.first < b.first; + } + }; + + + /// level of each vector (base level = 1), size = ntotal + std::vector levels; + + /// number of entry points in levels > 0. + int upper_beam; + + /// entry point in the search structure (one of the points with maximum level + storage_idx_t entry_point; + + faiss::RandomGenerator rng; + std::default_random_engine level_generator; + + /// maximum level + int max_level; + int M; + char *level0_links; + char **linkLists; + size_t level0_link_size; + size_t link_size; + double level_constant; + VisitedListPool *visited_list_pool; + std::vector link_list_locks; + std::mutex global; + + /// expansion factor at construction time + int efConstruction; + + /// expansion factor at search time + int efSearch; + + /// range of entries in the neighbors table of vertex no at layer_no + storage_idx_t* get_neighbor_link(idx_t no, int layer_no) const { + return layer_no == 0 ? (int*)(level0_links + no * level0_link_size) : (int*)(linkLists[no] + (layer_no - 1) * link_size); + } + unsigned short int get_neighbors_num(int *p) const { + return *((unsigned short int*)p); + } + void set_neighbors_num(int *p, unsigned short int num) const { + *((unsigned short int*)(p)) = *((unsigned short int *)(&num)); + } + + /// only mandatory parameter: nb of neighbors + explicit RHNSW(int M = 32); + ~RHNSW(); + + void init(int ntotal); + /// pick a random level for a new point, arg = 1/log(M) + int random_level(double arg) { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(level_generator)) * arg; + return (int)r; + } + + void reset(); + + int prepare_level_tab(size_t n, bool preset_levels = false); + + // re-implementations inspired by hnswlib + /** add point pt_id on all levels <= pt_level and build the link + * structure for them. inspired by implementation of hnswlib */ + void addPoint(DistanceComputer& ptdis, int pt_level, int pt_id); + + std::priority_queue, CompareByFirst> + search_layer (DistanceComputer& ptdis, + storage_idx_t pt_id, + storage_idx_t nearest, + int level); + + std::priority_queue, CompareByFirst> + search_base_layer (DistanceComputer& ptdis, + storage_idx_t nearest, + storage_idx_t ef, + float d_nearest, + ConcurrentBitsetPtr bitset = nullptr) const; + + void make_connection(DistanceComputer& ptdis, + storage_idx_t pt_id, + std::priority_queue, CompareByFirst> &cand, + int level); + + void prune_neighbors(DistanceComputer& ptdis, + std::priority_queue, CompareByFirst> &cand, + const int maxM, int *ret, int &ret_len); + + /// search interface inspired by hnswlib + void searchKnn(DistanceComputer& qdis, int k, + idx_t *I, float *D, + ConcurrentBitsetPtr bitset = nullptr) const; + + size_t cal_size(); + +}; + + +/************************************************************** + * Auxiliary structures + **************************************************************/ + +typedef unsigned short int vl_type; + +class VisitedList { + public: + vl_type curV; + vl_type *mass; + unsigned int numelements; + + VisitedList(int numelements1) { + curV = -1; + numelements = numelements1; + mass = new vl_type[numelements]; + } + + void reset() { + curV++; + if (curV == 0) { + memset(mass, 0, sizeof(vl_type) * numelements); + curV++; + } + }; + + // keep compatibae with original version VisitedTable + /// set flog #no to true + void set(int no) { + mass[no] = curV; + } + + /// get flag #no + bool get(int no) const { + return mass[no] == curV; + } + + void advance() { + reset(); + } + + ~VisitedList() { delete[] mass; } +}; + +/////////////////////////////////////////////////////////// +// +// Class for multi-threaded pool-management of VisitedLists +// +///////////////////////////////////////////////////////// + +class VisitedListPool { + std::deque pool; + std::mutex poolguard; + int numelements; + + public: + VisitedListPool(int initmaxpools, int numelements1) { + numelements = numelements1; + for (int i = 0; i < initmaxpools; i++) + pool.push_front(new VisitedList(numelements)); + } + + VisitedList *getFreeVisitedList() { + VisitedList *rez; + { + std::unique_lock lock(poolguard); + if (pool.size() > 0) { + rez = pool.front(); + pool.pop_front(); + } else { + rez = new VisitedList(numelements); + } + } + rez->reset(); + return rez; + }; + + void releaseVisitedList(VisitedList *vl) { + std::unique_lock lock(poolguard); + pool.push_front(vl); + }; + + ~VisitedListPool() { + while (pool.size()) { + VisitedList *rez = pool.front(); + pool.pop_front(); + delete rez; + } + }; + + int64_t GetSize() { + auto visit_list_size = sizeof(VisitedList) + numelements * sizeof(vl_type); + auto pool_size = pool.size() * (sizeof(VisitedList *) + visit_list_size); + return pool_size + sizeof(*this); + } +}; + +struct RHNSWStats { + size_t n1, n2, n3; + size_t ndis; + size_t nreorder; + bool view; + + RHNSWStats() { + reset(); + } + + void reset() { + n1 = n2 = n3 = 0; + ndis = 0; + nreorder = 0; + view = false; + } +}; + +// global var that collects them all +extern RHNSWStats rhnsw_stats; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.h b/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.h index cb447c603b..3dfa72333d 100644 --- a/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.h +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.h @@ -75,6 +75,7 @@ struct ScalarQuantizer { (MetricType mt, const Index *quantizer, bool store_pairs, bool by_residual=false) const; + size_t cal_size() { return sizeof(*this) + trained.size() * sizeof(float); } }; template diff --git a/core/src/index/thirdparty/faiss/impl/index_read.cpp b/core/src/index/thirdparty/faiss/impl/index_read.cpp index 24556606ee..9a3936a715 100644 --- a/core/src/index/thirdparty/faiss/impl/index_read.cpp +++ b/core/src/index/thirdparty/faiss/impl/index_read.cpp @@ -36,6 +36,7 @@ #include #include #include +#include #include #include @@ -458,6 +459,29 @@ static void read_HNSW (HNSW *hnsw, IOReader *f) { READ1 (hnsw->upper_beam); } +static void read_RHNSW (RHNSW *rhnsw, IOReader *f) { + READ1 (rhnsw->entry_point); + READ1 (rhnsw->max_level); + READ1 (rhnsw->M); + READ1 (rhnsw->level0_link_size); + READ1 (rhnsw->link_size); + READ1 (rhnsw->level_constant); + READ1 (rhnsw->efConstruction); + READ1 (rhnsw->efSearch); + + READVECTOR (rhnsw->levels); + auto ntotal = rhnsw->levels.size(); + rhnsw->level0_links = (char*) malloc(ntotal * rhnsw->level0_link_size); + READANDCHECK( rhnsw->level0_links, ntotal * rhnsw->level0_link_size); + rhnsw->linkLists = (char**) malloc(ntotal * sizeof(void*)); + for (auto i = 0; i < ntotal; ++ i) { + if (rhnsw->levels[i]) { + rhnsw->linkLists[i] = (char*)malloc(rhnsw->link_size * rhnsw->levels[i] + 1); + READANDCHECK( rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1); + } + } +} + ProductQuantizer * read_ProductQuantizer (const char*fname) { FileIOReader reader(fname); return read_ProductQuantizer(&reader); @@ -624,6 +648,9 @@ Index *read_index (IOReader *f, int io_flags) { if (h == fourcc ("IxPQ") || h == fourcc ("IxPo")) { idxp->metric_type = METRIC_L2; } + if (h == fourcc("IxPq")) { + idxp->pq.compute_sdc_table (); + } idx = idxp; } else if (h == fourcc ("IvFl") || h == fourcc("IvFL")) { // legacy IndexIVFFlat * ivfl = new IndexIVFFlat (); @@ -800,6 +827,17 @@ Index *read_index (IOReader *f, int io_flags) { dynamic_cast(idxhnsw->storage)->pq.compute_sdc_table (); } idx = idxhnsw; + } else if(h == fourcc("IRHf") || h == fourcc("IRHp") || + h == fourcc("IRHs") || h == fourcc("IRH2")) { + IndexRHNSW *idxrhnsw = nullptr; + if (h == fourcc("IRHf")) idxrhnsw = new IndexRHNSWFlat (); + if (h == fourcc("IRHp")) idxrhnsw = new IndexRHNSWPQ (); + if (h == fourcc("IRHs")) idxrhnsw = new IndexRHNSWSQ (); + if (h == fourcc("IRH2")) idxrhnsw = new IndexRHNSW2Level (); + read_index_header (idxrhnsw, f); + read_RHNSW (&idxrhnsw->hnsw, f); + idxrhnsw->own_fields = true; + idx = idxrhnsw; } else { FAISS_THROW_FMT("Index type 0x%08x not supported\n", h); idx = nullptr; diff --git a/core/src/index/thirdparty/faiss/impl/index_write.cpp b/core/src/index/thirdparty/faiss/impl/index_write.cpp index ef7720a273..9bbaa4e8bc 100644 --- a/core/src/index/thirdparty/faiss/impl/index_write.cpp +++ b/core/src/index/thirdparty/faiss/impl/index_write.cpp @@ -36,6 +36,7 @@ #include #include #include +#include #include #include @@ -364,6 +365,24 @@ static void write_HNSW (const HNSW *hnsw, IOWriter *f) { WRITE1 (hnsw->upper_beam); } +static void write_RHNSW (const RHNSW *rhnsw, IOWriter *f) { + WRITE1 (rhnsw->entry_point); + WRITE1 (rhnsw->max_level); + WRITE1 (rhnsw->M); + WRITE1 (rhnsw->level0_link_size); + WRITE1 (rhnsw->link_size); + WRITE1 (rhnsw->level_constant); + WRITE1 (rhnsw->efConstruction); + WRITE1 (rhnsw->efSearch); + + WRITEVECTOR (rhnsw->levels); + WRITEANDCHECK (rhnsw->level0_links, rhnsw->level0_link_size * rhnsw->levels.size()); + for (auto i = 0; i < rhnsw->levels.size(); ++ i) { + if (rhnsw->levels[i]) + WRITEANDCHECK (rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1); + } +} + static void write_direct_map (const DirectMap *dm, IOWriter *f) { char maintain_direct_map = (char)dm->type; // for backwards compatibility with bool WRITE1 (maintain_direct_map); @@ -560,6 +579,18 @@ void write_index (const Index *idx, IOWriter *f) { write_index_header (idxhnsw, f); write_HNSW (&idxhnsw->hnsw, f); write_index (idxhnsw->storage, f); + } else if (const IndexRHNSW * idxrhnsw = + dynamic_cast(idx)) { + uint32_t h = + dynamic_cast(idx) ? fourcc("IRHf") : + dynamic_cast(idx) ? fourcc("IRHp") : + dynamic_cast(idx) ? fourcc("IRHs") : + dynamic_cast(idx) ? fourcc("IRH2") : + 0; + FAISS_THROW_IF_NOT (h != 0); + WRITE1 (h); + write_index_header (idxrhnsw, f); + write_RHNSW (&idxrhnsw->hnsw, f); } else { FAISS_THROW_MSG ("don't know how to serialize this type of index"); } diff --git a/core/src/index/thirdparty/hnswlib/hnswalg_nm.h b/core/src/index/thirdparty/hnswlib/hnswalg_nm.h index ffdd1985a2..a39563d2b2 100644 --- a/core/src/index/thirdparty/hnswlib/hnswalg_nm.h +++ b/core/src/index/thirdparty/hnswlib/hnswalg_nm.h @@ -1224,4 +1224,4 @@ namespace hnswlib_nm { } }; -} \ No newline at end of file +} diff --git a/core/src/index/unittest/CMakeLists.txt b/core/src/index/unittest/CMakeLists.txt index 72e2f3e67a..c8dcfa986e 100644 --- a/core/src/index/unittest/CMakeLists.txt +++ b/core/src/index/unittest/CMakeLists.txt @@ -219,6 +219,42 @@ endif () target_link_libraries(test_hnsw_sq8nm ${depend_libs} ${unittest_libs} ${basic_libs}) install(TARGETS test_hnsw_sq8nm DESTINATION unittest) +################################################################################ +# +set(rhnsw_flat_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.cpp + ) +if (NOT TARGET test_rhnsw_flat) + add_executable(test_rhnsw_flat test_rhnsw_flat.cpp ${rhnsw_flat_srcs} ${util_srcs} ${faiss_srcs}) +endif () +target_link_libraries(test_rhnsw_flat ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_rhnsw_flat DESTINATION unittest) + +################################################################################ +# +set(rhnsw_pq_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.cpp + ) +if (NOT TARGET test_rhnsw_pq) + add_executable(test_rhnsw_pq test_rhnsw_pq.cpp ${rhnsw_pq_srcs} ${util_srcs} ${faiss_srcs}) +endif () +target_link_libraries(test_rhnsw_pq ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_rhnsw_pq DESTINATION unittest) + +################################################################################ +# +set(rhnsw_sq8_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.cpp + ) +if (NOT TARGET test_rhnsw_sq8) + add_executable(test_rhnsw_sq8 test_rhnsw_sq8.cpp ${rhnsw_sq8_srcs} ${util_srcs} ${faiss_srcs}) +endif () +target_link_libraries(test_rhnsw_sq8 ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_rhnsw_sq8 DESTINATION unittest) + ################################################################################ # if (MILVUS_SUPPORT_SPTAG) diff --git a/core/src/index/unittest/test_rhnsw_flat.cpp b/core/src/index/unittest/test_rhnsw_flat.cpp new file mode 100644 index 0000000000..51f4859983 --- /dev/null +++ b/core/src/index/unittest/test_rhnsw_flat.cpp @@ -0,0 +1,158 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class RHNSWFlatTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + std::cout << "IndexType from GetParam() is: " << IndexType << std::endl; + Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10 + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, + {milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWFlatTest, Values("RHNSWFlat")); + +TEST_P(RHNSWFlatTest, HNSW_basic) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + auto result1 = index_->Query(query_dataset, conf); + AssertAnns(result1, nq, k); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(); + + auto tmp_index = std::make_shared(); + + tmp_index->Load(bs); + + auto result2 = tmp_index->Query(query_dataset, conf); + AssertAnns(result2, nq, k); +} + +TEST_P(RHNSWFlatTest, HNSW_delete) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + auto result1 = index_->Query(query_dataset, conf); + AssertAnns(result1, nq, k); + + index_->SetBlacklist(bitset); + auto result2 = index_->Query(query_dataset, conf); + AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + + /* + * delete result checked by eyes + auto ids1 = result1->Get(milvus::knowhere::meta::IDS); + auto ids2 = result2->Get(milvus::knowhere::meta::IDS); + std::cout << std::endl; + for (int i = 0; i < nq; ++ i) { + std::cout << "ids1: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids1 + i * k + j) << " "; + } + std::cout << "ids2: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids2 + i * k + j) << " "; + } + std::cout << std::endl; + for (int j = 0; j < std::min(5, k>>1); ++ j) { + ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j)); + } + } + */ +} + +TEST_P(RHNSWFlatTest, HNSW_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + auto binaryset = index_->Serialize(); + std::string index_type = index_->index_type(); + std::string idx_name = index_type + "_Index"; + std::string dat_name = index_type + "_Data"; + if (binaryset.binary_map_.find(idx_name) == binaryset.binary_map_.end()) { + std::cout << "no idx!" << std::endl; + } + if (binaryset.binary_map_.find(dat_name) == binaryset.binary_map_.end()) { + std::cout << "no dat!" << std::endl; + } + auto bin_idx = binaryset.GetByName(idx_name); + auto bin_dat = binaryset.GetByName(dat_name); + + std::string filename_idx = "/tmp/RHNSWFlat_test_serialize_idx.bin"; + std::string filename_dat = "/tmp/RHNSWFlat_test_serialize_dat.bin"; + auto load_idx = new uint8_t[bin_idx->size]; + auto load_dat = new uint8_t[bin_dat->size]; + serialize(filename_idx, bin_idx, load_idx); + serialize(filename_dat, bin_dat, load_dat); + + binaryset.clear(); + auto new_idx = std::make_shared(); + std::shared_ptr dat(load_dat); + std::shared_ptr idx(load_idx); + binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size); + binaryset.Append(new_idx->index_type() + "_Data", dat, bin_dat->size); + + new_idx->Load(binaryset); + EXPECT_EQ(new_idx->Count(), nb); + EXPECT_EQ(new_idx->Dim(), dim); + auto result = new_idx->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +} diff --git a/core/src/index/unittest/test_rhnsw_pq.cpp b/core/src/index/unittest/test_rhnsw_pq.cpp new file mode 100644 index 0000000000..5aa55a366e --- /dev/null +++ b/core/src/index/unittest/test_rhnsw_pq.cpp @@ -0,0 +1,148 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class RHNSWPQTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + std::cout << "IndexType from GetParam() is: " << IndexType << std::endl; + Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10 + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, + {milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::IndexParams::PQM, 8}}; + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWPQTest, Values("RHNSWPQ")); + +TEST_P(RHNSWPQTest, HNSW_basic) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(); + auto result1 = index_->Query(query_dataset, conf); + // AssertAnns(result1, nq, k); + + auto tmp_index = std::make_shared(); + + tmp_index->Load(bs); + + auto result2 = tmp_index->Query(query_dataset, conf); + // AssertAnns(result2, nq, k); +} + +TEST_P(RHNSWPQTest, HNSW_delete) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + auto result1 = index_->Query(query_dataset, conf); + // AssertAnns(result1, nq, k); + + index_->SetBlacklist(bitset); + auto result2 = index_->Query(query_dataset, conf); + // AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + + /* + * delete result checked by eyes + auto ids1 = result1->Get(milvus::knowhere::meta::IDS); + auto ids2 = result2->Get(milvus::knowhere::meta::IDS); + std::cout << std::endl; + for (int i = 0; i < nq; ++ i) { + std::cout << "ids1: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids1 + i * k + j) << " "; + } + std::cout << "ids2: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids2 + i * k + j) << " "; + } + std::cout << std::endl; + for (int j = 0; j < std::min(5, k>>1); ++ j) { + ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j)); + } + } + */ +} + +TEST_P(RHNSWPQTest, HNSW_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + auto binaryset = index_->Serialize(); + auto bin_idx = binaryset.GetByName(index_->index_type() + "_Index"); + auto bin_dat = binaryset.GetByName(index_->index_type() + "_Data"); + + std::string filename_idx = "/tmp/RHNSWPQ_test_serialize_idx.bin"; + std::string filename_dat = "/tmp/RHNSWPQ_test_serialize_dat.bin"; + auto load_idx = new uint8_t[bin_idx->size]; + auto load_dat = new uint8_t[bin_dat->size]; + serialize(filename_idx, bin_idx, load_idx); + serialize(filename_dat, bin_dat, load_dat); + + binaryset.clear(); + auto new_idx = std::make_shared(); + std::shared_ptr dat(load_dat); + std::shared_ptr idx(load_idx); + binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size); + binaryset.Append(new_idx->index_type() + "_Data", dat, bin_dat->size); + + new_idx->Load(binaryset); + EXPECT_EQ(new_idx->Count(), nb); + EXPECT_EQ(new_idx->Dim(), dim); + auto result = new_idx->Query(query_dataset, conf); + // AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +} diff --git a/core/src/index/unittest/test_rhnsw_sq8.cpp b/core/src/index/unittest/test_rhnsw_sq8.cpp new file mode 100644 index 0000000000..1053d78764 --- /dev/null +++ b/core/src/index/unittest/test_rhnsw_sq8.cpp @@ -0,0 +1,149 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class RHNSWSQ8Test : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + std::cout << "IndexType from GetParam() is: " << IndexType << std::endl; + Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10 + // Generate(2, 10, 2); // dim = 64, nb = 10000, nq = 10 + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, + {milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWSQ8Test, Values("RHNSWSQ8")); + +TEST_P(RHNSWSQ8Test, HNSW_basic) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(); + auto result1 = index_->Query(query_dataset, conf); + AssertAnns(result1, nq, k); + + auto tmp_index = std::make_shared(); + + tmp_index->Load(bs); + + auto result2 = tmp_index->Query(query_dataset, conf); + AssertAnns(result2, nq, k); +} + +TEST_P(RHNSWSQ8Test, HNSW_delete) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + auto result1 = index_->Query(query_dataset, conf); + AssertAnns(result1, nq, k); + + index_->SetBlacklist(bitset); + auto result2 = index_->Query(query_dataset, conf); + AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + + /* + * delete result checked by eyes + auto ids1 = result1->Get(milvus::knowhere::meta::IDS); + auto ids2 = result2->Get(milvus::knowhere::meta::IDS); + std::cout << std::endl; + for (int i = 0; i < nq; ++ i) { + std::cout << "ids1: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids1 + i * k + j) << " "; + } + std::cout << "ids2: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids2 + i * k + j) << " "; + } + std::cout << std::endl; + for (int j = 0; j < std::min(5, k>>1); ++ j) { + ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j)); + } + } + */ +} + +TEST_P(RHNSWSQ8Test, HNSW_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + auto binaryset = index_->Serialize(); + auto bin_idx = binaryset.GetByName(index_->index_type() + "_Index"); + auto bin_dat = binaryset.GetByName(index_->index_type() + "_Data"); + + std::string filename_idx = "/tmp/RHNSWSQ_test_serialize_idx.bin"; + std::string filename_dat = "/tmp/RHNSWSQ_test_serialize_dat.bin"; + auto load_idx = new uint8_t[bin_idx->size]; + auto load_dat = new uint8_t[bin_dat->size]; + serialize(filename_idx, bin_idx, load_idx); + serialize(filename_dat, bin_dat, load_dat); + + binaryset.clear(); + auto new_idx = std::make_shared(); + std::shared_ptr dat(load_dat); + std::shared_ptr idx(load_idx); + binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size); + binaryset.Append(new_idx->index_type() + "_Data", dat, bin_dat->size); + + new_idx->Load(binaryset); + EXPECT_EQ(new_idx->Count(), nb); + EXPECT_EQ(new_idx->Dim(), dim); + auto result = new_idx->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +} diff --git a/core/src/server/ValidationUtil.cpp b/core/src/server/ValidationUtil.cpp index 8516dbfe05..094b976af2 100644 --- a/core/src/server/ValidationUtil.cpp +++ b/core/src/server/ValidationUtil.cpp @@ -279,7 +279,9 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s if (!status.ok()) { return status; } - } else if (index_type == knowhere::IndexEnum::INDEX_HNSW || index_type == knowhere::IndexEnum::INDEX_HNSW_SQ8NM) { + } else if (index_type == knowhere::IndexEnum::INDEX_HNSW || index_type == knowhere::IndexEnum::INDEX_HNSW_SQ8NM || + index_type == knowhere::IndexEnum::INDEX_RHNSWPQ || index_type == knowhere::IndexEnum::INDEX_RHNSWSQ || + index_type == knowhere::IndexEnum::INDEX_RHNSWFlat) { auto status = CheckParameterRange(index_params, knowhere::IndexParams::M, 4, 64); if (!status.ok()) { return status; @@ -288,6 +290,38 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s if (!status.ok()) { return status; } + + if (index_type == knowhere::IndexEnum::INDEX_RHNSWPQ) { + status = CheckParameterExistence(index_params, knowhere::IndexParams::PQM); + if (!status.ok()) { + return status; + } + + // special check for 'PQM' parameter + std::vector resset; + milvus::knowhere::IVFPQConfAdapter::GetValidMList(dimension, resset); + int64_t pqm_value = index_params[knowhere::IndexParams::PQM]; + if (resset.empty()) { + std::string msg = "Invalid collection dimension, unable to get reasonable values for 'PQM'"; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_COLLECTION_DIMENSION, msg); + } + + auto iter = std::find(std::begin(resset), std::end(resset), pqm_value); + if (iter == std::end(resset)) { + std::string msg = + "Invalid " + std::string(knowhere::IndexParams::PQM) + ", must be one of the following values: "; + for (size_t i = 0; i < resset.size(); i++) { + if (i != 0) { + msg += ","; + } + msg += std::to_string(resset[i]); + } + + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + } } else if (index_type == knowhere::IndexEnum::INDEX_ANNOY) { auto status = CheckParameterRange(index_params, knowhere::IndexParams::n_trees, 1, 1024); if (!status.ok()) { diff --git a/core/src/server/web_impl/Constants.cpp b/core/src/server/web_impl/Constants.cpp index 36bba0f790..a21459403a 100644 --- a/core/src/server/web_impl/Constants.cpp +++ b/core/src/server/web_impl/Constants.cpp @@ -25,6 +25,9 @@ const char* NAME_ENGINE_TYPE_HNSW = "HNSW"; const char* NAME_ENGINE_TYPE_ANNOY = "ANNOY"; const char* NAME_ENGINE_TYPE_IVFSQ8NR = "IVFSQ8NR"; const char* NAME_ENGINE_TYPE_HNSWSQ8NM = "HNSWSQ8NM"; +const char* NAME_ENGINE_TYPE_RHNSWFLAT = "RHNSWFLAT"; +const char* NAME_ENGINE_TYPE_RHNSWPQ = "RHNSWPQ"; +const char* NAME_ENGINE_TYPE_RHNSWSQ8 = "RHNSWSQ8"; const char* NAME_METRIC_TYPE_L2 = "L2"; const char* NAME_METRIC_TYPE_IP = "IP"; diff --git a/core/src/server/web_impl/Constants.h b/core/src/server/web_impl/Constants.h index f070d9298b..f19c14a54b 100644 --- a/core/src/server/web_impl/Constants.h +++ b/core/src/server/web_impl/Constants.h @@ -28,6 +28,9 @@ extern const char* NAME_ENGINE_TYPE_IVFPQ; extern const char* NAME_ENGINE_TYPE_HNSW; extern const char* NAME_ENGINE_TYPE_HNSW_SQ8NM; extern const char* NAME_ENGINE_TYPE_ANNOY; +extern const char* NAME_ENGINE_TYPE_RHNSWFLAT; +extern const char* NAME_ENGINE_TYPE_RHNSWPQ; +extern const char* NAME_ENGINE_TYPE_RHNSWSQ; extern const char* NAME_METRIC_TYPE_L2; extern const char* NAME_METRIC_TYPE_IP;