Faiss hnsw upgrade (#3134)

* combine the hnsw's implemention of faiss and hnswlib

Signed-off-by: cmli <chengming.li@zilliz.com>

* transplant the datastructure of hnsw from hnswlib 2 faiss

Signed-off-by: cmli <chengming.li@zilliz.com>

* basic work finished, pass compile, to be tested

Signed-off-by: cmli <chengming.li@zilliz.com>

* rhnswflat, rhnswsq, rhnswpq pass ut

Signed-off-by: cmli <chengming.li@zilliz.com>

* remove AssertAnns of RHNSWPQ because PQ has accuracy loss

Signed-off-by: cmli <chengming.li@zilliz.com>


Co-authored-by: cmli <chengming.li@zilliz.com>
Co-authored-by: shengjun.li <shengjun.li@zilliz.com>
pull/3153/head^2
op-hunter 2020-08-06 11:36:41 +08:00 committed by GitHub
parent b662295d63
commit 7688f51343
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 3206 additions and 5 deletions

View File

@ -59,6 +59,7 @@ Please mark all changes in change log and use the issue from GitHub
- \#2802 Add new index: IVFSQ8NR
- \#2834 Add C++ sdk support 4 hnsw_sq8nr
- \#2940 Add option to build.sh for cuda arch
- \#3132 Refine the implementation of hnsw in faiss and add support for hnsw-flat, hnsw-pq and hnsw-sq8 based on faiss
## Improvement
- \#2543 Remove secondary_path related code

View File

@ -63,6 +63,10 @@ set(vector_index_srcs
knowhere/index/IndexType.cpp
knowhere/index/vector_index/VecIndexFactory.cpp
knowhere/index/vector_index/IndexAnnoy.cpp
knowhere/index/vector_index/IndexRHNSW.cpp
knowhere/index/vector_index/IndexRHNSWFlat.cpp
knowhere/index/vector_index/IndexRHNSWSQ.cpp
knowhere/index/vector_index/IndexRHNSWPQ.cpp
)
set(vector_offset_index_srcs

View File

@ -34,6 +34,9 @@ const char* INDEX_SPTAG_KDT_RNT = "SPTAG_KDT_RNT";
const char* INDEX_SPTAG_BKT_RNT = "SPTAG_BKT_RNT";
#endif
const char* INDEX_HNSW = "HNSW";
const char* INDEX_RHNSWFlat = "RHNSW_FLAT";
const char* INDEX_RHNSWPQ = "RHNSW_PQ";
const char* INDEX_RHNSWSQ = "RHNSW_SQ";
const char* INDEX_ANNOY = "ANNOY";
const char* INDEX_HNSW_SQ8NM = "HNSW_SQ8NM";
} // namespace IndexEnum

View File

@ -37,6 +37,9 @@ enum class OldIndexType {
ANNOY,
FAISS_IVFSQ8NR,
HNSW_SQ8NM,
RHNSW_FLAT,
RHNSW_PQ,
RHNSW_SQ,
FAISS_BIN_IDMAP = 100,
FAISS_BIN_IVFLAT_CPU = 101,
};
@ -60,6 +63,9 @@ extern const char* INDEX_SPTAG_KDT_RNT;
extern const char* INDEX_SPTAG_BKT_RNT;
#endif
extern const char* INDEX_HNSW;
extern const char* INDEX_RHNSWFlat;
extern const char* INDEX_RHNSWPQ;
extern const char* INDEX_RHNSWSQ;
extern const char* INDEX_ANNOY;
extern const char* INDEX_HNSW_SQ8NM;
} // namespace IndexEnum

View File

@ -280,6 +280,80 @@ HNSWSQ8NRConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const In
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
RHNSWFlatConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_EFCONSTRUCTION = 8;
static int64_t MAX_EFCONSTRUCTION = 512;
static int64_t MIN_M = 4;
static int64_t MAX_M = 64;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
RHNSWFlatConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_EFCONSTRUCTION = 8;
static int64_t MAX_EFCONSTRUCTION = 512;
static int64_t MIN_M = 4;
static int64_t MAX_M = 64;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
std::vector<int64_t> resset;
int64_t dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
IVFPQConfAdapter::GetValidMList(dimension, resset);
CheckIntByValues(knowhere::IndexParams::m, resset);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
RHNSWPQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
RHNSWSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_EFCONSTRUCTION = 8;
static int64_t MAX_EFCONSTRUCTION = 512;
static int64_t MIN_M = 4;
static int64_t MAX_M = 64;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
RHNSWSQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
BinIDMAPConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,

View File

@ -109,5 +109,31 @@ class IVFSQ8NRConfAdapter : public IVFConfAdapter {
CheckTrain(Config& oricfg, const IndexMode mode) override;
};
class RHNSWFlatConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class RHNSWPQConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class RHNSWSQConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
} // namespace knowhere
} // namespace milvus

View File

@ -51,6 +51,9 @@ AdapterMgr::RegisterAdapter() {
REGISTER_CONF_ADAPTER(ANNOYConfAdapter, IndexEnum::INDEX_ANNOY, annoy_adapter);
REGISTER_CONF_ADAPTER(HNSWSQ8NRConfAdapter, IndexEnum::INDEX_HNSW_SQ8NM, hnswsq8nr_adapter);
REGISTER_CONF_ADAPTER(IVFSQ8NRConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8NR, ivfsq8nr_adapter);
REGISTER_CONF_ADAPTER(RHNSWFlatConfAdapter, IndexEnum::INDEX_RHNSWFlat, rhnswflat_adapter);
REGISTER_CONF_ADAPTER(RHNSWPQConfAdapter, IndexEnum::INDEX_RHNSWPQ, rhnswpq_adapter);
REGISTER_CONF_ADAPTER(RHNSWSQConfAdapter, IndexEnum::INDEX_RHNSWSQ, rhnswsq_adapter);
}
} // namespace knowhere

View File

@ -53,5 +53,11 @@ FaissBaseIndex::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
SealImpl();
}
void
FaissBaseIndex::SealImpl() {
}
// FaissBaseIndex::~FaissBaseIndex() {}
//
} // namespace knowhere
} // namespace milvus

View File

@ -34,8 +34,7 @@ class FaissBaseIndex {
LoadImpl(const BinarySet&, const IndexType& type);
virtual void
SealImpl() { /* do nothing */
}
SealImpl();
public:
std::shared_ptr<faiss::Index> index_ = nullptr;

View File

@ -0,0 +1,148 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexRHNSW.h"
#include <algorithm>
#include <cassert>
#include <iterator>
#include <utility>
#include <vector>
#include "faiss/BuilderSuspend.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace milvus {
namespace knowhere {
BinarySet
IndexRHNSW::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
try {
MemoryIOWriter writer;
writer.name = this->index_type() + "_Index";
faiss::write_index(index_.get(), &writer);
std::shared_ptr<uint8_t[]> data(writer.data_);
BinarySet res_set;
res_set.Append(writer.name, data, writer.rp);
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSW::Load(const BinarySet& index_binary) {
try {
MemoryIOReader reader;
reader.name = this->index_type() + "_Index";
auto binary = index_binary.GetByName(reader.name);
reader.total = (size_t)binary->size;
reader.data_ = binary->data.get();
auto idx = faiss::read_index(&reader);
index_.reset(idx);
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) {
KNOWHERE_THROW_MSG("IndexRHNSW has no implementation of Train, please use IndexRHNSW(Flat/SQ/PQ) instead!");
}
void
IndexRHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GET_TENSOR_DATA(dataset_ptr)
index_->add(rows, (float*)p_data);
}
DatasetPtr
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
GET_TENSOR_DATA(dataset_ptr)
size_t k = config[meta::TOPK].get<int64_t>();
size_t id_size = sizeof(int64_t) * k;
size_t dist_size = sizeof(float) * k;
auto p_id = (int64_t*)malloc(id_size * rows);
auto p_dist = (float*)malloc(dist_size * rows);
for (auto i = 0; i < k * rows; ++i) {
p_id[i] = -1;
p_dist[i] = -1;
}
auto real_index = dynamic_cast<faiss::IndexRHNSW*>(index_.get());
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
real_index->hnsw.efSearch = (config[IndexParams::ef]);
real_index->search(rows, (float*)p_data, k, p_dist, p_id, blacklist);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
}
int64_t
IndexRHNSW::Count() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->ntotal;
}
int64_t
IndexRHNSW::Dim() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->d;
}
void
IndexRHNSW::UpdateIndexSize() {
KNOWHERE_THROW_MSG(
"IndexRHNSW has no implementation of UpdateIndexSize, please use IndexRHNSW(Flat/SQ/PQ) instead!");
}
/*
BinarySet
IndexRHNSW::SerializeImpl(const milvus::knowhere::IndexType &type) { return BinarySet(); }
void
IndexRHNSW::SealImpl() {}
void
IndexRHNSW::LoadImpl(const milvus::knowhere::BinarySet &, const milvus::knowhere::IndexType &type) {}
*/
void
IndexRHNSW::AddWithoutIds(const milvus::knowhere::DatasetPtr& dataset, const milvus::knowhere::Config& config) {
KNOWHERE_THROW_MSG("IndexRHNSW has no implementation of AddWithoutIds, please use IndexRHNSW(Flat/SQ/PQ) instead!");
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,67 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <mutex>
#include <utility>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/FaissBaseIndex.h"
#include "knowhere/index/vector_index/VecIndex.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include <faiss/index_io.h>
#include "faiss/IndexRHNSW.h"
namespace milvus {
namespace knowhere {
class IndexRHNSW : public VecIndex, public FaissBaseIndex {
public:
IndexRHNSW() : FaissBaseIndex(nullptr) {
index_type_ = IndexEnum::INVALID;
}
explicit IndexRHNSW(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
index_type_ = IndexEnum::INVALID;
}
BinarySet
Serialize(const Config& config) override;
void
Load(const BinarySet& index_binary) override;
void
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
void
Add(const DatasetPtr& dataset_ptr, const Config& config) override;
void
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
int64_t
Count() override;
int64_t
Dim() override;
void
UpdateIndexSize() override;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,107 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexRHNSWFlat.h"
#include <algorithm>
#include <cassert>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "faiss/BuilderSuspend.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace milvus {
namespace knowhere {
IndexRHNSWFlat::IndexRHNSWFlat(int d, int M, milvus::knowhere::MetricType metric) {
faiss::MetricType mt =
metric == Metric::L2 ? faiss::MetricType::METRIC_L2 : faiss::MetricType::METRIC_INNER_PRODUCT;
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexRHNSWFlat(d, M, mt));
}
BinarySet
IndexRHNSWFlat::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
try {
auto res_set = IndexRHNSW::Serialize(config);
MemoryIOWriter writer;
writer.name = this->index_type() + "_Data";
auto real_idx = dynamic_cast<faiss::IndexRHNSWFlat*>(index_.get());
if (real_idx == nullptr) {
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWFlat*>(index_) failed during Serialize!");
}
auto storage_index = dynamic_cast<faiss::IndexFlat*>(real_idx->storage);
faiss::write_index(storage_index, &writer);
std::shared_ptr<uint8_t[]> data(writer.data_);
res_set.Append(writer.name, data, writer.rp);
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWFlat::Load(const BinarySet& index_binary) {
try {
IndexRHNSW::Load(index_binary);
MemoryIOReader reader;
reader.name = this->index_type() + "_Data";
auto binary = index_binary.GetByName(reader.name);
reader.total = (size_t)binary->size;
reader.data_ = binary->data.get();
auto real_idx = dynamic_cast<faiss::IndexRHNSWFlat*>(index_.get());
if (real_idx == nullptr) {
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWFlat*>(index_) failed during Load!");
}
real_idx->storage = faiss::read_index(&reader);
real_idx->init_hnsw();
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWFlat::Train(const DatasetPtr& dataset_ptr, const Config& config) {
try {
GET_TENSOR_DATA_DIM(dataset_ptr)
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto idx = new faiss::IndexRHNSWFlat(int(dim), config[IndexParams::M], metric_type);
idx->hnsw.efConstruction = config[IndexParams::efConstruction];
index_ = std::shared_ptr<faiss::Index>(idx);
index_->train(rows, (float*)p_data);
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWFlat::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
index_size_ = dynamic_cast<faiss::IndexRHNSWFlat*>(index_.get())->cal_size();
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,51 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <mutex>
#include <utility>
#include "IndexRHNSW.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/IndexType.h"
namespace milvus {
namespace knowhere {
class IndexRHNSWFlat : public IndexRHNSW {
public:
IndexRHNSWFlat() : IndexRHNSW() {
index_type_ = IndexEnum::INDEX_RHNSWFlat;
}
explicit IndexRHNSWFlat(std::shared_ptr<faiss::Index> index) : IndexRHNSW(std::move(index)) {
index_type_ = IndexEnum::INDEX_RHNSWFlat;
}
IndexRHNSWFlat(int d, int M, MetricType metric = Metric::L2);
BinarySet
Serialize(const Config& config = Config()) override;
void
Load(const BinarySet& index_binary) override;
void
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
void
UpdateIndexSize() override;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,102 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexRHNSWPQ.h"
#include <algorithm>
#include <cassert>
#include <iterator>
#include <utility>
#include <vector>
#include "faiss/BuilderSuspend.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace milvus {
namespace knowhere {
IndexRHNSWPQ::IndexRHNSWPQ(int d, int pq_m, int M) {
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexRHNSWPQ(d, pq_m, M));
}
BinarySet
IndexRHNSWPQ::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
try {
auto res_set = IndexRHNSW::Serialize(config);
MemoryIOWriter writer;
writer.name = this->index_type() + "_Data";
auto real_idx = dynamic_cast<faiss::IndexRHNSWPQ*>(index_.get());
if (real_idx == nullptr) {
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWPQ*>(index_) failed during Serialize!");
}
faiss::write_index(real_idx->storage, &writer);
std::shared_ptr<uint8_t[]> data(writer.data_);
res_set.Append(writer.name, data, writer.rp);
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWPQ::Load(const BinarySet& index_binary) {
try {
IndexRHNSW::Load(index_binary);
MemoryIOReader reader;
reader.name = this->index_type() + "_Data";
auto binary = index_binary.GetByName(reader.name);
reader.total = (size_t)binary->size;
reader.data_ = binary->data.get();
auto real_idx = dynamic_cast<faiss::IndexRHNSWPQ*>(index_.get());
if (real_idx == nullptr) {
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWPQ*>(index_) failed during Load!");
}
real_idx->storage = faiss::read_index(&reader);
real_idx->init_hnsw();
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
try {
GET_TENSOR_DATA_DIM(dataset_ptr)
auto idx = new faiss::IndexRHNSWPQ(int(dim), config[IndexParams::PQM], config[IndexParams::M]);
idx->hnsw.efConstruction = config[IndexParams::efConstruction];
index_ = std::shared_ptr<faiss::Index>(idx);
index_->train(rows, (float*)p_data);
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWPQ::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
index_size_ = dynamic_cast<faiss::IndexRHNSWPQ*>(index_.get())->cal_size();
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,52 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <mutex>
#include <utility>
#include "IndexRHNSW.h"
#include "knowhere/common/Exception.h"
namespace milvus {
namespace knowhere {
class IndexRHNSWPQ : public IndexRHNSW {
public:
IndexRHNSWPQ() : IndexRHNSW() {
index_type_ = IndexEnum::INDEX_RHNSWPQ;
}
explicit IndexRHNSWPQ(std::shared_ptr<faiss::Index> index) : IndexRHNSW(std::move(index)) {
index_type_ = IndexEnum::INDEX_RHNSWPQ;
}
IndexRHNSWPQ(int d, int pq_m, int M);
BinarySet
Serialize(const Config& config = Config()) override;
void
Load(const BinarySet& index_binary) override;
void
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
void
UpdateIndexSize() override;
private:
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,107 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexRHNSWSQ.h"
#include <algorithm>
#include <cassert>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "faiss/BuilderSuspend.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace milvus {
namespace knowhere {
IndexRHNSWSQ::IndexRHNSWSQ(int d, faiss::QuantizerType qtype, int M, milvus::knowhere::MetricType metric) {
faiss::MetricType mt =
metric == Metric::L2 ? faiss::MetricType::METRIC_L2 : faiss::MetricType::METRIC_INNER_PRODUCT;
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexRHNSWSQ(d, qtype, M, mt));
}
BinarySet
IndexRHNSWSQ::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
try {
auto res_set = IndexRHNSW::Serialize(config);
MemoryIOWriter writer;
writer.name = this->index_type() + "_Data";
auto real_idx = dynamic_cast<faiss::IndexRHNSWSQ*>(index_.get());
if (real_idx == nullptr) {
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWSQ*>(index_) failed during Serialize!");
}
faiss::write_index(real_idx->storage, &writer);
std::shared_ptr<uint8_t[]> data(writer.data_);
res_set.Append(writer.name, data, writer.rp);
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWSQ::Load(const BinarySet& index_binary) {
try {
IndexRHNSW::Load(index_binary);
MemoryIOReader reader;
reader.name = this->index_type() + "_Data";
auto binary = index_binary.GetByName(reader.name);
reader.total = (size_t)binary->size;
reader.data_ = binary->data.get();
auto real_idx = dynamic_cast<faiss::IndexRHNSWSQ*>(index_.get());
if (real_idx == nullptr) {
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWSQ*>(index_) failed during Load!");
}
real_idx->storage = faiss::read_index(&reader);
real_idx->init_hnsw();
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
try {
GET_TENSOR_DATA_DIM(dataset_ptr)
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto idx =
new faiss::IndexRHNSWSQ(int(dim), faiss::QuantizerType::QT_8bit, config[IndexParams::M], metric_type);
idx->hnsw.efConstruction = config[IndexParams::efConstruction];
index_ = std::shared_ptr<faiss::Index>(idx);
index_->train(rows, (float*)p_data);
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWSQ::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
index_size_ = dynamic_cast<faiss::IndexRHNSWSQ*>(index_.get())->cal_size();
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,52 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <mutex>
#include <utility>
#include "IndexRHNSW.h"
#include "knowhere/common/Exception.h"
namespace milvus {
namespace knowhere {
class IndexRHNSWSQ : public IndexRHNSW {
public:
IndexRHNSWSQ() : IndexRHNSW() {
index_type_ = IndexEnum::INDEX_RHNSWSQ;
}
explicit IndexRHNSWSQ(std::shared_ptr<faiss::Index> index) : IndexRHNSW(std::move(index)) {
index_type_ = IndexEnum::INDEX_RHNSWSQ;
}
IndexRHNSWSQ(int d, faiss::QuantizerType qtype, int M, MetricType metric = Metric::L2);
BinarySet
Serialize(const Config& config = Config()) override;
void
Load(const BinarySet& index_binary) override;
void
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
void
UpdateIndexSize() override;
private:
};
} // namespace knowhere
} // namespace milvus

View File

@ -20,11 +20,15 @@
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/IndexRHNSWFlat.h"
#include "knowhere/index/vector_index/IndexRHNSWPQ.h"
#include "knowhere/index/vector_index/IndexRHNSWSQ.h"
#include "knowhere/index/vector_offset_index/IndexHNSW_NM.h"
#include "knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.h"
#include "knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h"
#include "knowhere/index/vector_offset_index/IndexIVF_NM.h"
#include "knowhere/index/vector_offset_index/IndexNSG_NM.h"
#ifdef MILVUS_SUPPORT_SPTAG
#include "knowhere/index/vector_index/IndexSPTAG.h"
#endif
@ -94,6 +98,12 @@ VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) {
return std::make_shared<knowhere::IVFSQNR_NM>();
} else if (type == IndexEnum::INDEX_HNSW_SQ8NM) {
return std::make_shared<knowhere::IndexHNSW_SQ8NM>();
} else if (type == IndexEnum::INDEX_RHNSWFlat) {
return std::make_shared<knowhere::IndexRHNSWFlat>();
} else if (type == IndexEnum::INDEX_RHNSWPQ) {
return std::make_shared<knowhere::IndexRHNSWPQ>();
} else if (type == IndexEnum::INDEX_RHNSWSQ) {
return std::make_shared<knowhere::IndexRHNSWSQ>();
} else {
return nullptr;
}

View File

@ -48,6 +48,9 @@ constexpr const char* ef = "ef";
// Annoy Params
constexpr const char* n_trees = "n_trees";
constexpr const char* search_k = "search_k";
// PQ Params
constexpr const char* PQM = "PQM";
} // namespace IndexParams
namespace Metric {

View File

@ -80,6 +80,7 @@ struct Index2Layer: Index {
void sa_encode (idx_t n, const float *x, uint8_t *bytes) const override;
void sa_decode (idx_t n, const uint8_t *bytes, float *x) const override;
size_t cal_size() { return sizeof(*this) + codes.size() * sizeof(uint8_t) + pq.cal_size(); }
};

View File

@ -85,6 +85,8 @@ struct IndexFlat: Index {
void sa_decode (idx_t n, const uint8_t *bytes,
float *x) const override;
size_t cal_size() { return xb.size() * sizeof(float); }
};

View File

@ -124,6 +124,8 @@ struct IndexPQ: Index {
void hamming_distance_table (idx_t n, const float *x,
int32_t *dis) const;
size_t cal_size() { return codes.size() * sizeof(uint8_t) + pq.cal_size(); }
};

View File

@ -0,0 +1,812 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexRHNSW.h>
#include <cstdlib>
#include <cassert>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <omp.h>
#include <unordered_set>
#include <queue>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <stdint.h>
#ifdef __SSE__
#endif
#include <faiss/utils/distances.h>
#include <faiss/utils/random.h>
#include <faiss/utils/Heap.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/Index2Layer.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/FaissHook.h>
extern "C" {
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
n, FINTEGER *k, const float *alpha, const float *a,
FINTEGER *lda, const float *b, FINTEGER *
ldb, float *beta, float *c, FINTEGER *ldc);
}
namespace faiss {
using idx_t = Index::idx_t;
using MinimaxHeap = RHNSW::MinimaxHeap;
using storage_idx_t = RHNSW::storage_idx_t;
using NodeDistFarther = RHNSW::NodeDistFarther;
RHNSWStats rhnsw_stats;
/**************************************************************
* add / search blocks of descriptors
**************************************************************/
namespace {
/* Wrap the distance computer into one that negates the
distances. This makes supporting INNER_PRODUCE search easier */
struct NegativeDistanceComputer: DistanceComputer {
/// owned by this
DistanceComputer *basedis;
explicit NegativeDistanceComputer(DistanceComputer *basedis):
basedis(basedis)
{}
void set_query(const float *x) override {
basedis->set_query(x);
}
/// compute distance of vector i to current query
float operator () (idx_t i) override {
return -(*basedis)(i);
}
/// compute distance between two stored vectors
float symmetric_dis (idx_t i, idx_t j) override {
return -basedis->symmetric_dis(i, j);
}
virtual ~NegativeDistanceComputer ()
{
delete basedis;
}
};
DistanceComputer *storage_distance_computer(const Index *storage)
{
if (storage->metric_type == METRIC_INNER_PRODUCT) {
return new NegativeDistanceComputer(storage->get_distance_computer());
} else {
return storage->get_distance_computer();
}
}
void hnsw_add_vertices(IndexRHNSW &index_hnsw,
size_t n0,
size_t n, const float *x,
bool verbose,
bool preset_levels = false) {
size_t d = index_hnsw.d;
RHNSW & hnsw = index_hnsw.hnsw;
size_t ntotal = n0 + n;
double t0 = getmillisecs();
if (verbose) {
printf("hnsw_add_vertices: adding %ld elements on top of %ld "
"(preset_levels=%d)\n",
n, n0, int(preset_levels));
}
if (n == 0) {
return;
}
int max_level = hnsw.prepare_level_tab(n, preset_levels);
if (verbose) {
printf(" max_level = %d\n", max_level);
}
{ // perform add
auto tas = getmillisecs();
RandomGenerator rng2(789);
DistanceComputer *dis0 =
storage_distance_computer (index_hnsw.storage);
ScopeDeleter1<DistanceComputer> del0(dis0);
dis0->set_query(x);
hnsw.addPoint(*dis0, hnsw.levels[n0], n0);
#pragma omp parallel for
for (int i = 1; i < n; ++ i) {
DistanceComputer *dis =
storage_distance_computer (index_hnsw.storage);
ScopeDeleter1<DistanceComputer> del(dis);
dis->set_query(x + i * d);
hnsw.addPoint(*dis, hnsw.levels[n0 + i], i + n0);
}
}
if (verbose) {
printf("Done in %.3f ms\n", getmillisecs() - t0);
}
}
} // namespace
/**************************************************************
* IndexRHNSW implementation
**************************************************************/
IndexRHNSW::IndexRHNSW(int d, int M, MetricType metric):
Index(d, metric),
hnsw(M),
own_fields(false),
storage(nullptr),
reconstruct_from_neighbors(nullptr)
{}
IndexRHNSW::IndexRHNSW(Index *storage, int M):
Index(storage->d, storage->metric_type),
hnsw(M),
own_fields(false),
storage(storage),
reconstruct_from_neighbors(nullptr)
{}
IndexRHNSW::~IndexRHNSW() {
if (own_fields) {
delete storage;
}
}
void IndexRHNSW::init_hnsw() {
hnsw.init(ntotal);
}
void IndexRHNSW::train(idx_t n, const float* x)
{
FAISS_THROW_IF_NOT_MSG(storage,
"Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly");
// hnsw structure does not require training
storage->train (n, x);
is_trained = true;
}
void IndexRHNSW::search (idx_t n, const float *x, idx_t k,
float *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const
{
FAISS_THROW_IF_NOT_MSG(storage,
"Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly");
size_t nreorder = 0;
idx_t check_period = InterruptCallback::get_period_hint (
hnsw.max_level * d * hnsw.efSearch);
for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);
#pragma omp parallel reduction(+ : nreorder)
{
DistanceComputer *dis = storage_distance_computer(storage);
ScopeDeleter1<DistanceComputer> del(dis);
#pragma omp for
for(idx_t i = i0; i < i1; i++) {
idx_t * idxi = labels + i * k;
float * simi = distances + i * k;
dis->set_query(x + i * d);
maxheap_heapify (k, simi, idxi);
hnsw.searchKnn(*dis, k, idxi, simi, bitset);
maxheap_reorder (k, simi, idxi);
if (reconstruct_from_neighbors &&
reconstruct_from_neighbors->k_reorder != 0) {
int k_reorder = reconstruct_from_neighbors->k_reorder;
if (k_reorder == -1 || k_reorder > k) k_reorder = k;
nreorder += reconstruct_from_neighbors->compute_distances(
k_reorder, idxi, x + i * d, simi);
// sort top k_reorder
maxheap_heapify (k_reorder, simi, idxi, simi, idxi, k_reorder);
maxheap_reorder (k_reorder, simi, idxi);
}
}
}
InterruptCallback::check ();
}
if (metric_type == METRIC_INNER_PRODUCT) {
// we need to revert the negated distances
for (size_t i = 0; i < k * n; i++) {
distances[i] = -distances[i];
}
}
rhnsw_stats.nreorder += nreorder;
}
void IndexRHNSW::add(idx_t n, const float *x)
{
FAISS_THROW_IF_NOT_MSG(storage,
"Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly");
FAISS_THROW_IF_NOT(is_trained);
int n0 = ntotal;
storage->add(n, x);
ntotal = storage->ntotal;
hnsw_add_vertices (*this, n0, n, x, verbose,
hnsw.levels.size() == ntotal);
}
void IndexRHNSW::reset()
{
hnsw.reset();
storage->reset();
ntotal = 0;
}
void IndexRHNSW::reconstruct (idx_t key, float* recons) const
{
storage->reconstruct(key, recons);
}
size_t IndexRHNSW::cal_size() {
return hnsw.cal_size();
}
/**************************************************************
* ReconstructFromNeighbors implementation
**************************************************************/
ReconstructFromNeighbors2::ReconstructFromNeighbors2(
const IndexRHNSW & index, size_t k, size_t nsq):
index(index), k(k), nsq(nsq) {
M = index.hnsw.M << 1;
FAISS_ASSERT(k <= 256);
code_size = k == 1 ? 0 : nsq;
ntotal = 0;
d = index.d;
FAISS_ASSERT(d % nsq == 0);
dsub = d / nsq;
k_reorder = -1;
}
void ReconstructFromNeighbors2::reconstruct(storage_idx_t i, float *x, float *tmp) const
{
const RHNSW & hnsw = index.hnsw;
int *cur_links = hnsw.get_neighbor_link(i, 0);
int *cur_neighbors = cur_links + 1;
auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links);
if (k == 1 || nsq == 1) {
const float * beta;
if (k == 1) {
beta = codebook.data();
} else {
int idx = codes[i];
beta = codebook.data() + idx * (M + 1);
}
float w0 = beta[0]; // weight of image itself
index.storage->reconstruct(i, tmp);
for (int l = 0; l < d; l++)
x[l] = w0 * tmp[l];
for (auto j = 0; j < cur_neighbor_num; ++ j) {
storage_idx_t ji = cur_neighbors[j];
if (ji < 0) ji = i;
float w = beta[j + 1];
index.storage->reconstruct(ji, tmp);
for (int l = 0; l < d; l++)
x[l] += w * tmp[l];
}
} else if (nsq == 2) {
int idx0 = codes[2 * i];
int idx1 = codes[2 * i + 1];
const float *beta0 = codebook.data() + idx0 * (M + 1);
const float *beta1 = codebook.data() + (idx1 + k) * (M + 1);
index.storage->reconstruct(i, tmp);
float w0;
w0 = beta0[0];
for (int l = 0; l < dsub; l++)
x[l] = w0 * tmp[l];
w0 = beta1[0];
for (int l = dsub; l < d; l++)
x[l] = w0 * tmp[l];
for (auto j = 0; j < cur_neighbor_num; ++ j) {
storage_idx_t ji = cur_neighbors[j];
if (ji < 0) ji = i;
index.storage->reconstruct(ji, tmp);
float w;
w = beta0[j + 1];
for (int l = 0; l < dsub; l++)
x[l] += w * tmp[l];
w = beta1[j + 1];
for (int l = dsub; l < d; l++)
x[l] += w * tmp[l];
}
} else {
const float *betas[nsq];
{
const float *b = codebook.data();
const uint8_t *c = &codes[i * code_size];
for (int sq = 0; sq < nsq; sq++) {
betas[sq] = b + (*c++) * (M + 1);
b += (M + 1) * k;
}
}
index.storage->reconstruct(i, tmp);
{
int d0 = 0;
for (int sq = 0; sq < nsq; sq++) {
float w = *(betas[sq]++);
int d1 = d0 + dsub;
for (int l = d0; l < d1; l++) {
x[l] = w * tmp[l];
}
d0 = d1;
}
}
for (auto j = 0; j < cur_neighbor_num; ++ j) {
storage_idx_t ji = cur_neighbors[j];
if (ji < 0) ji = i;
index.storage->reconstruct(ji, tmp);
int d0 = 0;
for (int sq = 0; sq < nsq; sq++) {
float w = *(betas[sq]++);
int d1 = d0 + dsub;
for (int l = d0; l < d1; l++) {
x[l] += w * tmp[l];
}
d0 = d1;
}
}
}
}
void ReconstructFromNeighbors2::reconstruct_n(storage_idx_t n0,
storage_idx_t ni,
float *x) const
{
#pragma omp parallel
{
std::vector<float> tmp(index.d);
#pragma omp for
for (storage_idx_t i = 0; i < ni; i++) {
reconstruct(n0 + i, x + i * index.d, tmp.data());
}
}
}
size_t ReconstructFromNeighbors2::compute_distances(
size_t n, const idx_t *shortlist,
const float *query, float *distances) const
{
std::vector<float> tmp(2 * index.d);
size_t ncomp = 0;
for (int i = 0; i < n; i++) {
if (shortlist[i] < 0) break;
reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d);
distances[i] = fvec_L2sqr(query, tmp.data(), index.d);
ncomp++;
}
return ncomp;
}
void ReconstructFromNeighbors2::get_neighbor_table(storage_idx_t i, float *tmp1) const
{
const RHNSW & hnsw = index.hnsw;
int *cur_links = hnsw.get_neighbor_link(i, 0);
int *cur_neighbors = cur_links + 1;
auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links);
size_t d = index.d;
index.storage->reconstruct(i, tmp1);
for (auto j = 0; j < cur_neighbor_num; ++ j) {
storage_idx_t ji = cur_neighbors[j];
if (ji < 0) ji = i;
index.storage->reconstruct(ji, tmp1 + (j + 1) * d);
}
}
/// called by add_codes
void ReconstructFromNeighbors2::estimate_code(
const float *x, storage_idx_t i, uint8_t *code) const
{
// fill in tmp table with the neighbor values
float *tmp1 = new float[d * (M + 1) + (d * k)];
float *tmp2 = tmp1 + d * (M + 1);
ScopeDeleter<float> del(tmp1);
// collect coordinates of base
get_neighbor_table (i, tmp1);
for (size_t sq = 0; sq < nsq; sq++) {
int d0 = sq * dsub;
{
FINTEGER ki = k, di = d, m1 = M + 1;
FINTEGER dsubi = dsub;
float zero = 0, one = 1;
sgemm_ ("N", "N", &dsubi, &ki, &m1, &one,
tmp1 + d0, &di,
codebook.data() + sq * (m1 * k), &m1,
&zero, tmp2, &dsubi);
}
float min = HUGE_VAL;
int argmin = -1;
for (size_t j = 0; j < k; j++) {
float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub);
if (dis < min) {
min = dis;
argmin = j;
}
}
code[sq] = argmin;
}
}
void ReconstructFromNeighbors2::add_codes(size_t n, const float *x)
{
if (k == 1) { // nothing to encode
ntotal += n;
return;
}
codes.resize(codes.size() + code_size * n);
#pragma omp parallel for
for (int i = 0; i < n; i++) {
estimate_code(x + i * index.d, ntotal + i,
codes.data() + (ntotal + i) * code_size);
}
ntotal += n;
FAISS_ASSERT (codes.size() == ntotal * code_size);
}
/**************************************************************
* IndexRHNSWFlat implementation
**************************************************************/
IndexRHNSWFlat::IndexRHNSWFlat()
{
is_trained = true;
}
IndexRHNSWFlat::IndexRHNSWFlat(int d, int M, MetricType metric):
IndexRHNSW(new IndexFlat(d, metric), M)
{
own_fields = true;
is_trained = true;
}
size_t IndexRHNSWFlat::cal_size() {
return IndexRHNSW::cal_size() + dynamic_cast<IndexFlat*>(storage)->cal_size();
}
/**************************************************************
* IndexRHNSWPQ implementation
**************************************************************/
IndexRHNSWPQ::IndexRHNSWPQ() {}
IndexRHNSWPQ::IndexRHNSWPQ(int d, int pq_m, int M):
IndexRHNSW(new IndexPQ(d, pq_m, 8), M)
{
own_fields = true;
is_trained = false;
}
void IndexRHNSWPQ::train(idx_t n, const float* x)
{
IndexRHNSW::train (n, x);
(dynamic_cast<IndexPQ*> (storage))->pq.compute_sdc_table();
}
size_t IndexRHNSWPQ::cal_size() {
return IndexRHNSW::cal_size() + dynamic_cast<IndexPQ*>(storage)->cal_size();
}
/**************************************************************
* IndexRHNSWSQ implementation
**************************************************************/
IndexRHNSWSQ::IndexRHNSWSQ(int d, QuantizerType qtype, int M,
MetricType metric):
IndexRHNSW (new IndexScalarQuantizer (d, qtype, metric), M)
{
is_trained = false;
own_fields = true;
}
IndexRHNSWSQ::IndexRHNSWSQ() {}
size_t IndexRHNSWSQ::cal_size() {
return IndexRHNSW::cal_size() + dynamic_cast<IndexScalarQuantizer *>(storage)->cal_size();
}
/**************************************************************
* IndexRHNSW2Level implementation
**************************************************************/
IndexRHNSW2Level::IndexRHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M):
IndexRHNSW (new Index2Layer (quantizer, nlist, m_pq), M)
{
own_fields = true;
is_trained = false;
}
IndexRHNSW2Level::IndexRHNSW2Level() {}
size_t IndexRHNSW2Level::cal_size() {
return IndexRHNSW::cal_size() + dynamic_cast<Index2Layer *>(storage)->cal_size();
}
namespace {
// same as search_from_candidates but uses v
// visno -> is in result list
// visno + 1 -> in result list + in candidates
int search_from_candidates_2(const RHNSW & hnsw,
DistanceComputer & qdis, int k,
idx_t *I, float * D,
MinimaxHeap &candidates,
VisitedList &vt,
int level, int nres_in = 0)
{
int nres = nres_in;
int ndis = 0;
for (int i = 0; i < candidates.size(); i++) {
idx_t v1 = candidates.ids[i];
FAISS_ASSERT(v1 >= 0);
vt.mass[v1] = vt.curV + 1;
}
int nstep = 0;
while (candidates.size() > 0) {
float d0 = 0;
int v0 = candidates.pop_min(&d0);
int *cur_links = hnsw.get_neighbor_link(v0, level);
int *cur_neighbors = cur_links + 1;
auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links);
for (auto j = 0; j < cur_neighbor_num; ++ j) {
int v1 = cur_neighbors[j];
if (v1 < 0) break;
if (vt.mass[v1] == vt.curV + 1) {
// nothing to do
} else {
ndis++;
float d = qdis(v1);
candidates.push(v1, d);
// never seen before --> add to heap
if (vt.mass[v1] < vt.curV) {
if (nres < k) {
faiss::maxheap_push (++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_pop (nres--, D, I);
faiss::maxheap_push (++nres, D, I, d, v1);
}
}
vt.mass[v1] = vt.curV + 1;
}
}
nstep++;
if (nstep > hnsw.efSearch) {
break;
}
}
if (level == 0) {
#pragma omp critical
{
rhnsw_stats.n1 ++;
if (candidates.size() == 0)
rhnsw_stats.n2 ++;
}
}
return nres;
}
} // namespace
void IndexRHNSW2Level::search (idx_t n, const float *x, idx_t k,
float *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const
{
if (dynamic_cast<const Index2Layer*>(storage)) {
IndexRHNSW::search (n, x, k, distances, labels);
} else { // "mixed" search
const IndexIVFPQ *index_ivfpq =
dynamic_cast<const IndexIVFPQ*>(storage);
int nprobe = index_ivfpq->nprobe;
std::unique_ptr<idx_t[]> coarse_assign(new idx_t[n * nprobe]);
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
index_ivfpq->quantizer->search (n, x, nprobe, coarse_dis.get(),
coarse_assign.get());
index_ivfpq->search_preassigned (n, x, k, coarse_assign.get(),
coarse_dis.get(), distances, labels,
false);
#pragma omp parallel
{
VisitedList vt (ntotal);
DistanceComputer *dis = storage_distance_computer(storage);
ScopeDeleter1<DistanceComputer> del(dis);
int candidates_size = hnsw.upper_beam;
MinimaxHeap candidates(candidates_size);
#pragma omp for
for(idx_t i = 0; i < n; i++) {
idx_t * idxi = labels + i * k;
float * simi = distances + i * k;
dis->set_query(x + i * d);
// mark all inverted list elements as visited
for (int j = 0; j < nprobe; j++) {
idx_t key = coarse_assign[j + i * nprobe];
if (key < 0) break;
size_t list_length = index_ivfpq->get_list_size (key);
const idx_t * ids = index_ivfpq->invlists->get_ids (key);
for (int jj = 0; jj < list_length; jj++) {
vt.set (ids[jj]);
}
}
candidates.clear();
// copy the upper_beam elements to candidates list
int search_policy = 2;
if (search_policy == 1) {
for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
if (idxi[j] < 0) break;
candidates.push (idxi[j], simi[j]);
// search_from_candidates adds them back
idxi[j] = -1;
simi[j] = HUGE_VAL;
}
// reorder from sorted to heap
maxheap_heapify (k, simi, idxi, simi, idxi, k);
// removed from RHNSW, but still available in HNSW
// hnsw.search_from_candidates(
// *dis, k, idxi, simi,
// candidates, vt, 0, k
// );
vt.advance();
} else if (search_policy == 2) {
for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
if (idxi[j] < 0) break;
candidates.push (idxi[j], simi[j]);
}
// reorder from sorted to heap
maxheap_heapify (k, simi, idxi, simi, idxi, k);
search_from_candidates_2 (
hnsw, *dis, k, idxi, simi,
candidates, vt, 0, k);
vt.advance ();
vt.advance ();
}
maxheap_reorder (k, simi, idxi);
}
}
}
}
void IndexRHNSW2Level::flip_to_ivf ()
{
Index2Layer *storage2l =
dynamic_cast<Index2Layer*>(storage);
FAISS_THROW_IF_NOT (storage2l);
IndexIVFPQ * index_ivfpq =
new IndexIVFPQ (storage2l->q1.quantizer,
d, storage2l->q1.nlist,
storage2l->pq.M, 8);
index_ivfpq->pq = storage2l->pq;
index_ivfpq->is_trained = storage2l->is_trained;
index_ivfpq->precompute_table();
index_ivfpq->own_fields = storage2l->q1.own_fields;
storage2l->transfer_to_IVFPQ(*index_ivfpq);
index_ivfpq->make_direct_map (true);
storage = index_ivfpq;
delete storage2l;
}
} // namespace faiss

View File

@ -0,0 +1,152 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#pragma once
#include <vector>
#include <faiss/impl/RHNSW.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexPQ.h>
#include <faiss/IndexScalarQuantizer.h>
#include <faiss/utils/utils.h>
//#include <faiss/IndexHNSW.h>
namespace faiss {
struct IndexRHNSW;
struct ReconstructFromNeighbors2 {
typedef Index::idx_t idx_t;
typedef RHNSW::storage_idx_t storage_idx_t;
const IndexRHNSW & index;
size_t M; // number of neighbors
size_t k; // number of codebook entries
size_t nsq; // number of subvectors
size_t code_size;
int k_reorder; // nb to reorder. -1 = all
std::vector<float> codebook; // size nsq * k * (M + 1)
std::vector<uint8_t> codes; // size ntotal * code_size
size_t ntotal;
size_t d, dsub; // derived values
explicit ReconstructFromNeighbors2(const IndexRHNSW& index,
size_t k=256, size_t nsq=1);
/// codes must be added in the correct order and the IndexRHNSW
/// must be populated and sorted
void add_codes(size_t n, const float *x);
size_t compute_distances(size_t n, const idx_t *shortlist,
const float *query, float *distances) const;
/// called by add_codes
void estimate_code(const float *x, storage_idx_t i, uint8_t *code) const;
/// called by compute_distances
void reconstruct(storage_idx_t i, float *x, float *tmp) const;
void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float *x) const;
/// get the M+1 -by-d table for neighbor coordinates for vector i
void get_neighbor_table(storage_idx_t i, float *out) const;
};
/** The HNSW index is a normal random-access index with a HNSW
* link structure built on top */
struct IndexRHNSW : Index {
typedef RHNSW::storage_idx_t storage_idx_t;
// the link strcuture
RHNSW hnsw;
// the sequential storage
bool own_fields;
Index *storage;
ReconstructFromNeighbors2 *reconstruct_from_neighbors;
explicit IndexRHNSW (int d = 0, int M = 32, MetricType metric = METRIC_L2);
explicit IndexRHNSW (Index *storage, int M = 32);
~IndexRHNSW() override;
void add(idx_t n, const float *x) override;
/// Trains the storage if needed
void train(idx_t n, const float* x) override;
/// entry point for search
void search (idx_t n, const float *x, idx_t k,
float *distances, idx_t *labels,
ConcurrentBitsetPtr bitset = nullptr) const override;
void reconstruct(idx_t key, float* recons) const override;
void reset () override;
size_t cal_size();
void init_hnsw();
};
/** Flat index topped with with a HNSW structure to access elements
* more efficiently.
*/
struct IndexRHNSWFlat : IndexRHNSW {
IndexRHNSWFlat();
IndexRHNSWFlat(int d, int M, MetricType metric = METRIC_L2);
size_t cal_size();
};
/** PQ index topped with with a HNSW structure to access elements
* more efficiently.
*/
struct IndexRHNSWPQ : IndexRHNSW {
IndexRHNSWPQ();
IndexRHNSWPQ(int d, int pq_m, int M);
void train(idx_t n, const float* x) override;
size_t cal_size();
};
/** SQ index topped with with a HNSW structure to access elements
* more efficiently.
*/
struct IndexRHNSWSQ : IndexRHNSW {
IndexRHNSWSQ();
IndexRHNSWSQ(int d, QuantizerType qtype, int M, MetricType metric = METRIC_L2);
size_t cal_size();
};
/** 2-level code structure with fast random access
*/
struct IndexRHNSW2Level : IndexRHNSW {
IndexRHNSW2Level();
IndexRHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M);
void flip_to_ivf();
/// entry point for search
void search (idx_t n, const float *x, idx_t k,
float *distances, idx_t *labels,
ConcurrentBitsetPtr bitset = nullptr) const override;
size_t cal_size();
};
} // namespace faiss

View File

@ -78,6 +78,7 @@ struct IndexScalarQuantizer: Index {
void sa_decode (idx_t n, const uint8_t *bytes,
float *x) const override;
size_t cal_size() { return codes.size() * sizeof(uint8_t) + sizeof(size_t) + sq.cal_size(); }
};

View File

@ -1,3 +1,3 @@
#./configure CPUFLAGS='-mavx -mf16c -msse4 -mpopcnt' CXXFLAGS='-O0 -g -fPIC -m64 -Wno-sign-compare -Wall -Wextra' --prefix=$PWD --with-cuda-arch=-gencode=arch=compute_75,code=sm_75 --with-cuda=/usr/local/cuda
./configure --prefix=$PWD CFLAGS='-g -fPIC' CXXFLAGS='-O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp -g -fPIC -mf16c -O3' --without-python --with-cuda=/usr/local/cuda --with-cuda-arch='-gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75'
./configure --prefix=$PWD CFLAGS='-g -fPIC' CXXFLAGS='-O3 -g -fPIC -DELPP_THREAD_SAFE -fopenmp -g -fPIC -mf16c -O3 -DNDEBUG' --without-python --with-cuda=/usr/local/cuda --with-cuda-arch='-gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75'
make install -j8

View File

@ -173,6 +173,7 @@ struct ProductQuantizer {
float_maxheap_array_t * res,
bool init_finalize_heap = true) const;
size_t cal_size() { return sizeof(*this) + centroids.size() * sizeof(float); }
};

View File

@ -0,0 +1,441 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/impl/RHNSW.h>
#include <string>
#include <faiss/impl/AuxIndexStructures.h>
namespace faiss {
/**************************************************************
* hnsw structure implementation
**************************************************************/
RHNSW::RHNSW(int M) : M(M), rng(12345) {
level_generator.seed(100);
max_level = -1;
entry_point = -1;
efSearch = 16;
efConstruction = 40;
upper_beam = 1;
level0_link_size = sizeof(int) * ((M << 1) | 1);
link_size = sizeof(int) * (M + 1);
level0_links = nullptr;
linkLists = nullptr;
level_constant = 1 / log(1.0 * M);
visited_list_pool = nullptr;
}
void RHNSW::init(int ntotal) {
level_generator.seed(100);
if (visited_list_pool) delete visited_list_pool;
visited_list_pool = new VisitedListPool(1, ntotal);
std::vector<std::mutex>(ntotal).swap(link_list_locks);
}
RHNSW::~RHNSW() {
free(level0_links);
for (auto i = 0; i < levels.size(); ++ i) {
if (levels[i])
free(linkLists[i]);
}
free(linkLists);
delete visited_list_pool;
}
void RHNSW::reset() {
max_level = -1;
entry_point = -1;
levels.clear();
free(level0_links);
for (auto i = 0; i < levels.size(); ++ i) {
if (levels[i])
free(linkLists[i]);
}
free(linkLists);
level0_links = nullptr;
linkLists = nullptr;
level_constant = 1 / log(1.0 * M);
}
int RHNSW::prepare_level_tab(size_t n, bool preset_levels)
{
size_t n0 = levels.size();
std::vector<int> level_stats(n);
if (preset_levels) {
FAISS_ASSERT (n0 + n == levels.size());
} else {
FAISS_ASSERT (n0 == levels.size());
for (int i = 0; i < n; i++) {
int pt_level = random_level(level_constant);
levels.push_back(pt_level);
}
}
char *level0_links_new = (char*)malloc((n0 + n) * level0_link_size);
if (level0_links_new == nullptr) {
throw std::runtime_error("No enough memory 4 level0_links!");
}
memset(level0_links_new, 0, (n0 + n) * level0_link_size);
if (level0_links) {
memcpy(level0_links_new, level0_links, n0 * level0_link_size);
free(level0_links);
}
level0_links = level0_links_new;
char **linkLists_new = (char **)malloc(sizeof(void*) * (n0 + n));
if (linkLists_new == nullptr) {
throw std::runtime_error("No enough memory 4 level0_links_new!");
}
if (linkLists) {
memcpy(linkLists_new, linkLists, n0 * sizeof(void*));
free(linkLists);
}
linkLists = linkLists_new;
int max_level = 0;
int debug_space = 0;
for (int i = 0; i < n; i++) {
int pt_level = levels[i + n0];
if (pt_level > max_level) max_level = pt_level;
if (pt_level) {
linkLists[n0 + i] = (char*) malloc(link_size * pt_level + 1);
if (linkLists[n0 + i] == nullptr) {
throw std::runtime_error("No enough memory 4 linkLists!");
}
memset(linkLists[n0 + i], 0, link_size * pt_level + 1);
}
if (max_level >= level_stats.size()) {
level_stats.resize(max_level + 1);
}
level_stats[pt_level] ++;
}
// printf("level stats:\n");
// for (int i = 0; i <= max_level; ++ i)
// printf("level %d: %d points\n", i, level_stats[i]);
// printf("\n");
std::vector<std::mutex>(n0 + n).swap(link_list_locks);
if (visited_list_pool) delete visited_list_pool;
visited_list_pool = new VisitedListPool(1, n0 + n);
return max_level;
}
/**************************************************************
* new implementation of hnsw ispired by hnswlib
* by cmli@zilliz July 30, 2020
**************************************************************/
using Node = faiss::RHNSW::Node;
using CompareByFirst = faiss::RHNSW::CompareByFirst;
void RHNSW::addPoint(DistanceComputer& ptdis, int pt_level, int pt_id) {
std::unique_lock<std::mutex> lock_el(link_list_locks[pt_id]);
std::unique_lock<std::mutex> temp_lock(global);
int maxlevel_copy = max_level;
if (pt_level <= maxlevel_copy)
temp_lock.unlock();
int currObj = entry_point;
int ep_copy = entry_point;
if (currObj != -1) {
if (pt_level < maxlevel_copy) {
float curdist = ptdis(currObj);
for (int lev = maxlevel_copy; lev > pt_level; lev --) {
bool changed = true;
while (changed) {
changed = false;
std::unique_lock<std::mutex> lk(link_list_locks[currObj]);
int *curObj_link = get_neighbor_link(currObj, lev);
auto curObj_nei_num = get_neighbors_num(curObj_link);
for (auto i = 1; i <= curObj_nei_num; ++ i) {
int cand = curObj_link[i];
if (cand < 0 || cand > levels.size())
throw std::runtime_error("cand error when addPoint");
float d = ptdis(cand);
if (d < curdist) {
curdist = d;
currObj = cand;
changed = true;
}
}
}
}
}
for (int lev = std::min(pt_level, maxlevel_copy); lev >= 0; -- lev) {
if (lev > maxlevel_copy || lev < 0)
throw std::runtime_error("Level error");
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates = search_layer(ptdis, pt_id, currObj, lev);
currObj = top_candidates.top().second;
make_connection(ptdis, pt_id, top_candidates, lev);
}
} else {
entry_point = 0;
max_level = pt_level;
}
if (pt_level > maxlevel_copy) {
entry_point = pt_id;
max_level = pt_level;
}
}
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
RHNSW::search_layer(DistanceComputer& ptdis,
storage_idx_t pt_id,
storage_idx_t nearest,
int level) {
VisitedList *vl = visited_list_pool->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates;
std::priority_queue<Node, std::vector<Node>, CompareByFirst> candidate_set;
float d_nearest = ptdis(nearest);
float lb = d_nearest;
top_candidates.emplace(d_nearest, nearest);
candidate_set.emplace(-d_nearest, nearest);
visited_array[nearest] = visited_array_tag;
while (!candidate_set.empty()) {
Node currNode = candidate_set.top();
if ((-currNode.first) > lb)
break;
candidate_set.pop();
int cur_id = currNode.second;
std::unique_lock<std::mutex> lk(link_list_locks[cur_id]);
int *cur_link = get_neighbor_link(cur_id, level);
auto cur_neighbor_num = get_neighbors_num(cur_link);
for (auto i = 1; i <= cur_neighbor_num; ++ i) {
int candidate_id = cur_link[i];
if (visited_array[candidate_id] == visited_array_tag) continue;
visited_array[candidate_id] = visited_array_tag;
float dcand = ptdis(candidate_id);
if (top_candidates.size() < efConstruction || lb > dcand) {
candidate_set.emplace(-dcand, candidate_id);
top_candidates.emplace(dcand, candidate_id);
if (top_candidates.size() > efConstruction)
top_candidates.pop();
if (!top_candidates.empty())
lb = top_candidates.top().first;
}
}
}
visited_list_pool->releaseVisitedList(vl);
return top_candidates;
}
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
RHNSW::search_base_layer(DistanceComputer& ptdis,
storage_idx_t nearest,
storage_idx_t ef,
float d_nearest,
ConcurrentBitsetPtr bitset) const {
VisitedList *vl = visited_list_pool->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates;
std::priority_queue<Node, std::vector<Node>, CompareByFirst> candidate_set;
float lb;
if (bitset == nullptr || !bitset->test((faiss::ConcurrentBitset::id_type_t)(nearest))) {
lb = d_nearest;
top_candidates.emplace(d_nearest, nearest);
candidate_set.emplace(-d_nearest, nearest);
} else {
lb = std::numeric_limits<float>::max();
candidate_set.emplace(-lb, nearest);
}
visited_array[nearest] = visited_array_tag;
while (!candidate_set.empty()) {
Node currNode = candidate_set.top();
if ((-currNode.first) > lb)
break;
candidate_set.pop();
int cur_id = currNode.second;
int *cur_link = get_neighbor_link(cur_id, 0);
auto cur_neighbor_num = get_neighbors_num(cur_link);
for (auto i = 1; i <= cur_neighbor_num; ++ i) {
int candidate_id = cur_link[i];
if (visited_array[candidate_id] != visited_array_tag) {
visited_array[candidate_id] = visited_array_tag;
float dcand = ptdis(candidate_id);
if (top_candidates.size() < ef || lb > dcand) {
candidate_set.emplace(-dcand, candidate_id);
if (bitset == nullptr || !bitset->test((faiss::ConcurrentBitset::id_type_t)(candidate_id)))
top_candidates.emplace(dcand, candidate_id);
if (top_candidates.size() > ef)
top_candidates.pop();
if (!top_candidates.empty())
lb = top_candidates.top().first;
}
}
}
}
visited_list_pool->releaseVisitedList(vl);
return top_candidates;
}
void
RHNSW::make_connection(DistanceComputer& ptdis,
storage_idx_t pt_id,
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
int level) {
int maxM = level ? M : M << 1;
int *selectedNeighbors = (int*)malloc(sizeof(int) * maxM);
int selectedNeighborsNum = 0;
prune_neighbors(ptdis, cand, maxM, selectedNeighbors, selectedNeighborsNum);
if (selectedNeighborsNum > maxM)
throw std::runtime_error("Wrong size of candidates returned by prune_neighbors!");
int *cur_link = get_neighbor_link(pt_id, level);
if (*cur_link)
throw std::runtime_error("The newly inserted element should have blank link");
set_neighbors_num(cur_link, selectedNeighborsNum);
for (auto i = 1; i <= selectedNeighborsNum; ++ i) {
if (cur_link[i])
throw std::runtime_error("Possible memory corruption.");
if (level > levels[selectedNeighbors[i - 1]])
throw std::runtime_error("Trying to make a link on a non-exisitent level.");
cur_link[i] = selectedNeighbors[i - 1];
}
for (auto i = 0; i < selectedNeighborsNum; ++ i) {
std::unique_lock<std::mutex> lk(link_list_locks[selectedNeighbors[i]]);
int *selected_link = get_neighbor_link(selectedNeighbors[i], level);
auto selected_neighbor_num = get_neighbors_num(selected_link);
if (selected_neighbor_num > maxM)
throw std::runtime_error("Bad value of selected_neighbor_num.");
if (selectedNeighbors[i] == pt_id)
throw std::runtime_error("Trying to connect an element to itself.");
if (level > levels[selectedNeighbors[i]])
throw std::runtime_error("Trying to make a link on a non-exisitent level.");
if (selected_neighbor_num < maxM) {
selected_link[selected_neighbor_num + 1] = pt_id;
set_neighbors_num(selected_link, selected_neighbor_num + 1);
} else {
double d_max = ptdis(selectedNeighbors[i]);
std::priority_queue<Node, std::vector<Node>, CompareByFirst> candi;
candi.emplace(d_max, pt_id);
for (auto j = 1; j <= selected_neighbor_num; ++ j)
candi.emplace(ptdis.symmetric_dis(selectedNeighbors[i], selected_link[j]), selected_link[j]);
int indx = 0;
prune_neighbors(ptdis, candi, maxM, selected_link + 1, indx);
set_neighbors_num(selected_link, indx);
}
}
free(selectedNeighbors);
}
void RHNSW::prune_neighbors(DistanceComputer& ptdis,
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
const int maxM, int *ret, int &ret_len) {
if (cand.size() < maxM) {
while (!cand.empty()) {
ret[ret_len ++] = cand.top().second;
cand.pop();
}
return;
}
std::priority_queue<Node> closest;
while (!cand.empty()) {
closest.emplace(-cand.top().first, cand.top().second);
cand.pop();
}
while (closest.size()) {
if (ret_len >= maxM)
break;
Node curr = closest.top();
float dist_to_query = -curr.first;
closest.pop();
bool good = true;
for (auto i = 0; i < ret_len; ++ i) {
float cur_dist = ptdis.symmetric_dis(curr.second, ret[i]);
if (cur_dist < dist_to_query) {
good = false;
break;
}
}
if (good) {
ret[ret_len ++] = curr.second;
}
}
}
void RHNSW::searchKnn(DistanceComputer& qdis, int k,
idx_t *I, float *D,
ConcurrentBitsetPtr bitset) const {
if (levels.size() == 0)
return;
int ep = entry_point;
float dist = qdis(ep);
for (auto i = max_level; i > 0; -- i) {
bool good = true;
while (good) {
good = false;
int *ep_link = get_neighbor_link(ep, i);
auto ep_neighbors_cnt = get_neighbors_num(ep_link);
for (auto j = 1; j <= ep_neighbors_cnt; ++ j) {
int cand = ep_link[j];
if (cand < 0 || cand > levels.size())
throw std::runtime_error("cand error");
float d = qdis(cand);
if (d < dist) {
dist = d;
ep = cand;
good = true;
}
}
}
}
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates = search_base_layer(qdis, ep, std::max(efSearch, k), dist, bitset);
while (top_candidates.size() > k)
top_candidates.pop();
int i = 0;
while (!top_candidates.empty()) {
I[i] = top_candidates.top().second;
D[i] = top_candidates.top().first;
i ++;
top_candidates.pop();
}
}
size_t RHNSW::cal_size() {
size_t ret = 0;
ret += sizeof(*this);
ret += visited_list_pool->GetSize();
ret += link_list_locks.size() * sizeof(std::mutex);
ret += levels.size() * sizeof(int);
ret += levels.size() * level0_link_size;
ret += levels.size() * sizeof(void*);
for (auto i = 0; i < levels.size(); ++ i) {
ret += levels[i] ? link_size * levels[i] : 0;
}
return ret;
}
} // namespace faiss

View File

@ -0,0 +1,367 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#pragma once
#include <vector>
#include <mutex>
#include <unordered_set>
#include <queue>
#include <omp.h>
#include <faiss/Index.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/random.h>
#include <faiss/utils/Heap.h>
namespace faiss {
/** Implementation of the Hierarchical Navigable Small World
* datastructure.
*
* Efficient and robust approximate nearest neighbor search using
* Hierarchical Navigable Small World graphs
*
* Yu. A. Malkov, D. A. Yashunin, arXiv 2017
*
* This implmentation is heavily influenced by the hnswlib
* implementation by Yury Malkov and Leonid Boystov
* (https://github.com/searchivarius/nmslib/hnswlib)
*
* The HNSW object stores only the neighbor link structure, see
* IndexHNSW.h for the full index object.
*/
struct DistanceComputer; // from AuxIndexStructures
class VisitedListPool;
struct RHNSW {
/// internal storage of vectors (32 bits: this is expensive)
typedef int storage_idx_t;
/// Faiss results are 64-bit
typedef Index::idx_t idx_t;
typedef std::pair<float, storage_idx_t> Node;
/** Heap structure that allows fast
*/
struct MinimaxHeap {
int n;
int k;
int nvalid;
std::vector<storage_idx_t> ids;
std::vector<float> dis;
typedef faiss::CMax<float, storage_idx_t> HC;
explicit MinimaxHeap(int n): n(n), k(0), nvalid(0), ids(n), dis(n) {}
void push(storage_idx_t i, float v) {
if (k == n) {
if (v >= dis[0]) return;
faiss::heap_pop<HC> (k--, dis.data(), ids.data());
--nvalid;
}
faiss::heap_push<HC> (++k, dis.data(), ids.data(), v, i);
++nvalid;
}
float max() const {
return dis[0];
}
int size() const {
return nvalid;
}
void clear() {
nvalid = k = 0;
}
int pop_min(float *vmin_out = nullptr) {
assert(k > 0);
// returns min. This is an O(n) operation
int i = k - 1;
while (i >= 0) {
if (ids[i] != -1) break;
i--;
}
if (i == -1) return -1;
int imin = i;
float vmin = dis[i];
i--;
while(i >= 0) {
if (ids[i] != -1 && dis[i] < vmin) {
vmin = dis[i];
imin = i;
}
i--;
}
if (vmin_out) *vmin_out = vmin;
int ret = ids[imin];
ids[imin] = -1;
--nvalid;
return ret;
}
int count_below(float thresh) {
int n_below = 0;
for(int i = 0; i < k; i++) {
if (dis[i] < thresh) {
n_below++;
}
}
return n_below;
}
};
/// to sort pairs of (id, distance) from nearest to fathest or the reverse
struct NodeDistCloser {
float d;
int id;
NodeDistCloser(float d, int id): d(d), id(id) {}
bool operator < (const NodeDistCloser &obj1) const { return d < obj1.d; }
};
struct NodeDistFarther {
float d;
int id;
NodeDistFarther(float d, int id): d(d), id(id) {}
bool operator < (const NodeDistFarther &obj1) const { return d > obj1.d; }
};
struct CompareByFirst {
constexpr bool operator()(Node const &a,
Node const &b) const noexcept {
return a.first < b.first;
}
};
/// level of each vector (base level = 1), size = ntotal
std::vector<int> levels;
/// number of entry points in levels > 0.
int upper_beam;
/// entry point in the search structure (one of the points with maximum level
storage_idx_t entry_point;
faiss::RandomGenerator rng;
std::default_random_engine level_generator;
/// maximum level
int max_level;
int M;
char *level0_links;
char **linkLists;
size_t level0_link_size;
size_t link_size;
double level_constant;
VisitedListPool *visited_list_pool;
std::vector<std::mutex> link_list_locks;
std::mutex global;
/// expansion factor at construction time
int efConstruction;
/// expansion factor at search time
int efSearch;
/// range of entries in the neighbors table of vertex no at layer_no
storage_idx_t* get_neighbor_link(idx_t no, int layer_no) const {
return layer_no == 0 ? (int*)(level0_links + no * level0_link_size) : (int*)(linkLists[no] + (layer_no - 1) * link_size);
}
unsigned short int get_neighbors_num(int *p) const {
return *((unsigned short int*)p);
}
void set_neighbors_num(int *p, unsigned short int num) const {
*((unsigned short int*)(p)) = *((unsigned short int *)(&num));
}
/// only mandatory parameter: nb of neighbors
explicit RHNSW(int M = 32);
~RHNSW();
void init(int ntotal);
/// pick a random level for a new point, arg = 1/log(M)
int random_level(double arg) {
std::uniform_real_distribution<double> distribution(0.0, 1.0);
double r = -log(distribution(level_generator)) * arg;
return (int)r;
}
void reset();
int prepare_level_tab(size_t n, bool preset_levels = false);
// re-implementations inspired by hnswlib
/** add point pt_id on all levels <= pt_level and build the link
* structure for them. inspired by implementation of hnswlib */
void addPoint(DistanceComputer& ptdis, int pt_level, int pt_id);
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
search_layer (DistanceComputer& ptdis,
storage_idx_t pt_id,
storage_idx_t nearest,
int level);
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
search_base_layer (DistanceComputer& ptdis,
storage_idx_t nearest,
storage_idx_t ef,
float d_nearest,
ConcurrentBitsetPtr bitset = nullptr) const;
void make_connection(DistanceComputer& ptdis,
storage_idx_t pt_id,
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
int level);
void prune_neighbors(DistanceComputer& ptdis,
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
const int maxM, int *ret, int &ret_len);
/// search interface inspired by hnswlib
void searchKnn(DistanceComputer& qdis, int k,
idx_t *I, float *D,
ConcurrentBitsetPtr bitset = nullptr) const;
size_t cal_size();
};
/**************************************************************
* Auxiliary structures
**************************************************************/
typedef unsigned short int vl_type;
class VisitedList {
public:
vl_type curV;
vl_type *mass;
unsigned int numelements;
VisitedList(int numelements1) {
curV = -1;
numelements = numelements1;
mass = new vl_type[numelements];
}
void reset() {
curV++;
if (curV == 0) {
memset(mass, 0, sizeof(vl_type) * numelements);
curV++;
}
};
// keep compatibae with original version VisitedTable
/// set flog #no to true
void set(int no) {
mass[no] = curV;
}
/// get flag #no
bool get(int no) const {
return mass[no] == curV;
}
void advance() {
reset();
}
~VisitedList() { delete[] mass; }
};
///////////////////////////////////////////////////////////
//
// Class for multi-threaded pool-management of VisitedLists
//
/////////////////////////////////////////////////////////
class VisitedListPool {
std::deque<VisitedList *> pool;
std::mutex poolguard;
int numelements;
public:
VisitedListPool(int initmaxpools, int numelements1) {
numelements = numelements1;
for (int i = 0; i < initmaxpools; i++)
pool.push_front(new VisitedList(numelements));
}
VisitedList *getFreeVisitedList() {
VisitedList *rez;
{
std::unique_lock <std::mutex> lock(poolguard);
if (pool.size() > 0) {
rez = pool.front();
pool.pop_front();
} else {
rez = new VisitedList(numelements);
}
}
rez->reset();
return rez;
};
void releaseVisitedList(VisitedList *vl) {
std::unique_lock <std::mutex> lock(poolguard);
pool.push_front(vl);
};
~VisitedListPool() {
while (pool.size()) {
VisitedList *rez = pool.front();
pool.pop_front();
delete rez;
}
};
int64_t GetSize() {
auto visit_list_size = sizeof(VisitedList) + numelements * sizeof(vl_type);
auto pool_size = pool.size() * (sizeof(VisitedList *) + visit_list_size);
return pool_size + sizeof(*this);
}
};
struct RHNSWStats {
size_t n1, n2, n3;
size_t ndis;
size_t nreorder;
bool view;
RHNSWStats() {
reset();
}
void reset() {
n1 = n2 = n3 = 0;
ndis = 0;
nreorder = 0;
view = false;
}
};
// global var that collects them all
extern RHNSWStats rhnsw_stats;
} // namespace faiss

View File

@ -75,6 +75,7 @@ struct ScalarQuantizer {
(MetricType mt, const Index *quantizer, bool store_pairs,
bool by_residual=false) const;
size_t cal_size() { return sizeof(*this) + trained.size() * sizeof(float); }
};
template<class DCClass>

View File

@ -36,6 +36,7 @@
#include <faiss/IndexScalarQuantizer.h>
#include <faiss/IndexSQHybrid.h>
#include <faiss/IndexHNSW.h>
#include <faiss/IndexRHNSW.h>
#include <faiss/IndexLattice.h>
#include <faiss/OnDiskInvertedLists.h>
@ -458,6 +459,29 @@ static void read_HNSW (HNSW *hnsw, IOReader *f) {
READ1 (hnsw->upper_beam);
}
static void read_RHNSW (RHNSW *rhnsw, IOReader *f) {
READ1 (rhnsw->entry_point);
READ1 (rhnsw->max_level);
READ1 (rhnsw->M);
READ1 (rhnsw->level0_link_size);
READ1 (rhnsw->link_size);
READ1 (rhnsw->level_constant);
READ1 (rhnsw->efConstruction);
READ1 (rhnsw->efSearch);
READVECTOR (rhnsw->levels);
auto ntotal = rhnsw->levels.size();
rhnsw->level0_links = (char*) malloc(ntotal * rhnsw->level0_link_size);
READANDCHECK( rhnsw->level0_links, ntotal * rhnsw->level0_link_size);
rhnsw->linkLists = (char**) malloc(ntotal * sizeof(void*));
for (auto i = 0; i < ntotal; ++ i) {
if (rhnsw->levels[i]) {
rhnsw->linkLists[i] = (char*)malloc(rhnsw->link_size * rhnsw->levels[i] + 1);
READANDCHECK( rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1);
}
}
}
ProductQuantizer * read_ProductQuantizer (const char*fname) {
FileIOReader reader(fname);
return read_ProductQuantizer(&reader);
@ -624,6 +648,9 @@ Index *read_index (IOReader *f, int io_flags) {
if (h == fourcc ("IxPQ") || h == fourcc ("IxPo")) {
idxp->metric_type = METRIC_L2;
}
if (h == fourcc("IxPq")) {
idxp->pq.compute_sdc_table ();
}
idx = idxp;
} else if (h == fourcc ("IvFl") || h == fourcc("IvFL")) { // legacy
IndexIVFFlat * ivfl = new IndexIVFFlat ();
@ -800,6 +827,17 @@ Index *read_index (IOReader *f, int io_flags) {
dynamic_cast<IndexPQ*>(idxhnsw->storage)->pq.compute_sdc_table ();
}
idx = idxhnsw;
} else if(h == fourcc("IRHf") || h == fourcc("IRHp") ||
h == fourcc("IRHs") || h == fourcc("IRH2")) {
IndexRHNSW *idxrhnsw = nullptr;
if (h == fourcc("IRHf")) idxrhnsw = new IndexRHNSWFlat ();
if (h == fourcc("IRHp")) idxrhnsw = new IndexRHNSWPQ ();
if (h == fourcc("IRHs")) idxrhnsw = new IndexRHNSWSQ ();
if (h == fourcc("IRH2")) idxrhnsw = new IndexRHNSW2Level ();
read_index_header (idxrhnsw, f);
read_RHNSW (&idxrhnsw->hnsw, f);
idxrhnsw->own_fields = true;
idx = idxrhnsw;
} else {
FAISS_THROW_FMT("Index type 0x%08x not supported\n", h);
idx = nullptr;

View File

@ -36,6 +36,7 @@
#include <faiss/IndexScalarQuantizer.h>
#include <faiss/IndexSQHybrid.h>
#include <faiss/IndexHNSW.h>
#include <faiss/IndexRHNSW.h>
#include <faiss/IndexLattice.h>
#include <faiss/OnDiskInvertedLists.h>
@ -364,6 +365,24 @@ static void write_HNSW (const HNSW *hnsw, IOWriter *f) {
WRITE1 (hnsw->upper_beam);
}
static void write_RHNSW (const RHNSW *rhnsw, IOWriter *f) {
WRITE1 (rhnsw->entry_point);
WRITE1 (rhnsw->max_level);
WRITE1 (rhnsw->M);
WRITE1 (rhnsw->level0_link_size);
WRITE1 (rhnsw->link_size);
WRITE1 (rhnsw->level_constant);
WRITE1 (rhnsw->efConstruction);
WRITE1 (rhnsw->efSearch);
WRITEVECTOR (rhnsw->levels);
WRITEANDCHECK (rhnsw->level0_links, rhnsw->level0_link_size * rhnsw->levels.size());
for (auto i = 0; i < rhnsw->levels.size(); ++ i) {
if (rhnsw->levels[i])
WRITEANDCHECK (rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1);
}
}
static void write_direct_map (const DirectMap *dm, IOWriter *f) {
char maintain_direct_map = (char)dm->type; // for backwards compatibility with bool
WRITE1 (maintain_direct_map);
@ -560,6 +579,18 @@ void write_index (const Index *idx, IOWriter *f) {
write_index_header (idxhnsw, f);
write_HNSW (&idxhnsw->hnsw, f);
write_index (idxhnsw->storage, f);
} else if (const IndexRHNSW * idxrhnsw =
dynamic_cast<const IndexRHNSW *>(idx)) {
uint32_t h =
dynamic_cast<const IndexRHNSWFlat*>(idx) ? fourcc("IRHf") :
dynamic_cast<const IndexRHNSWPQ*>(idx) ? fourcc("IRHp") :
dynamic_cast<const IndexRHNSWSQ*>(idx) ? fourcc("IRHs") :
dynamic_cast<const IndexRHNSW2Level*>(idx) ? fourcc("IRH2") :
0;
FAISS_THROW_IF_NOT (h != 0);
WRITE1 (h);
write_index_header (idxrhnsw, f);
write_RHNSW (&idxrhnsw->hnsw, f);
} else {
FAISS_THROW_MSG ("don't know how to serialize this type of index");
}

View File

@ -1224,4 +1224,4 @@ namespace hnswlib_nm {
}
};
}
}

View File

@ -219,6 +219,42 @@ endif ()
target_link_libraries(test_hnsw_sq8nm ${depend_libs} ${unittest_libs} ${basic_libs})
install(TARGETS test_hnsw_sq8nm DESTINATION unittest)
################################################################################
#<RHNSW_FLAT-TEST>
set(rhnsw_flat_srcs
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.cpp
)
if (NOT TARGET test_rhnsw_flat)
add_executable(test_rhnsw_flat test_rhnsw_flat.cpp ${rhnsw_flat_srcs} ${util_srcs} ${faiss_srcs})
endif ()
target_link_libraries(test_rhnsw_flat ${depend_libs} ${unittest_libs} ${basic_libs})
install(TARGETS test_rhnsw_flat DESTINATION unittest)
################################################################################
#<RHNSW_PQ-TEST>
set(rhnsw_pq_srcs
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.cpp
)
if (NOT TARGET test_rhnsw_pq)
add_executable(test_rhnsw_pq test_rhnsw_pq.cpp ${rhnsw_pq_srcs} ${util_srcs} ${faiss_srcs})
endif ()
target_link_libraries(test_rhnsw_pq ${depend_libs} ${unittest_libs} ${basic_libs})
install(TARGETS test_rhnsw_pq DESTINATION unittest)
################################################################################
#<RHNSW_SQ8-TEST>
set(rhnsw_sq8_srcs
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.cpp
)
if (NOT TARGET test_rhnsw_sq8)
add_executable(test_rhnsw_sq8 test_rhnsw_sq8.cpp ${rhnsw_sq8_srcs} ${util_srcs} ${faiss_srcs})
endif ()
target_link_libraries(test_rhnsw_sq8 ${depend_libs} ${unittest_libs} ${basic_libs})
install(TARGETS test_rhnsw_sq8 DESTINATION unittest)
################################################################################
#<SPTAG-TEST>
if (MILVUS_SUPPORT_SPTAG)

View File

@ -0,0 +1,158 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <gtest/gtest.h>
#include <knowhere/index/vector_index/IndexRHNSWFlat.h>
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
#include <iostream>
#include <random>
#include "knowhere/common/Exception.h"
#include "unittest/utils.h"
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class RHNSWFlatTest : public DataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
IndexType = GetParam();
std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10
index_ = std::make_shared<milvus::knowhere::IndexRHNSWFlat>();
conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200},
{milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
};
}
protected:
milvus::knowhere::Config conf;
std::shared_ptr<milvus::knowhere::IndexRHNSWFlat> index_ = nullptr;
std::string IndexType;
};
INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWFlatTest, Values("RHNSWFlat"));
TEST_P(RHNSWFlatTest, HNSW_basic) {
assert(!xb.empty());
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result1 = index_->Query(query_dataset, conf);
AssertAnns(result1, nq, k);
// Serialize and Load before Query
milvus::knowhere::BinarySet bs = index_->Serialize();
auto tmp_index = std::make_shared<milvus::knowhere::IndexRHNSWFlat>();
tmp_index->Load(bs);
auto result2 = tmp_index->Query(query_dataset, conf);
AssertAnns(result2, nq, k);
}
TEST_P(RHNSWFlatTest, HNSW_delete) {
assert(!xb.empty());
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
for (auto i = 0; i < nq; ++i) {
bitset->set(i);
}
auto result1 = index_->Query(query_dataset, conf);
AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
/*
* delete result checked by eyes
auto ids1 = result1->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto ids2 = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
std::cout << std::endl;
for (int i = 0; i < nq; ++ i) {
std::cout << "ids1: ";
for (int j = 0; j < k; ++ j) {
std::cout << *(ids1 + i * k + j) << " ";
}
std::cout << "ids2: ";
for (int j = 0; j < k; ++ j) {
std::cout << *(ids2 + i * k + j) << " ";
}
std::cout << std::endl;
for (int j = 0; j < std::min(5, k>>1); ++ j) {
ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j));
}
}
*/
}
TEST_P(RHNSWFlatTest, HNSW_serialize) {
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
{
FileIOWriter writer(filename);
writer(static_cast<void*>(bin->data.get()), bin->size);
}
FileIOReader reader(filename);
reader(ret, bin->size);
};
{
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
auto binaryset = index_->Serialize();
std::string index_type = index_->index_type();
std::string idx_name = index_type + "_Index";
std::string dat_name = index_type + "_Data";
if (binaryset.binary_map_.find(idx_name) == binaryset.binary_map_.end()) {
std::cout << "no idx!" << std::endl;
}
if (binaryset.binary_map_.find(dat_name) == binaryset.binary_map_.end()) {
std::cout << "no dat!" << std::endl;
}
auto bin_idx = binaryset.GetByName(idx_name);
auto bin_dat = binaryset.GetByName(dat_name);
std::string filename_idx = "/tmp/RHNSWFlat_test_serialize_idx.bin";
std::string filename_dat = "/tmp/RHNSWFlat_test_serialize_dat.bin";
auto load_idx = new uint8_t[bin_idx->size];
auto load_dat = new uint8_t[bin_dat->size];
serialize(filename_idx, bin_idx, load_idx);
serialize(filename_dat, bin_dat, load_dat);
binaryset.clear();
auto new_idx = std::make_shared<milvus::knowhere::IndexRHNSWFlat>();
std::shared_ptr<uint8_t[]> dat(load_dat);
std::shared_ptr<uint8_t[]> idx(load_idx);
binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size);
binaryset.Append(new_idx->index_type() + "_Data", dat, bin_dat->size);
new_idx->Load(binaryset);
EXPECT_EQ(new_idx->Count(), nb);
EXPECT_EQ(new_idx->Dim(), dim);
auto result = new_idx->Query(query_dataset, conf);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
}
}

View File

@ -0,0 +1,148 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <gtest/gtest.h>
#include <knowhere/index/vector_index/IndexRHNSWPQ.h>
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
#include <iostream>
#include <random>
#include "knowhere/common/Exception.h"
#include "unittest/utils.h"
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class RHNSWPQTest : public DataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
IndexType = GetParam();
std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10
index_ = std::make_shared<milvus::knowhere::IndexRHNSWPQ>();
conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200},
{milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{milvus::knowhere::IndexParams::PQM, 8}};
}
protected:
milvus::knowhere::Config conf;
std::shared_ptr<milvus::knowhere::IndexRHNSWPQ> index_ = nullptr;
std::string IndexType;
};
INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWPQTest, Values("RHNSWPQ"));
TEST_P(RHNSWPQTest, HNSW_basic) {
assert(!xb.empty());
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
// Serialize and Load before Query
milvus::knowhere::BinarySet bs = index_->Serialize();
auto result1 = index_->Query(query_dataset, conf);
// AssertAnns(result1, nq, k);
auto tmp_index = std::make_shared<milvus::knowhere::IndexRHNSWPQ>();
tmp_index->Load(bs);
auto result2 = tmp_index->Query(query_dataset, conf);
// AssertAnns(result2, nq, k);
}
TEST_P(RHNSWPQTest, HNSW_delete) {
assert(!xb.empty());
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
for (auto i = 0; i < nq; ++i) {
bitset->set(i);
}
auto result1 = index_->Query(query_dataset, conf);
// AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
// AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
/*
* delete result checked by eyes
auto ids1 = result1->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto ids2 = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
std::cout << std::endl;
for (int i = 0; i < nq; ++ i) {
std::cout << "ids1: ";
for (int j = 0; j < k; ++ j) {
std::cout << *(ids1 + i * k + j) << " ";
}
std::cout << "ids2: ";
for (int j = 0; j < k; ++ j) {
std::cout << *(ids2 + i * k + j) << " ";
}
std::cout << std::endl;
for (int j = 0; j < std::min(5, k>>1); ++ j) {
ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j));
}
}
*/
}
TEST_P(RHNSWPQTest, HNSW_serialize) {
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
{
FileIOWriter writer(filename);
writer(static_cast<void*>(bin->data.get()), bin->size);
}
FileIOReader reader(filename);
reader(ret, bin->size);
};
{
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
auto binaryset = index_->Serialize();
auto bin_idx = binaryset.GetByName(index_->index_type() + "_Index");
auto bin_dat = binaryset.GetByName(index_->index_type() + "_Data");
std::string filename_idx = "/tmp/RHNSWPQ_test_serialize_idx.bin";
std::string filename_dat = "/tmp/RHNSWPQ_test_serialize_dat.bin";
auto load_idx = new uint8_t[bin_idx->size];
auto load_dat = new uint8_t[bin_dat->size];
serialize(filename_idx, bin_idx, load_idx);
serialize(filename_dat, bin_dat, load_dat);
binaryset.clear();
auto new_idx = std::make_shared<milvus::knowhere::IndexRHNSWPQ>();
std::shared_ptr<uint8_t[]> dat(load_dat);
std::shared_ptr<uint8_t[]> idx(load_idx);
binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size);
binaryset.Append(new_idx->index_type() + "_Data", dat, bin_dat->size);
new_idx->Load(binaryset);
EXPECT_EQ(new_idx->Count(), nb);
EXPECT_EQ(new_idx->Dim(), dim);
auto result = new_idx->Query(query_dataset, conf);
// AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
}
}

