mirror of https://github.com/milvus-io/milvus.git
Migrate knowhere to segcore
Signed-off-by: FluorineDog <guilin.gou@zilliz.com>pull/4973/head^2
parent
f3649f0419
commit
e33d0a797c
internal/core
build-support
src/index
knowhere
knowhere
common
index
structured_index
vector_index
thirdparty
annoy/src
faiss
|
@ -130,7 +130,7 @@ if __name__ == "__main__":
|
|||
print(file=sys.stderr)
|
||||
diff_out = []
|
||||
for diff_str in diff:
|
||||
diff_out.append(diff_str.encode('raw_unicode_escape'))
|
||||
diff_out.append(diff_str.encode('raw_unicode_escape').decode('ascii'))
|
||||
sys.stderr.writelines(diff_out)
|
||||
except Exception:
|
||||
error = True
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "index/archive/KnowhereResource.h"
|
||||
#include "index/archive/KnowhereConfig.h"
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
#endif
|
||||
|
@ -24,17 +24,11 @@
|
|||
#include "knowhere/index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/IndexHNSW.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
#include "scheduler/Utils.h"
|
||||
#include "utils/ConfigUtils.h"
|
||||
#include "utils/Error.h"
|
||||
#include "utils/Log.h"
|
||||
#include "value/config/ServerConfig.h"
|
||||
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace milvus {
|
||||
|
@ -43,8 +37,7 @@ namespace engine {
|
|||
constexpr int64_t M_BYTE = 1024 * 1024;
|
||||
|
||||
Status
|
||||
KnowhereResource::Initialize() {
|
||||
auto simd_type = config.engine.simd_type();
|
||||
KnowhereConfig::SetSimdType(const SimdType simd_type) {
|
||||
if (simd_type == SimdType::AVX512) {
|
||||
faiss::faiss_use_avx512 = true;
|
||||
faiss::faiss_use_avx2 = false;
|
||||
|
@ -62,19 +55,29 @@ KnowhereResource::Initialize() {
|
|||
faiss::faiss_use_avx2 = true;
|
||||
faiss::faiss_use_sse = true;
|
||||
}
|
||||
|
||||
std::string cpu_flag;
|
||||
if (faiss::hook_init(cpu_flag)) {
|
||||
std::cout << "FAISS hook " << cpu_flag << std::endl;
|
||||
LOG_ENGINE_DEBUG_ << "FAISS hook " << cpu_flag;
|
||||
} else {
|
||||
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// init faiss global variable
|
||||
int64_t use_blas_threshold = config.engine.use_blas_threshold();
|
||||
faiss::distance_compute_blas_threshold = use_blas_threshold;
|
||||
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
|
||||
}
|
||||
|
||||
int64_t clustering_type = config.engine.clustering_type();
|
||||
void
|
||||
KnowhereConfig::SetBlasThreshold(const int64_t use_blas_threshold) {
|
||||
faiss::distance_compute_blas_threshold = static_cast<int>(use_blas_threshold);
|
||||
}
|
||||
|
||||
void
|
||||
KnowhereConfig::SetEarlyStopThreshold(const double early_stop_threshold) {
|
||||
faiss::early_stop_threshold = early_stop_threshold;
|
||||
}
|
||||
|
||||
void
|
||||
KnowhereConfig::SetClusteringType(const ClusteringType clustering_type) {
|
||||
switch (clustering_type) {
|
||||
case ClusteringType::K_MEANS:
|
||||
default:
|
||||
|
@ -84,63 +87,38 @@ KnowhereResource::Initialize() {
|
|||
faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
bool enable_gpu = config.gpu.enable();
|
||||
fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false);
|
||||
if (enable_gpu) {
|
||||
struct GpuResourceSetting {
|
||||
int64_t pinned_memory = 256 * M_BYTE;
|
||||
int64_t temp_memory = 256 * M_BYTE;
|
||||
int64_t resource_num = 2;
|
||||
};
|
||||
using GpuResourcesArray = std::map<int64_t, GpuResourceSetting>;
|
||||
GpuResourcesArray gpu_resources;
|
||||
|
||||
// get build index gpu resource
|
||||
std::vector<int64_t> build_index_gpus = ParseGPUDevices(config.gpu.build_index_devices());
|
||||
|
||||
for (auto gpu_id : build_index_gpus) {
|
||||
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
|
||||
}
|
||||
|
||||
// get search gpu resource
|
||||
std::vector<int64_t> search_gpus = ParseGPUDevices(config.gpu.search_devices());
|
||||
|
||||
for (auto& gpu_id : search_gpus) {
|
||||
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
|
||||
}
|
||||
|
||||
// init gpu resources
|
||||
for (auto& gpu_resource : gpu_resources) {
|
||||
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(
|
||||
gpu_resource.first, gpu_resource.second.pinned_memory, gpu_resource.second.temp_memory,
|
||||
gpu_resource.second.resource_num);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
void
|
||||
KnowhereConfig::SetStatisticsLevel(const int64_t stat_level) {
|
||||
milvus::knowhere::STATISTICS_LEVEL = stat_level;
|
||||
faiss::STATISTICS_LEVEL = stat_level;
|
||||
}
|
||||
|
||||
void
|
||||
KnowhereConfig::SetLogHandler() {
|
||||
faiss::LOG_ERROR_ = &knowhere::log_error_;
|
||||
faiss::LOG_WARNING_ = &knowhere::log_warning_;
|
||||
// faiss::LOG_DEBUG_ = &knowhere::log_debug_;
|
||||
NGT_LOG_ERROR_ = &knowhere::log_error_;
|
||||
NGT_LOG_WARNING_ = &knowhere::log_warning_;
|
||||
// NGT_LOG_DEBUG_ = &knowhere::log_debug_;
|
||||
|
||||
auto stat_level = config.engine.statistics_level();
|
||||
milvus::knowhere::STATISTICS_LEVEL = stat_level;
|
||||
faiss::STATISTICS_LEVEL = stat_level;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
KnowhereResource::Finalize() {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource.
|
||||
#endif
|
||||
return Status::OK();
|
||||
void
|
||||
KnowhereConfig::InitGPUResource(const std::vector<int64_t>& gpu_ids) {
|
||||
for (auto id : gpu_ids) {
|
||||
// device_id, pinned_memory, temp_memory, resource_num
|
||||
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(id, 256 * M_BYTE, 256 * M_BYTE, 2);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
KnowhereConfig::FreeGPUResource() {
|
||||
knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource.
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace engine
|
||||
} // namespace milvus
|
|
@ -0,0 +1,94 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "utils/Status.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
class KnowhereConfig {
|
||||
public:
|
||||
/**
|
||||
* set SIMD type
|
||||
*/
|
||||
enum SimdType {
|
||||
AUTO = 1, // enable all and depend on the system
|
||||
SSE, // only enable SSE
|
||||
AVX2, // only enable AVX2
|
||||
AVX512, // only enable AVX512
|
||||
};
|
||||
|
||||
static Status
|
||||
SetSimdType(const SimdType simd_type);
|
||||
|
||||
/**
|
||||
* Set openblas threshold
|
||||
* if nq < use_blas_threshold, calculated by omp
|
||||
* else, calculated by openblas
|
||||
*/
|
||||
static void
|
||||
SetBlasThreshold(const int64_t use_blas_threshold);
|
||||
|
||||
/**
|
||||
* set Clustering early stop [0, 100]
|
||||
* It is to reduce the number of iterations of K-means.
|
||||
* Between each two iterations, if the optimization rate < early_stop_threshold, stop
|
||||
* And if early_stop_threshold = 0, won't early stop
|
||||
*/
|
||||
static void
|
||||
SetEarlyStopThreshold(const double early_stop_threshold);
|
||||
|
||||
/**
|
||||
* set Clustering type
|
||||
*/
|
||||
enum ClusteringType {
|
||||
K_MEANS, // k-means (default)
|
||||
K_MEANS_PLUS_PLUS, // k-means++
|
||||
};
|
||||
|
||||
static void
|
||||
SetClusteringType(const ClusteringType clustering_type);
|
||||
|
||||
/**
|
||||
* set Statistics Level [0, 3]
|
||||
*/
|
||||
static void
|
||||
SetStatisticsLevel(const int64_t stat_level);
|
||||
|
||||
// todo: add log level?
|
||||
/**
|
||||
* set Log handler
|
||||
*/
|
||||
static void
|
||||
SetLogHandler();
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
// todo: move to ohter file?
|
||||
/**
|
||||
* init GPU Resource
|
||||
*/
|
||||
static void
|
||||
InitGPUResource(const std::vector<int64_t>& gpu_ids);
|
||||
|
||||
/**
|
||||
* free GPU Resource
|
||||
*/
|
||||
static void
|
||||
FreeGPUResource();
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace engine
|
||||
} // namespace milvus
|
|
@ -1,29 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils/Status.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
class KnowhereResource {
|
||||
public:
|
||||
static Status
|
||||
Initialize();
|
||||
|
||||
static Status
|
||||
Finalize();
|
||||
};
|
||||
|
||||
} // namespace engine
|
||||
} // namespace milvus
|
|
@ -90,7 +90,7 @@ define_option(MILVUS_CUDA_ARCH "Build with CUDA arch" "DEFAULT")
|
|||
|
||||
#----------------------------------------------------------------------
|
||||
set_option_category("Test and benchmark")
|
||||
unset(KNOWHERE_BUILD_TESTS CACHE)
|
||||
|
||||
if (BUILD_UNIT_TEST)
|
||||
define_option(KNOWHERE_BUILD_TESTS "Build the KNOWHERE googletest unit tests" ON)
|
||||
else ()
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
set(KNOWHERE_THIRDPARTY_DEPENDENCIES
|
||||
Arrow
|
||||
FAISS
|
||||
GTest
|
||||
OpenBLAS
|
||||
MKL
|
||||
)
|
||||
|
|
|
@ -57,6 +57,7 @@ set(vector_index_srcs
|
|||
knowhere/index/vector_index/adapter/VectorAdapter.cpp
|
||||
knowhere/index/vector_index/helpers/FaissIO.cpp
|
||||
knowhere/index/vector_index/helpers/IndexParameter.cpp
|
||||
knowhere/index/vector_index/helpers/DynamicResultSet.cpp
|
||||
knowhere/index/vector_index/impl/nsg/Distance.cpp
|
||||
knowhere/index/vector_index/impl/nsg/NSG.cpp
|
||||
knowhere/index/vector_index/impl/nsg/NSGHelper.cpp
|
||||
|
@ -71,6 +72,7 @@ set(vector_index_srcs
|
|||
knowhere/index/vector_index/IndexIVF.cpp
|
||||
knowhere/index/vector_index/IndexIVFPQ.cpp
|
||||
knowhere/index/vector_index/IndexIVFSQ.cpp
|
||||
knowhere/index/vector_index/IndexIVFHNSW.cpp
|
||||
knowhere/index/vector_index/IndexAnnoy.cpp
|
||||
knowhere/index/vector_index/IndexRHNSW.cpp
|
||||
knowhere/index/vector_index/IndexHNSW.cpp
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <utility>
|
||||
|
||||
#include "Log.h"
|
||||
|
@ -22,19 +23,13 @@ KnowhereException::KnowhereException(std::string msg) : msg_(std::move(msg)) {
|
|||
}
|
||||
|
||||
KnowhereException::KnowhereException(const std::string& m, const char* funcName, const char* file, int line) {
|
||||
std::string filename;
|
||||
try {
|
||||
size_t pos;
|
||||
std::string file_path(file);
|
||||
pos = file_path.find_last_of('/');
|
||||
filename = file_path.substr(pos + 1);
|
||||
} catch (std::exception& e) {
|
||||
LOG_KNOWHERE_ERROR_ << e.what();
|
||||
const char* filename = funcName;
|
||||
while (auto tmp = strchr(filename, '/')) {
|
||||
filename = tmp + 1;
|
||||
}
|
||||
|
||||
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", funcName, filename.c_str(), line, m.c_str());
|
||||
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", funcName, filename, line, m.c_str());
|
||||
msg_.resize(size + 1);
|
||||
snprintf(&msg_[0], msg_.size(), "Error in %s at %s:%d: %s", funcName, filename.c_str(), line, m.c_str());
|
||||
snprintf(msg_.data(), m.size(), "Error in %s at %s:%d: %s", funcName, filename, line, m.c_str());
|
||||
}
|
||||
|
||||
const char*
|
||||
|
|
|
@ -40,10 +40,9 @@ Slice(const std::string& prefix,
|
|||
for (int64_t i = 0; i < data_src->size; ++slice_num) {
|
||||
int64_t ri = std::min(i + slice_len, data_src->size);
|
||||
auto size = static_cast<size_t>(ri - i);
|
||||
auto slice_i = reinterpret_cast<uint8_t*>(malloc(size));
|
||||
memcpy(slice_i, data_src->data.get() + i, size);
|
||||
std::shared_ptr<uint8_t[]> slice_i_sp(slice_i, std::default_delete<uint8_t[]>());
|
||||
binarySet.Append(prefix + "_" + std::to_string(slice_num), slice_i_sp, ri - i);
|
||||
auto slice_i = std::shared_ptr<uint8_t[]>(new uint8_t[size]);
|
||||
memcpy(slice_i.get(), data_src->data.get() + i, size);
|
||||
binarySet.Append(prefix + "_" + std::to_string(slice_num), slice_i, ri - i);
|
||||
i = ri;
|
||||
}
|
||||
ret[NAME] = prefix;
|
||||
|
@ -65,15 +64,14 @@ Assemble(BinarySet& binarySet) {
|
|||
std::string prefix = item[NAME];
|
||||
int slice_num = item[SLICE_NUM];
|
||||
auto total_len = static_cast<size_t>(item[TOTAL_LEN]);
|
||||
auto p_data = reinterpret_cast<uint8_t*>(malloc(total_len));
|
||||
auto p_data = std::shared_ptr<uint8_t[]>(new uint8_t[total_len]);
|
||||
int64_t pos = 0;
|
||||
for (auto i = 0; i < slice_num; ++i) {
|
||||
auto slice_i_sp = binarySet.Erase(prefix + "_" + std::to_string(i));
|
||||
memcpy(p_data + pos, slice_i_sp->data.get(), static_cast<size_t>(slice_i_sp->size));
|
||||
memcpy(p_data.get() + pos, slice_i_sp->data.get(), static_cast<size_t>(slice_i_sp->size));
|
||||
pos += slice_i_sp->size;
|
||||
}
|
||||
std::shared_ptr<uint8_t[]> integral_data(p_data, std::default_delete<uint8_t[]>());
|
||||
binarySet.Append(prefix, integral_data, total_len);
|
||||
binarySet.Append(prefix, p_data, total_len);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ const char* INDEX_FAISS_IVFFLAT = "IVF_FLAT";
|
|||
const char* INDEX_FAISS_IVFPQ = "IVF_PQ";
|
||||
const char* INDEX_FAISS_IVFSQ8 = "IVF_SQ8";
|
||||
const char* INDEX_FAISS_IVFSQ8H = "IVF_SQ8_HYBRID";
|
||||
const char* INDEX_FAISS_IVFHNSW = "IVF_HNSW";
|
||||
const char* INDEX_FAISS_BIN_IDMAP = "BIN_FLAT";
|
||||
const char* INDEX_FAISS_BIN_IVFFLAT = "BIN_IVF_FLAT";
|
||||
const char* INDEX_NSG = "NSG";
|
||||
|
|
|
@ -54,6 +54,7 @@ extern const char* INDEX_FAISS_IVFFLAT;
|
|||
extern const char* INDEX_FAISS_IVFPQ;
|
||||
extern const char* INDEX_FAISS_IVFSQ8;
|
||||
extern const char* INDEX_FAISS_IVFSQ8H;
|
||||
extern const char* INDEX_FAISS_IVFHNSW;
|
||||
extern const char* INDEX_FAISS_BIN_IDMAP;
|
||||
extern const char* INDEX_FAISS_BIN_IVFFLAT;
|
||||
extern const char* INDEX_NSG;
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "knowhere/knowhere/common/Log.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/structured_index/StructuredIndexFlat.h"
|
||||
|
||||
namespace milvus {
|
||||
|
@ -34,6 +34,8 @@ StructuredIndexFlat<T>::~StructuredIndexFlat() {
|
|||
template <typename T>
|
||||
BinarySet
|
||||
StructuredIndexFlat<T>::Serialize(const milvus::knowhere::Config& config) {
|
||||
// TODO
|
||||
return BinarySet();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -194,6 +194,10 @@ IVFPQConfAdapter::CheckGPUPQParams(int64_t dimension, int64_t m, int64_t nbits)
|
|||
static const std::vector<int64_t> support_dim_per_subquantizer{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1};
|
||||
static const std::vector<int64_t> support_subquantizer{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1};
|
||||
|
||||
if (!CheckCPUPQParams(dimension, m)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t sub_dim = dimension / m;
|
||||
return (std::find(std::begin(support_subquantizer), std::end(support_subquantizer), m) !=
|
||||
support_subquantizer.end()) &&
|
||||
|
@ -207,6 +211,34 @@ IVFPQConfAdapter::CheckCPUPQParams(int64_t dimension, int64_t m) {
|
|||
return (dimension % m == 0);
|
||||
}
|
||||
|
||||
bool
|
||||
IVFHNSWConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
// HNSW param check
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
|
||||
|
||||
// IVF param check
|
||||
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
|
||||
|
||||
// auto tune params
|
||||
auto rows = oricfg[knowhere::meta::ROWS].get<int64_t>();
|
||||
auto nlist = oricfg[knowhere::IndexParams::nlist].get<int64_t>();
|
||||
oricfg[knowhere::IndexParams::nlist] = MatchNlist(rows, nlist);
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
IVFHNSWConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
// HNSW param check
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF);
|
||||
|
||||
// IVF param check
|
||||
CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, MAX_NPROBE);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
NSGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
const int64_t MIN_KNNG = 5;
|
||||
|
|
|
@ -61,6 +61,15 @@ class IVFPQConfAdapter : public IVFConfAdapter {
|
|||
CheckCPUPQParams(int64_t dimension, int64_t m);
|
||||
};
|
||||
|
||||
class IVFHNSWConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
bool
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class NSGConfAdapter : public IVFConfAdapter {
|
||||
public:
|
||||
bool
|
||||
|
|
|
@ -41,6 +41,7 @@ AdapterMgr::RegisterAdapter() {
|
|||
REGISTER_CONF_ADAPTER(IVFPQConfAdapter, IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_adapter);
|
||||
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq8_adapter);
|
||||
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8H, ivfsq8h_adapter);
|
||||
REGISTER_CONF_ADAPTER(IVFHNSWConfAdapter, IndexEnum::INDEX_FAISS_IVFHNSW, ivfhnsw_adapter);
|
||||
REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IDMAP, idmap_bin_adapter);
|
||||
REGISTER_CONF_ADAPTER(BinIVFConfAdapter, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivf_bin_adapter);
|
||||
REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexEnum::INDEX_NSG, nsg_adapter);
|
||||
|
|
|
@ -109,7 +109,7 @@ IndexAnnoy::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ class IndexAnnoy : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -42,7 +42,7 @@ BinaryIDMAP::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
@ -65,6 +65,43 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const fa
|
|||
return ret_ds;
|
||||
}
|
||||
|
||||
DynamicResultSegment
|
||||
BinaryIDMAP::QueryByDistance(const milvus::knowhere::DatasetPtr& dataset,
|
||||
const milvus::knowhere::Config& config,
|
||||
const faiss::BitsetView bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset)
|
||||
if (rows != 1) {
|
||||
KNOWHERE_THROW_MSG("QueryByDistance only accept nq = 1!");
|
||||
}
|
||||
|
||||
auto default_type = index_->metric_type;
|
||||
if (config.contains(Metric::TYPE)) {
|
||||
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
}
|
||||
std::vector<faiss::RangeSearchPartialResult*> res;
|
||||
DynamicResultSegment result;
|
||||
float radius = 0.0;
|
||||
if (index_->metric_type != faiss::MetricType::METRIC_Substructure &&
|
||||
index_->metric_type != faiss::METRIC_Superstructure) {
|
||||
radius = config[IndexParams::range_search_radius].get<float>();
|
||||
}
|
||||
auto buffer_size = config.contains(IndexParams::range_search_buffer_size)
|
||||
? config[IndexParams::range_search_buffer_size].get<size_t>()
|
||||
: 16384;
|
||||
auto real_idx = dynamic_cast<faiss::IndexBinaryFlat*>(index_.get());
|
||||
if (real_idx == nullptr) {
|
||||
KNOWHERE_THROW_MSG("Cannot dynamic_cast the index to faiss::IndexBinaryFlat type!");
|
||||
}
|
||||
real_idx->range_search(rows, reinterpret_cast<const uint8_t*>(p_data), radius, res, buffer_size, bitset);
|
||||
ExchangeDataset(result, res);
|
||||
MapUids(result);
|
||||
index_->metric_type = default_type;
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t
|
||||
BinaryIDMAP::Count() {
|
||||
if (!index_) {
|
||||
|
@ -120,7 +157,7 @@ BinaryIDMAP::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView& bitset) {
|
||||
const faiss::BitsetView bitset) {
|
||||
// assign the metric type
|
||||
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
@ -44,7 +45,10 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView bitset) override;
|
||||
|
||||
DynamicResultSegment
|
||||
QueryByDistance(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView bitset);
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
@ -68,7 +72,7 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView& bitset);
|
||||
const faiss::BitsetView bitset);
|
||||
};
|
||||
|
||||
using BinaryIDMAPPtr = std::shared_ptr<BinaryIDMAP>;
|
||||
|
|
|
@ -51,21 +51,32 @@ BinaryIVF::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
GET_TENSOR_DATA(dataset_ptr)
|
||||
|
||||
int64_t* p_id = nullptr;
|
||||
float* p_dist = nullptr;
|
||||
auto release_when_exception = [&]() {
|
||||
if (p_id != nullptr) {
|
||||
free(p_id);
|
||||
}
|
||||
if (p_dist != nullptr) {
|
||||
free(p_dist);
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
auto k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
|
||||
auto p_dist = static_cast<float*>(malloc(p_dist_size));
|
||||
p_id = static_cast<int64_t*>(malloc(p_id_size));
|
||||
p_dist = static_cast<float*>(malloc(p_dist_size));
|
||||
|
||||
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config, bitset);
|
||||
MapOffsetToUid(p_id, static_cast<size_t>(elems));
|
||||
|
@ -77,8 +88,10 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
|
|||
|
||||
return ret_ds;
|
||||
} catch (faiss::FaissException& e) {
|
||||
release_when_exception();
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
} catch (std::exception& e) {
|
||||
release_when_exception();
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
@ -176,7 +189,7 @@ BinaryIVF::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView& bitset) {
|
||||
const faiss::BitsetView bitset) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
|
|
|
@ -49,7 +49,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
@ -77,7 +77,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView& bitset);
|
||||
const faiss::BitsetView bitset);
|
||||
};
|
||||
|
||||
using BinaryIVFIndexPtr = std::shared_ptr<BinaryIVF>;
|
||||
|
|
|
@ -138,7 +138,7 @@ IndexHNSW::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ class IndexHNSW : public VecIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -76,7 +76,7 @@ IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
@ -98,6 +98,42 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::B
|
|||
return ret_ds;
|
||||
}
|
||||
|
||||
DynamicResultSegment
|
||||
IDMAP::QueryByDistance(const milvus::knowhere::DatasetPtr& dataset,
|
||||
const milvus::knowhere::Config& config,
|
||||
const faiss::BitsetView bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset)
|
||||
if (rows != 1) {
|
||||
KNOWHERE_THROW_MSG("QueryByDistance only accept nq = 1!");
|
||||
}
|
||||
|
||||
auto default_type = index_->metric_type;
|
||||
if (config.contains(Metric::TYPE)) {
|
||||
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
}
|
||||
std::vector<faiss::RangeSearchPartialResult*> res;
|
||||
DynamicResultSegment result;
|
||||
auto radius = config[IndexParams::range_search_radius].get<float>();
|
||||
auto buffer_size = config.contains(IndexParams::range_search_buffer_size)
|
||||
? config[IndexParams::range_search_buffer_size].get<size_t>()
|
||||
: 16384;
|
||||
auto real_idx = dynamic_cast<faiss::IndexFlat*>(index_.get());
|
||||
if (real_idx == nullptr) {
|
||||
KNOWHERE_THROW_MSG("Cannot dynamic_cast the index to faiss::IndexFlat type!");
|
||||
}
|
||||
if (index_->metric_type == faiss::MetricType::METRIC_L2) {
|
||||
radius *= radius;
|
||||
}
|
||||
real_idx->range_search(rows, reinterpret_cast<const float*>(p_data), radius, res, buffer_size, bitset);
|
||||
ExchangeDataset(result, res);
|
||||
MapUids(result);
|
||||
index_->metric_type = default_type;
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t
|
||||
IDMAP::Count() {
|
||||
if (!index_) {
|
||||
|
@ -149,7 +185,7 @@ IDMAP::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView& bitset) {
|
||||
const faiss::BitsetView bitset) {
|
||||
// assign the metric type
|
||||
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
index_->search(n, data, k, distances, labels, bitset);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "knowhere/index/vector_index/FaissBaseIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
@ -43,7 +44,10 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;
|
||||
|
||||
DynamicResultSegment
|
||||
QueryByDistance(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView bitset);
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
@ -64,7 +68,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
|
|||
|
||||
protected:
|
||||
virtual void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView&);
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView);
|
||||
};
|
||||
|
||||
using IDMAPPtr = std::shared_ptr<IDMAP>;
|
||||
|
|
|
@ -62,7 +62,6 @@ IVF::Serialize(const Config& config) {
|
|||
void
|
||||
IVF::Load(const BinarySet& binary_set) {
|
||||
Assemble(const_cast<BinarySet&>(binary_set));
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
|
||||
LoadImpl(binary_set, index_type_);
|
||||
|
||||
if (IndexMode() == IndexMode::MODE_CPU && STATISTICS_LEVEL >= 3) {
|
||||
|
@ -95,13 +94,24 @@ IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
GET_TENSOR_DATA(dataset_ptr)
|
||||
|
||||
int64_t* p_id = nullptr;
|
||||
float* p_dist = nullptr;
|
||||
auto release_when_exception = [&]() {
|
||||
if (p_id != nullptr) {
|
||||
free(p_id);
|
||||
}
|
||||
if (p_dist != nullptr) {
|
||||
free(p_dist);
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
fiu_do_on("IVF.Search.throw_std_exception", throw std::exception());
|
||||
fiu_do_on("IVF.Search.throw_faiss_exception", throw faiss::FaissException(""));
|
||||
|
@ -110,8 +120,8 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::Bit
|
|||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
|
||||
auto p_dist = static_cast<float*>(malloc(p_dist_size));
|
||||
p_id = static_cast<int64_t*>(malloc(p_id_size));
|
||||
p_dist = static_cast<float*>(malloc(p_dist_size));
|
||||
|
||||
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);
|
||||
MapOffsetToUid(p_id, static_cast<size_t>(elems));
|
||||
|
@ -121,8 +131,10 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::Bit
|
|||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
return ret_ds;
|
||||
} catch (faiss::FaissException& e) {
|
||||
release_when_exception();
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
} catch (std::exception& e) {
|
||||
release_when_exception();
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
@ -310,7 +322,7 @@ IVF::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView& bitset) {
|
||||
const faiss::BitsetView bitset) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->nprobe = std::min(params->nprobe, ivf_index->invlists->nlist);
|
||||
|
|
|
@ -49,7 +49,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
|
@ -90,7 +90,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
|
|||
GenParams(const Config&);
|
||||
|
||||
virtual void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView&);
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView);
|
||||
|
||||
void
|
||||
SealImpl() override;
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/IndexIVFFlat.h>
|
||||
#include <faiss/clone_index.h>
|
||||
#include <faiss/index_io.h>
|
||||
|
||||
#include "faiss/IndexRHNSW.h"
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFHNSW.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
BinarySet
|
||||
IVFHNSW::Serialize(const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
try {
|
||||
// Serialize IVF index and HNSW data
|
||||
auto res_set = SerializeImpl(index_type_);
|
||||
auto index = dynamic_cast<faiss::IndexIVFFlat*>(index_.get());
|
||||
auto real_idx = dynamic_cast<faiss::IndexRHNSWFlat*>(index->quantizer);
|
||||
if (real_idx == nullptr) {
|
||||
KNOWHERE_THROW_MSG("Quantizer index is not a faiss::IndexRHNSWFlat");
|
||||
}
|
||||
|
||||
MemoryIOWriter writer;
|
||||
faiss::write_index(real_idx->storage, &writer);
|
||||
std::shared_ptr<uint8_t[]> data(writer.data_);
|
||||
res_set.Append("HNSW_STORAGE", data, writer.rp);
|
||||
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
|
||||
}
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVFHNSW::Load(const BinarySet& binary_set) {
|
||||
try {
|
||||
// Load IVF index and HNSW data
|
||||
Assemble(const_cast<BinarySet&>(binary_set));
|
||||
LoadImpl(binary_set, index_type_);
|
||||
|
||||
auto index = dynamic_cast<faiss::IndexIVFFlat*>(index_.get());
|
||||
MemoryIOReader reader;
|
||||
auto binary = binary_set.GetByName("HNSW_STORAGE");
|
||||
reader.total = static_cast<size_t>(binary->size);
|
||||
reader.data_ = binary->data.get();
|
||||
|
||||
auto real_idx = dynamic_cast<faiss::IndexRHNSWFlat*>(index->quantizer);
|
||||
real_idx->storage = faiss::read_index(&reader);
|
||||
real_idx->init_hnsw();
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVFHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
auto coarse_quantizer = new faiss::IndexRHNSWFlat(dim, config[IndexParams::M], metric_type);
|
||||
coarse_quantizer->hnsw.efConstruction = config[IndexParams::efConstruction];
|
||||
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(),
|
||||
metric_type);
|
||||
index->own_fields = true;
|
||||
index->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
index_ = index;
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
IVFHNSW::CopyCpuToGpu(const int64_t device_id, const Config& config) {
|
||||
KNOWHERE_THROW_MSG("IVFHNSW::CopyCpuToGpu not supported.");
|
||||
}
|
||||
|
||||
void
|
||||
IVFHNSW::UpdateIndexSize() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
auto ivf_index = static_cast<faiss::IndexIVFFlat*>(index_.get());
|
||||
auto nb = ivf_index->invlists->compute_ntotal();
|
||||
auto code_size = ivf_index->code_size;
|
||||
auto hnsw_quantizer = dynamic_cast<faiss::IndexRHNSWFlat*>(ivf_index->quantizer);
|
||||
// ivf codes, ivf ids and hnsw_flat quantizer
|
||||
index_size_ = nb * code_size + nb * sizeof(int64_t) + hnsw_quantizer->cal_size();
|
||||
}
|
||||
|
||||
void
|
||||
IVFHNSW::QueryImpl(int64_t n,
|
||||
const float* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView bitset) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->nprobe = std::min(params->nprobe, ivf_index->invlists->nlist);
|
||||
if (params->nprobe > 1 && n <= 4) {
|
||||
ivf_index->parallel_mode = 1;
|
||||
} else {
|
||||
ivf_index->parallel_mode = 0;
|
||||
}
|
||||
// Update HNSW quantizer search param
|
||||
auto hnsw_quantizer = dynamic_cast<faiss::IndexRHNSWFlat*>(ivf_index->quantizer);
|
||||
hnsw_quantizer->hnsw.efSearch = config[IndexParams::ef].get<int64_t>();
|
||||
ivf_index->search(n, data, k, distances, labels, bitset);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IVFHNSW : public IVF {
|
||||
public:
|
||||
IVFHNSW() : IVF() {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFHNSW;
|
||||
}
|
||||
|
||||
explicit IVFHNSW(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFHNSW;
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize(const Config&) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet&) override;
|
||||
|
||||
void
|
||||
Train(const DatasetPtr&, const Config&) override;
|
||||
|
||||
VecIndexPtr
|
||||
CopyCpuToGpu(const int64_t, const Config&) override;
|
||||
|
||||
void
|
||||
UpdateIndexSize() override;
|
||||
|
||||
protected:
|
||||
void
|
||||
QueryImpl(int64_t n,
|
||||
const float* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView bitset) override;
|
||||
};
|
||||
|
||||
using IVFHNSWPtr = std::shared_ptr<IVFHNSW>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -19,7 +19,6 @@
|
|||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/IndexScalarQuantizer.h>
|
||||
#include <faiss/clone_index.h>
|
||||
#include <faiss/index_factory.h>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
|
|
|
@ -122,7 +122,7 @@ IndexNGT::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
#endif
|
||||
|
||||
DatasetPtr
|
||||
IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ class IndexNGT : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -99,7 +99,7 @@ IndexRHNSW::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ class IndexRHNSW : public VecIndex, public FaissBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -87,9 +87,10 @@ IndexRHNSWFlat::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
try {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
int32_t efConstruction = config[IndexParams::efConstruction];
|
||||
|
||||
auto idx = new faiss::IndexRHNSWFlat(int(dim), config[IndexParams::M], metric_type);
|
||||
idx->hnsw.efConstruction = config[IndexParams::efConstruction];
|
||||
idx->hnsw.efConstruction = efConstruction;
|
||||
index_ = std::shared_ptr<faiss::Index>(idx);
|
||||
index_->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
} catch (std::exception& e) {
|
||||
|
|
|
@ -84,9 +84,10 @@ void
|
|||
IndexRHNSWPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
try {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
int32_t efConstruction = config[IndexParams::efConstruction];
|
||||
|
||||
auto idx = new faiss::IndexRHNSWPQ(int(dim), config[IndexParams::PQM], config[IndexParams::M]);
|
||||
idx->hnsw.efConstruction = config[IndexParams::efConstruction];
|
||||
idx->hnsw.efConstruction = efConstruction;
|
||||
index_ = std::shared_ptr<faiss::Index>(idx);
|
||||
index_->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
} catch (std::exception& e) {
|
||||
|
|
|
@ -88,10 +88,11 @@ IndexRHNSWSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
try {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
int32_t efConstruction = config[IndexParams::efConstruction];
|
||||
|
||||
auto idx =
|
||||
new faiss::IndexRHNSWSQ(int(dim), faiss::QuantizerType::QT_8bit, config[IndexParams::M], metric_type);
|
||||
idx->hnsw.efConstruction = config[IndexParams::efConstruction];
|
||||
idx->hnsw.efConstruction = efConstruction;
|
||||
index_ = std::shared_ptr<faiss::Index>(idx);
|
||||
index_->train(rows, static_cast<const float*>(p_data));
|
||||
} catch (std::exception& e) {
|
||||
|
|
|
@ -180,7 +180,7 @@ CPUSPTAGRNG::SetParameters(const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
|
||||
SetParameters(config);
|
||||
|
||||
float* p_data = (float*)dataset_ptr->Get<const void*>(meta::TENSOR);
|
||||
|
|
|
@ -47,7 +47,7 @@ class CPUSPTAGRNG : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -19,10 +19,10 @@
|
|||
#include <vector>
|
||||
|
||||
#include "IndexIVF.h"
|
||||
#include "faiss/utils/ConcurrentBitset.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/Statistics.h"
|
||||
#include "faiss/utils/ConcurrentBitset.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
|
|
@ -147,7 +147,7 @@ class Statistics {
|
|||
}
|
||||
|
||||
void
|
||||
update_filter_percentage(const faiss::BitsetView& bitset) {
|
||||
update_filter_percentage(const faiss::BitsetView bitset) {
|
||||
double fps = !bitset.empty() ? static_cast<double>(bitset.count_1()) / bitset.size() : 0.0;
|
||||
filter_stat[static_cast<int>(fps * 100) / 5] += 1;
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "knowhere/index/Index.h"
|
||||
#include "knowhere/index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/Statistics.h"
|
||||
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
@ -46,7 +47,7 @@ class VecIndex : public Index {
|
|||
AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
virtual DatasetPtr
|
||||
Query(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView& bitset) = 0;
|
||||
Query(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView bitset) = 0;
|
||||
|
||||
virtual int64_t
|
||||
Dim() = 0;
|
||||
|
@ -94,6 +95,23 @@ class VecIndex : public Index {
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
MapUids(DynamicResultSegment& milvus_dataset) {
|
||||
if (uids_) {
|
||||
for (auto& mrspr : milvus_dataset) {
|
||||
for (auto j = 0; j < mrspr->buffers.size(); ++j) {
|
||||
auto buf = mrspr->buffers[j];
|
||||
auto len = j + 1 == mrspr->buffers.size() ? mrspr->wp : mrspr->buffer_size;
|
||||
for (auto i = 0; i < len; ++i) {
|
||||
if (buf.ids[i] >= 0) {
|
||||
buf.ids[i] = uids_->at(buf.ids[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
UidsSize() {
|
||||
return uids_ ? uids_->size() * sizeof(IDType) : 0;
|
||||
|
|
|
@ -103,7 +103,7 @@ GPUIDMAP::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView& bitset) {
|
||||
const faiss::BitsetView bitset) {
|
||||
ResScope rs(res_, gpu_id_);
|
||||
|
||||
// assign the metric type
|
||||
|
|
|
@ -52,8 +52,7 @@ class GPUIDMAP : public IDMAP, public GPUIndex {
|
|||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(
|
||||
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView) override;
|
||||
};
|
||||
|
||||
using GPUIDMAPPtr = std::shared_ptr<GPUIDMAP>;
|
||||
|
|
|
@ -140,7 +140,7 @@ GPUIVF::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView& bitset) {
|
||||
const faiss::BitsetView bitset) {
|
||||
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
|
||||
fiu_do_on("GPUIVF.search_impl.invald_index", device_index = nullptr);
|
||||
if (device_index) {
|
||||
|
|
|
@ -51,8 +51,7 @@ class GPUIVF : public IVF, public GPUIndex {
|
|||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(
|
||||
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView) override;
|
||||
};
|
||||
|
||||
using GPUIVFPtr = std::shared_ptr<GPUIVF>;
|
||||
|
|
|
@ -28,16 +28,13 @@ void
|
|||
GPUIVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
gpu_id_ = config[knowhere::meta::DEVICEID];
|
||||
|
||||
auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (gpu_res != nullptr) {
|
||||
ResScope rs(gpu_res, gpu_id_, true);
|
||||
auto device_index =
|
||||
new faiss::gpu::GpuIndexIVFPQ(gpu_res->faiss_res.get(), dim, config[IndexParams::nlist].get<int64_t>(),
|
||||
config[IndexParams::m], config[IndexParams::nbits],
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>())); // IP not support
|
||||
auto device_index = new faiss::gpu::GpuIndexIVFPQ(
|
||||
gpu_res->faiss_res.get(), dim, config[IndexParams::nlist].get<int64_t>(), config[IndexParams::m],
|
||||
config[IndexParams::nbits], GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
device_index->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
|
||||
index_.reset(device_index);
|
||||
res_ = gpu_res;
|
||||
} else {
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/IndexScalarQuantizer.h>
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#include <faiss/index_factory.h>
|
||||
#include <faiss/gpu/GpuIndexIVFScalarQuantizer.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -30,24 +30,13 @@ void
|
|||
GPUIVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
gpu_id_ = config[knowhere::meta::DEVICEID];
|
||||
|
||||
// std::stringstream index_type;
|
||||
// index_type << "IVF" << config[IndexParams::nlist] << ","
|
||||
// << "SQ" << config[IndexParams::nbits];
|
||||
// faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
// auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type);
|
||||
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
|
||||
auto build_index = new faiss::IndexIVFScalarQuantizer(
|
||||
coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(), faiss::QuantizerType::QT_8bit, metric_type);
|
||||
|
||||
auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (gpu_res != nullptr) {
|
||||
ResScope rs(gpu_res, gpu_id_, true);
|
||||
auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index);
|
||||
auto device_index = new faiss::gpu::GpuIndexIVFScalarQuantizer(
|
||||
gpu_res->faiss_res.get(), dim, config[IndexParams::nlist].get<int64_t>(), faiss::QuantizerType::QT_8bit,
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
device_index->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
|
||||
index_.reset(device_index);
|
||||
res_ = gpu_res;
|
||||
} else {
|
||||
|
|
|
@ -50,6 +50,7 @@ IVFSQHybrid::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
res_ = gpu_res;
|
||||
gpu_mode_ = 2;
|
||||
} else {
|
||||
delete build_index;
|
||||
KNOWHERE_THROW_MSG("Build IVFSQHybrid can't get gpu resource");
|
||||
}
|
||||
|
||||
|
@ -131,16 +132,15 @@ IVFSQHybrid::LoadData(const FaissIVFQuantizerPtr& quantizer_ptr, const Config& c
|
|||
option.allInGpu = true;
|
||||
|
||||
auto ivf_quantizer = std::dynamic_pointer_cast<FaissIVFQuantizer>(quantizer_ptr);
|
||||
if (ivf_quantizer == nullptr) {
|
||||
if (ivf_quantizer == nullptr)
|
||||
KNOWHERE_THROW_MSG("quantizer type not faissivfquantizer");
|
||||
}
|
||||
|
||||
auto index_composition = new faiss::IndexComposition;
|
||||
index_composition->index = index_.get();
|
||||
index_composition->quantizer = ivf_quantizer->quantizer;
|
||||
index_composition->mode = 2; // only 2
|
||||
faiss::IndexComposition index_composition;
|
||||
index_composition.index = index_.get();
|
||||
index_composition.quantizer = ivf_quantizer->quantizer;
|
||||
index_composition.mode = 2; // only copy data
|
||||
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, &index_composition, &option);
|
||||
std::shared_ptr<faiss::Index> new_idx;
|
||||
new_idx.reset(gpu_index);
|
||||
auto sq_idx = std::make_shared<IVFSQHybrid>(new_idx, gpu_id, res);
|
||||
|
@ -159,17 +159,17 @@ IVFSQHybrid::LoadQuantizer(const Config& config) {
|
|||
faiss::gpu::GpuClonerOptions option;
|
||||
option.allInGpu = true;
|
||||
|
||||
auto index_composition = new faiss::IndexComposition;
|
||||
index_composition->index = index_.get();
|
||||
index_composition->quantizer = nullptr;
|
||||
index_composition->mode = 1; // only 1
|
||||
faiss::IndexComposition index_composition;
|
||||
index_composition.index = index_.get();
|
||||
index_composition.quantizer = nullptr;
|
||||
index_composition.mode = 1; // only copy quantizer
|
||||
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, &index_composition, &option);
|
||||
delete gpu_index;
|
||||
|
||||
auto q = std::make_shared<FaissIVFQuantizer>();
|
||||
|
||||
auto& q_ptr = index_composition->quantizer;
|
||||
auto q_ptr = index_composition.quantizer;
|
||||
q->size = q_ptr->d * q_ptr->getNumVecs() * sizeof(float);
|
||||
q->quantizer = q_ptr;
|
||||
q->gpu_id = gpu_id;
|
||||
|
@ -246,7 +246,7 @@ IVFSQHybrid::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView& bitset) {
|
||||
const faiss::BitsetView bitset) {
|
||||
if (gpu_mode_ == 2) {
|
||||
GPUIVF::QueryImpl(n, data, k, distances, labels, config, bitset);
|
||||
// index_->search(n, (float*)data, k, distances, labels);
|
||||
|
|
|
@ -88,8 +88,7 @@ class IVFSQHybrid : public GPUIVFSQ {
|
|||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(
|
||||
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView) override;
|
||||
|
||||
protected:
|
||||
int64_t gpu_mode_ = 0; // 0: CPU, 1: Hybrid, 2: GPU
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "faiss/impl/AuxIndexStructures.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
/***********************************************************************
|
||||
* DynamicResultSet
|
||||
***********************************************************************/
|
||||
|
||||
void
|
||||
DynamicResultSet::AlloctionImpl() {
|
||||
if (count <= 0) {
|
||||
KNOWHERE_THROW_MSG("DynamicResultSet::do_alloction failed because of count <= 0");
|
||||
}
|
||||
labels = std::shared_ptr<idx_t[]>(new idx_t[count], std::default_delete<idx_t[]>());
|
||||
distances = std::shared_ptr<float[]>(new float[count], std::default_delete<float[]>());
|
||||
// labels = std::make_shared<idx_t []>(new idx_t[count], std::default_delete<idx_t[]>());
|
||||
// distances = std::make_shared<float []>(new float[count], std::default_delete<float[]>());
|
||||
}
|
||||
|
||||
void
|
||||
DynamicResultSet::SortImpl(ResultSetPostProcessType postProcessType) {
|
||||
if (postProcessType == ResultSetPostProcessType::SortAsc) {
|
||||
quick_sort<true>(0, count);
|
||||
} else if (postProcessType == ResultSetPostProcessType::SortDesc) {
|
||||
quick_sort<false>(0, count);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("invalid sort type!");
|
||||
}
|
||||
}
|
||||
|
||||
template <bool asc>
|
||||
void
|
||||
DynamicResultSet::quick_sort(size_t lp, size_t rp) {
|
||||
auto len = rp - lp;
|
||||
if (len <= 1) {
|
||||
return;
|
||||
}
|
||||
auto pvot = lp + (len >> 1);
|
||||
size_t low = lp;
|
||||
size_t high = rp - 1;
|
||||
auto pids = labels.get();
|
||||
auto pdis = distances.get();
|
||||
std::swap(pdis[pvot], pdis[high]);
|
||||
std::swap(pids[pvot], pids[high]);
|
||||
if (asc) {
|
||||
while (low < high) {
|
||||
while (low < high && pdis[low] <= pdis[high]) {
|
||||
low++;
|
||||
}
|
||||
if (low == high) {
|
||||
break;
|
||||
}
|
||||
std::swap(pdis[low], pdis[high]);
|
||||
std::swap(pids[low], pids[high]);
|
||||
high--;
|
||||
while (low < high && pdis[high] >= pdis[low]) {
|
||||
high--;
|
||||
}
|
||||
if (low == high) {
|
||||
break;
|
||||
}
|
||||
std::swap(pdis[low], pdis[high]);
|
||||
std::swap(pids[low], pids[high]);
|
||||
low++;
|
||||
}
|
||||
} else {
|
||||
while (low < high) {
|
||||
while (low < high && pdis[low] >= pdis[high]) {
|
||||
low++;
|
||||
}
|
||||
if (low == high) {
|
||||
break;
|
||||
}
|
||||
std::swap(pdis[low], pdis[high]);
|
||||
std::swap(pids[low], pids[high]);
|
||||
high--;
|
||||
while (low < high && pdis[high] <= pdis[low]) {
|
||||
high--;
|
||||
}
|
||||
if (low == high) {
|
||||
break;
|
||||
}
|
||||
std::swap(pdis[low], pdis[high]);
|
||||
std::swap(pids[low], pids[high]);
|
||||
low++;
|
||||
}
|
||||
}
|
||||
quick_sort<asc>(lp, low);
|
||||
quick_sort<asc>(low, rp);
|
||||
}
|
||||
|
||||
DynamicResultSet
|
||||
DynamicResultCollector::Merge(size_t limit, ResultSetPostProcessType postProcessType) {
|
||||
if (limit <= 0) {
|
||||
KNOWHERE_THROW_MSG("limit must > 0!");
|
||||
}
|
||||
DynamicResultSet ret;
|
||||
auto seg_num = seg_results.size();
|
||||
std::vector<size_t> boundaries(seg_num + 1, 0);
|
||||
#pragma omp parallel for
|
||||
for (auto i = 0; i < seg_num; ++i) {
|
||||
for (auto& pseg : seg_results[i]) {
|
||||
boundaries[i] += (pseg->buffer_size * pseg->buffers.size() - pseg->buffer_size + pseg->wp);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0, ofs = 0; i <= seg_num; ++i) {
|
||||
auto bn = boundaries[i];
|
||||
boundaries[i] = ofs;
|
||||
ofs += bn;
|
||||
// boundaries[i] += boundaries[i - 1];
|
||||
}
|
||||
ret.count = boundaries[seg_num] <= limit ? boundaries[seg_num] : limit;
|
||||
ret.AlloctionImpl();
|
||||
|
||||
// abandon redundancy answers randomly
|
||||
// abandon strategy: keep the top limit sequentially
|
||||
int32_t pos = 1;
|
||||
for (int i = 1; i < boundaries.size(); ++i) {
|
||||
if (boundaries[i] >= ret.count) {
|
||||
pos = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
pos--; // last segment id
|
||||
// full copy
|
||||
#pragma omp parallel for
|
||||
for (auto i = 0; i < pos; ++i) {
|
||||
for (auto& pseg : seg_results[i]) {
|
||||
auto len = pseg->buffers.size() * pseg->buffer_size - pseg->buffer_size + pseg->wp;
|
||||
pseg->copy_range(0, len, ret.labels.get() + boundaries[i], ret.distances.get() + boundaries[i]);
|
||||
boundaries[i] += len;
|
||||
}
|
||||
}
|
||||
// partial copy
|
||||
auto last_len = ret.count - boundaries[pos];
|
||||
for (auto& pseg : seg_results[pos]) {
|
||||
auto len = pseg->buffers.size() * pseg->buffer_size - pseg->buffer_size + pseg->wp;
|
||||
auto ncopy = last_len > len ? len : last_len;
|
||||
pseg->copy_range(0, ncopy, ret.labels.get() + boundaries[pos], ret.distances.get() + boundaries[pos]);
|
||||
boundaries[pos] += ncopy;
|
||||
last_len -= ncopy;
|
||||
if (last_len <= 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (postProcessType != ResultSetPostProcessType::None) {
|
||||
ret.SortImpl(postProcessType);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
DynamicResultCollector::Append(milvus::knowhere::DynamicResultSegment&& seg_result) {
|
||||
seg_results.emplace_back(std::move(seg_result));
|
||||
}
|
||||
|
||||
void
|
||||
ExchangeDataset(DynamicResultSegment& milvus_dataset, std::vector<faiss::RangeSearchPartialResult*>& faiss_dataset) {
|
||||
for (auto& prspr : faiss_dataset) {
|
||||
auto mrspr = std::make_shared<DynamicResultFragment>(prspr->res->buffer_size);
|
||||
mrspr->wp = prspr->wp;
|
||||
mrspr->buffers.resize(prspr->buffers.size());
|
||||
for (auto i = 0; i < prspr->buffers.size(); ++i) {
|
||||
mrspr->buffers[i].ids = prspr->buffers[i].ids;
|
||||
mrspr->buffers[i].dis = prspr->buffers[i].dis;
|
||||
prspr->buffers[i].ids = nullptr;
|
||||
prspr->buffers[i].dis = nullptr;
|
||||
}
|
||||
delete prspr->res;
|
||||
delete prspr;
|
||||
milvus_dataset.push_back(mrspr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,85 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "faiss/impl/AuxIndexStructures.h"
|
||||
#include "knowhere/common/Typedef.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
enum class ResultSetPostProcessType { None = 0, SortDesc, SortAsc };
|
||||
using idx_t = int64_t;
|
||||
|
||||
/*
|
||||
* Class: Dynamic result set (merged results)
|
||||
*/
|
||||
struct DynamicResultSet {
|
||||
std::shared_ptr<idx_t[]> labels; /// result for query i is labels[lims[i]:lims[i + 1]]
|
||||
std::shared_ptr<float[]> distances; /// corresponding distances, not sorted
|
||||
size_t count; /// size of the result buffer's size, when reaches this size, auto start a new buffer
|
||||
|
||||
void
|
||||
AlloctionImpl();
|
||||
|
||||
void
|
||||
SortImpl(ResultSetPostProcessType postProcessType = ResultSetPostProcessType::SortAsc);
|
||||
|
||||
private:
|
||||
template <bool asc>
|
||||
void
|
||||
quick_sort(size_t lp, size_t rp);
|
||||
};
|
||||
|
||||
// BufferPool (inner classes)
|
||||
typedef faiss::BufferList DynamicResultFragment;
|
||||
typedef std::shared_ptr<DynamicResultFragment> DynamicResultFragmentPtr;
|
||||
typedef std::vector<DynamicResultFragmentPtr> DynamicResultSegment;
|
||||
|
||||
/*
|
||||
* Class: Dynamic result collector
|
||||
* Example:
|
||||
DynamicResultCollector collector;
|
||||
for (auto &seg: segments) {
|
||||
auto seg_rst = seg.QueryByDistance(xxx);
|
||||
collector.append(seg_rst);
|
||||
}
|
||||
auto rst = collector.merge();
|
||||
*/
|
||||
struct DynamicResultCollector {
|
||||
public:
|
||||
/*
|
||||
* Merge the results of segments
|
||||
* Notes: Now, we apply limit before sort.
|
||||
* It can be updated if you don't expect the policy.
|
||||
*/
|
||||
DynamicResultSet
|
||||
Merge(size_t limit = 10000, ResultSetPostProcessType postProcessType = ResultSetPostProcessType::None);
|
||||
|
||||
/*
|
||||
* Collect the results of segments
|
||||
*/
|
||||
void
|
||||
Append(DynamicResultSegment&& seg_result);
|
||||
|
||||
private:
|
||||
std::vector<DynamicResultSegment> seg_results; /// unmerged results of every segments
|
||||
};
|
||||
|
||||
void
|
||||
ExchangeDataset(DynamicResultSegment& milvus_dataset, std::vector<faiss::RangeSearchPartialResult*>& faiss_dataset);
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
#include <faiss/gpu/StandardGpuResources.h>
|
||||
|
||||
#include "src/utils/BlockingQueue.h"
|
||||
#include "utils/BlockingQueue.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
|
|
@ -28,6 +28,10 @@ constexpr const char* DEVICEID = "gpu_id";
|
|||
}; // namespace meta
|
||||
|
||||
namespace IndexParams {
|
||||
// Range Search Params
|
||||
constexpr const char* range_search_radius = "range_search_radius";
|
||||
constexpr const char* range_search_buffer_size = "range_search_buffer_size";
|
||||
|
||||
// IVF Params
|
||||
constexpr const char* nprobe = "nprobe";
|
||||
constexpr const char* nlist = "nlist";
|
||||
|
|
|
@ -127,6 +127,7 @@ NsgIndex::InitNavigationPoint(float* data) {
|
|||
//
|
||||
// float r1 = distance_->Compare(center, ori_data_ + navigation_point * dimension, dimension);
|
||||
// assert(r1 == resset[0].distance);
|
||||
delete[] center;
|
||||
}
|
||||
|
||||
// Specify Link
|
||||
|
@ -871,7 +872,7 @@ NsgIndex::Search(const float* query,
|
|||
float* dist,
|
||||
int64_t* ids,
|
||||
SearchParams& params,
|
||||
const faiss::BitsetView& bitset) {
|
||||
const faiss::BitsetView bitset) {
|
||||
std::vector<std::vector<Neighbor>> resset(nq);
|
||||
|
||||
TimeRecorder rc("NsgIndex::search", 1);
|
||||
|
|
|
@ -91,7 +91,7 @@ class NsgIndex {
|
|||
float* dist,
|
||||
int64_t* ids,
|
||||
SearchParams& params,
|
||||
const faiss::BitsetView& bitset = nullptr);
|
||||
const faiss::BitsetView bitset);
|
||||
|
||||
int64_t
|
||||
GetSize();
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#include <faiss/IndexIVFFlat.h>
|
||||
#include <faiss/IndexIVFPQ.h>
|
||||
#include <faiss/clone_index.h>
|
||||
#include <faiss/index_factory.h>
|
||||
#include <faiss/index_io.h>
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include <faiss/gpu/GpuAutoTune.h>
|
||||
|
@ -138,13 +137,24 @@ IVF_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
GET_TENSOR_DATA(dataset_ptr)
|
||||
|
||||
int64_t* p_id = nullptr;
|
||||
float* p_dist = nullptr;
|
||||
auto release_when_exception = [&]() {
|
||||
if (p_id != nullptr) {
|
||||
free(p_id);
|
||||
}
|
||||
if (p_dist != nullptr) {
|
||||
free(p_dist);
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
fiu_do_on("IVF_NM.Search.throw_std_exception", throw std::exception());
|
||||
fiu_do_on("IVF_NM.Search.throw_faiss_exception", throw faiss::FaissException(""));
|
||||
|
@ -164,8 +174,10 @@ IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::
|
|||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
return ret_ds;
|
||||
} catch (faiss::FaissException& e) {
|
||||
release_when_exception();
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
} catch (std::exception& e) {
|
||||
release_when_exception();
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
@ -312,7 +324,7 @@ IVF_NM::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView& bitset) {
|
||||
const faiss::BitsetView bitset) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
|
|
|
@ -50,7 +50,7 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView bitset) override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
|
@ -91,7 +91,7 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
|
|||
GenParams(const Config&);
|
||||
|
||||
virtual void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset);
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView bitset);
|
||||
|
||||
void
|
||||
SealImpl() override;
|
||||
|
|
|
@ -76,7 +76,7 @@ NSG_NM::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ class NSG_NM : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -122,7 +122,7 @@ GPUIVF_NM::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::BitsetView& bitset) {
|
||||
const faiss::BitsetView bitset) {
|
||||
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
|
||||
fiu_do_on("GPUIVF_NM.search_impl.invald_index", device_index = nullptr);
|
||||
if (device_index) {
|
||||
|
|
|
@ -51,8 +51,7 @@ class GPUIVF_NM : public IVF, public GPUIndex {
|
|||
SerializeImpl(const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(
|
||||
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView) override;
|
||||
|
||||
protected:
|
||||
uint8_t* arranged_data;
|
||||
|
|
|
@ -599,7 +599,7 @@ NeighborhoodGraph::setupSeeds(NGT::SearchContainer &sc, ObjectDistances &seeds,
|
|||
}
|
||||
|
||||
// for milvus
|
||||
void NeighborhoodGraph::search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView& bitset)
|
||||
void NeighborhoodGraph::search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView bitset)
|
||||
{
|
||||
if (sc.explorationCoefficient == 0.0)
|
||||
{
|
||||
|
|
|
@ -698,7 +698,7 @@ namespace NGT {
|
|||
|
||||
void search(NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
// for milvus
|
||||
void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView&bitset);
|
||||
void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView bitset);
|
||||
|
||||
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
|
||||
template <typename COMPARATOR, typename CHECK_LIST> void searchReadOnlyGraph(NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
|
|
|
@ -655,7 +655,7 @@ public:
|
|||
virtual void linearSearch(NGT::SearchContainer & sc) { getIndex().linearSearch(sc); }
|
||||
virtual void linearSearch(NGT::SearchQuery & sc) { getIndex().linearSearch(sc); }
|
||||
// for milvus
|
||||
virtual void search(NGT::SearchContainer & sc, const faiss::BitsetView&bitset) { getIndex().search(sc, bitset); }
|
||||
virtual void search(NGT::SearchContainer & sc, const faiss::BitsetView bitset) { getIndex().search(sc, bitset); }
|
||||
virtual void search(NGT::SearchContainer & sc) { getIndex().search(sc); }
|
||||
virtual void search(NGT::SearchQuery & sc) { getIndex().search(sc); }
|
||||
virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds) { getIndex().search(sc, seeds); }
|
||||
|
@ -1058,7 +1058,7 @@ public:
|
|||
}
|
||||
|
||||
// for milvus
|
||||
virtual void search(NGT::SearchContainer & sc, const faiss::BitsetView&bitset)
|
||||
virtual void search(NGT::SearchContainer & sc, const faiss::BitsetView bitset)
|
||||
{
|
||||
sc.distanceComputationCount = 0;
|
||||
sc.visitCount = 0;
|
||||
|
@ -1586,7 +1586,7 @@ protected:
|
|||
}
|
||||
|
||||
// for milvus
|
||||
virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView&bitset)
|
||||
virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView bitset)
|
||||
{
|
||||
if (sc.size == 0)
|
||||
{
|
||||
|
@ -2147,7 +2147,7 @@ public:
|
|||
|
||||
// for milvus
|
||||
void
|
||||
getSeedsFromTree(NGT::SearchContainer& sc, ObjectDistances& seeds, const faiss::BitsetView& bitset) {
|
||||
getSeedsFromTree(NGT::SearchContainer& sc, ObjectDistances& seeds, const faiss::BitsetView bitset) {
|
||||
DVPTree::SearchContainer tso(sc.object);
|
||||
tso.mode = DVPTree::SearchContainer::SearchLeaf;
|
||||
tso.radius = 0.0;
|
||||
|
@ -2204,7 +2204,7 @@ public:
|
|||
}
|
||||
|
||||
// for milvus
|
||||
void search(NGT::SearchContainer & sc, const faiss::BitsetView&bitset)
|
||||
void search(NGT::SearchContainer & sc, const faiss::BitsetView bitset)
|
||||
{
|
||||
sc.distanceComputationCount = 0;
|
||||
sc.visitCount = 0;
|
||||
|
|
|
@ -263,7 +263,7 @@ namespace NGT {
|
|||
|
||||
// for milvus
|
||||
void
|
||||
getObjectIDsFromLeaf(Node::ID nid, ObjectDistances& rl, const faiss::BitsetView& bitset) {
|
||||
getObjectIDsFromLeaf(Node::ID nid, ObjectDistances& rl, const faiss::BitsetView bitset) {
|
||||
LeafNode& ln = *(LeafNode*)getNode(nid);
|
||||
rl.clear();
|
||||
ObjectDistance r;
|
||||
|
|
|
@ -120,8 +120,8 @@ inline void set_error_from_string(char **error, const char* msg) {
|
|||
#include <intrin.h>
|
||||
#elif defined(__GNUC__)
|
||||
#include <x86intrin.h>
|
||||
#include <faiss/utils/ConcurrentBitset.h>
|
||||
#include <faiss/utils/BitsetView.h>
|
||||
#include "faiss/utils/ConcurrentBitset.h"
|
||||
#include "faiss/utils/BitsetView.h"
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
@ -839,9 +839,9 @@ class AnnoyIndexInterface {
|
|||
virtual bool load_index(void* index_data, const int64_t& index_size, char** error = nullptr) = 0;
|
||||
virtual T get_distance(S i, S j) const = 0;
|
||||
virtual void get_nns_by_item(S item, size_t n, int64_t search_k, vector<S>* result, vector<T>* distances,
|
||||
const faiss::BitsetView& bitset = nullptr) const = 0;
|
||||
const faiss::BitsetView bitset = nullptr) const = 0;
|
||||
virtual void get_nns_by_vector(const T* w, size_t n, int64_t search_k, vector<S>* result, vector<T>* distances,
|
||||
const faiss::BitsetView& bitset = nullptr) const = 0;
|
||||
const faiss::BitsetView bitset = nullptr) const = 0;
|
||||
virtual S get_n_items() const = 0;
|
||||
virtual S get_dim() const = 0;
|
||||
virtual S get_n_trees() const = 0;
|
||||
|
@ -1178,14 +1178,14 @@ public:
|
|||
}
|
||||
|
||||
void get_nns_by_item(S item, size_t n, int64_t search_k, vector<S>* result, vector<T>* distances,
|
||||
const faiss::BitsetView& bitset) const {
|
||||
const faiss::BitsetView bitset) const {
|
||||
// TODO: handle OOB
|
||||
const Node* m = _get(item);
|
||||
_get_all_nns(m->v, n, search_k, result, distances, bitset);
|
||||
}
|
||||
|
||||
void get_nns_by_vector(const T* w, size_t n, int64_t search_k, vector<S>* result, vector<T>* distances,
|
||||
const faiss::BitsetView& bitset) const {
|
||||
const faiss::BitsetView bitset) const {
|
||||
_get_all_nns(w, n, search_k, result, distances, bitset);
|
||||
}
|
||||
|
||||
|
@ -1335,7 +1335,7 @@ protected:
|
|||
}
|
||||
|
||||
void _get_all_nns(const T* v, size_t n, int64_t search_k, vector<S>* result, vector<T>* distances,
|
||||
const faiss::BitsetView& bitset) const {
|
||||
const faiss::BitsetView bitset) const {
|
||||
Node* v_node = (Node *)alloca(_s);
|
||||
D::template zero_value<Node>(v_node);
|
||||
memcpy(v_node->v, v, sizeof(T) * _f);
|
||||
|
|
|
@ -262,6 +262,7 @@ int split_clusters (size_t d, size_t k, size_t n,
|
|||
};
|
||||
|
||||
ClusteringType clustering_type = ClusteringType::K_MEANS;
|
||||
double early_stop_threshold = 0.0;
|
||||
|
||||
void Clustering::kmeans_algorithm(std::vector<int>& centroids_index, int64_t random_seed,
|
||||
size_t n_input_centroids, size_t d, size_t k,
|
||||
|
@ -532,6 +533,7 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|||
// k-means iterations
|
||||
|
||||
float err = 0;
|
||||
float prev_objective = 0;
|
||||
for (int i = 0; i < niter; i++) {
|
||||
double t0s = getmillisecs();
|
||||
|
||||
|
@ -600,6 +602,14 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|||
}
|
||||
|
||||
index.add (k, centroids.data());
|
||||
|
||||
// Early stop strategy
|
||||
float diff = (prev_objective == 0) ? std::numeric_limits<float>::max() : (prev_objective - stats.obj) / prev_objective;
|
||||
prev_objective = stats.obj;
|
||||
if (diff < early_stop_threshold / 100.) {
|
||||
break;
|
||||
}
|
||||
|
||||
InterruptCallback::check ();
|
||||
}
|
||||
|
||||
|
|
|
@ -25,9 +25,12 @@ enum ClusteringType
|
|||
K_MEANS_TWO,
|
||||
};
|
||||
|
||||
//The default algorithm use the K_MEANS
|
||||
// The default algorithm use the K_MEANS
|
||||
extern ClusteringType clustering_type;
|
||||
|
||||
// K-Means Early Stop Threshold; defaults to 0.0
|
||||
extern double early_stop_threshold;
|
||||
|
||||
|
||||
/** Class for the clustering parameters. Can be passed to the
|
||||
* constructor of the Clustering object.
|
||||
|
|
|
@ -31,7 +31,7 @@ void Index::train(idx_t /*n*/, const float* /*x*/) {
|
|||
|
||||
void Index::range_search (idx_t , const float *, float,
|
||||
RangeSearchResult *,
|
||||
const BitsetView&) const
|
||||
const BitsetView) const
|
||||
{
|
||||
FAISS_THROW_MSG ("range search not implemented");
|
||||
}
|
||||
|
@ -62,12 +62,12 @@ void Index::add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xid
|
|||
}
|
||||
|
||||
#if 0
|
||||
void Index::get_vector_by_id (idx_t n, const idx_t *xid, float *x, const BitsetView& bitset) {
|
||||
void Index::get_vector_by_id (idx_t n, const idx_t *xid, float *x, const BitsetView bitset) {
|
||||
FAISS_THROW_MSG ("get_vector_by_id not implemented for this type of index");
|
||||
}
|
||||
|
||||
void Index::search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels,
|
||||
const BitsetView& bitset) {
|
||||
const BitsetView bitset) {
|
||||
FAISS_THROW_MSG ("search_by_id not implemented for this type of index");
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -129,7 +129,7 @@ struct Index {
|
|||
*/
|
||||
virtual void search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) const = 0;
|
||||
const BitsetView bitset = nullptr) const = 0;
|
||||
|
||||
#if 0
|
||||
/** query n raw vectors from the index by ids.
|
||||
|
@ -141,7 +141,7 @@ struct Index {
|
|||
* @param x output raw vectors, size n * d
|
||||
* @param bitset flags to check the validity of vectors
|
||||
*/
|
||||
virtual void get_vector_by_id (idx_t n, const idx_t *xid, float *x, const BitsetView& bitset = nullptr);
|
||||
virtual void get_vector_by_id (idx_t n, const idx_t *xid, float *x, const BitsetView bitset = nullptr);
|
||||
|
||||
/** query n vectors of dimension d to the index by ids.
|
||||
*
|
||||
|
@ -154,7 +154,7 @@ struct Index {
|
|||
* @param bitset flags to check the validity of vectors
|
||||
*/
|
||||
virtual void search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr);
|
||||
const BitsetView bitset = nullptr);
|
||||
#endif
|
||||
|
||||
/** query n vectors of dimension d to the index.
|
||||
|
@ -169,7 +169,7 @@ struct Index {
|
|||
*/
|
||||
virtual void range_search (idx_t n, const float *x, float radius,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset = nullptr) const;
|
||||
const BitsetView bitset = nullptr) const;
|
||||
|
||||
/** return the indexes of the k vectors closest to the query x.
|
||||
*
|
||||
|
|
|
@ -166,7 +166,7 @@ void Index2Layer::search(
|
|||
idx_t /*k*/,
|
||||
float* /*distances*/,
|
||||
idx_t* /*labels*/,
|
||||
const BitsetView&) const {
|
||||
const BitsetView) const {
|
||||
FAISS_THROW_MSG("not implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ struct Index2Layer: Index {
|
|||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ void IndexBinary::train(idx_t, const uint8_t *) {
|
|||
|
||||
void IndexBinary::range_search(idx_t, const uint8_t *, int,
|
||||
RangeSearchResult *,
|
||||
const BitsetView&) const {
|
||||
const BitsetView) const {
|
||||
FAISS_THROW_MSG("range search not implemented");
|
||||
}
|
||||
|
||||
|
@ -37,12 +37,12 @@ void IndexBinary::add_with_ids(idx_t, const uint8_t *, const idx_t *) {
|
|||
}
|
||||
|
||||
#if 0
|
||||
void IndexBinary::get_vector_by_id (idx_t n, const idx_t *xid, uint8_t *x, const BitsetView& bitset) {
|
||||
void IndexBinary::get_vector_by_id (idx_t n, const idx_t *xid, uint8_t *x, const BitsetView bitset) {
|
||||
FAISS_THROW_MSG("get_vector_by_id not implemented for this type of index");
|
||||
}
|
||||
|
||||
void IndexBinary::search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset) {
|
||||
const BitsetView bitset) {
|
||||
FAISS_THROW_MSG("search_by_id not implemented for this type of index");
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -97,7 +97,7 @@ struct IndexBinary {
|
|||
*/
|
||||
virtual void search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) const = 0;
|
||||
const BitsetView bitset = nullptr) const = 0;
|
||||
|
||||
#if 0
|
||||
/** Query n raw vectors from the index by ids.
|
||||
|
@ -109,7 +109,7 @@ struct IndexBinary {
|
|||
* @param x output raw vectors, size n * d
|
||||
* @param bitset flags to check the validity of vectors
|
||||
*/
|
||||
virtual void get_vector_by_id (idx_t n, const idx_t *xid, uint8_t *x, const BitsetView& bitset = nullptr);
|
||||
virtual void get_vector_by_id (idx_t n, const idx_t *xid, uint8_t *x, const BitsetView bitset = nullptr);
|
||||
|
||||
/** query n vectors of dimension d to the index by ids.
|
||||
*
|
||||
|
@ -122,7 +122,7 @@ struct IndexBinary {
|
|||
* @param bitset flags to check the validity of vectors
|
||||
*/
|
||||
virtual void search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr);
|
||||
const BitsetView bitset = nullptr);
|
||||
#endif
|
||||
|
||||
/** Query n vectors of dimension d to the index.
|
||||
|
@ -141,7 +141,7 @@ struct IndexBinary {
|
|||
*/
|
||||
virtual void range_search(idx_t n, const uint8_t *x, int radius,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset = nullptr) const;
|
||||
const BitsetView bitset = nullptr) const;
|
||||
|
||||
/** Return the indexes of the k vectors closest to the query x.
|
||||
*
|
||||
|
|
|
@ -40,61 +40,35 @@ void IndexBinaryFlat::reset() {
|
|||
|
||||
void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset) const {
|
||||
const idx_t block_size = query_batch_size;
|
||||
const BitsetView bitset) const {
|
||||
|
||||
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
|
||||
float *D = reinterpret_cast<float*>(distances);
|
||||
for (idx_t s = 0; s < n; s += block_size) {
|
||||
idx_t nn = block_size;
|
||||
if (s + block_size > n) {
|
||||
nn = n - s;
|
||||
}
|
||||
float_maxheap_array_t res = {
|
||||
size_t(n), size_t(k), labels, D
|
||||
};
|
||||
binary_distance_knn_hc(METRIC_Jaccard, &res, x, xb.data(), ntotal, code_size, bitset);
|
||||
|
||||
// We see the distances and labels as heaps.
|
||||
float_maxheap_array_t res = {
|
||||
size_t(nn), size_t(k), labels + s * k, D + s * k
|
||||
};
|
||||
|
||||
binary_distence_knn_hc(metric_type, &res, x + s * code_size, xb.data(), ntotal, code_size,
|
||||
/* ordered = */ true, bitset);
|
||||
|
||||
}
|
||||
if (metric_type == METRIC_Tanimoto) {
|
||||
for (int i = 0; i < k * n; i++) {
|
||||
D[i] = -log2(1-D[i]);
|
||||
D[i] = Jaccard_2_Tanimoto(D[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} else if (metric_type == METRIC_Hamming) {
|
||||
int_maxheap_array_t res = {
|
||||
size_t(n), size_t(k), labels, distances
|
||||
};
|
||||
binary_distance_knn_hc(METRIC_Hamming, &res, x, xb.data(), ntotal, code_size, bitset);
|
||||
|
||||
} else if (metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) {
|
||||
float *D = reinterpret_cast<float*>(distances);
|
||||
for (idx_t s = 0; s < n; s += block_size) {
|
||||
idx_t nn = block_size;
|
||||
if (s + block_size > n) {
|
||||
nn = n - s;
|
||||
}
|
||||
|
||||
// only match ids will be chosed, not to use heap
|
||||
binary_distence_knn_mc(metric_type, x + s * code_size, xb.data(), nn, ntotal, k, code_size,
|
||||
D + s * k, labels + s * k, bitset);
|
||||
}
|
||||
// only matched ids will be chosen, not to use heap
|
||||
binary_distance_knn_mc(metric_type, x, xb.data(), n, ntotal, k, code_size,
|
||||
D, labels, bitset);
|
||||
} else {
|
||||
for (idx_t s = 0; s < n; s += block_size) {
|
||||
idx_t nn = block_size;
|
||||
if (s + block_size > n) {
|
||||
nn = n - s;
|
||||
}
|
||||
if (use_heap) {
|
||||
// We see the distances and labels as heaps.
|
||||
int_maxheap_array_t res = {
|
||||
size_t(nn), size_t(k), labels + s * k, distances + s * k
|
||||
};
|
||||
|
||||
hammings_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size,
|
||||
/* ordered = */ true, bitset);
|
||||
} else {
|
||||
hammings_knn_mc(x + s * code_size, xb.data(), nn, ntotal, k, code_size,
|
||||
distances + s * k, labels + s * k, bitset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,9 +98,43 @@ void IndexBinaryFlat::reconstruct(idx_t key, uint8_t *recons) const {
|
|||
|
||||
void IndexBinaryFlat::range_search(idx_t n, const uint8_t *x, int radius,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result);
|
||||
FAISS_THROW_MSG("This interface is abandoned yet.");
|
||||
}
|
||||
|
||||
void IndexBinaryFlat::range_search(faiss::IndexBinary::idx_t n,
|
||||
const uint8_t* x,
|
||||
float radius,
|
||||
std::vector<faiss::RangeSearchPartialResult*>& result,
|
||||
size_t buffer_size,
|
||||
const faiss::BitsetView bitset)
|
||||
{
|
||||
switch (metric_type) {
|
||||
case METRIC_Jaccard: {
|
||||
binary_range_search<CMax<float, int64_t>, float>(METRIC_Jaccard, x, xb.data(), n, ntotal, radius, code_size, result, buffer_size, bitset);
|
||||
break;
|
||||
}
|
||||
case METRIC_Tanimoto: {
|
||||
binary_range_search<CMax<float, int64_t>, float>(METRIC_Tanimoto, x, xb.data(), n, ntotal, radius, code_size, result, buffer_size, bitset);
|
||||
break;
|
||||
}
|
||||
case METRIC_Hamming: {
|
||||
binary_range_search<CMax<int, int64_t>, int>(METRIC_Hamming, x, xb.data(), n, ntotal, static_cast<int>(radius), code_size, result, buffer_size, bitset);
|
||||
break;
|
||||
}
|
||||
case METRIC_Superstructure: {
|
||||
binary_range_search<CMin<bool, int64_t>, bool>(METRIC_Superstructure, x, xb.data(), n, ntotal, false, code_size, result, buffer_size, bitset);
|
||||
break;
|
||||
}
|
||||
case METRIC_Substructure: {
|
||||
binary_range_search<CMin<bool, int64_t>, bool>(METRIC_Substructure, x, xb.data(), n, ntotal, false, code_size, result, buffer_size, bitset);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
//hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result, buffer_size, bitset);
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include <faiss/IndexBinary.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
@ -39,11 +40,16 @@ struct IndexBinaryFlat : IndexBinary {
|
|||
|
||||
void search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void range_search(idx_t n, const uint8_t *x, int radius,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void range_search(idx_t n, const uint8_t *x, float radius,
|
||||
std::vector<RangeSearchPartialResult*> &result,
|
||||
size_t buffer_size,
|
||||
const BitsetView bitset = nullptr); // const override
|
||||
|
||||
void reconstruct(idx_t key, uint8_t *recons) const override;
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ void IndexBinaryFromFloat::reset() {
|
|||
|
||||
void IndexBinaryFromFloat::search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset) const {
|
||||
const BitsetView bitset) const {
|
||||
constexpr idx_t bs = 32768;
|
||||
std::unique_ptr<float[]> xf(new float[bs * d]);
|
||||
std::unique_ptr<float[]> df(new float[bs * k]);
|
||||
|
|
|
@ -42,7 +42,7 @@ struct IndexBinaryFromFloat : IndexBinary {
|
|||
|
||||
void search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void train(idx_t n, const uint8_t *x) override;
|
||||
};
|
||||
|
|
|
@ -197,7 +197,7 @@ void IndexBinaryHNSW::train(idx_t n, const uint8_t *x)
|
|||
|
||||
void IndexBinaryHNSW::search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
#pragma omp parallel
|
||||
{
|
||||
|
@ -260,12 +260,12 @@ struct FlatHammingDis : DistanceComputer {
|
|||
|
||||
float operator () (idx_t i) override {
|
||||
ndis++;
|
||||
return hc.hamming(b + i * code_size);
|
||||
return hc.compute(b + i * code_size);
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) override {
|
||||
return HammingComputerDefault(b + j * code_size, code_size)
|
||||
.hamming(b + i * code_size);
|
||||
.compute(b + i * code_size);
|
||||
}
|
||||
|
||||
|
||||
|
@ -312,11 +312,7 @@ DistanceComputer *IndexBinaryHNSW::get_distance_computer() const {
|
|||
case 64:
|
||||
return new FlatHammingDis<HammingComputer64>(*flat_storage);
|
||||
default:
|
||||
if (code_size % 8 == 0) {
|
||||
return new FlatHammingDis<HammingComputerM8>(*flat_storage);
|
||||
} else if (code_size % 4 == 0) {
|
||||
return new FlatHammingDis<HammingComputerM4>(*flat_storage);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
|
||||
|
|
|
@ -46,7 +46,7 @@ struct IndexBinaryHNSW : IndexBinary {
|
|||
/// entry point for search
|
||||
void search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void reconstruct(idx_t key, uint8_t* recons) const override;
|
||||
|
||||
|
|
|
@ -173,7 +173,7 @@ search_single_query_template(const IndexBinaryHash & index, const uint8_t *q,
|
|||
} else {
|
||||
const uint8_t *codes = il.vecs.data();
|
||||
for (size_t i = 0; i < nv; i++) {
|
||||
int dis = hc.hamming (codes);
|
||||
int dis = hc.compute (codes);
|
||||
res.add(dis, il.ids[i]);
|
||||
codes += code_size;
|
||||
}
|
||||
|
@ -196,12 +196,8 @@ search_single_query(const IndexBinaryHash & index, const uint8_t *q,
|
|||
case 16: HC(HammingComputer16); break;
|
||||
case 20: HC(HammingComputer20); break;
|
||||
case 32: HC(HammingComputer32); break;
|
||||
default:
|
||||
if (index.code_size % 8 == 0) {
|
||||
HC(HammingComputerM8);
|
||||
} else {
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
default: HC(HammingComputerDefault);
|
||||
|
||||
}
|
||||
#undef HC
|
||||
}
|
||||
|
@ -213,7 +209,7 @@ search_single_query(const IndexBinaryHash & index, const uint8_t *q,
|
|||
|
||||
void IndexBinaryHash::range_search(idx_t n, const uint8_t *x, int radius,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
|
||||
size_t nlist = 0, ndis = 0, n0 = 0;
|
||||
|
@ -241,7 +237,7 @@ void IndexBinaryHash::range_search(idx_t n, const uint8_t *x, int radius,
|
|||
|
||||
void IndexBinaryHash::search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
|
||||
using HeapForL2 = CMax<int32_t, idx_t>;
|
||||
|
@ -364,7 +360,7 @@ void verify_shortlist(
|
|||
const uint8_t *codes = index.xb.data();
|
||||
|
||||
for (auto i: shortlist) {
|
||||
int dis = hc.hamming (codes + i * code_size);
|
||||
int dis = hc.compute (codes + i * code_size);
|
||||
res.add(dis, i);
|
||||
}
|
||||
}
|
||||
|
@ -416,12 +412,7 @@ search_1_query_multihash(const IndexBinaryMultiHash & index, const uint8_t *xi,
|
|||
case 16: HC(HammingComputer16); break;
|
||||
case 20: HC(HammingComputer20); break;
|
||||
case 32: HC(HammingComputer32); break;
|
||||
default:
|
||||
if (index.code_size % 8 == 0) {
|
||||
HC(HammingComputerM8);
|
||||
} else {
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
default: HC(HammingComputerDefault);
|
||||
}
|
||||
#undef HC
|
||||
}
|
||||
|
@ -430,7 +421,7 @@ search_1_query_multihash(const IndexBinaryMultiHash & index, const uint8_t *xi,
|
|||
|
||||
void IndexBinaryMultiHash::range_search(idx_t n, const uint8_t *x, int radius,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
|
||||
size_t nlist = 0, ndis = 0, n0 = 0;
|
||||
|
@ -458,7 +449,7 @@ void IndexBinaryMultiHash::range_search(idx_t n, const uint8_t *x, int radius,
|
|||
|
||||
void IndexBinaryMultiHash::search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
|
||||
using HeapForL2 = CMax<int32_t, idx_t>;
|
||||
|
|
|
@ -52,11 +52,11 @@ struct IndexBinaryHash : IndexBinary {
|
|||
|
||||
void range_search(idx_t n, const uint8_t *x, int radius,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void display() const;
|
||||
size_t hashtable_size() const;
|
||||
|
@ -105,11 +105,11 @@ struct IndexBinaryMultiHash: IndexBinary {
|
|||
|
||||
void range_search(idx_t n, const uint8_t *x, int radius,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
size_t hashtable_size() const;
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include <faiss/utils/BinaryDistance.h>
|
||||
#include <faiss/utils/hamming.h>
|
||||
#include <faiss/utils/jaccard-inl.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
@ -147,7 +148,7 @@ void IndexBinaryIVF::set_direct_map_type (DirectMap::Type type)
|
|||
|
||||
void IndexBinaryIVF::search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset) const {
|
||||
const BitsetView bitset) const {
|
||||
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
||||
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
|
||||
|
||||
|
@ -172,7 +173,7 @@ void IndexBinaryIVF::search(idx_t n, const uint8_t *x, idx_t k,
|
|||
}
|
||||
|
||||
#if 0
|
||||
void IndexBinaryIVF::get_vector_by_id(idx_t n, const idx_t *xid, uint8_t *x, const BitsetView& bitset) {
|
||||
void IndexBinaryIVF::get_vector_by_id(idx_t n, const idx_t *xid, uint8_t *x, const BitsetView bitset) {
|
||||
make_direct_map(true);
|
||||
|
||||
/* only get vector by 1 id */
|
||||
|
@ -185,7 +186,7 @@ void IndexBinaryIVF::get_vector_by_id(idx_t n, const idx_t *xid, uint8_t *x, con
|
|||
}
|
||||
|
||||
void IndexBinaryIVF::search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset) {
|
||||
const BitsetView bitset) {
|
||||
make_direct_map(true);
|
||||
|
||||
auto x = new uint8_t[n * d];
|
||||
|
@ -376,7 +377,7 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
|
|||
}
|
||||
|
||||
uint32_t distance_to_code (const uint8_t *code) const override {
|
||||
return hc.hamming (code);
|
||||
return hc.compute (code);
|
||||
}
|
||||
|
||||
size_t scan_codes (size_t n,
|
||||
|
@ -384,14 +385,14 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
|
|||
const idx_t *ids,
|
||||
int32_t *simi, idx_t *idxi,
|
||||
size_t k,
|
||||
const BitsetView& bitset) const override
|
||||
const BitsetView bitset) const override
|
||||
{
|
||||
using C = CMax<int32_t, idx_t>;
|
||||
|
||||
size_t nup = 0;
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
if (!bitset || !bitset.test(ids[j])) {
|
||||
uint32_t dis = hc.hamming (codes);
|
||||
uint32_t dis = hc.compute (codes);
|
||||
if (dis < simi[0]) {
|
||||
idx_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
||||
heap_swap_top<C> (k, simi, idxi, dis, id);
|
||||
|
@ -411,7 +412,7 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
|
|||
{
|
||||
size_t nup = 0;
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
uint32_t dis = hc.hamming (codes);
|
||||
uint32_t dis = hc.compute (codes);
|
||||
if (dis < radius) {
|
||||
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
|
||||
result.add (dis, id);
|
||||
|
@ -447,7 +448,7 @@ struct IVFBinaryScannerJaccard: BinaryInvertedListScanner {
|
|||
const idx_t *ids,
|
||||
int32_t *simi, idx_t *idxi,
|
||||
size_t k,
|
||||
const BitsetView& bitset = nullptr) const override
|
||||
const BitsetView bitset = nullptr) const override
|
||||
{
|
||||
using C = CMax<float, idx_t>;
|
||||
float* psimi = (float*)simi;
|
||||
|
@ -486,14 +487,7 @@ BinaryInvertedListScanner *select_IVFBinaryScannerL2 (size_t code_size) {
|
|||
case 20: HC(HammingComputer20);
|
||||
case 32: HC(HammingComputer32);
|
||||
case 64: HC(HammingComputer64);
|
||||
default:
|
||||
if (code_size % 8 == 0) {
|
||||
HC(HammingComputerM8);
|
||||
} else if (code_size % 4 == 0) {
|
||||
HC(HammingComputerM4);
|
||||
} else {
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
default: HC(HammingComputerDefault);
|
||||
}
|
||||
#undef HC
|
||||
}
|
||||
|
@ -527,7 +521,7 @@ void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
|
|||
bool store_pairs,
|
||||
const IVFSearchParameters *params,
|
||||
IndexIVFStats &index_ivf_stats,
|
||||
const BitsetView& bitset = nullptr)
|
||||
const BitsetView bitset = nullptr)
|
||||
{
|
||||
long nprobe = params ? params->nprobe : ivf.nprobe;
|
||||
long max_codes = params ? params->max_codes : ivf.max_codes;
|
||||
|
@ -624,7 +618,7 @@ void search_knn_binary_dis_heap(const IndexBinaryIVF& ivf,
|
|||
bool store_pairs,
|
||||
const IVFSearchParameters *params,
|
||||
IndexIVFStats &index_ivf_stats,
|
||||
const BitsetView& bitset = nullptr)
|
||||
const BitsetView bitset = nullptr)
|
||||
{
|
||||
long nprobe = params ? params->nprobe : ivf.nprobe;
|
||||
long max_codes = params ? params->max_codes : ivf.max_codes;
|
||||
|
@ -711,7 +705,7 @@ void search_knn_hamming_count(const IndexBinaryIVF& ivf,
|
|||
idx_t *labels,
|
||||
const IVFSearchParameters *params,
|
||||
IndexIVFStats &index_ivf_stats,
|
||||
const BitsetView& bitset = nullptr) {
|
||||
const BitsetView bitset = nullptr) {
|
||||
const int nBuckets = ivf.d + 1;
|
||||
std::vector<int> all_counters(nx * nBuckets, 0);
|
||||
std::unique_ptr<idx_t[]> all_ids_per_dis(new idx_t[nx * nBuckets * k]);
|
||||
|
@ -809,7 +803,7 @@ void search_knn_hamming_count_1 (
|
|||
idx_t *labels,
|
||||
const IVFSearchParameters *params,
|
||||
IndexIVFStats &index_ivf_stats,
|
||||
const BitsetView& bitset = nullptr) {
|
||||
const BitsetView bitset = nullptr) {
|
||||
switch (ivf.code_size) {
|
||||
#define HANDLE_CS(cs) \
|
||||
case cs: \
|
||||
|
@ -824,16 +818,8 @@ void search_knn_hamming_count_1 (
|
|||
HANDLE_CS(64);
|
||||
#undef HANDLE_CS
|
||||
default:
|
||||
if (ivf.code_size % 8 == 0) {
|
||||
search_knn_hamming_count<HammingComputerM8, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params, index_ivf_stats, bitset);
|
||||
} else if (ivf.code_size % 4 == 0) {
|
||||
search_knn_hamming_count<HammingComputerM4, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params, index_ivf_stats, bitset);
|
||||
} else {
|
||||
search_knn_hamming_count<HammingComputerDefault, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params, index_ivf_stats, bitset);
|
||||
}
|
||||
search_knn_hamming_count<HammingComputerDefault, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params, index_ivf_stats, bitset);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -870,7 +856,7 @@ void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
|
|||
int32_t *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params,
|
||||
const BitsetView& bitset
|
||||
const BitsetView bitset
|
||||
) const {
|
||||
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
|
||||
if (use_heap) {
|
||||
|
@ -882,7 +868,7 @@ void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
|
|||
params, index_ivf_stats, bitset);
|
||||
if (metric_type == METRIC_Tanimoto) {
|
||||
for (int i = 0; i < k * n; i++) {
|
||||
D[i] = -log2(1-D[i]);
|
||||
D[i] = Jaccard_2_Tanimoto(D[i]);
|
||||
}
|
||||
}
|
||||
memcpy(distances, D, sizeof(float) * n * k);
|
||||
|
@ -913,7 +899,7 @@ void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
|
|||
void IndexBinaryIVF::range_search(
|
||||
idx_t n, const uint8_t *x, int radius,
|
||||
RangeSearchResult *res,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
||||
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
|
||||
|
|
|
@ -108,7 +108,7 @@ struct IndexBinaryIVF : IndexBinary {
|
|||
int32_t *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params=nullptr,
|
||||
const BitsetView& bitset = nullptr
|
||||
const BitsetView bitset = nullptr
|
||||
) const;
|
||||
|
||||
virtual BinaryInvertedListScanner *get_InvertedListScanner (
|
||||
|
@ -116,20 +116,20 @@ struct IndexBinaryIVF : IndexBinary {
|
|||
|
||||
/** assign the vectors, then call search_preassign */
|
||||
void search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels, const BitsetView& bitset = nullptr) const override;
|
||||
int32_t *distances, idx_t *labels, const BitsetView bitset = nullptr) const override;
|
||||
|
||||
|
||||
#if 0
|
||||
/** get raw vectors by ids */
|
||||
void get_vector_by_id(idx_t n, const idx_t *xid, uint8_t *x, const BitsetView& bitset = nullptr) override;
|
||||
void get_vector_by_id(idx_t n, const idx_t *xid, uint8_t *x, const BitsetView bitset = nullptr) override;
|
||||
|
||||
void search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) override;
|
||||
const BitsetView bitset = nullptr) override;
|
||||
#endif
|
||||
|
||||
void range_search(idx_t n, const uint8_t *x, int radius,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void reconstruct(idx_t key, uint8_t *recons) const override;
|
||||
|
||||
|
@ -230,7 +230,7 @@ struct BinaryInvertedListScanner {
|
|||
const idx_t *ids,
|
||||
int32_t *distances, idx_t *labels,
|
||||
size_t k,
|
||||
const BitsetView& bitset = nullptr) const = 0;
|
||||
const BitsetView bitset = nullptr) const = 0;
|
||||
|
||||
virtual void scan_codes_range (size_t n,
|
||||
const uint8_t *codes,
|
||||
|
|
|
@ -41,7 +41,7 @@ void IndexFlat::reset() {
|
|||
|
||||
void IndexFlat::search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
// we see the distances and labels as heaps
|
||||
|
||||
|
@ -91,21 +91,30 @@ void IndexFlat::assign(idx_t n, const float * x, idx_t * labels, float* distance
|
|||
|
||||
void IndexFlat::range_search (idx_t n, const float *x, float radius,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
switch (metric_type) {
|
||||
case METRIC_INNER_PRODUCT:
|
||||
range_search_inner_product (x, xb.data(), d, n, ntotal,
|
||||
radius, result);
|
||||
break;
|
||||
case METRIC_L2:
|
||||
range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result);
|
||||
break;
|
||||
default:
|
||||
FAISS_THROW_MSG("metric type not supported");
|
||||
}
|
||||
FAISS_THROW_MSG("This interface is abandoned yet.");
|
||||
}
|
||||
|
||||
void IndexFlat::range_search(faiss::Index::idx_t n,
|
||||
const float* x,
|
||||
float radius,
|
||||
std::vector<faiss::RangeSearchPartialResult*>& result,
|
||||
size_t buffer_size,
|
||||
const faiss::BitsetView bitset) {
|
||||
|
||||
switch (metric_type) {
|
||||
case METRIC_INNER_PRODUCT:
|
||||
range_search_inner_product (x, xb.data(), d, n, ntotal,
|
||||
radius, result, buffer_size, bitset);
|
||||
break;
|
||||
case METRIC_L2:
|
||||
range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result, buffer_size, bitset);
|
||||
break;
|
||||
default:
|
||||
FAISS_THROW_MSG("metric type not supported");
|
||||
}
|
||||
}
|
||||
|
||||
void IndexFlat::compute_distance_subset (
|
||||
idx_t n,
|
||||
|
@ -271,7 +280,7 @@ void IndexFlatL2BaseShift::search (
|
|||
idx_t k,
|
||||
float *distances,
|
||||
idx_t *labels,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
FAISS_THROW_IF_NOT (shift.size() == ntotal);
|
||||
|
||||
|
@ -355,7 +364,7 @@ static void reorder_2_heaps (
|
|||
void IndexRefineFlat::search (
|
||||
idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
FAISS_THROW_IF_NOT (is_trained);
|
||||
idx_t k_base = idx_t (k * k_factor);
|
||||
|
@ -449,7 +458,7 @@ void IndexFlat1D::search (
|
|||
idx_t k,
|
||||
float *distances,
|
||||
idx_t *labels,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
FAISS_THROW_IF_NOT_MSG (perm.size() == ntotal,
|
||||
"Call update_permutation before search");
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include <faiss/Index.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
|
||||
namespace faiss {
|
||||
|
@ -35,7 +36,7 @@ struct IndexFlat: Index {
|
|||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void assign (
|
||||
idx_t n,
|
||||
|
@ -48,7 +49,15 @@ struct IndexFlat: Index {
|
|||
const float* x,
|
||||
float radius,
|
||||
RangeSearchResult* result,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void range_search(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
float radius,
|
||||
std::vector<RangeSearchPartialResult*> &result,
|
||||
size_t buffer_size,
|
||||
const BitsetView bitset = nullptr); // const override
|
||||
|
||||
void reconstruct(idx_t key, float* recons) const override;
|
||||
|
||||
|
@ -115,7 +124,7 @@ struct IndexFlatL2BaseShift: IndexFlatL2 {
|
|||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
};
|
||||
|
||||
|
||||
|
@ -151,7 +160,7 @@ struct IndexRefineFlat: Index {
|
|||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
~IndexRefineFlat() override;
|
||||
};
|
||||
|
@ -180,7 +189,7 @@ struct IndexFlat1D:IndexFlatL2 {
|
|||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -277,18 +277,18 @@ IndexHNSW::~IndexHNSW() {
|
|||
void IndexHNSW::train(idx_t n, const float* x)
|
||||
{
|
||||
FAISS_THROW_IF_NOT_MSG(storage,
|
||||
"Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
|
||||
"Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
|
||||
// hnsw structure does not require training
|
||||
storage->train (n, x);
|
||||
is_trained = true;
|
||||
}
|
||||
|
||||
void IndexHNSW::search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels, const BitsetView& bitset) const
|
||||
float *distances, idx_t *labels, const BitsetView bitset) const
|
||||
|
||||
{
|
||||
FAISS_THROW_IF_NOT_MSG(storage,
|
||||
"Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
|
||||
"Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
|
||||
size_t nreorder = 0;
|
||||
|
||||
idx_t check_period = InterruptCallback::get_period_hint (
|
||||
|
@ -348,7 +348,7 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
|
|||
void IndexHNSW::add(idx_t n, const float *x)
|
||||
{
|
||||
FAISS_THROW_IF_NOT_MSG(storage,
|
||||
"Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
|
||||
"Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
|
||||
FAISS_THROW_IF_NOT(is_trained);
|
||||
int n0 = ntotal;
|
||||
storage->add(n, x);
|
||||
|
@ -1013,7 +1013,7 @@ int search_from_candidates_2(const HNSW & hnsw,
|
|||
} // namespace
|
||||
|
||||
void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels, const BitsetView& bitset) const
|
||||
float *distances, idx_t *labels, const BitsetView bitset) const
|
||||
{
|
||||
if (dynamic_cast<const Index2Layer*>(storage)) {
|
||||
IndexHNSW::search (n, x, k, distances, labels);
|
||||
|
|
|
@ -92,7 +92,7 @@ struct IndexHNSW : Index {
|
|||
/// entry point for search
|
||||
void search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void reconstruct(idx_t key, float* recons) const override;
|
||||
|
||||
|
@ -164,7 +164,7 @@ struct IndexHNSW2Level : IndexHNSW {
|
|||
/// entry point for search
|
||||
void search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -318,7 +318,7 @@ void IndexIVF::set_direct_map_type (DirectMap::Type type)
|
|||
|
||||
void IndexIVF::search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
||||
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
||||
|
@ -346,7 +346,7 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k,
|
|||
void IndexIVF::search_without_codes (idx_t n, const float *x,
|
||||
const uint8_t *arranged_codes, std::vector<size_t> prefix_sum,
|
||||
bool is_sq8, idx_t k, float *distances, idx_t *labels,
|
||||
const BitsetView& bitset)
|
||||
const BitsetView bitset)
|
||||
{
|
||||
|
||||
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
||||
|
@ -374,7 +374,7 @@ void IndexIVF::search_without_codes (idx_t n, const float *x,
|
|||
}
|
||||
|
||||
#if 0
|
||||
void IndexIVF::get_vector_by_id (idx_t n, const idx_t *xid, float *x, const BitsetView& bitset) {
|
||||
void IndexIVF::get_vector_by_id (idx_t n, const idx_t *xid, float *x, const BitsetView bitset) {
|
||||
make_direct_map(true);
|
||||
|
||||
/* only get vector by 1 id */
|
||||
|
@ -387,7 +387,7 @@ void IndexIVF::get_vector_by_id (idx_t n, const idx_t *xid, float *x, const Bits
|
|||
}
|
||||
|
||||
void IndexIVF::search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels,
|
||||
const BitsetView& bitset) {
|
||||
const BitsetView bitset) {
|
||||
make_direct_map(true);
|
||||
|
||||
auto x = new float[n * d];
|
||||
|
@ -406,7 +406,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
long nprobe = params ? params->nprobe : this->nprobe;
|
||||
long max_codes = params ? params->max_codes : this->max_codes;
|
||||
|
@ -462,7 +462,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|||
// set porperly) and storing results in simi and idxi
|
||||
auto scan_one_list = [&] (idx_t key, float coarse_dis_i,
|
||||
float *simi, idx_t *idxi,
|
||||
const BitsetView& bitset) {
|
||||
const BitsetView bitset) {
|
||||
|
||||
if (key < 0) {
|
||||
// not enough centroids for multiprobe
|
||||
|
@ -612,7 +612,7 @@ void IndexIVF::search_preassigned_without_codes (idx_t n, const float *x,
|
|||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params,
|
||||
const BitsetView& bitset)
|
||||
const BitsetView bitset)
|
||||
{
|
||||
long nprobe = params ? params->nprobe : this->nprobe;
|
||||
long max_codes = params ? params->max_codes : this->max_codes;
|
||||
|
@ -667,7 +667,7 @@ void IndexIVF::search_preassigned_without_codes (idx_t n, const float *x,
|
|||
// single list scan using the current scanner (with query
|
||||
// set porperly) and storing results in simi and idxi
|
||||
auto scan_one_list = [&] (idx_t key, float coarse_dis_i, const uint8_t *arranged_codes,
|
||||
float *simi, idx_t *idxi, const BitsetView& bitset) {
|
||||
float *simi, idx_t *idxi, const BitsetView bitset) {
|
||||
|
||||
if (key < 0) {
|
||||
// not enough centroids for multiprobe
|
||||
|
@ -812,7 +812,7 @@ void IndexIVF::search_preassigned_without_codes (idx_t n, const float *x,
|
|||
|
||||
void IndexIVF::range_search (idx_t nx, const float *x, float radius,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
std::unique_ptr<idx_t[]> keys (new idx_t[nx * nprobe]);
|
||||
std::unique_ptr<float []> coarse_dis (new float[nx * nprobe]);
|
||||
|
@ -834,7 +834,7 @@ void IndexIVF::range_search_preassigned (
|
|||
idx_t nx, const float *x, float radius,
|
||||
const idx_t *keys, const float *coarse_dis,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
|
||||
size_t nlistv = 0, ndis = 0;
|
||||
|
@ -1250,7 +1250,7 @@ void InvertedListScanner::scan_codes_range (size_t ,
|
|||
const idx_t *,
|
||||
float ,
|
||||
RangeQueryResult &,
|
||||
const BitsetView&) const
|
||||
const BitsetView) const
|
||||
{
|
||||
FAISS_THROW_MSG ("scan_codes_range not implemented");
|
||||
}
|
||||
|
|
|
@ -208,7 +208,7 @@ struct IndexIVF: Index, Level1Quantizer {
|
|||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params=nullptr,
|
||||
const BitsetView& bitset = nullptr
|
||||
const BitsetView bitset = nullptr
|
||||
) const;
|
||||
|
||||
/** Similar to search_preassigned, but does not store codes **/
|
||||
|
@ -221,36 +221,36 @@ struct IndexIVF: Index, Level1Quantizer {
|
|||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params = nullptr,
|
||||
const BitsetView& bitset = nullptr);
|
||||
const BitsetView bitset = nullptr);
|
||||
|
||||
/** assign the vectors, then call search_preassign */
|
||||
void search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
|
||||
/** Similar to search, but does not store codes **/
|
||||
void search_without_codes (idx_t n, const float *x,
|
||||
const uint8_t *arranged_codes, std::vector<size_t> prefix_sum,
|
||||
bool is_sq8, idx_t k, float *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr);
|
||||
const BitsetView bitset = nullptr);
|
||||
|
||||
#if 0
|
||||
/** get raw vectors by ids */
|
||||
void get_vector_by_id (idx_t n, const idx_t *xid, float *x, const BitsetView& bitset = nullptr) override;
|
||||
void get_vector_by_id (idx_t n, const idx_t *xid, float *x, const BitsetView bitset = nullptr) override;
|
||||
|
||||
void search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels,
|
||||
const BitsetView& bitset = nullptr) override;
|
||||
const BitsetView bitset = nullptr) override;
|
||||
#endif
|
||||
|
||||
void range_search (idx_t n, const float* x, float radius,
|
||||
RangeSearchResult* result,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void range_search_preassigned(idx_t nx, const float *x, float radius,
|
||||
const idx_t *keys, const float *coarse_dis,
|
||||
RangeSearchResult *result,
|
||||
const BitsetView& bitset = nullptr) const;
|
||||
const BitsetView bitset = nullptr) const;
|
||||
|
||||
/// get a scanner for this index (store_pairs means ignore labels)
|
||||
virtual InvertedListScanner *get_InvertedListScanner (
|
||||
|
@ -413,7 +413,7 @@ struct InvertedListScanner {
|
|||
const idx_t *ids,
|
||||
float *distances, idx_t *labels,
|
||||
size_t k,
|
||||
const BitsetView& bitset = nullptr) const = 0;
|
||||
const BitsetView bitset = nullptr) const = 0;
|
||||
|
||||
/** scan a set of codes, compute distances to current query and
|
||||
* update results if distances are below radius
|
||||
|
@ -424,7 +424,7 @@ struct InvertedListScanner {
|
|||
const idx_t *ids,
|
||||
float radius,
|
||||
RangeQueryResult &result,
|
||||
const BitsetView& bitset = nullptr) const;
|
||||
const BitsetView bitset = nullptr) const;
|
||||
|
||||
virtual ~InvertedListScanner () {}
|
||||
|
||||
|
|
|
@ -184,7 +184,7 @@ struct IVFFlatScanner: InvertedListScanner {
|
|||
const idx_t *ids,
|
||||
float *simi, idx_t *idxi,
|
||||
size_t k,
|
||||
const BitsetView& bitset) const override
|
||||
const BitsetView bitset) const override
|
||||
{
|
||||
const float *list_vecs = (const float*)codes;
|
||||
size_t nup = 0;
|
||||
|
@ -208,7 +208,7 @@ struct IVFFlatScanner: InvertedListScanner {
|
|||
const idx_t *ids,
|
||||
float radius,
|
||||
RangeQueryResult & res,
|
||||
const BitsetView& bitset = nullptr) const override
|
||||
const BitsetView bitset = nullptr) const override
|
||||
{
|
||||
const float *list_vecs = (const float*)codes;
|
||||
for (size_t j = 0; j < list_size; j++) {
|
||||
|
@ -354,7 +354,7 @@ void IndexIVFFlatDedup::search_preassigned (
|
|||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
FAISS_THROW_IF_NOT_MSG (
|
||||
!store_pairs, "store_pairs not supported in IVFDedup");
|
||||
|
@ -483,7 +483,7 @@ void IndexIVFFlatDedup::range_search(
|
|||
const float* ,
|
||||
float ,
|
||||
RangeSearchResult* ,
|
||||
const BitsetView&) const
|
||||
const BitsetView) const
|
||||
{
|
||||
FAISS_THROW_MSG ("not implemented");
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ struct IndexIVFFlatDedup: IndexIVFFlat {
|
|||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params=nullptr,
|
||||
const BitsetView& bitset = nullptr
|
||||
const BitsetView bitset = nullptr
|
||||
) const override;
|
||||
|
||||
size_t remove_ids(const IDSelector& sel) override;
|
||||
|
@ -92,7 +92,7 @@ struct IndexIVFFlatDedup: IndexIVFFlat {
|
|||
const float* x,
|
||||
float radius,
|
||||
RangeSearchResult* result,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
/// not implemented
|
||||
void update_vectors (int nv, const idx_t *idx, const float *v) override;
|
||||
|
|
|
@ -801,7 +801,7 @@ struct KnnSearchResults {
|
|||
|
||||
size_t nup;
|
||||
|
||||
inline void add (idx_t j, float dis, const BitsetView& bitset = nullptr) {
|
||||
inline void add (idx_t j, float dis, const BitsetView bitset = nullptr) {
|
||||
if (C::cmp (heap_sim[0], dis)) {
|
||||
idx_t id = ids ? ids[j] : lo_build (key, j);
|
||||
if (!bitset.empty() && bitset.test((faiss::ConcurrentBitset::id_type_t)id))
|
||||
|
@ -822,7 +822,7 @@ struct RangeSearchResults {
|
|||
float radius;
|
||||
RangeQueryResult & rres;
|
||||
|
||||
inline void add (idx_t j, float dis, const faiss::BitsetView& bitset = nullptr) {
|
||||
inline void add (idx_t j, float dis, const faiss::BitsetView bitset = nullptr) {
|
||||
if (C::cmp (radius, dis)) {
|
||||
idx_t id = ids ? ids[j] : lo_build (key, j);
|
||||
rres.add (dis, id);
|
||||
|
@ -872,7 +872,7 @@ struct IVFPQScannerT: QueryTables {
|
|||
template<class SearchResultType>
|
||||
void scan_list_with_table (size_t ncode, const uint8_t *codes,
|
||||
SearchResultType & res,
|
||||
const BitsetView& bitset = nullptr) const
|
||||
const BitsetView bitset = nullptr) const
|
||||
{
|
||||
for (size_t j = 0; j < ncode; j++) {
|
||||
PQDecoder decoder(codes, pq.nbits);
|
||||
|
@ -895,7 +895,7 @@ struct IVFPQScannerT: QueryTables {
|
|||
template<class SearchResultType>
|
||||
void scan_list_with_pointer (size_t ncode, const uint8_t *codes,
|
||||
SearchResultType & res,
|
||||
const faiss::BitsetView& bitset = nullptr) const
|
||||
const faiss::BitsetView bitset = nullptr) const
|
||||
{
|
||||
for (size_t j = 0; j < ncode; j++) {
|
||||
PQDecoder decoder(codes, pq.nbits);
|
||||
|
@ -918,7 +918,7 @@ struct IVFPQScannerT: QueryTables {
|
|||
template<class SearchResultType>
|
||||
void scan_on_the_fly_dist (size_t ncode, const uint8_t *codes,
|
||||
SearchResultType &res,
|
||||
const faiss::BitsetView& bitset = nullptr) const
|
||||
const faiss::BitsetView bitset = nullptr) const
|
||||
{
|
||||
const float *dvec;
|
||||
float dis0 = 0;
|
||||
|
@ -958,7 +958,7 @@ struct IVFPQScannerT: QueryTables {
|
|||
void scan_list_polysemous_hc (
|
||||
size_t ncode, const uint8_t *codes,
|
||||
SearchResultType & res,
|
||||
const faiss::BitsetView& bitset = nullptr) const
|
||||
const faiss::BitsetView bitset = nullptr) const
|
||||
{
|
||||
int ht = ivfpq.polysemous_ht;
|
||||
size_t n_hamming_pass = 0, nup = 0;
|
||||
|
@ -969,7 +969,7 @@ struct IVFPQScannerT: QueryTables {
|
|||
|
||||
for (size_t j = 0; j < ncode; j++) {
|
||||
const uint8_t *b_code = codes;
|
||||
int hd = hc.hamming (b_code);
|
||||
int hd = hc.compute (b_code);
|
||||
if (hd < ht) {
|
||||
n_hamming_pass ++;
|
||||
PQDecoder decoder(codes, pq.nbits);
|
||||
|
@ -996,7 +996,7 @@ struct IVFPQScannerT: QueryTables {
|
|||
void scan_list_polysemous (
|
||||
size_t ncode, const uint8_t *codes,
|
||||
SearchResultType &res,
|
||||
const faiss::BitsetView& bitset = nullptr) const
|
||||
const faiss::BitsetView bitset = nullptr) const
|
||||
{
|
||||
switch (pq.code_size) {
|
||||
#define HANDLE_CODE_SIZE(cs) \
|
||||
|
@ -1013,14 +1013,9 @@ struct IVFPQScannerT: QueryTables {
|
|||
HANDLE_CODE_SIZE(64);
|
||||
#undef HANDLE_CODE_SIZE
|
||||
default:
|
||||
if (pq.code_size % 8 == 0)
|
||||
scan_list_polysemous_hc
|
||||
<HammingComputerM8, SearchResultType>
|
||||
(ncode, codes, res, bitset);
|
||||
else
|
||||
scan_list_polysemous_hc
|
||||
<HammingComputerM4, SearchResultType>
|
||||
(ncode, codes, res, bitset);
|
||||
scan_list_polysemous_hc
|
||||
<HammingComputerDefault, SearchResultType>
|
||||
(ncode, codes, res, bitset);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -1075,7 +1070,7 @@ struct IVFPQScanner:
|
|||
const idx_t *ids,
|
||||
float *heap_sim, idx_t *heap_ids,
|
||||
size_t k,
|
||||
const faiss::BitsetView& bitset) const override
|
||||
const faiss::BitsetView bitset) const override
|
||||
{
|
||||
KnnSearchResults<C> res = {
|
||||
/* key */ this->key,
|
||||
|
@ -1106,7 +1101,7 @@ struct IVFPQScanner:
|
|||
const idx_t *ids,
|
||||
float radius,
|
||||
RangeQueryResult & rres,
|
||||
const faiss::BitsetView& bitset = nullptr) const override
|
||||
const faiss::BitsetView bitset = nullptr) const override
|
||||
{
|
||||
RangeSearchResults<C> res = {
|
||||
/* key */ this->key,
|
||||
|
|
|
@ -101,7 +101,7 @@ void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k,
|
|||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params,
|
||||
const BitsetView& bitset
|
||||
const BitsetView bitset
|
||||
) const
|
||||
{
|
||||
uint64_t t0;
|
||||
|
|
|
@ -56,7 +56,7 @@ struct IndexIVFPQR: IndexIVFPQ {
|
|||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params=nullptr,
|
||||
const BitsetView& bitset = nullptr
|
||||
const BitsetView bitset = nullptr
|
||||
) const override;
|
||||
|
||||
IndexIVFPQR();
|
||||
|
|
|
@ -254,7 +254,7 @@ struct IVFScanner: InvertedListScanner {
|
|||
}
|
||||
|
||||
float distance_to_code (const uint8_t *code) const final {
|
||||
return hc.hamming (code);
|
||||
return hc.compute (code);
|
||||
}
|
||||
|
||||
size_t scan_codes (size_t list_size,
|
||||
|
@ -262,12 +262,12 @@ struct IVFScanner: InvertedListScanner {
|
|||
const idx_t *ids,
|
||||
float *simi, idx_t *idxi,
|
||||
size_t k,
|
||||
const BitsetView& bitset) const override
|
||||
const BitsetView bitset) const override
|
||||
{
|
||||
size_t nup = 0;
|
||||
for (size_t j = 0; j < list_size; j++) {
|
||||
if (!bitset || !bitset.test(ids[j])) {
|
||||
float dis = hc.hamming (codes);
|
||||
float dis = hc.compute (codes);
|
||||
|
||||
if (dis < simi [0]) {
|
||||
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
||||
|
@ -285,10 +285,10 @@ struct IVFScanner: InvertedListScanner {
|
|||
const idx_t *ids,
|
||||
float radius,
|
||||
RangeQueryResult & res,
|
||||
const BitsetView& bitset = nullptr) const override
|
||||
const BitsetView bitset = nullptr) const override
|
||||
{
|
||||
for (size_t j = 0; j < list_size; j++) {
|
||||
float dis = hc.hamming (codes);
|
||||
float dis = hc.compute (codes);
|
||||
if (dis < radius) {
|
||||
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
|
||||
res.add (dis, id);
|
||||
|
@ -317,10 +317,8 @@ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner
|
|||
HANDLE_CODE_SIZE(64);
|
||||
#undef HANDLE_CODE_SIZE
|
||||
default:
|
||||
if (code_size % 8 == 0) {
|
||||
return new IVFScanner<HammingComputerM8>(this, store_pairs);
|
||||
} else if (code_size % 4 == 0) {
|
||||
return new IVFScanner<HammingComputerM4>(this, store_pairs);
|
||||
if (code_size % 4 == 0) {
|
||||
return new IVFScanner<HammingComputerDefault>(this, store_pairs);
|
||||
} else {
|
||||
FAISS_THROW_MSG("not supported");
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/utils/hamming.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/BinaryDistance.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
|
||||
|
||||
namespace faiss {
|
||||
|
@ -131,7 +133,7 @@ void IndexLSH::search (
|
|||
idx_t k,
|
||||
float *distances,
|
||||
idx_t *labels,
|
||||
const BitsetView& bitset) const
|
||||
const BitsetView bitset) const
|
||||
{
|
||||
FAISS_THROW_IF_NOT (is_trained);
|
||||
const float *xt = apply_preprocess (n, x);
|
||||
|
@ -147,9 +149,8 @@ void IndexLSH::search (
|
|||
|
||||
int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances};
|
||||
|
||||
hammings_knn_hc (&res, qcodes, codes.data(),
|
||||
ntotal, bytes_per_vec, true);
|
||||
|
||||
binary_distance_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)&qcodes, codes.data(), ntotal,
|
||||
bytes_per_vec, bitset);
|
||||
|
||||
// convert distances to floats
|
||||
for (int i = 0; i < k * n; i++)
|
||||
|
|
|
@ -58,7 +58,7 @@ struct IndexLSH:Index {
|
|||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
const BitsetView bitset = nullptr) const override;
|
||||
|
||||
void reset() override;
|
||||
|
||||
|
|
|
@ -128,7 +128,7 @@ void IndexLattice::add(idx_t , const float* )
|
|||
|
||||
|
||||
void IndexLattice::search(idx_t , const float* , idx_t ,
|
||||
float* , idx_t* , const BitsetView&) const
|
||||
float* , idx_t* , const BitsetView) const
|
||||
{
|
||||
FAISS_THROW_MSG("not implemented");
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue