add NGT index (#3555)

* add NGT index

Signed-off-by: fenglv <fenglv15@mails.ucas.ac.cn>
pull/3676/head
flynn 2020-09-05 14:09:17 +08:00 committed by GitHub
parent 16169bc7a4
commit f70d766475
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
71 changed files with 26247 additions and 1 deletions

View File

@ -23,3 +23,4 @@
| hnswlib | [Apache 2.0](https://github.com/nmslib/hnswlib/blob/master/LICENSE) |
| annoy | [Apache 2.0](https://github.com/spotify/annoy/blob/master/LICENSE) |
| crc32c | [BSD 3-Clause](https://github.com/google/crc32c/blob/master/LICENSE) |
| NGT | [Apache 2.0](https://github.com/yahoojapan/NGT/blob/master/LICENSE) |

View File

@ -653,3 +653,5 @@ if (KNOWHERE_WITH_FAISS AND NOT TARGET faiss_ep)
include_directories(SYSTEM "${FAISS_INCLUDE_DIR}")
link_directories(SYSTEM ${FAISS_PREFIX}/lib/)
endif ()
add_subdirectory(thirdparty/NGT)

View File

@ -13,6 +13,7 @@
include_directories(${INDEX_SOURCE_DIR}/knowhere)
include_directories(${INDEX_SOURCE_DIR}/thirdparty)
include_directories(${INDEX_SOURCE_DIR}/thirdparty/NGT/lib)
if (MILVUS_SUPPORT_SPTAG)
include_directories(${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService)
@ -68,6 +69,9 @@ set(vector_index_srcs
knowhere/index/vector_index/IndexRHNSWFlat.cpp
knowhere/index/vector_index/IndexRHNSWSQ.cpp
knowhere/index/vector_index/IndexRHNSWPQ.cpp
knowhere/index/vector_index/IndexNGT.cpp
knowhere/index/vector_index/IndexNGTPANNG.cpp
knowhere/index/vector_index/IndexNGTONNG.cpp
)
set(vector_offset_index_srcs
@ -91,6 +95,7 @@ set(depend_libs
gfortran
pthread
fiu
ngt
)
if (MILVUS_SUPPORT_SPTAG)

View File

@ -37,6 +37,8 @@ const char* INDEX_RHNSWFlat = "RHNSW_FLAT";
const char* INDEX_RHNSWPQ = "RHNSW_PQ";
const char* INDEX_RHNSWSQ = "RHNSW_SQ";
const char* INDEX_ANNOY = "ANNOY";
const char* INDEX_NGTPANNG = "NGT_PANNG";
const char* INDEX_NGTONNG = "NGT_ONNG";
} // namespace IndexEnum
} // namespace knowhere

View File

@ -64,6 +64,8 @@ extern const char* INDEX_RHNSWFlat;
extern const char* INDEX_RHNSWPQ;
extern const char* INDEX_RHNSWSQ;
extern const char* INDEX_ANNOY;
extern const char* INDEX_NGTPANNG;
extern const char* INDEX_NGTONNG;
} // namespace IndexEnum
enum class IndexMode { MODE_CPU = 0, MODE_GPU = 1 };

View File

@ -389,5 +389,37 @@ ANNOYConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexM
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
NGTPANNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
return true;
}
bool
NGTPANNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
NGTONNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
return true;
}
bool
NGTONNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
} // namespace knowhere
} // namespace milvus

View File

@ -126,5 +126,24 @@ class RHNSWSQConfAdapter : public ConfAdapter {
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class NGTPANNGConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class NGTONNGConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
} // namespace knowhere
} // namespace milvus

View File

@ -53,6 +53,8 @@ AdapterMgr::RegisterAdapter() {
REGISTER_CONF_ADAPTER(RHNSWFlatConfAdapter, IndexEnum::INDEX_RHNSWFlat, rhnswflat_adapter);
REGISTER_CONF_ADAPTER(RHNSWPQConfAdapter, IndexEnum::INDEX_RHNSWPQ, rhnswpq_adapter);
REGISTER_CONF_ADAPTER(RHNSWSQConfAdapter, IndexEnum::INDEX_RHNSWSQ, rhnswsq_adapter);
REGISTER_CONF_ADAPTER(NGTPANNGConfAdapter, IndexEnum::INDEX_NGTPANNG, ngtpanng_adapter);
REGISTER_CONF_ADAPTER(NGTONNGConfAdapter, IndexEnum::INDEX_NGTONNG, ngtonng_adapter);
}
} // namespace knowhere

View File

@ -0,0 +1,203 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexNGT.h"
#include <omp.h>
#include <sstream>
#include <string>
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
namespace milvus {
namespace knowhere {
BinarySet
IndexNGT::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
std::stringstream obj, grp, prf, tre;
index_->saveIndex(obj, grp, prf, tre);
auto obj_str = obj.str();
auto grp_str = grp.str();
auto prf_str = prf.str();
auto tre_str = tre.str();
uint64_t obj_size = obj_str.size();
uint64_t grp_size = grp_str.size();
uint64_t prf_size = prf_str.size();
uint64_t tre_size = tre_str.size();
std::shared_ptr<uint8_t[]> obj_data(new uint8_t[obj_size]);
memcpy(obj_data.get(), obj_str.data(), obj_size);
std::shared_ptr<uint8_t[]> grp_data(new uint8_t[grp_size]);
memcpy(grp_data.get(), grp_str.data(), grp_size);
std::shared_ptr<uint8_t[]> prf_data(new uint8_t[prf_size]);
memcpy(prf_data.get(), prf_str.data(), prf_size);
std::shared_ptr<uint8_t[]> tre_data(new uint8_t[tre_size]);
memcpy(tre_data.get(), tre_str.data(), tre_size);
BinarySet res_set;
res_set.Append("ngt_obj_data", obj_data, obj_size);
res_set.Append("ngt_grp_data", grp_data, grp_size);
res_set.Append("ngt_prf_data", prf_data, prf_size);
res_set.Append("ngt_tre_data", tre_data, tre_size);
return res_set;
}
void
IndexNGT::Load(const BinarySet& index_binary) {
auto obj_data = index_binary.GetByName("ngt_obj_data");
std::string obj_str(reinterpret_cast<char*>(obj_data->data.get()), obj_data->size);
auto grp_data = index_binary.GetByName("ngt_grp_data");
std::string grp_str(reinterpret_cast<char*>(grp_data->data.get()), grp_data->size);
auto prf_data = index_binary.GetByName("ngt_prf_data");
std::string prf_str(reinterpret_cast<char*>(prf_data->data.get()), prf_data->size);
auto tre_data = index_binary.GetByName("ngt_tre_data");
std::string tre_str(reinterpret_cast<char*>(tre_data->data.get()), tre_data->size);
std::stringstream obj(obj_str);
std::stringstream grp(grp_str);
std::stringstream prf(prf_str);
std::stringstream tre(tre_str);
index_ = std::shared_ptr<NGT::Index>(NGT::Index::loadIndex(obj, grp, prf, tre));
}
void
IndexNGT::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
KNOWHERE_THROW_MSG("IndexNGT has no implementation of BuildAll, please use IndexNGT(PANNG/ONNG) instead!");
}
#if 0
void
IndexNGT::Train(const DatasetPtr& dataset_ptr, const Config& config) {
KNOWHERE_THROW_MSG("IndexNGT has no implementation of Train, please use IndexNGT(PANNG/ONNG) instead!");
GET_TENSOR_DATA_DIM(dataset_ptr);
NGT::Property prop;
prop.setDefaultForCreateIndex();
prop.dimension = dim;
MetricType metric_type = config[Metric::TYPE];
if (metric_type == Metric::L2)
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
else if (metric_type == Metric::HAMMING)
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
else if (metric_type == Metric::JACCARD)
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
else
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
index_ =
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
}
void
IndexNGT::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GET_TENSOR_DATA(dataset_ptr);
index_->append(reinterpret_cast<const float*>(p_data), rows);
}
#endif
DatasetPtr
IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GET_TENSOR_DATA(dataset_ptr);
size_t k = config[meta::TOPK].get<int64_t>();
size_t id_size = sizeof(int64_t) * k;
size_t dist_size = sizeof(float) * k;
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
auto p_dist = static_cast<float*>(malloc(dist_size * rows));
NGT::Command::SearchParameter sp;
sp.size = k;
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
#pragma omp parallel for
for (unsigned int i = 0; i < rows; ++i) {
const float* single_query = reinterpret_cast<float*>(const_cast<void*>(p_data)) + i * Dim();
NGT::Object* object = index_->allocateObject(single_query, Dim());
NGT::SearchContainer sc(*object);
double epsilon = sp.beginOfEpsilon;
NGT::ObjectDistances res;
sc.setResults(&res);
sc.setSize(sp.size);
sc.setRadius(sp.radius);
if (sp.accuracy > 0.0) {
sc.setExpectedAccuracy(sp.accuracy);
} else {
sc.setEpsilon(epsilon);
}
sc.setEdgeSize(sp.edgeSize);
try {
index_->search(sc, blacklist);
} catch (NGT::Exception& err) {
KNOWHERE_THROW_MSG("Query failed");
}
auto local_id = p_id + i * k;
auto local_dist = p_dist + i * k;
int64_t res_num = res.size();
for (int64_t idx = 0; idx < res_num; ++idx) {
*(local_id + idx) = res[idx].id - 1;
*(local_dist + idx) = res[idx].distance;
}
while (res_num < static_cast<int64_t>(k)) {
*(local_id + res_num) = -1;
*(local_dist + res_num) = 1.0 / 0.0;
}
index_->deleteObject(object);
}
auto res_ds = std::make_shared<Dataset>();
res_ds->Set(meta::IDS, p_id);
res_ds->Set(meta::DISTANCE, p_dist);
return res_ds;
}
int64_t
IndexNGT::Count() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->getNumberOfVectors();
}
int64_t
IndexNGT::Dim() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->getDimension();
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,70 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <NGT/lib/NGT/Command.h>
#include <NGT/lib/NGT/Common.h>
#include <NGT/lib/NGT/Index.h>
#include <knowhere/common/Exception.h>
#include <knowhere/index/IndexType.h>
#include <knowhere/index/vector_index/VecIndex.h>
#include <memory>
namespace milvus {
namespace knowhere {
class IndexNGT : public VecIndex {
public:
IndexNGT() {
index_type_ = IndexEnum::INVALID;
}
BinarySet
Serialize(const Config& config) override;
void
Load(const BinarySet& index_binary) override;
void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
void
Train(const DatasetPtr& dataset_ptr, const Config& config) override {
KNOWHERE_THROW_MSG("NGT not support add item dynamically, please invoke BuildAll interface.");
}
void
Add(const DatasetPtr& dataset_ptr, const Config& config) override {
KNOWHERE_THROW_MSG("NGT not support add item dynamically, please invoke BuildAll interface.");
}
void
AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) override {
KNOWHERE_THROW_MSG("Incremental index is not supported");
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
int64_t
Count() override;
int64_t
Dim() override;
protected:
std::shared_ptr<NGT::Index> index_ = nullptr;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,69 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexNGTONNG.h"
#include "NGT/lib/NGT/GraphOptimizer.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include <memory>
namespace milvus {
namespace knowhere {
void
IndexNGTONNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA_DIM(dataset_ptr);
NGT::Property prop;
prop.setDefaultForCreateIndex();
prop.dimension = dim;
prop.edgeSizeForCreation = 20;
prop.insertionRadiusCoefficient = 1.0;
MetricType metric_type = config[Metric::TYPE];
if (metric_type == Metric::L2) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
} else if (metric_type == Metric::HAMMING) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
} else if (metric_type == Metric::JACCARD) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
} else {
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
}
index_ =
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
// reconstruct graph
NGT::GraphOptimizer graphOptimizer(true);
size_t number_of_outgoing_edges = 5;
size_t number_of_incoming_edges = 30;
size_t number_of_queries = 1000;
size_t number_of_res = 20;
graphOptimizer.shortcutReduction = true;
graphOptimizer.searchParameterOptimization = true;
graphOptimizer.prefetchParameterOptimization = false;
graphOptimizer.accuracyTableGeneration = false;
graphOptimizer.margin = 0.2;
graphOptimizer.gtEpsilon = 0.1;
graphOptimizer.set(number_of_outgoing_edges, number_of_incoming_edges, number_of_queries, number_of_res);
graphOptimizer.execute(*index_);
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,30 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include "knowhere/index/vector_index/IndexNGT.h"
namespace milvus {
namespace knowhere {
class IndexNGTONNG : public IndexNGT {
public:
IndexNGTONNG() {
index_type_ = IndexEnum::INDEX_NGTONNG;
}
void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,94 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexNGTPANNG.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include <memory>
namespace milvus {
namespace knowhere {
void
IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA_DIM(dataset_ptr);
NGT::Property prop;
prop.setDefaultForCreateIndex();
prop.dimension = dim;
MetricType metric_type = config[Metric::TYPE];
if (metric_type == Metric::L2) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
} else if (metric_type == Metric::HAMMING) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
} else if (metric_type == Metric::JACCARD) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
} else {
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
}
index_ =
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
size_t force_removed_edge_size = 60;
size_t selective_removed_edge_size = 30;
// prune
auto& graph = dynamic_cast<NGT::GraphIndex&>(index_->getIndex());
for (size_t id = 1; id < graph.repository.size(); id++) {
try {
NGT::GraphNode& node = *graph.getNode(id);
if (node.size() >= force_removed_edge_size) {
node.resize(force_removed_edge_size);
}
if (node.size() >= selective_removed_edge_size) {
size_t rank = 0;
for (auto i = node.begin(); i != node.end(); ++rank) {
if (rank >= selective_removed_edge_size) {
bool found = false;
for (size_t t1 = 0; t1 < node.size() && found == false; ++t1) {
if (t1 >= selective_removed_edge_size) {
break;
}
if (rank == t1) {
continue;
}
NGT::GraphNode& node2 = *graph.getNode(node[t1].id);
for (size_t t2 = 0; t2 < node2.size(); ++t2) {
if (t2 >= selective_removed_edge_size) {
break;
}
if (node2[t2].id == (*i).id) {
found = true;
break;
}
} // for
} // for
if (found) {
// remove
i = node.erase(i);
continue;
}
}
i++;
} // for
}
} catch (NGT::Exception& err) {
std::cerr << "Graph::search: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
continue;
}
}
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,30 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include "knowhere/index/vector_index/IndexNGT.h"
namespace milvus {
namespace knowhere {
class IndexNGTPANNG : public IndexNGT {
public:
IndexNGTPANNG() {
index_type_ = IndexEnum::INDEX_NGTPANNG;
}
void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
};
} // namespace knowhere
} // namespace milvus

View File

@ -21,6 +21,8 @@
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/IndexNGTONNG.h"
#include "knowhere/index/vector_index/IndexNGTPANNG.h"
#include "knowhere/index/vector_index/IndexRHNSWFlat.h"
#include "knowhere/index/vector_index/IndexRHNSWPQ.h"
#include "knowhere/index/vector_index/IndexRHNSWSQ.h"
@ -99,6 +101,10 @@ VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) {
return std::make_shared<knowhere::IndexRHNSWPQ>();
} else if (type == IndexEnum::INDEX_RHNSWSQ) {
return std::make_shared<knowhere::IndexRHNSWSQ>();
} else if (type == IndexEnum::INDEX_NGTPANNG) {
return std::make_shared<knowhere::IndexNGTPANNG>();
} else if (type == IndexEnum::INDEX_NGTONNG) {
return std::make_shared<knowhere::IndexNGTONNG>();
} else {
return nullptr;
}

View File

@ -0,0 +1,79 @@
if(APPLE)
cmake_minimum_required(VERSION 3.0)
else()
cmake_minimum_required(VERSION 2.8)
endif()
project(ngt)
file(STRINGS "VERSION" ngt_VERSION)
message(STATUS "VERSION: ${ngt_VERSION}")
string(REGEX MATCH "^[0-9]+" ngt_VERSION_MAJOR ${ngt_VERSION})
set(ngt_VERSION ${ngt_VERSION})
set(ngt_SOVERSION ${ngt_VERSION_MAJOR})
if (NOT CMAKE_BUILD_TYPE)
set (CMAKE_BUILD_TYPE "Release")
endif (NOT CMAKE_BUILD_TYPE)
string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_LOWER)
message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
message(STATUS "CMAKE_BUILD_TYPE_LOWER: ${CMAKE_BUILD_TYPE_LOWER}")
if(${UNIX})
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
if(CMAKE_VERSION VERSION_LESS 3.1)
set(BASE_OPTIONS "-Wall -std=gnu++0x -lrt")
if(${NGT_AVX_DISABLED})
message(STATUS "AVX will not be used to compute distances.")
endif()
if(${NGT_OPENMP_DISABLED})
message(STATUS "OpenMP is disabled.")
else()
set(BASE_OPTIONS "${BASE_OPTIONS} -fopenmp")
endif()
set(CMAKE_CXX_FLAGS_DEBUG "-g ${BASE_OPTIONS}")
if(${NGT_MARCH_NATIVE_DISABLED})
message(STATUS "Compile option -march=native is disabled.")
set(CMAKE_CXX_FLAGS_RELEASE "-O2 ${BASE_OPTIONS}")
else()
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=native ${BASE_OPTIONS}")
endif()
else()
if (CMAKE_BUILD_TYPE_LOWER STREQUAL "release")
set(CMAKE_CXX_FLAGS_RELEASE "")
if(${NGT_MARCH_NATIVE_DISABLED})
message(STATUS "Compile option -march=native is disabled.")
add_compile_options(-O2 -DNDEBUG)
else()
add_compile_options(-Ofast -march=native -DNDEBUG)
endif()
endif()
add_compile_options(-Wall)
if(${NGT_AVX_DISABLED})
message(STATUS "AVX will not be used to compute distances.")
endif()
if(${NGT_OPENMP_DISABLED})
message(STATUS "OpenMP is disabled.")
else()
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "8.1.0")
message(FATAL_ERROR "Insufficient AppleClang version")
endif()
cmake_minimum_required(VERSION 3.16)
endif()
find_package(OpenMP REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()
set(CMAKE_CXX_STANDARD 11) # for std::unordered_set, std::unique_ptr
set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_package(Threads REQUIRED)
endif()
add_subdirectory("${PROJECT_SOURCE_DIR}/lib")
endif( ${UNIX} )

202
core/src/index/thirdparty/NGT/LICENSE vendored Normal file
View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

1
core/src/index/thirdparty/NGT/VERSION vendored Normal file
View File

@ -0,0 +1 @@
1.12.0

View File

@ -0,0 +1,3 @@
if( ${UNIX} )
add_subdirectory(${PROJECT_SOURCE_DIR}/lib/NGT)
endif()

View File

@ -0,0 +1,89 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "ArrayFile.h"
#include <iostream>
#include <assert.h>
class ItemID {
public:
void serialize(std::ostream &os, NGT::ObjectSpace *ospace = 0) {
os.write((char*)&value, sizeof(value));
}
void deserialize(std::istream &is, NGT::ObjectSpace *ospace = 0) {
is.read((char*)&value, sizeof(value));
}
static size_t getSerializedDataSize() {
return sizeof(uint64_t);
}
uint64_t value;
};
void
sampleForUsage() {
{
ArrayFile<ItemID> itemIDFile;
itemIDFile.create("test.data", ItemID::getSerializedDataSize());
itemIDFile.open("test.data");
ItemID itemID;
size_t id;
id = 1;
itemID.value = 4910002490100;
itemIDFile.put(id, itemID);
itemID.value = 0;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490100);
id = 2;
itemID.value = 4910002490101;
itemIDFile.put(id, itemID);
itemID.value = 0;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490101);
itemID.value = 4910002490102;
id = itemIDFile.insert(itemID);
itemID.value = 0;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490102);
itemIDFile.close();
}
{
ArrayFile<ItemID> itemIDFile;
itemIDFile.create("test.data", ItemID::getSerializedDataSize());
itemIDFile.open("test.data");
ItemID itemID;
size_t id;
id = 10;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490100);
id = 20;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490101);
}
}

View File

@ -0,0 +1,220 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <fstream>
#include <string>
#include <cstddef>
#include <stdint.h>
#include <iostream>
#include <stdexcept>
#include <cerrno>
#include <cstring>
namespace NGT {
class ObjectSpace;
};
template <class TYPE>
class ArrayFile {
private:
struct FileHeadStruct {
size_t recordSize;
uint64_t extraData; // reserve
};
struct RecordStruct {
bool deleteFlag;
uint64_t extraData; // reserve
};
bool _isOpen;
std::fstream _stream;
FileHeadStruct _fileHead;
bool _readFileHead();
pthread_mutex_t _mutex;
public:
ArrayFile();
~ArrayFile();
bool create(const std::string &file, size_t recordSize);
bool open(const std::string &file);
void close();
size_t insert(TYPE &data, NGT::ObjectSpace *objectSpace = 0);
void put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0);
bool get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0);
void remove(const size_t id);
bool isOpen() const;
size_t size();
size_t getRecordSize() { return _fileHead.recordSize; }
};
// constructor
template <class TYPE>
ArrayFile<TYPE>::ArrayFile()
: _isOpen(false), _mutex((pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER){
if(pthread_mutex_init(&_mutex, NULL) < 0) throw std::runtime_error("pthread init error.");
}
// destructor
template <class TYPE>
ArrayFile<TYPE>::~ArrayFile() {
pthread_mutex_destroy(&_mutex);
close();
}
template <class TYPE>
bool ArrayFile<TYPE>::create(const std::string &file, size_t recordSize) {
std::fstream tmpstream;
tmpstream.open(file.c_str());
if(tmpstream){
return false;
}
tmpstream.open(file.c_str(), std::ios::out);
tmpstream.seekp(0, std::ios::beg);
FileHeadStruct fileHead = {recordSize, 0};
tmpstream.write((char *)(&fileHead), sizeof(FileHeadStruct));
tmpstream.close();
return true;
}
template <class TYPE>
bool ArrayFile<TYPE>::open(const std::string &file) {
_stream.open(file.c_str(), std::ios::in | std::ios::out);
if(!_stream){
_isOpen = false;
return false;
}
_isOpen = true;
bool ret = _readFileHead();
return ret;
}
template <class TYPE>
void ArrayFile<TYPE>::close(){
_stream.close();
_isOpen = false;
}
template <class TYPE>
size_t ArrayFile<TYPE>::insert(TYPE &data, NGT::ObjectSpace *objectSpace) {
_stream.seekp(sizeof(RecordStruct), std::ios::end);
int64_t write_pos = _stream.tellg();
for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); }
_stream.seekp(write_pos, std::ios::beg);
data.serialize(_stream, objectSpace);
int64_t offset_pos = _stream.tellg();
offset_pos -= sizeof(FileHeadStruct);
size_t id = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize);
if(offset_pos % (sizeof(RecordStruct) + _fileHead.recordSize) == 0){
id -= 1;
}
return id;
}
template <class TYPE>
void ArrayFile<TYPE>::put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) {
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
offset_pos += sizeof(RecordStruct);
_stream.seekp(offset_pos, std::ios::beg);
for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); }
_stream.seekp(offset_pos, std::ios::beg);
data.serialize(_stream, objectSpace);
}
template <class TYPE>
bool ArrayFile<TYPE>::get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) {
pthread_mutex_lock(&_mutex);
if( size() <= id ){
pthread_mutex_unlock(&_mutex);
return false;
}
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
offset_pos += sizeof(RecordStruct);
_stream.seekg(offset_pos, std::ios::beg);
if (!_stream.fail()) {
data.deserialize(_stream, objectSpace);
}
if (_stream.fail()) {
const int trialCount = 10;
for (int tc = 0; tc < trialCount; tc++) {
_stream.clear();
_stream.seekg(offset_pos, std::ios::beg);
if (_stream.fail()) {
continue;
}
data.deserialize(_stream, objectSpace);
if (_stream.fail()) {
continue;
} else {
break;
}
}
if (_stream.fail()) {
throw std::runtime_error("ArrayFile::get: Error!");
}
}
pthread_mutex_unlock(&_mutex);
return true;
}
template <class TYPE>
void ArrayFile<TYPE>::remove(const size_t id) {
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
_stream.seekp(offset_pos, std::ios::beg);
RecordStruct recordHead = {1, 0};
_stream.write((char *)(&recordHead), sizeof(RecordStruct));
}
template <class TYPE>
bool ArrayFile<TYPE>::isOpen() const
{
return _isOpen;
}
template <class TYPE>
size_t ArrayFile<TYPE>::size()
{
_stream.seekp(0, std::ios::end);
int64_t offset_pos = _stream.tellg();
offset_pos -= sizeof(FileHeadStruct);
size_t num = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize);
return num;
}
template <class TYPE>
bool ArrayFile<TYPE>::_readFileHead() {
_stream.seekp(0, std::ios::beg);
_stream.read((char *)(&_fileHead), sizeof(FileHeadStruct));
if(_stream.bad()){
return false;
}
return true;
}

View File

@ -0,0 +1,40 @@
if( ${UNIX} )
option(NGT_SHARED_MEMORY_ALLOCATOR "enable shared memory" OFF)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/defines.h.in ${CMAKE_CURRENT_BINARY_DIR}/defines.h)
include_directories("${CMAKE_CURRENT_BINARY_DIR}" "${PROJECT_SOURCE_DIR}/lib" "${PROJECT_BINARY_DIR}/lib/")
include_directories("${PROJECT_SOURCE_DIR}/../")
file(GLOB NGT_SOURCES *.cpp)
file(GLOB HEADER_FILES *.h *.hpp)
file(GLOB NGTQ_HEADER_FILES NGTQ/*.h NGTQ/*.hpp)
add_library(ngtstatic STATIC ${NGT_SOURCES})
set_target_properties(ngtstatic PROPERTIES OUTPUT_NAME ngt)
set_target_properties(ngtstatic PROPERTIES COMPILE_FLAGS "-fPIC")
target_link_libraries(ngtstatic)
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
target_link_libraries(ngtstatic OpenMP::OpenMP_CXX)
endif()
add_library(ngt SHARED ${NGT_SOURCES})
set_target_properties(ngt PROPERTIES VERSION ${ngt_VERSION})
set_target_properties(ngt PROPERTIES SOVERSION ${ngt_SOVERSION})
add_dependencies(ngt ngtstatic)
if(${APPLE})
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
target_link_libraries(ngt OpenMP::OpenMP_CXX)
else()
target_link_libraries(ngt gomp)
endif()
else(${APPLE})
target_link_libraries(ngt gomp rt)
endif(${APPLE})
install(TARGETS
ngt
ngtstatic
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib)
endif()

View File

@ -0,0 +1,988 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include <string>
#include <iostream>
#include <sstream>
#include "NGT/Index.h"
#include "NGT/GraphOptimizer.h"
#include "Capi.h"
static bool operate_error_string_(const std::stringstream &ss, NGTError error){
if(error != NULL){
try{
std::string *error_str = static_cast<std::string*>(error);
*error_str = ss.str();
}catch(std::exception &err){
std::cerr << ss.str() << " > " << err.what() << std::endl;
return false;
}
}else{
std::cerr << ss.str() << std::endl;
}
return true;
}
NGTIndex ngt_open_index(const char *index_path, NGTError error) {
try{
std::string index_path_str(index_path);
NGT::Index *index = new NGT::Index(index_path_str);
index->disableLog();
return static_cast<NGTIndex>(index);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
NGTIndex ngt_create_graph_and_tree(const char *database, NGTProperty prop, NGTError error) {
NGT::Index *index = NULL;
try{
std::string database_str(database);
NGT::Property prop_i = *(static_cast<NGT::Property*>(prop));
NGT::Index::createGraphAndTree(database_str, prop_i, true);
index = new NGT::Index(database_str);
index->disableLog();
return static_cast<NGTIndex>(index);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
delete index;
return NULL;
}
}
NGTIndex ngt_create_graph_and_tree_in_memory(NGTProperty prop, NGTError error) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << __FUNCTION__ << " is unavailable for shared-memory-type NGT.";
operate_error_string_(ss, error);
return NULL;
#else
try{
NGT::Index *index = new NGT::GraphAndTreeIndex(*(static_cast<NGT::Property*>(prop)));
index->disableLog();
return static_cast<NGTIndex>(index);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
#endif
}
NGTProperty ngt_create_property(NGTError error) {
try{
return static_cast<NGTProperty>(new NGT::Property());
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
bool ngt_save_index(const NGTIndex index, const char *database, NGTError error) {
try{
std::string database_str(database);
(static_cast<NGT::Index*>(index))->saveIndex(database_str);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_get_property(NGTIndex index, NGTProperty prop, NGTError error) {
if(index == NULL || prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " prop = " << prop;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::Index*>(index))->getProperty(*(static_cast<NGT::Property*>(prop)));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
int32_t ngt_get_property_dimension(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).dimension;
}
bool ngt_set_property_dimension(NGTProperty prop, int32_t value, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).dimension = value;
return true;
}
bool ngt_set_property_edge_size_for_creation(NGTProperty prop, int16_t value, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).edgeSizeForCreation = value;
return true;
}
bool ngt_set_property_edge_size_for_search(NGTProperty prop, int16_t value, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).edgeSizeForSearch = value;
return true;
}
int32_t ngt_get_property_object_type(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).objectType;
}
bool ngt_is_property_object_type_float(int32_t object_type) {
return (object_type == NGT::ObjectSpace::ObjectType::Float);
}
bool ngt_is_property_object_type_integer(int32_t object_type) {
return (object_type == NGT::ObjectSpace::ObjectType::Uint8);
}
bool ngt_set_property_object_type_float(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).objectType = NGT::ObjectSpace::ObjectType::Float;
return true;
}
bool ngt_set_property_object_type_integer(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).objectType = NGT::ObjectSpace::ObjectType::Uint8;
return true;
}
bool ngt_set_property_distance_type_l1(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL1;
return true;
}
bool ngt_set_property_distance_type_l2(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
return true;
}
bool ngt_set_property_distance_type_angle(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeAngle;
return true;
}
bool ngt_set_property_distance_type_hamming(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
return true;
}
bool ngt_set_property_distance_type_jaccard(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
return true;
}
bool ngt_set_property_distance_type_cosine(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeCosine;
return true;
}
bool ngt_set_property_distance_type_normalized_angle(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedAngle;
return true;
}
bool ngt_set_property_distance_type_normalized_cosine(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedCosine;
return true;
}
NGTObjectDistances ngt_create_empty_results(NGTError error) {
try{
return static_cast<NGTObjectDistances>(new NGT::ObjectDistances());
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
static bool ngt_search_index_(NGT::Index* pindex, NGT::Object *ngtquery, size_t size, float epsilon, float radius, NGTObjectDistances results, int edge_size = INT_MIN) {
// set search prameters.
NGT::SearchContainer sc(*ngtquery); // search parametera container.
sc.setResults(static_cast<NGT::ObjectDistances*>(results)); // set the result set.
sc.setSize(size); // the number of resultant objects.
sc.setRadius(radius); // search radius.
sc.setEpsilon(epsilon); // set exploration coefficient.
if (edge_size != INT_MIN) {
sc.setEdgeSize(edge_size);// set # of edges for each node
}
pindex->search(sc);
// delete the query object.
pindex->deleteObject(ngtquery);
return true;
}
bool ngt_search_index(NGTIndex index, double *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) {
if(index == NULL || query == NULL || results == NULL || query_dim <= 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
NGT::Object *ngtquery = NULL;
if(radius < 0.0){
radius = FLT_MAX;
}
try{
std::vector<double> vquery(&query[0], &query[query_dim]);
ngtquery = pindex->allocateObject(vquery);
ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
if(ngtquery != NULL){
pindex->deleteObject(ngtquery);
}
return false;
}
return true;
}
bool ngt_search_index_as_float(NGTIndex index, float *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) {
if(index == NULL || query == NULL || results == NULL || query_dim <= 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
NGT::Object *ngtquery = NULL;
if(radius < 0.0){
radius = FLT_MAX;
}
try{
std::vector<float> vquery(&query[0], &query[query_dim]);
ngtquery = pindex->allocateObject(vquery);
ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
if(ngtquery != NULL){
pindex->deleteObject(ngtquery);
}
return false;
}
return true;
}
bool ngt_search_index_with_query(NGTIndex index, NGTQuery query, NGTObjectDistances results, NGTError error) {
if(index == NULL || query.query == NULL || results == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query.query << " results = " << results;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
int32_t dim = pindex->getObjectSpace().getDimension();
NGT::Object *ngtquery = NULL;
if(query.radius < 0.0){
query.radius = FLT_MAX;
}
try{
std::vector<float> vquery(&query.query[0], &query.query[dim]);
ngtquery = pindex->allocateObject(vquery);
ngt_search_index_(pindex, ngtquery, query.size, query.epsilon, query.radius, results, query.edge_size);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
if(ngtquery != NULL){
pindex->deleteObject(ngtquery);
}
return false;
}
return true;
}
// * deprecated *
int32_t ngt_get_size(NGTObjectDistances results, NGTError error) {
if(results == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results;
operate_error_string_(ss, error);
return -1;
}
return (static_cast<NGT::ObjectDistances*>(results))->size();
}
uint32_t ngt_get_result_size(NGTObjectDistances results, NGTError error) {
if(results == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results;
operate_error_string_(ss, error);
return 0;
}
return (static_cast<NGT::ObjectDistances*>(results))->size();
}
NGTObjectDistance ngt_get_result(const NGTObjectDistances results, const uint32_t i, NGTError error) {
try{
NGT::ObjectDistances objects = *(static_cast<NGT::ObjectDistances*>(results));
NGTObjectDistance ret_val = {objects[i].id, objects[i].distance};
return ret_val;
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
NGTObjectDistance err_val = {0};
return err_val;
}
}
ObjectID ngt_insert_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<double> vobj(&obj[0], &obj[obj_dim]);
return pindex->insert(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
ObjectID ngt_append_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<double> vobj(&obj[0], &obj[obj_dim]);
return pindex->append(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
ObjectID ngt_insert_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<float> vobj(&obj[0], &obj[obj_dim]);
return pindex->insert(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
ObjectID ngt_append_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<float> vobj(&obj[0], &obj[obj_dim]);
return pindex->append(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
bool ngt_batch_append_index(NGTIndex index, float *obj, uint32_t data_count, NGTError error) {
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
pindex->append(obj, data_count);
return true;
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
}
bool ngt_batch_insert_index(NGTIndex index, float *obj, uint32_t data_count, uint32_t *ids, NGTError error) {
NGT::Index* pindex = static_cast<NGT::Index*>(index);
int32_t dim = pindex->getObjectSpace().getDimension();
bool status = true;
float *objptr = obj;
for (size_t idx = 0; idx < data_count; idx++, objptr += dim) {
try{
std::vector<double> vobj(objptr, objptr + dim);
ids[idx] = pindex->insert(vobj);
}catch(std::exception &err) {
status = false;
ids[idx] = 0;
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
}
}
return status;
}
bool ngt_create_index(NGTIndex index, uint32_t pool_size, NGTError error) {
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::Index*>(index))->createIndex(pool_size);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_remove_index(NGTIndex index, ObjectID id, NGTError error) {
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::Index*>(index))->remove(id);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
NGTObjectSpace ngt_get_object_space(NGTIndex index, NGTError error) {
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
operate_error_string_(ss, error);
return NULL;
}
try{
return static_cast<NGTObjectSpace>(&(static_cast<NGT::Index*>(index))->getObjectSpace());
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
float* ngt_get_object_as_float(NGTObjectSpace object_space, ObjectID id, NGTError error) {
if(object_space == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space;
operate_error_string_(ss, error);
return NULL;
}
try{
return static_cast<float*>((static_cast<NGT::ObjectSpace*>(object_space))->getObject(id));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
uint8_t* ngt_get_object_as_integer(NGTObjectSpace object_space, ObjectID id, NGTError error) {
if(object_space == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space;
operate_error_string_(ss, error);
return NULL;
}
try{
return static_cast<uint8_t*>((static_cast<NGT::ObjectSpace*>(object_space))->getObject(id));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
void ngt_destroy_results(NGTObjectDistances results) {
if(results == NULL) return;
delete(static_cast<NGT::ObjectDistances*>(results));
}
void ngt_destroy_property(NGTProperty prop) {
if(prop == NULL) return;
delete(static_cast<NGT::Property*>(prop));
}
void ngt_close_index(NGTIndex index) {
if(index == NULL) return;
(static_cast<NGT::Index*>(index))->close();
delete(static_cast<NGT::Index*>(index));
}
int16_t ngt_get_property_edge_size_for_creation(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).edgeSizeForCreation;
}
int16_t ngt_get_property_edge_size_for_search(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).edgeSizeForSearch;
}
int32_t ngt_get_property_distance_type(NGTProperty prop, NGTError error){
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).distanceType;
}
NGTError ngt_create_error_object()
{
try{
std::string *error_str = new std::string();
return static_cast<NGTError>(error_str);
}catch(std::exception &err){
std::cerr << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
return NULL;
}
}
const char *ngt_get_error_string(const NGTError error)
{
std::string *error_str = static_cast<std::string*>(error);
return error_str->c_str();
}
void ngt_clear_error_string(NGTError error)
{
std::string *error_str = static_cast<std::string*>(error);
*error_str = "";
}
void ngt_destroy_error_object(NGTError error)
{
std::string *error_str = static_cast<std::string*>(error);
delete error_str;
}
NGTOptimizer ngt_create_optimizer(bool logDisabled, NGTError error)
{
try{
return static_cast<NGTOptimizer>(new NGT::GraphOptimizer(logDisabled));
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
bool ngt_optimizer_adjust_search_coefficients(NGTOptimizer optimizer, const char *index, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->adjustSearchCoefficients(std::string(index));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_execute(NGTOptimizer optimizer, const char *inIndex, const char *outIndex, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->execute(std::string(inIndex), std::string(outIndex));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
// obsolute because of a lack of a parameter
bool ngt_optimizer_set(NGTOptimizer optimizer, int outgoing, int incoming, int nofqs,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->set(outgoing, incoming, nofqs, baseAccuracyFrom, baseAccuracyTo,
rateAccuracyFrom, rateAccuracyTo, gte, m);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_set_minimum(NGTOptimizer optimizer, int outgoing, int incoming,
int nofqs, int nofrs, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->set(outgoing, incoming, nofqs, nofrs);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_set_extension(NGTOptimizer optimizer,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->setExtension(baseAccuracyFrom, baseAccuracyTo,
rateAccuracyFrom, rateAccuracyTo, gte, m);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_set_processing_modes(NGTOptimizer optimizer, bool searchParameter,
bool prefetchParameter, bool accuracyTable, NGTError error)
{
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
(static_cast<NGT::GraphOptimizer*>(optimizer))->setProcessingModes(searchParameter, prefetchParameter,
accuracyTable);
return true;
}
void ngt_destroy_optimizer(NGTOptimizer optimizer)
{
if(optimizer == NULL) return;
delete(static_cast<NGT::GraphOptimizer*>(optimizer));
}
bool ngt_refine_anng(NGTIndex index, float epsilon, float accuracy, int noOfEdges, int exploreEdgeSize, size_t batchSize, NGTError error)
{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
try {
NGT::GraphReconstructor::refineANNG(*pindex, true, epsilon, accuracy, noOfEdges, exploreEdgeSize, batchSize);
} catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_get_edges(NGTIndex index, ObjectID id, NGTObjectDistances edges, NGTError error)
{
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
NGT::GraphIndex &graph = static_cast<NGT::GraphIndex&>(pindex->getIndex());
try {
NGT::ObjectDistances &objects = *static_cast<NGT::ObjectDistances*>(edges);
objects = *graph.getNode(id);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
uint32_t ngt_get_object_repository_size(NGTIndex index, NGTError error)
{
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index;
operate_error_string_(ss, error);
return false;
}
NGT::Index& pindex = *static_cast<NGT::Index*>(index);
return pindex.getObjectRepositorySize();
}
NGTAnngEdgeOptimizationParameter ngt_get_anng_edge_optimization_parameter()
{
NGT::GraphOptimizer::ANNGEdgeOptimizationParameter gp;
NGTAnngEdgeOptimizationParameter parameter;
parameter.no_of_queries = gp.noOfQueries;
parameter.no_of_results = gp.noOfResults;
parameter.no_of_threads = gp.noOfThreads;
parameter.target_accuracy = gp.targetAccuracy;
parameter.target_no_of_objects = gp.targetNoOfObjects;
parameter.no_of_sample_objects = gp.noOfSampleObjects;
parameter.max_of_no_of_edges = gp.maxNoOfEdges;
parameter.log = false;
return parameter;
}
bool ngt_optimize_number_of_edges(const char *indexPath, NGTAnngEdgeOptimizationParameter parameter, NGTError error)
{
NGT::GraphOptimizer::ANNGEdgeOptimizationParameter p;
p.noOfQueries = parameter.no_of_queries;
p.noOfResults = parameter.no_of_results;
p.noOfThreads = parameter.no_of_threads;
p.targetAccuracy = parameter.target_accuracy;
p.targetNoOfObjects = parameter.target_no_of_objects;
p.noOfSampleObjects = parameter.no_of_sample_objects;
p.maxNoOfEdges = parameter.max_of_no_of_edges;
try {
NGT::GraphOptimizer graphOptimizer(!parameter.log); // false=log
std::string path(indexPath);
auto edge = graphOptimizer.optimizeNumberOfEdgesForANNG(path, p);
if (parameter.log) {
std::cerr << "the optimized number of edges is" << edge.first << "(" << edge.second << ")" << std::endl;
}
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}

View File

@ -0,0 +1,210 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
#include <stdio.h>
#include <stdint.h>
#include <stdbool.h>
typedef unsigned int ObjectID;
typedef void* NGTIndex;
typedef void* NGTProperty;
typedef void* NGTObjectSpace;
typedef void* NGTObjectDistances;
typedef void* NGTError;
typedef void* NGTOptimizer;
typedef struct {
ObjectID id;
float distance;
} NGTObjectDistance;
typedef struct {
float *query;
size_t size; // # of returned objects
float epsilon;
float accuracy; // expected accuracy
float radius;
size_t edge_size; // # of edges to explore for each node
} NGTQuery;
typedef struct {
size_t no_of_queries;
size_t no_of_results;
size_t no_of_threads;
float target_accuracy;
size_t target_no_of_objects;
size_t no_of_sample_objects;
size_t max_of_no_of_edges;
bool log;
} NGTAnngEdgeOptimizationParameter;
NGTIndex ngt_open_index(const char *, NGTError);
NGTIndex ngt_create_graph_and_tree(const char *, NGTProperty, NGTError);
NGTIndex ngt_create_graph_and_tree_in_memory(NGTProperty, NGTError);
NGTProperty ngt_create_property(NGTError);
bool ngt_save_index(const NGTIndex, const char *, NGTError);
bool ngt_get_property(const NGTIndex, NGTProperty, NGTError);
int32_t ngt_get_property_dimension(NGTProperty, NGTError);
bool ngt_set_property_dimension(NGTProperty, int32_t, NGTError);
bool ngt_set_property_edge_size_for_creation(NGTProperty, int16_t, NGTError);
bool ngt_set_property_edge_size_for_search(NGTProperty, int16_t, NGTError);
int32_t ngt_get_property_object_type(NGTProperty, NGTError);
bool ngt_is_property_object_type_float(int32_t);
bool ngt_is_property_object_type_integer(int32_t);
bool ngt_set_property_object_type_float(NGTProperty, NGTError);
bool ngt_set_property_object_type_integer(NGTProperty, NGTError);
bool ngt_set_property_distance_type_l1(NGTProperty, NGTError);
bool ngt_set_property_distance_type_l2(NGTProperty, NGTError);
bool ngt_set_property_distance_type_angle(NGTProperty, NGTError);
bool ngt_set_property_distance_type_hamming(NGTProperty, NGTError);
bool ngt_set_property_distance_type_jaccard(NGTProperty, NGTError);
bool ngt_set_property_distance_type_cosine(NGTProperty, NGTError);
bool ngt_set_property_distance_type_normalized_angle(NGTProperty, NGTError);
bool ngt_set_property_distance_type_normalized_cosine(NGTProperty, NGTError);
NGTObjectDistances ngt_create_empty_results(NGTError);
bool ngt_search_index(NGTIndex, double*, int32_t, size_t, float, float, NGTObjectDistances, NGTError);
bool ngt_search_index_as_float(NGTIndex, float*, int32_t, size_t, float, float, NGTObjectDistances, NGTError);
bool ngt_search_index_with_query(NGTIndex, NGTQuery, NGTObjectDistances, NGTError);
int32_t ngt_get_size(NGTObjectDistances, NGTError); // deprecated
uint32_t ngt_get_result_size(NGTObjectDistances, NGTError);
NGTObjectDistance ngt_get_result(const NGTObjectDistances, const uint32_t, NGTError);
ObjectID ngt_insert_index(NGTIndex, double*, uint32_t, NGTError);
ObjectID ngt_append_index(NGTIndex, double*, uint32_t, NGTError);
ObjectID ngt_insert_index_as_float(NGTIndex, float*, uint32_t, NGTError);
ObjectID ngt_append_index_as_float(NGTIndex, float*, uint32_t, NGTError);
bool ngt_batch_append_index(NGTIndex, float*, uint32_t, NGTError);
bool ngt_batch_insert_index(NGTIndex, float*, uint32_t, uint32_t *, NGTError);
bool ngt_create_index(NGTIndex, uint32_t, NGTError);
bool ngt_remove_index(NGTIndex, ObjectID, NGTError);
NGTObjectSpace ngt_get_object_space(NGTIndex, NGTError);
float* ngt_get_object_as_float(NGTObjectSpace, ObjectID, NGTError);
uint8_t* ngt_get_object_as_integer(NGTObjectSpace, ObjectID, NGTError);
void ngt_destroy_results(NGTObjectDistances);
void ngt_destroy_property(NGTProperty);
void ngt_close_index(NGTIndex);
int16_t ngt_get_property_edge_size_for_creation(NGTProperty, NGTError);
int16_t ngt_get_property_edge_size_for_search(NGTProperty, NGTError);
int32_t ngt_get_property_distance_type(NGTProperty, NGTError);
NGTError ngt_create_error_object();
const char *ngt_get_error_string(const NGTError);
void ngt_clear_error_string(NGTError);
void ngt_destroy_error_object(NGTError);
NGTOptimizer ngt_create_optimizer(bool logDisabled, NGTError);
bool ngt_optimizer_adjust_search_coefficients(NGTOptimizer, const char *, NGTError);
bool ngt_optimizer_execute(NGTOptimizer, const char *, const char *, NGTError);
bool ngt_optimizer_set(NGTOptimizer optimizer, int outgoing, int incoming, int nofqs,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m, NGTError error);
bool ngt_optimizer_set_minimum(NGTOptimizer optimizer, int outgoing, int incoming,
int nofqs, int nofrs, NGTError error);
bool ngt_optimizer_set_extension(NGTOptimizer optimizer,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m, NGTError error);
bool ngt_optimizer_set_processing_modes(NGTOptimizer optimizer, bool searchParameter,
bool prefetchParameter, bool accuracyTable, NGTError error);
void ngt_destroy_optimizer(NGTOptimizer);
// refine: the specified index by searching each node.
// epsilon, exepectedAccuracy and edgeSize: the same as the prameters for search. but if edgeSize is INT_MIN, default is used.
// noOfEdges: if this is not 0, kNNG with k = noOfEdges is build
// batchSize: batch size for parallelism.
bool ngt_refine_anng(NGTIndex index, float epsilon, float expectedAccuracy,
int noOfEdges, int edgeSize, size_t batchSize, NGTError error);
// get edges of the node that is specified with id.
bool ngt_get_edges(NGTIndex index, ObjectID id, NGTObjectDistances edges, NGTError error);
// get the size of the specified object repository.
// Since the size includes empty objects, the size is not the number of objects.
// The size is mostly the largest ID of the objects - 1;
uint32_t ngt_get_object_repository_size(NGTIndex index, NGTError error);
// return parameters for ngt_optimize_number_of_edges. You can customize them before calling ngt_optimize_number_of_edges.
NGTAnngEdgeOptimizationParameter ngt_get_anng_edge_optimization_parameter();
// optimize the number of initial edges for ANNG that is specified with indexPath.
// The parameter should be a struct which is returned by nt_get_optimization_parameter.
bool ngt_optimize_number_of_edges(const char *indexPath, NGTAnngEdgeOptimizationParameter parameter, NGTError error);
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,857 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/Index.h"
using namespace std;
#if defined(NGT_AVX_DISABLED)
#define NGT_CLUSTER_NO_AVX
#else
#if defined(__AVX2__)
#define NGT_CLUSTER_AVX2
#else
#define NGT_CLUSTER_NO_AVX
#endif
#endif
#if defined(NGT_CLUSTER_NO_AVX)
// #warning "*** SIMD is *NOT* available! ***"
#else
#include <immintrin.h>
#endif
#include <omp.h>
#include <random>
namespace NGT {
class Clustering {
public:
enum InitializationMode {
InitializationModeHead = 0,
InitializationModeRandom = 1,
InitializationModeKmeansPlusPlus = 2
};
enum ClusteringType {
ClusteringTypeKmeansWithNGT = 0,
ClusteringTypeKmeansWithoutNGT = 1,
ClusteringTypeKmeansWithIteration = 2,
ClusteringTypeKmeansWithNGTForCentroids = 3
};
class Entry {
public:
Entry() : vectorID(0), centroidID(0), distance(0.0) {
}
Entry(size_t vid, size_t cid, double d) : vectorID(vid), centroidID(cid), distance(d) {
}
bool
operator<(const Entry& e) const {
return distance > e.distance;
}
uint32_t vectorID;
uint32_t centroidID;
double distance;
};
class DescendingEntry {
public:
DescendingEntry(size_t vid, double d) : vectorID(vid), distance(d) {
}
bool
operator<(const DescendingEntry& e) const {
return distance < e.distance;
}
size_t vectorID;
double distance;
};
class Cluster {
public:
Cluster(std::vector<float>& c) : centroid(c), radius(0.0) {
}
Cluster(const Cluster& c) {
*this = c;
}
Cluster&
operator=(const Cluster& c) {
members = c.members;
centroid = c.centroid;
radius = c.radius;
return *this;
}
std::vector<Entry> members;
std::vector<float> centroid;
double radius;
};
Clustering(InitializationMode im = InitializationModeHead, ClusteringType ct = ClusteringTypeKmeansWithNGT,
size_t mi = 100)
: clusteringType(ct), initializationMode(im), maximumIteration(mi) {
initialize();
}
void
initialize() {
epsilonFrom = 0.12;
epsilonTo = epsilonFrom;
epsilonStep = 0.04;
resultSizeCoefficient = 5;
}
static void
convert(std::vector<std::string>& strings, std::vector<float>& vector) {
vector.clear();
for (auto it = strings.begin(); it != strings.end(); ++it) {
vector.push_back(stod(*it));
}
}
static void
extractVector(const std::string& str, std::vector<float>& vec) {
std::vector<std::string> tokens;
NGT::Common::tokenize(str, tokens, " \t");
convert(tokens, vec);
}
static void
loadVectors(const std::string& file, std::vector<std::vector<float> >& vectors) {
std::ifstream is(file);
if (!is) {
throw std::runtime_error("loadVectors::Cannot open " + file);
}
std::string line;
while (getline(is, line)) {
std::vector<float> v;
extractVector(line, v);
vectors.push_back(v);
}
}
static void
saveVectors(const std::string& file, std::vector<std::vector<float> >& vectors) {
std::ofstream os(file);
for (auto vit = vectors.begin(); vit != vectors.end(); ++vit) {
std::vector<float>& v = *vit;
for (auto it = v.begin(); it != v.end(); ++it) {
os << std::setprecision(9) << (*it);
if (it + 1 != v.end()) {
os << "\t";
}
}
os << std::endl;
}
}
static void
saveVector(const std::string& file, std::vector<size_t>& vectors) {
std::ofstream os(file);
for (auto vit = vectors.begin(); vit != vectors.end(); ++vit) {
os << *vit << std::endl;
}
}
static void
loadClusters(const std::string& file, std::vector<Cluster>& clusters, size_t numberOfClusters = 0) {
std::ifstream is(file);
if (!is) {
throw std::runtime_error("loadClusters::Cannot open " + file);
}
std::string line;
while (getline(is, line)) {
std::vector<float> v;
extractVector(line, v);
clusters.push_back(v);
if ((numberOfClusters != 0) && (clusters.size() >= numberOfClusters)) {
break;
}
}
if ((numberOfClusters != 0) && (clusters.size() < numberOfClusters)) {
std::cerr << "initial cluster data are not enough. " << clusters.size() << ":" << numberOfClusters
<< std::endl;
exit(1);
}
}
#if !defined(NGT_CLUSTER_NO_AVX)
static double
sumOfSquares(float* a, float* b, size_t size) {
__m256 sum = _mm256_setzero_ps();
float* last = a + size;
float* lastgroup = last - 7;
while (a < lastgroup) {
__m256 v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
sum = _mm256_add_ps(sum, _mm256_mul_ps(v, v));
a += 8;
b += 8;
}
__attribute__((aligned(32))) float f[8];
_mm256_store_ps(f, sum);
double s = f[0] + f[1] + f[2] + f[3] + f[4] + f[5] + f[6] + f[7];
while (a < last) {
double d = *a++ - *b++;
s += d * d;
}
return s;
}
#else // !defined(NGT_AVX_DISABLED) && defined(__AVX__)
static double
sumOfSquares(float* a, float* b, size_t size) {
double csum = 0.0;
float* x = a;
float* y = b;
for (size_t i = 0; i < size; i++) {
double d = (double)*x++ - (double)*y++;
csum += d * d;
}
return csum;
}
#endif // !defined(NGT_AVX_DISABLED) && defined(__AVX__)
static double
distanceL2(std::vector<float>& vector1, std::vector<float>& vector2) {
return sqrt(sumOfSquares(&vector1[0], &vector2[0], vector1.size()));
}
static double
distanceL2(std::vector<std::vector<float> >& vector1, std::vector<std::vector<float> >& vector2) {
assert(vector1.size() == vector2.size());
double distance = 0.0;
for (size_t i = 0; i < vector1.size(); i++) {
distance += distanceL2(vector1[i], vector2[i]);
}
distance /= (double)vector1.size();
return distance;
}
static double
meanSumOfSquares(std::vector<float>& vector1, std::vector<float>& vector2) {
return sumOfSquares(&vector1[0], &vector2[0], vector1.size()) / (double)vector1.size();
}
static void
subtract(std::vector<float>& a, std::vector<float>& b) {
assert(a.size() == b.size());
auto bit = b.begin();
for (auto ait = a.begin(); ait != a.end(); ++ait, ++bit) {
*ait = *ait - *bit;
}
}
static void
getInitialCentroidsFromHead(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
size_t size) {
size = size > vectors.size() ? vectors.size() : size;
clusters.clear();
for (size_t i = 0; i < size; i++) {
clusters.push_back(Cluster(vectors[i]));
}
}
static void
getInitialCentroidsRandomly(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters, size_t size,
size_t seed) {
clusters.clear();
std::random_device rnd;
if (seed == 0) {
seed = rnd();
}
std::mt19937 mt(seed);
for (size_t i = 0; i < size; i++) {
size_t idx = mt() * vectors.size() / mt.max();
if (idx >= size) {
i--;
continue;
}
clusters.push_back(Cluster(vectors[idx]));
}
assert(clusters.size() == size);
}
static void
getInitialCentroidsKmeansPlusPlus(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
size_t size) {
size = size > vectors.size() ? vectors.size() : size;
clusters.clear();
std::random_device rnd;
std::mt19937 mt(rnd());
size_t idx = (long long)mt() * (long long)vectors.size() / (long long)mt.max();
clusters.push_back(Cluster(vectors[idx]));
NGT::Timer timer;
for (size_t k = 1; k < size; k++) {
double sum = 0;
std::priority_queue<DescendingEntry> sortedObjects;
// get d^2 and sort
#pragma omp parallel for
for (size_t vi = 0; vi < vectors.size(); vi++) {
auto vit = vectors.begin() + vi;
double mind = DBL_MAX;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
double d = distanceL2(*vit, (*cit).centroid);
d *= d;
if (d < mind) {
mind = d;
}
}
#pragma omp critical
{
sortedObjects.push(DescendingEntry(distance(vectors.begin(), vit), mind));
sum += mind;
}
}
double l = (double)mt() / (double)mt.max() * sum;
while (!sortedObjects.empty()) {
sum -= sortedObjects.top().distance;
if (l >= sum) {
clusters.push_back(Cluster(vectors[sortedObjects.top().vectorID]));
break;
}
sortedObjects.pop();
}
}
}
static void
assign(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
size_t clusterSize = std::numeric_limits<size_t>::max()) {
// compute distances to the nearest clusters, and construct heap by the distances.
NGT::Timer timer;
timer.start();
std::vector<Entry> sortedObjects(vectors.size());
#pragma omp parallel for
for (size_t vi = 0; vi < vectors.size(); vi++) {
auto vit = vectors.begin() + vi;
{
double mind = DBL_MAX;
size_t mincidx = -1;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
double d = distanceL2(*vit, (*cit).centroid);
if (d < mind) {
mind = d;
mincidx = distance(clusters.begin(), cit);
}
}
sortedObjects[vi] = Entry(vi, mincidx, mind);
}
}
std::sort(sortedObjects.begin(), sortedObjects.end());
// clear
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
(*cit).members.clear();
}
// distribute objects to the nearest clusters in the same size constraint.
for (auto soi = sortedObjects.rbegin(); soi != sortedObjects.rend();) {
Entry& entry = *soi;
if (entry.centroidID >= clusters.size()) {
std::cerr << "Something wrong. " << entry.centroidID << ":" << clusters.size() << std::endl;
soi++;
continue;
}
if (clusters[entry.centroidID].members.size() < clusterSize) {
clusters[entry.centroidID].members.push_back(entry);
soi++;
} else {
double mind = DBL_MAX;
size_t mincidx = -1;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
if ((*cit).members.size() >= clusterSize) {
continue;
}
double d = distanceL2(vectors[entry.vectorID], (*cit).centroid);
if (d < mind) {
mind = d;
mincidx = distance(clusters.begin(), cit);
}
}
entry = Entry(entry.vectorID, mincidx, mind);
int pt = distance(sortedObjects.rbegin(), soi);
std::sort(sortedObjects.begin(), soi.base());
soi = sortedObjects.rbegin() + pt;
assert(pt == distance(sortedObjects.rbegin(), soi));
}
}
moveFartherObjectsToEmptyClusters(clusters);
}
static void
moveFartherObjectsToEmptyClusters(std::vector<Cluster>& clusters) {
size_t emptyClusterCount = 0;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
if ((*cit).members.size() == 0) {
emptyClusterCount++;
double max = 0.0;
auto maxit = clusters.begin();
for (auto scit = clusters.begin(); scit != clusters.end(); ++scit) {
if ((*scit).members.size() >= 2 && (*scit).members.back().distance > max) {
maxit = scit;
max = (*scit).members.back().distance;
}
}
(*cit).members.push_back((*maxit).members.back());
(*cit).members.back().centroidID = distance(clusters.begin(), cit);
(*maxit).members.pop_back();
}
}
emptyClusterCount = 0;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
if ((*cit).members.size() == 0) {
emptyClusterCount++;
}
}
}
static void
assignWithNGT(NGT::Index& index, std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
float& radius, size_t& resultSize, float epsilon = 0.12, size_t notRetrievedObjectCount = 0) {
size_t dataSize = vectors.size();
assert(index.getObjectRepositorySize() - 1 == vectors.size());
vector<vector<Entry> > results(clusters.size());
#pragma omp parallel for
for (size_t ci = 0; ci < clusters.size(); ci++) {
auto cit = clusters.begin() + ci;
NGT::ObjectDistances objects; // result set
NGT::Object* query = 0;
query = index.allocateObject((*cit).centroid);
// set search prameters.
NGT::SearchContainer sc(*query); // search parametera container.
sc.setResults(&objects); // set the result set.
sc.setEpsilon(epsilon); // set exploration coefficient.
if (radius > 0.0) {
sc.setRadius(radius);
sc.setSize(dataSize / 2);
} else {
sc.setSize(resultSize); // the number of resultant objects.
}
index.search(sc);
results[ci].reserve(objects.size());
for (size_t idx = 0; idx < objects.size(); idx++) {
size_t oidx = objects[idx].id - 1;
results[ci].push_back(Entry(oidx, ci, objects[idx].distance));
}
index.deleteObject(query);
}
size_t resultCount = 0;
for (auto ri = results.begin(); ri != results.end(); ++ri) {
resultCount += (*ri).size();
}
vector<Entry> sortedResults;
sortedResults.reserve(resultCount);
for (auto ri = results.begin(); ri != results.end(); ++ri) {
auto end = (*ri).begin();
for (; end != (*ri).end(); ++end) {
}
std::copy((*ri).begin(), end, std::back_inserter(sortedResults));
}
vector<bool> processedObjects(dataSize, false);
for (auto i = sortedResults.begin(); i != sortedResults.end(); ++i) {
processedObjects[(*i).vectorID] = true;
}
notRetrievedObjectCount = 0;
vector<uint32_t> notRetrievedObjectIDs;
for (size_t idx = 0; idx < dataSize; idx++) {
if (!processedObjects[idx]) {
notRetrievedObjectCount++;
notRetrievedObjectIDs.push_back(idx);
}
}
sort(sortedResults.begin(), sortedResults.end());
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
(*cit).members.clear();
}
for (auto i = sortedResults.rbegin(); i != sortedResults.rend(); ++i) {
size_t objectID = (*i).vectorID;
size_t clusterID = (*i).centroidID;
if (processedObjects[objectID]) {
processedObjects[objectID] = false;
clusters[clusterID].members.push_back(*i);
clusters[clusterID].members.back().centroidID = clusterID;
radius = (*i).distance;
}
}
vector<Entry> notRetrievedObjects(notRetrievedObjectIDs.size());
#pragma omp parallel for
for (size_t vi = 0; vi < notRetrievedObjectIDs.size(); vi++) {
auto vit = notRetrievedObjectIDs.begin() + vi;
{
double mind = DBL_MAX;
size_t mincidx = -1;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
double d = distanceL2(vectors[*vit], (*cit).centroid);
if (d < mind) {
mind = d;
mincidx = distance(clusters.begin(), cit);
}
}
notRetrievedObjects[vi] = Entry(*vit, mincidx, mind); // Entry(vectorID, centroidID, distance)
}
}
sort(notRetrievedObjects.begin(), notRetrievedObjects.end());
for (auto nroit = notRetrievedObjects.begin(); nroit != notRetrievedObjects.end(); ++nroit) {
clusters[(*nroit).centroidID].members.push_back(*nroit);
}
moveFartherObjectsToEmptyClusters(clusters);
}
static double
calculateCentroid(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters) {
double distance = 0;
size_t memberCount = 0;
for (auto it = clusters.begin(); it != clusters.end(); ++it) {
memberCount += (*it).members.size();
if ((*it).members.size() != 0) {
std::vector<float> mean(vectors[0].size(), 0.0);
for (auto memit = (*it).members.begin(); memit != (*it).members.end(); ++memit) {
auto mit = mean.begin();
auto& v = vectors[(*memit).vectorID];
for (auto vit = v.begin(); vit != v.end(); ++vit, ++mit) {
*mit += *vit;
}
}
for (auto mit = mean.begin(); mit != mean.end(); ++mit) {
*mit /= (*it).members.size();
}
distance += distanceL2((*it).centroid, mean);
(*it).centroid = mean;
} else {
cerr << "Clustering: Fatal Error. No member!" << endl;
abort();
}
}
return distance;
}
static void
saveClusters(const std::string& file, std::vector<Cluster>& clusters) {
std::ofstream os(file);
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
std::vector<float>& v = (*cit).centroid;
for (auto it = v.begin(); it != v.end(); ++it) {
os << std::setprecision(9) << (*it);
if (it + 1 != v.end()) {
os << "\t";
}
}
os << std::endl;
}
}
double
kmeansWithoutNGT(std::vector<std::vector<float> >& vectors, size_t numberOfClusters,
std::vector<Cluster>& clusters) {
size_t clusterSize = std::numeric_limits<size_t>::max();
if (clusterSizeConstraint) {
clusterSize = ceil((double)vectors.size() / (double)numberOfClusters);
}
double diff = 0;
for (size_t i = 0; i < maximumIteration; i++) {
std::cerr << "iteration=" << i << std::endl;
assign(vectors, clusters, clusterSize);
// centroid is recomputed.
// diff is distance between the current centroids and the previous centroids.
diff = calculateCentroid(vectors, clusters);
if (diff == 0) {
break;
}
}
return diff == 0;
}
double
kmeansWithNGT(NGT::Index& index, std::vector<std::vector<float> >& vectors, size_t numberOfClusters,
std::vector<Cluster>& clusters, float epsilon) {
diffHistory.clear();
NGT::Timer timer;
timer.start();
float radius;
double diff = 0.0;
size_t resultSize;
resultSize = resultSizeCoefficient * vectors.size() / clusters.size();
for (size_t i = 0; i < maximumIteration; i++) {
size_t notRetrievedObjectCount = 0;
radius = -1.0;
assignWithNGT(index, vectors, clusters, radius, resultSize, epsilon, notRetrievedObjectCount);
// centroid is recomputed.
// diff is distance between the current centroids and the previous centroids.
std::vector<Cluster> prevClusters = clusters;
diff = calculateCentroid(vectors, clusters);
timer.stop();
std::cerr << "iteration=" << i << " time=" << timer << " diff=" << diff << std::endl;
timer.start();
diffHistory.push_back(diff);
if (diff == 0) {
break;
}
}
return diff;
}
double
kmeansWithNGT(std::vector<std::vector<float> >& vectors, size_t numberOfClusters, std::vector<Cluster>& clusters) {
pid_t pid = getpid();
std::stringstream str;
str << "cluster-ngt." << pid;
string database = str.str();
string dataFile;
size_t dataSize = 0;
size_t dim = clusters.front().centroid.size();
NGT::Property property;
property.dimension = dim;
property.graphType = NGT::Property::GraphType::GraphTypeANNG;
property.objectType = NGT::Index::Property::ObjectType::Float;
property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
NGT::Index::createGraphAndTree(database, property, dataFile, dataSize);
float* data = new float[vectors.size() * dim];
float* ptr = data;
dataSize = vectors.size();
for (auto vi = vectors.begin(); vi != vectors.end(); ++vi) {
memcpy(ptr, &((*vi)[0]), dim * sizeof(float));
ptr += dim;
}
size_t threadSize = 20;
NGT::Index::append(database, data, dataSize, threadSize);
delete[] data;
NGT::Index index(database);
return kmeansWithNGT(index, vectors, numberOfClusters, clusters, epsilonFrom);
}
double
kmeansWithNGT(NGT::Index& index, size_t numberOfClusters, std::vector<Cluster>& clusters) {
NGT::GraphIndex& graph = static_cast<NGT::GraphIndex&>(index.getIndex());
NGT::ObjectSpace& os = graph.getObjectSpace();
size_t size = os.getRepository().size();
std::vector<std::vector<float> > vectors(size - 1);
for (size_t idx = 1; idx < size; idx++) {
try {
os.getObject(idx, vectors[idx - 1]);
} catch (...) {
cerr << "Cannot get object " << idx << endl;
}
}
cerr << "# of data for clustering=" << vectors.size() << endl;
double diff = DBL_MAX;
clusters.clear();
setupInitialClusters(vectors, numberOfClusters, clusters);
for (float epsilon = epsilonFrom; epsilon <= epsilonTo; epsilon += epsilonStep) {
cerr << "epsilon=" << epsilon << endl;
diff = kmeansWithNGT(index, vectors, numberOfClusters, clusters, epsilon);
if (diff == 0.0) {
return diff;
}
}
return diff;
}
double
kmeansWithNGT(NGT::Index& index, size_t numberOfClusters, NGT::Index& outIndex) {
std::vector<Cluster> clusters;
double diff = kmeansWithNGT(index, numberOfClusters, clusters);
for (auto i = clusters.begin(); i != clusters.end(); ++i) {
outIndex.insert((*i).centroid);
}
outIndex.createIndex(16);
return diff;
}
double
kmeansWithNGT(NGT::Index& index, size_t numberOfClusters) {
NGT::Property prop;
index.getProperty(prop);
string path = index.getPath();
index.save();
index.close();
string outIndexName = path;
string inIndexName = path + ".tmp";
std::rename(outIndexName.c_str(), inIndexName.c_str());
NGT::Index::createGraphAndTree(outIndexName, prop);
index.open(outIndexName);
NGT::Index inIndex(inIndexName);
double diff = kmeansWithNGT(inIndex, numberOfClusters, index);
inIndex.close();
NGT::Index::destroy(inIndexName);
return diff;
}
double
kmeansWithNGT(string& indexName, size_t numberOfClusters) {
NGT::Index inIndex(indexName);
double diff = kmeansWithNGT(inIndex, numberOfClusters);
inIndex.save();
inIndex.close();
return diff;
}
static double
calculateMSE(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters) {
double mse = 0.0;
size_t count = 0;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
count += (*cit).members.size();
for (auto mit = (*cit).members.begin(); mit != (*cit).members.end(); ++mit) {
mse += meanSumOfSquares((*cit).centroid, vectors[(*mit).vectorID]);
}
}
assert(vectors.size() == count);
return mse / (double)vectors.size();
}
static double
calculateML2(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters) {
double d = 0.0;
size_t count = 0;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
count += (*cit).members.size();
double localD = 0.0;
for (auto mit = (*cit).members.begin(); mit != (*cit).members.end(); ++mit) {
double distance = distanceL2((*cit).centroid, vectors[(*mit).vectorID]);
d += distance;
localD += distance;
}
}
if (vectors.size() != count) {
std::cerr << "Warning! vectors.size() != count" << std::endl;
}
return d / (double)vectors.size();
}
static double
calculateML2FromSpecifiedCentroids(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
std::vector<size_t>& centroidIds) {
double d = 0.0;
size_t count = 0;
for (auto it = centroidIds.begin(); it != centroidIds.end(); ++it) {
Cluster& cluster = clusters[(*it)];
count += cluster.members.size();
for (auto mit = cluster.members.begin(); mit != cluster.members.end(); ++mit) {
d += distanceL2(cluster.centroid, vectors[(*mit).vectorID]);
}
}
return d / (double)vectors.size();
}
void
setupInitialClusters(std::vector<std::vector<float> >& vectors, size_t numberOfClusters,
std::vector<Cluster>& clusters) {
if (clusters.empty()) {
switch (initializationMode) {
case InitializationModeHead: {
getInitialCentroidsFromHead(vectors, clusters, numberOfClusters);
break;
}
case InitializationModeRandom: {
getInitialCentroidsRandomly(vectors, clusters, numberOfClusters, 0);
break;
}
case InitializationModeKmeansPlusPlus: {
getInitialCentroidsKmeansPlusPlus(vectors, clusters, numberOfClusters);
break;
}
default:
std::cerr << "proper initMode is not specified." << std::endl;
exit(1);
}
}
}
bool
kmeans(std::vector<std::vector<float> >& vectors, size_t numberOfClusters, std::vector<Cluster>& clusters) {
setupInitialClusters(vectors, numberOfClusters, clusters);
switch (clusteringType) {
case ClusteringTypeKmeansWithoutNGT:
return kmeansWithoutNGT(vectors, numberOfClusters, clusters);
break;
case ClusteringTypeKmeansWithNGT:
return kmeansWithNGT(vectors, numberOfClusters, clusters);
break;
default:
cerr << "kmeans::fatal error!. invalid clustering type. " << clusteringType << endl;
abort();
break;
}
}
static void
evaluate(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters, char mode,
std::vector<size_t> centroidIds = std::vector<size_t>()) {
size_t clusterSize = std::numeric_limits<size_t>::max();
assign(vectors, clusters, clusterSize);
std::cout << "The number of vectors=" << vectors.size() << std::endl;
std::cout << "The number of centroids=" << clusters.size() << std::endl;
if (centroidIds.size() == 0) {
switch (mode) {
case 'e':
std::cout << "MSE=" << calculateMSE(vectors, clusters) << std::endl;
break;
case '2':
default:
std::cout << "ML2=" << calculateML2(vectors, clusters) << std::endl;
break;
}
} else {
switch (mode) {
case 'e':
break;
case '2':
default:
std::cout << "ML2=" << calculateML2FromSpecifiedCentroids(vectors, clusters, centroidIds)
<< std::endl;
break;
}
}
}
ClusteringType clusteringType;
InitializationMode initializationMode;
size_t numberOfClusters;
bool clusterSizeConstraint;
size_t maximumIteration;
float epsilonFrom;
float epsilonTo;
float epsilonStep;
size_t resultSizeCoefficient;
vector<double> diffHistory;
};
} // namespace NGT

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,127 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/Index.h"
namespace NGT {
class Command {
public:
class SearchParameter {
public:
SearchParameter() {
openMode = 'r';
query = "";
querySize = 0;
indexType = 't';
size = 20;
edgeSize = -1;
outputMode = "-";
radius = FLT_MAX;
step = 0;
trial = 1;
beginOfEpsilon = endOfEpsilon = stepOfEpsilon = 0.1;
accuracy = 0.0;
}
SearchParameter(Args &args) { parse(args); }
void parse(Args &args) {
openMode = args.getChar("m", 'r');
try {
query = args.get("#2");
} catch (...) {
NGTThrowException("ngt: Error: Query is not specified");
}
querySize = args.getl("Q", 0);
indexType = args.getChar("i", 't');
size = args.getl("n", 20);
// edgeSize
// -1(default) : using the size which was specified at the index creation.
// 0 : no limitation for the edge size.
// -2('e') : automatically set it according to epsilon.
if (args.getChar("E", '-') == 'e') {
edgeSize = -2;
} else {
edgeSize = args.getl("E", -1);
}
outputMode = args.getString("o", "-");
radius = args.getf("r", FLT_MAX);
trial = args.getl("t", 1);
{
beginOfEpsilon = endOfEpsilon = stepOfEpsilon = 0.1;
std::string epsilon = args.getString("e", "0.1");
std::vector<std::string> tokens;
NGT::Common::tokenize(epsilon, tokens, ":");
if (tokens.size() >= 1) { beginOfEpsilon = endOfEpsilon = NGT::Common::strtod(tokens[0]); }
if (tokens.size() >= 2) { endOfEpsilon = NGT::Common::strtod(tokens[1]); }
if (tokens.size() >= 3) { stepOfEpsilon = NGT::Common::strtod(tokens[2]); }
step = 0;
if (tokens.size() >= 4) { step = NGT::Common::strtol(tokens[3]); }
}
accuracy = args.getf("a", 0.0);
}
char openMode;
std::string query;
size_t querySize;
char indexType;
int size;
long edgeSize;
std::string outputMode;
float radius;
float beginOfEpsilon;
float endOfEpsilon;
float stepOfEpsilon;
float accuracy;
size_t step;
size_t trial;
};
Command():debugLevel(0) {}
void create(Args &args);
void append(Args &args);
static void search(NGT::Index &index, SearchParameter &searchParameter, std::ostream &stream)
{
std::ifstream is(searchParameter.query);
if (!is) {
std::cerr << "Cannot open the specified file. " << searchParameter.query << std::endl;
return;
}
search(index, searchParameter, is, stream);
}
static void search(NGT::Index &index, SearchParameter &searchParameter, std::istream &is, std::ostream &stream);
void search(Args &args);
void remove(Args &args);
void exportIndex(Args &args);
void importIndex(Args &args);
void prune(Args &args);
void reconstructGraph(Args &args);
void optimizeSearchParameters(Args &args);
void optimizeNumberOfEdgesForANNG(Args &args);
void refineANNG(Args &args);
void repair(Args &args);
void info(Args &args);
void setDebugLevel(int level) { debugLevel = level; }
int getDebugLevel() { return debugLevel; }
protected:
int debugLevel;
};
}; // NGT

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,213 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "faiss/utils/ConcurrentBitset.h"
#include <cstring>
namespace faiss {
ConcurrentBitset::ConcurrentBitset(id_type_t capacity, uint8_t init_value) : capacity_(capacity), bitset_(((capacity + 8 - 1) >> 3)) {
if (init_value) {
memset(mutable_data(), init_value, (capacity + 8 - 1) >> 3);
}
}
std::vector<std::atomic<uint8_t>>&
ConcurrentBitset::bitset() {
return bitset_;
}
ConcurrentBitset&
ConcurrentBitset::operator&=(ConcurrentBitset& bitset) {
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
// bitset_[i].fetch_and(bitset.bitset()[i].load());
// }
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset.data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
u64_1[i] &= u64_2[i];
}
size_t remain = n8 % 8;
u8_1 += n64 * 8;
u8_2 += n64 * 8;
for (size_t i = 0; i < remain; i++) {
u8_1[i] &= u8_2[i];
}
return *this;
}
std::shared_ptr<ConcurrentBitset>
ConcurrentBitset::operator&(const std::shared_ptr<ConcurrentBitset>& bitset) {
auto result_bitset = std::make_shared<ConcurrentBitset>(bitset->capacity());
auto result_8 = const_cast<uint8_t*>(result_bitset->data());
auto result_64 = reinterpret_cast<uint64_t*>(result_8);
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset->data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
result_64[i] = u64_1[i] & u64_2[i];
}
size_t remain = n8 % 8;
u8_1 += n64 * 8;
u8_2 += n64 * 8;
result_8 += n64 * 8;
for (size_t i = 0; i < remain; i++) {
result_8[i] = u8_1[i] & u8_2[i];
}
return result_bitset;
}
ConcurrentBitset&
ConcurrentBitset::operator|=(ConcurrentBitset& bitset) {
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
// bitset_[i].fetch_or(bitset.bitset()[i].load());
// }
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset.data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
u64_1[i] |= u64_2[i];
}
size_t remain = n8 % 8;
u8_1 += n64 * 8;
u8_2 += n64 * 8;
for (size_t i = 0; i < remain; i++) {
u8_1[i] |= u8_2[i];
}
return *this;
}
std::shared_ptr<ConcurrentBitset>
ConcurrentBitset::operator|(const std::shared_ptr<ConcurrentBitset>& bitset) {
auto result_bitset = std::make_shared<ConcurrentBitset>(bitset->capacity());
auto result_8 = const_cast<uint8_t*>(result_bitset->data());
auto result_64 = reinterpret_cast<uint64_t*>(result_8);
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset->data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
result_64[i] = u64_1[i] | u64_2[i];
}
size_t remain = n8 % 8;
u8_1 += n64 * 8;
u8_2 += n64 * 8;
result_8 += n64 * 8;
for (size_t i = 0; i < remain; i++) {
result_8[i] = u8_1[i] | u8_2[i];
}
return result_bitset;
}
ConcurrentBitset&
ConcurrentBitset::operator^=(ConcurrentBitset& bitset) {
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
// bitset_[i].fetch_xor(bitset.bitset()[i].load());
// }
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset.data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
u64_1[i] &= u64_2[i];
}
size_t remain = n8 % 8;
u8_1 += n64 * 8;
u8_2 += n64 * 8;
for (size_t i = 0; i < remain; i++) {
u8_1[i] ^= u8_2[i];
}
return *this;
}
bool
ConcurrentBitset::test(id_type_t id) {
return bitset_[id >> 3].load() & (0x1 << (id & 0x7));
}
void
ConcurrentBitset::set(id_type_t id) {
bitset_[id >> 3].fetch_or(0x1 << (id & 0x7));
}
void
ConcurrentBitset::clear(id_type_t id) {
bitset_[id >> 3].fetch_and(~(0x1 << (id & 0x7)));
}
size_t
ConcurrentBitset::capacity() {
return capacity_;
}
size_t
ConcurrentBitset::size() {
return ((capacity_ + 8 - 1) >> 3);
}
const uint8_t*
ConcurrentBitset::data() {
return reinterpret_cast<const uint8_t*>(bitset_.data());
}
uint8_t*
ConcurrentBitset::mutable_data() {
return reinterpret_cast<uint8_t*>(bitset_.data());
}
} // namespace faiss

View File

@ -0,0 +1,15 @@
#include "NGT/GetCoreNumber.h"
namespace NGT
{
int getCoreNumber()
{
#ifndef __linux__
SYSTEM_INFO sys_info;
GetSystemInfo(&sys_info);
return sysInfo.dwNumberOfProcessors;
#else
return get_nprocs();
#endif
}
}

View File

@ -0,0 +1,12 @@
#ifndef __linux__
# include "windows.h"
#else
# include "sys/sysinfo.h"
# include "unistd.h"
#endif
namespace NGT
{
int getCoreNumber();
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,948 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <bitset>
#include <sstream>
#include "NGT/defines.h"
#include "NGT/Common.h"
#include "NGT/ObjectSpaceRepository.h"
#include "faiss/utils/ConcurrentBitset.h"
#include "NGT/HashBasedBooleanSet.h"
#ifndef NGT_GRAPH_CHECK_VECTOR
#include <unordered_set>
#endif
#ifdef NGT_GRAPH_UNCHECK_STACK
#include <stack>
#endif
#ifndef NGT_EXPLORATION_COEFFICIENT
#define NGT_EXPLORATION_COEFFICIENT 1.1
#endif
#ifndef NGT_INSERTION_EXPLORATION_COEFFICIENT
#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1
#endif
#ifndef NGT_TRUNCATION_THRESHOLD
#define NGT_TRUNCATION_THRESHOLD 50
#endif
#ifndef NGT_SEED_SIZE
#define NGT_SEED_SIZE 10
#endif
#ifndef NGT_CREATION_EDGE_SIZE
#define NGT_CREATION_EDGE_SIZE 10
#endif
namespace NGT {
class Property;
typedef GraphNode GRAPH_NODE;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
class GraphRepository: public PersistentRepository<GRAPH_NODE> {
#else
class GraphRepository: public Repository<GRAPH_NODE> {
#endif
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
typedef PersistentRepository<GRAPH_NODE> VECTOR;
#else
typedef Repository<GRAPH_NODE> VECTOR;
GraphRepository() {
prevsize = new vector<unsigned short>;
}
virtual ~GraphRepository() {
deleteAll();
if (prevsize != 0) {
delete prevsize;
}
}
#endif
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void open(const std::string &file, size_t sharedMemorySize) {
SharedMemoryAllocator &allocator = VECTOR::getAllocator();
off_t *entryTable = (off_t*)allocator.construct(file, sharedMemorySize);
if (entryTable == 0) {
entryTable = (off_t*)construct();
allocator.setEntry(entryTable);
}
assert(entryTable != 0);
this->initialize(entryTable);
}
void *construct() {
SharedMemoryAllocator &allocator = VECTOR::getAllocator();
off_t *entryTable = new(allocator) off_t[2];
entryTable[0] = allocator.getOffset(PersistentRepository<GRAPH_NODE>::construct());
entryTable[1] = allocator.getOffset(new(allocator) Vector<unsigned short>);
return entryTable;
}
void initialize(void *e) {
SharedMemoryAllocator &allocator = VECTOR::getAllocator();
off_t *entryTable = (off_t*)e;
array = (ARRAY*)allocator.getAddr(entryTable[0]);
PersistentRepository<GRAPH_NODE>::initialize(allocator.getAddr(entryTable[0]));
prevsize = (Vector<unsigned short>*)allocator.getAddr(entryTable[1]);
}
#endif
void insert(ObjectID id, ObjectDistances &objects) {
GRAPH_NODE *r = allocate();
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
(*r).copy(objects, VECTOR::getAllocator());
#else
*r = objects;
#endif
try {
put(id, r);
} catch (Exception &exp) {
delete r;
throw exp;
}
if (id >= prevsize->size()) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
prevsize->resize(id + 1, VECTOR::getAllocator(), 0);
#else
prevsize->resize(id + 1, 0);
#endif
} else {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
(*prevsize).at(id, VECTOR::getAllocator()) = 0;
#else
(*prevsize)[id] = 0;
#endif
}
return;
}
inline GRAPH_NODE *get(ObjectID fid, size_t &minsize) {
GRAPH_NODE *rs = VECTOR::get(fid);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
minsize = (*prevsize).at(fid, VECTOR::getAllocator());
#else
minsize = (*prevsize)[fid];
#endif
return rs;
}
void serialize(std::ofstream &os) {
VECTOR::serialize(os);
Serializer::write(os, *prevsize);
}
// for milvus
void serialize(std::stringstream & grp)
{
VECTOR::serialize(grp);
Serializer::write(grp, *prevsize);
}
void deserialize(std::ifstream &is) {
VECTOR::deserialize(is);
Serializer::read(is, *prevsize);
}
// for milvus
void deserialize(std::stringstream & is)
{
VECTOR::deserialize(is);
Serializer::read(is, *prevsize);
}
void show() {
for (size_t i = 0; i < this->size(); i++) {
std::cout << "Show graph " << i << " ";
if ((*this)[i] == 0) {
std::cout << std::endl;
continue;
}
for (size_t j = 0; j < (*this)[i]->size(); j++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cout << (*this)[i]->at(j, VECTOR::getAllocator()).id << ":" << (*this)[i]->at(j, VECTOR::getAllocator()).distance << " ";
#else
std::cout << (*this)[i]->at(j).id << ":" << (*this)[i]->at(j).distance << " ";
#endif
}
std::cout << std::endl;
}
}
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
Vector<unsigned short> *prevsize;
#else
std::vector<unsigned short> *prevsize;
#endif
};
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
class ReadOnlyGraphNode : public std::vector<std::pair<uint64_t, PersistentObject*>> {
public:
ReadOnlyGraphNode():reservedSize(0), usedSize(0) {}
void reserve(size_t s) {
reservedSize = ((s & 7) == 0) ? s : (s & 0xFFFFFFFFFFFFFFF8) + 8;
resize(reservedSize);
for (size_t i = (reservedSize & 0xFFFFFFFFFFFFFFF8); i < reservedSize; i++) {
(*this)[i].first = 0;
}
}
void push_back(std::pair<uint32_t, PersistentObject*> node) {
(*this)[usedSize] = node;
usedSize++;
}
size_t size() { return usedSize; }
size_t reservedSize;
size_t usedSize;
};
class SearchGraphRepository : public std::vector<ReadOnlyGraphNode> {
public:
SearchGraphRepository() {}
bool isEmpty(size_t idx) { return (*this)[idx].empty(); }
void deserialize(std::ifstream &is, ObjectRepository &objectRepository) {
if (!is.is_open()) {
NGTThrowException("NGT::SearchGraph: Not open the specified stream yet.");
}
clear();
size_t s;
NGT::Serializer::read(is, s);
resize(s);
for (size_t id = 0; id < s; id++) {
char type;
NGT::Serializer::read(is, type);
switch(type) {
case '-':
break;
case '+':
{
ObjectDistances node;
node.deserialize(is);
ReadOnlyGraphNode &searchNode = at(id);
searchNode.reserve(node.size());
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
for (auto ni = node.begin(); ni != node.end(); ni++) {
std::cerr << "not implement" << std::endl;
abort();
}
#else
for (auto ni = node.begin(); ni != node.end(); ni++) {
searchNode.push_back(std::pair<uint32_t, Object*>((*ni).id, objectRepository.get((*ni).id)));
}
#endif
}
break;
default:
{
assert(type == '-' || type == '+');
break;
}
}
}
}
};
#endif // NGT_GRAPH_READ_ONLY_GRAPH
class NeighborhoodGraph {
public:
enum GraphType {
GraphTypeNone = 0,
GraphTypeANNG = 1,
GraphTypeKNNG = 2,
GraphTypeBKNNG = 3,
GraphTypeONNG = 4,
GraphTypeIANNG = 5, // Improved ANNG
GraphTypeDNNG = 6
};
enum SeedType {
SeedTypeNone = 0,
SeedTypeRandomNodes = 1,
SeedTypeFixedNodes = 2,
SeedTypeFirstNode = 3,
SeedTypeAllLeafNodes = 4
};
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
class Search {
public:
static void (*getMethod(NGT::ObjectSpace::DistanceType dtype, NGT::ObjectSpace::ObjectType otype, size_t size))(NGT::NeighborhoodGraph&, NGT::SearchContainer&, NGT::ObjectDistances&) {
if (size < 5000000) {
switch (otype) {
default:
case NGT::ObjectSpace::Float:
switch (dtype) {
case NGT::ObjectSpace::DistanceTypeNormalizedCosine : return normalizedCosineSimilarityFloat;
case NGT::ObjectSpace::DistanceTypeCosine : return cosineSimilarityFloat;
case NGT::ObjectSpace::DistanceTypeNormalizedAngle : return normalizedAngleFloat;
case NGT::ObjectSpace::DistanceTypeAngle : return angleFloat;
case NGT::ObjectSpace::DistanceTypeL2 : return l2Float;
case NGT::ObjectSpace::DistanceTypeL1 : return l1Float;
case NGT::ObjectSpace::DistanceTypeSparseJaccard : return sparseJaccardFloat;
default: return l2Float;
}
break;
case NGT::ObjectSpace::Uint8:
switch (dtype) {
case NGT::ObjectSpace::DistanceTypeHamming : return hammingUint8;
case NGT::ObjectSpace::DistanceTypeJaccard : return jaccardUint8;
case NGT::ObjectSpace::DistanceTypeL2 : return l2Uint8;
case NGT::ObjectSpace::DistanceTypeL1 : return l1Uint8;
default : return l2Uint8;
}
break;
}
return l1Uint8;
} else {
switch (otype) {
default:
case NGT::ObjectSpace::Float:
switch (dtype) {
case NGT::ObjectSpace::DistanceTypeNormalizedCosine : return normalizedCosineSimilarityFloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeCosine : return cosineSimilarityFloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeNormalizedAngle : return normalizedAngleFloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeAngle : return angleFloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeL2 : return l2FloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeL1 : return l1FloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeSparseJaccard : return sparseJaccardFloatForLargeDataset;
default: return l2FloatForLargeDataset;
}
break;
case NGT::ObjectSpace::Uint8:
switch (dtype) {
case NGT::ObjectSpace::DistanceTypeHamming : return hammingUint8ForLargeDataset;
case NGT::ObjectSpace::DistanceTypeJaccard : return jaccardUint8ForLargeDataset;
case NGT::ObjectSpace::DistanceTypeL2 : return l2Uint8ForLargeDataset;
case NGT::ObjectSpace::DistanceTypeL1 : return l1Uint8ForLargeDataset;
default : return l2Uint8ForLargeDataset;
}
break;
}
return l1Uint8ForLargeDataset;
}
}
static void l1Uint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l2Uint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l1Float(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l2Float(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void hammingUint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void jaccardUint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void sparseJaccardFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void cosineSimilarityFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void angleFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void normalizedCosineSimilarityFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void normalizedAngleFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l1Uint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l2Uint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l1FloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l2FloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void hammingUint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void jaccardUint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void sparseJaccardFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void cosineSimilarityFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void angleFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void normalizedCosineSimilarityFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void normalizedAngleFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
};
#endif
class Property {
public:
Property() { setDefault(); }
void setDefault() {
truncationThreshold = 0;
edgeSizeForCreation = NGT_CREATION_EDGE_SIZE;
edgeSizeForSearch = 0;
edgeSizeLimitForCreation = 5;
insertionRadiusCoefficient = NGT_INSERTION_EXPLORATION_COEFFICIENT;
seedSize = NGT_SEED_SIZE;
seedType = SeedTypeNone;
truncationThreadPoolSize = 8;
batchSizeForCreation = 200;
graphType = GraphTypeANNG;
dynamicEdgeSizeBase = 30;
dynamicEdgeSizeRate = 20;
buildTimeLimit = 0.0;
outgoingEdge = 10;
incomingEdge = 80;
}
void clear() {
truncationThreshold = -1;
edgeSizeForCreation = -1;
edgeSizeForSearch = -1;
edgeSizeLimitForCreation = -1;
insertionRadiusCoefficient = -1;
seedSize = -1;
seedType = SeedTypeNone;
truncationThreadPoolSize = -1;
batchSizeForCreation = -1;
graphType = GraphTypeNone;
dynamicEdgeSizeBase = -1;
dynamicEdgeSizeRate = -1;
buildTimeLimit = -1;
outgoingEdge = -1;
incomingEdge = -1;
}
void set(NGT::Property &prop);
void get(NGT::Property &prop);
void exportProperty(NGT::PropertySet &p) {
p.set("IncrimentalEdgeSizeLimitForTruncation", truncationThreshold);
p.set("EdgeSizeForCreation", edgeSizeForCreation);
p.set("EdgeSizeForSearch", edgeSizeForSearch);
p.set("EdgeSizeLimitForCreation", edgeSizeLimitForCreation);
assert(insertionRadiusCoefficient >= 1.0);
p.set("EpsilonForCreation", insertionRadiusCoefficient - 1.0);
p.set("BatchSizeForCreation", batchSizeForCreation);
p.set("SeedSize", seedSize);
p.set("TruncationThreadPoolSize", truncationThreadPoolSize);
p.set("DynamicEdgeSizeBase", dynamicEdgeSizeBase);
p.set("DynamicEdgeSizeRate", dynamicEdgeSizeRate);
p.set("BuildTimeLimit", buildTimeLimit);
p.set("OutgoingEdge", outgoingEdge);
p.set("IncomingEdge", incomingEdge);
switch (graphType) {
case NeighborhoodGraph::GraphTypeKNNG: p.set("GraphType", "KNNG"); break;
case NeighborhoodGraph::GraphTypeANNG: p.set("GraphType", "ANNG"); break;
case NeighborhoodGraph::GraphTypeBKNNG: p.set("GraphType", "BKNNG"); break;
case NeighborhoodGraph::GraphTypeONNG: p.set("GraphType", "ONNG"); break;
case NeighborhoodGraph::GraphTypeIANNG: p.set("GraphType", "IANNG"); break;
default: std::cerr << "Graph::exportProperty: Fatal error! Invalid Graph Type." << std::endl; abort();
}
switch (seedType) {
case NeighborhoodGraph::SeedTypeRandomNodes: p.set("SeedType", "RandomNodes"); break;
case NeighborhoodGraph::SeedTypeFixedNodes: p.set("SeedType", "FixedNodes"); break;
case NeighborhoodGraph::SeedTypeFirstNode: p.set("SeedType", "FirstNode"); break;
case NeighborhoodGraph::SeedTypeNone: p.set("SeedType", "None"); break;
case NeighborhoodGraph::SeedTypeAllLeafNodes: p.set("SeedType", "AllLeafNodes"); break;
default: std::cerr << "Graph::exportProperty: Fatal error! Invalid Seed Type." << std::endl; abort();
}
}
void importProperty(NGT::PropertySet &p) {
setDefault();
truncationThreshold = p.getl("IncrimentalEdgeSizeLimitForTruncation", truncationThreshold);
edgeSizeForCreation = p.getl("EdgeSizeForCreation", edgeSizeForCreation);
edgeSizeForSearch = p.getl("EdgeSizeForSearch", edgeSizeForSearch);
edgeSizeLimitForCreation = p.getl("EdgeSizeLimitForCreation", edgeSizeLimitForCreation);
insertionRadiusCoefficient = p.getf("EpsilonForCreation", insertionRadiusCoefficient);
insertionRadiusCoefficient += 1.0;
batchSizeForCreation = p.getl("BatchSizeForCreation", batchSizeForCreation);
seedSize = p.getl("SeedSize", seedSize);
truncationThreadPoolSize = p.getl("TruncationThreadPoolSize", truncationThreadPoolSize);
dynamicEdgeSizeBase = p.getl("DynamicEdgeSizeBase", dynamicEdgeSizeBase);
dynamicEdgeSizeRate = p.getl("DynamicEdgeSizeRate", dynamicEdgeSizeRate);
buildTimeLimit = p.getf("BuildTimeLimit", buildTimeLimit);
outgoingEdge = p.getl("OutgoingEdge", outgoingEdge);
incomingEdge = p.getl("IncomingEdge", incomingEdge);
PropertySet::iterator it = p.find("GraphType");
if (it != p.end()) {
if (it->second == "KNNG") graphType = NeighborhoodGraph::GraphTypeKNNG;
else if (it->second == "ANNG") graphType = NeighborhoodGraph::GraphTypeANNG;
else if (it->second == "BKNNG") graphType = NeighborhoodGraph::GraphTypeBKNNG;
else if (it->second == "ONNG") graphType = NeighborhoodGraph::GraphTypeONNG;
else if (it->second == "IANNG") graphType = NeighborhoodGraph::GraphTypeIANNG;
else { std::cerr << "Graph::importProperty: Fatal error! Invalid Graph Type. " << it->second << std::endl; abort(); }
}
it = p.find("SeedType");
if (it != p.end()) {
if (it->second == "RandomNodes") seedType = NeighborhoodGraph::SeedTypeRandomNodes;
else if (it->second == "FixedNodes") seedType = NeighborhoodGraph::SeedTypeFixedNodes;
else if (it->second == "FirstNode") seedType = NeighborhoodGraph::SeedTypeFirstNode;
else if (it->second == "None") seedType = NeighborhoodGraph::SeedTypeNone;
else if (it->second == "AllLeafNodes") seedType = NeighborhoodGraph::SeedTypeAllLeafNodes;
else { std::cerr << "Graph::importProperty: Fatal error! Invalid Seed Type. " << it->second << std::endl; abort(); }
}
}
friend std::ostream & operator<<(std::ostream& os, const Property& p) {
os << "truncationThreshold=" << p.truncationThreshold << std::endl;
os << "edgeSizeForCreation=" << p.edgeSizeForCreation << std::endl;
os << "edgeSizeForSearch=" << p.edgeSizeForSearch << std::endl;
os << "edgeSizeLimitForCreation=" << p.edgeSizeLimitForCreation << std::endl;
os << "insertionRadiusCoefficient=" << p.insertionRadiusCoefficient << std::endl;
os << "insertionRadiusCoefficient=" << p.insertionRadiusCoefficient << std::endl;
os << "seedSize=" << p.seedSize << std::endl;
os << "seedType=" << p.seedType << std::endl;
os << "truncationThreadPoolSize=" << p.truncationThreadPoolSize << std::endl;
os << "batchSizeForCreation=" << p.batchSizeForCreation << std::endl;
os << "graphType=" << p.graphType << std::endl;
os << "dynamicEdgeSizeBase=" << p.dynamicEdgeSizeBase << std::endl;
os << "dynamicEdgeSizeRate=" << p.dynamicEdgeSizeRate << std::endl;
os << "outgoingEdge=" << p.outgoingEdge << std::endl;
os << "incomingEdge=" << p.incomingEdge << std::endl;
return os;
}
int16_t truncationThreshold;
int16_t edgeSizeForCreation;
int16_t edgeSizeForSearch;
int16_t edgeSizeLimitForCreation;
double insertionRadiusCoefficient;
int16_t seedSize;
SeedType seedType;
int16_t truncationThreadPoolSize;
int16_t batchSizeForCreation;
GraphType graphType;
int16_t dynamicEdgeSizeBase;
int16_t dynamicEdgeSizeRate;
float buildTimeLimit;
int16_t outgoingEdge;
int16_t incomingEdge;
};
NeighborhoodGraph(): objectSpace(0) {
property.truncationThreshold = NGT_TRUNCATION_THRESHOLD;
// initialize random to generate random seeds
#ifdef NGT_DISABLE_SRAND_FOR_RANDOM
struct timeval randTime;
gettimeofday(&randTime, 0);
srand(randTime.tv_usec);
#endif
}
inline GraphNode *getNode(ObjectID fid, size_t &minsize) { return repository.get(fid, minsize); }
inline GraphNode *getNode(ObjectID fid) { return repository.VECTOR::get(fid); }
void insertNode(ObjectID id, ObjectDistances &objects) {
switch (property.graphType) {
case GraphTypeANNG:
insertANNGNode(id, objects);
break;
case GraphTypeIANNG:
insertIANNGNode(id, objects);
break;
case GraphTypeONNG:
insertONNGNode(id, objects);
break;
case GraphTypeKNNG:
insertKNNGNode(id, objects);
break;
case GraphTypeBKNNG:
insertBKNNGNode(id, objects);
break;
case GraphTypeNone:
NGTThrowException("NGT::insertNode: GraphType is not specified.");
break;
default:
NGTThrowException("NGT::insertNode: GraphType is invalid.");
break;
}
}
void insertBKNNGNode(ObjectID id, ObjectDistances &results) {
if (repository.isEmpty(id)) {
repository.insert(id, results);
} else {
GraphNode &rs = *getNode(id);
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
rs.push_back((*ri), repository.allocator);
#else
rs.push_back((*ri));
#endif
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::sort(rs.begin(repository.allocator), rs.end(repository.allocator));
ObjectID prev = 0;
for (GraphNode::iterator ri = rs.begin(repository.allocator); ri != rs.end(repository.allocator);) {
if (prev == (*ri).id) {
ri = rs.erase(ri, repository.allocator);
continue;
}
prev = (*ri).id;
ri++;
}
#else
std::sort(rs.begin(), rs.end());
ObjectID prev = 0;
for (GraphNode::iterator ri = rs.begin(); ri != rs.end();) {
if (prev == (*ri).id) {
ri = rs.erase(ri);
continue;
}
prev = (*ri).id;
ri++;
}
#endif
}
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
assert(id != (*ri).id);
addBKNNGEdge((*ri).id, id, (*ri).distance);
}
return;
}
void insertKNNGNode(ObjectID id, ObjectDistances &results) {
repository.insert(id, results);
}
void insertANNGNode(ObjectID id, ObjectDistances &results) {
repository.insert(id, results);
std::queue<ObjectID> truncateQueue;
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
assert(id != (*ri).id);
if (addEdge((*ri).id, id, (*ri).distance)) {
truncateQueue.push((*ri).id);
}
}
while (!truncateQueue.empty()) {
ObjectID tid = truncateQueue.front();
truncateEdges(tid);
truncateQueue.pop();
}
return;
}
void insertIANNGNode(ObjectID id, ObjectDistances &results) {
repository.insert(id, results);
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
assert(id != (*ri).id);
addEdgeDeletingExcessEdges((*ri).id, id, (*ri).distance);
}
return;
}
void insertONNGNode(ObjectID id, ObjectDistances &results) {
if (property.truncationThreshold != 0) {
std::stringstream msg;
msg << "NGT::insertONNGNode: truncation should be disabled!" << std::endl;
NGTThrowException(msg);
}
int count = 0;
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++, count++) {
assert(id != (*ri).id);
if (count >= property.incomingEdge) {
break;
}
addEdge((*ri).id, id, (*ri).distance);
}
if (static_cast<int>(results.size()) > property.outgoingEdge) {
results.resize(property.outgoingEdge);
}
repository.insert(id, results);
}
void removeEdgesReliably(ObjectID id);
int truncateEdgesOptimally(ObjectID id, GraphNode &results, size_t truncationSize);
int truncateEdges(ObjectID id) {
GraphNode &results = *getNode(id);
if (results.size() == 0) {
return -1;
}
size_t truncationSize = NGT_TRUNCATION_THRESHOLD;
if (truncationSize < (size_t)property.edgeSizeForCreation) {
truncationSize = property.edgeSizeForCreation;
}
return truncateEdgesOptimally(id, results, truncationSize);
}
// setup edgeSize
inline size_t getEdgeSize(NGT::SearchContainer &sc) {
size_t edgeSize = INT_MAX;
if (sc.edgeSize < 0) {
if (sc.edgeSize == -2) {
double add = pow(10, (sc.explorationCoefficient - 1.0) * static_cast<float>(property.dynamicEdgeSizeRate));
edgeSize = add >= static_cast<double>(INT_MAX) ? INT_MAX : property.dynamicEdgeSizeBase + add;
} else {
edgeSize = property.edgeSizeForSearch == 0 ? INT_MAX : property.edgeSizeForSearch;
}
} else {
edgeSize = sc.edgeSize == 0 ? INT_MAX : sc.edgeSize;
}
return edgeSize;
}
void search(NGT::SearchContainer &sc, ObjectDistances &seeds);
// for milvus
void search(NGT::SearchContainer & sc, ObjectDistances & seeds, faiss::ConcurrentBitsetPtr & bitset);
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
template <typename COMPARATOR, typename CHECK_LIST> void searchReadOnlyGraph(NGT::SearchContainer &sc, ObjectDistances &seeds);
#endif
void removeEdge(ObjectID fid, ObjectID rmid) {
GraphNode &rs = *getNode(fid);
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
for (GraphNode::iterator ri = rs.begin(repository.allocator); ri != rs.end(repository.allocator); ri++) {
if ((*ri).id == rmid) {
rs.erase(ri, repository.allocator);
break;
}
}
#else
for (GraphNode::iterator ri = rs.begin(); ri != rs.end(); ri++) {
if ((*ri).id == rmid) {
rs.erase(ri);
break;
}
}
#endif
}
void removeEdge(GraphNode &node, ObjectDistance &edge) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
GraphNode::iterator ni = std::lower_bound(node.begin(repository.allocator), node.end(repository.allocator), edge);
if (ni != node.end(repository.allocator) && (*ni).id == edge.id) {
node.erase(ni, repository.allocator);
#else
GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), edge);
if (ni != node.end() && (*ni).id == edge.id) {
node.erase(ni);
#endif
return;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (ni == node.end(repository.allocator)) {
#else
if (ni == node.end()) {
#endif
std::stringstream msg;
msg << "NGT::removeEdge: Cannot found " << edge.id;
NGTThrowException(msg);
} else {
std::stringstream msg;
msg << "NGT::removeEdge: Cannot found " << (*ni).id << ":" << edge.id;
NGTThrowException(msg);
}
}
void
removeNode(ObjectID id) {
repository.erase(id);
}
class BooleanVector : public std::vector<bool> {
public:
inline BooleanVector(size_t s):std::vector<bool>(s, false) {}
inline void insert(size_t i) { std::vector<bool>::operator[](i) = true; }
};
#ifdef NGT_GRAPH_VECTOR_RESULT
typedef ObjectDistances ResultSet;
#else
typedef std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > ResultSet;
#endif
#if defined(NGT_GRAPH_CHECK_BOOLEANSET)
typedef BooleanSet DistanceCheckedSet;
#elif defined(NGT_GRAPH_CHECK_VECTOR)
typedef BooleanVector DistanceCheckedSet;
#elif defined(NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET)
typedef HashBasedBooleanSet DistanceCheckedSet;
#else
class DistanceCheckedSet : public unordered_set<ObjectID> {
public:
bool operator[](ObjectID id) { return find(id) != end(); }
};
#endif
typedef HashBasedBooleanSet DistanceCheckedSetForLargeDataset;
class NodeWithPosition : public ObjectDistance {
public:
NodeWithPosition(uint32_t p = 0):position(p){}
NodeWithPosition(ObjectDistance &o):ObjectDistance(o), position(0){}
NodeWithPosition &operator=(const NodeWithPosition &n) {
ObjectDistance::operator=(static_cast<const ObjectDistance&>(n));
position = n.position;
assert(id != 0);
return *this;
}
uint32_t position;
};
#ifdef NGT_GRAPH_UNCHECK_STACK
typedef std::stack<ObjectDistance> UncheckedSet;
#else
#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE
typedef std::priority_queue<NodeWithPosition, std::vector<NodeWithPosition>, std::greater<NodeWithPosition> > UncheckedSet;
#else
typedef std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::greater<ObjectDistance> > UncheckedSet;
#endif
#endif
void setupDistances(NGT::SearchContainer &sc, ObjectDistances &seeds);
void setupDistances(NGT::SearchContainer &sc, ObjectDistances &seeds, double (&comparator)(const void*, const void*, size_t));
void setupSeeds(SearchContainer &sc, ObjectDistances &seeds, ResultSet &results,
UncheckedSet &unchecked, DistanceCheckedSet &distanceChecked);
#if !defined(NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET)
void setupSeeds(SearchContainer &sc, ObjectDistances &seeds, ResultSet &results,
UncheckedSet &unchecked, DistanceCheckedSetForLargeDataset &distanceChecked);
#endif
int getEdgeSize() {return property.edgeSizeForCreation;}
ObjectRepository &getObjectRepository() { return objectSpace->getRepository(); }
ObjectSpace &getObjectSpace() { return *objectSpace; }
void deleteInMemory() {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
assert(0);
#else
for (std::vector<NGT::GraphNode*>::iterator i = repository.begin(); i != repository.end(); i++) {
if ((*i) != 0) {
delete (*i);
}
}
repository.clear();
#endif
}
static double (*getComparator())(const void*, const void*, size_t);
protected:
void
addBKNNGEdge(ObjectID target, ObjectID addID, Distance addDistance) {
if (repository.isEmpty(target)) {
ObjectDistances objs;
objs.push_back(ObjectDistance(addID, addDistance));
repository.insert(target, objs);
return;
}
addEdge(target, addID, addDistance, false);
}
public:
void addEdge(GraphNode &node, ObjectID addID, Distance addDistance, bool identityCheck = true) {
ObjectDistance obj(addID, addDistance);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
GraphNode::iterator ni = std::lower_bound(node.begin(repository.allocator), node.end(repository.allocator), obj);
if ((ni != node.end(repository.allocator)) && ((*ni).id == addID)) {
if (identityCheck) {
std::stringstream msg;
msg << "NGT::addEdge: already existed! " << (*ni).id << ":" << addID;
NGTThrowException(msg);
}
return;
}
#else
GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), obj);
if ((ni != node.end()) && ((*ni).id == addID)) {
if (identityCheck) {
std::stringstream msg;
msg << "NGT::addEdge: already existed! " << (*ni).id << ":" << addID;
NGTThrowException(msg);
}
return;
}
#endif
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
node.insert(ni, obj, repository.allocator);
#else
node.insert(ni, obj);
#endif
}
// identityCheck is checking whether the same edge has already added to the node.
// return whether truncation is needed that means the node has too many edges.
bool addEdge(ObjectID target, ObjectID addID, Distance addDistance, bool identityCheck = true) {
size_t minsize = 0;
GraphNode &node = property.truncationThreshold == 0 ? *getNode(target) : *getNode(target, minsize);
addEdge(node, addID, addDistance, identityCheck);
if ((size_t)property.truncationThreshold != 0 && node.size() - minsize >
(size_t)property.truncationThreshold) {
return true;
}
return false;
}
void addEdgeDeletingExcessEdges(ObjectID target, ObjectID addID, Distance addDistance, bool identityCheck = true) {
GraphNode &node = *getNode(target);
size_t kEdge = property.edgeSizeForCreation - 1;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
if (node.size() > kEdge && node.at(kEdge, repository.allocator).distance >= addDistance) {
GraphNode &linkedNode = *getNode(node.at(kEdge, repository.allocator).id);
ObjectDistance linkedNodeEdge(target, node.at(kEdge, repository.allocator).distance);
if ((linkedNode.size() > kEdge) && node.at(kEdge, repository.allocator).distance >=
linkedNode.at(kEdge, repository.allocator).distance) {
#else
if (node.size() > kEdge && node[kEdge].distance >= addDistance) {
GraphNode &linkedNode = *getNode(node[kEdge].id);
ObjectDistance linkedNodeEdge(target, node[kEdge].distance);
if ((linkedNode.size() > kEdge) && node[kEdge].distance >= linkedNode[kEdge].distance) {
#endif
try {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
removeEdge(node, node.at(kEdge, repository.allocator));
#else
removeEdge(node, node[kEdge]);
#endif
} catch (Exception &exp) {
std::stringstream msg;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
msg << "addEdge: Cannot remove. (a) " << target << "," << addID << "," << node.at(kEdge, repository.allocator).id << "," << node.at(kEdge, repository.allocator).distance;
#else
msg << "addEdge: Cannot remove. (a) " << target << "," << addID << "," << node[kEdge].id << "," << node[kEdge].distance;
#endif
msg << ":" << exp.what();
NGTThrowException(msg.str());
}
try {
removeEdge(linkedNode, linkedNodeEdge);
} catch (Exception &exp) {
std::stringstream msg;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
msg << "addEdge: Cannot remove. (b) " << target << "," << addID << "," << node.at(kEdge, repository.allocator).id << "," << node.at(kEdge, repository.allocator).distance;
#else
msg << "addEdge: Cannot remove. (b) " << target << "," << addID << "," << node[kEdge].id << "," << node[kEdge].distance;
#endif
msg << ":" << exp.what();
NGTThrowException(msg.str());
}
}
}
addEdge(node, addID, addDistance, identityCheck);
}
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
void loadSearchGraph(const std::string &database) {
std::ifstream isg(database + "/grp");
NeighborhoodGraph::searchRepository.deserialize(isg, NeighborhoodGraph::getObjectRepository());
}
#endif
public:
GraphRepository repository;
ObjectSpace *objectSpace;
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
SearchGraphRepository searchRepository;
#endif
NeighborhoodGraph::Property property;
}; // NeighborhoodGraph
} // NGT

View File

@ -0,0 +1,789 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "GraphReconstructor.h"
#include "Optimizer.h"
namespace NGT {
class GraphOptimizer {
public:
class ANNGEdgeOptimizationParameter {
public:
ANNGEdgeOptimizationParameter() {
initialize();
}
void initialize() {
noOfQueries = 200;
noOfResults = 50;
noOfThreads = 16;
targetAccuracy = 0.9; // when epsilon is 0.0 and all of the edges are used
targetNoOfObjects = 0;
noOfSampleObjects = 100000;
maxNoOfEdges = 100;
}
size_t noOfQueries;
size_t noOfResults;
size_t noOfThreads;
float targetAccuracy;
size_t targetNoOfObjects;
size_t noOfSampleObjects;
size_t maxNoOfEdges;
};
GraphOptimizer(bool unlog = false) {
init();
logDisabled = unlog;
}
GraphOptimizer(int outgoing, int incoming, int nofqs, int nofrs,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m,
bool unlog // stderr log is disabled.
) {
init();
set(outgoing, incoming, nofqs, nofrs, baseAccuracyFrom, baseAccuracyTo,
rateAccuracyFrom, rateAccuracyTo, gte, m);
logDisabled = unlog;
}
void init() {
numOfOutgoingEdges = 10;
numOfIncomingEdges= 120;
numOfQueries = 100;
numOfResults = 20;
baseAccuracyRange = std::pair<float, float>(0.30, 0.50);
rateAccuracyRange = std::pair<float, float>(0.80, 0.90);
gtEpsilon = 0.1;
margin = 0.2;
logDisabled = false;
shortcutReduction = true;
searchParameterOptimization = true;
prefetchParameterOptimization = true;
accuracyTableGeneration = true;
}
void adjustSearchCoefficients(const std::string indexPath){
NGT::Index index(indexPath);
NGT::GraphIndex &graph = static_cast<NGT::GraphIndex&>(index.getIndex());
NGT::Optimizer optimizer(index);
if (logDisabled) {
optimizer.disableLog();
} else {
optimizer.enableLog();
}
try {
auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin);
NGT::NeighborhoodGraph::Property &prop = graph.getGraphProperty();
prop.dynamicEdgeSizeBase = coefficients.first;
prop.dynamicEdgeSizeRate = coefficients.second;
prop.edgeSizeForSearch = -2;
} catch(NGT::Exception &err) {
std::stringstream msg;
msg << "Optimizer::adjustSearchCoefficients: Cannot adjust the search coefficients. " << err.what();
NGTThrowException(msg);
}
graph.saveIndex(indexPath);
}
static double measureQueryTime(NGT::Index &index, size_t start) {
NGT::ObjectSpace &objectSpace = index.getObjectSpace();
NGT::ObjectRepository &objectRepository = objectSpace.getRepository();
size_t nQueries = 200;
nQueries = objectRepository.size() - 1 < nQueries ? objectRepository.size() - 1 : nQueries;
size_t step = objectRepository.size() / nQueries;
assert(step != 0);
std::vector<size_t> ids;
for (size_t startID = start; startID < step; startID++) {
for (size_t id = startID; id < objectRepository.size(); id += step) {
if (!objectRepository.isEmpty(id)) {
ids.push_back(id);
}
}
if (ids.size() >= nQueries) {
ids.resize(nQueries);
break;
}
}
if (nQueries > ids.size()) {
std::cerr << "# of Queries is not enough." << std::endl;
return DBL_MAX;
}
NGT::Timer timer;
timer.reset();
for (auto id = ids.begin(); id != ids.end(); id++) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
NGT::Object *obj = objectSpace.allocateObject(*objectRepository.get(*id));
NGT::SearchContainer searchContainer(*obj);
#else
NGT::SearchContainer searchContainer(*objectRepository.get(*id));
#endif
NGT::ObjectDistances objects;
searchContainer.setResults(&objects);
searchContainer.setSize(10);
searchContainer.setEpsilon(0.1);
timer.restart();
index.search(searchContainer);
timer.stop();
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
objectSpace.deleteObject(obj);
#endif
}
return timer.time * 1000.0;
}
static std::pair<size_t, double> searchMinimumQueryTime(NGT::Index &index, size_t prefetchOffset,
int maxPrefetchSize, size_t seedID) {
NGT::ObjectSpace &objectSpace = index.getObjectSpace();
int step = 256;
int prevPrefetchSize = 64;
size_t minPrefetchSize = 0;
double minTime = DBL_MAX;
for (step = 256; step != 32; step /= 2) {
double prevTime = DBL_MAX;
for (int prefetchSize = prevPrefetchSize - step < 64 ? 64 : prevPrefetchSize - step; prefetchSize <= maxPrefetchSize; prefetchSize += step) {
objectSpace.setPrefetchOffset(prefetchOffset);
objectSpace.setPrefetchSize(prefetchSize);
double time = measureQueryTime(index, seedID);
if (prevTime < time) {
break;
}
prevTime = time;
prevPrefetchSize = prefetchSize;
}
if (minTime > prevTime) {
minTime = prevTime;
minPrefetchSize = prevPrefetchSize;
}
}
return std::make_pair(minPrefetchSize, minTime);
}
static std::pair<size_t, size_t> adjustPrefetchParameters(NGT::Index &index) {
bool gridSearch = false;
{
double time = measureQueryTime(index, 1);
if (time < 500.0) {
gridSearch = true;
}
}
size_t prefetchOffset = 0;
size_t prefetchSize = 0;
std::vector<std::pair<size_t, size_t>> mins;
NGT::ObjectSpace &objectSpace = index.getObjectSpace();
int maxSize = objectSpace.getByteSizeOfObject() * 4;
maxSize = maxSize < 64 * 28 ? maxSize : 64 * 28;
for (int trial = 0; trial < 10; trial++) {
size_t minps = 0;
size_t minpo = 0;
if (gridSearch) {
double minTime = DBL_MAX;
for (size_t po = 1; po <= 10; po++) {
auto min = searchMinimumQueryTime(index, po, maxSize, trial + 1);
if (minTime > min.second) {
minTime = min.second;
minps = min.first;
minpo = po;
}
}
} else {
double prevTime = DBL_MAX;
for (size_t po = 1; po <= 10; po++) {
auto min = searchMinimumQueryTime(index, po, maxSize, trial + 1);
if (prevTime < min.second) {
break;
}
prevTime = min.second;
minps = min.first;
minpo = po;
}
}
if (std::find(mins.begin(), mins.end(), std::make_pair(minpo, minps)) != mins.end()) {
prefetchOffset = minpo;
prefetchSize = minps;
mins.push_back(std::make_pair(minpo, minps));
break;
}
mins.push_back(std::make_pair(minpo, minps));
}
return std::make_pair(prefetchOffset, prefetchSize);
}
void execute(NGT::Index & index_)
{
NGT::GraphIndex & graphIndex = static_cast<NGT::GraphIndex &>(index_.getIndex());
if (numOfOutgoingEdges > 0 || numOfIncomingEdges > 0)
{
if (!logDisabled)
{
std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
}
NGT::Timer timer;
timer.start();
std::vector<NGT::ObjectDistances> graph;
try
{
std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
// extract only edges from the index to reduce the memory usage.
NGT::GraphReconstructor::extractGraph(graph, graphIndex);
NeighborhoodGraph::Property & prop = graphIndex.getGraphProperty();
if (prop.graphType != NGT::NeighborhoodGraph::GraphTypeANNG)
{
NGT::GraphReconstructor::convertToANNG(graph);
}
NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges);
timer.stop();
std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
}
catch (NGT::Exception & err)
{
throw(err);
}
}
if (shortcutReduction)
{
if (!logDisabled)
{
std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
}
try
{
NGT::Timer timer;
timer.start();
NGT::GraphReconstructor::adjustPathsEffectively(graphIndex);
timer.stop();
std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl;
}
catch (NGT::Exception & err)
{
throw(err);
}
}
}
void optimizeSearchParameters(NGT::Index & outIndex)
{
if (searchParameterOptimization)
{
if (!logDisabled)
{
std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
}
NGT::GraphIndex & outGraph = static_cast<NGT::GraphIndex &>(outIndex.getIndex());
NGT::Optimizer optimizer(outIndex);
if (logDisabled)
{
optimizer.disableLog();
}
else
{
optimizer.enableLog();
}
try
{
auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin);
NGT::NeighborhoodGraph::Property & prop = outGraph.getGraphProperty();
prop.dynamicEdgeSizeBase = coefficients.first;
prop.dynamicEdgeSizeRate = coefficients.second;
prop.edgeSizeForSearch = -2;
}
catch (NGT::Exception & err)
{
std::stringstream msg;
msg << "Optimizer::execute: Cannot adjust the search coefficients. " << err.what();
NGTThrowException(msg);
}
}
if (searchParameterOptimization || prefetchParameterOptimization || accuracyTableGeneration)
{
// NGT::GraphIndex & outGraph = static_cast<NGT::GraphIndex &>(*outIndex.getIndex());
if (prefetchParameterOptimization)
{
if (!logDisabled)
{
std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
}
try
{
auto prefetch = adjustPrefetchParameters(outIndex);
NGT::Property prop;
outIndex.getProperty(prop);
prop.prefetchOffset = prefetch.first;
prop.prefetchSize = prefetch.second;
outIndex.setProperty(prop);
}
catch (NGT::Exception & err)
{
std::stringstream msg;
msg << "Optimizer::execute: Cannot adjust prefetch parameters. " << err.what();
NGTThrowException(msg);
}
}
if (accuracyTableGeneration)
{
if (!logDisabled)
{
std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
}
try
{
auto table = NGT::Optimizer::generateAccuracyTable(outIndex, numOfResults, numOfQueries);
NGT::Index::AccuracyTable accuracyTable(table);
NGT::Property prop;
outIndex.getProperty(prop);
prop.accuracyTable = accuracyTable.getString();
outIndex.setProperty(prop);
}
catch (NGT::Exception & err)
{
std::stringstream msg;
msg << "Optimizer::execute: Cannot generate the accuracy table. " << err.what();
NGTThrowException(msg);
}
}
}
}
void execute(
const std::string inIndexPath,
const std::string outIndexPath
){
if (access(outIndexPath.c_str(), 0) == 0) {
std::stringstream msg;
msg << "Optimizer::execute: The specified index exists. " << outIndexPath;
NGTThrowException(msg);
}
const std::string com = "cp -r " + inIndexPath + " " + outIndexPath;
int stat = system(com.c_str());
if (stat != 0) {
std::stringstream msg;
msg << "Optimizer::execute: Cannot create the specified index. " << outIndexPath;
NGTThrowException(msg);
}
{
NGT::StdOstreamRedirector redirector(logDisabled);
NGT::GraphIndex graphIndex(outIndexPath, false);
if (numOfOutgoingEdges > 0 || numOfIncomingEdges > 0) {
if (!logDisabled) {
std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
}
redirector.begin();
NGT::Timer timer;
timer.start();
std::vector<NGT::ObjectDistances> graph;
try {
std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
// extract only edges from the index to reduce the memory usage.
NGT::GraphReconstructor::extractGraph(graph, graphIndex);
NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty();
if (prop.graphType != NGT::NeighborhoodGraph::GraphTypeANNG) {
NGT::GraphReconstructor::convertToANNG(graph);
}
NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges);
timer.stop();
std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
graphIndex.saveGraph(outIndexPath);
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
graphIndex.saveProperty(outIndexPath);
} catch (NGT::Exception &err) {
redirector.end();
throw(err);
}
}
if (shortcutReduction) {
if (!logDisabled) {
std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
}
try {
NGT::Timer timer;
timer.start();
NGT::GraphReconstructor::adjustPathsEffectively(graphIndex);
timer.stop();
std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl;
graphIndex.saveGraph(outIndexPath);
} catch (NGT::Exception &err) {
redirector.end();
throw(err);
}
}
redirector.end();
}
optimizeSearchParameters(outIndexPath);
}
void optimizeSearchParameters(const std::string outIndexPath)
{
if (searchParameterOptimization) {
if (!logDisabled) {
std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
}
NGT::Index outIndex(outIndexPath);
NGT::GraphIndex &outGraph = static_cast<NGT::GraphIndex&>(outIndex.getIndex());
NGT::Optimizer optimizer(outIndex);
if (logDisabled) {
optimizer.disableLog();
} else {
optimizer.enableLog();
}
try {
auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin);
NGT::NeighborhoodGraph::Property &prop = outGraph.getGraphProperty();
prop.dynamicEdgeSizeBase = coefficients.first;
prop.dynamicEdgeSizeRate = coefficients.second;
prop.edgeSizeForSearch = -2;
outGraph.saveProperty(outIndexPath);
} catch(NGT::Exception &err) {
std::stringstream msg;
msg << "Optimizer::execute: Cannot adjust the search coefficients. " << err.what();
NGTThrowException(msg);
}
}
if (searchParameterOptimization || prefetchParameterOptimization || accuracyTableGeneration) {
NGT::StdOstreamRedirector redirector(logDisabled);
redirector.begin();
NGT::Index outIndex(outIndexPath, true);
NGT::GraphIndex &outGraph = static_cast<NGT::GraphIndex&>(outIndex.getIndex());
if (prefetchParameterOptimization) {
if (!logDisabled) {
std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
}
try {
auto prefetch = adjustPrefetchParameters(outIndex);
NGT::Property prop;
outIndex.getProperty(prop);
prop.prefetchOffset = prefetch.first;
prop.prefetchSize = prefetch.second;
outIndex.setProperty(prop);
outGraph.saveProperty(outIndexPath);
} catch(NGT::Exception &err) {
redirector.end();
std::stringstream msg;
msg << "Optimizer::execute: Cannot adjust prefetch parameters. " << err.what();
NGTThrowException(msg);
}
}
if (accuracyTableGeneration) {
if (!logDisabled) {
std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
}
try {
auto table = NGT::Optimizer::generateAccuracyTable(outIndex, numOfResults, numOfQueries);
NGT::Index::AccuracyTable accuracyTable(table);
NGT::Property prop;
outIndex.getProperty(prop);
prop.accuracyTable = accuracyTable.getString();
outIndex.setProperty(prop);
} catch(NGT::Exception &err) {
redirector.end();
std::stringstream msg;
msg << "Optimizer::execute: Cannot generate the accuracy table. " << err.what();
NGTThrowException(msg);
}
}
try {
outGraph.saveProperty(outIndexPath);
redirector.end();
} catch(NGT::Exception &err) {
redirector.end();
std::stringstream msg;
msg << "Optimizer::execute: Cannot save the index. " << outIndexPath << err.what();
NGTThrowException(msg);
}
}
}
static std::tuple<size_t, double, double> // optimized # of edges, accuracy, accuracy gain per edge
optimizeNumberOfEdgesForANNG(NGT::Optimizer &optimizer, std::vector<std::vector<float>> &queries,
size_t nOfResults, float targetAccuracy, size_t maxNoOfEdges) {
NGT::Index &index = optimizer.index;
std::stringstream queryStream;
std::stringstream gtStream;
float maxEpsilon = 0.0;
optimizer.generatePseudoGroundTruth(queries, maxEpsilon, queryStream, gtStream);
size_t nOfEdges = 0;
double accuracy = 0.0;
size_t prevEdge = 0;
double prevAccuracy = 0.0;
double gain = 0.0;
{
std::vector<NGT::ObjectDistances> graph;
NGT::GraphReconstructor::extractGraph(graph, static_cast<NGT::GraphIndex&>(index.getIndex()));
float epsilon = 0.0;
for (size_t edgeSize = 5; edgeSize <= maxNoOfEdges; edgeSize += (edgeSize >= 10 ? 10 : 5) ) {
NGT::GraphReconstructor::reconstructANNGFromANNG(graph, index, edgeSize);
NGT::Command::SearchParameter searchParameter;
searchParameter.size = nOfResults;
searchParameter.outputMode = 'e';
searchParameter.edgeSize = 0;
searchParameter.beginOfEpsilon = searchParameter.endOfEpsilon = epsilon;
queryStream.clear();
queryStream.seekg(0, std::ios_base::beg);
std::vector<NGT::Optimizer::MeasuredValue> acc;
NGT::Optimizer::search(index, queryStream, gtStream, searchParameter, acc);
if (acc.size() == 0) {
NGTThrowException("Fatal error! Cannot get any accuracy value.");
}
accuracy = acc[0].meanAccuracy;
nOfEdges = edgeSize;
if (prevEdge != 0) {
gain = (accuracy - prevAccuracy) / (edgeSize - prevEdge);
}
if (accuracy >= targetAccuracy) {
break;
}
prevEdge = edgeSize;
prevAccuracy = accuracy;
}
}
return std::make_tuple(nOfEdges, accuracy, gain);
}
static std::pair<size_t, float>
optimizeNumberOfEdgesForANNG(NGT::Index &index, ANNGEdgeOptimizationParameter &parameter)
{
if (parameter.targetNoOfObjects == 0) {
parameter.targetNoOfObjects = index.getObjectRepositorySize();
}
NGT::Optimizer optimizer(index, parameter.noOfResults);
NGT::ObjectRepository &objectRepository = index.getObjectSpace().getRepository();
NGT::GraphIndex &graphIndex = static_cast<NGT::GraphIndex&>(index.getIndex());
NGT::GraphAndTreeIndex &treeIndex = static_cast<NGT::GraphAndTreeIndex&>(index.getIndex());
NGT::GraphRepository &graphRepository = graphIndex.NeighborhoodGraph::repository;
//float targetAccuracy = parameter.targetAccuracy + FLT_EPSILON;
std::vector<std::vector<float>> queries;
optimizer.extractAndRemoveRandomQueries(parameter.noOfQueries, queries);
{
graphRepository.deleteAll();
treeIndex.DVPTree::deleteAll();
treeIndex.DVPTree::insertNode(treeIndex.DVPTree::leafNodes.allocate());
}
NGT::NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty();
prop.edgeSizeForCreation = parameter.maxNoOfEdges;
std::vector<std::pair<size_t, std::tuple<size_t, double, double>>> transition;
size_t targetNo = 12500;
for (;targetNo <= objectRepository.size() && targetNo <= parameter.noOfSampleObjects;
targetNo *= 2) {
ObjectID id = 0;
size_t noOfObjects = 0;
for (id = 1; id < objectRepository.size(); ++id) {
if (!objectRepository.isEmpty(id)) {
noOfObjects++;
}
if (noOfObjects >= targetNo) {
break;
}
}
id++;
index.createIndex(parameter.noOfThreads, id);
auto edge = NGT::GraphOptimizer::optimizeNumberOfEdgesForANNG(optimizer, queries, parameter.noOfResults, parameter.targetAccuracy, parameter.maxNoOfEdges);
transition.push_back(make_pair(noOfObjects, edge));
}
if (transition.size() < 2) {
std::stringstream msg;
msg << "Optimizer::optimizeNumberOfEdgesForANNG: Cannot optimize the number of edges. Too small object set. # of objects=" << objectRepository.size() << " target No.=" << targetNo;
NGTThrowException(msg);
}
double edgeRate = 0.0;
double accuracyRate = 0.0;
for (auto i = transition.begin(); i != transition.end() - 1; ++i) {
edgeRate += std::get<0>((*(i + 1)).second) - std::get<0>((*i).second);
accuracyRate += std::get<1>((*(i + 1)).second) - std::get<1>((*i).second);
}
edgeRate /= (transition.size() - 1);
accuracyRate /= (transition.size() - 1);
size_t estimatedEdge = std::get<0>(transition[0].second) +
edgeRate * (log2(parameter.targetNoOfObjects) - log2(transition[0].first));
float estimatedAccuracy = std::get<1>(transition[0].second) +
accuracyRate * (log2(parameter.targetNoOfObjects) - log2(transition[0].first));
if (estimatedAccuracy < parameter.targetAccuracy) {
estimatedEdge += (parameter.targetAccuracy - estimatedAccuracy) / std::get<2>(transition.back().second);
estimatedAccuracy = parameter.targetAccuracy;
}
if (estimatedEdge == 0) {
std::stringstream msg;
msg << "Optimizer::optimizeNumberOfEdgesForANNG: Cannot optimize the number of edges. "
<< estimatedEdge << ":" << estimatedAccuracy << " # of objects=" << objectRepository.size();
NGTThrowException(msg);
}
return std::make_pair(estimatedEdge, estimatedAccuracy);
}
std::pair<size_t, float>
optimizeNumberOfEdgesForANNG(const std::string indexPath, GraphOptimizer::ANNGEdgeOptimizationParameter &parameter) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
NGTThrowException("Not implemented for NGT with the shared memory option.");
#endif
NGT::StdOstreamRedirector redirector(logDisabled);
redirector.begin();
try {
NGT::Index index(indexPath, false);
auto optimizedEdge = NGT::GraphOptimizer::optimizeNumberOfEdgesForANNG(index, parameter);
NGT::GraphIndex &graph = static_cast<NGT::GraphIndex&>(index.getIndex());
size_t noOfEdges = (optimizedEdge.first + 10) / 5 * 5;
if (noOfEdges > parameter.maxNoOfEdges) {
noOfEdges = parameter.maxNoOfEdges;
}
NGT::NeighborhoodGraph::Property &prop = graph.getGraphProperty();
prop.edgeSizeForCreation = noOfEdges;
static_cast<NGT::GraphIndex&>(index.getIndex()).saveProperty(indexPath);
optimizedEdge.first = noOfEdges;
redirector.end();
return optimizedEdge;
} catch (NGT::Exception &err) {
redirector.end();
throw(err);
}
}
void set(int outgoing, int incoming, int nofqs, int nofrs,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m
) {
set(outgoing, incoming, nofqs, nofrs);
setExtension(baseAccuracyFrom, baseAccuracyTo, rateAccuracyFrom, rateAccuracyTo, gte, m);
}
void set(int outgoing, int incoming, int nofqs, int nofrs) {
if (outgoing >= 0) {
numOfOutgoingEdges = outgoing;
}
if (incoming >= 0) {
numOfIncomingEdges = incoming;
}
if (nofqs > 0) {
numOfQueries = nofqs;
}
if (nofrs > 0) {
numOfResults = nofrs;
}
}
void setExtension(float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m
) {
if (baseAccuracyFrom > 0.0) {
baseAccuracyRange.first = baseAccuracyFrom;
}
if (baseAccuracyTo > 0.0) {
baseAccuracyRange.second = baseAccuracyTo;
}
if (rateAccuracyFrom > 0.0) {
rateAccuracyRange.first = rateAccuracyFrom;
}
if (rateAccuracyTo > 0.0) {
rateAccuracyRange.second = rateAccuracyTo;
}
if (gte >= -1.0) {
gtEpsilon = gte;
}
if (m > 0.0) {
margin = m;
}
}
// obsolete because of a lack of a parameter
void set(int outgoing, int incoming, int nofqs,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m
) {
if (outgoing >= 0) {
numOfOutgoingEdges = outgoing;
}
if (incoming >= 0) {
numOfIncomingEdges = incoming;
}
if (nofqs > 0) {
numOfQueries = nofqs;
}
if (baseAccuracyFrom > 0.0) {
baseAccuracyRange.first = baseAccuracyFrom;
}
if (baseAccuracyTo > 0.0) {
baseAccuracyRange.second = baseAccuracyTo;
}
if (rateAccuracyFrom > 0.0) {
rateAccuracyRange.first = rateAccuracyFrom;
}
if (rateAccuracyTo > 0.0) {
rateAccuracyRange.second = rateAccuracyTo;
}
if (gte >= -1.0) {
gtEpsilon = gte;
}
if (m > 0.0) {
margin = m;
}
}
void setProcessingModes(bool shortcut = true, bool searchParameter = true, bool prefetchParameter = true,
bool accuracyTable = true) {
shortcutReduction = shortcut;
searchParameterOptimization = searchParameter;
prefetchParameterOptimization = prefetchParameter;
accuracyTableGeneration = accuracyTable;
}
size_t numOfOutgoingEdges;
size_t numOfIncomingEdges;
std::pair<float, float> baseAccuracyRange;
std::pair<float, float> rateAccuracyRange;
size_t numOfQueries;
size_t numOfResults;
double gtEpsilon;
double margin;
bool logDisabled;
bool shortcutReduction;
bool searchParameterOptimization;
bool prefetchParameterOptimization;
bool accuracyTableGeneration;
};
}; // NGT

View File

@ -0,0 +1,907 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <unordered_map>
#include <unordered_set>
#include <list>
#ifdef _OPENMP
#include <omp.h>
#else
#warning "*** OMP is *NOT* available! ***"
#endif
namespace NGT {
class GraphReconstructor {
public:
static void extractGraph(std::vector<NGT::ObjectDistances> &graph, NGT::GraphIndex &graphIndex) {
graph.reserve(graphIndex.repository.size());
for (size_t id = 1; id < graphIndex.repository.size(); id++) {
if (id % 1000000 == 0) {
std::cerr << "GraphReconstructor::extractGraph: Processed " << id << " objects." << std::endl;
}
try {
NGT::GraphNode &node = *graphIndex.getNode(id);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::ObjectDistances nd;
nd.reserve(node.size());
for (auto n = node.begin(graphIndex.repository.allocator); n != node.end(graphIndex.repository.allocator); ++n) {
nd.push_back(ObjectDistance((*n).id, (*n).distance));
}
graph.push_back(nd);
#else
graph.push_back(node);
#endif
if (graph.back().size() != graph.back().capacity()) {
std::cerr << "GraphReconstructor::extractGraph: Warning! The graph size must be the same as the capacity. " << id << std::endl;
}
} catch(NGT::Exception &err) {
graph.push_back(NGT::ObjectDistances());
continue;
}
}
}
static void
adjustPaths(NGT::Index &outIndex)
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "construct index is not implemented." << std::endl;
exit(1);
#else
NGT::GraphIndex &outGraph = dynamic_cast<NGT::GraphIndex&>(outIndex.getIndex());
size_t rStartRank = 0;
std::list<std::pair<size_t, NGT::GraphNode> > tmpGraph;
for (size_t id = 1; id < outGraph.repository.size(); id++) {
NGT::GraphNode &node = *outGraph.getNode(id);
tmpGraph.push_back(std::pair<size_t, NGT::GraphNode>(id, node));
if (node.size() > rStartRank) {
node.resize(rStartRank);
}
}
size_t removeCount = 0;
for (size_t rank = rStartRank; ; rank++) {
bool edge = false;
Timer timer;
for (auto it = tmpGraph.begin(); it != tmpGraph.end();) {
size_t id = (*it).first;
try {
NGT::GraphNode &node = (*it).second;
if (rank >= node.size()) {
it = tmpGraph.erase(it);
continue;
}
edge = true;
if (rank >= 1 && node[rank - 1].distance > node[rank].distance) {
std::cerr << "distance order is wrong!" << std::endl;
std::cerr << id << ":" << rank << ":" << node[rank - 1].id << ":" << node[rank].id << std::endl;
}
NGT::GraphNode &tn = *outGraph.getNode(id);
volatile bool found = false;
if (rank < 1000) {
for (size_t tni = 0; tni < tn.size() && !found; tni++) {
if (tn[tni].id == node[rank].id) {
continue;
}
NGT::GraphNode &dstNode = *outGraph.getNode(tn[tni].id);
for (size_t dni = 0; dni < dstNode.size(); dni++) {
if ((dstNode[dni].id == node[rank].id) && (dstNode[dni].distance < node[rank].distance)) {
found = true;
break;
}
}
}
} else {
#ifdef _OPENMP
#pragma omp parallel for num_threads(10)
#endif
for (size_t tni = 0; tni < tn.size(); tni++) {
if (found) {
continue;
}
if (tn[tni].id == node[rank].id) {
continue;
}
NGT::GraphNode &dstNode = *outGraph.getNode(tn[tni].id);
for (size_t dni = 0; dni < dstNode.size(); dni++) {
if ((dstNode[dni].id == node[rank].id) && (dstNode[dni].distance < node[rank].distance)) {
found = true;
}
}
}
}
if (!found) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
outGraph.addEdge(id, node.at(i, outGraph.repository.allocator).id,
node.at(i, outGraph.repository.allocator).distance, true);
#else
tn.push_back(NGT::ObjectDistance(node[rank].id, node[rank].distance));
#endif
} else {
removeCount++;
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
it++;
continue;
}
it++;
}
if (edge == false) {
break;
}
}
#endif // NGT_SHARED_MEMORY_ALLOCATOR
}
static void
adjustPathsEffectively(NGT::Index &outIndex)
{
NGT::GraphIndex &outGraph = dynamic_cast<NGT::GraphIndex&>(outIndex.getIndex());
adjustPathsEffectively(outGraph);
}
static bool edgeComp(NGT::ObjectDistance a, NGT::ObjectDistance b) {
return a.id < b.id;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
static void insert(NGT::GraphNode &node, size_t edgeID, NGT::Distance edgeDistance, NGT::GraphIndex &graph) {
NGT::ObjectDistance edge(edgeID, edgeDistance);
GraphNode::iterator ni = std::lower_bound(node.begin(graph.repository.allocator), node.end(graph.repository.allocator), edge, edgeComp);
node.insert(ni, edge, graph.repository.allocator);
}
static bool hasEdge(NGT::GraphIndex &graph, size_t srcNodeID, size_t dstNodeID)
{
NGT::GraphNode &srcNode = *graph.getNode(srcNodeID);
GraphNode::iterator ni = std::lower_bound(srcNode.begin(graph.repository.allocator), srcNode.end(graph.repository.allocator), ObjectDistance(dstNodeID, 0.0), edgeComp);
return (ni != srcNode.end(graph.repository.allocator)) && ((*ni).id == dstNodeID);
}
#else
static void insert(NGT::GraphNode &node, size_t edgeID, NGT::Distance edgeDistance) {
NGT::ObjectDistance edge(edgeID, edgeDistance);
GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), edge, edgeComp);
node.insert(ni, edge);
}
static bool hasEdge(NGT::GraphIndex &graph, size_t srcNodeID, size_t dstNodeID)
{
NGT::GraphNode &srcNode = *graph.getNode(srcNodeID);
GraphNode::iterator ni = std::lower_bound(srcNode.begin(), srcNode.end(), ObjectDistance(dstNodeID, 0.0), edgeComp);
return (ni != srcNode.end()) && ((*ni).id == dstNodeID);
}
#endif
static void
adjustPathsEffectively(NGT::GraphIndex &outGraph)
{
Timer timer;
timer.start();
std::vector<NGT::GraphNode> tmpGraph;
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &node = *outGraph.getNode(id);
tmpGraph.push_back(node);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
node.clear(outGraph.repository.allocator);
#else
node.clear();
#endif
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
tmpGraph.push_back(NGT::GraphNode(outGraph.repository.allocator));
#else
tmpGraph.push_back(NGT::GraphNode());
#endif
}
}
if (outGraph.repository.size() != tmpGraph.size() + 1) {
std::stringstream msg;
msg << "GraphReconstructor: Fatal inner error. " << outGraph.repository.size() << ":" << tmpGraph.size();
NGTThrowException(msg);
}
timer.stop();
std::cerr << "GraphReconstructor::adjustPaths: graph preparing time=" << timer << std::endl;
timer.reset();
timer.start();
std::vector<std::vector<std::pair<uint32_t, uint32_t> > > removeCandidates(tmpGraph.size());
int removeCandidateCount = 0;
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (size_t idx = 0; idx < tmpGraph.size(); ++idx) {
auto it = tmpGraph.begin() + idx;
size_t id = idx + 1;
try {
NGT::GraphNode &srcNode = *it;
std::unordered_map<uint32_t, std::pair<size_t, double> > neighbors;
for (size_t sni = 0; sni < srcNode.size(); ++sni) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
neighbors[srcNode.at(sni, outGraph.repository.allocator).id] = std::pair<size_t, double>(sni, srcNode.at(sni, outGraph.repository.allocator).distance);
#else
neighbors[srcNode[sni].id] = std::pair<size_t, double>(sni, srcNode[sni].distance);
#endif
}
std::vector<std::pair<int, std::pair<uint32_t, uint32_t> > > candidates;
for (size_t sni = 0; sni < srcNode.size(); sni++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::GraphNode &pathNode = tmpGraph[srcNode.at(sni, outGraph.repository.allocator).id - 1];
#else
NGT::GraphNode &pathNode = tmpGraph[srcNode[sni].id - 1];
#endif
for (size_t pni = 0; pni < pathNode.size(); pni++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
auto dstNodeID = pathNode.at(pni, outGraph.repository.allocator).id;
#else
auto dstNodeID = pathNode[pni].id;
#endif
auto dstNode = neighbors.find(dstNodeID);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (dstNode != neighbors.end()
&& srcNode.at(sni, outGraph.repository.allocator).distance < (*dstNode).second.second
&& pathNode.at(pni, outGraph.repository.allocator).distance < (*dstNode).second.second
) {
#else
if (dstNode != neighbors.end()
&& srcNode[sni].distance < (*dstNode).second.second
&& pathNode[pni].distance < (*dstNode).second.second
) {
#endif
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
candidates.push_back(std::pair<int, std::pair<uint32_t, uint32_t> >((*dstNode).second.first, std::pair<uint32_t, uint32_t>(srcNode.at(sni, outGraph.repository.allocator).id, dstNodeID)));
#else
candidates.push_back(std::pair<int, std::pair<uint32_t, uint32_t> >((*dstNode).second.first, std::pair<uint32_t, uint32_t>(srcNode[sni].id, dstNodeID)));
#endif
removeCandidateCount++;
}
}
}
sort(candidates.begin(), candidates.end(), std::greater<std::pair<int, std::pair<uint32_t, uint32_t>>>());
removeCandidates[id - 1].reserve(candidates.size());
for (size_t i = 0; i < candidates.size(); i++) {
removeCandidates[id - 1].push_back(candidates[i].second);
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
continue;
}
}
timer.stop();
std::cerr << "GraphReconstructor::adjustPaths: extracting removed edge candidates time=" << timer << std::endl;
timer.reset();
timer.start();
std::list<size_t> ids;
for (size_t idx = 0; idx < tmpGraph.size(); ++idx) {
ids.push_back(idx + 1);
}
int removeCount = 0;
removeCandidateCount = 0;
for (size_t rank = 0; ids.size() != 0; rank++) {
for (auto it = ids.begin(); it != ids.end(); ) {
size_t id = *it;
size_t idx = id - 1;
try {
NGT::GraphNode &srcNode = tmpGraph[idx];
if (rank >= srcNode.size()) {
if (!removeCandidates[idx].empty()) {
std::cerr << "Something wrong! ID=" << id << " # of remaining candidates=" << removeCandidates[idx].size() << std::endl;
abort();
}
#if !defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::GraphNode empty;
tmpGraph[idx] = empty;
#endif
it = ids.erase(it);
continue;
}
if (removeCandidates[idx].size() > 0) {
removeCandidateCount++;
bool pathExist = false;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode.at(rank, outGraph.repository.allocator).id)) {
#else
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode[rank].id)) {
#endif
size_t path = removeCandidates[idx].back().first;
size_t dst = removeCandidates[idx].back().second;
removeCandidates[idx].pop_back();
if (removeCandidates[idx].empty()) {
std::vector<std::pair<uint32_t, uint32_t>> empty;
removeCandidates[idx] = empty;
}
if ((hasEdge(outGraph, id, path)) && (hasEdge(outGraph, path, dst))) {
pathExist = true;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode.at(rank, outGraph.repository.allocator).id)) {
#else
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode[rank].id)) {
#endif
removeCandidates[idx].pop_back();
if (removeCandidates[idx].empty()) {
std::vector<std::pair<uint32_t, uint32_t>> empty;
removeCandidates[idx] = empty;
}
}
break;
}
}
if (pathExist) {
removeCount++;
it++;
continue;
}
}
NGT::GraphNode &outSrcNode = *outGraph.getNode(id);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
insert(outSrcNode, srcNode.at(rank, outGraph.repository.allocator).id, srcNode.at(rank, outGraph.repository.allocator).distance, outGraph);
#else
insert(outSrcNode, srcNode[rank].id, srcNode[rank].distance);
#endif
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
it++;
continue;
}
it++;
}
}
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &node = *outGraph.getNode(id);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::sort(node.begin(outGraph.repository.allocator), node.end(outGraph.repository.allocator));
#else
std::sort(node.begin(), node.end());
#endif
} catch(...) {}
}
}
static
void convertToANNG(std::vector<NGT::ObjectDistances> &graph)
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "convertToANNG is not implemented for shared memory." << std::endl;
return;
#else
std::cerr << "convertToANNG begin" << std::endl;
for (size_t idx = 0; idx < graph.size(); idx++) {
NGT::GraphNode &node = graph[idx];
for (auto ni = node.begin(); ni != node.end(); ++ni) {
graph[(*ni).id - 1].push_back(NGT::ObjectDistance(idx + 1, (*ni).distance));
}
}
for (size_t idx = 0; idx < graph.size(); idx++) {
NGT::GraphNode &node = graph[idx];
if (node.size() == 0) {
continue;
}
std::sort(node.begin(), node.end());
NGT::ObjectID prev = 0;
for (auto it = node.begin(); it != node.end();) {
if (prev == (*it).id) {
it = node.erase(it);
continue;
}
prev = (*it).id;
it++;
}
NGT::GraphNode tmp = node;
node.swap(tmp);
}
std::cerr << "convertToANNG end" << std::endl;
#endif
}
static
void reconstructGraph(std::vector<NGT::ObjectDistances> &graph, NGT::GraphIndex &outGraph, size_t originalEdgeSize, size_t reverseEdgeSize)
{
if (reverseEdgeSize > 10000) {
std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
exit(1);
}
NGT::Timer originalEdgeTimer, reverseEdgeTimer, normalizeEdgeTimer;
originalEdgeTimer.start();
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &node = *outGraph.getNode(id);
if (originalEdgeSize == 0) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
node.clear(outGraph.repository.allocator);
#else
NGT::GraphNode empty;
node.swap(empty);
#endif
} else {
NGT::ObjectDistances n = graph[id - 1];
if (n.size() < originalEdgeSize) {
std::cerr << "GraphReconstructor: Warning. The edges are too few. " << n.size() << ":" << originalEdgeSize << " for " << id << std::endl;
continue;
}
n.resize(originalEdgeSize);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
node.copy(n, outGraph.repository.allocator);
#else
node.swap(n);
#endif
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
continue;
}
}
originalEdgeTimer.stop();
reverseEdgeTimer.start();
int insufficientNodeCount = 0;
for (size_t id = 1; id <= graph.size(); ++id) {
try {
NGT::ObjectDistances &node = graph[id - 1];
size_t rsize = reverseEdgeSize;
if (rsize > node.size()) {
insufficientNodeCount++;
rsize = node.size();
}
for (size_t i = 0; i < rsize; ++i) {
NGT::Distance distance = node[i].distance;
size_t nodeID = node[i].id;
try {
NGT::GraphNode &n = *outGraph.getNode(nodeID);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
n.push_back(NGT::ObjectDistance(id, distance), outGraph.repository.allocator);
#else
n.push_back(NGT::ObjectDistance(id, distance));
#endif
} catch(...) {}
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
continue;
}
}
reverseEdgeTimer.stop();
if (insufficientNodeCount != 0) {
std::cerr << "# of the nodes edges of which are in short = " << insufficientNodeCount << std::endl;
}
normalizeEdgeTimer.start();
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &n = *outGraph.getNode(id);
if (id % 100000 == 0) {
std::cerr << "Processed " << id << " nodes" << std::endl;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::sort(n.begin(outGraph.repository.allocator), n.end(outGraph.repository.allocator));
#else
std::sort(n.begin(), n.end());
#endif
NGT::ObjectID prev = 0;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
for (auto it = n.begin(outGraph.repository.allocator); it != n.end(outGraph.repository.allocator);) {
#else
for (auto it = n.begin(); it != n.end();) {
#endif
if (prev == (*it).id) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
it = n.erase(it, outGraph.repository.allocator);
#else
it = n.erase(it);
#endif
continue;
}
prev = (*it).id;
it++;
}
#if !defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::GraphNode tmp = n;
n.swap(tmp);
#endif
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
continue;
}
}
normalizeEdgeTimer.stop();
std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
<< ":" << normalizeEdgeTimer.time << std::endl;
NGT::Property prop;
outGraph.getProperty().get(prop);
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
outGraph.getProperty().set(prop);
}
static
void reconstructGraphWithConstraint(std::vector<NGT::ObjectDistances> &graph, NGT::GraphIndex &outGraph,
size_t originalEdgeSize, size_t reverseEdgeSize,
char mode = 'a')
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "reconstructGraphWithConstraint is not implemented." << std::endl;
abort();
#else
NGT::Timer originalEdgeTimer, reverseEdgeTimer, normalizeEdgeTimer;
if (reverseEdgeSize > 10000) {
std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
exit(1);
}
for (size_t id = 1; id < outGraph.repository.size(); id++) {
if (id % 1000000 == 0) {
std::cerr << "Processed " << id << std::endl;
}
try {
NGT::GraphNode &node = *outGraph.getNode(id);
if (node.size() == 0) {
continue;
}
node.clear();
NGT::GraphNode empty;
node.swap(empty);
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
continue;
}
}
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
std::vector<ObjectDistances> reverse(graph.size() + 1);
for (size_t id = 1; id <= graph.size(); ++id) {
try {
NGT::GraphNode &node = graph[id - 1];
if (id % 100000 == 0) {
std::cerr << "Processed (summing up) " << id << std::endl;
}
for (size_t rank = 0; rank < node.size(); rank++) {
reverse[node[rank].id].push_back(ObjectDistance(id, node[rank].distance));
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
continue;
}
}
std::vector<std::pair<size_t, size_t> > reverseSize(graph.size() + 1);
reverseSize[0] = std::pair<size_t, size_t>(0, 0);
for (size_t rid = 1; rid <= graph.size(); ++rid) {
reverseSize[rid] = std::pair<size_t, size_t>(reverse[rid].size(), rid);
}
std::sort(reverseSize.begin(), reverseSize.end());
std::vector<uint32_t> indegreeCount(graph.size(), 0);
size_t zeroCount = 0;
for (size_t sizerank = 0; sizerank <= reverseSize.size(); sizerank++) {
if (reverseSize[sizerank].first == 0) {
zeroCount++;
continue;
}
size_t rid = reverseSize[sizerank].second;
ObjectDistances &rnode = reverse[rid];
for (auto rni = rnode.begin(); rni != rnode.end(); ++rni) {
if (indegreeCount[(*rni).id] >= reverseEdgeSize) {
continue;
}
NGT::GraphNode &node = *outGraph.getNode(rid);
if (indegreeCount[(*rni).id] > 0 && node.size() >= originalEdgeSize) {
continue;
}
node.push_back(NGT::ObjectDistance((*rni).id, (*rni).distance));
indegreeCount[(*rni).id]++;
}
}
reverseEdgeTimer.stop();
std::cerr << "The number of nodes with zero outdegree by reverse edges=" << zeroCount << std::endl;
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
normalizeEdgeTimer.start();
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &n = *outGraph.getNode(id);
if (id % 100000 == 0) {
std::cerr << "Processed " << id << std::endl;
}
std::sort(n.begin(), n.end());
NGT::ObjectID prev = 0;
for (auto it = n.begin(); it != n.end();) {
if (prev == (*it).id) {
it = n.erase(it);
continue;
}
prev = (*it).id;
it++;
}
NGT::GraphNode tmp = n;
n.swap(tmp);
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
continue;
}
}
normalizeEdgeTimer.stop();
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
originalEdgeTimer.start();
for (size_t id = 1; id < outGraph.repository.size(); id++) {
if (id % 1000000 == 0) {
std::cerr << "Processed " << id << std::endl;
}
NGT::GraphNode &node = graph[id - 1];
try {
NGT::GraphNode &onode = *outGraph.getNode(id);
bool stop = false;
for (size_t rank = 0; (rank < node.size() && rank < originalEdgeSize) && stop == false; rank++) {
switch (mode) {
case 'a':
if (onode.size() >= originalEdgeSize) {
stop = true;
continue;
}
break;
case 'c':
break;
}
NGT::Distance distance = node[rank].distance;
size_t nodeID = node[rank].id;
outGraph.addEdge(id, nodeID, distance, false);
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
continue;
}
}
originalEdgeTimer.stop();
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
<< ":" << normalizeEdgeTimer.time << std::endl;
#endif
}
// reconstruct a pseudo ANNG with a fewer edges from an actual ANNG with more edges.
// graph is a source ANNG
// index is an index with a reconstructed ANNG
static
void reconstructANNGFromANNG(std::vector<NGT::ObjectDistances> &graph, NGT::Index &index, size_t edgeSize)
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "reconstructANNGFromANNG is not implemented." << std::endl;
abort();
#else
NGT::GraphIndex &outGraph = dynamic_cast<NGT::GraphIndex&>(index.getIndex());
// remove all edges in the index.
for (size_t id = 1; id < outGraph.repository.size(); id++) {
if (id % 1000000 == 0) {
std::cerr << "Processed " << id << " nodes." << std::endl;
}
try {
NGT::GraphNode &node = *outGraph.getNode(id);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
node.clear(outGraph.repository.allocator);
#else
NGT::GraphNode empty;
node.swap(empty);
#endif
} catch(NGT::Exception &err) {
}
}
for (size_t id = 1; id <= graph.size(); ++id) {
size_t edgeCount = 0;
try {
NGT::ObjectDistances &node = graph[id - 1];
NGT::GraphNode &n = *outGraph.getNode(id);
NGT::Distance prevDistance = 0.0;
assert(n.size() == 0);
for (size_t i = 0; i < node.size(); ++i) {
NGT::Distance distance = node[i].distance;
if (prevDistance > distance) {
NGTThrowException("Edge distance order is invalid");
}
prevDistance = distance;
size_t nodeID = node[i].id;
if (node[i].id < id) {
try {
NGT::GraphNode &dn = *outGraph.getNode(nodeID);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
n.push_back(NGT::ObjectDistance(nodeID, distance), outGraph.repository.allocator);
dn.push_back(NGT::ObjectDistance(id, distance), outGraph.repository.allocator);
#else
n.push_back(NGT::ObjectDistance(nodeID, distance));
dn.push_back(NGT::ObjectDistance(id, distance));
#endif
} catch(...) {}
edgeCount++;
}
if (edgeCount >= edgeSize) {
break;
}
}
} catch(NGT::Exception &err) {
}
}
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &n = *outGraph.getNode(id);
std::sort(n.begin(), n.end());
NGT::ObjectID prev = 0;
for (auto it = n.begin(); it != n.end();) {
if (prev == (*it).id) {
it = n.erase(it);
continue;
}
prev = (*it).id;
it++;
}
NGT::GraphNode tmp = n;
n.swap(tmp);
} catch (...) {
}
}
#endif
}
static void refineANNG(NGT::Index &index, bool unlog, float epsilon = 0.1, float accuracy = 0.0, int noOfEdges = 0, int exploreEdgeSize = INT_MIN, size_t batchSize = 10000) {
NGT::StdOstreamRedirector redirector(unlog);
redirector.begin();
try {
refineANNG(index, epsilon, accuracy, noOfEdges, exploreEdgeSize, batchSize);
} catch (NGT::Exception &err) {
redirector.end();
throw(err);
}
}
static void refineANNG(NGT::Index &index, float epsilon = 0.1, float accuracy = 0.0, int noOfEdges = 0, int exploreEdgeSize = INT_MIN, size_t batchSize = 10000) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGTThrowException("GraphReconstructor::refineANNG: Not implemented for the shared memory option.");
#else
auto prop = static_cast<GraphIndex&>(index.getIndex()).getGraphProperty();
NGT::ObjectRepository &objectRepository = index.getObjectSpace().getRepository();
NGT::GraphIndex &graphIndex = static_cast<GraphIndex&>(index.getIndex());
size_t nOfObjects = objectRepository.size();
bool error = false;
std::string errorMessage;
for (size_t bid = 1; bid < nOfObjects; bid += batchSize) {
NGT::ObjectDistances results[batchSize];
// search
#pragma omp parallel for
for (size_t idx = 0; idx < batchSize; idx++) {
size_t id = bid + idx;
if (id % 100000 == 0) {
std::cerr << "# of processed objects=" << id << std::endl;
}
if (objectRepository.isEmpty(id)) {
continue;
}
NGT::SearchContainer searchContainer(*objectRepository.get(id));
searchContainer.setResults(&results[idx]);
assert(prop.edgeSizeForCreation > 0);
searchContainer.setSize(noOfEdges > prop.edgeSizeForCreation ? noOfEdges : prop.edgeSizeForCreation);
if (accuracy > 0.0) {
searchContainer.setExpectedAccuracy(accuracy);
} else {
searchContainer.setEpsilon(epsilon);
}
if (exploreEdgeSize != INT_MIN) {
searchContainer.setEdgeSize(exploreEdgeSize);
}
if (!error) {
try {
index.search(searchContainer);
} catch (NGT::Exception &err) {
#pragma omp critical
{
error = true;
errorMessage = err.what();
}
}
}
}
if (error) {
std::stringstream msg;
msg << "GraphReconstructor::refineANNG: " << errorMessage;
NGTThrowException(msg);
}
// outgoing edges
#pragma omp parallel for
for (size_t idx = 0; idx < batchSize; idx++) {
size_t id = bid + idx;
if (objectRepository.isEmpty(id)) {
continue;
}
NGT::GraphNode &node = *graphIndex.getNode(id);
for (auto i = results[idx].begin(); i != results[idx].end(); ++i) {
if ((*i).id != id) {
node.push_back(*i);
}
}
std::sort(node.begin(), node.end());
// dedupe
ObjectID prev = 0;
for (GraphNode::iterator ni = node.begin(); ni != node.end();) {
if (prev == (*ni).id) {
ni = node.erase(ni);
continue;
}
prev = (*ni).id;
ni++;
}
}
// incomming edges
if (noOfEdges != 0) {
continue;
}
for (size_t idx = 0; idx < batchSize; idx++) {
size_t id = bid + idx;
if (id % 10000 == 0) {
std::cerr << "# of processed objects=" << id << std::endl;
}
for (auto i = results[idx].begin(); i != results[idx].end(); ++i) {
if ((*i).id != id) {
NGT::GraphNode &node = *graphIndex.getNode((*i).id);
graphIndex.addEdge(node, id, (*i).distance, false);
}
}
}
}
if (noOfEdges != 0) {
// prune to build knng
size_t nedges = noOfEdges < 0 ? -noOfEdges : noOfEdges;
#pragma omp parallel for
for (ObjectID id = 1; id < nOfObjects; ++id) {
if (objectRepository.isEmpty(id)) {
continue;
}
NGT::GraphNode &node = *graphIndex.getNode(id);
if (node.size() > nedges) {
node.resize(nedges);
}
}
}
#endif // defined(NGT_SHARED_MEMORY_ALLOCATOR)
}
};
}; // NGT

View File

@ -0,0 +1,110 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <iostream>
#include <cstring>
#include <stdint.h>
#include <climits>
#include <unordered_set>
class HashBasedBooleanSet{
private:
uint32_t *_table;
uint32_t _tableSize;
uint32_t _mask;
std::unordered_set<uint32_t> _stlHash;
inline uint32_t _hash1(const uint32_t value){
return value & _mask;
}
public:
HashBasedBooleanSet():_table(NULL), _tableSize(0), _mask(0) {}
HashBasedBooleanSet(const uint64_t size):_table(NULL), _tableSize(0), _mask(0) {
size_t bitSize = 0;
size_t bit = size;
while (bit != 0) {
bitSize++;
bit >>= 1;
}
size_t bucketSize = 0x1 << ((bitSize + 4) / 2 + 3);
initialize(bucketSize);
}
void initialize(const uint32_t tableSize) {
_tableSize = tableSize;
_mask = _tableSize - 1;
const uint32_t checkValue = _hash1(tableSize);
if(checkValue != 0){
std::cerr << "[WARN] table size is not 2^N : " << tableSize << std::endl;
}
_table = new uint32_t[tableSize];
memset(_table, 0, tableSize * sizeof(uint32_t));
}
~HashBasedBooleanSet(){
delete[] _table;
_stlHash.clear();
}
inline bool operator[](const uint32_t num){
const uint32_t hashValue = _hash1(num);
auto v = _table[hashValue];
if (v == num){
return true;
}
if (v == 0){
return false;
}
if (_stlHash.count(num) <= 0) {
return false;
}
return true;
}
inline void set(const uint32_t num){
uint32_t &value = _table[_hash1(num)];
if(value == 0){
value = num;
}else{
if(value != num){
_stlHash.insert(num);
}
}
}
inline void insert(const uint32_t num){
set(num);
}
inline void reset(const uint32_t num){
const uint32_t hashValue = _hash1(num);
if(_table[hashValue] != 0){
if(_table[hashValue] != num){
_stlHash.erase(num);
}else{
_table[hashValue] = UINT_MAX;
}
}
}
};

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,457 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "MmapManagerImpl.hpp"
namespace MemoryManager{
// static method ---
void MmapManager::setDefaultOptionValue(init_option_st &optionst)
{
optionst.use_expand = MMAP_DEFAULT_ALLOW_EXPAND;
optionst.reuse_type = REUSE_DATA_CLASSIFY;
}
size_t MmapManager::getAlignSize(size_t size){
if((size % MMAP_MEMORY_ALIGN) == 0){
return size;
}else{
return ( (size >> MMAP_MEMORY_ALIGN_EXP ) + 1 ) * MMAP_MEMORY_ALIGN;
}
}
// static method ---
MmapManager::MmapManager():_impl(new MmapManager::Impl(*this))
{
for(uint64_t i = 0; i < MMAP_MAX_UNIT_NUM; ++i){
_impl->mmapDataAddr[i] = NULL;
}
}
MmapManager::~MmapManager() = default;
void MmapManager::dumpHeap() const
{
_impl->dumpHeap();
}
bool MmapManager::isOpen() const
{
return _impl->isOpen;
}
void *MmapManager::getEntryHook() const {
return getAbsAddr(_impl->mmapCntlHead->entry_p);
}
void MmapManager::setEntryHook(const void *entry_p){
_impl->mmapCntlHead->entry_p = getRelAddr(entry_p);
}
bool MmapManager::init(const std::string &filePath, size_t size, const init_option_st *optionst) const
{
try{
const std::string controlFile = filePath + MMAP_CNTL_FILE_SUFFIX;
struct stat st;
if(stat(controlFile.c_str(), &st) == 0){
return false;
}
if(filePath.length() > MMAP_MAX_FILE_NAME_LENGTH){
std::cerr << "too long filepath" << std::endl;
return false;
}
if((size % sysconf(_SC_PAGESIZE) != 0) || ( size < MMAP_LOWER_SIZE )){
std::cerr << "input size error" << std::endl;
return false;
}
int32_t fd = _impl->formatFile(controlFile, MMAP_CNTL_FILE_SIZE);
assert(fd >= 0);
errno = 0;
char *cntl_p = (char *)mmap(NULL, MMAP_CNTL_FILE_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if(cntl_p == MAP_FAILED){
const std::string err_str = getErrorStr(errno);
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException(controlFile + " " + err_str);
}
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
try {
fd = _impl->formatFile(filePath, size);
} catch (MmapManagerException &err) {
if(munmap(cntl_p, MMAP_CNTL_FILE_SIZE) == -1) {
throw MmapManagerException("[ERR] : munmap error : " + getErrorStr(errno) +
" : Through the exception : " + err.what());
}
throw err;
}
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
boot_st bootStruct = {0};
control_st controlStruct = {0};
_impl->initBootStruct(bootStruct, size);
_impl->initControlStruct(controlStruct, size);
char *cntl_head = cntl_p;
cntl_head += sizeof(boot_st);
if(optionst != NULL){
controlStruct.use_expand = optionst->use_expand;
controlStruct.reuse_type = optionst->reuse_type;
}
memcpy(cntl_p, (char *)&bootStruct, sizeof(boot_st));
memcpy(cntl_head, (char *)&controlStruct, sizeof(control_st));
errno = 0;
if(munmap(cntl_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
return true;
}catch(MmapManagerException &e){
std::cerr << "init error. " << e.what() << std::endl;
throw e;
}
}
bool MmapManager::openMemory(const std::string &filePath)
{
try{
if(_impl->isOpen == true){
std::string err_str = "[ERROR] : openMemory error (double open).";
throw MmapManagerException(err_str);
}
const std::string controlFile = filePath + MMAP_CNTL_FILE_SUFFIX;
_impl->filePath = filePath;
int32_t fd;
errno = 0;
if((fd = open(controlFile.c_str(), O_RDWR, 0666)) == -1){
const std::string err_str = getErrorStr(errno);
throw MmapManagerException("file open error" + err_str);
}
errno = 0;
boot_st *boot_p = (boot_st*)mmap(NULL, MMAP_CNTL_FILE_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if(boot_p == MAP_FAILED){
const std::string err_str = getErrorStr(errno);
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException(controlFile + " " + err_str);
}
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
if(boot_p->version != MMAP_MANAGER_VERSION){
std::cerr << "[WARN] : version error" << std::endl;
errno = 0;
if(munmap(boot_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
throw MmapManagerException("MemoryManager version error");
}
errno = 0;
if((fd = open(filePath.c_str(), O_RDWR, 0666)) == -1){
const std::string err_str = getErrorStr(errno);
errno = 0;
if(munmap(boot_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
throw MmapManagerException("file open error = " + std::string(filePath.c_str()) + err_str);
}
_impl->mmapCntlHead = (control_st*)( (char *)boot_p + sizeof(boot_st));
_impl->mmapCntlAddr = (void *)boot_p;
for(uint64_t i = 0; i < _impl->mmapCntlHead->unit_num; i++){
off_t offset = _impl->mmapCntlHead->base_size * i;
errno = 0;
_impl->mmapDataAddr[i] = mmap(NULL, _impl->mmapCntlHead->base_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, offset);
if(_impl->mmapDataAddr[i] == MAP_FAILED){
if (errno == EINVAL) {
std::cerr << "MmapManager::openMemory: Fatal error. EINVAL" << std::endl
<< " If you use valgrind, this error might occur when the DB is created." << std::endl
<< " In the case of that, reduce bsize in SharedMemoryAllocator." << std::endl;
assert(errno != EINVAL);
}
const std::string err_str = getErrorStr(errno);
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
closeMemory(true);
throw MmapManagerException(err_str);
}
}
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
_impl->isOpen = true;
return true;
}catch(MmapManagerException &e){
std::cerr << "open error" << std::endl;
throw e;
}
}
void MmapManager::closeMemory(const bool force)
{
try{
if(force || _impl->isOpen){
uint16_t count = 0;
void *error_ids[MMAP_MAX_UNIT_NUM] = {0};
for(uint16_t i = 0; i < _impl->mmapCntlHead->unit_num; i++){
if(_impl->mmapDataAddr[i] != NULL){
if(munmap(_impl->mmapDataAddr[i], _impl->mmapCntlHead->base_size) == -1){
error_ids[i] = _impl->mmapDataAddr[i];;
count++;
}
_impl->mmapDataAddr[i] = NULL;
}
}
if(count > 0){
std::string msg = "";
for(uint16_t i = 0; i < count; i++){
std::stringstream ss;
ss << error_ids[i];
msg += ss.str() + ", ";
}
throw MmapManagerException("unmap error : ids = " + msg);
}
if(_impl->mmapCntlAddr != NULL){
if(munmap(_impl->mmapCntlAddr, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
_impl->mmapCntlAddr = NULL;
}
_impl->isOpen = false;
}
}catch(MmapManagerException &e){
std::cerr << "close error" << std::endl;
throw e;
}
}
off_t MmapManager::alloc(const size_t size, const bool not_reuse_flag)
{
try{
if(!_impl->isOpen){
std::cerr << "not open this file" << std::endl;
return -1;
}
size_t alloc_size = getAlignSize(size);
if( (alloc_size + sizeof(chunk_head_st)) >= _impl->mmapCntlHead->base_size ){
std::cerr << "alloc size over. size=" << size << "." << std::endl;
return -1;
}
if(!not_reuse_flag){
if( _impl->mmapCntlHead->reuse_type == REUSE_DATA_CLASSIFY
|| _impl->mmapCntlHead->reuse_type == REUSE_DATA_QUEUE
|| _impl->mmapCntlHead->reuse_type == REUSE_DATA_QUEUE_PLUS){
off_t ret_offset;
reuse_state_t reuse_state = REUSE_STATE_OK;
ret_offset = reuse(alloc_size, reuse_state);
if(reuse_state != REUSE_STATE_ALLOC){
return ret_offset;
}
}
}
head_st *unit_header = &_impl->mmapCntlHead->data_headers[_impl->mmapCntlHead->active_unit];
if((unit_header->break_p + sizeof(chunk_head_st) + alloc_size) >= _impl->mmapCntlHead->base_size){
if(_impl->mmapCntlHead->use_expand == true){
if(_impl->expandMemory() == false){
std::cerr << __func__ << ": cannot expand" << std::endl;
return -1;
}
unit_header = &_impl->mmapCntlHead->data_headers[_impl->mmapCntlHead->active_unit];
}else{
std::cerr << __func__ << ": total size over" << std::endl;
return -1;
}
}
const off_t file_offset = _impl->mmapCntlHead->active_unit * _impl->mmapCntlHead->base_size;
const off_t ret_p = file_offset + ( unit_header->break_p + sizeof(chunk_head_st) );
chunk_head_st *chunk_head = (chunk_head_st*)(unit_header->break_p + (char *)_impl->mmapDataAddr[_impl->mmapCntlHead->active_unit]);
_impl->setupChunkHead(chunk_head, false, _impl->mmapCntlHead->active_unit, -1, alloc_size);
unit_header->break_p += alloc_size + sizeof(chunk_head_st);
unit_header->chunk_num++;
return ret_p;
}catch(MmapManagerException &e){
std::cerr << "allocation error" << std::endl;
throw e;
}
}
void MmapManager::free(const off_t p)
{
switch(_impl->mmapCntlHead->reuse_type){
case REUSE_DATA_CLASSIFY:
_impl->free_data_classify(p);
break;
case REUSE_DATA_QUEUE:
_impl->free_data_queue(p);
break;
case REUSE_DATA_QUEUE_PLUS:
_impl->free_data_queue_plus(p);
break;
default:
_impl->free_data_classify(p);
break;
}
}
off_t MmapManager::reuse(const size_t size, reuse_state_t &reuse_state)
{
off_t ret_off;
switch(_impl->mmapCntlHead->reuse_type){
case REUSE_DATA_CLASSIFY:
ret_off = _impl->reuse_data_classify(size, reuse_state);
break;
case REUSE_DATA_QUEUE:
ret_off = _impl->reuse_data_queue(size, reuse_state);
break;
case REUSE_DATA_QUEUE_PLUS:
ret_off = _impl->reuse_data_queue_plus(size, reuse_state);
break;
default:
ret_off = _impl->reuse_data_classify(size, reuse_state);
break;
}
return ret_off;
}
void *MmapManager::getAbsAddr(off_t p) const
{
if(p < 0){
return NULL;
}
const uint16_t unit_id = p / _impl->mmapCntlHead->base_size;
const off_t file_offset = unit_id * _impl->mmapCntlHead->base_size;
const off_t ret_p = p - file_offset;
return ABS_ADDR(ret_p, _impl->mmapDataAddr[unit_id]);
}
off_t MmapManager::getRelAddr(const void *p) const
{
const chunk_head_st *chunk_head = (chunk_head_st *)((char *)p - sizeof(chunk_head_st));
const uint16_t unit_id = chunk_head->unit_id;
const off_t file_offset = unit_id * _impl->mmapCntlHead->base_size;
off_t ret_p = (off_t)((char *)p - (char *)_impl->mmapDataAddr[unit_id]);
ret_p += file_offset;
return ret_p;
}
std::string getErrorStr(int32_t err_num){
char err_msg[256];
#ifdef _GNU_SOURCE
char *msg = strerror_r(err_num, err_msg, 256);
return std::string(msg);
#else
strerror_r(err_num, err_msg, 256);
return std::string(err_msg);
#endif
}
size_t MmapManager::getTotalSize() const
{
const uint16_t active_unit = _impl->mmapCntlHead->active_unit;
const size_t ret_size = ((_impl->mmapCntlHead->unit_num - 1) * _impl->mmapCntlHead->base_size) + _impl->mmapCntlHead->data_headers[active_unit].break_p;
return ret_size;
}
size_t MmapManager::getUseSize() const
{
size_t total_size = 0;
void *ref_addr = (void *)&total_size;
_impl->scanAllData(ref_addr, CHECK_STATS_USE_SIZE);
return total_size;
}
uint64_t MmapManager::getUseNum() const
{
uint64_t total_chunk_num = 0;
void *ref_addr = (void *)&total_chunk_num;
_impl->scanAllData(ref_addr, CHECK_STATS_USE_NUM);
return total_chunk_num;
}
size_t MmapManager::getFreeSize() const
{
size_t total_size = 0;
void *ref_addr = (void *)&total_size;
_impl->scanAllData(ref_addr, CHECK_STATS_FREE_SIZE);
return total_size;
}
uint64_t MmapManager::getFreeNum() const
{
uint64_t total_chunk_num = 0;
void *ref_addr = (void *)&total_chunk_num;
_impl->scanAllData(ref_addr, CHECK_STATS_FREE_NUM);
return total_chunk_num;
}
uint16_t MmapManager::getUnitNum() const
{
return _impl->mmapCntlHead->unit_num;
}
size_t MmapManager::getQueueCapacity() const
{
free_queue_st *free_queue = &_impl->mmapCntlHead->free_queue;
return free_queue->capacity;
}
uint64_t MmapManager::getQueueNum() const
{
free_queue_st *free_queue = &_impl->mmapCntlHead->free_queue;
return free_queue->tail;
}
uint64_t MmapManager::getLargeListNum() const
{
uint64_t count = 0;
free_list_st *free_list = &_impl->mmapCntlHead->free_data.large_list;
if(free_list->free_p == -1){
return count;
}
off_t current_off = free_list->free_p;
chunk_head_st *current_chunk_head = (chunk_head_st *)getAbsAddr(current_off);
while(current_chunk_head != NULL){
count++;
current_off = current_chunk_head->free_next;
current_chunk_head = (chunk_head_st *)getAbsAddr(current_off);
}
return count;
}
}

View File

@ -0,0 +1,95 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <sys/types.h>
#include <stdint.h>
#include <string>
#include <memory>
#define ABS_ADDR(x, y) (void *)(x + (char *)y);
#define USE_MMAP_MANAGER
namespace MemoryManager{
typedef enum _option_reuse_t{
REUSE_DATA_CLASSIFY,
REUSE_DATA_QUEUE,
REUSE_DATA_QUEUE_PLUS,
}option_reuse_t;
typedef enum _reuse_state_t{
REUSE_STATE_OK,
REUSE_STATE_FALSE,
REUSE_STATE_ALLOC,
}reuse_state_t;
typedef enum _check_statistics_t{
CHECK_STATS_USE_SIZE,
CHECK_STATS_USE_NUM,
CHECK_STATS_FREE_SIZE,
CHECK_STATS_FREE_NUM,
}check_statistics_t;
typedef struct _init_option_st{
bool use_expand;
option_reuse_t reuse_type;
}init_option_st;
class MmapManager{
public:
MmapManager();
~MmapManager();
bool init(const std::string &filePath, size_t size, const init_option_st *optionst = NULL) const;
bool openMemory(const std::string &filePath);
void closeMemory(const bool force = false);
off_t alloc(const size_t size, const bool not_reuse_flag = false);
void free(const off_t p);
off_t reuse(const size_t size, reuse_state_t &reuse_state);
void *getAbsAddr(off_t p) const;
off_t getRelAddr(const void *p) const;
size_t getTotalSize() const;
size_t getUseSize() const;
uint64_t getUseNum() const;
size_t getFreeSize() const;
uint64_t getFreeNum() const;
uint16_t getUnitNum() const;
size_t getQueueCapacity() const;
uint64_t getQueueNum() const;
uint64_t getLargeListNum() const;
void dumpHeap() const;
bool isOpen() const;
void *getEntryHook() const;
void setEntryHook(const void *entry_p);
// static method ---
static void setDefaultOptionValue(init_option_st &optionst);
static size_t getAlignSize(size_t size);
private:
class Impl;
std::unique_ptr<Impl> _impl;
};
std::string getErrorStr(int32_t err_num);
}

View File

@ -0,0 +1,98 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "MmapManager.h"
#include <unistd.h>
namespace MemoryManager{
const uint64_t MMAP_MANAGER_VERSION = 5;
const bool MMAP_DEFAULT_ALLOW_EXPAND = false;
const uint64_t MMAP_CNTL_FILE_RANGE = 16;
const size_t MMAP_CNTL_FILE_SIZE = MMAP_CNTL_FILE_RANGE * sysconf(_SC_PAGESIZE);
const uint64_t MMAP_MAX_FILE_NAME_LENGTH = 1024;
const std::string MMAP_CNTL_FILE_SUFFIX = "c";
const size_t MMAP_LOWER_SIZE = 1;
const size_t MMAP_MEMORY_ALIGN = 8;
const size_t MMAP_MEMORY_ALIGN_EXP = 3;
#ifndef MMANAGER_TEST_MODE
const uint64_t MMAP_MAX_UNIT_NUM = 1024;
#else
const uint64_t MMAP_MAX_UNIT_NUM = 8;
#endif
const uint64_t MMAP_FREE_QUEUE_SIZE = 1024;
const uint64_t MMAP_FREE_LIST_NUM = 64;
typedef struct _boot_st{
uint32_t version;
uint64_t reserve;
size_t size;
}boot_st;
typedef struct _head_st{
off_t break_p;
uint64_t chunk_num;
uint64_t reserve;
}head_st;
typedef struct _free_list_st{
off_t free_p;
off_t free_last_p;
}free_list_st;
typedef struct _free_st{
free_list_st large_list;
free_list_st free_lists[MMAP_FREE_LIST_NUM];
}free_st;
typedef struct _free_queue_st{
off_t data;
size_t capacity;
uint64_t tail;
}free_queue_st;
typedef struct _control_st{
bool use_expand;
uint16_t unit_num;
uint16_t active_unit;
uint64_t reserve;
size_t base_size;
off_t entry_p;
option_reuse_t reuse_type;
free_st free_data;
free_queue_st free_queue;
head_st data_headers[MMAP_MAX_UNIT_NUM];
}control_st;
typedef struct _chunk_head_st{
bool delete_flg;
uint16_t unit_id;
off_t free_next;
size_t size;
}chunk_head_st;
}

View File

@ -0,0 +1,28 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <iostream>
#include <exception>
#include <stdexcept>
namespace MemoryManager{
class MmapManagerException : public std::domain_error{
public:
MmapManagerException(const std::string &msg) : std::domain_error(msg){}
};
}

View File

@ -0,0 +1,644 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "MmapManagerDefs.h"
#include "MmapManagerException.h"
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <iostream>
#include <sstream>
#include <cstring>
#include <cassert>
namespace MemoryManager{
class MmapManager::Impl{
public:
Impl() = delete;
Impl(MmapManager &ommanager);
virtual ~Impl(){}
MmapManager &mmanager;
bool isOpen;
void *mmapCntlAddr;
control_st *mmapCntlHead;
std::string filePath;
void *mmapDataAddr[MMAP_MAX_UNIT_NUM];
void initBootStruct(boot_st &bst, size_t size) const;
void initFreeStruct(free_st &fst) const;
void initFreeQueue(free_queue_st &fqst) const;
void initControlStruct(control_st &cntlst, size_t size) const;
void setupChunkHead(chunk_head_st *chunk_head, const bool delete_flg, const uint16_t unit_id, const off_t free_next, const size_t size) const;
bool expandMemory();
int32_t formatFile(const std::string &targetFile, size_t size) const;
void clearChunk(const off_t chunk_off) const;
void free_data_classify(const off_t p, const bool force_large_list = false) const;
off_t reuse_data_classify(const size_t size, reuse_state_t &reuse_state, const bool force_large_list = false) const;
void free_data_queue(const off_t p);
off_t reuse_data_queue(const size_t size, reuse_state_t &reuse_state);
void free_data_queue_plus(const off_t p);
off_t reuse_data_queue_plus(const size_t size, reuse_state_t &reuse_state);
bool scanAllData(void *target, const check_statistics_t stats_type) const;
void upHeap(free_queue_st *free_queue, uint64_t index) const;
void downHeap(free_queue_st *free_queue)const;
bool insertHeap(free_queue_st *free_queue, const off_t p) const;
bool getHeap(free_queue_st *free_queue, off_t *p) const;
size_t getMaxHeapValue(free_queue_st *free_queue) const;
void dumpHeap() const;
void divChunk(const off_t chunk_offset, const size_t size);
};
MmapManager::Impl::Impl(MmapManager &ommanager):mmanager(ommanager), isOpen(false), mmapCntlAddr(NULL), mmapCntlHead(NULL){}
void MmapManager::Impl::initBootStruct(boot_st &bst, size_t size) const
{
bst.version = MMAP_MANAGER_VERSION;
bst.reserve = 0;
bst.size = size;
}
void MmapManager::Impl::initFreeStruct(free_st &fst) const
{
fst.large_list.free_p = -1;
fst.large_list.free_last_p = -1;
for(uint32_t i = 0; i < MMAP_FREE_LIST_NUM; ++i){
fst.free_lists[i].free_p = -1;
fst.free_lists[i].free_last_p = -1;
}
}
void MmapManager::Impl::initFreeQueue(free_queue_st &fqst) const
{
fqst.data = -1;
fqst.capacity = MMAP_FREE_QUEUE_SIZE;
fqst.tail = 1;
}
void MmapManager::Impl::initControlStruct(control_st &cntlst, size_t size) const
{
cntlst.use_expand = MMAP_DEFAULT_ALLOW_EXPAND;
cntlst.unit_num = 1;
cntlst.active_unit = 0;
cntlst.reserve = 0;
cntlst.base_size = size;
cntlst.entry_p = 0;
cntlst.reuse_type = REUSE_DATA_CLASSIFY;
initFreeStruct(cntlst.free_data);
initFreeQueue(cntlst.free_queue);
memset(cntlst.data_headers, 0, sizeof(head_st) * MMAP_MAX_UNIT_NUM);
}
void MmapManager::Impl::setupChunkHead(chunk_head_st *chunk_head, const bool delete_flg, const uint16_t unit_id, const off_t free_next, const size_t size) const
{
chunk_head_st chunk_buffer;
chunk_buffer.delete_flg = delete_flg;
chunk_buffer.unit_id = unit_id;
chunk_buffer.free_next = free_next;
chunk_buffer.size = size;
memcpy(chunk_head, &chunk_buffer, sizeof(chunk_head_st));
}
bool MmapManager::Impl::expandMemory()
{
const uint16_t new_unit_num = mmapCntlHead->unit_num + 1;
const size_t new_file_size = mmapCntlHead->base_size * new_unit_num;
const off_t old_file_size = mmapCntlHead->base_size * mmapCntlHead->unit_num;
if(new_unit_num >= MMAP_MAX_UNIT_NUM){
std::cerr << "over max unit num" << std::endl;
return false;
}
int32_t fd = formatFile(filePath, new_file_size);
assert(fd >= 0);
const off_t offset = mmapCntlHead->base_size * mmapCntlHead->unit_num;
errno = 0;
void *new_area = mmap(NULL, mmapCntlHead->base_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, offset);
if(new_area == MAP_FAILED){
const std::string err_str = getErrorStr(errno);
errno = 0;
if(ftruncate(fd, old_file_size) == -1){
const std::string err_str = getErrorStr(errno);
throw MmapManagerException("truncate error" + err_str);
}
if(close(fd) == -1) std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException("mmap error" + err_str);
}
if(close(fd) == -1) std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
mmapDataAddr[mmapCntlHead->unit_num] = new_area;
mmapCntlHead->unit_num = new_unit_num;
mmapCntlHead->active_unit++;
return true;
}
int32_t MmapManager::Impl::formatFile(const std::string &targetFile, size_t size) const
{
const char *c = "";
int32_t fd;
errno = 0;
if((fd = open(targetFile.c_str(), O_RDWR|O_CREAT, 0666)) == -1){
std::stringstream ss;
ss << "[ERR] Cannot open the file. " << targetFile << " " << getErrorStr(errno);
throw MmapManagerException(ss.str());
}
errno = 0;
if(lseek(fd, (off_t)size-1, SEEK_SET) < 0){
std::stringstream ss;
ss << "[ERR] Cannot seek the file. " << targetFile << " " << getErrorStr(errno);
if(close(fd) == -1) std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException(ss.str());
}
errno = 0;
if(write(fd, &c, sizeof(char)) == -1){
std::stringstream ss;
ss << "[ERR] Cannot write the file. Check the disk space. " << targetFile << " " << getErrorStr(errno);
if(close(fd) == -1) std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException(ss.str());
}
return fd;
}
void MmapManager::Impl::clearChunk(const off_t chunk_off) const
{
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_off);
const off_t payload_off = chunk_off + sizeof(chunk_head_st);
chunk_head->delete_flg = false;
chunk_head->free_next = -1;
char *payload_addr = (char *)mmanager.getAbsAddr(payload_off);
memset(payload_addr, 0, chunk_head->size);
}
void MmapManager::Impl::free_data_classify(const off_t p, const bool force_large_list) const
{
const off_t chunk_offset = p - sizeof(chunk_head_st);
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
const size_t p_size = chunk_head->size;
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
free_list_st *free_list;
if(p_size <= border_size && force_large_list == false){
uint32_t index = (p_size / MMAP_MEMORY_ALIGN) - 1;
free_list = &mmapCntlHead->free_data.free_lists[index];
}else{
free_list = &mmapCntlHead->free_data.large_list;
}
if(free_list->free_p == -1){
free_list->free_p = free_list->free_last_p = chunk_offset;
}else{
off_t last_off = free_list->free_last_p;
chunk_head_st *tmp_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(last_off);
free_list->free_last_p = tmp_chunk_head->free_next = chunk_offset;
}
chunk_head->delete_flg = true;
}
off_t MmapManager::Impl::reuse_data_classify(const size_t size, reuse_state_t &reuse_state, const bool force_large_list) const
{
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
free_list_st *free_list;
if(size <= border_size && force_large_list == false){
uint32_t index = (size / MMAP_MEMORY_ALIGN) - 1;
free_list = &mmapCntlHead->free_data.free_lists[index];
}else{
free_list = &mmapCntlHead->free_data.large_list;
}
if(free_list->free_p == -1){
reuse_state = REUSE_STATE_ALLOC;
return -1;
}
off_t current_off = free_list->free_p;
off_t ret_off = 0;
chunk_head_st *current_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(current_off);
chunk_head_st *ret_chunk_head = NULL;
if( (size <= border_size) && (free_list->free_last_p == free_list->free_p) ){
ret_off = current_off;
ret_chunk_head = current_chunk_head;
free_list->free_p = free_list->free_last_p = -1;
}else{
off_t ret_before_off = -1, before_off = -1;
bool found_candidate_flag = false;
while(current_chunk_head != NULL){
if( current_chunk_head->size >= size ) found_candidate_flag = true;
if(found_candidate_flag){
ret_off = current_off;
ret_chunk_head = current_chunk_head;
ret_before_off = before_off;
break;
}
before_off = current_off;
current_off = current_chunk_head->free_next;
current_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(current_off);
}
if(!found_candidate_flag){
reuse_state = REUSE_STATE_ALLOC;
return -1;
}
const off_t free_next = ret_chunk_head->free_next;
if(free_list->free_p == ret_off){
free_list->free_p = free_next;
}else{
chunk_head_st *before_chunk = (chunk_head_st *)mmanager.getAbsAddr(ret_before_off);
before_chunk->free_next = free_next;
}
if(free_list->free_last_p == ret_off){
free_list->free_last_p = ret_before_off;
}
}
clearChunk(ret_off);
ret_off = ret_off + sizeof(chunk_head_st);
return ret_off;
}
void MmapManager::Impl::free_data_queue(const off_t p)
{
free_queue_st *free_queue = &mmapCntlHead->free_queue;
if(free_queue->data == -1){
const size_t queue_size = sizeof(off_t) * free_queue->capacity;
const off_t alloc_offset = mmanager.alloc(queue_size);
if(alloc_offset == -1){
return free_data_classify(p, true);
}
free_queue->data = alloc_offset;
}else if(free_queue->tail >= free_queue->capacity){
const off_t tmp_old_queue = free_queue->data;
const size_t old_size = sizeof(off_t) * free_queue->capacity;
const size_t new_capacity = free_queue->capacity * 2;
const size_t new_size = sizeof(off_t) * new_capacity;
if(new_size > mmapCntlHead->base_size){
return free_data_classify(p, true);
}else{
const off_t alloc_offset = mmanager.alloc(new_size);
if(alloc_offset == -1){
return free_data_classify(p, true);
}
free_queue->data = alloc_offset;
const off_t *old_data = (off_t *)mmanager.getAbsAddr(tmp_old_queue);
off_t *new_data = (off_t *)mmanager.getAbsAddr(free_queue->data);
memcpy(new_data, old_data, old_size);
free_queue->capacity = new_capacity;
mmanager.free(tmp_old_queue);
}
}
const off_t chunk_offset = p - sizeof(chunk_head_st);
if(!insertHeap(free_queue, chunk_offset)){
return;
}
chunk_head_st *chunk_head = (chunk_head_st*)mmanager.getAbsAddr(chunk_offset);
chunk_head->delete_flg = 1;
return;
}
off_t MmapManager::Impl::reuse_data_queue(const size_t size, reuse_state_t &reuse_state)
{
free_queue_st *free_queue = &mmapCntlHead->free_queue;
if(free_queue->data == -1){
reuse_state = REUSE_STATE_ALLOC;
return -1;
}
if(getMaxHeapValue(free_queue) < size){
reuse_state = REUSE_STATE_ALLOC;
return -1;
}
off_t ret_off;
if(!getHeap(free_queue, &ret_off)){
reuse_state = REUSE_STATE_ALLOC;
return -1;
}
reuse_state_t list_state = REUSE_STATE_OK;
off_t candidate_off = reuse_data_classify(MMAP_MEMORY_ALIGN, list_state, true);
if(list_state == REUSE_STATE_OK){
mmanager.free(candidate_off);
}
const off_t c_ret_off = ret_off;
divChunk(c_ret_off, size);
clearChunk(ret_off);
ret_off = ret_off + sizeof(chunk_head_st);
return ret_off;
}
void MmapManager::Impl::free_data_queue_plus(const off_t p)
{
const off_t chunk_offset = p - sizeof(chunk_head_st);
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
const size_t p_size = chunk_head->size;
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
if(p_size <= border_size){
free_data_classify(p);
}else{
free_data_queue(p);
}
}
off_t MmapManager::Impl::reuse_data_queue_plus(const size_t size, reuse_state_t &reuse_state)
{
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
off_t ret_off;
if(size <= border_size){
ret_off = reuse_data_classify(size, reuse_state);
if(reuse_state == REUSE_STATE_ALLOC){
reuse_state = REUSE_STATE_OK;
ret_off = reuse_data_queue(size, reuse_state);
}
}else{
ret_off = reuse_data_queue(size, reuse_state);
}
return ret_off;
}
bool MmapManager::Impl::scanAllData(void *target, const check_statistics_t stats_type) const
{
const uint16_t unit_num = mmapCntlHead->unit_num;
size_t total_size = 0;
uint64_t total_chunk_num = 0;
for(int i = 0; i < unit_num; i++){
const head_st *target_unit_head = &mmapCntlHead->data_headers[i];
const uint64_t chunk_num = target_unit_head->chunk_num;
const off_t base_offset = i * mmapCntlHead->base_size;
off_t target_offset = base_offset;
chunk_head_st *target_chunk;
for(uint64_t j = 0; j < chunk_num; j++){
target_chunk = (chunk_head_st*)mmanager.getAbsAddr(target_offset);
if(stats_type == CHECK_STATS_USE_SIZE){
if(target_chunk->delete_flg == false){
total_size += target_chunk->size;
}
}else if(stats_type == CHECK_STATS_USE_NUM){
if(target_chunk->delete_flg == false){
total_chunk_num++;
}
}else if(stats_type == CHECK_STATS_FREE_SIZE){
if(target_chunk->delete_flg == true){
total_size += target_chunk->size;
}
}else if(stats_type == CHECK_STATS_FREE_NUM){
if(target_chunk->delete_flg == true){
total_chunk_num++;
}
}
const size_t chunk_size = sizeof(chunk_head_st) + target_chunk->size;
target_offset += chunk_size;
}
}
if(stats_type == CHECK_STATS_USE_SIZE || stats_type == CHECK_STATS_FREE_SIZE){
size_t *tmp_size = (size_t *)target;
*tmp_size = total_size;
}else if(stats_type == CHECK_STATS_USE_NUM || stats_type == CHECK_STATS_FREE_NUM){
uint64_t *tmp_chunk_num = (uint64_t *)target;
*tmp_chunk_num = total_chunk_num;
}
return true;
}
void MmapManager::Impl::upHeap(free_queue_st *free_queue, uint64_t index) const
{
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
while(index > 1){
uint64_t parent = index / 2;
const off_t parent_chunk_offset = queue[parent];
const off_t index_chunk_offset = queue[index];
const chunk_head_st *parent_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(parent_chunk_offset);
const chunk_head_st *index_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(index_chunk_offset);
if(parent_chunk_head->size < index_chunk_head->size){
const off_t tmp = queue[parent];
queue[parent] = queue[index];
queue[index] = tmp;
}
index = parent;
}
}
void MmapManager::Impl::downHeap(free_queue_st *free_queue)const
{
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
uint64_t index = 1;
while(index * 2 <= free_queue->tail){
uint64_t child = index * 2;
const off_t index_chunk_offset = queue[index];
const chunk_head_st *index_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(index_chunk_offset);
if(child + 1 < free_queue->tail){
const off_t left_chunk_offset = queue[child];
const off_t right_chunk_offset = queue[child+1];
const chunk_head_st *left_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(left_chunk_offset);
const chunk_head_st *right_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(right_chunk_offset);
if(left_chunk_head->size < right_chunk_head->size){
child = child + 1;
}
}
const off_t child_chunk_offset = queue[child];
const chunk_head_st *child_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(child_chunk_offset);
if(child_chunk_head->size > index_chunk_head->size){
const off_t tmp = queue[child];
queue[child] = queue[index];
queue[index] = tmp;
index = child;
}else{
break;
}
}
}
bool MmapManager::Impl::insertHeap(free_queue_st *free_queue, const off_t p) const
{
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
uint64_t index;
if(free_queue->capacity < free_queue->tail){
return false;
}
index = free_queue->tail;
queue[index] = p;
free_queue->tail += 1;
upHeap(free_queue, index);
return true;
}
bool MmapManager::Impl::getHeap(free_queue_st *free_queue, off_t *p) const
{
if( (free_queue->tail - 1) <= 0){
return false;
}
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
*p = queue[1];
free_queue->tail -= 1;
queue[1] = queue[free_queue->tail];
downHeap(free_queue);
return true;
}
size_t MmapManager::Impl::getMaxHeapValue(free_queue_st *free_queue) const
{
if(free_queue->data == -1){
return 0;
}
const off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
const chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(queue[1]);
return chunk_head->size;
}
void MmapManager::Impl::dumpHeap() const
{
free_queue_st *free_queue = &mmapCntlHead->free_queue;
if(free_queue->data == -1){
std::cout << "heap unused" << std::endl;
return;
}
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
for(uint32_t i = 1; i < free_queue->tail; ++i){
const off_t chunk_offset = queue[i];
const off_t payload_offset = chunk_offset + sizeof(chunk_head_st);
const chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
const size_t size = chunk_head->size;
std::cout << "[" << chunk_offset << "(" << payload_offset << "), " << size << "] ";
}
std::cout << std::endl;
}
void MmapManager::Impl::divChunk(const off_t chunk_offset, const size_t size)
{
if((mmapCntlHead->reuse_type != REUSE_DATA_QUEUE)
&& (mmapCntlHead->reuse_type != REUSE_DATA_QUEUE_PLUS)){
return;
}
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
const size_t border_size = sizeof(chunk_head_st) + MMAP_MEMORY_ALIGN;
const size_t align_size = getAlignSize(size);
const size_t rest_size = chunk_head->size - align_size;
if(rest_size < border_size){
return;
}
chunk_head->size = align_size;
const off_t new_chunk_offset = chunk_offset + sizeof(chunk_head_st) + align_size;
chunk_head_st *new_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(new_chunk_offset);
const size_t new_size = rest_size - sizeof(chunk_head_st);
setupChunkHead(new_chunk_head, true, chunk_head->unit_id, -1, new_size);
head_st *unit_header = &mmapCntlHead->data_headers[mmapCntlHead->active_unit];
unit_header->chunk_num++;
const off_t payload_offset = new_chunk_offset + sizeof(chunk_head_st);
mmanager.free(payload_offset);
return;
}
}

View File

@ -0,0 +1,602 @@
//
// Copyright (C) 2016-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/NGTQ/Quantizer.h"
#define NGTQ_SEARCH_CODEBOOK_SIZE_FLUCTUATION
namespace NGTQ {
class Command {
public:
Command():debugLevel(0) {}
void
create(NGT::Args &args)
{
const string usage = "Usage: ngtq create "
"[-o object-type (f:float|c:unsigned char)] [-D distance-function] [-n data-size] "
"[-p #-of-thread] [-d dimension] [-R global-codebook-range] [-r local-codebook-range] "
"[-C global-codebook-size-limit] [-c local-codebook-size-limit] [-N local-division-no] "
"[-T single-local-centroid (t|f)] [-e epsilon] [-i index-type (t:Tree|g:Graph)] "
"[-M global-centroid-creation-mode (d|s)] [-L global-centroid-creation-mode (d|k|s)] "
"[-S local-sample-coefficient] "
"index(output) data.tsv(input)";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified." << endl;
cerr << usage << endl;
return;
}
string data;
try {
data = args.get("#2");
} catch (...) {
cerr << "Data is not specified." << endl;
}
char objectType = args.getChar("o", 'f');
char distanceType = args.getChar("D", '2');
size_t dataSize = args.getl("n", 0);
NGTQ::Property property;
property.threadSize = args.getl("p", 24);
property.dimension = args.getl("d", 0);
property.globalRange = args.getf("R", 0);
property.localRange = args.getf("r", 0);
property.globalCentroidLimit = args.getl("C", 1000000);
property.localCentroidLimit = args.getl("c", 65000);
property.localDivisionNo = args.getl("N", 8);
property.batchSize = args.getl("b", 1000);
property.localClusteringSampleCoefficient = args.getl("S", 10);
{
char localCentroidType = args.getChar("T", 'f');
property.singleLocalCodebook = localCentroidType == 't' ? true : false;
}
{
char centroidCreationMode = args.getChar("M", 'd');
switch(centroidCreationMode) {
case 'd': property.centroidCreationMode = NGTQ::CentroidCreationModeDynamic; break;
case 's': property.centroidCreationMode = NGTQ::CentroidCreationModeStatic; break;
default:
cerr << "ngt: Invalid centroid creation mode. " << centroidCreationMode << endl;
cerr << usage << endl;
return;
}
}
{
char localCentroidCreationMode = args.getChar("L", 'd');
switch(localCentroidCreationMode) {
case 'd': property.localCentroidCreationMode = NGTQ::CentroidCreationModeDynamic; break;
case 's': property.localCentroidCreationMode = NGTQ::CentroidCreationModeStatic; break;
case 'k': property.localCentroidCreationMode = NGTQ::CentroidCreationModeDynamicKmeans; break;
default:
cerr << "ngt: Invalid centroid creation mode. " << localCentroidCreationMode << endl;
cerr << usage << endl;
return;
}
}
NGT::Property globalProperty;
NGT::Property localProperty;
{
char indexType = args.getChar("i", 't');
globalProperty.indexType = indexType == 't' ? NGT::Property::GraphAndTree : NGT::Property::Graph;
localProperty.indexType = globalProperty.indexType;
}
globalProperty.insertionRadiusCoefficient = args.getf("e", 0.1) + 1.0;
localProperty.insertionRadiusCoefficient = globalProperty.insertionRadiusCoefficient;
if (debugLevel >= 1) {
cerr << "epsilon=" << globalProperty.insertionRadiusCoefficient << endl;
cerr << "data size=" << dataSize << endl;
cerr << "dimension=" << property.dimension << endl;
cerr << "thread size=" << property.threadSize << endl;
cerr << "batch size=" << localProperty.batchSizeForCreation << endl;;
cerr << "index type=" << globalProperty.indexType << endl;
}
switch (objectType) {
case 'f': property.dataType = NGTQ::DataTypeFloat; break;
case 'c': property.dataType = NGTQ::DataTypeUint8; break;
default:
cerr << "ngt: Invalid object type. " << objectType << endl;
cerr << usage << endl;
return;
}
switch (distanceType) {
case '2': property.distanceType = NGTQ::DistanceTypeL2; break;
case '1': property.distanceType = NGTQ::DistanceTypeL1; break;
case 'a': property.distanceType = NGTQ::DistanceTypeAngle; break;
default:
cerr << "ngt: Invalid distance type. " << distanceType << endl;
cerr << usage << endl;
return;
}
cerr << "ngtq: Create" << endl;
NGTQ::Index::create(database, property, globalProperty, localProperty);
cerr << "ngtq: Append" << endl;
NGTQ::Index::append(database, data, dataSize);
}
void
rebuild(NGT::Args &args)
{
const string usage = "Usage: ngtq rebuild "
"[-o object-type (f:float|c:unsigned char)] [-D distance-function] [-n data-size] "
"[-p #-of-thread] [-d dimension] [-R global-codebook-range] [-r local-codebook-range] "
"[-C global-codebook-size-limit] [-c local-codebook-size-limit] [-N local-division-no] "
"[-T single-local-centroid (t|f)] [-e epsilon] [-i index-type (t:Tree|g:Graph)] "
"[-M centroid-creation_mode (d|s)] "
"index(output) data.tsv(input)";
string srcIndex;
try {
srcIndex = args.get("#1");
} catch (...) {
cerr << "DB is not specified." << endl;
cerr << usage << endl;
return;
}
string rebuiltIndex = srcIndex + ".tmp";
NGTQ::Property property;
NGT::Property globalProperty;
NGT::Property localProperty;
{
NGTQ::Index index(srcIndex);
property = index.getQuantizer().property;
index.getQuantizer().globalCodebook.getProperty(globalProperty);
index.getQuantizer().getLocalCodebook(0).getProperty(localProperty);
}
property.globalRange = args.getf("R", property.globalRange);
property.localRange = args.getf("r", property.localRange);
property.globalCentroidLimit = args.getl("C", property.globalCentroidLimit);
property.localCentroidLimit = args.getl("c", property.localCentroidLimit);
property.localDivisionNo = args.getl("N", property.localDivisionNo);
{
char localCentroidType = args.getChar("T", '-');
if (localCentroidType != '-') {
property.singleLocalCodebook = localCentroidType == 't' ? true : false;
}
}
{
char centroidCreationMode = args.getChar("M", '-');
if (centroidCreationMode != '-') {
property.centroidCreationMode = centroidCreationMode == 'd' ?
NGTQ::CentroidCreationModeDynamic : NGTQ::CentroidCreationModeStatic;
}
}
cerr << "global range=" << property.globalRange << endl;
cerr << "local range=" << property.localRange << endl;
cerr << "global centroid limit=" << property.globalCentroidLimit << endl;
cerr << "local centroid limit=" << property.localCentroidLimit << endl;
cerr << "local division no=" << property.localDivisionNo << endl;
NGTQ::Index::create(rebuiltIndex, property, globalProperty, localProperty);
cerr << "created a new db" << endl;
cerr << "start rebuilding..." << endl;
NGTQ::Index::rebuild(srcIndex, rebuiltIndex);
{
string src = srcIndex;
string dst = srcIndex + ".org";
if (std::rename(src.c_str(), dst.c_str()) != 0) {
stringstream msg;
msg << "ngtq::rebuild: Cannot rename. " << src << "=>" << dst ;
NGTThrowException(msg);
}
}
{
string src = rebuiltIndex;
string dst = srcIndex;
if (std::rename(src.c_str(), dst.c_str()) != 0) {
stringstream msg;
msg << "ngtq::rebuild: Cannot rename. " << src << "=>" << dst ;
NGTThrowException(msg);
}
}
}
void
append(NGT::Args &args)
{
const string usage = "Usage: ngtq append [-n data-size] "
"index(output) data.tsv(input)";
string index;
try {
index = args.get("#1");
} catch (...) {
cerr << "DB is not specified." << endl;
cerr << usage << endl;
return;
}
string data;
try {
data = args.get("#2");
} catch (...) {
cerr << "Data is not specified." << endl;
}
size_t dataSize = args.getl("n", 0);
if (debugLevel >= 1) {
cerr << "data size=" << dataSize << endl;
}
NGTQ::Index::append(index, data, dataSize);
}
void
search(NGT::Args &args)
{
const string usage = "Usage: ngtq search [-i g|t|s] [-n result-size] [-e epsilon] [-m mode(r|l|c|a)] "
"[-E edge-size] [-o output-mode] [-b result expansion(begin:end:[x]step)] "
"index(input) query.tsv(input)";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified" << endl;
cerr << usage << endl;
return;
}
string query;
try {
query = args.get("#2");
} catch (...) {
cerr << "Query is not specified" << endl;
cerr << usage << endl;
return;
}
int size = args.getl("n", 20);
char outputMode = args.getChar("o", '-');
float epsilon = 0.1;
char mode = args.getChar("m", '-');
NGTQ::AggregationMode aggregationMode;
switch (mode) {
case 'r': aggregationMode = NGTQ::AggregationModeExactDistanceThroughApproximateDistance; break; // refine
case 'e': aggregationMode = NGTQ::AggregationModeExactDistance; break; // refine
case 'l': aggregationMode = NGTQ::AggregationModeApproximateDistanceWithLookupTable; break; // lookup
case 'c': aggregationMode = NGTQ::AggregationModeApproximateDistanceWithCache; break; // cache
case '-':
case 'a': aggregationMode = NGTQ::AggregationModeApproximateDistance; break; // cache
default:
cerr << "Invalid aggregation mode. " << mode << endl;
cerr << usage << endl;
return;
}
if (args.getString("e", "none") == "-") {
// linear search
epsilon = FLT_MAX;
} else {
epsilon = args.getf("e", 0.1);
}
size_t beginOfResultExpansion, endOfResultExpansion, stepOfResultExpansion;
bool mulStep = false;
{
beginOfResultExpansion = stepOfResultExpansion = 1;
endOfResultExpansion = 0;
string str = args.getString("b", "16");
vector<string> tokens;
NGT::Common::tokenize(str, tokens, ":");
if (tokens.size() >= 1) { beginOfResultExpansion = NGT::Common::strtod(tokens[0]); }
if (tokens.size() >= 2) { endOfResultExpansion = NGT::Common::strtod(tokens[1]); }
if (tokens.size() >= 3) {
if (tokens[2][0] == 'x') {
mulStep = true;
stepOfResultExpansion = NGT::Common::strtod(tokens[2].substr(1));
} else {
stepOfResultExpansion = NGT::Common::strtod(tokens[2]);
}
}
}
if (debugLevel >= 1) {
cerr << "size=" << size << endl;
cerr << "result expansion=" << beginOfResultExpansion << "->" << endOfResultExpansion << "," << stepOfResultExpansion << endl;
}
NGTQ::Index index(database);
try {
ifstream is(query);
if (!is) {
cerr << "Cannot open the specified file. " << query << endl;
return;
}
if (outputMode == 's') { cout << "# Beginning of Evaluation" << endl; }
string line;
double totalTime = 0;
int queryCount = 0;
while(getline(is, line)) {
NGT::Object *query = index.allocateObject(line, " \t", 0);
queryCount++;
size_t resultExpansion = 0;
for (size_t base = beginOfResultExpansion;
resultExpansion <= endOfResultExpansion;
base = mulStep ? base * stepOfResultExpansion : base + stepOfResultExpansion) {
resultExpansion = base;
NGT::ObjectDistances objects;
if (outputMode == 'e') {
index.search(query, objects, size, resultExpansion, aggregationMode, epsilon);
objects.clear();
}
NGT::Timer timer;
timer.start();
// size : # of final resultant objects
// resultExpansion : # of resultant objects by using codebook search
index.search(query, objects, size, resultExpansion, aggregationMode, epsilon);
timer.stop();
totalTime += timer.time;
if (outputMode == 'e') {
cout << "# Query No.=" << queryCount << endl;
cout << "# Query=" << line.substr(0, 20) + " ..." << endl;
cout << "# Index Type=" << "----" << endl;
cout << "# Size=" << size << endl;
cout << "# Epsilon=" << epsilon << endl;
cout << "# Result expansion=" << resultExpansion << endl;
cout << "# Distance Computation=" << index.getQuantizer().distanceComputationCount << endl;
cout << "# Query Time (msec)=" << timer.time * 1000.0 << endl;
} else {
cout << "Query No." << queryCount << endl;
cout << "Rank\tIN-ID\tID\tDistance" << endl;
}
for (size_t i = 0; i < objects.size(); i++) {
cout << i + 1 << "\t" << objects[i].id << "\t";
cout << objects[i].distance << endl;
}
if (outputMode == 'e') {
cout << "# End of Search" << endl;
} else {
cout << "Query Time= " << timer.time << " (sec), " << timer.time * 1000.0 << " (msec)" << endl;
}
}
if (outputMode == 'e') {
cout << "# End of Query" << endl;
}
index.deleteObject(query);
}
if (outputMode == 'e') {
cout << "# Average Query Time (msec)=" << totalTime * 1000.0 / (double)queryCount << endl;
cout << "# Number of queries=" << queryCount << endl;
cout << "# End of Evaluation" << endl;
} else {
cout << "Average Query Time= " << totalTime / (double)queryCount << " (sec), "
<< totalTime * 1000.0 / (double)queryCount << " (msec), ("
<< totalTime << "/" << queryCount << ")" << endl;
}
} catch (NGT::Exception &err) {
cerr << "Error " << err.what() << endl;
cerr << usage << endl;
} catch (...) {
cerr << "Error" << endl;
cerr << usage << endl;
}
index.close();
}
void
remove(NGT::Args &args)
{
const string usage = "Usage: ngtq remove [-d object-ID-type(f|d)] index(input) object-ID(input)";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified" << endl;
cerr << usage << endl;
return;
}
try {
args.get("#2");
} catch (...) {
cerr << "ID is not specified" << endl;
cerr << usage << endl;
return;
}
char dataType = args.getChar("d", 'f');
if (debugLevel >= 1) {
cerr << "dataType=" << dataType << endl;
}
try {
vector<NGT::ObjectID> objects;
if (dataType == 'f') {
string ids;
try {
ids = args.get("#2");
} catch (...) {
cerr << "Data file is not specified" << endl;
cerr << usage << endl;
return;
}
ifstream is(ids);
if (!is) {
cerr << "Cannot open the specified file. " << ids << endl;
return;
}
string line;
int count = 0;
while(getline(is, line)) {
count++;
vector<string> tokens;
NGT::Common::tokenize(line, tokens, "\t ");
if (tokens.size() == 0 || tokens[0].size() == 0) {
continue;
}
char *e;
size_t id;
try {
id = strtol(tokens[0].c_str(), &e, 10);
objects.push_back(id);
} catch (...) {
cerr << "Illegal data. " << tokens[0] << endl;
}
if (*e != 0) {
cerr << "Illegal data. " << e << endl;
}
cerr << "removed ID=" << id << endl;
}
} else {
size_t id = args.getl("#2", 0);
cerr << "removed ID=" << id << endl;
objects.push_back(id);
}
NGT::Index::remove(database, objects);
} catch (NGT::Exception &err) {
cerr << "Error " << err.what() << endl;
cerr << usage << endl;
} catch (...) {
cerr << "Error" << endl;
cerr << usage << endl;
}
}
void
info(NGT::Args &args)
{
const string usage = "Usage: ngtq info index";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified" << endl;
cerr << usage << endl;
return;
}
NGTQ::Index index(database);
index.info(cout);
}
void
validate(NGT::Args &args)
{
const string usage = "parameter";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified" << endl;
cerr << usage << endl;
return;
}
NGTQ::Index index(database);
index.getQuantizer().validate();
}
#ifdef NGTQ_SHARED_INVERTED_INDEX
void
compress(NGT::Args &args)
{
const string usage = "Usage: ngtq compress index)";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified" << endl;
cerr << usage << endl;
return;
}
try {
NGTQ::Index::compress(database);
} catch (NGT::Exception &err) {
cerr << "Error " << err.what() << endl;
cerr << usage << endl;
} catch (...) {
cerr << "Error" << endl;
cerr << usage << endl;
}
}
#endif
void help() {
cerr << "Usage : ngtq command database data" << endl;
cerr << " command : create search remove append export import" << endl;
}
void execute(NGT::Args args) {
string command;
try {
command = args.get("#0");
} catch(...) {
help();
return;
}
debugLevel = args.getl("X", 0);
try {
if (debugLevel >= 1) {
cerr << "ngt::command=" << command << endl;
}
if (command == "search") {
search(args);
} else if (command == "create") {
create(args);
} else if (command == "append") {
append(args);
} else if (command == "remove") {
remove(args);
} else if (command == "info") {
info(args);
} else if (command == "validate") {
validate(args);
} else if (command == "rebuild") {
rebuild(args);
#ifdef NGTQ_SHARED_INVERTED_INDEX
} else if (command == "compress") {
compress(args);
#endif
} else {
cerr << "Illegal command. " << command << endl;
}
} catch(NGT::Exception &err) {
cerr << "ngt: Fatal error: " << err.what() << endl;
}
}
int debugLevel;
};
};

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,338 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/defines.h"
#include "NGT/Node.h"
#include "NGT/Tree.h"
#include <algorithm>
using namespace std;
const double NGT::Node::Object::Pivot = -1.0;
using namespace NGT;
void
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
InternalNode::updateChild(DVPTree &dvptree, Node::ID src, Node::ID dst,
SharedMemoryAllocator &allocator) {
#else
InternalNode::updateChild(DVPTree &dvptree, Node::ID src, Node::ID dst) {
#endif
int cs = dvptree.internalChildrenSize;
for (int i = 0; i < cs; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (getChildren(allocator)[i] == src) {
getChildren(allocator)[i] = dst;
#else
if (getChildren()[i] == src) {
getChildren()[i] = dst;
#endif
return;
}
}
}
int
LeafNode::selectPivotByMaxDistance(Container &c, Node::Objects &fs)
{
DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c;
int fsize = fs.size();
Distance maxd = 0.0;
int maxid = 0;
for (int i = 1; i < fsize; i++) {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[0].object, *fs[i].object);
if (d >= maxd) {
maxd = d;
maxid = i;
}
}
int aid = maxid;
maxd = 0.0;
maxid = 0;
for (int i = 0; i < fsize; i++) {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[aid].object, *fs[i].object);
if (i == aid) {
continue;
}
if (d >= maxd) {
maxd = d;
maxid = i;
}
}
int bid = maxid;
maxd = 0.0;
maxid = 0;
for (int i = 0; i < fsize; i++) {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[bid].object, *fs[i].object);
if (i == bid) {
continue;
}
if (d >= maxd) {
maxd = d;
maxid = i;
}
}
return maxid;
}
int
LeafNode::selectPivotByMaxVariance(Container &c, Node::Objects &fs)
{
DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c;
int fsize = fs.size();
Distance *distance = new Distance[fsize * fsize];
for (int i = 0; i < fsize; i++) {
distance[i * fsize + i] = 0;
}
for (int i = 0; i < fsize; i++) {
for (int j = i + 1; j < fsize; j++) {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[i].object, *fs[j].object);
distance[i * fsize + j] = d;
distance[j * fsize + i] = d;
}
}
double *variance = new double[fsize];
for (int i = 0; i < fsize; i++) {
double avg = 0.0;
for (int j = 0; j < fsize; j++) {
avg += distance[i * fsize + j];
}
avg /= (double)fsize;
double v = 0.0;
for (int j = 0; j < fsize; j++) {
v += pow(distance[i * fsize + j] - avg, 2.0);
}
variance[i] = v / (double)fsize;
}
double maxv = variance[0];
int maxid = 0;
for (int i = 0; i < fsize; i++) {
if (variance[i] > maxv) {
maxv = variance[i];
maxid = i;
}
}
delete [] variance;
delete [] distance;
return maxid;
}
void
LeafNode::splitObjects(Container &c, Objects &fs, int pv)
{
DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c;
// sort the objects by distance
int fsize = fs.size();
for (int i = 0; i < fsize; i++) {
if (i == pv) {
fs[i].distance = 0;
} else {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[pv].object, *fs[i].object);
fs[i].distance = d;
}
}
sort(fs.begin(), fs.end());
int childrenSize = iobj.vptree->internalChildrenSize;
int cid = childrenSize - 1;
int cms = (fsize * cid) / childrenSize;
// divide the objects into child clusters.
fs[fsize - 1].clusterID = cid;
for (int i = fsize - 2; i >= 0; i--) {
if (i < cms && cid > 0) {
if (fs[i].distance != fs[i + 1].distance) {
cid--;
cms = (fsize * cid) / childrenSize;
}
}
fs[i].clusterID = cid;
}
if (cid != 0) {
// the required number of child nodes could not be acquired
stringstream msg;
msg << "LeafNode::splitObjects: Too many same distances. Reduce internal children size for the tree index or not use the tree index." << endl;
msg << " internalChildrenSize=" << childrenSize << endl;
msg << " # of the children=" << (childrenSize - cid) << endl;
msg << " Size=" << fsize << endl;
msg << " pivot=" << pv << endl;
msg << " cluster id=" << cid << endl;
msg << " Show distances for debug." << endl;
for (int i = 0; i < fsize; i++) {
msg << " " << fs[i].id << ":" << fs[i].distance << endl;
msg << " ";
PersistentObject &po = *fs[i].object;
iobj.vptree->objectSpace->show(msg, po);
msg << endl;
}
if (fs[fsize - 1].clusterID == cid) {
msg << "LeafNode::splitObjects: All of the object distances are the same!" << endl;;
NGTThrowException(msg.str());
} else {
cerr << msg.str() << endl;
cerr << "LeafNode::splitObjects: Anyway, continue..." << endl;
// sift the cluster IDs to start from 0 to continue.
for (int i = 0; i < fsize; i++) {
fs[i].clusterID -= cid;
}
}
}
long long *pivots = new long long[childrenSize];
for (int i = 0; i < childrenSize; i++) {
pivots[i] = -1;
}
// find the boundaries for the subspaces
for (int i = 0; i < fsize; i++) {
if (pivots[fs[i].clusterID] == -1) {
pivots[fs[i].clusterID] = i;
fs[i].leafDistance = Object::Pivot;
} else {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[pivots[fs[i].clusterID]].object, *fs[i].object);
fs[i].leafDistance = d;
}
}
delete[] pivots;
return;
}
void
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
LeafNode::removeObject(size_t id, size_t replaceId, SharedMemoryAllocator &allocator) {
#else
LeafNode::removeObject(size_t id, size_t replaceId) {
#endif
size_t fsize = getObjectSize();
size_t idx;
for (idx = 0; idx < fsize; idx++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (getObjectIDs(allocator)[idx].id == id) {
if (replaceId != 0) {
getObjectIDs(allocator)[idx].id = replaceId;
#else
if (getObjectIDs()[idx].id == id) {
if (replaceId != 0) {
getObjectIDs()[idx].id = replaceId;
#endif
return;
} else {
break;
}
}
}
if (idx == fsize) {
if (pivot == 0) {
NGTThrowException("LeafNode::removeObject: Internal error!. the pivot is illegal.");
}
stringstream msg;
msg << "VpTree::Leaf::remove: Warning. Cannot find the specified object. ID=" << id << "," << replaceId << " idx=" << idx << " If the same objects were inserted into the index, ignore this message.";
NGTThrowException(msg.str());
}
#ifdef NGT_NODE_USE_VECTOR
for (; idx < objectIDs.size() - 1; idx++) {
getObjectIDs()[idx] = getObjectIDs()[idx + 1];
}
objectIDs.pop_back();
#else
objectSize--;
for (; idx < objectSize; idx++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getObjectIDs(allocator)[idx] = getObjectIDs(allocator)[idx + 1];
#else
getObjectIDs()[idx] = getObjectIDs()[idx + 1];
#endif
}
#endif
return;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
bool InternalNode::verify(PersistentRepository<InternalNode> &internalNodes, PersistentRepository<LeafNode> &leafNodes,
SharedMemoryAllocator &allocator) {
#else
bool InternalNode::verify(Repository<InternalNode> &internalNodes, Repository<LeafNode> &leafNodes) {
#endif
size_t isize = internalNodes.size();
size_t lsize = leafNodes.size();
bool valid = true;
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
size_t nid = getChildren(allocator)[i].getID();
ID::Type type = getChildren(allocator)[i].getType();
#else
size_t nid = getChildren()[i].getID();
ID::Type type = getChildren()[i].getType();
#endif
size_t size = type == ID::Leaf ? lsize : isize;
if (nid >= size) {
cerr << "Error! Internal children node id is too big." << nid << ":" << size << endl;
valid = false;
}
try {
if (type == ID::Leaf) {
leafNodes.get(nid);
} else {
internalNodes.get(nid);
}
} catch (...) {
cerr << "Error! Cannot get the node. " << ((type == ID::Leaf) ? "Leaf" : "Internal") << endl;
valid = false;
}
}
return valid;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
bool LeafNode::verify(size_t nobjs, vector<uint8_t> &status, SharedMemoryAllocator &allocator) {
#else
bool LeafNode::verify(size_t nobjs, vector<uint8_t> &status) {
#endif
bool valid = true;
for (size_t i = 0; i < objectSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
size_t nid = getObjectIDs(allocator)[i].id;
#else
size_t nid = getObjectIDs()[i].id;
#endif
if (nid > nobjs) {
cerr << "Error! Object id is too big. " << nid << ":" << nobjs << endl;
valid =false;
continue;
}
status[nid] |= 0x04;
}
return valid;
}

View File

@ -0,0 +1,772 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <algorithm>
#include <sstream>
#include "NGT/Common.h"
#include "NGT/ObjectSpaceRepository.h"
#include "NGT/defines.h"
namespace NGT {
class DVPTree;
class InternalNode;
class LeafNode;
class Node {
public:
typedef unsigned int NodeID;
class ID {
public:
enum Type {
Leaf = 1,
Internal = 0
};
ID():id(0) {}
ID &operator=(const ID &n) {
id = n.id;
return *this;
}
ID &operator=(int i) {
setID(i);
return *this;
}
bool operator==(ID &n) { return id == n.id; }
bool operator<(ID &n) { return id < n.id; }
Type getType() { return (Type)((0x80000000 & id) >> 31); }
NodeID getID() { return 0x7fffffff & id; }
NodeID get() { return id; }
void setID(NodeID i) { id = (0x80000000 & id) | i; }
void setType(Type t) { id = (t << 31) | getID(); }
void setRaw(NodeID i) { id = i; }
void setNull() { id = 0; }
// for milvus
void serialize(std::stringstream & os) { NGT::Serializer::write(os, id); }
void serialize(std::ofstream &os) { NGT::Serializer::write(os, id); }
void deserialize(std::ifstream &is) { NGT::Serializer::read(is, id); }
// for milvus
void deserialize(std::stringstream & is) { NGT::Serializer::read(is, id); }
void serializeAsText(std::ofstream &os) { NGT::Serializer::writeAsText(os, id); }
void deserializeAsText(std::ifstream &is) { NGT::Serializer::readAsText(is, id); }
protected:
NodeID id;
};
class Object {
public:
Object():object(0) {}
bool operator<(const Object &o) const { return distance < o.distance; }
static const double Pivot;
ObjectID id;
PersistentObject *object;
Distance distance;
Distance leafDistance;
int clusterID;
};
typedef std::vector<Object> Objects;
Node() {
parent.setNull();
id.setNull();
}
virtual ~Node() {}
Node &operator=(const Node &n) {
id = n.id;
parent = n.parent;
return *this;
}
// for milvus
void serialize(std::stringstream & os)
{
id.serialize(os);
parent.serialize(os);
}
void serialize(std::ofstream &os) {
id.serialize(os);
parent.serialize(os);
}
void deserialize(std::ifstream &is) {
id.deserialize(is);
parent.deserialize(is);
}
void deserialize(std::stringstream & is)
{
id.deserialize(is);
parent.deserialize(is);
}
void serializeAsText(std::ofstream &os) {
id.serializeAsText(os);
os << " ";
parent.serializeAsText(os);
}
void deserializeAsText(std::ifstream &is) {
id.deserializeAsText(is);
parent.deserializeAsText(is);
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void setPivot(PersistentObject &f, ObjectSpace &os, SharedMemoryAllocator &allocator) {
if (pivot == 0) {
pivot = NGT::PersistentObject::allocate(os);
}
getPivot(os).set(f, os);
}
PersistentObject &getPivot(ObjectSpace &os) {
return *(PersistentObject*)os.getRepository().getAllocator().getAddr(pivot);
}
void deletePivot(ObjectSpace &os, SharedMemoryAllocator &allocator) {
os.deleteObject(&getPivot(os));
}
#else // NGT_SHARED_MEMORY_ALLOCATOR
void setPivot(NGT::Object &f, ObjectSpace &os) {
if (pivot == 0) {
pivot = NGT::Object::allocate(os);
}
os.copy(getPivot(), f);
}
NGT::Object &getPivot() { return *pivot; }
void deletePivot(ObjectSpace &os) {
os.deleteObject(pivot);
}
#endif // NGT_SHARED_MEMORY_ALLOCATOR
bool pivotIsEmpty() {
return pivot == 0;
}
ID id;
ID parent;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
off_t pivot;
#else
NGT::Object *pivot;
#endif
};
class InternalNode : public Node {
public:
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
InternalNode(size_t csize, SharedMemoryAllocator &allocator) : childrenSize(csize) { initialize(allocator); }
InternalNode(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) : childrenSize(5) { initialize(allocator); }
#else
InternalNode(size_t csize) : childrenSize(csize) { initialize(); }
InternalNode(NGT::ObjectSpace *os = 0) : childrenSize(5) { initialize(); }
#endif
~InternalNode() {
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
if (children != 0) {
delete[] children;
}
if (borders != 0) {
delete[] borders;
}
#endif
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void initialize(SharedMemoryAllocator &allocator) {
#else
void initialize() {
#endif
id = 0;
id.setType(ID::Internal);
pivot = 0;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
children = allocator.getOffset(new(allocator) ID[childrenSize]);
#else
children = new ID[childrenSize];
#endif
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getChildren(allocator)[i] = 0;
#else
getChildren()[i] = 0;
#endif
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
borders = allocator.getOffset(new(allocator) Distance[childrenSize - 1]);
#else
borders = new Distance[childrenSize - 1];
#endif
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getBorders(allocator)[i] = 0;
#else
getBorders()[i] = 0;
#endif
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void updateChild(DVPTree &dvptree, ID src, ID dst, SharedMemoryAllocator &allocator);
#else
void updateChild(DVPTree &dvptree, ID src, ID dst);
#endif
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ID *getChildren(SharedMemoryAllocator &allocator) { return (ID*)allocator.getAddr(children); }
Distance *getBorders(SharedMemoryAllocator &allocator) { return (Distance*)allocator.getAddr(borders); }
#else // NGT_SHARED_MEMORY_ALLOCATOR
ID *getChildren() { return children; }
Distance *getBorders() { return borders; }
#endif // NGT_SHARED_MEMORY_ALLOCATOR
// for milvus
void serialize(std::stringstream & os, ObjectSpace * objectspace = 0)
{
Node::serialize(os);
if (pivot == 0)
{
NGTThrowException("Node::write: pivot is null!");
}
assert(objectspace != 0);
getPivot().serialize(os, objectspace);
NGT::Serializer::write(os, childrenSize);
for (size_t i = 0; i < childrenSize; i++)
{
getChildren()[i].serialize(os);
}
for (size_t i = 0; i < childrenSize - 1; i++)
{
NGT::Serializer::write(os, getBorders()[i]);
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void serialize(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
#else
void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) {
#endif
Node::serialize(os);
if (pivot == 0) {
NGTThrowException("Node::write: pivot is null!");
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getPivot(*objectspace).serialize(os, allocator, objectspace);
#else
getPivot().serialize(os, objectspace);
#endif
NGT::Serializer::write(os, childrenSize);
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getChildren(allocator)[i].serialize(os);
#else
getChildren()[i].serialize(os);
#endif
}
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::Serializer::write(os, getBorders(allocator)[i]);
#else
NGT::Serializer::write(os, getBorders()[i]);
#endif
}
}
void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) {
Node::deserialize(is);
if (pivot == 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getPivot().deserialize(is, objectspace);
#endif
NGT::Serializer::read(is, childrenSize);
assert(children != 0);
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
assert(0);
#else
getChildren()[i].deserialize(is);
#endif
}
assert(borders != 0);
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
assert(0);
#else
NGT::Serializer::read(is, getBorders()[i]);
#endif
}
}
// for milvus
void deserialize(std::stringstream & is, ObjectSpace * objectspace = 0)
{
Node::deserialize(is);
if (pivot == 0)
{
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getPivot().deserialize(is, objectspace);
#endif
NGT::Serializer::read(is, childrenSize);
assert(children != 0);
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
assert(0);
#else
getChildren()[i].deserialize(is);
#endif
}
assert(borders != 0);
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
assert(0);
#else
NGT::Serializer::read(is, getBorders()[i]);
#endif
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
#else
void serializeAsText(std::ofstream &os, ObjectSpace *objectspace = 0) {
#endif
Node::serializeAsText(os);
if (pivot == 0) {
NGTThrowException("Node::write: pivot is null!");
}
os << " ";
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getPivot(*objectspace).serializeAsText(os, objectspace);
#else
getPivot().serializeAsText(os, objectspace);
#endif
os << " ";
NGT::Serializer::writeAsText(os, childrenSize);
os << " ";
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getChildren(allocator)[i].serializeAsText(os);
#else
getChildren()[i].serializeAsText(os);
#endif
os << " ";
}
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::Serializer::writeAsText(os, getBorders(allocator)[i]);
#else
NGT::Serializer::writeAsText(os, getBorders()[i]);
#endif
os << " ";
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
#else
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) {
#endif
Node::deserializeAsText(is);
if (pivot == 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getPivot(*objectspace).deserializeAsText(is, objectspace);
#else
getPivot().deserializeAsText(is, objectspace);
#endif
size_t csize;
NGT::Serializer::readAsText(is, csize);
assert(children != 0);
assert(childrenSize == csize);
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getChildren(allocator)[i].deserializeAsText(is);
#else
getChildren()[i].deserializeAsText(is);
#endif
}
assert(borders != 0);
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::Serializer::readAsText(is, getBorders(allocator)[i]);
#else
NGT::Serializer::readAsText(is, getBorders()[i]);
#endif
}
}
void show() {
std::cout << "Show internal node " << childrenSize << ":";
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
assert(0);
#else
std::cout << getChildren()[i].getID() << " ";
#endif
}
std::cout << std::endl;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
bool verify(PersistentRepository<InternalNode> &internalNodes, PersistentRepository<LeafNode> &leafNodes,
SharedMemoryAllocator &allocator);
#else
bool verify(Repository<InternalNode> &internalNodes, Repository<LeafNode> &leafNodes);
#endif
static const int InternalChildrenSizeMax = 5;
const size_t childrenSize;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
off_t children;
off_t borders;
#else
ID *children;
Distance *borders;
#endif
};
class LeafNode : public Node {
public:
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
LeafNode(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) {
#else
LeafNode(NGT::ObjectSpace *os = 0) {
#endif
id = 0;
id.setType(ID::Leaf);
pivot = 0;
#ifdef NGT_NODE_USE_VECTOR
objectIDs.reserve(LeafObjectsSizeMax);
#else
objectSize = 0;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
objectIDs = allocator.getOffset(new(allocator) Object[LeafObjectsSizeMax]);
#else
objectIDs = new NGT::ObjectDistance[LeafObjectsSizeMax];
#endif
#endif
}
~LeafNode() {
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
#ifndef NGT_NODE_USE_VECTOR
if (objectIDs != 0) {
delete[] objectIDs;
}
#endif
#endif
}
static int
selectPivotByMaxDistance(Container &iobj, Node::Objects &fs);
static int
selectPivotByMaxVariance(Container &iobj, Node::Objects &fs);
static void
splitObjects(Container &insertedObject, Objects &splitObjectSet, int pivot);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void removeObject(size_t id, size_t replaceId, SharedMemoryAllocator &allocator);
#else
void removeObject(size_t id, size_t replaceId);
#endif
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
#ifndef NGT_NODE_USE_VECTOR
NGT::ObjectDistance *getObjectIDs(SharedMemoryAllocator &allocator) {
return (NGT::ObjectDistance *)allocator.getAddr(objectIDs);
}
#endif
#else // NGT_SHARED_MEMORY_ALLOCATOR
NGT::ObjectDistance *getObjectIDs() { return objectIDs; }
#endif // NGT_SHARED_MEMORY_ALLOCATOR
// for milvus
void serialize(std::stringstream & os, ObjectSpace * objectspace = 0)
{
Node::serialize(os);
NGT::Serializer::write(os, objectSize);
for (int i = 0; i < objectSize; i++)
{
objectIDs[i].serialize(os);
}
if (pivot == 0)
{
// Before insertion, parent ID == 0 and object size == 0, that indicates an empty index
if (parent.getID() != 0 || objectSize != 0)
{
NGTThrowException("Node::write: pivot is null!");
}
}
else
{
assert(objectspace != 0);
pivot->serialize(os, objectspace);
}
}
void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) {
Node::serialize(os);
#ifdef NGT_NODE_USE_VECTOR
NGT::Serializer::write(os, objectIDs);
#else
NGT::Serializer::write(os, objectSize);
for (int i = 0; i < objectSize; i++) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
std::cerr << "not implemented" << std::endl;
assert(0);
#else
objectIDs[i].serialize(os);
#endif
}
#endif // NGT_NODE_USE_VECTOR
if (pivot == 0) {
// Before insertion, parent ID == 0 and object size == 0, that indicates an empty index
if (parent.getID() != 0 || objectSize != 0) {
NGTThrowException("Node::write: pivot is null!");
}
} else {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
std::cerr << "not implemented" << std::endl;
assert(0);
#else
assert(objectspace != 0);
pivot->serialize(os, objectspace);
#endif
}
}
void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) {
Node::deserialize(is);
#ifdef NGT_NODE_USE_VECTOR
objectIDs.clear();
NGT::Serializer::read(is, objectIDs);
#else
assert(objectIDs != 0);
NGT::Serializer::read(is, objectSize);
for (int i = 0; i < objectSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getObjectIDs()[i].deserialize(is);
#endif
}
#endif
if (parent.getID() == 0 && objectSize == 0) {
// The index is empty
return;
}
if (pivot == 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
assert(pivot != 0);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getPivot().deserialize(is, objectspace);
#endif
}
// for milvus
void deserialize(std::stringstream & is, ObjectSpace * objectspace = 0)
{
Node::deserialize(is);
#ifdef NGT_NODE_USE_VECTOR
objectIDs.clear();
NGT::Serializer::read(is, objectIDs);
#else
assert(objectIDs != 0);
NGT::Serializer::read(is, objectSize);
for (int i = 0; i < objectSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getObjectIDs()[i].deserialize(is);
#endif
}
#endif
if (parent.getID() == 0 && objectSize == 0) {
// The index is empty
return;
}
if (pivot == 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
assert(pivot != 0);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getPivot().deserialize(is, objectspace);
#endif
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
#else
void serializeAsText(std::ofstream &os, ObjectSpace *objectspace = 0) {
#endif
Node::serializeAsText(os);
os << " ";
if (pivot == 0) {
NGTThrowException("Node::write: pivot is null!");
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
getPivot(*objectspace).serializeAsText(os, objectspace);
#else
assert(pivot != 0);
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
pivot->serializeAsText(os, allocator, objectspace);
#else
pivot->serializeAsText(os, objectspace);
#endif
#endif
os << " ";
#ifdef NGT_NODE_USE_VECTOR
NGT::Serializer::writeAsText(os, objectIDs);
#else
NGT::Serializer::writeAsText(os, objectSize);
for (int i = 0; i < objectSize; i++) {
os << " ";
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
getObjectIDs(allocator)[i].serializeAsText(os);
#else
objectIDs[i].serializeAsText(os);
#endif
}
#endif
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
#else
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) {
#endif
Node::deserializeAsText(is);
if (pivot == 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getPivot(*objectspace).deserializeAsText(is, objectspace);
#else
getPivot().deserializeAsText(is, objectspace);
#endif
#ifdef NGT_NODE_USE_VECTOR
objectIDs.clear();
NGT::Serializer::readAsText(is, objectIDs);
#else
assert(objectIDs != 0);
NGT::Serializer::readAsText(is, objectSize);
for (int i = 0; i < objectSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getObjectIDs(allocator)[i].deserializeAsText(is);
#else
getObjectIDs()[i].deserializeAsText(is);
#endif
}
#endif
}
void show() {
std::cout << "Show leaf node " << objectSize << ":";
for (int i = 0; i < objectSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
std::cout << getObjectIDs()[i].id << "," << getObjectIDs()[i].distance << " ";
#endif
}
std::cout << std::endl;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
bool verify(size_t nobjs, std::vector<uint8_t> &status, SharedMemoryAllocator &allocator);
#else
bool verify(size_t nobjs, std::vector<uint8_t> &status);
#endif
#ifdef NGT_NODE_USE_VECTOR
size_t getObjectSize() { return objectIDs.size(); }
#else
size_t getObjectSize() { return objectSize; }
#endif
static const size_t LeafObjectsSizeMax = 100;
#ifdef NGT_NODE_USE_VECTOR
std::vector<Object> objectIDs;
#else
unsigned short objectSize;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
off_t objectIDs;
#else
ObjectDistance *objectIDs;
#endif
#endif
};
}

View File

@ -0,0 +1,395 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <sstream>
namespace NGT {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
class ObjectRepository :
public PersistentRepository<PersistentObject> {
public:
typedef PersistentRepository<PersistentObject> Parent;
void open(const std::string &smfile, size_t sharedMemorySize) {
std::string file = smfile;
file.append("po");
Parent::open(file, sharedMemorySize);
}
#else
class ObjectRepository : public Repository<Object> {
public:
typedef Repository<Object> Parent;
#endif
ObjectRepository(size_t dim, const std::type_info &ot):dimension(dim), type(ot), sparse(false) { }
void initialize() {
deleteAll();
Parent::push_back((PersistentObject*)0);
}
// for milvus
void serialize(std::stringstream & obj, ObjectSpace * ospace) { Parent::serialize(obj, ospace); }
void serialize(const std::string &ofile, ObjectSpace *ospace) {
std::ofstream objs(ofile);
if (!objs.is_open()) {
std::stringstream msg;
msg << "NGT::ObjectSpace: Cannot open the specified file " << ofile << ".";
NGTThrowException(msg);
}
Parent::serialize(objs, ospace);
}
void deserialize(std::stringstream & obj, ObjectSpace * ospace)
{
assert(ospace != 0);
Parent::deserialize(obj, ospace);
}
void deserialize(const std::string &ifile, ObjectSpace *ospace) {
assert(ospace != 0);
std::ifstream objs(ifile);
if (!objs.is_open()) {
std::stringstream msg;
msg << "NGT::ObjectSpace: Cannot open the specified file " << ifile << ".";
NGTThrowException(msg);
}
Parent::deserialize(objs, ospace);
}
void serializeAsText(const std::string &ofile, ObjectSpace *ospace) {
std::ofstream objs(ofile);
if (!objs.is_open()) {
std::stringstream msg;
msg << "NGT::ObjectSpace: Cannot open the specified file " << ofile << ".";
NGTThrowException(msg);
}
Parent::serializeAsText(objs, ospace);
}
void deserializeAsText(const std::string &ifile, ObjectSpace *ospace) {
std::ifstream objs(ifile);
if (!objs.is_open()) {
std::stringstream msg;
msg << "NGT::ObjectSpace: Cannot open the specified file " << ifile << ".";
NGTThrowException(msg);
}
Parent::deserializeAsText(objs, ospace);
}
void readText(std::istream &is, size_t dataSize = 0) {
initialize();
appendText(is, dataSize);
}
// For milvus
template <typename T>
void readRawData(const T * raw_data, size_t dataSize)
{
initialize();
append(raw_data, dataSize);
}
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<double> &obj) {
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(double): Fatal error! Something wrong!" << std::endl;
abort();
}
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<float> &obj) {
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(float): Fatal error! Something wrong!" << std::endl;
abort();
}
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<uint8_t> &obj) {
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(uint8_t): Fatal error! Something wrong!" << std::endl;
abort();
}
virtual PersistentObject *allocateNormalizedPersistentObject(const float *obj, size_t size) {
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject: Fatal error! Something wrong!" << std::endl;
abort();
}
void appendText(std::istream &is, size_t dataSize = 0) {
if (dimension == 0) {
NGTThrowException("ObjectSpace::readText: Dimension is not specified.");
}
if (size() == 0) {
// First entry should be always a dummy entry.
// If it is empty, the dummy entry should be inserted.
push_back((PersistentObject*)0);
}
size_t prevDataSize = size();
if (dataSize > 0) {
reserve(size() + dataSize);
}
std::string line;
size_t lineNo = 0;
while (getline(is, line)) {
lineNo++;
if (dataSize > 0 && (dataSize <= size() - prevDataSize)) {
std::cerr << "The size of data reached the specified size. The remaining data in the file are not inserted. "
<< dataSize << std::endl;
break;
}
std::vector<double> object;
try {
extractObjectFromText(line, "\t ", object);
PersistentObject *obj = 0;
try {
obj = allocateNormalizedPersistentObject(object);
} catch (Exception &err) {
std::cerr << err.what() << " continue..." << std::endl;
obj = allocatePersistentObject(object);
}
push_back(obj);
} catch (Exception &err) {
std::cerr << "ObjectSpace::readText: Warning! Invalid line. [" << line << "] Skip the line " << lineNo << " and continue." << std::endl;
}
}
}
template <typename T>
void append(T *data, size_t objectCount) {
if (dimension == 0) {
NGTThrowException("ObjectSpace::readText: Dimension is not specified.");
}
if (size() == 0) {
// First entry should be always a dummy entry.
// If it is empty, the dummy entry should be inserted.
push_back((PersistentObject*)0);
}
if (objectCount > 0) {
reserve(size() + objectCount);
}
for (size_t idx = 0; idx < objectCount; idx++, data += dimension) {
std::vector<double> object;
object.reserve(dimension);
for (size_t dataidx = 0; dataidx < dimension; dataidx++) {
object.push_back(data[dataidx]);
}
try {
PersistentObject *obj = 0;
try {
obj = allocateNormalizedPersistentObject(object);
} catch (Exception &err) {
std::cerr << err.what() << " continue..." << std::endl;
obj = allocatePersistentObject(object);
}
push_back(obj);
} catch (Exception &err) {
std::cerr << "ObjectSpace::readText: Warning! Invalid data. Skip the data no. " << idx << " and continue." << std::endl;
}
}
}
Object *allocateObject() {
return (Object*) new Object(paddedByteSize);
}
// This method is called during search to generate query.
// Therefore the object is not persistent.
Object *allocateObject(const std::string &textLine, const std::string &sep) {
std::vector<double> object;
extractObjectFromText(textLine, sep, object);
Object *po = (Object*)allocateObject(object);
return (Object*)po;
}
void extractObjectFromText(const std::string &textLine, const std::string &sep, std::vector<double> &object) {
object.resize(dimension);
std::vector<std::string> tokens;
NGT::Common::tokenize(textLine, tokens, sep);
if (dimension > tokens.size()) {
std::stringstream msg;
msg << "ObjectSpace::allocate: too few dimension. " << tokens.size() << ":" << dimension << ". "
<< textLine;
NGTThrowException(msg);
}
size_t idx;
for (idx = 0; idx < dimension; idx++) {
if (tokens[idx].size() == 0) {
std::stringstream msg;
msg << "ObjectSpace::allocate: too few dimension. " << tokens.size() << ":"
<< dimension << ". " << textLine;
NGTThrowException(msg);
}
char *e;
object[idx] = strtod(tokens[idx].c_str(), &e);
if (*e != 0) {
std::cerr << "ObjectSpace::readText: Warning! Not numerical value. [" << e << "]" << std::endl;
break;
}
}
}
template <typename T>
Object *allocateObject(T *o, size_t size) {
size_t osize = paddedByteSize;
if (sparse) {
size_t vsize = size * (type == typeid(float) ? 4 : 1);
osize = osize < vsize ? vsize : osize;
} else {
if (dimension != size) {
std::cerr << "ObjectSpace::allocateObject: Fatal error! dimension is invalid. The indexed objects="
<< dimension << " The specified object=" << size << std::endl;
assert(dimension == size);
}
}
Object *po = new Object(osize);
void *object = static_cast<void*>(&(*po)[0]);
if (type == typeid(uint8_t)) {
uint8_t *obj = static_cast<uint8_t*>(object);
for (size_t i = 0; i < size; i++) {
obj[i] = static_cast<uint8_t>(o[i]);
}
} else if (type == typeid(float)) {
float *obj = static_cast<float*>(object);
for (size_t i = 0; i < size; i++) {
obj[i] = static_cast<float>(o[i]);
}
} else {
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
abort();
}
return po;
}
template <typename T>
Object *allocateObject(const std::vector<T> &o) {
return allocateObject(o.data(), o.size());
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
PersistentObject *allocatePersistentObject(Object &o) {
SharedMemoryAllocator &objectAllocator = getAllocator();
size_t cpsize = dimension;
if (type == typeid(uint8_t)) {
cpsize *= sizeof(uint8_t);
} else if (type == typeid(float)) {
cpsize *= sizeof(float);
} else {
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
abort();
}
PersistentObject *po = new (objectAllocator) PersistentObject(objectAllocator, paddedByteSize);
void *dsto = &(*po).at(0, allocator);
void *srco = &o[0];
memcpy(dsto, srco, cpsize);
return po;
}
template <typename T>
PersistentObject *allocatePersistentObject(T *o, size_t size) {
SharedMemoryAllocator &objectAllocator = getAllocator();
PersistentObject *po = new (objectAllocator) PersistentObject(objectAllocator, paddedByteSize);
if (size != 0 && dimension != size) {
std::stringstream msg;
msg << "ObjectSpace::allocatePersistentObject: Fatal error! The dimensionality is invalid. The specified dimensionality="
<< (sparse ? dimension - 1 : dimension) << ". The specified object=" << (sparse ? size - 1 : size) << ".";
NGTThrowException(msg);
}
void *object = static_cast<void*>(&(*po).at(0, allocator));
if (type == typeid(uint8_t)) {
uint8_t *obj = static_cast<uint8_t*>(object);
for (size_t i = 0; i < dimension; i++) {
obj[i] = static_cast<uint8_t>(o[i]);
}
} else if (type == typeid(float)) {
float *obj = static_cast<float*>(object);
for (size_t i = 0; i < dimension; i++) {
obj[i] = static_cast<float>(o[i]);
}
} else {
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
abort();
}
return po;
}
template <typename T>
PersistentObject *allocatePersistentObject(const std::vector<T> &o) {
return allocatePersistentObject(o.data(), o.size());
}
#else
template <typename T>
PersistentObject *allocatePersistentObject(T *o, size_t size) {
if (size != 0 && dimension != size) {
std::stringstream msg;
msg << "ObjectSpace::allocatePersistentObject: Fatal error! The dimensionality is invalid. The specified dimensionality="
<< (sparse ? dimension - 1 : dimension) << ". The specified object=" << (sparse ? size - 1 : size) << ".";
NGTThrowException(msg);
}
return allocateObject(o, size);
}
template <typename T>
PersistentObject *allocatePersistentObject(const std::vector<T> &o) {
return allocatePersistentObject(o.data(), o.size());
}
#endif
void deleteObject(Object *po) {
delete po;
}
private:
void extractObject(void *object, std::vector<double> &d) {
if (type == typeid(uint8_t)) {
uint8_t *obj = (uint8_t*)object;
for (size_t i = 0; i < dimension; i++) {
d.push_back(obj[i]);
}
} else if (type == typeid(float)) {
float *obj = (float*)object;
for (size_t i = 0; i < dimension; i++) {
d.push_back(obj[i]);
}
} else {
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
abort();
}
}
public:
void extractObject(Object *o, std::vector<double> &d) {
void *object = (void*)(&(*o)[0]);
extractObject(object, d);
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void extractObject(PersistentObject *o, std::vector<double> &d) {
SharedMemoryAllocator &objectAllocator = getAllocator();
void *object = (void*)(&(*o).at(0, objectAllocator));
extractObject(object, d);
}
#endif
void setLength(size_t l) { byteSize = l; }
void setPaddedLength(size_t l) { paddedByteSize = l; }
void setSparse() { sparse = true; }
size_t getByteSize() { return byteSize; }
size_t insert(PersistentObject *obj) { return Parent::insert(obj); }
const size_t dimension;
const std::type_info &type;
protected:
size_t byteSize; // the length of all of elements.
size_t paddedByteSize;
bool sparse; // sparse data format
};
} // namespace NGT

View File

@ -0,0 +1,475 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "PrimitiveComparator.h"
class ObjectSpace;
namespace NGT {
class PersistentObjectDistances;
class ObjectDistances : public std::vector<ObjectDistance> {
public:
ObjectDistances(NGT::ObjectSpace *os = 0) {}
// for milvus
void serialize(std::stringstream & os, ObjectSpace * objspace = 0) { NGT::Serializer::write(os, (std::vector<ObjectDistance> &)*this); }
void serialize(std::ofstream &os, ObjectSpace *objspace = 0) { NGT::Serializer::write(os, (std::vector<ObjectDistance>&)*this);}
// for milvus
void deserialize(std::stringstream & is, ObjectSpace * objspace = 0)
{
NGT::Serializer::read(is, (std::vector<ObjectDistance> &)*this);
}
void deserialize(std::ifstream &is, ObjectSpace *objspace = 0) { NGT::Serializer::read(is, (std::vector<ObjectDistance>&)*this);}
void serializeAsText(std::ofstream &os, ObjectSpace *objspace = 0) {
NGT::Serializer::writeAsText(os, size());
os << " ";
for (size_t i = 0; i < size(); i++) {
(*this)[i].serializeAsText(os);
os << " ";
}
}
void deserializeAsText(std::ifstream &is, ObjectSpace *objspace = 0) {
size_t s;
NGT::Serializer::readAsText(is, s);
resize(s);
for (size_t i = 0; i < size(); i++) {
(*this)[i].deserializeAsText(is);
}
}
void moveFrom(std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > &pq) {
this->clear();
this->resize(pq.size());
for (int i = pq.size() - 1; i >= 0; i--) {
(*this)[i] = pq.top();
pq.pop();
}
assert(pq.size() == 0);
}
void moveFrom(std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > &pq, double (&f)(double)) {
this->clear();
this->resize(pq.size());
for (int i = pq.size() - 1; i >= 0; i--) {
(*this)[i] = pq.top();
(*this)[i].distance = f((*this)[i].distance);
pq.pop();
}
assert(pq.size() == 0);
}
void moveFrom(std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > &pq, unsigned int id) {
this->clear();
if (pq.size() == 0) {
return;
}
this->resize(id == 0 ? pq.size() : pq.size() - 1);
int i = this->size() - 1;
while (pq.size() != 0 && i >= 0) {
if (pq.top().id != id) {
(*this)[i] = pq.top();
i--;
}
pq.pop();
}
if (pq.size() != 0 && pq.top().id != id) {
std::cerr << "moveFrom: Fatal error: somethig wrong! " << pq.size() << ":" << this->size() << ":" << id << ":" << pq.top().id << std::endl;
assert(pq.size() == 0 || pq.top().id == id);
}
}
ObjectDistances &operator=(PersistentObjectDistances &objs);
};
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
class PersistentObjectDistances : public Vector<ObjectDistance> {
public:
PersistentObjectDistances(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) {}
void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) { NGT::Serializer::write(os, (Vector<ObjectDistance>&)*this); }
void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) { NGT::Serializer::read(is, (Vector<ObjectDistance>&)*this); }
void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objspace = 0) {
NGT::Serializer::writeAsText(os, size());
os << " ";
for (size_t i = 0; i < size(); i++) {
(*this).at(i, allocator).serializeAsText(os);
os << " ";
}
}
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objspace = 0) {
size_t s;
is >> s;
resize(s, allocator);
for (size_t i = 0; i < size(); i++) {
(*this).at(i, allocator).deserializeAsText(is);
}
}
PersistentObjectDistances &copy(ObjectDistances &objs, SharedMemoryAllocator &allocator) {
clear(allocator);
reserve(objs.size(), allocator);
for (ObjectDistances::iterator i = objs.begin(); i != objs.end(); i++) {
push_back(*i, allocator);
}
return *this;
}
};
typedef PersistentObjectDistances GraphNode;
inline ObjectDistances &ObjectDistances::operator=(PersistentObjectDistances &objs)
{
clear();
reserve(objs.size());
std::cerr << "not implemented" << std::endl;
assert(0);
return *this;
}
#else // NGT_SHARED_MEMORY_ALLOCATOR
typedef ObjectDistances GraphNode;
#endif // NGT_SHARED_MEMORY_ALLOCATOR
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
class PersistentObject;
#else
typedef Object PersistentObject;
#endif
class ObjectRepository;
class ObjectSpace {
public:
class Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
Comparator(size_t d, SharedMemoryAllocator &a) : dimension(d), allocator(a) {}
#else
Comparator(size_t d) : dimension(d) {}
#endif
virtual double operator()(Object &objecta, Object &objectb) = 0;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
virtual double operator()(Object &objecta, PersistentObject &objectb) = 0;
virtual double operator()(PersistentObject &objecta, PersistentObject &objectb) = 0;
#endif
size_t dimension;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
SharedMemoryAllocator &allocator;
#endif
virtual ~Comparator(){}
};
enum DistanceType {
DistanceTypeNone = -1,
DistanceTypeL1 = 0,
DistanceTypeL2 = 1,
DistanceTypeHamming = 2,
DistanceTypeAngle = 3,
DistanceTypeCosine = 4,
DistanceTypeNormalizedAngle = 5,
DistanceTypeNormalizedCosine = 6,
DistanceTypeJaccard = 7,
DistanceTypeSparseJaccard = 8
};
enum ObjectType {
ObjectTypeNone = 0,
Uint8 = 1,
Float = 2
};
typedef std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > ResultSet;
ObjectSpace(size_t d):dimension(d), distanceType(DistanceTypeNone), comparator(0), normalization(false) {}
virtual ~ObjectSpace() { if (comparator != 0) { delete comparator; } }
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
virtual void open(const std::string &f, size_t shareMemorySize) = 0;
virtual Object *allocateObject(Object &o) = 0;
virtual Object *allocateObject(PersistentObject &o) = 0;
virtual PersistentObject *allocatePersistentObject(Object &obj) = 0;
virtual void deleteObject(PersistentObject *) = 0;
virtual void copy(PersistentObject &objecta, PersistentObject &objectb) = 0;
virtual void show(std::ostream &os, PersistentObject &object) = 0;
virtual size_t insert(PersistentObject *obj) = 0;
#else
virtual size_t insert(Object *obj) = 0;
#endif
Comparator &getComparator() { return *comparator; }
virtual void serialize(const std::string &of) = 0;
// for milvus
virtual void serialize(std::stringstream & obj) = 0;
// for milvus
virtual void deserialize(std::stringstream & obj) = 0;
virtual void deserialize(const std::string &ifile) = 0;
virtual void serializeAsText(const std::string &of) = 0;
virtual void deserializeAsText(const std::string &of) = 0;
//for milvus
virtual void readRawData(const float * raw_data, size_t dataSize) = 0;
virtual void readText(std::istream &is, size_t dataSize) = 0;
virtual void appendText(std::istream &is, size_t dataSize) = 0;
virtual void append(const float *data, size_t dataSize) = 0;
virtual void append(const double *data, size_t dataSize) = 0;
virtual void copy(Object &objecta, Object &objectb) = 0;
virtual void linearSearch(Object &query, double radius, size_t size,
ObjectSpace::ResultSet &results) = 0;
virtual const std::type_info &getObjectType() = 0;
virtual void show(std::ostream &os, Object &object) = 0;
virtual size_t getSize() = 0;
virtual size_t getSizeOfElement() = 0;
virtual size_t getByteSizeOfObject() = 0;
virtual Object *allocateNormalizedObject(const std::string &textLine, const std::string &sep) = 0;
virtual Object *allocateNormalizedObject(const std::vector<double> &obj) = 0;
virtual Object *allocateNormalizedObject(const std::vector<float> &obj) = 0;
virtual Object *allocateNormalizedObject(const std::vector<uint8_t> &obj) = 0;
virtual Object *allocateNormalizedObject(const float *obj, size_t size) = 0;
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<double> &obj) = 0;
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<float> &obj) = 0;
virtual void deleteObject(Object *po) = 0;
virtual Object *allocateObject() = 0;
virtual void remove(size_t id) = 0;
virtual ObjectRepository &getRepository() = 0;
virtual void setDistanceType(DistanceType t) = 0;
virtual void *getObject(size_t idx) = 0;
virtual void getObject(size_t idx, std::vector<float> &v) = 0;
virtual void getObjects(const std::vector<size_t> &idxs, std::vector<std::vector<float>> &vs) = 0;
size_t getDimension() { return dimension; }
size_t getPaddedDimension() { return ((dimension - 1) / 16 + 1) * 16; }
template <typename T>
void normalize(T *data, size_t dim) {
double sum = 0.0;
for (size_t i = 0; i < dim; i++) {
sum += (double)data[i] * (double)data[i];
}
if (sum == 0.0) {
std::stringstream msg;
msg << "ObjectSpace::normalize: Error! the object is an invalid zero vector for the cosine similarity or angle distance.";
NGTThrowException(msg);
}
sum = sqrt(sum);
for (size_t i = 0; i < dim; i++) {
data[i] = (double)data[i] / sum;
}
}
uint32_t getPrefetchOffset() { return prefetchOffset; }
uint32_t setPrefetchOffset(size_t offset) {
if (offset == 0) {
prefetchOffset = floor(300.0 / (static_cast<float>(getPaddedDimension()) + 30.0) + 1.0);
} else {
prefetchOffset = offset;
}
return prefetchOffset;
}
uint32_t getPrefetchSize() { return prefetchSize; }
uint32_t setPrefetchSize(size_t size) {
if (size == 0) {
prefetchSize = getByteSizeOfObject();
} else {
prefetchSize = size;
}
return prefetchSize;
}
protected:
const size_t dimension;
DistanceType distanceType;
Comparator *comparator;
bool normalization;
uint32_t prefetchOffset;
uint32_t prefetchSize;
};
class BaseObject {
public:
virtual uint8_t &operator[](size_t idx) const = 0;
void serialize(std::ostream &os, ObjectSpace *objectspace = 0) {
assert(objectspace != 0);
size_t byteSize = objectspace->getByteSizeOfObject();
NGT::Serializer::write(os, (uint8_t*)&(*this)[0], byteSize);
}
void deserialize(std::istream &is, ObjectSpace *objectspace = 0) {
assert(objectspace != 0);
size_t byteSize = objectspace->getByteSizeOfObject();
assert(&(*this)[0] != 0);
NGT::Serializer::read(is, (uint8_t*)&(*this)[0], byteSize);
}
void serializeAsText(std::ostream &os, ObjectSpace *objectspace = 0) {
assert(objectspace != 0);
const std::type_info &t = objectspace->getObjectType();
size_t dimension = objectspace->getDimension();
void *ref = (void*)&(*this)[0];
if (t == typeid(uint8_t)) {
NGT::Serializer::writeAsText(os, (uint8_t*)ref, dimension);
} else if (t == typeid(float)) {
NGT::Serializer::writeAsText(os, (float*)ref, dimension);
} else if (t == typeid(double)) {
NGT::Serializer::writeAsText(os, (double*)ref, dimension);
} else if (t == typeid(uint16_t)) {
NGT::Serializer::writeAsText(os, (uint16_t*)ref, dimension);
} else if (t == typeid(uint32_t)) {
NGT::Serializer::writeAsText(os, (uint32_t*)ref, dimension);
} else {
std::cerr << "Object::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
assert(0);
}
}
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) {
assert(objectspace != 0);
const std::type_info &t = objectspace->getObjectType();
size_t dimension = objectspace->getDimension();
void *ref = (void*)&(*this)[0];
assert(ref != 0);
if (t == typeid(uint8_t)) {
NGT::Serializer::readAsText(is, (uint8_t*)ref, dimension);
} else if (t == typeid(float)) {
NGT::Serializer::readAsText(is, (float*)ref, dimension);
} else if (t == typeid(double)) {
NGT::Serializer::readAsText(is, (double*)ref, dimension);
} else if (t == typeid(uint16_t)) {
NGT::Serializer::readAsText(is, (uint16_t*)ref, dimension);
} else if (t == typeid(uint32_t)) {
NGT::Serializer::readAsText(is, (uint32_t*)ref, dimension);
} else {
std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
assert(0);
}
}
};
class Object : public BaseObject {
public:
Object(NGT::ObjectSpace *os = 0):vector(0) {
assert(os != 0);
size_t s = os->getByteSizeOfObject();
construct(s);
}
Object(size_t s):vector(0) {
assert(s != 0);
construct(s);
}
void copy(Object &o, size_t s) {
assert(vector != 0);
for (size_t i = 0; i < s; i++) {
vector[i] = o[i];
}
}
virtual ~Object() { clear(); }
uint8_t &operator[](size_t idx) const { return vector[idx]; }
void *getPointer(size_t idx = 0) const { return vector + idx; }
static Object *allocate(ObjectSpace &objectspace) { return new Object(&objectspace); }
private:
void clear() {
if (vector != 0) {
MemoryCache::alignedFree(vector);
}
vector = 0;
}
void construct(size_t s) {
assert(vector == 0);
size_t allocsize = ((s - 1) / 64 + 1) * 64;
vector = static_cast<uint8_t*>(MemoryCache::alignedAlloc(allocsize));
memset(vector, 0, allocsize);
}
uint8_t* vector;
};
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
class PersistentObject : public BaseObject {
public:
PersistentObject(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0):array(0) {
assert(os != 0);
size_t s = os->getByteSizeOfObject();
construct(s, allocator);
}
PersistentObject(SharedMemoryAllocator &allocator, size_t s):array(0) {
assert(s != 0);
construct(s, allocator);
}
~PersistentObject() {}
uint8_t &at(size_t idx, SharedMemoryAllocator &allocator) const {
uint8_t *a = (uint8_t *)allocator.getAddr(array);
return a[idx];
}
uint8_t &operator[](size_t idx) const {
std::cerr << "not implemented" << std::endl;
assert(0);
uint8_t *a = 0;
return a[idx];
}
void *getPointer(size_t idx, SharedMemoryAllocator &allocator) {
uint8_t *a = (uint8_t *)allocator.getAddr(array);
return a + idx;
}
// set v in objectspace to this object using allocator.
void set(PersistentObject &po, ObjectSpace &objectspace);
static off_t allocate(ObjectSpace &objectspace);
void serializeAsText(std::ostream &os, SharedMemoryAllocator &allocator,
ObjectSpace *objectspace = 0) {
serializeAsText(os, objectspace);
}
void serializeAsText(std::ostream &os, ObjectSpace *objectspace = 0);
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator,
ObjectSpace *objectspace = 0) {
deserializeAsText(is, objectspace);
}
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0);
void serialize(std::ostream &os, SharedMemoryAllocator &allocator,
ObjectSpace *objectspace = 0) {
std::cerr << "serialize is not implemented" << std::endl;
assert(0);
}
private:
void construct(size_t s, SharedMemoryAllocator &allocator) {
assert(array == 0);
assert(s != 0);
size_t allocsize = ((s - 1) / 64 + 1) * 64;
array = allocator.getOffset(new(allocator) uint8_t[allocsize]);
memset(getPointer(0, allocator), 0, allocsize);
}
off_t array;
};
#endif // NGT_SHARED_MEMORY_ALLOCATOR
}

View File

@ -0,0 +1,620 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <sstream>
#include "Common.h"
#include "ObjectSpace.h"
#include "ObjectRepository.h"
#include "PrimitiveComparator.h"
class ObjectSpace;
namespace NGT {
template <typename OBJECT_TYPE, typename COMPARE_TYPE>
class ObjectSpaceRepository : public ObjectSpace, public ObjectRepository {
public:
class ComparatorL1 : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorL1(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorL1(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorL2 : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorL2(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorL2(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorHammingDistance : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorHammingDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorHammingDistance(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorJaccardDistance : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorJaccardDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorJaccardDistance(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorSparseJaccardDistance : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorSparseJaccardDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorSparseJaccardDistance(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorAngleDistance : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorAngleDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorAngleDistance(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorNormalizedAngleDistance : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorNormalizedAngleDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorNormalizedAngleDistance(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorCosineSimilarity : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorCosineSimilarity(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorCosineSimilarity(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorNormalizedCosineSimilarity : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorNormalizedCosineSimilarity(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorNormalizedCosineSimilarity(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
ObjectSpaceRepository(size_t d, const std::type_info &ot, DistanceType t) : ObjectSpace(d), ObjectRepository(d, ot) {
size_t objectSize = 0;
if (ot == typeid(uint8_t)) {
objectSize = sizeof(uint8_t);
} else if (ot == typeid(float)) {
objectSize = sizeof(float);
} else {
std::stringstream msg;
msg << "ObjectSpace::constructor: Not supported type. " << ot.name();
NGTThrowException(msg);
}
setLength(objectSize * d);
setPaddedLength(objectSize * ObjectSpace::getPaddedDimension());
setDistanceType(t);
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void open(const std::string &f, size_t sharedMemorySize) { ObjectRepository::open(f, sharedMemorySize); }
void copy(PersistentObject &objecta, PersistentObject &objectb) { objecta = objectb; }
void show(std::ostream &os, PersistentObject &object) {
const std::type_info &t = getObjectType();
if (t == typeid(uint8_t)) {
unsigned char *optr = static_cast<unsigned char*>(&object.at(0,allocator));
for (size_t i = 0; i < getDimension(); i++) {
os << (int)optr[i] << " ";
}
} else if (t == typeid(float)) {
float *optr = reinterpret_cast<float*>(&object.at(0,allocator));
for (size_t i = 0; i < getDimension(); i++) {
os << optr[i] << " ";
}
} else {
os << " not implement for the type.";
}
}
Object *allocateObject(Object &o) {
Object *po = new Object(getByteSizeOfObject());
for (size_t i = 0; i < getByteSizeOfObject(); i++) {
(*po)[i] = o[i];
}
return po;
}
Object *allocateObject(PersistentObject &o) {
PersistentObject &spo = (PersistentObject &)o;
Object *po = new Object(getByteSizeOfObject());
for (size_t i = 0; i < getByteSizeOfObject(); i++) {
(*po)[i] = spo.at(i,ObjectRepository::allocator);
}
return (Object*)po;
}
void deleteObject(PersistentObject *po) {
delete po;
}
#endif // NGT_SHARED_MEMORY_ALLOCATOR
void copy(Object &objecta, Object &objectb) {
objecta.copy(objectb, getByteSizeOfObject());
}
void setDistanceType(DistanceType t) {
if (comparator != 0) {
delete comparator;
}
assert(ObjectSpace::dimension != 0);
distanceType = t;
switch (distanceType) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
case DistanceTypeL1:
comparator = new ObjectSpaceRepository::ComparatorL1(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeL2:
comparator = new ObjectSpaceRepository::ComparatorL2(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeHamming:
comparator = new ObjectSpaceRepository::ComparatorHammingDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeJaccard:
comparator = new ObjectSpaceRepository::ComparatorJaccardDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeSparseJaccard:
comparator = new ObjectSpaceRepository::ComparatorSparseJaccardDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
setSparse();
break;
case DistanceTypeAngle:
comparator = new ObjectSpaceRepository::ComparatorAngleDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeCosine:
comparator = new ObjectSpaceRepository::ComparatorCosineSimilarity(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeNormalizedAngle:
comparator = new ObjectSpaceRepository::ComparatorNormalizedAngleDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
normalization = true;
break;
case DistanceTypeNormalizedCosine:
comparator = new ObjectSpaceRepository::ComparatorNormalizedCosineSimilarity(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
normalization = true;
break;
#else
case DistanceTypeL1:
comparator = new ObjectSpaceRepository::ComparatorL1(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeL2:
comparator = new ObjectSpaceRepository::ComparatorL2(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeHamming:
comparator = new ObjectSpaceRepository::ComparatorHammingDistance(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeJaccard:
comparator = new ObjectSpaceRepository::ComparatorJaccardDistance(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeSparseJaccard:
comparator = new ObjectSpaceRepository::ComparatorSparseJaccardDistance(ObjectSpace::getPaddedDimension());
setSparse();
break;
case DistanceTypeAngle:
comparator = new ObjectSpaceRepository::ComparatorAngleDistance(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeCosine:
comparator = new ObjectSpaceRepository::ComparatorCosineSimilarity(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeNormalizedAngle:
comparator = new ObjectSpaceRepository::ComparatorNormalizedAngleDistance(ObjectSpace::getPaddedDimension());
normalization = true;
break;
case DistanceTypeNormalizedCosine:
comparator = new ObjectSpaceRepository::ComparatorNormalizedCosineSimilarity(ObjectSpace::getPaddedDimension());
normalization = true;
break;
#endif
default:
std::cerr << "Distance type is not specified" << std::endl;
assert(distanceType != DistanceTypeNone);
abort();
}
}
void serialize(const std::string & ofile) { ObjectRepository::serialize(ofile, this); }
// for milvus
void serialize(std::stringstream & obj) { ObjectRepository::serialize(obj, this); }
// for milvus
void deserialize(std::stringstream & obj) { ObjectRepository::deserialize(obj, this); }
void deserialize(const std::string &ifile) { ObjectRepository::deserialize(ifile, this); }
void serializeAsText(const std::string &ofile) { ObjectRepository::serializeAsText(ofile, this); }
void deserializeAsText(const std::string &ifile) { ObjectRepository::deserializeAsText(ifile, this); }
// For milvus
void readRawData(const float * raw_data, size_t dataSize) { ObjectRepository::readRawData<float>(raw_data, dataSize); }
void readText(std::istream &is, size_t dataSize) { ObjectRepository::readText(is, dataSize); }
void appendText(std::istream &is, size_t dataSize) { ObjectRepository::appendText(is, dataSize); }
void append(const float *data, size_t dataSize) { ObjectRepository::append(data, dataSize); }
void append(const double *data, size_t dataSize) { ObjectRepository::append(data, dataSize); }
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
PersistentObject *allocatePersistentObject(Object &obj) {
return ObjectRepository::allocatePersistentObject(obj);
}
size_t insert(PersistentObject *obj) { return ObjectRepository::insert(obj); }
#else
size_t insert(Object *obj) { return ObjectRepository::insert(obj); }
#endif
void remove(size_t id) { ObjectRepository::remove(id); }
void linearSearch(Object &query, double radius, size_t size, ObjectSpace::ResultSet &results) {
if (!results.empty()) {
NGTThrowException("lenearSearch: results is not empty");
}
#ifndef NGT_PREFETCH_DISABLED
size_t byteSizeOfObject = getByteSizeOfObject();
const size_t prefetchOffset = getPrefetchOffset();
#endif
ObjectRepository &rep = *this;
for (size_t idx = 0; idx < rep.size(); idx++) {
#ifndef NGT_PREFETCH_DISABLED
if (idx + prefetchOffset < rep.size() && rep[idx + prefetchOffset] != 0) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
MemoryCache::prefetch((unsigned char*)&(*static_cast<PersistentObject*>(ObjectRepository::get(idx + prefetchOffset))), byteSizeOfObject);
#else
MemoryCache::prefetch((unsigned char*)&(*static_cast<PersistentObject*>(rep[idx + prefetchOffset]))[0], byteSizeOfObject);
#endif
}
#endif
if (rep[idx] == 0) {
continue;
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
Distance d = (*comparator)((Object&)query, (PersistentObject&)*rep[idx]);
#else
Distance d = (*comparator)((Object&)query, (Object&)*rep[idx]);
#endif
if (radius < 0.0 || d <= radius) {
NGT::ObjectDistance obj(idx, d);
results.push(obj);
if (results.size() > size) {
results.pop();
}
}
}
return;
}
void *getObject(size_t idx) {
if (isEmpty(idx)) {
std::stringstream msg;
msg << "NGT::ObjectSpaceRepository: The specified ID is out of the range. The object ID should be greater than zero. " << idx << ":" << ObjectRepository::size() << ".";
NGTThrowException(msg);
}
PersistentObject &obj = *(*this)[idx];
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
return reinterpret_cast<OBJECT_TYPE*>(&obj.at(0, allocator));
#else
return reinterpret_cast<OBJECT_TYPE*>(&obj[0]);
#endif
}
void getObject(size_t idx, std::vector<float> &v) {
OBJECT_TYPE *obj = static_cast<OBJECT_TYPE*>(getObject(idx));
size_t dim = getDimension();
v.resize(dim);
for (size_t i = 0; i < dim; i++) {
v[i] = static_cast<float>(obj[i]);
}
}
void getObjects(const std::vector<size_t> &idxs, std::vector<std::vector<float>> &vs) {
vs.resize(idxs.size());
auto v = vs.begin();
for (auto idx = idxs.begin(); idx != idxs.end(); idx++, v++) {
getObject(*idx, *v);
}
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void normalize(PersistentObject &object) {
OBJECT_TYPE *obj = (OBJECT_TYPE*)&object.at(0, getRepository().getAllocator());
ObjectSpace::normalize(obj, ObjectSpace::dimension);
}
#endif
void normalize(Object &object) {
OBJECT_TYPE *obj = (OBJECT_TYPE*)&object[0];
ObjectSpace::normalize(obj, ObjectSpace::dimension);
}
Object *allocateObject() { return ObjectRepository::allocateObject(); }
void deleteObject(Object *po) { ObjectRepository::deleteObject(po); }
Object *allocateNormalizedObject(const std::string &textLine, const std::string &sep) {
Object *allocatedObject = ObjectRepository::allocateObject(textLine, sep);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
Object *allocateNormalizedObject(const std::vector<double> &obj) {
Object *allocatedObject = ObjectRepository::allocateObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
Object *allocateNormalizedObject(const std::vector<float> &obj) {
Object *allocatedObject = ObjectRepository::allocateObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
Object *allocateNormalizedObject(const std::vector<uint8_t> &obj) {
Object *allocatedObject = ObjectRepository::allocateObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
Object *allocateNormalizedObject(const float *obj, size_t size) {
Object *allocatedObject = ObjectRepository::allocateObject(obj, size);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
PersistentObject *allocateNormalizedPersistentObject(const std::vector<double> &obj) {
PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
PersistentObject *allocateNormalizedPersistentObject(const std::vector<float> &obj) {
PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
PersistentObject *allocateNormalizedPersistentObject(const std::vector<uint8_t> &obj) {
PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
size_t getSize() { return ObjectRepository::size(); }
size_t getSizeOfElement() { return sizeof(OBJECT_TYPE); }
const std::type_info &getObjectType() { return typeid(OBJECT_TYPE); };
size_t getByteSizeOfObject() { return getByteSize(); }
ObjectRepository &getRepository() { return *this; };
void show(std::ostream &os, Object &object) {
const std::type_info &t = getObjectType();
if (t == typeid(uint8_t)) {
unsigned char *optr = static_cast<unsigned char*>(&object[0]);
for (size_t i = 0; i < getDimension(); i++) {
os << (int)optr[i] << " ";
}
} else if (t == typeid(float)) {
float *optr = reinterpret_cast<float*>(&object[0]);
for (size_t i = 0; i < getDimension(); i++) {
os << optr[i] << " ";
}
} else {
os << " not implement for the type.";
}
}
};
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
// set v in objectspace to this object using allocator.
inline void PersistentObject::set(PersistentObject &po, ObjectSpace &objectspace) {
SharedMemoryAllocator &allocator = objectspace.getRepository().getAllocator();
uint8_t *src = (uint8_t *)&po.at(0, allocator);
uint8_t *dst = (uint8_t *)&(*this).at(0, allocator);
memcpy(dst, src, objectspace.getByteSizeOfObject());
}
inline off_t PersistentObject::allocate(ObjectSpace &objectspace) {
SharedMemoryAllocator &allocator = objectspace.getRepository().getAllocator();
return allocator.getOffset(new(allocator) PersistentObject(allocator, &objectspace));
}
inline void PersistentObject::serializeAsText(std::ostream &os, ObjectSpace *objectspace) {
assert(objectspace != 0);
SharedMemoryAllocator &allocator = objectspace->getRepository().getAllocator();
const std::type_info &t = objectspace->getObjectType();
void *ref = &(*this).at(0, allocator);
size_t dimension = objectspace->getDimension();
if (t == typeid(uint8_t)) {
NGT::Serializer::writeAsText(os, (uint8_t*)ref, dimension);
} else if (t == typeid(float)) {
NGT::Serializer::writeAsText(os, (float*)ref, dimension);
} else if (t == typeid(double)) {
NGT::Serializer::writeAsText(os, (double*)ref, dimension);
} else if (t == typeid(uint16_t)) {
NGT::Serializer::writeAsText(os, (uint16_t*)ref, dimension);
} else if (t == typeid(uint32_t)) {
NGT::Serializer::writeAsText(os, (uint32_t*)ref, dimension);
} else {
std::cerr << "ObjectT::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
assert(0);
}
}
inline void PersistentObject::deserializeAsText(std::ifstream &is, ObjectSpace *objectspace) {
assert(objectspace != 0);
SharedMemoryAllocator &allocator = objectspace->getRepository().getAllocator();
const std::type_info &t = objectspace->getObjectType();
size_t dimension = objectspace->getDimension();
void *ref = &(*this).at(0, allocator);
assert(ref != 0);
if (t == typeid(uint8_t)) {
NGT::Serializer::readAsText(is, (uint8_t*)ref, dimension);
} else if (t == typeid(float)) {
NGT::Serializer::readAsText(is, (float*)ref, dimension);
} else if (t == typeid(double)) {
NGT::Serializer::readAsText(is, (double*)ref, dimension);
} else if (t == typeid(uint16_t)) {
NGT::Serializer::readAsText(is, (uint16_t*)ref, dimension);
} else if (t == typeid(uint32_t)) {
NGT::Serializer::readAsText(is, (uint32_t*)ref, dimension);
} else {
std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
assert(0);
}
}
#endif
} // namespace NGT

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,781 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/defines.h"
#if defined(NGT_NO_AVX)
// #warning "*** SIMD is *NOT* available! ***"
#else
#include <immintrin.h>
#endif
namespace NGT {
class MemoryCache {
public:
inline static void
prefetch(unsigned char* ptr, const size_t byteSizeOfObject) {
#if !defined(NGT_NO_AVX)
switch ((byteSizeOfObject - 1) >> 6) {
default:
case 28:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 27:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 26:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 25:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 24:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 23:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 22:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 21:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 20:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 19:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 18:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 17:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 16:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 15:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 14:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 13:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 12:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 11:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 10:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 9:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 8:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 7:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 6:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 5:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 4:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 3:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 2:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 1:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 0:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
break;
}
#endif
}
inline static void*
alignedAlloc(const size_t allocSize) {
#ifdef NGT_NO_AVX
return new uint8_t[allocSize];
#else
#if defined(NGT_AVX512)
size_t alignment = 64;
uint64_t mask = 0xFFFFFFFFFFFFFFC0;
#elif defined(NGT_AVX2)
size_t alignment = 32;
uint64_t mask = 0xFFFFFFFFFFFFFFE0;
#else
size_t alignment = 16;
uint64_t mask = 0xFFFFFFFFFFFFFFF0;
#endif
uint8_t* p = new uint8_t[allocSize + alignment];
uint8_t* ptr = p + alignment;
ptr = reinterpret_cast<uint8_t*>((reinterpret_cast<uint64_t>(ptr) & mask));
*p++ = 0xAB;
while (p != ptr) *p++ = 0xCD;
return ptr;
#endif
}
inline static void
alignedFree(void* ptr) {
#ifdef NGT_NO_AVX
delete[] static_cast<uint8_t*>(ptr);
#else
uint8_t* p = static_cast<uint8_t*>(ptr);
p--;
while (*p == 0xCD) p--;
if (*p != 0xAB) {
NGTThrowException("MemoryCache::alignedFree: Fatal Error! Cannot find allocated address.");
}
delete[] p;
#endif
}
};
class PrimitiveComparator {
public:
static double
absolute(double v) {
return fabs(v);
}
static int
absolute(int v) {
return abs(v);
}
#if defined(NGT_NO_AVX)
template <typename OBJECT_TYPE, typename COMPARE_TYPE>
inline static double
compareL2(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const OBJECT_TYPE* last = a + size;
const OBJECT_TYPE* lastgroup = last - 3;
COMPARE_TYPE diff0, diff1, diff2, diff3;
double d = 0.0;
while (a < lastgroup) {
diff0 = (COMPARE_TYPE)(a[0] - b[0]);
diff1 = (COMPARE_TYPE)(a[1] - b[1]);
diff2 = (COMPARE_TYPE)(a[2] - b[2]);
diff3 = (COMPARE_TYPE)(a[3] - b[3]);
d += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
a += 4;
b += 4;
}
while (a < last) {
diff0 = (COMPARE_TYPE)(*a++ - *b++);
d += diff0 * diff0;
}
return sqrt((double)d);
}
inline static double
compareL2(const uint8_t* a, const uint8_t* b, size_t size) {
return compareL2<uint8_t, int>(a, b, size);
}
inline static double
compareL2(const float* a, const float* b, size_t size) {
return compareL2<float, double>(a, b, size);
}
#else
inline static double
compareL2(const float* a, const float* b, size_t size) {
const float* last = a + size;
#if defined(NGT_AVX512)
__m512 sum512 = _mm512_setzero_ps();
while (a < last) {
__m512 v = _mm512_sub_ps(_mm512_loadu_ps(a), _mm512_loadu_ps(b));
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v, v));
a += 16;
b += 16;
}
__m256 sum256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum512, 0), _mm512_extractf32x8_ps(sum512, 1));
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#elif defined(NGT_AVX2)
__m256 sum256 = _mm256_setzero_ps();
__m256 v;
while (a < last) {
v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v, v));
a += 8;
b += 8;
v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v, v));
a += 8;
b += 8;
}
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#else
__m128 sum128 = _mm_setzero_ps();
__m128 v;
while (a < last) {
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
a += 4;
b += 4;
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
a += 4;
b += 4;
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
a += 4;
b += 4;
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
a += 4;
b += 4;
}
#endif
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, sum128);
double s = f[0] + f[1] + f[2] + f[3];
return sqrt(s);
}
inline static double
compareL2(const unsigned char* a, const unsigned char* b, size_t size) {
__m128 sum = _mm_setzero_ps();
const unsigned char* last = a + size;
const unsigned char* lastgroup = last - 7;
const __m128i zero = _mm_setzero_si128();
while (a < lastgroup) {
__m128i x1 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)a));
__m128i x2 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)b));
x1 = _mm_subs_epi16(x1, x2);
__m128i v = _mm_mullo_epi16(x1, x1);
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpacklo_epi16(v, zero)));
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpackhi_epi16(v, zero)));
a += 8;
b += 8;
}
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, sum);
double s = f[0] + f[1] + f[2] + f[3];
while (a < last) {
int d = (int)*a++ - (int)*b++;
s += d * d;
}
return sqrt(s);
}
#endif
#if defined(NGT_NO_AVX)
template <typename OBJECT_TYPE, typename COMPARE_TYPE>
static double
compareL1(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const OBJECT_TYPE* last = a + size;
const OBJECT_TYPE* lastgroup = last - 3;
COMPARE_TYPE diff0, diff1, diff2, diff3;
double d = 0.0;
while (a < lastgroup) {
diff0 = (COMPARE_TYPE)(a[0] - b[0]);
diff1 = (COMPARE_TYPE)(a[1] - b[1]);
diff2 = (COMPARE_TYPE)(a[2] - b[2]);
diff3 = (COMPARE_TYPE)(a[3] - b[3]);
d += absolute(diff0) + absolute(diff1) + absolute(diff2) + absolute(diff3);
a += 4;
b += 4;
}
while (a < last) {
diff0 = (COMPARE_TYPE)*a++ - (COMPARE_TYPE)*b++;
d += absolute(diff0);
}
return d;
}
inline static double
compareL1(const uint8_t* a, const uint8_t* b, size_t size) {
return compareL1<uint8_t, int>(a, b, size);
}
inline static double
compareL1(const float* a, const float* b, size_t size) {
return compareL1<float, double>(a, b, size);
}
#else
inline static double
compareL1(const float* a, const float* b, size_t size) {
__m256 sum = _mm256_setzero_ps();
const float* last = a + size;
const float* lastgroup = last - 7;
while (a < lastgroup) {
__m256 x1 = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
const __m256 mask = _mm256_set1_ps(-0.0f);
__m256 v = _mm256_andnot_ps(mask, x1);
sum = _mm256_add_ps(sum, v);
a += 8;
b += 8;
}
__attribute__((aligned(32))) float f[8];
_mm256_store_ps(f, sum);
double s = f[0] + f[1] + f[2] + f[3] + f[4] + f[5] + f[6] + f[7];
while (a < last) {
double d = fabs(*a++ - *b++);
s += d;
}
return s;
}
inline static double
compareL1(const unsigned char* a, const unsigned char* b, size_t size) {
__m128 sum = _mm_setzero_ps();
const unsigned char* last = a + size;
const unsigned char* lastgroup = last - 7;
const __m128i zero = _mm_setzero_si128();
while (a < lastgroup) {
__m128i x1 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)a));
__m128i x2 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)b));
x1 = _mm_subs_epi16(x1, x2);
x1 = _mm_sign_epi16(x1, x1);
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpacklo_epi16(x1, zero)));
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpackhi_epi16(x1, zero)));
a += 8;
b += 8;
}
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, sum);
double s = f[0] + f[1] + f[2] + f[3];
while (a < last) {
double d = fabs((double)*a++ - (double)*b++);
s += d;
}
return s;
}
#endif
#if defined(NGT_NO_AVX) || !defined(__POPCNT__)
inline static double
popCount(uint32_t x) {
x = (x & 0x55555555) + (x >> 1 & 0x55555555);
x = (x & 0x33333333) + (x >> 2 & 0x33333333);
x = (x & 0x0F0F0F0F) + (x >> 4 & 0x0F0F0F0F);
x = (x & 0x00FF00FF) + (x >> 8 & 0x00FF00FF);
x = (x & 0x0000FFFF) + (x >> 16 & 0x0000FFFF);
return x;
}
template <typename OBJECT_TYPE>
inline static double
compareHammingDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const uint32_t* last = reinterpret_cast<const uint32_t*>(a + size);
const uint32_t* uinta = reinterpret_cast<const uint32_t*>(a);
const uint32_t* uintb = reinterpret_cast<const uint32_t*>(b);
size_t count = 0;
while (uinta < last) {
count += popCount(*uinta++ ^ *uintb++);
}
return static_cast<double>(count);
}
#else
template <typename OBJECT_TYPE>
inline static double
compareHammingDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const uint64_t* last = reinterpret_cast<const uint64_t*>(a + size);
const uint64_t* uinta = reinterpret_cast<const uint64_t*>(a);
const uint64_t* uintb = reinterpret_cast<const uint64_t*>(b);
size_t count = 0;
while (uinta < last) {
count += _mm_popcnt_u64(*uinta++ ^ *uintb++);
count += _mm_popcnt_u64(*uinta++ ^ *uintb++);
}
return static_cast<double>(count);
}
#endif
#if defined(NGT_NO_AVX) || !defined(__POPCNT__)
template <typename OBJECT_TYPE>
inline static double
compareJaccardDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const uint32_t* last = reinterpret_cast<const uint32_t*>(a + size);
const uint32_t* uinta = reinterpret_cast<const uint32_t*>(a);
const uint32_t* uintb = reinterpret_cast<const uint32_t*>(b);
size_t count = 0;
size_t countDe = 0;
while (uinta < last) {
count += popCount(*uinta & *uintb);
countDe += popCount(*uinta++ | *uintb++);
count += popCount(*uinta & *uintb);
countDe += popCount(*uinta++ | *uintb++);
}
return 1.0 - static_cast<double>(count) / static_cast<double>(countDe);
}
#else
template <typename OBJECT_TYPE>
inline static double
compareJaccardDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const uint64_t* last = reinterpret_cast<const uint64_t*>(a + size);
const uint64_t* uinta = reinterpret_cast<const uint64_t*>(a);
const uint64_t* uintb = reinterpret_cast<const uint64_t*>(b);
size_t count = 0;
size_t countDe = 0;
while (uinta < last) {
count += _mm_popcnt_u64(*uinta & *uintb);
countDe += _mm_popcnt_u64(*uinta++ | *uintb++);
count += _mm_popcnt_u64(*uinta & *uintb);
countDe += _mm_popcnt_u64(*uinta++ | *uintb++);
}
return 1.0 - static_cast<double>(count) / static_cast<double>(countDe);
}
#endif
inline static double
compareSparseJaccardDistance(const unsigned char* a, unsigned char* b, size_t size) {
abort();
}
inline static double
compareSparseJaccardDistance(const float* a, const float* b, size_t size) {
size_t loca = 0;
size_t locb = 0;
const uint32_t* ai = reinterpret_cast<const uint32_t*>(a);
const uint32_t* bi = reinterpret_cast<const uint32_t*>(b);
size_t count = 0;
while (locb < size && ai[loca] != 0 && bi[loca] != 0) {
int64_t sub = static_cast<int64_t>(ai[loca]) - static_cast<int64_t>(bi[locb]);
count += sub == 0;
loca += sub <= 0;
locb += sub >= 0;
}
while (ai[loca] != 0) {
loca++;
}
while (locb < size && bi[locb] != 0) {
locb++;
}
return 1.0 - static_cast<double>(count) / static_cast<double>(loca + locb - count);
}
#if defined(NGT_NO_AVX)
template <typename OBJECT_TYPE>
inline static double
compareDotProduct(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
sum += (double)a[loc] * (double)b[loc];
}
return sum;
}
template <typename OBJECT_TYPE>
inline static double
compareCosine(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double normA = 0.0;
double normB = 0.0;
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
normA += (double)a[loc] * (double)a[loc];
normB += (double)b[loc] * (double)b[loc];
sum += (double)a[loc] * (double)b[loc];
}
double cosine = sum / sqrt(normA * normB);
return cosine;
}
#else
inline static double
compareDotProduct(const float* a, const float* b, size_t size) {
const float* last = a + size;
#if defined(NGT_AVX512)
__m512 sum512 = _mm512_setzero_ps();
while (a < last) {
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(_mm512_loadu_ps(a), _mm512_loadu_ps(b)));
a += 16;
b += 16;
}
__m256 sum256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum512, 0), _mm512_extractf32x8_ps(sum512, 1));
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#elif defined(NGT_AVX2)
__m256 sum256 = _mm256_setzero_ps();
while (a < last) {
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b)));
a += 8;
b += 8;
}
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#else
__m128 sum128 = _mm_setzero_ps();
while (a < last) {
sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(a), _mm_loadu_ps(b)));
a += 4;
b += 4;
}
#endif
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, sum128);
double s = f[0] + f[1] + f[2] + f[3];
return s;
}
inline static double
compareDotProduct(const unsigned char* a, const unsigned char* b, size_t size) {
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
sum += (double)a[loc] * (double)b[loc];
}
return sum;
}
inline static double
compareCosine(const float* a, const float* b, size_t size) {
const float* last = a + size;
#if defined(NGT_AVX512)
__m512 normA = _mm512_setzero_ps();
__m512 normB = _mm512_setzero_ps();
__m512 sum = _mm512_setzero_ps();
while (a < last) {
__m512 am = _mm512_loadu_ps(a);
__m512 bm = _mm512_loadu_ps(b);
normA = _mm512_add_ps(normA, _mm512_mul_ps(am, am));
normB = _mm512_add_ps(normB, _mm512_mul_ps(bm, bm));
sum = _mm512_add_ps(sum, _mm512_mul_ps(am, bm));
a += 16;
b += 16;
}
__m256 am256 = _mm256_add_ps(_mm512_extractf32x8_ps(normA, 0), _mm512_extractf32x8_ps(normA, 1));
__m256 bm256 = _mm256_add_ps(_mm512_extractf32x8_ps(normB, 0), _mm512_extractf32x8_ps(normB, 1));
__m256 s256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum, 0), _mm512_extractf32x8_ps(sum, 1));
__m128 am128 = _mm_add_ps(_mm256_extractf128_ps(am256, 0), _mm256_extractf128_ps(am256, 1));
__m128 bm128 = _mm_add_ps(_mm256_extractf128_ps(bm256, 0), _mm256_extractf128_ps(bm256, 1));
__m128 s128 = _mm_add_ps(_mm256_extractf128_ps(s256, 0), _mm256_extractf128_ps(s256, 1));
#elif defined(NGT_AVX2)
__m256 normA = _mm256_setzero_ps();
__m256 normB = _mm256_setzero_ps();
__m256 sum = _mm256_setzero_ps();
__m256 am, bm;
while (a < last) {
am = _mm256_loadu_ps(a);
bm = _mm256_loadu_ps(b);
normA = _mm256_add_ps(normA, _mm256_mul_ps(am, am));
normB = _mm256_add_ps(normB, _mm256_mul_ps(bm, bm));
sum = _mm256_add_ps(sum, _mm256_mul_ps(am, bm));
a += 8;
b += 8;
}
__m128 am128 = _mm_add_ps(_mm256_extractf128_ps(normA, 0), _mm256_extractf128_ps(normA, 1));
__m128 bm128 = _mm_add_ps(_mm256_extractf128_ps(normB, 0), _mm256_extractf128_ps(normB, 1));
__m128 s128 = _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
#else
__m128 am128 = _mm_setzero_ps();
__m128 bm128 = _mm_setzero_ps();
__m128 s128 = _mm_setzero_ps();
__m128 am, bm;
while (a < last) {
am = _mm_loadu_ps(a);
bm = _mm_loadu_ps(b);
am128 = _mm_add_ps(am128, _mm_mul_ps(am, am));
bm128 = _mm_add_ps(bm128, _mm_mul_ps(bm, bm));
s128 = _mm_add_ps(s128, _mm_mul_ps(am, bm));
a += 4;
b += 4;
}
#endif
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, am128);
double na = f[0] + f[1] + f[2] + f[3];
_mm_store_ps(f, bm128);
double nb = f[0] + f[1] + f[2] + f[3];
_mm_store_ps(f, s128);
double s = f[0] + f[1] + f[2] + f[3];
double cosine = s / sqrt(na * nb);
return cosine;
}
inline static double
compareCosine(const unsigned char* a, const unsigned char* b, size_t size) {
double normA = 0.0;
double normB = 0.0;
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
normA += (double)a[loc] * (double)a[loc];
normB += (double)b[loc] * (double)b[loc];
sum += (double)a[loc] * (double)b[loc];
}
double cosine = sum / sqrt(normA * normB);
return cosine;
}
#endif // #if defined(NGT_NO_AVX)
template <typename OBJECT_TYPE>
inline static double
compareAngleDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double cosine = compareCosine(a, b, size);
if (cosine >= 1.0) {
return 0.0;
} else if (cosine <= -1.0) {
return acos(-1.0);
} else {
return acos(cosine);
}
}
template <typename OBJECT_TYPE>
inline static double
compareNormalizedAngleDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double cosine = compareDotProduct(a, b, size);
if (cosine >= 1.0) {
return 0.0;
} else if (cosine <= -1.0) {
return acos(-1.0);
} else {
return acos(cosine);
}
}
template <typename OBJECT_TYPE>
inline static double
compareCosineSimilarity(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
return 1.0 - compareCosine(a, b, size);
}
template <typename OBJECT_TYPE>
inline static double
compareNormalizedCosineSimilarity(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double v = 1.0 - compareDotProduct(a, b, size);
return v < 0.0 ? 0.0 : v;
}
class L1Uint8 {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareL1((const uint8_t*)a, (const uint8_t*)b, size);
}
};
class L2Uint8 {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareL2((const uint8_t*)a, (const uint8_t*)b, size);
}
};
class HammingUint8 {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareHammingDistance((const uint8_t*)a, (const uint8_t*)b, size);
}
};
class JaccardUint8 {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareJaccardDistance((const uint8_t*)a, (const uint8_t*)b, size);
}
};
class SparseJaccardFloat {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareSparseJaccardDistance((const float*)a, (const float*)b, size);
}
};
class L2Float {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
#if defined(NGT_NO_AVX)
return PrimitiveComparator::compareL2<float, double>((const float*)a, (const float*)b, size);
#else
return PrimitiveComparator::compareL2((const float*)a, (const float*)b, size);
#endif
}
};
class L1Float {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareL1((const float*)a, (const float*)b, size);
}
};
class CosineSimilarityFloat {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareCosineSimilarity((const float*)a, (const float*)b, size);
}
};
class AngleFloat {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareAngleDistance((const float*)a, (const float*)b, size);
}
};
class NormalizedCosineSimilarityFloat {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareNormalizedCosineSimilarity((const float*)a, (const float*)b, size);
}
};
class NormalizedAngleFloat {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareNormalizedAngleDistance((const float*)a, (const float*)b, size);
}
};
};
} // namespace NGT

View File

@ -0,0 +1,40 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/SharedMemoryAllocator.h"
void* operator
new(size_t size, SharedMemoryAllocator &allocator)
{
void *addr = allocator.allocate(size);
#ifdef MEMORY_ALLOCATOR_INFO
std::cerr << "new:" << size << " " << addr << " " << allocator.getTotalSize() << std::endl;
#endif
return addr;
}
void* operator
new[](size_t size, SharedMemoryAllocator &allocator)
{
void *addr = allocator.allocate(size);
#ifdef MEMORY_ALLOCATOR_INFO
std::cerr << "new[]:" << size << " " << addr << " " << allocator.getTotalSize() << std::endl;
#endif
return addr;
}

View File

@ -0,0 +1,209 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/defines.h"
#include "NGT/MmapManager.h"
#include <unistd.h>
#include <cstdlib>
#include <cstring>
#include <string>
#include <iostream>
#include <vector>
#include <exception>
#include <cassert>
#define MMAP_MANAGER
///////////////////////////////////////////////////////////////////////
class SharedMemoryAllocator {
public:
enum GetMemorySizeType {
GetTotalMemorySize = 0,
GetAllocatedMemorySize = 1,
GetFreedMemorySize = 2
};
SharedMemoryAllocator():isValid(false) {
#ifdef SMA_TRACE
std::cerr << "SharedMemoryAllocatorSiglton::constructor" << std::endl;
#endif
}
SharedMemoryAllocator(const SharedMemoryAllocator& a){}
SharedMemoryAllocator& operator=(const SharedMemoryAllocator& a){ return *this; }
public:
void* allocate(size_t size) {
if (isValid == false) {
std::cerr << "SharedMemoryAllocator::allocate: Fatal error! " << std::endl;
assert(isValid);
}
#ifdef SMA_TRACE
std::cerr << "SharedMemoryAllocator::allocate: size=" << size << std::endl;
std::cerr << "SharedMemoryAllocator::allocate: before " << getTotalSize() << ":" << getAllocatedSize() << ":" << getFreedSize() << std::endl;
#endif
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
if(!isValid){
return NULL;
}
off_t file_offset = mmanager->alloc(size, true);
if (file_offset == -1) {
std::cerr << "Fatal Error: Allocating memory size is too big for this settings." << std::endl;
std::cerr << " Max allocation size should be enlarged." << std::endl;
abort();
}
void *p = mmanager->getAbsAddr(file_offset);
std::memset(p, 0, size);
#ifdef SMA_TRACE
std::cerr << "SharedMemoryAllocator::allocate: end" <<std::endl;
#endif
return p;
#else
void *ptr = std::malloc(size);
std::memset(ptr, 0, size);
return ptr;
#endif
}
void free(void *ptr) {
#ifdef SMA_TRACE
std::cerr << "SharedMemoryAllocator::free: ptr=" << ptr << std::endl;
#endif
if (ptr == 0) {
std::cerr << "SharedMemoryAllocator::free: ptr is invalid! ptr=" << ptr << std::endl;
}
if (ptr == 0) {
return;
}
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
off_t file_offset = mmanager->getRelAddr(ptr);
mmanager->free(file_offset);
#else
std::free(ptr);
#endif
}
void *construct(const std::string &filePath, size_t memorysize = 0) {
file = filePath; // debug
#ifdef SMA_TRACE
std::cerr << "ObjectSharedMemoryAllocator::construct: file " << filePath << std::endl;
#endif
void *hook = 0;
#ifdef MMAP_MANAGER
mmanager = new MemoryManager::MmapManager();
// msize is the maximum allocated size (M byte) at once.
size_t msize = memorysize;
if (msize == 0) {
msize = NGT_SHARED_MEMORY_MAX_SIZE;
}
size_t bsize = msize * 1048576 / sysconf(_SC_PAGESIZE) + 1; // 1048576=1M
uint64_t size = bsize * sysconf(_SC_PAGESIZE);
MemoryManager::init_option_st option;
MemoryManager::MmapManager::setDefaultOptionValue(option);
option.use_expand = true;
option.reuse_type = MemoryManager::REUSE_DATA_CLASSIFY;
bool create = true;
if(!mmanager->init(filePath, size, &option)){
#ifdef SMA_TRACE
std::cerr << "SMA: info. already existed." << std::endl;
#endif
create = false;
} else {
#ifdef SMA_TRACE
std::cerr << "SMA::construct: msize=" << msize << ":" << memorysize << std::endl;
#endif
}
if(!mmanager->openMemory(filePath)){
std::cerr << "SMA: open error" << std::endl;
return 0;
}
if (!create) {
#ifdef SMA_TRACE
std::cerr << "SMA: get hook to initialize data structure" << std::endl;
#endif
hook = mmanager->getEntryHook();
assert(hook != 0);
}
#endif
isValid = true;
#ifdef SMA_TRACE
std::cerr << "SharedMemoryAllocator::construct: " << filePath << " total="
<< getTotalSize() << " allocated=" << getAllocatedSize() << " freed="
<< getFreedSize() << " (" << (double)getFreedSize() / (double)getTotalSize() << ") " << std::endl;
#endif
return hook;
}
void destruct() {
if (!isValid) {
return;
}
isValid = false;
#ifdef MMAP_MANAGER
mmanager->closeMemory();
delete mmanager;
#endif
};
void setEntry(void *entry) {
#ifdef MMAP_MANAGER
mmanager->setEntryHook(entry);
#endif
}
void *getAddr(off_t oft) {
if (oft == 0) {
return 0;
}
assert(oft > 0);
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
return mmanager->getAbsAddr(oft);
#else
return (void*)oft;
#endif
}
off_t getOffset(void *adr) {
if (adr == 0) {
return 0;
}
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
return mmanager->getRelAddr(adr);
#else
return (off_t)adr;
#endif
}
size_t getMemorySize(GetMemorySizeType t) {
switch (t) {
case GetTotalMemorySize : return getTotalSize();
case GetAllocatedMemorySize : return getAllocatedSize();
case GetFreedMemorySize : return getFreedSize();
}
return getTotalSize();
}
size_t getTotalSize() { return mmanager->getTotalSize(); }
size_t getAllocatedSize() { return mmanager->getUseSize(); }
size_t getFreedSize() { return mmanager->getFreeSize(); }
bool isValid;
std::string file;
#ifdef MMAP_MANAGER
MemoryManager::MmapManager *mmanager;
#endif
};
/////////////////////////////////////////////////////////////////////////
void* operator new(size_t size, SharedMemoryAllocator &allocator);
void* operator new[](size_t size, SharedMemoryAllocator &allocator);

View File

@ -0,0 +1,128 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include <pthread.h>
#include "Thread.h"
using namespace std;
using namespace NGT;
namespace NGT {
class ThreadInfo {
public:
pthread_t threadid;
pthread_attr_t threadAttr;
};
class ThreadMutex {
public:
pthread_mutex_t mutex;
pthread_cond_t condition;
};
}
Thread::Thread() {
threadInfo = new ThreadInfo;
threadInfo->threadid = 0;
threadNo = -1;
isTerminate = false;
}
Thread::~Thread() {
if (threadInfo != 0) {
delete threadInfo;
}
}
ThreadMutex *
Thread::constructThreadMutex()
{
return new ThreadMutex;
}
void
Thread::destructThreadMutex(ThreadMutex *t)
{
if (t != 0) {
pthread_mutex_destroy(&(t->mutex));
pthread_cond_destroy(&(t->condition));
delete t;
}
}
int
Thread::start()
{
pthread_attr_init(&(threadInfo->threadAttr));
size_t stackSize = 0;
pthread_attr_getstacksize(&(threadInfo->threadAttr), &stackSize);
if (stackSize < 0xa00000) { // 64bit stack size
stackSize *= 4;
}
pthread_attr_setstacksize(&(threadInfo->threadAttr), stackSize);
pthread_attr_getstacksize(&(threadInfo->threadAttr), &stackSize);
return pthread_create(&(threadInfo->threadid), &(threadInfo->threadAttr), Thread::startThread, this);
}
int
Thread::join()
{
return pthread_join(threadInfo->threadid, 0);
}
void
Thread::lock(ThreadMutex &m)
{
pthread_mutex_lock(&m.mutex);
}
void
Thread::unlock(ThreadMutex &m)
{
pthread_mutex_unlock(&m.mutex);
}
void
Thread::signal(ThreadMutex &m)
{
pthread_cond_signal(&m.condition);
}
void
Thread::wait(ThreadMutex &m)
{
if (pthread_cond_wait(&m.condition, &m.mutex) != 0) {
cerr << "waitForSignalFromThread: internal error" << endl;
NGTThrowException("waitForSignalFromThread: internal error");
}
}
void
Thread::broadcast(ThreadMutex &m)
{
pthread_cond_broadcast(&m.condition);
}
void
Thread::mutexInit(ThreadMutex &m)
{
if (pthread_mutex_init(&m.mutex, NULL) != 0) {
NGTThrowException("Thread::mutexInit: Cannot initialize mutex");
}
if (pthread_cond_init(&m.condition, NULL) != 0) {
NGTThrowException("Thread::mutexInit: Cannot initialize condition");
}
}

View File

@ -0,0 +1,291 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/Common.h"
#include <cstdio>
#include <cstdlib>
#include <sys/time.h>
#include <unistd.h>
#include <iostream>
#include <deque>
namespace NGT {
void * evaluate_responce(void *);
class ThreadTerminationException : public Exception {
public:
ThreadTerminationException(const std::string &file, size_t line, std::stringstream &m) { set(file, line, m.str()); }
ThreadTerminationException(const std::string &file, size_t line, const std::string &m) { set(file, line, m); }
};
class ThreadInfo;
class ThreadMutex;
class Thread
{
public:
Thread();
virtual ~Thread();
virtual int start();
virtual int join();
static ThreadMutex *constructThreadMutex();
static void destructThreadMutex(ThreadMutex *t);
static void mutexInit(ThreadMutex &m);
static void lock(ThreadMutex &m);
static void unlock(ThreadMutex &m);
static void signal(ThreadMutex &m);
static void wait(ThreadMutex &m);
static void broadcast(ThreadMutex &m);
protected:
virtual int run() {
return 0;
}
private:
static void* startThread(void *thread) {
if (thread == 0) {
return 0;
}
Thread* p = (Thread*)thread;
p->run();
return thread;
}
public:
int threadNo;
bool isTerminate;
protected:
ThreadInfo *threadInfo;
};
template <class JOB, class SHARED_DATA, class THREAD>
class ThreadPool {
public:
class JobQueue : public std::deque<JOB> {
public:
JobQueue() {
threadMutex = Thread::constructThreadMutex();
Thread::mutexInit(*threadMutex);
}
~JobQueue() {
Thread::destructThreadMutex(threadMutex);
}
bool isDeficient() { return std::deque<JOB>::size() <= requestSize; }
bool isEmpty() { return std::deque<JOB>::size() == 0; }
bool isFull() { return std::deque<JOB>::size() >= maxSize; }
void setRequestSize(int s) { requestSize = s; }
void setMaxSize(int s) { maxSize = s; }
void lock() { Thread::lock(*threadMutex); }
void unlock() { Thread::unlock(*threadMutex); }
void signal() { Thread::signal(*threadMutex); }
void wait() { Thread::wait(*threadMutex); }
void wait(JobQueue &q) { wait(*q.threadMutex); }
void broadcast() { Thread::broadcast(*threadMutex); }
unsigned int requestSize;
unsigned int maxSize;
ThreadMutex *threadMutex;
};
class InputJobQueue : public JobQueue {
public:
InputJobQueue() {
isTerminate = false;
underPushing = false;
pushedSize = 0;
}
void popFront(JOB &d) {
JobQueue::lock();
while (JobQueue::isEmpty()) {
if (isTerminate) {
JobQueue::unlock();
NGTThrowSpecificException("Thread::termination", ThreadTerminationException);
}
JobQueue::wait();
}
d = std::deque<JOB>::front();
std::deque<JOB>::pop_front();
JobQueue::unlock();
return;
}
void popFront(std::deque<JOB> &d, size_t s) {
JobQueue::lock();
while (JobQueue::isEmpty()) {
if (isTerminate) {
JobQueue::unlock();
NGTThrowSpecificException("Thread::termination", ThreadTerminationException);
}
JobQueue::wait();
}
for (size_t i = 0; i < s; i++) {
d.push_back(std::deque<JOB>::front());
std::deque<JOB>::pop_front();
if (JobQueue::isEmpty()) {
break;
}
}
JobQueue::unlock();
return;
}
void pushBack(JOB &data) {
JobQueue::lock();
if (!underPushing) {
underPushing = true;
pushedSize = 0;
}
pushedSize++;
std::deque<JOB>::push_back(data);
JobQueue::unlock();
JobQueue::signal();
}
void pushBackEnd() {
underPushing = false;
}
void terminate() {
JobQueue::lock();
if (underPushing || !JobQueue::isEmpty()) {
JobQueue::unlock();
NGTThrowException("Thread::teminate:Under pushing!");
}
isTerminate = true;
JobQueue::unlock();
JobQueue::broadcast();
}
bool isTerminate;
bool underPushing;
size_t pushedSize;
};
class OutputJobQueue : public JobQueue {
public:
void waitForFull() {
JobQueue::wait();
JobQueue::unlock();
}
void pushBack(JOB &data) {
JobQueue::lock();
std::deque<JOB>::push_back(data);
if (!JobQueue::isFull()) {
JobQueue::unlock();
return;
}
JobQueue::unlock();
JobQueue::signal();
}
};
class SharedData {
public:
SharedData():isAvailable(false) {
inputJobs.requestSize = 5;
inputJobs.maxSize = 50;
}
SHARED_DATA sharedData;
InputJobQueue inputJobs;
OutputJobQueue outputJobs;
bool isAvailable;
};
class Thread : public THREAD {
public:
SHARED_DATA &getSharedData() {
if (threadPool->sharedData.isAvailable) {
return threadPool->sharedData.sharedData;
} else {
NGTThrowException("Thread::getSharedData: Shared data is unavailable. No set yet.");
}
}
InputJobQueue &getInputJobQueue() {
return threadPool->sharedData.inputJobs;
}
OutputJobQueue &getOutputJobQueue() {
return threadPool->sharedData.outputJobs;
}
ThreadPool *threadPool;
};
ThreadPool(int s) {
size = s;
threads = new Thread[s];
}
~ThreadPool() {
delete[] threads;
}
void setSharedData(SHARED_DATA d) {
sharedData.sharedData = d;
sharedData.isAvailable = true;
}
void create() {
for (unsigned int i = 0; i < size; i++) {
threads[i].threadPool = this;
threads[i].threadNo = i;
threads[i].start();
}
}
void pushInputQueue(JOB &data) {
if (!sharedData.inputJobs.underPushing) {
sharedData.outputJobs.lock();
}
sharedData.inputJobs.pushBack(data);
}
void waitForFinish() {
sharedData.inputJobs.pushBackEnd();
sharedData.outputJobs.setMaxSize(sharedData.inputJobs.pushedSize);
sharedData.inputJobs.pushedSize = 0;
sharedData.outputJobs.waitForFull();
}
void terminate() {
sharedData.inputJobs.terminate();
for (unsigned int i = 0; i < size; i++) {
threads[i].join();
}
}
InputJobQueue &getInputJobQueue() { return sharedData.inputJobs; }
OutputJobQueue &getOutputJobQueue() { return sharedData.outputJobs; }
SharedData sharedData; // shared data
Thread *threads; // thread set
unsigned int size; // thread size
};
}

View File

@ -0,0 +1,564 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/defines.h"
#include "NGT/Tree.h"
#include "NGT/Node.h"
#include <vector>
using namespace std;
using namespace NGT;
void
DVPTree::insert(InsertContainer &iobj) {
SearchContainer q(iobj.object);
q.mode = SearchContainer::SearchLeaf;
q.vptree = this;
q.radius = 0.0;
search(q);
iobj.vptree = this;
assert(q.nodeID.getType() == Node::ID::Leaf);
LeafNode *ln = (LeafNode*)getNode(q.nodeID);
insert(iobj, ln);
return;
}
void
DVPTree::insert(InsertContainer &iobj, LeafNode *leafNode)
{
LeafNode &leaf = *leafNode;
size_t fsize = leaf.getObjectSize();
if (fsize != 0) {
NGT::ObjectSpace::Comparator &comparator = objectSpace->getComparator();
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Distance d = comparator(iobj.object, leaf.getPivot(*objectSpace));
#else
Distance d = comparator(iobj.object, leaf.getPivot());
#endif
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::ObjectDistance *objects = leaf.getObjectIDs(leafNodes.allocator);
#else
NGT::ObjectDistance *objects = leaf.getObjectIDs();
#endif
for (size_t i = 0; i < fsize; i++) {
if (objects[i].distance == d) {
Distance idd = 0.0;
ObjectID loid;
try {
loid = objects[i].id;
idd = comparator(iobj.object, *getObjectRepository().get(loid));
} catch (Exception &e) {
stringstream msg;
msg << "LeafNode::insert: Cannot find object which belongs to a leaf node. id="
<< objects[i].id << ":" << e.what() << endl;
NGTThrowException(msg.str());
}
if (idd == 0.0) {
if (loid == iobj.id) {
stringstream msg;
msg << "DVPTree::insert:already existed. " << iobj.id;
NGTThrowException(msg);
}
return;
}
}
}
}
if (leaf.getObjectSize() >= leafObjectsSize) {
split(iobj, leaf);
} else {
insertObject(iobj, leaf);
}
return;
}
Node::ID
DVPTree::split(InsertContainer &iobj, LeafNode &leaf)
{
Node::Objects *fs = getObjects(leaf, iobj);
int pv = DVPTree::MaxVariance;
switch (splitMode) {
case DVPTree::MaxVariance:
pv = LeafNode::selectPivotByMaxVariance(iobj, *fs);
break;
case DVPTree::MaxDistance:
pv = LeafNode::selectPivotByMaxDistance(iobj, *fs);
break;
}
LeafNode::splitObjects(iobj, *fs, pv);
Node::ID nid = recombineNodes(iobj, *fs, leaf);
delete fs;
return nid;
}
Node::ID
DVPTree::recombineNodes(InsertContainer &ic, Node::Objects &fs, LeafNode &leaf)
{
LeafNode *ln[internalChildrenSize];
Node::ID targetParent = leaf.parent;
Node::ID targetId = leaf.id;
ln[0] = &leaf;
ln[0]->objectSize = 0;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
for (size_t i = 1; i < internalChildrenSize; i++) {
ln[i] = new(leafNodes.allocator) LeafNode(leafNodes.allocator);
}
#else
for (size_t i = 1; i < internalChildrenSize; i++) {
ln[i] = new LeafNode;
}
#endif
InternalNode *in = createInternalNode();
Node::ID inid = in->id;
try {
if (targetParent.getID() != 0) {
InternalNode &pnode = *(InternalNode*)getNode(targetParent);
for (size_t i = 0; i < internalChildrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (pnode.getChildren(internalNodes.allocator)[i] == targetId) {
pnode.getChildren(internalNodes.allocator)[i] = inid;
#else
if (pnode.getChildren()[i] == targetId) {
pnode.getChildren()[i] = inid;
#endif
break;
}
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, internalNodes.allocator);
#else
in->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace);
#endif
in->parent = targetParent;
int fsize = fs.size();
int cid = fs[0].clusterID;
#ifdef NGT_NODE_USE_VECTOR
LeafNode::ObjectIDs fid;
fid.id = fs[0].id;
fid.distance = 0.0;
ln[cid]->objectIDs.push_back(fid);
#else
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln[cid]->getObjectIDs(leafNodes.allocator)[ln[cid]->objectSize].id = fs[0].id;
ln[cid]->getObjectIDs(leafNodes.allocator)[ln[cid]->objectSize++].distance = 0.0;
#else
ln[cid]->getObjectIDs()[ln[cid]->objectSize].id = fs[0].id;
ln[cid]->getObjectIDs()[ln[cid]->objectSize++].distance = 0.0;
#endif
#endif
if (fs[0].leafDistance == Node::Object::Pivot) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln[cid]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, leafNodes.allocator);
#else
ln[cid]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace);
#endif
} else {
NGTThrowException("recombineNodes: internal error : illegal pivot.");
}
ln[cid]->parent = inid;
int maxClusterID = cid;
for (int i = 1; i < fsize; i++) {
int clusterID = fs[i].clusterID;
if (clusterID > maxClusterID) {
maxClusterID = clusterID;
}
Distance ld;
if (fs[i].leafDistance == Node::Object::Pivot) {
// pivot
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln[clusterID]->setPivot(*getObjectRepository().get(fs[i].id), *objectSpace, leafNodes.allocator);
#else
ln[clusterID]->setPivot(*getObjectRepository().get(fs[i].id), *objectSpace);
#endif
ld = 0.0;
} else {
ld = fs[i].leafDistance;
}
#ifdef NGT_NODE_USE_VECTOR
fid.id = fs[i].id;
fid.distance = ld;
ln[clusterID]->objectIDs.push_back(fid);
#else
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln[clusterID]->getObjectIDs(leafNodes.allocator)[ln[clusterID]->objectSize].id = fs[i].id;
ln[clusterID]->getObjectIDs(leafNodes.allocator)[ln[clusterID]->objectSize++].distance = ld;
#else
ln[clusterID]->getObjectIDs()[ln[clusterID]->objectSize].id = fs[i].id;
ln[clusterID]->getObjectIDs()[ln[clusterID]->objectSize++].distance = ld;
#endif
#endif
ln[clusterID]->parent = inid;
if (clusterID != cid) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in->getBorders(internalNodes.allocator)[cid] = fs[i].distance;
#else
in->getBorders()[cid] = fs[i].distance;
#endif
cid = fs[i].clusterID;
}
}
// When the number of the children is less than the expected,
// proper values are set to the empty children.
for (size_t i = maxClusterID + 1; i < internalChildrenSize; i++) {
ln[i]->parent = inid;
// dummy
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln[i]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, leafNodes.allocator);
#else
ln[i]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace);
#endif
if (i < (internalChildrenSize - 1)) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in->getBorders(internalNodes.allocator)[i] = FLT_MAX;
#else
in->getBorders()[i] = FLT_MAX;
#endif
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in->getChildren(internalNodes.allocator)[0] = targetId;
#else
in->getChildren()[0] = targetId;
#endif
for (size_t i = 1; i < internalChildrenSize; i++) {
insertNode(ln[i]);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in->getChildren(internalNodes.allocator)[i] = ln[i]->id;
#else
in->getChildren()[i] = ln[i]->id;
#endif
}
} catch(Exception &e) {
throw e;
}
return inid;
}
void
DVPTree::insertObject(InsertContainer &ic, LeafNode &leaf) {
if (leaf.getObjectSize() == 0) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
leaf.setPivot(*getObjectRepository().get(ic.id), *objectSpace, leafNodes.allocator);
#else
leaf.setPivot(*getObjectRepository().get(ic.id), *objectSpace);
#endif
#ifdef NGT_NODE_USE_VECTOR
LeafNode::ObjectIDs fid;
fid.id = ic.id;
fid.distance = 0;
leaf.objectIDs.push_back(fid);
#else
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize].id = ic.id;
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize++].distance = 0;
#else
leaf.getObjectIDs()[leaf.objectSize].id = ic.id;
leaf.getObjectIDs()[leaf.objectSize++].distance = 0;
#endif
#endif
} else {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Distance d = objectSpace->getComparator()(ic.object, leaf.getPivot(*objectSpace));
#else
Distance d = objectSpace->getComparator()(ic.object, leaf.getPivot());
#endif
#ifdef NGT_NODE_USE_VECTOR
LeafNode::ObjectIDs fid;
fid.id = ic.id;
fid.distance = d;
leaf.objectIDs.push_back(fid);
std::sort(leaf.objectIDs.begin(), leaf.objectIDs.end(), LeafNode::ObjectIDs());
#else
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize].id = ic.id;
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize++].distance = d;
#else
leaf.getObjectIDs()[leaf.objectSize].id = ic.id;
leaf.getObjectIDs()[leaf.objectSize++].distance = d;
#endif
#endif
}
}
Node::Objects *
DVPTree::getObjects(LeafNode &n, Container &iobj)
{
int size = n.getObjectSize() + 1;
Node::Objects *fs = new Node::Objects(size);
for (size_t i = 0; i < n.getObjectSize(); i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
(*fs)[i].object = getObjectRepository().get(n.getObjectIDs(leafNodes.allocator)[i].id);
(*fs)[i].id = n.getObjectIDs(leafNodes.allocator)[i].id;
#else
(*fs)[i].object = getObjectRepository().get(n.getObjectIDs()[i].id);
(*fs)[i].id = n.getObjectIDs()[i].id;
#endif
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
(*fs)[n.getObjectSize()].object = getObjectRepository().get(iobj.id);
#else
(*fs)[n.getObjectSize()].object = &iobj.object;
#endif
(*fs)[n.getObjectSize()].id = iobj.id;
return fs;
}
void
DVPTree::removeEmptyNodes(InternalNode &inode) {
int csize = internalChildrenSize;
InternalNode *target = &inode;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Node::ID *children = target->getChildren(internalNodes.allocator);
#else
Node::ID *children = target->getChildren();
#endif
for(;;) {
for (int i = 0; i < csize; i++) {
if (children[i].getType() == Node::ID::Internal) {
return;
}
LeafNode &ln = *static_cast<LeafNode*>(getNode(children[i]));
if (ln.getObjectSize() != 0) {
return;
}
}
for (int i = 0; i < csize; i++) {
removeNode(children[i]);
}
if (target->parent.getID() == 0) {
removeNode(target->id);
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
LeafNode *root = new(leafNodes.allocator) LeafNode(leafNodes.allocator);
#else
LeafNode *root = new LeafNode;
#endif
insertNode(root);
if (root->id.getID() != 1) {
NGTThrowException("Root id Error");
}
return;
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
LeafNode *ln = new(leafNodes.allocator) LeafNode(leafNodes.allocator);
#else
LeafNode *ln = new LeafNode;
#endif
ln->parent = target->parent;
insertNode(ln);
InternalNode &in = *(InternalNode*)getNode(ln->parent);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in.updateChild(*this, target->id, ln->id, internalNodes.allocator);
#else
in.updateChild(*this, target->id, ln->id);
#endif
removeNode(target->id);
target = &in;
}
return;
}
void
DVPTree::search(SearchContainer &sc, InternalNode &node, UncheckedNode &uncheckedNode)
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Distance d = objectSpace->getComparator()(sc.object, node.getPivot(*objectSpace));
#else
Distance d = objectSpace->getComparator()(sc.object, node.getPivot());
#endif
#ifdef NGT_DISTANCE_COMPUTATION_COUNT
sc.distanceComputationCount++;
#endif
int bsize = internalChildrenSize - 1;
vector<ObjectDistance> regions;
regions.reserve(internalChildrenSize);
ObjectDistance child;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Distance *borders = node.getBorders(internalNodes.allocator);
#else
Distance *borders = node.getBorders();
#endif
int mid;
for (mid = 0; mid < bsize; mid++) {
if (d < borders[mid]) {
child.id = mid;
child.distance = 0.0;
regions.push_back(child);
if (d + sc.radius < borders[mid]) {
break;
} else {
continue;
}
} else {
if (d < borders[mid] + sc.radius) {
child.id = mid;
child.distance = d - borders[mid];
regions.push_back(child);
continue;
} else {
continue;
}
}
}
if (mid == bsize) {
if (d >= borders[mid - 1]) {
child.id = mid;
child.distance = 0.0;
regions.push_back(child);
} else {
child.id = mid;
child.distance = borders[mid - 1] - d;
regions.push_back(child);
}
}
sort(regions.begin(), regions.end());
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Node::ID *children = node.getChildren(internalNodes.allocator);
#else
Node::ID *children = node.getChildren();
#endif
vector<ObjectDistance>::iterator i;
if (sc.mode == DVPTree::SearchContainer::SearchLeaf) {
if (children[regions.front().id].getType() == Node::ID::Leaf) {
sc.nodeID.setRaw(children[regions.front().id].get());
assert(uncheckedNode.empty());
} else {
uncheckedNode.push(children[regions.front().id]);
}
} else {
for (i = regions.begin(); i != regions.end(); i++) {
uncheckedNode.push(children[i->id]);
}
}
}
void
DVPTree::search(SearchContainer &so, LeafNode &node, UncheckedNode &uncheckedNode)
{
DVPTree::SearchContainer &q = (DVPTree::SearchContainer&)so;
if (node.getObjectSize() == 0) {
return;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Distance pq = objectSpace->getComparator()(q.object, node.getPivot(*objectSpace));
#else
Distance pq = objectSpace->getComparator()(q.object, node.getPivot());
#endif
#ifdef NGT_DISTANCE_COMPUTATION_COUNT
so.distanceComputationCount++;
#endif
ObjectDistance r;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::ObjectDistance *objects = node.getObjectIDs(leafNodes.allocator);
#else
NGT::ObjectDistance *objects = node.getObjectIDs();
#endif
for (size_t i = 0; i < node.getObjectSize(); i++) {
if ((objects[i].distance <= pq + q.radius) &&
(objects[i].distance >= pq - q.radius)) {
Distance d = 0;
try {
d = objectSpace->getComparator()(q.object, *q.vptree->getObjectRepository().get(objects[i].id));
#ifdef NGT_DISTANCE_COMPUTATION_COUNT
so.distanceComputationCount++;
#endif
} catch(...) {
NGTThrowException("VpTree::LeafNode::search: Internal fatal error : Cannot get object");
}
if (d <= q.radius) {
r.id = objects[i].id;
r.distance = d;
so.getResult().push_back(r);
std::sort(so.getResult().begin(), so.getResult().end());
if (so.getResult().size() > q.size) {
so.getResult().resize(q.size);
}
}
}
}
}
void
DVPTree::search(SearchContainer &sc) {
((SearchContainer&)sc).vptree = this;
Node *root = getRootNode();
assert(root != 0);
if (sc.mode == DVPTree::SearchContainer::SearchLeaf) {
if (root->id.getType() == Node::ID::Leaf) {
sc.nodeID.setRaw(root->id.get());
return;
}
}
UncheckedNode uncheckedNode;
uncheckedNode.push(root->id);
while (!uncheckedNode.empty()) {
Node::ID nodeid = uncheckedNode.top();
uncheckedNode.pop();
Node *cnode = getNode(nodeid);
if (cnode == 0) {
cerr << "Error! child node is null. but continue." << endl;
continue;
}
if (cnode->id.getType() == Node::ID::Internal) {
search(sc, (InternalNode&)*cnode, uncheckedNode);
} else if (cnode->id.getType() == Node::ID::Leaf) {
search(sc, (LeafNode&)*cnode, uncheckedNode);
} else {
cerr << "Tree: Inner fatal error!: Node type error!" << endl;
abort();
}
}
}

View File

@ -0,0 +1,511 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/Common.h"
#include "NGT/Node.h"
#include "NGT/defines.h"
#include "faiss/utils/ConcurrentBitset.h"
#include <sstream>
#include <string>
#include <vector>
#include <stack>
#include <set>
namespace NGT {
class DVPTree {
public:
enum SplitMode {
MaxDistance = 0,
MaxVariance = 1
};
typedef std::vector<Node::ID> IDVector;
class Container : public NGT::Container {
public:
Container(Object &f, ObjectID i):NGT::Container(f, i) {}
DVPTree *vptree;
};
class SearchContainer : public NGT::SearchContainer {
public:
enum Mode {
SearchLeaf = 0,
SearchObject = 1
};
SearchContainer(Object &f, ObjectID i):NGT::SearchContainer(f, i) {}
SearchContainer(Object &f):NGT::SearchContainer(f, 0) {}
DVPTree *vptree;
Mode mode;
Node::ID nodeID;
};
class InsertContainer : public Container {
public:
InsertContainer(Object &f, ObjectID i):Container(f, i) {}
};
class RemoveContainer : public Container {
public:
RemoveContainer(Object &f, ObjectID i):Container(f, i) {}
};
DVPTree() {
leafObjectsSize = LeafNode::LeafObjectsSizeMax;
internalChildrenSize = InternalNode::InternalChildrenSizeMax;
splitMode = MaxVariance;
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
insertNode(new LeafNode);
#endif
}
virtual ~DVPTree() {
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
deleteAll();
#endif
}
void deleteAll() {
for (size_t i = 0; i < leafNodes.size(); i++) {
if (leafNodes[i] != 0) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
leafNodes[i]->deletePivot(*objectSpace, leafNodes.allocator);
#else
leafNodes[i]->deletePivot(*objectSpace);
#endif
delete leafNodes[i];
}
}
leafNodes.clear();
for (size_t i = 0; i < internalNodes.size(); i++) {
if (internalNodes[i] != 0) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
internalNodes[i]->deletePivot(*objectSpace, internalNodes.allocator);
#else
internalNodes[i]->deletePivot(*objectSpace);
#endif
delete internalNodes[i];
}
}
internalNodes.clear();
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void open(const std::string &f, size_t sharedMemorySize) {
// If no file, then create a new file.
leafNodes.open(f + "l", sharedMemorySize);
internalNodes.open(f + "i", sharedMemorySize);
if (leafNodes.size() == 0) {
if (internalNodes.size() != 0) {
NGTThrowException("Tree::Open: Internal error. Internal and leaf are inconsistent.");
}
LeafNode *ln = leafNodes.allocate();
insertNode(ln);
}
}
#endif // NGT_SHARED_MEMORY_ALLOCATOR
void insert(InsertContainer &iobj);
void insert(InsertContainer &iobj, LeafNode *n);
Node::ID split(InsertContainer &iobj, LeafNode &leaf);
Node::ID recombineNodes(InsertContainer &ic, Node::Objects &fs, LeafNode &leaf);
void insertObject(InsertContainer &obj, LeafNode &leaf);
typedef std::stack<Node::ID> UncheckedNode;
void search(SearchContainer &so);
void search(SearchContainer &so, InternalNode &node, UncheckedNode &uncheckedNode);
void search(SearchContainer &so, LeafNode &node, UncheckedNode &uncheckedNode);
bool searchObject(ObjectID id) {
LeafNode &ln = getLeaf(id);
for (size_t i = 0; i < ln.getObjectSize(); i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (ln.getObjectIDs(leafNodes.allocator)[i].id == id) {
#else
if (ln.getObjectIDs()[i].id == id) {
#endif
return true;
}
}
return false;
}
LeafNode &getLeaf(ObjectID id) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
Object *qobject = objectSpace->allocateObject(*getObjectRepository().get(id));
SearchContainer q(*qobject);
#else
SearchContainer q(*getObjectRepository().get(id));
#endif
q.mode = SearchContainer::SearchLeaf;
q.vptree = this;
q.radius = 0.0;
q.size = 1;
search(q);
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
objectSpace->deleteObject(qobject);
#endif
return *(LeafNode*)getNode(q.nodeID);
}
void replace(ObjectID id, ObjectID replacedId) { remove(id, replacedId); }
// remove the specified object.
void remove(ObjectID id, ObjectID replaceId = 0) {
LeafNode &ln = getLeaf(id);
try {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln.removeObject(id, replaceId, leafNodes.allocator);
#else
ln.removeObject(id, replaceId);
#endif
} catch(Exception &err) {
std::stringstream msg;
msg << "VpTree::remove: Inner error. Cannot remove object. leafNode=" << ln.id.getID() << ":" << err.what();
NGTThrowException(msg);
}
if (ln.getObjectSize() == 0) {
if (ln.parent.getID() != 0) {
InternalNode &inode = *(InternalNode*)getNode(ln.parent);
removeEmptyNodes(inode);
}
}
return;
}
void removeNaively(ObjectID id, ObjectID replaceId = 0) {
for (size_t i = 0; i < leafNodes.size(); i++) {
if (leafNodes[i] != 0) {
try {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
leafNodes[i]->removeObject(id, replaceId, leafNodes.allocator);
#else
leafNodes[i]->removeObject(id, replaceId);
#endif
break;
} catch(...) {}
}
}
}
Node *getRootNode() {
size_t nid = 1;
Node *root;
try {
root = internalNodes.get(nid);
} catch(Exception &err) {
try {
root = leafNodes.get(nid);
} catch(Exception &e) {
std::stringstream msg;
msg << "VpTree::getRootNode: Inner error. Cannot get a leaf root node. " << nid << ":" << e.what();
NGTThrowException(msg);
}
}
return root;
}
InternalNode *createInternalNode() {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
InternalNode *n = new(internalNodes.allocator) InternalNode(internalChildrenSize, internalNodes.allocator);
#else
InternalNode *n = new InternalNode(internalChildrenSize);
#endif
insertNode(n);
return n;
}
void
removeNode(Node::ID id) {
size_t idx = id.getID();
if (id.getType() == Node::ID::Leaf) {
leafNodes.remove(idx);
} else {
internalNodes.remove(idx);
}
}
void removeEmptyNodes(InternalNode &node);
Node::Objects * getObjects(LeafNode &n, Container &iobj);
// for milvus
void
getObjectIDsFromLeaf(Node::ID nid, ObjectDistances& rl, faiss::ConcurrentBitsetPtr& bitset) {
LeafNode& ln = *(LeafNode*)getNode(nid);
rl.clear();
ObjectDistance r;
for (size_t i = 0; i < ln.getObjectSize(); i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
r.id = ln.getObjectIDs(leafNodes.allocator)[i].id;
r.distance = ln.getObjectIDs(leafNodes.allocator)[i].distance;
#else
r.id = ln.getObjectIDs()[i].id;
r.distance = ln.getObjectIDs()[i].distance;
#endif
if (bitset != nullptr && bitset->test(r.id - 1)) {
continue;
}
rl.push_back(r);
}
return;
}
void getObjectIDsFromLeaf(Node::ID nid, ObjectDistances &rl) {
LeafNode &ln = *(LeafNode*)getNode(nid);
rl.clear();
ObjectDistance r;
for (size_t i = 0; i < ln.getObjectSize(); i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
r.id = ln.getObjectIDs(leafNodes.allocator)[i].id;
r.distance = ln.getObjectIDs(leafNodes.allocator)[i].distance;
#else
r.id = ln.getObjectIDs()[i].id;
r.distance = ln.getObjectIDs()[i].distance;
#endif
rl.push_back(r);
}
return;
}
void
insertNode(LeafNode *n) {
size_t id = leafNodes.insert(n);
n->id.setID(id);
n->id.setType(Node::ID::Leaf);
}
// replace
void replaceNode(LeafNode *n) {
int id = n->id.getID();
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
leafNodes.set(id, n);
#else
leafNodes[id] = n;
#endif
}
void
insertNode(InternalNode *n)
{
size_t id = internalNodes.insert(n);
n->id.setID(id);
n->id.setType(Node::ID::Internal);
}
Node *getNode(Node::ID &id) {
Node *n = 0;
Node::NodeID idx = id.getID();
if (id.getType() == Node::ID::Leaf) {
n = leafNodes.get(idx);
} else {
n = internalNodes.get(idx);
}
return n;
}
void getAllLeafNodeIDs(std::vector<Node::ID> &leafIDs) {
leafIDs.clear();
Node *root = getRootNode();
if (root->id.getType() == Node::ID::Leaf) {
leafIDs.push_back(root->id);
return;
}
UncheckedNode uncheckedNode;
uncheckedNode.push(root->id);
while (!uncheckedNode.empty()) {
Node::ID nodeid = uncheckedNode.top();
uncheckedNode.pop();
Node *cnode = getNode(nodeid);
if (cnode->id.getType() == Node::ID::Internal) {
InternalNode &inode = static_cast<InternalNode&>(*cnode);
for (size_t ci = 0; ci < internalChildrenSize; ci++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
uncheckedNode.push(inode.getChildren(internalNodes.allocator)[ci]);
#else
uncheckedNode.push(inode.getChildren()[ci]);
#endif
}
} else if (cnode->id.getType() == Node::ID::Leaf) {
leafIDs.push_back(static_cast<LeafNode&>(*cnode).id);
} else {
std::cerr << "Tree: Inner fatal error!: Node type error!" << std::endl;
abort();
}
}
}
// for milvus
void serialize(std::stringstream & os)
{
leafNodes.serialize(os, objectSpace);
internalNodes.serialize(os, objectSpace);
}
void serialize(std::ofstream &os) {
leafNodes.serialize(os, objectSpace);
internalNodes.serialize(os, objectSpace);
}
void deserialize(std::ifstream &is) {
leafNodes.deserialize(is, objectSpace);
internalNodes.deserialize(is, objectSpace);
}
void deserialize(std::stringstream & is)
{
leafNodes.deserialize(is, objectSpace);
internalNodes.deserialize(is, objectSpace);
}
void serializeAsText(std::ofstream &os) {
leafNodes.serializeAsText(os, objectSpace);
internalNodes.serializeAsText(os, objectSpace);
}
void deserializeAsText(std::ifstream &is) {
leafNodes.deserializeAsText(is, objectSpace);
internalNodes.deserializeAsText(is, objectSpace);
}
void show() {
std::cout << "Show tree " << std::endl;
for (size_t i = 0; i < leafNodes.size(); i++) {
if (leafNodes[i] != 0) {
std::cout << i << ":";
(*leafNodes[i]).show();
}
}
for (size_t i = 0; i < internalNodes.size(); i++) {
if (internalNodes[i] != 0) {
std::cout << i << ":";
(*internalNodes[i]).show();
}
}
}
bool verify(size_t objCount, std::vector<uint8_t> &status) {
std::cerr << "Started verifying internal nodes. size=" << internalNodes.size() << "..." << std::endl;
bool valid = true;
for (size_t i = 0; i < internalNodes.size(); i++) {
if (internalNodes[i] != 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
valid = valid && (*internalNodes[i]).verify(internalNodes, leafNodes, internalNodes.allocator);
#else
valid = valid && (*internalNodes[i]).verify(internalNodes, leafNodes);
#endif
}
}
std::cerr << "Started verifying leaf nodes. size=" << leafNodes.size() << " ..." << std::endl;
for (size_t i = 0; i < leafNodes.size(); i++) {
if (leafNodes[i] != 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
valid = valid && (*leafNodes[i]).verify(objCount, status, leafNodes.allocator);
#else
valid = valid && (*leafNodes[i]).verify(objCount, status);
#endif
}
}
return valid;
}
void deleteInMemory() {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
assert(0);
#else
for (std::vector<NGT::LeafNode*>::iterator i = leafNodes.begin(); i != leafNodes.end(); i++) {
if ((*i) != 0) {
delete (*i);
}
}
leafNodes.clear();
for (std::vector<NGT::InternalNode*>::iterator i = internalNodes.begin(); i != internalNodes.end(); i++) {
if ((*i) != 0) {
delete (*i);
}
}
internalNodes.clear();
#endif
}
ObjectRepository &getObjectRepository() { return objectSpace->getRepository(); }
size_t getSharedMemorySize(std::ostream &os, SharedMemoryAllocator::GetMemorySizeType t) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
size_t isize = internalNodes.getAllocator().getMemorySize(t);
os << "internal=" << isize << std::endl;
size_t lsize = leafNodes.getAllocator().getMemorySize(t);
os << "leaf=" << lsize << std::endl;
return isize + lsize;
#else
return 0;
#endif
}
void getAllObjectIDs(std::set<ObjectID> &ids) {
for (size_t i = 0; i < leafNodes.size(); i++) {
if (leafNodes[i] != 0) {
LeafNode &ln = *leafNodes[i];
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
auto objs = ln.getObjectIDs(leafNodes.allocator);
#else
auto objs = ln.getObjectIDs();
#endif
for (size_t idx = 0; idx < ln.objectSize; ++idx) {
ids.insert(objs[idx].id);
}
}
}
}
public:
size_t internalChildrenSize;
size_t leafObjectsSize;
SplitMode splitMode;
std::string name;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
PersistentRepository<LeafNode> leafNodes;
PersistentRepository<InternalNode> internalNodes;
#else
Repository<LeafNode> leafNodes;
Repository<InternalNode> internalNodes;
#endif
ObjectSpace *objectSpace;
};
} // namespace DVPTree

View File

@ -0,0 +1,58 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/Version.h"
void
NGT::Version::get(std::ostream &os)
{
os << " Version:" << NGT::Version::getVersion() << std::endl;
os << " Built date:" << NGT::Version::getBuildDate() << std::endl;
os << " The last git tag:" << Version::getGitTag() << std::endl;
os << " The last git commit hash:" << Version::getGitHash() << std::endl;
os << " The last git commit date:" << Version::getGitDate() << std::endl;
}
const std::string
NGT::Version::getVersion()
{
return NGT_VERSION;
}
const std::string
NGT::Version::getBuildDate()
{
return NGT_BUILD_DATE;
}
const std::string
NGT::Version::getGitHash()
{
return NGT_GIT_HASH;
}
const std::string
NGT::Version::getGitDate()
{
return NGT_GIT_DATE;
}
const std::string
NGT::Version::getGitTag()
{
return NGT_GIT_TAG;
}

View File

@ -0,0 +1,61 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <ostream>
#include <string>
#ifndef NGT_VERSION
#define NGT_VERSION "-"
#endif
#ifndef NGT_BUILD_DATE
#define NGT_BUILD_DATE "-"
#endif
#ifndef NGT_GIT_HASH
#define NGT_GIT_HASH "-"
#endif
#ifndef NGT_GIT_DATE
#define NGT_GIT_DATE "-"
#endif
#ifndef NGT_GIT_TAG
#define NGT_GIT_TAG "-"
#endif
namespace NGT {
class Version {
public:
static void
get(std::ostream& os);
static const std::string
getVersion();
static const std::string
getBuildDate();
static const std::string
getGitHash();
static const std::string
getGitDate();
static const std::string
getGitTag();
static const std::string
get();
};
}; // namespace NGT
#ifdef NGT_VERSION_FOR_HEADER
#include "Version.cpp"
#endif

View File

@ -0,0 +1,60 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
// Begin of cmake defines
#if 0
#cmakedefine NGT_SHARED_MEMORY_ALLOCATOR // use shared memory for indexes
#cmakedefine NGT_GRAPH_CHECK_VECTOR // use vector to check whether accessed
#cmakedefine NGT_AVX_DISABLED // not use avx to compare
#cmakedefine NGT_LARGE_DATASET // more than 10M objects
#cmakedefine NGT_DISTANCE_COMPUTATION_COUNT // count # of distance computations
#endif
// End of cmake defines
//////////////////////////////////////////////////////////////////////////
// Release Definitions for OSS
//#define NGT_DISTANCE_COMPUTATION_COUNT
#define NGT_CREATION_EDGE_SIZE 10
#define NGT_EXPLORATION_COEFFICIENT 1.1
#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1
#define NGT_SHARED_MEMORY_MAX_SIZE 1024 // MB
#define NGT_FORCED_REMOVE // When errors occur due to the index inconsistency, ignore them.
#define NGT_COMPACT_VECTOR
#define NGT_GRAPH_READ_ONLY_GRAPH
#ifdef NGT_LARGE_DATASET
#define NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET
#else
#define NGT_GRAPH_CHECK_VECTOR
#endif
#if defined(NGT_AVX_DISABLED)
#define NGT_NO_AVX
#else
#if defined(__AVX512F__) && defined(__AVX512DQ__)
#define NGT_AVX512
#elif defined(__AVX2__)
#define NGT_AVX2
#else
#define NGT_NO_AVX
#endif
#endif

View File

@ -0,0 +1,58 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
// Begin of cmake defines
#cmakedefine NGT_SHARED_MEMORY_ALLOCATOR // use shared memory for indexes
#cmakedefine NGT_GRAPH_CHECK_VECTOR // use vector to check whether accessed
#cmakedefine NGT_AVX_DISABLED // not use avx to compare
#cmakedefine NGT_LARGE_DATASET // more than 10M objects
#cmakedefine NGT_DISTANCE_COMPUTATION_COUNT // count # of distance computations
// End of cmake defines
//////////////////////////////////////////////////////////////////////////
// Release Definitions for OSS
//#define NGT_DISTANCE_COMPUTATION_COUNT
#define NGT_CREATION_EDGE_SIZE 10
#define NGT_EXPLORATION_COEFFICIENT 1.1
#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1
#define NGT_SHARED_MEMORY_MAX_SIZE 1024 // MB
#define NGT_FORCED_REMOVE // When errors occur due to the index inconsistency, ignore them.
#define NGT_COMPACT_VECTOR
#define NGT_GRAPH_READ_ONLY_GRAPH
#ifdef NGT_LARGE_DATASET
#define NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET
#else
#define NGT_GRAPH_CHECK_VECTOR
#endif
#if defined(NGT_AVX_DISABLED)
#define NGT_NO_AVX
#else
#if defined(__AVX512F__) && defined(__AVX512DQ__)
#define NGT_AVX512
#elif defined(__AVX2__)
#define NGT_AVX2
#else
#define NGT_NO_AVX
#endif
#endif

View File

@ -1,11 +1,13 @@
include_directories(${INDEX_SOURCE_DIR}/thirdparty)
include_directories(${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService)
include_directories(${INDEX_SOURCE_DIR}/thirdparty/NGT/lib)
include_directories(${INDEX_SOURCE_DIR}/knowhere)
include_directories(${INDEX_SOURCE_DIR})
set(depend_libs
gtest gmock gtest_main gmock_main
faiss fiu
ngt
)
if (FAISS_WITH_MKL)
set(depend_libs ${depend_libs}
@ -268,3 +270,28 @@ install(TARGETS test_structured_index_sort DESTINATION unittest)
#add_subdirectory(faiss_benchmark)
#add_subdirectory(metric_alg_benchmark)
################################################################################
#<NGTPANNG-TEST>
set(ngtpanng_srcs
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGTPANNG.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGT.cpp
)
if (NOT TARGET test_ngtpanng)
add_executable(test_ngtpanng test_ngtpanng.cpp ${ngtpanng_srcs} ${util_srcs})
endif ()
target_link_libraries(test_ngtpanng ${depend_libs} ${unittest_libs} ${basic_libs})
install(TARGETS test_ngtpanng DESTINATION unittest)
################################################################################
#<NGTPANNG-TEST>
set(ngtonng_srcs
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGTONNG.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGT.cpp
)
if (NOT TARGET test_ngtonng)
add_executable(test_ngtonng test_ngtonng.cpp ${ngtonng_srcs} ${util_srcs})
endif ()
target_link_libraries(test_ngtonng ${depend_libs} ${unittest_libs} ${basic_libs})
install(TARGETS test_ngtonng DESTINATION unittest)
################################################################################

View File

@ -0,0 +1,146 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <gtest/gtest.h>
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
#include <iostream>
#include <sstream>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexNGTONNG.h"
#include "unittest/utils.h"
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class NGTONNGTest : public DataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
IndexType = GetParam();
Generate(128, 10000, 10);
index_ = std::make_shared<milvus::knowhere::IndexNGTONNG>();
conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, dim},
{milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
};
}
protected:
milvus::knowhere::Config conf;
std::shared_ptr<milvus::knowhere::IndexNGTONNG> index_ = nullptr;
std::string IndexType;
};
INSTANTIATE_TEST_CASE_P(NGTONNGParameters, NGTONNGTest, Values("NGTONNG"));
TEST_P(NGTONNGTest, ngtonng_basic) {
assert(!xb.empty());
// null index
{
ASSERT_ANY_THROW(index_->Train(base_dataset, conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
ASSERT_ANY_THROW(index_->Serialize(conf));
ASSERT_ANY_THROW(index_->Add(base_dataset, conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf));
ASSERT_ANY_THROW(index_->Count());
ASSERT_ANY_THROW(index_->Dim());
}
index_->BuildAll(base_dataset, conf); // Train + Add
ASSERT_EQ(index_->Count(), nb);
ASSERT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, k);
}
TEST_P(NGTONNGTest, ngtonng_delete) {
assert(!xb.empty());
index_->BuildAll(base_dataset, conf); // Train + Add
ASSERT_EQ(index_->Count(), nb);
ASSERT_EQ(index_->Dim(), dim);
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
for (auto i = 0; i < nq; ++i) {
bitset->set(i);
}
auto result1 = index_->Query(query_dataset, conf);
AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
}
TEST_P(NGTONNGTest, ngtonng_serialize) {
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
{
// write and flush
FileIOWriter writer(filename);
writer(static_cast<void*>(bin->data.get()), bin->size);
}
FileIOReader reader(filename);
reader(ret, bin->size);
};
{
// serialize index
index_->BuildAll(base_dataset, conf);
auto binaryset = index_->Serialize(milvus::knowhere::Config());
auto bin_obj_data = binaryset.GetByName("ngt_obj_data");
std::string filename1 = "/tmp/ngt_obj_data_serialize.bin";
auto load_data1 = new uint8_t[bin_obj_data->size];
serialize(filename1, bin_obj_data, load_data1);
auto bin_grp_data = binaryset.GetByName("ngt_grp_data");
std::string filename2 = "/tmp/ngt_grp_data_serialize.bin";
auto load_data2 = new uint8_t[bin_grp_data->size];
serialize(filename2, bin_grp_data, load_data2);
auto bin_prf_data = binaryset.GetByName("ngt_prf_data");
std::string filename3 = "/tmp/ngt_prf_data_serialize.bin";
auto load_data3 = new uint8_t[bin_prf_data->size];
serialize(filename3, bin_prf_data, load_data3);
auto bin_tre_data = binaryset.GetByName("ngt_tre_data");
std::string filename4 = "/tmp/ngt_tre_data_serialize.bin";
auto load_data4 = new uint8_t[bin_tre_data->size];
serialize(filename4, bin_tre_data, load_data4);
binaryset.clear();
std::shared_ptr<uint8_t[]> obj_data(load_data1);
binaryset.Append("ngt_obj_data", obj_data, bin_obj_data->size);
std::shared_ptr<uint8_t[]> grp_data(load_data2);
binaryset.Append("ngt_grp_data", grp_data, bin_grp_data->size);
std::shared_ptr<uint8_t[]> prf_data(load_data3);
binaryset.Append("ngt_prf_data", prf_data, bin_prf_data->size);
std::shared_ptr<uint8_t[]> tre_data(load_data4);
binaryset.Append("ngt_tre_data", tre_data, bin_tre_data->size);
index_->Load(binaryset);
ASSERT_EQ(index_->Count(), nb);
ASSERT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
}
}

View File

@ -0,0 +1,146 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <gtest/gtest.h>
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
#include <iostream>
#include <sstream>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexNGTPANNG.h"
#include "unittest/utils.h"
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class NGTPANNGTest : public DataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
IndexType = GetParam();
Generate(128, 10000, 10);
index_ = std::make_shared<milvus::knowhere::IndexNGTPANNG>();
conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, dim},
{milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
};
}
protected:
milvus::knowhere::Config conf;
std::shared_ptr<milvus::knowhere::IndexNGTPANNG> index_ = nullptr;
std::string IndexType;
};
INSTANTIATE_TEST_CASE_P(NGTPANNGParameters, NGTPANNGTest, Values("NGTPANNG"));
TEST_P(NGTPANNGTest, ngtpanng_basic) {
assert(!xb.empty());
// null index
{
ASSERT_ANY_THROW(index_->Train(base_dataset, conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
ASSERT_ANY_THROW(index_->Serialize(conf));
ASSERT_ANY_THROW(index_->Add(base_dataset, conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf));
ASSERT_ANY_THROW(index_->Count());
ASSERT_ANY_THROW(index_->Dim());
}
index_->BuildAll(base_dataset, conf);
ASSERT_EQ(index_->Count(), nb);
ASSERT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, k);
}
TEST_P(NGTPANNGTest, ngtpanng_delete) {
assert(!xb.empty());
index_->BuildAll(base_dataset, conf);
ASSERT_EQ(index_->Count(), nb);
ASSERT_EQ(index_->Dim(), dim);
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
for (auto i = 0; i < nq; ++i) {
bitset->set(i);
}
auto result1 = index_->Query(query_dataset, conf);
AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
}
TEST_P(NGTPANNGTest, ngtpanng_serialize) {
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
{
// write and flush
FileIOWriter writer(filename);
writer(static_cast<void*>(bin->data.get()), bin->size);
}
FileIOReader reader(filename);
reader(ret, bin->size);
};
{
// serialize index
index_->BuildAll(base_dataset, conf);
auto binaryset = index_->Serialize(milvus::knowhere::Config());
auto bin_obj_data = binaryset.GetByName("ngt_obj_data");
std::string filename1 = "/tmp/ngt_obj_data_serialize.bin";
auto load_data1 = new uint8_t[bin_obj_data->size];
serialize(filename1, bin_obj_data, load_data1);
auto bin_grp_data = binaryset.GetByName("ngt_grp_data");
std::string filename2 = "/tmp/ngt_grp_data_serialize.bin";
auto load_data2 = new uint8_t[bin_grp_data->size];
serialize(filename2, bin_grp_data, load_data2);
auto bin_prf_data = binaryset.GetByName("ngt_prf_data");
std::string filename3 = "/tmp/ngt_prf_data_serialize.bin";
auto load_data3 = new uint8_t[bin_prf_data->size];
serialize(filename3, bin_prf_data, load_data3);
auto bin_tre_data = binaryset.GetByName("ngt_tre_data");
std::string filename4 = "/tmp/ngt_tre_data_serialize.bin";
auto load_data4 = new uint8_t[bin_tre_data->size];
serialize(filename4, bin_tre_data, load_data4);
binaryset.clear();
std::shared_ptr<uint8_t[]> obj_data(load_data1);
binaryset.Append("ngt_obj_data", obj_data, bin_obj_data->size);
std::shared_ptr<uint8_t[]> grp_data(load_data2);
binaryset.Append("ngt_grp_data", grp_data, bin_grp_data->size);
std::shared_ptr<uint8_t[]> prf_data(load_data3);
binaryset.Append("ngt_prf_data", prf_data, bin_prf_data->size);
std::shared_ptr<uint8_t[]> tre_data(load_data4);
binaryset.Append("ngt_tre_data", tre_data, bin_tre_data->size);
index_->Load(binaryset);
ASSERT_EQ(index_->Count(), nb);
ASSERT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
}
}

View File

@ -191,6 +191,8 @@ ValidateIndexType(std::string& index_type) {
knowhere::IndexEnum::INDEX_RHNSWFlat,
knowhere::IndexEnum::INDEX_RHNSWPQ,
knowhere::IndexEnum::INDEX_RHNSWSQ,
knowhere::IndexEnum::INDEX_NGTPANNG,
knowhere::IndexEnum::INDEX_NGTONNG,
// structured index names
engine::DEFAULT_STRUCTURED_INDEX,

View File

@ -26,6 +26,8 @@ const char* NAME_ENGINE_TYPE_ANNOY = "ANNOY";
const char* NAME_ENGINE_TYPE_RHNSWFLAT = "RHNSW_FLAT";
const char* NAME_ENGINE_TYPE_RHNSWPQ = "RHNSW_PQ";
const char* NAME_ENGINE_TYPE_RHNSWSQ8 = "RHNSW_SQ8";
const char* NAME_ENGINE_TYPE_NGTPANNG = "NGTPANNG";
const char* NAME_ENGINE_TYPE_NGTONNG = "NGTONNG";
const char* NAME_METRIC_TYPE_L2 = "L2";
const char* NAME_METRIC_TYPE_IP = "IP";

View File

@ -29,6 +29,8 @@ extern const char* NAME_ENGINE_TYPE_ANNOY;
extern const char* NAME_ENGINE_TYPE_RHNSWFLAT;
extern const char* NAME_ENGINE_TYPE_RHNSWPQ;
extern const char* NAME_ENGINE_TYPE_RHNSWSQ;
extern const char* NAME_ENGINE_TYPE_NGTPANNG;
extern const char* NAME_ENGINE_TYPE_NGTONNG;
extern const char* NAME_METRIC_TYPE_L2;
extern const char* NAME_METRIC_TYPE_IP;

View File

@ -263,7 +263,7 @@ void
ClientTest::CreateIndex(const std::string& collection_name, int64_t nlist) {
milvus_sdk::TimeRecorder rc("Create index");
std::cout << "Wait until create all index done" << std::endl;
JSON json_params = {{"index_type", "IVF_FLAT"}, {"metric_type", "L2"}, {"params", {{"nlist", nlist}}}};
JSON json_params = {{"index_type", "NGT_PANNG"}, {"metric_type", "L2"}};
milvus::IndexParam index1 = {collection_name, "field_vec", json_params.dump()};
milvus_sdk::Utils::PrintIndexParam(index1);
milvus::Status stat = conn_->CreateIndex(index1);

View File

@ -123,6 +123,10 @@ Utils::IndexTypeName(const milvus::IndexType& index_type) {
return "RHNSWPQ";
case milvus::IndexType::ANNOY:
return "ANNOY";
case milvus::IndexType::NGTPANNG:
return "NGTPANNG";
case milvus::IndexType::NGTONNG:
return "NGTONNG";
default:
return "Unknown index type";
}

View File

@ -43,6 +43,8 @@ enum class IndexType {
RHNSWFLAT = 13,
RHNSWPQ = 14,
RHNSWSQ = 15,
NGTPANNG = 16,
NGTONNG = 17,
};
enum class MetricType {