View File

@ -0,0 +1,149 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <gtest/gtest.h>
#include <knowhere/index/vector_index/IndexRHNSWSQ.h>
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
#include <iostream>
#include <random>
#include "knowhere/common/Exception.h"
#include "unittest/utils.h"
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class RHNSWSQ8Test : public DataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
IndexType = GetParam();
std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10
// Generate(2, 10, 2); // dim = 64, nb = 10000, nq = 10
index_ = std::make_shared<milvus::knowhere::IndexRHNSWSQ>();
conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200},
{milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
};
}
protected:
milvus::knowhere::Config conf;
std::shared_ptr<milvus::knowhere::IndexRHNSWSQ> index_ = nullptr;
std::string IndexType;
};
INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWSQ8Test, Values("RHNSWSQ8"));
TEST_P(RHNSWSQ8Test, HNSW_basic) {
assert(!xb.empty());
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
// Serialize and Load before Query
milvus::knowhere::BinarySet bs = index_->Serialize();
auto result1 = index_->Query(query_dataset, conf);
AssertAnns(result1, nq, k);
auto tmp_index = std::make_shared<milvus::knowhere::IndexRHNSWSQ>();
tmp_index->Load(bs);
auto result2 = tmp_index->Query(query_dataset, conf);
AssertAnns(result2, nq, k);
}
TEST_P(RHNSWSQ8Test, HNSW_delete) {
assert(!xb.empty());
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
for (auto i = 0; i < nq; ++i) {
bitset->set(i);
}
auto result1 = index_->Query(query_dataset, conf);
AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
/*
* delete result checked by eyes
auto ids1 = result1->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto ids2 = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
std::cout << std::endl;
for (int i = 0; i < nq; ++ i) {
std::cout << "ids1: ";
for (int j = 0; j < k; ++ j) {
std::cout << *(ids1 + i * k + j) << " ";
}
std::cout << "ids2: ";
for (int j = 0; j < k; ++ j) {
std::cout << *(ids2 + i * k + j) << " ";
}
std::cout << std::endl;
for (int j = 0; j < std::min(5, k>>1); ++ j) {
ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j));
}
}
*/
}
TEST_P(RHNSWSQ8Test, HNSW_serialize) {
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
{
FileIOWriter writer(filename);
writer(static_cast<void*>(bin->data.get()), bin->size);
}
FileIOReader reader(filename);
reader(ret, bin->size);
};
{
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
auto binaryset = index_->Serialize();
auto bin_idx = binaryset.GetByName(index_->index_type() + "_Index");
auto bin_dat = binaryset.GetByName(index_->index_type() + "_Data");
std::string filename_idx = "/tmp/RHNSWSQ_test_serialize_idx.bin";
std::string filename_dat = "/tmp/RHNSWSQ_test_serialize_dat.bin";
auto load_idx = new uint8_t[bin_idx->size];
auto load_dat = new uint8_t[bin_dat->size];
serialize(filename_idx, bin_idx, load_idx);
serialize(filename_dat, bin_dat, load_dat);
binaryset.clear();
auto new_idx = std::make_shared<milvus::knowhere::IndexRHNSWSQ>();
std::shared_ptr<uint8_t[]> dat(load_dat);
std::shared_ptr<uint8_t[]> idx(load_idx);
binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size);
binaryset.Append(new_idx->index_type() + "_Data", dat, bin_dat->size);
new_idx->Load(binaryset);
EXPECT_EQ(new_idx->Count(), nb);
EXPECT_EQ(new_idx->Dim(), dim);
auto result = new_idx->Query(query_dataset, conf);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
}
}

View File

@ -279,7 +279,9 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
if (!status.ok()) {
return status;
}
} else if (index_type == knowhere::IndexEnum::INDEX_HNSW || index_type == knowhere::IndexEnum::INDEX_HNSW_SQ8NM) {
} else if (index_type == knowhere::IndexEnum::INDEX_HNSW || index_type == knowhere::IndexEnum::INDEX_HNSW_SQ8NM ||
index_type == knowhere::IndexEnum::INDEX_RHNSWPQ || index_type == knowhere::IndexEnum::INDEX_RHNSWSQ ||
index_type == knowhere::IndexEnum::INDEX_RHNSWFlat) {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::M, 4, 64);
if (!status.ok()) {
return status;
@ -288,6 +290,38 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
if (!status.ok()) {
return status;
}
if (index_type == knowhere::IndexEnum::INDEX_RHNSWPQ) {
status = CheckParameterExistence(index_params, knowhere::IndexParams::PQM);
if (!status.ok()) {
return status;
}
// special check for 'PQM' parameter
std::vector<int64_t> resset;
milvus::knowhere::IVFPQConfAdapter::GetValidMList(dimension, resset);
int64_t pqm_value = index_params[knowhere::IndexParams::PQM];
if (resset.empty()) {
std::string msg = "Invalid collection dimension, unable to get reasonable values for 'PQM'";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_COLLECTION_DIMENSION, msg);
}
auto iter = std::find(std::begin(resset), std::end(resset), pqm_value);
if (iter == std::end(resset)) {
std::string msg =
"Invalid " + std::string(knowhere::IndexParams::PQM) + ", must be one of the following values: ";
for (size_t i = 0; i < resset.size(); i++) {
if (i != 0) {
msg += ",";
}
msg += std::to_string(resset[i]);
}
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
}
}
} else if (index_type == knowhere::IndexEnum::INDEX_ANNOY) {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::n_trees, 1, 1024);
if (!status.ok()) {

View File

@ -25,6 +25,9 @@ const char* NAME_ENGINE_TYPE_HNSW = "HNSW";
const char* NAME_ENGINE_TYPE_ANNOY = "ANNOY";
const char* NAME_ENGINE_TYPE_IVFSQ8NR = "IVFSQ8NR";
const char* NAME_ENGINE_TYPE_HNSWSQ8NM = "HNSWSQ8NM";
const char* NAME_ENGINE_TYPE_RHNSWFLAT = "RHNSWFLAT";
const char* NAME_ENGINE_TYPE_RHNSWPQ = "RHNSWPQ";
const char* NAME_ENGINE_TYPE_RHNSWSQ8 = "RHNSWSQ8";
const char* NAME_METRIC_TYPE_L2 = "L2";
const char* NAME_METRIC_TYPE_IP = "IP";

View File

@ -28,6 +28,9 @@ extern const char* NAME_ENGINE_TYPE_IVFPQ;
extern const char* NAME_ENGINE_TYPE_HNSW;
extern const char* NAME_ENGINE_TYPE_HNSW_SQ8NM;
extern const char* NAME_ENGINE_TYPE_ANNOY;
extern const char* NAME_ENGINE_TYPE_RHNSWFLAT;
extern const char* NAME_ENGINE_TYPE_RHNSWPQ;
extern const char* NAME_ENGINE_TYPE_RHNSWSQ;
extern const char* NAME_METRIC_TYPE_L2;
extern const char* NAME_METRIC_TYPE_IP;