mirror of https://github.com/milvus-io/milvus.git
add NGT index (#3555)
* add NGT index Signed-off-by: fenglv <fenglv15@mails.ucas.ac.cn>pull/3676/head
parent
16169bc7a4
commit
f70d766475
|
@ -23,3 +23,4 @@
|
|||
| hnswlib | [Apache 2.0](https://github.com/nmslib/hnswlib/blob/master/LICENSE) |
|
||||
| annoy | [Apache 2.0](https://github.com/spotify/annoy/blob/master/LICENSE) |
|
||||
| crc32c | [BSD 3-Clause](https://github.com/google/crc32c/blob/master/LICENSE) |
|
||||
| NGT | [Apache 2.0](https://github.com/yahoojapan/NGT/blob/master/LICENSE) |
|
||||
|
|
|
@ -653,3 +653,5 @@ if (KNOWHERE_WITH_FAISS AND NOT TARGET faiss_ep)
|
|||
include_directories(SYSTEM "${FAISS_INCLUDE_DIR}")
|
||||
link_directories(SYSTEM ${FAISS_PREFIX}/lib/)
|
||||
endif ()
|
||||
|
||||
add_subdirectory(thirdparty/NGT)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
include_directories(${INDEX_SOURCE_DIR}/knowhere)
|
||||
include_directories(${INDEX_SOURCE_DIR}/thirdparty)
|
||||
include_directories(${INDEX_SOURCE_DIR}/thirdparty/NGT/lib)
|
||||
|
||||
if (MILVUS_SUPPORT_SPTAG)
|
||||
include_directories(${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService)
|
||||
|
@ -68,6 +69,9 @@ set(vector_index_srcs
|
|||
knowhere/index/vector_index/IndexRHNSWFlat.cpp
|
||||
knowhere/index/vector_index/IndexRHNSWSQ.cpp
|
||||
knowhere/index/vector_index/IndexRHNSWPQ.cpp
|
||||
knowhere/index/vector_index/IndexNGT.cpp
|
||||
knowhere/index/vector_index/IndexNGTPANNG.cpp
|
||||
knowhere/index/vector_index/IndexNGTONNG.cpp
|
||||
)
|
||||
|
||||
set(vector_offset_index_srcs
|
||||
|
@ -91,6 +95,7 @@ set(depend_libs
|
|||
gfortran
|
||||
pthread
|
||||
fiu
|
||||
ngt
|
||||
)
|
||||
|
||||
if (MILVUS_SUPPORT_SPTAG)
|
||||
|
|
|
@ -37,6 +37,8 @@ const char* INDEX_RHNSWFlat = "RHNSW_FLAT";
|
|||
const char* INDEX_RHNSWPQ = "RHNSW_PQ";
|
||||
const char* INDEX_RHNSWSQ = "RHNSW_SQ";
|
||||
const char* INDEX_ANNOY = "ANNOY";
|
||||
const char* INDEX_NGTPANNG = "NGT_PANNG";
|
||||
const char* INDEX_NGTONNG = "NGT_ONNG";
|
||||
} // namespace IndexEnum
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -64,6 +64,8 @@ extern const char* INDEX_RHNSWFlat;
|
|||
extern const char* INDEX_RHNSWPQ;
|
||||
extern const char* INDEX_RHNSWSQ;
|
||||
extern const char* INDEX_ANNOY;
|
||||
extern const char* INDEX_NGTPANNG;
|
||||
extern const char* INDEX_NGTONNG;
|
||||
} // namespace IndexEnum
|
||||
|
||||
enum class IndexMode { MODE_CPU = 0, MODE_GPU = 1 };
|
||||
|
|
|
@ -389,5 +389,37 @@ ANNOYConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexM
|
|||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
NGTPANNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
NGTPANNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
NGTONNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
NGTONNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -126,5 +126,24 @@ class RHNSWSQConfAdapter : public ConfAdapter {
|
|||
bool
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class NGTPANNGConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
bool
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class NGTONNGConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
bool
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -53,6 +53,8 @@ AdapterMgr::RegisterAdapter() {
|
|||
REGISTER_CONF_ADAPTER(RHNSWFlatConfAdapter, IndexEnum::INDEX_RHNSWFlat, rhnswflat_adapter);
|
||||
REGISTER_CONF_ADAPTER(RHNSWPQConfAdapter, IndexEnum::INDEX_RHNSWPQ, rhnswpq_adapter);
|
||||
REGISTER_CONF_ADAPTER(RHNSWSQConfAdapter, IndexEnum::INDEX_RHNSWSQ, rhnswsq_adapter);
|
||||
REGISTER_CONF_ADAPTER(NGTPANNGConfAdapter, IndexEnum::INDEX_NGTPANNG, ngtpanng_adapter);
|
||||
REGISTER_CONF_ADAPTER(NGTONNGConfAdapter, IndexEnum::INDEX_NGTONNG, ngtonng_adapter);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexNGT.h"
|
||||
|
||||
#include <omp.h>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
BinarySet
|
||||
IndexNGT::Serialize(const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
std::stringstream obj, grp, prf, tre;
|
||||
index_->saveIndex(obj, grp, prf, tre);
|
||||
|
||||
auto obj_str = obj.str();
|
||||
auto grp_str = grp.str();
|
||||
auto prf_str = prf.str();
|
||||
auto tre_str = tre.str();
|
||||
uint64_t obj_size = obj_str.size();
|
||||
uint64_t grp_size = grp_str.size();
|
||||
uint64_t prf_size = prf_str.size();
|
||||
uint64_t tre_size = tre_str.size();
|
||||
|
||||
std::shared_ptr<uint8_t[]> obj_data(new uint8_t[obj_size]);
|
||||
memcpy(obj_data.get(), obj_str.data(), obj_size);
|
||||
std::shared_ptr<uint8_t[]> grp_data(new uint8_t[grp_size]);
|
||||
memcpy(grp_data.get(), grp_str.data(), grp_size);
|
||||
std::shared_ptr<uint8_t[]> prf_data(new uint8_t[prf_size]);
|
||||
memcpy(prf_data.get(), prf_str.data(), prf_size);
|
||||
std::shared_ptr<uint8_t[]> tre_data(new uint8_t[tre_size]);
|
||||
memcpy(tre_data.get(), tre_str.data(), tre_size);
|
||||
|
||||
BinarySet res_set;
|
||||
res_set.Append("ngt_obj_data", obj_data, obj_size);
|
||||
res_set.Append("ngt_grp_data", grp_data, grp_size);
|
||||
res_set.Append("ngt_prf_data", prf_data, prf_size);
|
||||
res_set.Append("ngt_tre_data", tre_data, tre_size);
|
||||
return res_set;
|
||||
}
|
||||
|
||||
void
|
||||
IndexNGT::Load(const BinarySet& index_binary) {
|
||||
auto obj_data = index_binary.GetByName("ngt_obj_data");
|
||||
std::string obj_str(reinterpret_cast<char*>(obj_data->data.get()), obj_data->size);
|
||||
|
||||
auto grp_data = index_binary.GetByName("ngt_grp_data");
|
||||
std::string grp_str(reinterpret_cast<char*>(grp_data->data.get()), grp_data->size);
|
||||
|
||||
auto prf_data = index_binary.GetByName("ngt_prf_data");
|
||||
std::string prf_str(reinterpret_cast<char*>(prf_data->data.get()), prf_data->size);
|
||||
|
||||
auto tre_data = index_binary.GetByName("ngt_tre_data");
|
||||
std::string tre_str(reinterpret_cast<char*>(tre_data->data.get()), tre_data->size);
|
||||
|
||||
std::stringstream obj(obj_str);
|
||||
std::stringstream grp(grp_str);
|
||||
std::stringstream prf(prf_str);
|
||||
std::stringstream tre(tre_str);
|
||||
|
||||
index_ = std::shared_ptr<NGT::Index>(NGT::Index::loadIndex(obj, grp, prf, tre));
|
||||
}
|
||||
|
||||
void
|
||||
IndexNGT::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
KNOWHERE_THROW_MSG("IndexNGT has no implementation of BuildAll, please use IndexNGT(PANNG/ONNG) instead!");
|
||||
}
|
||||
|
||||
#if 0
|
||||
void
|
||||
IndexNGT::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
KNOWHERE_THROW_MSG("IndexNGT has no implementation of Train, please use IndexNGT(PANNG/ONNG) instead!");
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr);
|
||||
|
||||
NGT::Property prop;
|
||||
prop.setDefaultForCreateIndex();
|
||||
prop.dimension = dim;
|
||||
|
||||
MetricType metric_type = config[Metric::TYPE];
|
||||
|
||||
if (metric_type == Metric::L2)
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
|
||||
else if (metric_type == Metric::HAMMING)
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
|
||||
else if (metric_type == Metric::JACCARD)
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
|
||||
else
|
||||
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
|
||||
index_ =
|
||||
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
|
||||
}
|
||||
|
||||
void
|
||||
IndexNGT::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset_ptr);
|
||||
|
||||
index_->append(reinterpret_cast<const float*>(p_data), rows);
|
||||
}
|
||||
#endif
|
||||
|
||||
DatasetPtr
|
||||
IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset_ptr);
|
||||
|
||||
size_t k = config[meta::TOPK].get<int64_t>();
|
||||
size_t id_size = sizeof(int64_t) * k;
|
||||
size_t dist_size = sizeof(float) * k;
|
||||
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
|
||||
auto p_dist = static_cast<float*>(malloc(dist_size * rows));
|
||||
|
||||
NGT::Command::SearchParameter sp;
|
||||
sp.size = k;
|
||||
|
||||
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < rows; ++i) {
|
||||
const float* single_query = reinterpret_cast<float*>(const_cast<void*>(p_data)) + i * Dim();
|
||||
|
||||
NGT::Object* object = index_->allocateObject(single_query, Dim());
|
||||
NGT::SearchContainer sc(*object);
|
||||
|
||||
double epsilon = sp.beginOfEpsilon;
|
||||
|
||||
NGT::ObjectDistances res;
|
||||
sc.setResults(&res);
|
||||
sc.setSize(sp.size);
|
||||
sc.setRadius(sp.radius);
|
||||
|
||||
if (sp.accuracy > 0.0) {
|
||||
sc.setExpectedAccuracy(sp.accuracy);
|
||||
} else {
|
||||
sc.setEpsilon(epsilon);
|
||||
}
|
||||
sc.setEdgeSize(sp.edgeSize);
|
||||
|
||||
try {
|
||||
index_->search(sc, blacklist);
|
||||
} catch (NGT::Exception& err) {
|
||||
KNOWHERE_THROW_MSG("Query failed");
|
||||
}
|
||||
|
||||
auto local_id = p_id + i * k;
|
||||
auto local_dist = p_dist + i * k;
|
||||
|
||||
int64_t res_num = res.size();
|
||||
for (int64_t idx = 0; idx < res_num; ++idx) {
|
||||
*(local_id + idx) = res[idx].id - 1;
|
||||
*(local_dist + idx) = res[idx].distance;
|
||||
}
|
||||
while (res_num < static_cast<int64_t>(k)) {
|
||||
*(local_id + res_num) = -1;
|
||||
*(local_dist + res_num) = 1.0 / 0.0;
|
||||
}
|
||||
index_->deleteObject(object);
|
||||
}
|
||||
|
||||
auto res_ds = std::make_shared<Dataset>();
|
||||
res_ds->Set(meta::IDS, p_id);
|
||||
res_ds->Set(meta::DISTANCE, p_dist);
|
||||
return res_ds;
|
||||
}
|
||||
|
||||
int64_t
|
||||
IndexNGT::Count() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
return index_->getNumberOfVectors();
|
||||
}
|
||||
|
||||
int64_t
|
||||
IndexNGT::Dim() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
return index_->getDimension();
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,70 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <NGT/lib/NGT/Command.h>
|
||||
#include <NGT/lib/NGT/Common.h>
|
||||
#include <NGT/lib/NGT/Index.h>
|
||||
|
||||
#include <knowhere/common/Exception.h>
|
||||
#include <knowhere/index/IndexType.h>
|
||||
#include <knowhere/index/vector_index/VecIndex.h>
|
||||
#include <memory>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IndexNGT : public VecIndex {
|
||||
public:
|
||||
IndexNGT() {
|
||||
index_type_ = IndexEnum::INVALID;
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize(const Config& config) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
Train(const DatasetPtr& dataset_ptr, const Config& config) override {
|
||||
KNOWHERE_THROW_MSG("NGT not support add item dynamically, please invoke BuildAll interface.");
|
||||
}
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset_ptr, const Config& config) override {
|
||||
KNOWHERE_THROW_MSG("NGT not support add item dynamically, please invoke BuildAll interface.");
|
||||
}
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) override {
|
||||
KNOWHERE_THROW_MSG("Incremental index is not supported");
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
int64_t
|
||||
Dim() override;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<NGT::Index> index_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexNGTONNG.h"
|
||||
|
||||
#include "NGT/lib/NGT/GraphOptimizer.h"
|
||||
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
void
|
||||
IndexNGTONNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr);
|
||||
|
||||
NGT::Property prop;
|
||||
prop.setDefaultForCreateIndex();
|
||||
prop.dimension = dim;
|
||||
prop.edgeSizeForCreation = 20;
|
||||
prop.insertionRadiusCoefficient = 1.0;
|
||||
|
||||
MetricType metric_type = config[Metric::TYPE];
|
||||
|
||||
if (metric_type == Metric::L2) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
|
||||
} else if (metric_type == Metric::HAMMING) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
|
||||
} else if (metric_type == Metric::JACCARD) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
|
||||
}
|
||||
index_ =
|
||||
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
|
||||
|
||||
// reconstruct graph
|
||||
NGT::GraphOptimizer graphOptimizer(true);
|
||||
|
||||
size_t number_of_outgoing_edges = 5;
|
||||
size_t number_of_incoming_edges = 30;
|
||||
size_t number_of_queries = 1000;
|
||||
size_t number_of_res = 20;
|
||||
|
||||
graphOptimizer.shortcutReduction = true;
|
||||
graphOptimizer.searchParameterOptimization = true;
|
||||
graphOptimizer.prefetchParameterOptimization = false;
|
||||
graphOptimizer.accuracyTableGeneration = false;
|
||||
graphOptimizer.margin = 0.2;
|
||||
graphOptimizer.gtEpsilon = 0.1;
|
||||
|
||||
graphOptimizer.set(number_of_outgoing_edges, number_of_incoming_edges, number_of_queries, number_of_res);
|
||||
|
||||
graphOptimizer.execute(*index_);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "knowhere/index/vector_index/IndexNGT.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IndexNGTONNG : public IndexNGT {
|
||||
public:
|
||||
IndexNGTONNG() {
|
||||
index_type_ = IndexEnum::INDEX_NGTONNG;
|
||||
}
|
||||
|
||||
void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,94 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexNGTPANNG.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
void
|
||||
IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr);
|
||||
|
||||
NGT::Property prop;
|
||||
prop.setDefaultForCreateIndex();
|
||||
prop.dimension = dim;
|
||||
|
||||
MetricType metric_type = config[Metric::TYPE];
|
||||
|
||||
if (metric_type == Metric::L2) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
|
||||
} else if (metric_type == Metric::HAMMING) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
|
||||
} else if (metric_type == Metric::JACCARD) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
|
||||
}
|
||||
index_ =
|
||||
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
|
||||
|
||||
size_t force_removed_edge_size = 60;
|
||||
size_t selective_removed_edge_size = 30;
|
||||
|
||||
// prune
|
||||
auto& graph = dynamic_cast<NGT::GraphIndex&>(index_->getIndex());
|
||||
for (size_t id = 1; id < graph.repository.size(); id++) {
|
||||
try {
|
||||
NGT::GraphNode& node = *graph.getNode(id);
|
||||
if (node.size() >= force_removed_edge_size) {
|
||||
node.resize(force_removed_edge_size);
|
||||
}
|
||||
if (node.size() >= selective_removed_edge_size) {
|
||||
size_t rank = 0;
|
||||
for (auto i = node.begin(); i != node.end(); ++rank) {
|
||||
if (rank >= selective_removed_edge_size) {
|
||||
bool found = false;
|
||||
for (size_t t1 = 0; t1 < node.size() && found == false; ++t1) {
|
||||
if (t1 >= selective_removed_edge_size) {
|
||||
break;
|
||||
}
|
||||
if (rank == t1) {
|
||||
continue;
|
||||
}
|
||||
NGT::GraphNode& node2 = *graph.getNode(node[t1].id);
|
||||
for (size_t t2 = 0; t2 < node2.size(); ++t2) {
|
||||
if (t2 >= selective_removed_edge_size) {
|
||||
break;
|
||||
}
|
||||
if (node2[t2].id == (*i).id) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
} // for
|
||||
} // for
|
||||
if (found) {
|
||||
// remove
|
||||
i = node.erase(i);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
i++;
|
||||
} // for
|
||||
}
|
||||
} catch (NGT::Exception& err) {
|
||||
std::cerr << "Graph::search: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "knowhere/index/vector_index/IndexNGT.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IndexNGTPANNG : public IndexNGT {
|
||||
public:
|
||||
IndexNGTPANNG() {
|
||||
index_type_ = IndexEnum::INDEX_NGTPANNG;
|
||||
}
|
||||
|
||||
void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -21,6 +21,8 @@
|
|||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/IndexNGTONNG.h"
|
||||
#include "knowhere/index/vector_index/IndexNGTPANNG.h"
|
||||
#include "knowhere/index/vector_index/IndexRHNSWFlat.h"
|
||||
#include "knowhere/index/vector_index/IndexRHNSWPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexRHNSWSQ.h"
|
||||
|
@ -99,6 +101,10 @@ VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) {
|
|||
return std::make_shared<knowhere::IndexRHNSWPQ>();
|
||||
} else if (type == IndexEnum::INDEX_RHNSWSQ) {
|
||||
return std::make_shared<knowhere::IndexRHNSWSQ>();
|
||||
} else if (type == IndexEnum::INDEX_NGTPANNG) {
|
||||
return std::make_shared<knowhere::IndexNGTPANNG>();
|
||||
} else if (type == IndexEnum::INDEX_NGTONNG) {
|
||||
return std::make_shared<knowhere::IndexNGTONNG>();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
if(APPLE)
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
else()
|
||||
cmake_minimum_required(VERSION 2.8)
|
||||
endif()
|
||||
|
||||
project(ngt)
|
||||
|
||||
file(STRINGS "VERSION" ngt_VERSION)
|
||||
message(STATUS "VERSION: ${ngt_VERSION}")
|
||||
string(REGEX MATCH "^[0-9]+" ngt_VERSION_MAJOR ${ngt_VERSION})
|
||||
|
||||
set(ngt_VERSION ${ngt_VERSION})
|
||||
set(ngt_SOVERSION ${ngt_VERSION_MAJOR})
|
||||
|
||||
if (NOT CMAKE_BUILD_TYPE)
|
||||
set (CMAKE_BUILD_TYPE "Release")
|
||||
endif (NOT CMAKE_BUILD_TYPE)
|
||||
string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_LOWER)
|
||||
message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
|
||||
message(STATUS "CMAKE_BUILD_TYPE_LOWER: ${CMAKE_BUILD_TYPE_LOWER}")
|
||||
|
||||
if(${UNIX})
|
||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
|
||||
|
||||
if(CMAKE_VERSION VERSION_LESS 3.1)
|
||||
set(BASE_OPTIONS "-Wall -std=gnu++0x -lrt")
|
||||
|
||||
if(${NGT_AVX_DISABLED})
|
||||
message(STATUS "AVX will not be used to compute distances.")
|
||||
endif()
|
||||
|
||||
if(${NGT_OPENMP_DISABLED})
|
||||
message(STATUS "OpenMP is disabled.")
|
||||
else()
|
||||
set(BASE_OPTIONS "${BASE_OPTIONS} -fopenmp")
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "-g ${BASE_OPTIONS}")
|
||||
|
||||
if(${NGT_MARCH_NATIVE_DISABLED})
|
||||
message(STATUS "Compile option -march=native is disabled.")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "-O2 ${BASE_OPTIONS}")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=native ${BASE_OPTIONS}")
|
||||
endif()
|
||||
else()
|
||||
if (CMAKE_BUILD_TYPE_LOWER STREQUAL "release")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "")
|
||||
if(${NGT_MARCH_NATIVE_DISABLED})
|
||||
message(STATUS "Compile option -march=native is disabled.")
|
||||
add_compile_options(-O2 -DNDEBUG)
|
||||
else()
|
||||
add_compile_options(-Ofast -march=native -DNDEBUG)
|
||||
endif()
|
||||
endif()
|
||||
add_compile_options(-Wall)
|
||||
if(${NGT_AVX_DISABLED})
|
||||
message(STATUS "AVX will not be used to compute distances.")
|
||||
endif()
|
||||
if(${NGT_OPENMP_DISABLED})
|
||||
message(STATUS "OpenMP is disabled.")
|
||||
else()
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
|
||||
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "8.1.0")
|
||||
message(FATAL_ERROR "Insufficient AppleClang version")
|
||||
endif()
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
endif()
|
||||
find_package(OpenMP REQUIRED)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
endif()
|
||||
set(CMAKE_CXX_STANDARD 11) # for std::unordered_set, std::unique_ptr
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
find_package(Threads REQUIRED)
|
||||
endif()
|
||||
|
||||
add_subdirectory("${PROJECT_SOURCE_DIR}/lib")
|
||||
endif( ${UNIX} )
|
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1 @@
|
|||
1.12.0
|
|
@ -0,0 +1,3 @@
|
|||
if( ${UNIX} )
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/lib/NGT)
|
||||
endif()
|
|
@ -0,0 +1,89 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include "ArrayFile.h"
|
||||
#include <iostream>
|
||||
#include <assert.h>
|
||||
|
||||
class ItemID {
|
||||
public:
|
||||
void serialize(std::ostream &os, NGT::ObjectSpace *ospace = 0) {
|
||||
os.write((char*)&value, sizeof(value));
|
||||
}
|
||||
void deserialize(std::istream &is, NGT::ObjectSpace *ospace = 0) {
|
||||
is.read((char*)&value, sizeof(value));
|
||||
}
|
||||
static size_t getSerializedDataSize() {
|
||||
return sizeof(uint64_t);
|
||||
}
|
||||
uint64_t value;
|
||||
};
|
||||
|
||||
void
|
||||
sampleForUsage() {
|
||||
{
|
||||
ArrayFile<ItemID> itemIDFile;
|
||||
itemIDFile.create("test.data", ItemID::getSerializedDataSize());
|
||||
itemIDFile.open("test.data");
|
||||
ItemID itemID;
|
||||
size_t id;
|
||||
|
||||
id = 1;
|
||||
itemID.value = 4910002490100;
|
||||
itemIDFile.put(id, itemID);
|
||||
itemID.value = 0;
|
||||
itemIDFile.get(id, itemID);
|
||||
std::cerr << "value=" << itemID.value << std::endl;
|
||||
assert(itemID.value == 4910002490100);
|
||||
|
||||
id = 2;
|
||||
itemID.value = 4910002490101;
|
||||
itemIDFile.put(id, itemID);
|
||||
itemID.value = 0;
|
||||
itemIDFile.get(id, itemID);
|
||||
std::cerr << "value=" << itemID.value << std::endl;
|
||||
assert(itemID.value == 4910002490101);
|
||||
|
||||
itemID.value = 4910002490102;
|
||||
id = itemIDFile.insert(itemID);
|
||||
itemID.value = 0;
|
||||
itemIDFile.get(id, itemID);
|
||||
std::cerr << "value=" << itemID.value << std::endl;
|
||||
assert(itemID.value == 4910002490102);
|
||||
|
||||
itemIDFile.close();
|
||||
}
|
||||
{
|
||||
ArrayFile<ItemID> itemIDFile;
|
||||
itemIDFile.create("test.data", ItemID::getSerializedDataSize());
|
||||
itemIDFile.open("test.data");
|
||||
ItemID itemID;
|
||||
size_t id;
|
||||
|
||||
id = 10;
|
||||
itemIDFile.get(id, itemID);
|
||||
std::cerr << "value=" << itemID.value << std::endl;
|
||||
assert(itemID.value == 4910002490100);
|
||||
|
||||
id = 20;
|
||||
itemIDFile.get(id, itemID);
|
||||
std::cerr << "value=" << itemID.value << std::endl;
|
||||
assert(itemID.value == 4910002490101);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,220 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <cstddef>
|
||||
#include <stdint.h>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <cerrno>
|
||||
#include <cstring>
|
||||
|
||||
namespace NGT {
|
||||
class ObjectSpace;
|
||||
};
|
||||
|
||||
template <class TYPE>
|
||||
class ArrayFile {
|
||||
private:
|
||||
struct FileHeadStruct {
|
||||
size_t recordSize;
|
||||
uint64_t extraData; // reserve
|
||||
};
|
||||
|
||||
struct RecordStruct {
|
||||
bool deleteFlag;
|
||||
uint64_t extraData; // reserve
|
||||
};
|
||||
|
||||
bool _isOpen;
|
||||
std::fstream _stream;
|
||||
FileHeadStruct _fileHead;
|
||||
|
||||
bool _readFileHead();
|
||||
pthread_mutex_t _mutex;
|
||||
|
||||
public:
|
||||
ArrayFile();
|
||||
~ArrayFile();
|
||||
bool create(const std::string &file, size_t recordSize);
|
||||
bool open(const std::string &file);
|
||||
void close();
|
||||
size_t insert(TYPE &data, NGT::ObjectSpace *objectSpace = 0);
|
||||
void put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0);
|
||||
bool get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0);
|
||||
void remove(const size_t id);
|
||||
bool isOpen() const;
|
||||
size_t size();
|
||||
size_t getRecordSize() { return _fileHead.recordSize; }
|
||||
};
|
||||
|
||||
|
||||
// constructor
|
||||
template <class TYPE>
|
||||
ArrayFile<TYPE>::ArrayFile()
|
||||
: _isOpen(false), _mutex((pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER){
|
||||
if(pthread_mutex_init(&_mutex, NULL) < 0) throw std::runtime_error("pthread init error.");
|
||||
}
|
||||
|
||||
// destructor
|
||||
template <class TYPE>
|
||||
ArrayFile<TYPE>::~ArrayFile() {
|
||||
pthread_mutex_destroy(&_mutex);
|
||||
close();
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
bool ArrayFile<TYPE>::create(const std::string &file, size_t recordSize) {
|
||||
std::fstream tmpstream;
|
||||
tmpstream.open(file.c_str());
|
||||
if(tmpstream){
|
||||
return false;
|
||||
}
|
||||
|
||||
tmpstream.open(file.c_str(), std::ios::out);
|
||||
tmpstream.seekp(0, std::ios::beg);
|
||||
FileHeadStruct fileHead = {recordSize, 0};
|
||||
tmpstream.write((char *)(&fileHead), sizeof(FileHeadStruct));
|
||||
tmpstream.close();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
bool ArrayFile<TYPE>::open(const std::string &file) {
|
||||
_stream.open(file.c_str(), std::ios::in | std::ios::out);
|
||||
if(!_stream){
|
||||
_isOpen = false;
|
||||
return false;
|
||||
}
|
||||
_isOpen = true;
|
||||
|
||||
bool ret = _readFileHead();
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
void ArrayFile<TYPE>::close(){
|
||||
_stream.close();
|
||||
_isOpen = false;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
size_t ArrayFile<TYPE>::insert(TYPE &data, NGT::ObjectSpace *objectSpace) {
|
||||
_stream.seekp(sizeof(RecordStruct), std::ios::end);
|
||||
int64_t write_pos = _stream.tellg();
|
||||
for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); }
|
||||
_stream.seekp(write_pos, std::ios::beg);
|
||||
data.serialize(_stream, objectSpace);
|
||||
|
||||
int64_t offset_pos = _stream.tellg();
|
||||
offset_pos -= sizeof(FileHeadStruct);
|
||||
size_t id = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize);
|
||||
if(offset_pos % (sizeof(RecordStruct) + _fileHead.recordSize) == 0){
|
||||
id -= 1;
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
void ArrayFile<TYPE>::put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) {
|
||||
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
|
||||
offset_pos += sizeof(RecordStruct);
|
||||
_stream.seekp(offset_pos, std::ios::beg);
|
||||
|
||||
for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); }
|
||||
_stream.seekp(offset_pos, std::ios::beg);
|
||||
data.serialize(_stream, objectSpace);
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
bool ArrayFile<TYPE>::get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) {
|
||||
pthread_mutex_lock(&_mutex);
|
||||
|
||||
if( size() <= id ){
|
||||
pthread_mutex_unlock(&_mutex);
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
|
||||
offset_pos += sizeof(RecordStruct);
|
||||
_stream.seekg(offset_pos, std::ios::beg);
|
||||
if (!_stream.fail()) {
|
||||
data.deserialize(_stream, objectSpace);
|
||||
}
|
||||
if (_stream.fail()) {
|
||||
const int trialCount = 10;
|
||||
for (int tc = 0; tc < trialCount; tc++) {
|
||||
_stream.clear();
|
||||
_stream.seekg(offset_pos, std::ios::beg);
|
||||
if (_stream.fail()) {
|
||||
continue;
|
||||
}
|
||||
data.deserialize(_stream, objectSpace);
|
||||
if (_stream.fail()) {
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (_stream.fail()) {
|
||||
throw std::runtime_error("ArrayFile::get: Error!");
|
||||
}
|
||||
}
|
||||
|
||||
pthread_mutex_unlock(&_mutex);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
void ArrayFile<TYPE>::remove(const size_t id) {
|
||||
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
|
||||
_stream.seekp(offset_pos, std::ios::beg);
|
||||
RecordStruct recordHead = {1, 0};
|
||||
_stream.write((char *)(&recordHead), sizeof(RecordStruct));
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
bool ArrayFile<TYPE>::isOpen() const
|
||||
{
|
||||
return _isOpen;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
size_t ArrayFile<TYPE>::size()
|
||||
{
|
||||
_stream.seekp(0, std::ios::end);
|
||||
int64_t offset_pos = _stream.tellg();
|
||||
offset_pos -= sizeof(FileHeadStruct);
|
||||
size_t num = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize);
|
||||
|
||||
return num;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
bool ArrayFile<TYPE>::_readFileHead() {
|
||||
_stream.seekp(0, std::ios::beg);
|
||||
_stream.read((char *)(&_fileHead), sizeof(FileHeadStruct));
|
||||
if(_stream.bad()){
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
if( ${UNIX} )
|
||||
option(NGT_SHARED_MEMORY_ALLOCATOR "enable shared memory" OFF)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/defines.h.in ${CMAKE_CURRENT_BINARY_DIR}/defines.h)
|
||||
include_directories("${CMAKE_CURRENT_BINARY_DIR}" "${PROJECT_SOURCE_DIR}/lib" "${PROJECT_BINARY_DIR}/lib/")
|
||||
include_directories("${PROJECT_SOURCE_DIR}/../")
|
||||
|
||||
file(GLOB NGT_SOURCES *.cpp)
|
||||
file(GLOB HEADER_FILES *.h *.hpp)
|
||||
file(GLOB NGTQ_HEADER_FILES NGTQ/*.h NGTQ/*.hpp)
|
||||
|
||||
add_library(ngtstatic STATIC ${NGT_SOURCES})
|
||||
set_target_properties(ngtstatic PROPERTIES OUTPUT_NAME ngt)
|
||||
set_target_properties(ngtstatic PROPERTIES COMPILE_FLAGS "-fPIC")
|
||||
target_link_libraries(ngtstatic)
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
|
||||
target_link_libraries(ngtstatic OpenMP::OpenMP_CXX)
|
||||
endif()
|
||||
|
||||
add_library(ngt SHARED ${NGT_SOURCES})
|
||||
set_target_properties(ngt PROPERTIES VERSION ${ngt_VERSION})
|
||||
set_target_properties(ngt PROPERTIES SOVERSION ${ngt_SOVERSION})
|
||||
add_dependencies(ngt ngtstatic)
|
||||
if(${APPLE})
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
|
||||
target_link_libraries(ngt OpenMP::OpenMP_CXX)
|
||||
else()
|
||||
target_link_libraries(ngt gomp)
|
||||
endif()
|
||||
else(${APPLE})
|
||||
target_link_libraries(ngt gomp rt)
|
||||
endif(${APPLE})
|
||||
|
||||
install(TARGETS
|
||||
ngt
|
||||
ngtstatic
|
||||
RUNTIME DESTINATION bin
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib)
|
||||
|
||||
endif()
|
|
@ -0,0 +1,988 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "NGT/Index.h"
|
||||
#include "NGT/GraphOptimizer.h"
|
||||
#include "Capi.h"
|
||||
|
||||
static bool operate_error_string_(const std::stringstream &ss, NGTError error){
|
||||
if(error != NULL){
|
||||
try{
|
||||
std::string *error_str = static_cast<std::string*>(error);
|
||||
*error_str = ss.str();
|
||||
}catch(std::exception &err){
|
||||
std::cerr << ss.str() << " > " << err.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
}else{
|
||||
std::cerr << ss.str() << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
NGTIndex ngt_open_index(const char *index_path, NGTError error) {
|
||||
try{
|
||||
std::string index_path_str(index_path);
|
||||
NGT::Index *index = new NGT::Index(index_path_str);
|
||||
index->disableLog();
|
||||
return static_cast<NGTIndex>(index);
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
NGTIndex ngt_create_graph_and_tree(const char *database, NGTProperty prop, NGTError error) {
|
||||
NGT::Index *index = NULL;
|
||||
try{
|
||||
std::string database_str(database);
|
||||
NGT::Property prop_i = *(static_cast<NGT::Property*>(prop));
|
||||
NGT::Index::createGraphAndTree(database_str, prop_i, true);
|
||||
index = new NGT::Index(database_str);
|
||||
index->disableLog();
|
||||
return static_cast<NGTIndex>(index);
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
delete index;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
NGTIndex ngt_create_graph_and_tree_in_memory(NGTProperty prop, NGTError error) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << __FUNCTION__ << " is unavailable for shared-memory-type NGT.";
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
#else
|
||||
try{
|
||||
NGT::Index *index = new NGT::GraphAndTreeIndex(*(static_cast<NGT::Property*>(prop)));
|
||||
index->disableLog();
|
||||
return static_cast<NGTIndex>(index);
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
NGTProperty ngt_create_property(NGTError error) {
|
||||
try{
|
||||
return static_cast<NGTProperty>(new NGT::Property());
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
bool ngt_save_index(const NGTIndex index, const char *database, NGTError error) {
|
||||
try{
|
||||
std::string database_str(database);
|
||||
(static_cast<NGT::Index*>(index))->saveIndex(database_str);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_get_property(NGTIndex index, NGTProperty prop, NGTError error) {
|
||||
if(index == NULL || prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
try{
|
||||
(static_cast<NGT::Index*>(index))->getProperty(*(static_cast<NGT::Property*>(prop)));
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t ngt_get_property_dimension(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
return (*static_cast<NGT::Property*>(prop)).dimension;
|
||||
}
|
||||
|
||||
bool ngt_set_property_dimension(NGTProperty prop, int32_t value, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
(*static_cast<NGT::Property*>(prop)).dimension = value;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_edge_size_for_creation(NGTProperty prop, int16_t value, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
(*static_cast<NGT::Property*>(prop)).edgeSizeForCreation = value;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_edge_size_for_search(NGTProperty prop, int16_t value, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
(*static_cast<NGT::Property*>(prop)).edgeSizeForSearch = value;
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t ngt_get_property_object_type(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
return (*static_cast<NGT::Property*>(prop)).objectType;
|
||||
}
|
||||
|
||||
bool ngt_is_property_object_type_float(int32_t object_type) {
|
||||
return (object_type == NGT::ObjectSpace::ObjectType::Float);
|
||||
}
|
||||
|
||||
bool ngt_is_property_object_type_integer(int32_t object_type) {
|
||||
return (object_type == NGT::ObjectSpace::ObjectType::Uint8);
|
||||
}
|
||||
|
||||
bool ngt_set_property_object_type_float(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).objectType = NGT::ObjectSpace::ObjectType::Float;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_object_type_integer(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).objectType = NGT::ObjectSpace::ObjectType::Uint8;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_l1(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL1;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_l2(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_angle(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeAngle;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_hamming(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_jaccard(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_cosine(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeCosine;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_normalized_angle(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedAngle;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_normalized_cosine(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedCosine;
|
||||
return true;
|
||||
}
|
||||
|
||||
NGTObjectDistances ngt_create_empty_results(NGTError error) {
|
||||
try{
|
||||
return static_cast<NGTObjectDistances>(new NGT::ObjectDistances());
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static bool ngt_search_index_(NGT::Index* pindex, NGT::Object *ngtquery, size_t size, float epsilon, float radius, NGTObjectDistances results, int edge_size = INT_MIN) {
|
||||
// set search prameters.
|
||||
NGT::SearchContainer sc(*ngtquery); // search parametera container.
|
||||
|
||||
sc.setResults(static_cast<NGT::ObjectDistances*>(results)); // set the result set.
|
||||
sc.setSize(size); // the number of resultant objects.
|
||||
sc.setRadius(radius); // search radius.
|
||||
sc.setEpsilon(epsilon); // set exploration coefficient.
|
||||
if (edge_size != INT_MIN) {
|
||||
sc.setEdgeSize(edge_size);// set # of edges for each node
|
||||
}
|
||||
|
||||
pindex->search(sc);
|
||||
|
||||
// delete the query object.
|
||||
pindex->deleteObject(ngtquery);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_search_index(NGTIndex index, double *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) {
|
||||
if(index == NULL || query == NULL || results == NULL || query_dim <= 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
NGT::Object *ngtquery = NULL;
|
||||
|
||||
if(radius < 0.0){
|
||||
radius = FLT_MAX;
|
||||
}
|
||||
|
||||
try{
|
||||
std::vector<double> vquery(&query[0], &query[query_dim]);
|
||||
ngtquery = pindex->allocateObject(vquery);
|
||||
ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
if(ngtquery != NULL){
|
||||
pindex->deleteObject(ngtquery);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_search_index_as_float(NGTIndex index, float *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) {
|
||||
if(index == NULL || query == NULL || results == NULL || query_dim <= 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
NGT::Object *ngtquery = NULL;
|
||||
|
||||
if(radius < 0.0){
|
||||
radius = FLT_MAX;
|
||||
}
|
||||
|
||||
try{
|
||||
std::vector<float> vquery(&query[0], &query[query_dim]);
|
||||
ngtquery = pindex->allocateObject(vquery);
|
||||
ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
if(ngtquery != NULL){
|
||||
pindex->deleteObject(ngtquery);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_search_index_with_query(NGTIndex index, NGTQuery query, NGTObjectDistances results, NGTError error) {
|
||||
if(index == NULL || query.query == NULL || results == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query.query << " results = " << results;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
int32_t dim = pindex->getObjectSpace().getDimension();
|
||||
|
||||
NGT::Object *ngtquery = NULL;
|
||||
|
||||
if(query.radius < 0.0){
|
||||
query.radius = FLT_MAX;
|
||||
}
|
||||
|
||||
try{
|
||||
std::vector<float> vquery(&query.query[0], &query.query[dim]);
|
||||
ngtquery = pindex->allocateObject(vquery);
|
||||
ngt_search_index_(pindex, ngtquery, query.size, query.epsilon, query.radius, results, query.edge_size);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
if(ngtquery != NULL){
|
||||
pindex->deleteObject(ngtquery);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// * deprecated *
|
||||
int32_t ngt_get_size(NGTObjectDistances results, NGTError error) {
|
||||
if(results == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
|
||||
return (static_cast<NGT::ObjectDistances*>(results))->size();
|
||||
}
|
||||
|
||||
uint32_t ngt_get_result_size(NGTObjectDistances results, NGTError error) {
|
||||
if(results == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results;
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (static_cast<NGT::ObjectDistances*>(results))->size();
|
||||
}
|
||||
|
||||
NGTObjectDistance ngt_get_result(const NGTObjectDistances results, const uint32_t i, NGTError error) {
|
||||
try{
|
||||
NGT::ObjectDistances objects = *(static_cast<NGT::ObjectDistances*>(results));
|
||||
NGTObjectDistance ret_val = {objects[i].id, objects[i].distance};
|
||||
return ret_val;
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
|
||||
NGTObjectDistance err_val = {0};
|
||||
return err_val;
|
||||
}
|
||||
}
|
||||
|
||||
ObjectID ngt_insert_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) {
|
||||
if(index == NULL || obj == NULL || obj_dim == 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
try{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
std::vector<double> vobj(&obj[0], &obj[obj_dim]);
|
||||
return pindex->insert(vobj);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
ObjectID ngt_append_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) {
|
||||
if(index == NULL || obj == NULL || obj_dim == 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
try{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
std::vector<double> vobj(&obj[0], &obj[obj_dim]);
|
||||
return pindex->append(vobj);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
ObjectID ngt_insert_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) {
|
||||
if(index == NULL || obj == NULL || obj_dim == 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
try{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
std::vector<float> vobj(&obj[0], &obj[obj_dim]);
|
||||
return pindex->insert(vobj);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
ObjectID ngt_append_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) {
|
||||
if(index == NULL || obj == NULL || obj_dim == 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
try{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
std::vector<float> vobj(&obj[0], &obj[obj_dim]);
|
||||
return pindex->append(vobj);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool ngt_batch_append_index(NGTIndex index, float *obj, uint32_t data_count, NGTError error) {
|
||||
try{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
pindex->append(obj, data_count);
|
||||
return true;
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool ngt_batch_insert_index(NGTIndex index, float *obj, uint32_t data_count, uint32_t *ids, NGTError error) {
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
int32_t dim = pindex->getObjectSpace().getDimension();
|
||||
|
||||
bool status = true;
|
||||
float *objptr = obj;
|
||||
for (size_t idx = 0; idx < data_count; idx++, objptr += dim) {
|
||||
try{
|
||||
std::vector<double> vobj(objptr, objptr + dim);
|
||||
ids[idx] = pindex->insert(vobj);
|
||||
}catch(std::exception &err) {
|
||||
status = false;
|
||||
ids[idx] = 0;
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
bool ngt_create_index(NGTIndex index, uint32_t pool_size, NGTError error) {
|
||||
if(index == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
try{
|
||||
(static_cast<NGT::Index*>(index))->createIndex(pool_size);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_remove_index(NGTIndex index, ObjectID id, NGTError error) {
|
||||
if(index == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
try{
|
||||
(static_cast<NGT::Index*>(index))->remove(id);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
NGTObjectSpace ngt_get_object_space(NGTIndex index, NGTError error) {
|
||||
if(index == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
try{
|
||||
return static_cast<NGTObjectSpace>(&(static_cast<NGT::Index*>(index))->getObjectSpace());
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
float* ngt_get_object_as_float(NGTObjectSpace object_space, ObjectID id, NGTError error) {
|
||||
if(object_space == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space;
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
try{
|
||||
return static_cast<float*>((static_cast<NGT::ObjectSpace*>(object_space))->getObject(id));
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t* ngt_get_object_as_integer(NGTObjectSpace object_space, ObjectID id, NGTError error) {
|
||||
if(object_space == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space;
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
try{
|
||||
return static_cast<uint8_t*>((static_cast<NGT::ObjectSpace*>(object_space))->getObject(id));
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void ngt_destroy_results(NGTObjectDistances results) {
|
||||
if(results == NULL) return;
|
||||
delete(static_cast<NGT::ObjectDistances*>(results));
|
||||
}
|
||||
|
||||
void ngt_destroy_property(NGTProperty prop) {
|
||||
if(prop == NULL) return;
|
||||
delete(static_cast<NGT::Property*>(prop));
|
||||
}
|
||||
|
||||
void ngt_close_index(NGTIndex index) {
|
||||
if(index == NULL) return;
|
||||
(static_cast<NGT::Index*>(index))->close();
|
||||
delete(static_cast<NGT::Index*>(index));
|
||||
}
|
||||
|
||||
int16_t ngt_get_property_edge_size_for_creation(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
return (*static_cast<NGT::Property*>(prop)).edgeSizeForCreation;
|
||||
}
|
||||
|
||||
int16_t ngt_get_property_edge_size_for_search(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
return (*static_cast<NGT::Property*>(prop)).edgeSizeForSearch;
|
||||
}
|
||||
|
||||
int32_t ngt_get_property_distance_type(NGTProperty prop, NGTError error){
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
return (*static_cast<NGT::Property*>(prop)).distanceType;
|
||||
}
|
||||
|
||||
NGTError ngt_create_error_object()
|
||||
{
|
||||
try{
|
||||
std::string *error_str = new std::string();
|
||||
return static_cast<NGTError>(error_str);
|
||||
}catch(std::exception &err){
|
||||
std::cerr << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
const char *ngt_get_error_string(const NGTError error)
|
||||
{
|
||||
std::string *error_str = static_cast<std::string*>(error);
|
||||
return error_str->c_str();
|
||||
}
|
||||
|
||||
void ngt_clear_error_string(NGTError error)
|
||||
{
|
||||
std::string *error_str = static_cast<std::string*>(error);
|
||||
*error_str = "";
|
||||
}
|
||||
|
||||
void ngt_destroy_error_object(NGTError error)
|
||||
{
|
||||
std::string *error_str = static_cast<std::string*>(error);
|
||||
delete error_str;
|
||||
}
|
||||
|
||||
NGTOptimizer ngt_create_optimizer(bool logDisabled, NGTError error)
|
||||
{
|
||||
try{
|
||||
return static_cast<NGTOptimizer>(new NGT::GraphOptimizer(logDisabled));
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
bool ngt_optimizer_adjust_search_coefficients(NGTOptimizer optimizer, const char *index, NGTError error) {
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
try{
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->adjustSearchCoefficients(std::string(index));
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_optimizer_execute(NGTOptimizer optimizer, const char *inIndex, const char *outIndex, NGTError error) {
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
try{
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->execute(std::string(inIndex), std::string(outIndex));
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// obsolute because of a lack of a parameter
|
||||
bool ngt_optimizer_set(NGTOptimizer optimizer, int outgoing, int incoming, int nofqs,
|
||||
float baseAccuracyFrom, float baseAccuracyTo,
|
||||
float rateAccuracyFrom, float rateAccuracyTo,
|
||||
double gte, double m, NGTError error) {
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
try{
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->set(outgoing, incoming, nofqs, baseAccuracyFrom, baseAccuracyTo,
|
||||
rateAccuracyFrom, rateAccuracyTo, gte, m);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_optimizer_set_minimum(NGTOptimizer optimizer, int outgoing, int incoming,
|
||||
int nofqs, int nofrs, NGTError error) {
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
try{
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->set(outgoing, incoming, nofqs, nofrs);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_optimizer_set_extension(NGTOptimizer optimizer,
|
||||
float baseAccuracyFrom, float baseAccuracyTo,
|
||||
float rateAccuracyFrom, float rateAccuracyTo,
|
||||
double gte, double m, NGTError error) {
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
try{
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->setExtension(baseAccuracyFrom, baseAccuracyTo,
|
||||
rateAccuracyFrom, rateAccuracyTo, gte, m);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_optimizer_set_processing_modes(NGTOptimizer optimizer, bool searchParameter,
|
||||
bool prefetchParameter, bool accuracyTable, NGTError error)
|
||||
{
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->setProcessingModes(searchParameter, prefetchParameter,
|
||||
accuracyTable);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ngt_destroy_optimizer(NGTOptimizer optimizer)
|
||||
{
|
||||
if(optimizer == NULL) return;
|
||||
delete(static_cast<NGT::GraphOptimizer*>(optimizer));
|
||||
}
|
||||
|
||||
bool ngt_refine_anng(NGTIndex index, float epsilon, float accuracy, int noOfEdges, int exploreEdgeSize, size_t batchSize, NGTError error)
|
||||
{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
try {
|
||||
NGT::GraphReconstructor::refineANNG(*pindex, true, epsilon, accuracy, noOfEdges, exploreEdgeSize, batchSize);
|
||||
} catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_get_edges(NGTIndex index, ObjectID id, NGTObjectDistances edges, NGTError error)
|
||||
{
|
||||
if(index == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
NGT::GraphIndex &graph = static_cast<NGT::GraphIndex&>(pindex->getIndex());
|
||||
|
||||
try {
|
||||
NGT::ObjectDistances &objects = *static_cast<NGT::ObjectDistances*>(edges);
|
||||
objects = *graph.getNode(id);
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t ngt_get_object_repository_size(NGTIndex index, NGTError error)
|
||||
{
|
||||
if(index == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
NGT::Index& pindex = *static_cast<NGT::Index*>(index);
|
||||
return pindex.getObjectRepositorySize();
|
||||
}
|
||||
|
||||
NGTAnngEdgeOptimizationParameter ngt_get_anng_edge_optimization_parameter()
|
||||
{
|
||||
NGT::GraphOptimizer::ANNGEdgeOptimizationParameter gp;
|
||||
NGTAnngEdgeOptimizationParameter parameter;
|
||||
|
||||
parameter.no_of_queries = gp.noOfQueries;
|
||||
parameter.no_of_results = gp.noOfResults;
|
||||
parameter.no_of_threads = gp.noOfThreads;
|
||||
parameter.target_accuracy = gp.targetAccuracy;
|
||||
parameter.target_no_of_objects = gp.targetNoOfObjects;
|
||||
parameter.no_of_sample_objects = gp.noOfSampleObjects;
|
||||
parameter.max_of_no_of_edges = gp.maxNoOfEdges;
|
||||
parameter.log = false;
|
||||
|
||||
return parameter;
|
||||
}
|
||||
|
||||
bool ngt_optimize_number_of_edges(const char *indexPath, NGTAnngEdgeOptimizationParameter parameter, NGTError error)
|
||||
{
|
||||
|
||||
NGT::GraphOptimizer::ANNGEdgeOptimizationParameter p;
|
||||
|
||||
p.noOfQueries = parameter.no_of_queries;
|
||||
p.noOfResults = parameter.no_of_results;
|
||||
p.noOfThreads = parameter.no_of_threads;
|
||||
p.targetAccuracy = parameter.target_accuracy;
|
||||
p.targetNoOfObjects = parameter.target_no_of_objects;
|
||||
p.noOfSampleObjects = parameter.no_of_sample_objects;
|
||||
p.maxNoOfEdges = parameter.max_of_no_of_edges;
|
||||
|
||||
try {
|
||||
NGT::GraphOptimizer graphOptimizer(!parameter.log); // false=log
|
||||
std::string path(indexPath);
|
||||
auto edge = graphOptimizer.optimizeNumberOfEdgesForANNG(path, p);
|
||||
if (parameter.log) {
|
||||
std::cerr << "the optimized number of edges is" << edge.first << "(" << edge.second << ")" << std::endl;
|
||||
}
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
typedef unsigned int ObjectID;
|
||||
typedef void* NGTIndex;
|
||||
typedef void* NGTProperty;
|
||||
typedef void* NGTObjectSpace;
|
||||
typedef void* NGTObjectDistances;
|
||||
typedef void* NGTError;
|
||||
typedef void* NGTOptimizer;
|
||||
|
||||
typedef struct {
|
||||
ObjectID id;
|
||||
float distance;
|
||||
} NGTObjectDistance;
|
||||
|
||||
typedef struct {
|
||||
float *query;
|
||||
size_t size; // # of returned objects
|
||||
float epsilon;
|
||||
float accuracy; // expected accuracy
|
||||
float radius;
|
||||
size_t edge_size; // # of edges to explore for each node
|
||||
} NGTQuery;
|
||||
|
||||
typedef struct {
|
||||
size_t no_of_queries;
|
||||
size_t no_of_results;
|
||||
size_t no_of_threads;
|
||||
float target_accuracy;
|
||||
size_t target_no_of_objects;
|
||||
size_t no_of_sample_objects;
|
||||
size_t max_of_no_of_edges;
|
||||
bool log;
|
||||
} NGTAnngEdgeOptimizationParameter;
|
||||
|
||||
NGTIndex ngt_open_index(const char *, NGTError);
|
||||
|
||||
NGTIndex ngt_create_graph_and_tree(const char *, NGTProperty, NGTError);
|
||||
|
||||
NGTIndex ngt_create_graph_and_tree_in_memory(NGTProperty, NGTError);
|
||||
|
||||
NGTProperty ngt_create_property(NGTError);
|
||||
|
||||
bool ngt_save_index(const NGTIndex, const char *, NGTError);
|
||||
|
||||
bool ngt_get_property(const NGTIndex, NGTProperty, NGTError);
|
||||
|
||||
int32_t ngt_get_property_dimension(NGTProperty, NGTError);
|
||||
|
||||
bool ngt_set_property_dimension(NGTProperty, int32_t, NGTError);
|
||||
|
||||
bool ngt_set_property_edge_size_for_creation(NGTProperty, int16_t, NGTError);
|
||||
|
||||
bool ngt_set_property_edge_size_for_search(NGTProperty, int16_t, NGTError);
|
||||
|
||||
int32_t ngt_get_property_object_type(NGTProperty, NGTError);
|
||||
|
||||
bool ngt_is_property_object_type_float(int32_t);
|
||||
|
||||
bool ngt_is_property_object_type_integer(int32_t);
|
||||
|
||||
bool ngt_set_property_object_type_float(NGTProperty, NGTError);
|
||||
|
||||
bool ngt_set_property_object_type_integer(NGTProperty, NGTError);
|
||||
|
||||
bool ngt_set_property_distance_type_l1(NGTProperty, NGTError);
|
||||
|
||||
bool ngt_set_property_distance_type_l2(NGTProperty, NGTError);
|
||||
|
||||
bool ngt_set_property_distance_type_angle(NGTProperty, NGTError);
|
||||
|
||||
bool ngt_set_property_distance_type_hamming(NGTProperty, NGTError);
|
||||
|
||||
bool ngt_set_property_distance_type_jaccard(NGTProperty, NGTError);
|
||||
|
||||
bool ngt_set_property_distance_type_cosine(NGTProperty, NGTError);
|
||||
|
||||
bool ngt_set_property_distance_type_normalized_angle(NGTProperty, NGTError);
|
||||
|
||||
bool ngt_set_property_distance_type_normalized_cosine(NGTProperty, NGTError);
|
||||
|
||||
NGTObjectDistances ngt_create_empty_results(NGTError);
|
||||
|
||||
bool ngt_search_index(NGTIndex, double*, int32_t, size_t, float, float, NGTObjectDistances, NGTError);
|
||||
|
||||
bool ngt_search_index_as_float(NGTIndex, float*, int32_t, size_t, float, float, NGTObjectDistances, NGTError);
|
||||
|
||||
bool ngt_search_index_with_query(NGTIndex, NGTQuery, NGTObjectDistances, NGTError);
|
||||
|
||||
int32_t ngt_get_size(NGTObjectDistances, NGTError); // deprecated
|
||||
|
||||
uint32_t ngt_get_result_size(NGTObjectDistances, NGTError);
|
||||
|
||||
NGTObjectDistance ngt_get_result(const NGTObjectDistances, const uint32_t, NGTError);
|
||||
|
||||
ObjectID ngt_insert_index(NGTIndex, double*, uint32_t, NGTError);
|
||||
|
||||
ObjectID ngt_append_index(NGTIndex, double*, uint32_t, NGTError);
|
||||
|
||||
ObjectID ngt_insert_index_as_float(NGTIndex, float*, uint32_t, NGTError);
|
||||
|
||||
ObjectID ngt_append_index_as_float(NGTIndex, float*, uint32_t, NGTError);
|
||||
|
||||
bool ngt_batch_append_index(NGTIndex, float*, uint32_t, NGTError);
|
||||
|
||||
bool ngt_batch_insert_index(NGTIndex, float*, uint32_t, uint32_t *, NGTError);
|
||||
|
||||
bool ngt_create_index(NGTIndex, uint32_t, NGTError);
|
||||
|
||||
bool ngt_remove_index(NGTIndex, ObjectID, NGTError);
|
||||
|
||||
NGTObjectSpace ngt_get_object_space(NGTIndex, NGTError);
|
||||
|
||||
float* ngt_get_object_as_float(NGTObjectSpace, ObjectID, NGTError);
|
||||
|
||||
uint8_t* ngt_get_object_as_integer(NGTObjectSpace, ObjectID, NGTError);
|
||||
|
||||
void ngt_destroy_results(NGTObjectDistances);
|
||||
|
||||
void ngt_destroy_property(NGTProperty);
|
||||
|
||||
void ngt_close_index(NGTIndex);
|
||||
|
||||
int16_t ngt_get_property_edge_size_for_creation(NGTProperty, NGTError);
|
||||
|
||||
int16_t ngt_get_property_edge_size_for_search(NGTProperty, NGTError);
|
||||
|
||||
int32_t ngt_get_property_distance_type(NGTProperty, NGTError);
|
||||
|
||||
NGTError ngt_create_error_object();
|
||||
|
||||
const char *ngt_get_error_string(const NGTError);
|
||||
|
||||
void ngt_clear_error_string(NGTError);
|
||||
|
||||
void ngt_destroy_error_object(NGTError);
|
||||
|
||||
NGTOptimizer ngt_create_optimizer(bool logDisabled, NGTError);
|
||||
|
||||
bool ngt_optimizer_adjust_search_coefficients(NGTOptimizer, const char *, NGTError);
|
||||
|
||||
bool ngt_optimizer_execute(NGTOptimizer, const char *, const char *, NGTError);
|
||||
|
||||
bool ngt_optimizer_set(NGTOptimizer optimizer, int outgoing, int incoming, int nofqs,
|
||||
float baseAccuracyFrom, float baseAccuracyTo,
|
||||
float rateAccuracyFrom, float rateAccuracyTo,
|
||||
double gte, double m, NGTError error);
|
||||
|
||||
bool ngt_optimizer_set_minimum(NGTOptimizer optimizer, int outgoing, int incoming,
|
||||
int nofqs, int nofrs, NGTError error);
|
||||
|
||||
bool ngt_optimizer_set_extension(NGTOptimizer optimizer,
|
||||
float baseAccuracyFrom, float baseAccuracyTo,
|
||||
float rateAccuracyFrom, float rateAccuracyTo,
|
||||
double gte, double m, NGTError error);
|
||||
|
||||
bool ngt_optimizer_set_processing_modes(NGTOptimizer optimizer, bool searchParameter,
|
||||
bool prefetchParameter, bool accuracyTable, NGTError error);
|
||||
|
||||
void ngt_destroy_optimizer(NGTOptimizer);
|
||||
|
||||
// refine: the specified index by searching each node.
|
||||
// epsilon, exepectedAccuracy and edgeSize: the same as the prameters for search. but if edgeSize is INT_MIN, default is used.
|
||||
// noOfEdges: if this is not 0, kNNG with k = noOfEdges is build
|
||||
// batchSize: batch size for parallelism.
|
||||
bool ngt_refine_anng(NGTIndex index, float epsilon, float expectedAccuracy,
|
||||
int noOfEdges, int edgeSize, size_t batchSize, NGTError error);
|
||||
|
||||
// get edges of the node that is specified with id.
|
||||
bool ngt_get_edges(NGTIndex index, ObjectID id, NGTObjectDistances edges, NGTError error);
|
||||
|
||||
// get the size of the specified object repository.
|
||||
// Since the size includes empty objects, the size is not the number of objects.
|
||||
// The size is mostly the largest ID of the objects - 1;
|
||||
uint32_t ngt_get_object_repository_size(NGTIndex index, NGTError error);
|
||||
|
||||
// return parameters for ngt_optimize_number_of_edges. You can customize them before calling ngt_optimize_number_of_edges.
|
||||
NGTAnngEdgeOptimizationParameter ngt_get_anng_edge_optimization_parameter();
|
||||
|
||||
// optimize the number of initial edges for ANNG that is specified with indexPath.
|
||||
// The parameter should be a struct which is returned by nt_get_optimization_parameter.
|
||||
bool ngt_optimize_number_of_edges(const char *indexPath, NGTAnngEdgeOptimizationParameter parameter, NGTError error);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,857 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "NGT/Index.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
#if defined(NGT_AVX_DISABLED)
|
||||
#define NGT_CLUSTER_NO_AVX
|
||||
#else
|
||||
#if defined(__AVX2__)
|
||||
#define NGT_CLUSTER_AVX2
|
||||
#else
|
||||
#define NGT_CLUSTER_NO_AVX
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(NGT_CLUSTER_NO_AVX)
|
||||
// #warning "*** SIMD is *NOT* available! ***"
|
||||
#else
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
#include <omp.h>
|
||||
#include <random>
|
||||
|
||||
namespace NGT {
|
||||
|
||||
class Clustering {
|
||||
public:
|
||||
enum InitializationMode {
|
||||
InitializationModeHead = 0,
|
||||
InitializationModeRandom = 1,
|
||||
InitializationModeKmeansPlusPlus = 2
|
||||
};
|
||||
|
||||
enum ClusteringType {
|
||||
ClusteringTypeKmeansWithNGT = 0,
|
||||
ClusteringTypeKmeansWithoutNGT = 1,
|
||||
ClusteringTypeKmeansWithIteration = 2,
|
||||
ClusteringTypeKmeansWithNGTForCentroids = 3
|
||||
};
|
||||
|
||||
class Entry {
|
||||
public:
|
||||
Entry() : vectorID(0), centroidID(0), distance(0.0) {
|
||||
}
|
||||
Entry(size_t vid, size_t cid, double d) : vectorID(vid), centroidID(cid), distance(d) {
|
||||
}
|
||||
bool
|
||||
operator<(const Entry& e) const {
|
||||
return distance > e.distance;
|
||||
}
|
||||
uint32_t vectorID;
|
||||
uint32_t centroidID;
|
||||
double distance;
|
||||
};
|
||||
|
||||
class DescendingEntry {
|
||||
public:
|
||||
DescendingEntry(size_t vid, double d) : vectorID(vid), distance(d) {
|
||||
}
|
||||
bool
|
||||
operator<(const DescendingEntry& e) const {
|
||||
return distance < e.distance;
|
||||
}
|
||||
size_t vectorID;
|
||||
double distance;
|
||||
};
|
||||
|
||||
class Cluster {
|
||||
public:
|
||||
Cluster(std::vector<float>& c) : centroid(c), radius(0.0) {
|
||||
}
|
||||
Cluster(const Cluster& c) {
|
||||
*this = c;
|
||||
}
|
||||
Cluster&
|
||||
operator=(const Cluster& c) {
|
||||
members = c.members;
|
||||
centroid = c.centroid;
|
||||
radius = c.radius;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::vector<Entry> members;
|
||||
std::vector<float> centroid;
|
||||
double radius;
|
||||
};
|
||||
|
||||
Clustering(InitializationMode im = InitializationModeHead, ClusteringType ct = ClusteringTypeKmeansWithNGT,
|
||||
size_t mi = 100)
|
||||
: clusteringType(ct), initializationMode(im), maximumIteration(mi) {
|
||||
initialize();
|
||||
}
|
||||
|
||||
void
|
||||
initialize() {
|
||||
epsilonFrom = 0.12;
|
||||
epsilonTo = epsilonFrom;
|
||||
epsilonStep = 0.04;
|
||||
resultSizeCoefficient = 5;
|
||||
}
|
||||
|
||||
static void
|
||||
convert(std::vector<std::string>& strings, std::vector<float>& vector) {
|
||||
vector.clear();
|
||||
for (auto it = strings.begin(); it != strings.end(); ++it) {
|
||||
vector.push_back(stod(*it));
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
extractVector(const std::string& str, std::vector<float>& vec) {
|
||||
std::vector<std::string> tokens;
|
||||
NGT::Common::tokenize(str, tokens, " \t");
|
||||
convert(tokens, vec);
|
||||
}
|
||||
|
||||
static void
|
||||
loadVectors(const std::string& file, std::vector<std::vector<float> >& vectors) {
|
||||
std::ifstream is(file);
|
||||
if (!is) {
|
||||
throw std::runtime_error("loadVectors::Cannot open " + file);
|
||||
}
|
||||
std::string line;
|
||||
while (getline(is, line)) {
|
||||
std::vector<float> v;
|
||||
extractVector(line, v);
|
||||
vectors.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
saveVectors(const std::string& file, std::vector<std::vector<float> >& vectors) {
|
||||
std::ofstream os(file);
|
||||
for (auto vit = vectors.begin(); vit != vectors.end(); ++vit) {
|
||||
std::vector<float>& v = *vit;
|
||||
for (auto it = v.begin(); it != v.end(); ++it) {
|
||||
os << std::setprecision(9) << (*it);
|
||||
if (it + 1 != v.end()) {
|
||||
os << "\t";
|
||||
}
|
||||
}
|
||||
os << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
saveVector(const std::string& file, std::vector<size_t>& vectors) {
|
||||
std::ofstream os(file);
|
||||
for (auto vit = vectors.begin(); vit != vectors.end(); ++vit) {
|
||||
os << *vit << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
loadClusters(const std::string& file, std::vector<Cluster>& clusters, size_t numberOfClusters = 0) {
|
||||
std::ifstream is(file);
|
||||
if (!is) {
|
||||
throw std::runtime_error("loadClusters::Cannot open " + file);
|
||||
}
|
||||
std::string line;
|
||||
while (getline(is, line)) {
|
||||
std::vector<float> v;
|
||||
extractVector(line, v);
|
||||
clusters.push_back(v);
|
||||
if ((numberOfClusters != 0) && (clusters.size() >= numberOfClusters)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if ((numberOfClusters != 0) && (clusters.size() < numberOfClusters)) {
|
||||
std::cerr << "initial cluster data are not enough. " << clusters.size() << ":" << numberOfClusters
|
||||
<< std::endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
#if !defined(NGT_CLUSTER_NO_AVX)
|
||||
static double
|
||||
sumOfSquares(float* a, float* b, size_t size) {
|
||||
__m256 sum = _mm256_setzero_ps();
|
||||
float* last = a + size;
|
||||
float* lastgroup = last - 7;
|
||||
while (a < lastgroup) {
|
||||
__m256 v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
|
||||
sum = _mm256_add_ps(sum, _mm256_mul_ps(v, v));
|
||||
a += 8;
|
||||
b += 8;
|
||||
}
|
||||
__attribute__((aligned(32))) float f[8];
|
||||
_mm256_store_ps(f, sum);
|
||||
double s = f[0] + f[1] + f[2] + f[3] + f[4] + f[5] + f[6] + f[7];
|
||||
while (a < last) {
|
||||
double d = *a++ - *b++;
|
||||
s += d * d;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
#else // !defined(NGT_AVX_DISABLED) && defined(__AVX__)
|
||||
static double
|
||||
sumOfSquares(float* a, float* b, size_t size) {
|
||||
double csum = 0.0;
|
||||
float* x = a;
|
||||
float* y = b;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
double d = (double)*x++ - (double)*y++;
|
||||
csum += d * d;
|
||||
}
|
||||
return csum;
|
||||
}
|
||||
#endif // !defined(NGT_AVX_DISABLED) && defined(__AVX__)
|
||||
|
||||
static double
|
||||
distanceL2(std::vector<float>& vector1, std::vector<float>& vector2) {
|
||||
return sqrt(sumOfSquares(&vector1[0], &vector2[0], vector1.size()));
|
||||
}
|
||||
|
||||
static double
|
||||
distanceL2(std::vector<std::vector<float> >& vector1, std::vector<std::vector<float> >& vector2) {
|
||||
assert(vector1.size() == vector2.size());
|
||||
double distance = 0.0;
|
||||
for (size_t i = 0; i < vector1.size(); i++) {
|
||||
distance += distanceL2(vector1[i], vector2[i]);
|
||||
}
|
||||
distance /= (double)vector1.size();
|
||||
return distance;
|
||||
}
|
||||
|
||||
static double
|
||||
meanSumOfSquares(std::vector<float>& vector1, std::vector<float>& vector2) {
|
||||
return sumOfSquares(&vector1[0], &vector2[0], vector1.size()) / (double)vector1.size();
|
||||
}
|
||||
|
||||
static void
|
||||
subtract(std::vector<float>& a, std::vector<float>& b) {
|
||||
assert(a.size() == b.size());
|
||||
auto bit = b.begin();
|
||||
for (auto ait = a.begin(); ait != a.end(); ++ait, ++bit) {
|
||||
*ait = *ait - *bit;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
getInitialCentroidsFromHead(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
|
||||
size_t size) {
|
||||
size = size > vectors.size() ? vectors.size() : size;
|
||||
clusters.clear();
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
clusters.push_back(Cluster(vectors[i]));
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
getInitialCentroidsRandomly(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters, size_t size,
|
||||
size_t seed) {
|
||||
clusters.clear();
|
||||
std::random_device rnd;
|
||||
if (seed == 0) {
|
||||
seed = rnd();
|
||||
}
|
||||
std::mt19937 mt(seed);
|
||||
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
size_t idx = mt() * vectors.size() / mt.max();
|
||||
if (idx >= size) {
|
||||
i--;
|
||||
continue;
|
||||
}
|
||||
clusters.push_back(Cluster(vectors[idx]));
|
||||
}
|
||||
assert(clusters.size() == size);
|
||||
}
|
||||
|
||||
static void
|
||||
getInitialCentroidsKmeansPlusPlus(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
|
||||
size_t size) {
|
||||
size = size > vectors.size() ? vectors.size() : size;
|
||||
clusters.clear();
|
||||
std::random_device rnd;
|
||||
std::mt19937 mt(rnd());
|
||||
size_t idx = (long long)mt() * (long long)vectors.size() / (long long)mt.max();
|
||||
clusters.push_back(Cluster(vectors[idx]));
|
||||
|
||||
NGT::Timer timer;
|
||||
for (size_t k = 1; k < size; k++) {
|
||||
double sum = 0;
|
||||
std::priority_queue<DescendingEntry> sortedObjects;
|
||||
// get d^2 and sort
|
||||
#pragma omp parallel for
|
||||
for (size_t vi = 0; vi < vectors.size(); vi++) {
|
||||
auto vit = vectors.begin() + vi;
|
||||
double mind = DBL_MAX;
|
||||
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
|
||||
double d = distanceL2(*vit, (*cit).centroid);
|
||||
d *= d;
|
||||
if (d < mind) {
|
||||
mind = d;
|
||||
}
|
||||
}
|
||||
#pragma omp critical
|
||||
{
|
||||
sortedObjects.push(DescendingEntry(distance(vectors.begin(), vit), mind));
|
||||
sum += mind;
|
||||
}
|
||||
}
|
||||
double l = (double)mt() / (double)mt.max() * sum;
|
||||
while (!sortedObjects.empty()) {
|
||||
sum -= sortedObjects.top().distance;
|
||||
if (l >= sum) {
|
||||
clusters.push_back(Cluster(vectors[sortedObjects.top().vectorID]));
|
||||
break;
|
||||
}
|
||||
sortedObjects.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
assign(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
|
||||
size_t clusterSize = std::numeric_limits<size_t>::max()) {
|
||||
// compute distances to the nearest clusters, and construct heap by the distances.
|
||||
NGT::Timer timer;
|
||||
timer.start();
|
||||
|
||||
std::vector<Entry> sortedObjects(vectors.size());
|
||||
#pragma omp parallel for
|
||||
for (size_t vi = 0; vi < vectors.size(); vi++) {
|
||||
auto vit = vectors.begin() + vi;
|
||||
{
|
||||
double mind = DBL_MAX;
|
||||
size_t mincidx = -1;
|
||||
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
|
||||
double d = distanceL2(*vit, (*cit).centroid);
|
||||
if (d < mind) {
|
||||
mind = d;
|
||||
mincidx = distance(clusters.begin(), cit);
|
||||
}
|
||||
}
|
||||
sortedObjects[vi] = Entry(vi, mincidx, mind);
|
||||
}
|
||||
}
|
||||
std::sort(sortedObjects.begin(), sortedObjects.end());
|
||||
|
||||
// clear
|
||||
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
|
||||
(*cit).members.clear();
|
||||
}
|
||||
|
||||
// distribute objects to the nearest clusters in the same size constraint.
|
||||
for (auto soi = sortedObjects.rbegin(); soi != sortedObjects.rend();) {
|
||||
Entry& entry = *soi;
|
||||
if (entry.centroidID >= clusters.size()) {
|
||||
std::cerr << "Something wrong. " << entry.centroidID << ":" << clusters.size() << std::endl;
|
||||
soi++;
|
||||
continue;
|
||||
}
|
||||
if (clusters[entry.centroidID].members.size() < clusterSize) {
|
||||
clusters[entry.centroidID].members.push_back(entry);
|
||||
soi++;
|
||||
} else {
|
||||
double mind = DBL_MAX;
|
||||
size_t mincidx = -1;
|
||||
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
|
||||
if ((*cit).members.size() >= clusterSize) {
|
||||
continue;
|
||||
}
|
||||
double d = distanceL2(vectors[entry.vectorID], (*cit).centroid);
|
||||
if (d < mind) {
|
||||
mind = d;
|
||||
mincidx = distance(clusters.begin(), cit);
|
||||
}
|
||||
}
|
||||
entry = Entry(entry.vectorID, mincidx, mind);
|
||||
int pt = distance(sortedObjects.rbegin(), soi);
|
||||
std::sort(sortedObjects.begin(), soi.base());
|
||||
soi = sortedObjects.rbegin() + pt;
|
||||
assert(pt == distance(sortedObjects.rbegin(), soi));
|
||||
}
|
||||
}
|
||||
|
||||
moveFartherObjectsToEmptyClusters(clusters);
|
||||
}
|
||||
|
||||
static void
|
||||
moveFartherObjectsToEmptyClusters(std::vector<Cluster>& clusters) {
|
||||
size_t emptyClusterCount = 0;
|
||||
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
|
||||
if ((*cit).members.size() == 0) {
|
||||
emptyClusterCount++;
|
||||
double max = 0.0;
|
||||
auto maxit = clusters.begin();
|
||||
for (auto scit = clusters.begin(); scit != clusters.end(); ++scit) {
|
||||
if ((*scit).members.size() >= 2 && (*scit).members.back().distance > max) {
|
||||
maxit = scit;
|
||||
max = (*scit).members.back().distance;
|
||||
}
|
||||
}
|
||||
(*cit).members.push_back((*maxit).members.back());
|
||||
(*cit).members.back().centroidID = distance(clusters.begin(), cit);
|
||||
(*maxit).members.pop_back();
|
||||
}
|
||||
}
|
||||
emptyClusterCount = 0;
|
||||
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
|
||||
if ((*cit).members.size() == 0) {
|
||||
emptyClusterCount++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
assignWithNGT(NGT::Index& index, std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
|
||||
float& radius, size_t& resultSize, float epsilon = 0.12, size_t notRetrievedObjectCount = 0) {
|
||||
size_t dataSize = vectors.size();
|
||||
assert(index.getObjectRepositorySize() - 1 == vectors.size());
|
||||
vector<vector<Entry> > results(clusters.size());
|
||||
#pragma omp parallel for
|
||||
for (size_t ci = 0; ci < clusters.size(); ci++) {
|
||||
auto cit = clusters.begin() + ci;
|
||||
NGT::ObjectDistances objects; // result set
|
||||
NGT::Object* query = 0;
|
||||
query = index.allocateObject((*cit).centroid);
|
||||
// set search prameters.
|
||||
NGT::SearchContainer sc(*query); // search parametera container.
|
||||
sc.setResults(&objects); // set the result set.
|
||||
sc.setEpsilon(epsilon); // set exploration coefficient.
|
||||
if (radius > 0.0) {
|
||||
sc.setRadius(radius);
|
||||
sc.setSize(dataSize / 2);
|
||||
} else {
|
||||
sc.setSize(resultSize); // the number of resultant objects.
|
||||
}
|
||||
index.search(sc);
|
||||
results[ci].reserve(objects.size());
|
||||
for (size_t idx = 0; idx < objects.size(); idx++) {
|
||||
size_t oidx = objects[idx].id - 1;
|
||||
results[ci].push_back(Entry(oidx, ci, objects[idx].distance));
|
||||
}
|
||||
|
||||
index.deleteObject(query);
|
||||
}
|
||||
size_t resultCount = 0;
|
||||
for (auto ri = results.begin(); ri != results.end(); ++ri) {
|
||||
resultCount += (*ri).size();
|
||||
}
|
||||
vector<Entry> sortedResults;
|
||||
sortedResults.reserve(resultCount);
|
||||
for (auto ri = results.begin(); ri != results.end(); ++ri) {
|
||||
auto end = (*ri).begin();
|
||||
for (; end != (*ri).end(); ++end) {
|
||||
}
|
||||
std::copy((*ri).begin(), end, std::back_inserter(sortedResults));
|
||||
}
|
||||
|
||||
vector<bool> processedObjects(dataSize, false);
|
||||
for (auto i = sortedResults.begin(); i != sortedResults.end(); ++i) {
|
||||
processedObjects[(*i).vectorID] = true;
|
||||
}
|
||||
|
||||
notRetrievedObjectCount = 0;
|
||||
vector<uint32_t> notRetrievedObjectIDs;
|
||||
for (size_t idx = 0; idx < dataSize; idx++) {
|
||||
if (!processedObjects[idx]) {
|
||||
notRetrievedObjectCount++;
|
||||
notRetrievedObjectIDs.push_back(idx);
|
||||
}
|
||||
}
|
||||
|
||||
sort(sortedResults.begin(), sortedResults.end());
|
||||
|
||||
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
|
||||
(*cit).members.clear();
|
||||
}
|
||||
|
||||
for (auto i = sortedResults.rbegin(); i != sortedResults.rend(); ++i) {
|
||||
size_t objectID = (*i).vectorID;
|
||||
size_t clusterID = (*i).centroidID;
|
||||
if (processedObjects[objectID]) {
|
||||
processedObjects[objectID] = false;
|
||||
clusters[clusterID].members.push_back(*i);
|
||||
clusters[clusterID].members.back().centroidID = clusterID;
|
||||
radius = (*i).distance;
|
||||
}
|
||||
}
|
||||
|
||||
vector<Entry> notRetrievedObjects(notRetrievedObjectIDs.size());
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t vi = 0; vi < notRetrievedObjectIDs.size(); vi++) {
|
||||
auto vit = notRetrievedObjectIDs.begin() + vi;
|
||||
{
|
||||
double mind = DBL_MAX;
|
||||
size_t mincidx = -1;
|
||||
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
|
||||
double d = distanceL2(vectors[*vit], (*cit).centroid);
|
||||
if (d < mind) {
|
||||
mind = d;
|
||||
mincidx = distance(clusters.begin(), cit);
|
||||
}
|
||||
}
|
||||
notRetrievedObjects[vi] = Entry(*vit, mincidx, mind); // Entry(vectorID, centroidID, distance)
|
||||
}
|
||||
}
|
||||
|
||||
sort(notRetrievedObjects.begin(), notRetrievedObjects.end());
|
||||
|
||||
for (auto nroit = notRetrievedObjects.begin(); nroit != notRetrievedObjects.end(); ++nroit) {
|
||||
clusters[(*nroit).centroidID].members.push_back(*nroit);
|
||||
}
|
||||
|
||||
moveFartherObjectsToEmptyClusters(clusters);
|
||||
}
|
||||
|
||||
static double
|
||||
calculateCentroid(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters) {
|
||||
double distance = 0;
|
||||
size_t memberCount = 0;
|
||||
for (auto it = clusters.begin(); it != clusters.end(); ++it) {
|
||||
memberCount += (*it).members.size();
|
||||
if ((*it).members.size() != 0) {
|
||||
std::vector<float> mean(vectors[0].size(), 0.0);
|
||||
for (auto memit = (*it).members.begin(); memit != (*it).members.end(); ++memit) {
|
||||
auto mit = mean.begin();
|
||||
auto& v = vectors[(*memit).vectorID];
|
||||
for (auto vit = v.begin(); vit != v.end(); ++vit, ++mit) {
|
||||
*mit += *vit;
|
||||
}
|
||||
}
|
||||
for (auto mit = mean.begin(); mit != mean.end(); ++mit) {
|
||||
*mit /= (*it).members.size();
|
||||
}
|
||||
distance += distanceL2((*it).centroid, mean);
|
||||
(*it).centroid = mean;
|
||||
} else {
|
||||
cerr << "Clustering: Fatal Error. No member!" << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
static void
|
||||
saveClusters(const std::string& file, std::vector<Cluster>& clusters) {
|
||||
std::ofstream os(file);
|
||||
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
|
||||
std::vector<float>& v = (*cit).centroid;
|
||||
for (auto it = v.begin(); it != v.end(); ++it) {
|
||||
os << std::setprecision(9) << (*it);
|
||||
if (it + 1 != v.end()) {
|
||||
os << "\t";
|
||||
}
|
||||
}
|
||||
os << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
double
|
||||
kmeansWithoutNGT(std::vector<std::vector<float> >& vectors, size_t numberOfClusters,
|
||||
std::vector<Cluster>& clusters) {
|
||||
size_t clusterSize = std::numeric_limits<size_t>::max();
|
||||
if (clusterSizeConstraint) {
|
||||
clusterSize = ceil((double)vectors.size() / (double)numberOfClusters);
|
||||
}
|
||||
|
||||
double diff = 0;
|
||||
for (size_t i = 0; i < maximumIteration; i++) {
|
||||
std::cerr << "iteration=" << i << std::endl;
|
||||
assign(vectors, clusters, clusterSize);
|
||||
// centroid is recomputed.
|
||||
// diff is distance between the current centroids and the previous centroids.
|
||||
diff = calculateCentroid(vectors, clusters);
|
||||
if (diff == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return diff == 0;
|
||||
}
|
||||
|
||||
double
|
||||
kmeansWithNGT(NGT::Index& index, std::vector<std::vector<float> >& vectors, size_t numberOfClusters,
|
||||
std::vector<Cluster>& clusters, float epsilon) {
|
||||
diffHistory.clear();
|
||||
NGT::Timer timer;
|
||||
timer.start();
|
||||
float radius;
|
||||
double diff = 0.0;
|
||||
size_t resultSize;
|
||||
resultSize = resultSizeCoefficient * vectors.size() / clusters.size();
|
||||
for (size_t i = 0; i < maximumIteration; i++) {
|
||||
size_t notRetrievedObjectCount = 0;
|
||||
radius = -1.0;
|
||||
assignWithNGT(index, vectors, clusters, radius, resultSize, epsilon, notRetrievedObjectCount);
|
||||
// centroid is recomputed.
|
||||
// diff is distance between the current centroids and the previous centroids.
|
||||
std::vector<Cluster> prevClusters = clusters;
|
||||
diff = calculateCentroid(vectors, clusters);
|
||||
timer.stop();
|
||||
std::cerr << "iteration=" << i << " time=" << timer << " diff=" << diff << std::endl;
|
||||
timer.start();
|
||||
diffHistory.push_back(diff);
|
||||
|
||||
if (diff == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return diff;
|
||||
}
|
||||
|
||||
double
|
||||
kmeansWithNGT(std::vector<std::vector<float> >& vectors, size_t numberOfClusters, std::vector<Cluster>& clusters) {
|
||||
pid_t pid = getpid();
|
||||
std::stringstream str;
|
||||
str << "cluster-ngt." << pid;
|
||||
string database = str.str();
|
||||
string dataFile;
|
||||
size_t dataSize = 0;
|
||||
size_t dim = clusters.front().centroid.size();
|
||||
NGT::Property property;
|
||||
property.dimension = dim;
|
||||
property.graphType = NGT::Property::GraphType::GraphTypeANNG;
|
||||
property.objectType = NGT::Index::Property::ObjectType::Float;
|
||||
property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
|
||||
|
||||
NGT::Index::createGraphAndTree(database, property, dataFile, dataSize);
|
||||
|
||||
float* data = new float[vectors.size() * dim];
|
||||
float* ptr = data;
|
||||
dataSize = vectors.size();
|
||||
for (auto vi = vectors.begin(); vi != vectors.end(); ++vi) {
|
||||
memcpy(ptr, &((*vi)[0]), dim * sizeof(float));
|
||||
ptr += dim;
|
||||
}
|
||||
size_t threadSize = 20;
|
||||
NGT::Index::append(database, data, dataSize, threadSize);
|
||||
delete[] data;
|
||||
|
||||
NGT::Index index(database);
|
||||
|
||||
return kmeansWithNGT(index, vectors, numberOfClusters, clusters, epsilonFrom);
|
||||
}
|
||||
|
||||
double
|
||||
kmeansWithNGT(NGT::Index& index, size_t numberOfClusters, std::vector<Cluster>& clusters) {
|
||||
NGT::GraphIndex& graph = static_cast<NGT::GraphIndex&>(index.getIndex());
|
||||
NGT::ObjectSpace& os = graph.getObjectSpace();
|
||||
size_t size = os.getRepository().size();
|
||||
std::vector<std::vector<float> > vectors(size - 1);
|
||||
for (size_t idx = 1; idx < size; idx++) {
|
||||
try {
|
||||
os.getObject(idx, vectors[idx - 1]);
|
||||
} catch (...) {
|
||||
cerr << "Cannot get object " << idx << endl;
|
||||
}
|
||||
}
|
||||
cerr << "# of data for clustering=" << vectors.size() << endl;
|
||||
double diff = DBL_MAX;
|
||||
clusters.clear();
|
||||
setupInitialClusters(vectors, numberOfClusters, clusters);
|
||||
for (float epsilon = epsilonFrom; epsilon <= epsilonTo; epsilon += epsilonStep) {
|
||||
cerr << "epsilon=" << epsilon << endl;
|
||||
diff = kmeansWithNGT(index, vectors, numberOfClusters, clusters, epsilon);
|
||||
if (diff == 0.0) {
|
||||
return diff;
|
||||
}
|
||||
}
|
||||
return diff;
|
||||
}
|
||||
|
||||
double
|
||||
kmeansWithNGT(NGT::Index& index, size_t numberOfClusters, NGT::Index& outIndex) {
|
||||
std::vector<Cluster> clusters;
|
||||
double diff = kmeansWithNGT(index, numberOfClusters, clusters);
|
||||
for (auto i = clusters.begin(); i != clusters.end(); ++i) {
|
||||
outIndex.insert((*i).centroid);
|
||||
}
|
||||
outIndex.createIndex(16);
|
||||
return diff;
|
||||
}
|
||||
|
||||
double
|
||||
kmeansWithNGT(NGT::Index& index, size_t numberOfClusters) {
|
||||
NGT::Property prop;
|
||||
index.getProperty(prop);
|
||||
string path = index.getPath();
|
||||
index.save();
|
||||
index.close();
|
||||
string outIndexName = path;
|
||||
string inIndexName = path + ".tmp";
|
||||
std::rename(outIndexName.c_str(), inIndexName.c_str());
|
||||
NGT::Index::createGraphAndTree(outIndexName, prop);
|
||||
index.open(outIndexName);
|
||||
NGT::Index inIndex(inIndexName);
|
||||
double diff = kmeansWithNGT(inIndex, numberOfClusters, index);
|
||||
inIndex.close();
|
||||
NGT::Index::destroy(inIndexName);
|
||||
return diff;
|
||||
}
|
||||
|
||||
double
|
||||
kmeansWithNGT(string& indexName, size_t numberOfClusters) {
|
||||
NGT::Index inIndex(indexName);
|
||||
double diff = kmeansWithNGT(inIndex, numberOfClusters);
|
||||
inIndex.save();
|
||||
inIndex.close();
|
||||
return diff;
|
||||
}
|
||||
|
||||
static double
|
||||
calculateMSE(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters) {
|
||||
double mse = 0.0;
|
||||
size_t count = 0;
|
||||
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
|
||||
count += (*cit).members.size();
|
||||
for (auto mit = (*cit).members.begin(); mit != (*cit).members.end(); ++mit) {
|
||||
mse += meanSumOfSquares((*cit).centroid, vectors[(*mit).vectorID]);
|
||||
}
|
||||
}
|
||||
assert(vectors.size() == count);
|
||||
return mse / (double)vectors.size();
|
||||
}
|
||||
|
||||
static double
|
||||
calculateML2(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters) {
|
||||
double d = 0.0;
|
||||
size_t count = 0;
|
||||
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
|
||||
count += (*cit).members.size();
|
||||
double localD = 0.0;
|
||||
for (auto mit = (*cit).members.begin(); mit != (*cit).members.end(); ++mit) {
|
||||
double distance = distanceL2((*cit).centroid, vectors[(*mit).vectorID]);
|
||||
d += distance;
|
||||
localD += distance;
|
||||
}
|
||||
}
|
||||
if (vectors.size() != count) {
|
||||
std::cerr << "Warning! vectors.size() != count" << std::endl;
|
||||
}
|
||||
|
||||
return d / (double)vectors.size();
|
||||
}
|
||||
|
||||
static double
|
||||
calculateML2FromSpecifiedCentroids(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
|
||||
std::vector<size_t>& centroidIds) {
|
||||
double d = 0.0;
|
||||
size_t count = 0;
|
||||
for (auto it = centroidIds.begin(); it != centroidIds.end(); ++it) {
|
||||
Cluster& cluster = clusters[(*it)];
|
||||
count += cluster.members.size();
|
||||
for (auto mit = cluster.members.begin(); mit != cluster.members.end(); ++mit) {
|
||||
d += distanceL2(cluster.centroid, vectors[(*mit).vectorID]);
|
||||
}
|
||||
}
|
||||
return d / (double)vectors.size();
|
||||
}
|
||||
|
||||
void
|
||||
setupInitialClusters(std::vector<std::vector<float> >& vectors, size_t numberOfClusters,
|
||||
std::vector<Cluster>& clusters) {
|
||||
if (clusters.empty()) {
|
||||
switch (initializationMode) {
|
||||
case InitializationModeHead: {
|
||||
getInitialCentroidsFromHead(vectors, clusters, numberOfClusters);
|
||||
break;
|
||||
}
|
||||
case InitializationModeRandom: {
|
||||
getInitialCentroidsRandomly(vectors, clusters, numberOfClusters, 0);
|
||||
break;
|
||||
}
|
||||
case InitializationModeKmeansPlusPlus: {
|
||||
getInitialCentroidsKmeansPlusPlus(vectors, clusters, numberOfClusters);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
std::cerr << "proper initMode is not specified." << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool
|
||||
kmeans(std::vector<std::vector<float> >& vectors, size_t numberOfClusters, std::vector<Cluster>& clusters) {
|
||||
setupInitialClusters(vectors, numberOfClusters, clusters);
|
||||
|
||||
switch (clusteringType) {
|
||||
case ClusteringTypeKmeansWithoutNGT:
|
||||
return kmeansWithoutNGT(vectors, numberOfClusters, clusters);
|
||||
break;
|
||||
case ClusteringTypeKmeansWithNGT:
|
||||
return kmeansWithNGT(vectors, numberOfClusters, clusters);
|
||||
break;
|
||||
default:
|
||||
cerr << "kmeans::fatal error!. invalid clustering type. " << clusteringType << endl;
|
||||
abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
evaluate(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters, char mode,
|
||||
std::vector<size_t> centroidIds = std::vector<size_t>()) {
|
||||
size_t clusterSize = std::numeric_limits<size_t>::max();
|
||||
assign(vectors, clusters, clusterSize);
|
||||
|
||||
std::cout << "The number of vectors=" << vectors.size() << std::endl;
|
||||
std::cout << "The number of centroids=" << clusters.size() << std::endl;
|
||||
if (centroidIds.size() == 0) {
|
||||
switch (mode) {
|
||||
case 'e':
|
||||
std::cout << "MSE=" << calculateMSE(vectors, clusters) << std::endl;
|
||||
break;
|
||||
case '2':
|
||||
default:
|
||||
std::cout << "ML2=" << calculateML2(vectors, clusters) << std::endl;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
switch (mode) {
|
||||
case 'e':
|
||||
break;
|
||||
case '2':
|
||||
default:
|
||||
std::cout << "ML2=" << calculateML2FromSpecifiedCentroids(vectors, clusters, centroidIds)
|
||||
<< std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClusteringType clusteringType;
|
||||
InitializationMode initializationMode;
|
||||
size_t numberOfClusters;
|
||||
bool clusterSizeConstraint;
|
||||
size_t maximumIteration;
|
||||
float epsilonFrom;
|
||||
float epsilonTo;
|
||||
float epsilonStep;
|
||||
size_t resultSizeCoefficient;
|
||||
vector<double> diffHistory;
|
||||
};
|
||||
|
||||
} // namespace NGT
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,127 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "NGT/Index.h"
|
||||
|
||||
namespace NGT {
|
||||
|
||||
|
||||
class Command {
|
||||
public:
|
||||
class SearchParameter {
|
||||
public:
|
||||
SearchParameter() {
|
||||
openMode = 'r';
|
||||
query = "";
|
||||
querySize = 0;
|
||||
indexType = 't';
|
||||
size = 20;
|
||||
edgeSize = -1;
|
||||
outputMode = "-";
|
||||
radius = FLT_MAX;
|
||||
step = 0;
|
||||
trial = 1;
|
||||
beginOfEpsilon = endOfEpsilon = stepOfEpsilon = 0.1;
|
||||
accuracy = 0.0;
|
||||
}
|
||||
SearchParameter(Args &args) { parse(args); }
|
||||
void parse(Args &args) {
|
||||
openMode = args.getChar("m", 'r');
|
||||
try {
|
||||
query = args.get("#2");
|
||||
} catch (...) {
|
||||
NGTThrowException("ngt: Error: Query is not specified");
|
||||
}
|
||||
querySize = args.getl("Q", 0);
|
||||
indexType = args.getChar("i", 't');
|
||||
size = args.getl("n", 20);
|
||||
// edgeSize
|
||||
// -1(default) : using the size which was specified at the index creation.
|
||||
// 0 : no limitation for the edge size.
|
||||
// -2('e') : automatically set it according to epsilon.
|
||||
if (args.getChar("E", '-') == 'e') {
|
||||
edgeSize = -2;
|
||||
} else {
|
||||
edgeSize = args.getl("E", -1);
|
||||
}
|
||||
outputMode = args.getString("o", "-");
|
||||
radius = args.getf("r", FLT_MAX);
|
||||
trial = args.getl("t", 1);
|
||||
{
|
||||
beginOfEpsilon = endOfEpsilon = stepOfEpsilon = 0.1;
|
||||
std::string epsilon = args.getString("e", "0.1");
|
||||
std::vector<std::string> tokens;
|
||||
NGT::Common::tokenize(epsilon, tokens, ":");
|
||||
if (tokens.size() >= 1) { beginOfEpsilon = endOfEpsilon = NGT::Common::strtod(tokens[0]); }
|
||||
if (tokens.size() >= 2) { endOfEpsilon = NGT::Common::strtod(tokens[1]); }
|
||||
if (tokens.size() >= 3) { stepOfEpsilon = NGT::Common::strtod(tokens[2]); }
|
||||
step = 0;
|
||||
if (tokens.size() >= 4) { step = NGT::Common::strtol(tokens[3]); }
|
||||
}
|
||||
accuracy = args.getf("a", 0.0);
|
||||
}
|
||||
char openMode;
|
||||
std::string query;
|
||||
size_t querySize;
|
||||
char indexType;
|
||||
int size;
|
||||
long edgeSize;
|
||||
std::string outputMode;
|
||||
float radius;
|
||||
float beginOfEpsilon;
|
||||
float endOfEpsilon;
|
||||
float stepOfEpsilon;
|
||||
float accuracy;
|
||||
size_t step;
|
||||
size_t trial;
|
||||
};
|
||||
|
||||
Command():debugLevel(0) {}
|
||||
|
||||
void create(Args &args);
|
||||
void append(Args &args);
|
||||
static void search(NGT::Index &index, SearchParameter &searchParameter, std::ostream &stream)
|
||||
{
|
||||
std::ifstream is(searchParameter.query);
|
||||
if (!is) {
|
||||
std::cerr << "Cannot open the specified file. " << searchParameter.query << std::endl;
|
||||
return;
|
||||
}
|
||||
search(index, searchParameter, is, stream);
|
||||
}
|
||||
static void search(NGT::Index &index, SearchParameter &searchParameter, std::istream &is, std::ostream &stream);
|
||||
void search(Args &args);
|
||||
void remove(Args &args);
|
||||
void exportIndex(Args &args);
|
||||
void importIndex(Args &args);
|
||||
void prune(Args &args);
|
||||
void reconstructGraph(Args &args);
|
||||
void optimizeSearchParameters(Args &args);
|
||||
void optimizeNumberOfEdgesForANNG(Args &args);
|
||||
void refineANNG(Args &args);
|
||||
void repair(Args &args);
|
||||
|
||||
void info(Args &args);
|
||||
void setDebugLevel(int level) { debugLevel = level; }
|
||||
int getDebugLevel() { return debugLevel; }
|
||||
|
||||
protected:
|
||||
int debugLevel;
|
||||
|
||||
};
|
||||
|
||||
}; // NGT
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,213 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include "faiss/utils/ConcurrentBitset.h"
|
||||
#include <cstring>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
ConcurrentBitset::ConcurrentBitset(id_type_t capacity, uint8_t init_value) : capacity_(capacity), bitset_(((capacity + 8 - 1) >> 3)) {
|
||||
if (init_value) {
|
||||
memset(mutable_data(), init_value, (capacity + 8 - 1) >> 3);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::atomic<uint8_t>>&
|
||||
ConcurrentBitset::bitset() {
|
||||
return bitset_;
|
||||
}
|
||||
|
||||
ConcurrentBitset&
|
||||
ConcurrentBitset::operator&=(ConcurrentBitset& bitset) {
|
||||
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
|
||||
// bitset_[i].fetch_and(bitset.bitset()[i].load());
|
||||
// }
|
||||
|
||||
auto u8_1 = const_cast<uint8_t*>(data());
|
||||
auto u8_2 = const_cast<uint8_t*>(bitset.data());
|
||||
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
|
||||
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
|
||||
|
||||
size_t n8 = bitset_.size();
|
||||
size_t n64 = n8 / 8;
|
||||
|
||||
for (size_t i = 0; i < n64; i++) {
|
||||
u64_1[i] &= u64_2[i];
|
||||
}
|
||||
|
||||
size_t remain = n8 % 8;
|
||||
u8_1 += n64 * 8;
|
||||
u8_2 += n64 * 8;
|
||||
for (size_t i = 0; i < remain; i++) {
|
||||
u8_1[i] &= u8_2[i];
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::shared_ptr<ConcurrentBitset>
|
||||
ConcurrentBitset::operator&(const std::shared_ptr<ConcurrentBitset>& bitset) {
|
||||
auto result_bitset = std::make_shared<ConcurrentBitset>(bitset->capacity());
|
||||
|
||||
auto result_8 = const_cast<uint8_t*>(result_bitset->data());
|
||||
auto result_64 = reinterpret_cast<uint64_t*>(result_8);
|
||||
|
||||
auto u8_1 = const_cast<uint8_t*>(data());
|
||||
auto u8_2 = const_cast<uint8_t*>(bitset->data());
|
||||
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
|
||||
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
|
||||
|
||||
size_t n8 = bitset_.size();
|
||||
size_t n64 = n8 / 8;
|
||||
|
||||
for (size_t i = 0; i < n64; i++) {
|
||||
result_64[i] = u64_1[i] & u64_2[i];
|
||||
}
|
||||
|
||||
size_t remain = n8 % 8;
|
||||
u8_1 += n64 * 8;
|
||||
u8_2 += n64 * 8;
|
||||
result_8 += n64 * 8;
|
||||
for (size_t i = 0; i < remain; i++) {
|
||||
result_8[i] = u8_1[i] & u8_2[i];
|
||||
}
|
||||
|
||||
|
||||
return result_bitset;
|
||||
}
|
||||
|
||||
ConcurrentBitset&
|
||||
ConcurrentBitset::operator|=(ConcurrentBitset& bitset) {
|
||||
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
|
||||
// bitset_[i].fetch_or(bitset.bitset()[i].load());
|
||||
// }
|
||||
|
||||
auto u8_1 = const_cast<uint8_t*>(data());
|
||||
auto u8_2 = const_cast<uint8_t*>(bitset.data());
|
||||
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
|
||||
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
|
||||
|
||||
size_t n8 = bitset_.size();
|
||||
size_t n64 = n8 / 8;
|
||||
|
||||
for (size_t i = 0; i < n64; i++) {
|
||||
u64_1[i] |= u64_2[i];
|
||||
}
|
||||
|
||||
size_t remain = n8 % 8;
|
||||
u8_1 += n64 * 8;
|
||||
u8_2 += n64 * 8;
|
||||
for (size_t i = 0; i < remain; i++) {
|
||||
u8_1[i] |= u8_2[i];
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::shared_ptr<ConcurrentBitset>
|
||||
ConcurrentBitset::operator|(const std::shared_ptr<ConcurrentBitset>& bitset) {
|
||||
auto result_bitset = std::make_shared<ConcurrentBitset>(bitset->capacity());
|
||||
|
||||
auto result_8 = const_cast<uint8_t*>(result_bitset->data());
|
||||
auto result_64 = reinterpret_cast<uint64_t*>(result_8);
|
||||
|
||||
auto u8_1 = const_cast<uint8_t*>(data());
|
||||
auto u8_2 = const_cast<uint8_t*>(bitset->data());
|
||||
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
|
||||
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
|
||||
|
||||
size_t n8 = bitset_.size();
|
||||
size_t n64 = n8 / 8;
|
||||
|
||||
for (size_t i = 0; i < n64; i++) {
|
||||
result_64[i] = u64_1[i] | u64_2[i];
|
||||
}
|
||||
|
||||
size_t remain = n8 % 8;
|
||||
u8_1 += n64 * 8;
|
||||
u8_2 += n64 * 8;
|
||||
result_8 += n64 * 8;
|
||||
for (size_t i = 0; i < remain; i++) {
|
||||
result_8[i] = u8_1[i] | u8_2[i];
|
||||
}
|
||||
|
||||
return result_bitset;
|
||||
}
|
||||
|
||||
ConcurrentBitset&
|
||||
ConcurrentBitset::operator^=(ConcurrentBitset& bitset) {
|
||||
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
|
||||
// bitset_[i].fetch_xor(bitset.bitset()[i].load());
|
||||
// }
|
||||
|
||||
auto u8_1 = const_cast<uint8_t*>(data());
|
||||
auto u8_2 = const_cast<uint8_t*>(bitset.data());
|
||||
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
|
||||
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
|
||||
|
||||
size_t n8 = bitset_.size();
|
||||
size_t n64 = n8 / 8;
|
||||
|
||||
for (size_t i = 0; i < n64; i++) {
|
||||
u64_1[i] &= u64_2[i];
|
||||
}
|
||||
|
||||
size_t remain = n8 % 8;
|
||||
u8_1 += n64 * 8;
|
||||
u8_2 += n64 * 8;
|
||||
for (size_t i = 0; i < remain; i++) {
|
||||
u8_1[i] ^= u8_2[i];
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool
|
||||
ConcurrentBitset::test(id_type_t id) {
|
||||
return bitset_[id >> 3].load() & (0x1 << (id & 0x7));
|
||||
}
|
||||
|
||||
void
|
||||
ConcurrentBitset::set(id_type_t id) {
|
||||
bitset_[id >> 3].fetch_or(0x1 << (id & 0x7));
|
||||
}
|
||||
|
||||
void
|
||||
ConcurrentBitset::clear(id_type_t id) {
|
||||
bitset_[id >> 3].fetch_and(~(0x1 << (id & 0x7)));
|
||||
}
|
||||
|
||||
size_t
|
||||
ConcurrentBitset::capacity() {
|
||||
return capacity_;
|
||||
}
|
||||
|
||||
size_t
|
||||
ConcurrentBitset::size() {
|
||||
return ((capacity_ + 8 - 1) >> 3);
|
||||
}
|
||||
|
||||
const uint8_t*
|
||||
ConcurrentBitset::data() {
|
||||
return reinterpret_cast<const uint8_t*>(bitset_.data());
|
||||
}
|
||||
|
||||
uint8_t*
|
||||
ConcurrentBitset::mutable_data() {
|
||||
return reinterpret_cast<uint8_t*>(bitset_.data());
|
||||
}
|
||||
} // namespace faiss
|
|
@ -0,0 +1,15 @@
|
|||
#include "NGT/GetCoreNumber.h"
|
||||
|
||||
namespace NGT
|
||||
{
|
||||
int getCoreNumber()
|
||||
{
|
||||
#ifndef __linux__
|
||||
SYSTEM_INFO sys_info;
|
||||
GetSystemInfo(&sys_info);
|
||||
return sysInfo.dwNumberOfProcessors;
|
||||
#else
|
||||
return get_nprocs();
|
||||
#endif
|
||||
}
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
#ifndef __linux__
|
||||
# include "windows.h"
|
||||
#else
|
||||
|
||||
# include "sys/sysinfo.h"
|
||||
# include "unistd.h"
|
||||
#endif
|
||||
|
||||
namespace NGT
|
||||
{
|
||||
int getCoreNumber();
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,948 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <bitset>
|
||||
#include <sstream>
|
||||
|
||||
#include "NGT/defines.h"
|
||||
#include "NGT/Common.h"
|
||||
#include "NGT/ObjectSpaceRepository.h"
|
||||
|
||||
#include "faiss/utils/ConcurrentBitset.h"
|
||||
|
||||
#include "NGT/HashBasedBooleanSet.h"
|
||||
|
||||
#ifndef NGT_GRAPH_CHECK_VECTOR
|
||||
#include <unordered_set>
|
||||
#endif
|
||||
|
||||
#ifdef NGT_GRAPH_UNCHECK_STACK
|
||||
#include <stack>
|
||||
#endif
|
||||
|
||||
#ifndef NGT_EXPLORATION_COEFFICIENT
|
||||
#define NGT_EXPLORATION_COEFFICIENT 1.1
|
||||
#endif
|
||||
|
||||
#ifndef NGT_INSERTION_EXPLORATION_COEFFICIENT
|
||||
#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1
|
||||
#endif
|
||||
|
||||
#ifndef NGT_TRUNCATION_THRESHOLD
|
||||
#define NGT_TRUNCATION_THRESHOLD 50
|
||||
#endif
|
||||
|
||||
#ifndef NGT_SEED_SIZE
|
||||
#define NGT_SEED_SIZE 10
|
||||
#endif
|
||||
|
||||
#ifndef NGT_CREATION_EDGE_SIZE
|
||||
#define NGT_CREATION_EDGE_SIZE 10
|
||||
#endif
|
||||
|
||||
namespace NGT {
|
||||
class Property;
|
||||
|
||||
typedef GraphNode GRAPH_NODE;
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
class GraphRepository: public PersistentRepository<GRAPH_NODE> {
|
||||
#else
|
||||
class GraphRepository: public Repository<GRAPH_NODE> {
|
||||
#endif
|
||||
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
typedef PersistentRepository<GRAPH_NODE> VECTOR;
|
||||
#else
|
||||
typedef Repository<GRAPH_NODE> VECTOR;
|
||||
|
||||
GraphRepository() {
|
||||
prevsize = new vector<unsigned short>;
|
||||
}
|
||||
virtual ~GraphRepository() {
|
||||
deleteAll();
|
||||
if (prevsize != 0) {
|
||||
delete prevsize;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
void open(const std::string &file, size_t sharedMemorySize) {
|
||||
SharedMemoryAllocator &allocator = VECTOR::getAllocator();
|
||||
off_t *entryTable = (off_t*)allocator.construct(file, sharedMemorySize);
|
||||
if (entryTable == 0) {
|
||||
entryTable = (off_t*)construct();
|
||||
allocator.setEntry(entryTable);
|
||||
}
|
||||
assert(entryTable != 0);
|
||||
this->initialize(entryTable);
|
||||
}
|
||||
|
||||
void *construct() {
|
||||
SharedMemoryAllocator &allocator = VECTOR::getAllocator();
|
||||
off_t *entryTable = new(allocator) off_t[2];
|
||||
entryTable[0] = allocator.getOffset(PersistentRepository<GRAPH_NODE>::construct());
|
||||
entryTable[1] = allocator.getOffset(new(allocator) Vector<unsigned short>);
|
||||
return entryTable;
|
||||
}
|
||||
|
||||
void initialize(void *e) {
|
||||
SharedMemoryAllocator &allocator = VECTOR::getAllocator();
|
||||
off_t *entryTable = (off_t*)e;
|
||||
array = (ARRAY*)allocator.getAddr(entryTable[0]);
|
||||
PersistentRepository<GRAPH_NODE>::initialize(allocator.getAddr(entryTable[0]));
|
||||
prevsize = (Vector<unsigned short>*)allocator.getAddr(entryTable[1]);
|
||||
}
|
||||
#endif
|
||||
|
||||
void insert(ObjectID id, ObjectDistances &objects) {
|
||||
GRAPH_NODE *r = allocate();
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
(*r).copy(objects, VECTOR::getAllocator());
|
||||
#else
|
||||
*r = objects;
|
||||
#endif
|
||||
try {
|
||||
put(id, r);
|
||||
} catch (Exception &exp) {
|
||||
delete r;
|
||||
throw exp;
|
||||
}
|
||||
if (id >= prevsize->size()) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
prevsize->resize(id + 1, VECTOR::getAllocator(), 0);
|
||||
#else
|
||||
prevsize->resize(id + 1, 0);
|
||||
#endif
|
||||
} else {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
(*prevsize).at(id, VECTOR::getAllocator()) = 0;
|
||||
#else
|
||||
(*prevsize)[id] = 0;
|
||||
#endif
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
inline GRAPH_NODE *get(ObjectID fid, size_t &minsize) {
|
||||
GRAPH_NODE *rs = VECTOR::get(fid);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
minsize = (*prevsize).at(fid, VECTOR::getAllocator());
|
||||
#else
|
||||
minsize = (*prevsize)[fid];
|
||||
#endif
|
||||
return rs;
|
||||
}
|
||||
void serialize(std::ofstream &os) {
|
||||
VECTOR::serialize(os);
|
||||
Serializer::write(os, *prevsize);
|
||||
}
|
||||
// for milvus
|
||||
void serialize(std::stringstream & grp)
|
||||
{
|
||||
VECTOR::serialize(grp);
|
||||
Serializer::write(grp, *prevsize);
|
||||
}
|
||||
void deserialize(std::ifstream &is) {
|
||||
VECTOR::deserialize(is);
|
||||
Serializer::read(is, *prevsize);
|
||||
}
|
||||
// for milvus
|
||||
void deserialize(std::stringstream & is)
|
||||
{
|
||||
VECTOR::deserialize(is);
|
||||
Serializer::read(is, *prevsize);
|
||||
}
|
||||
void show() {
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
std::cout << "Show graph " << i << " ";
|
||||
if ((*this)[i] == 0) {
|
||||
std::cout << std::endl;
|
||||
continue;
|
||||
}
|
||||
for (size_t j = 0; j < (*this)[i]->size(); j++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cout << (*this)[i]->at(j, VECTOR::getAllocator()).id << ":" << (*this)[i]->at(j, VECTOR::getAllocator()).distance << " ";
|
||||
#else
|
||||
std::cout << (*this)[i]->at(j).id << ":" << (*this)[i]->at(j).distance << " ";
|
||||
#endif
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
Vector<unsigned short> *prevsize;
|
||||
#else
|
||||
std::vector<unsigned short> *prevsize;
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
|
||||
class ReadOnlyGraphNode : public std::vector<std::pair<uint64_t, PersistentObject*>> {
|
||||
public:
|
||||
ReadOnlyGraphNode():reservedSize(0), usedSize(0) {}
|
||||
void reserve(size_t s) {
|
||||
reservedSize = ((s & 7) == 0) ? s : (s & 0xFFFFFFFFFFFFFFF8) + 8;
|
||||
resize(reservedSize);
|
||||
for (size_t i = (reservedSize & 0xFFFFFFFFFFFFFFF8); i < reservedSize; i++) {
|
||||
(*this)[i].first = 0;
|
||||
}
|
||||
}
|
||||
void push_back(std::pair<uint32_t, PersistentObject*> node) {
|
||||
(*this)[usedSize] = node;
|
||||
usedSize++;
|
||||
}
|
||||
size_t size() { return usedSize; }
|
||||
size_t reservedSize;
|
||||
size_t usedSize;
|
||||
};
|
||||
|
||||
class SearchGraphRepository : public std::vector<ReadOnlyGraphNode> {
|
||||
public:
|
||||
SearchGraphRepository() {}
|
||||
bool isEmpty(size_t idx) { return (*this)[idx].empty(); }
|
||||
|
||||
void deserialize(std::ifstream &is, ObjectRepository &objectRepository) {
|
||||
if (!is.is_open()) {
|
||||
NGTThrowException("NGT::SearchGraph: Not open the specified stream yet.");
|
||||
}
|
||||
clear();
|
||||
size_t s;
|
||||
NGT::Serializer::read(is, s);
|
||||
resize(s);
|
||||
for (size_t id = 0; id < s; id++) {
|
||||
char type;
|
||||
NGT::Serializer::read(is, type);
|
||||
switch(type) {
|
||||
case '-':
|
||||
break;
|
||||
case '+':
|
||||
{
|
||||
ObjectDistances node;
|
||||
node.deserialize(is);
|
||||
ReadOnlyGraphNode &searchNode = at(id);
|
||||
searchNode.reserve(node.size());
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
for (auto ni = node.begin(); ni != node.end(); ni++) {
|
||||
std::cerr << "not implement" << std::endl;
|
||||
abort();
|
||||
}
|
||||
#else
|
||||
for (auto ni = node.begin(); ni != node.end(); ni++) {
|
||||
searchNode.push_back(std::pair<uint32_t, Object*>((*ni).id, objectRepository.get((*ni).id)));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
break;
|
||||
default:
|
||||
{
|
||||
assert(type == '-' || type == '+');
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
#endif // NGT_GRAPH_READ_ONLY_GRAPH
|
||||
|
||||
class NeighborhoodGraph {
|
||||
public:
|
||||
enum GraphType {
|
||||
GraphTypeNone = 0,
|
||||
GraphTypeANNG = 1,
|
||||
GraphTypeKNNG = 2,
|
||||
GraphTypeBKNNG = 3,
|
||||
GraphTypeONNG = 4,
|
||||
GraphTypeIANNG = 5, // Improved ANNG
|
||||
GraphTypeDNNG = 6
|
||||
};
|
||||
|
||||
enum SeedType {
|
||||
SeedTypeNone = 0,
|
||||
SeedTypeRandomNodes = 1,
|
||||
SeedTypeFixedNodes = 2,
|
||||
SeedTypeFirstNode = 3,
|
||||
SeedTypeAllLeafNodes = 4
|
||||
};
|
||||
|
||||
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
|
||||
class Search {
|
||||
public:
|
||||
static void (*getMethod(NGT::ObjectSpace::DistanceType dtype, NGT::ObjectSpace::ObjectType otype, size_t size))(NGT::NeighborhoodGraph&, NGT::SearchContainer&, NGT::ObjectDistances&) {
|
||||
if (size < 5000000) {
|
||||
switch (otype) {
|
||||
default:
|
||||
case NGT::ObjectSpace::Float:
|
||||
switch (dtype) {
|
||||
case NGT::ObjectSpace::DistanceTypeNormalizedCosine : return normalizedCosineSimilarityFloat;
|
||||
case NGT::ObjectSpace::DistanceTypeCosine : return cosineSimilarityFloat;
|
||||
case NGT::ObjectSpace::DistanceTypeNormalizedAngle : return normalizedAngleFloat;
|
||||
case NGT::ObjectSpace::DistanceTypeAngle : return angleFloat;
|
||||
case NGT::ObjectSpace::DistanceTypeL2 : return l2Float;
|
||||
case NGT::ObjectSpace::DistanceTypeL1 : return l1Float;
|
||||
case NGT::ObjectSpace::DistanceTypeSparseJaccard : return sparseJaccardFloat;
|
||||
default: return l2Float;
|
||||
}
|
||||
break;
|
||||
case NGT::ObjectSpace::Uint8:
|
||||
switch (dtype) {
|
||||
case NGT::ObjectSpace::DistanceTypeHamming : return hammingUint8;
|
||||
case NGT::ObjectSpace::DistanceTypeJaccard : return jaccardUint8;
|
||||
case NGT::ObjectSpace::DistanceTypeL2 : return l2Uint8;
|
||||
case NGT::ObjectSpace::DistanceTypeL1 : return l1Uint8;
|
||||
default : return l2Uint8;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return l1Uint8;
|
||||
} else {
|
||||
switch (otype) {
|
||||
default:
|
||||
case NGT::ObjectSpace::Float:
|
||||
switch (dtype) {
|
||||
case NGT::ObjectSpace::DistanceTypeNormalizedCosine : return normalizedCosineSimilarityFloatForLargeDataset;
|
||||
case NGT::ObjectSpace::DistanceTypeCosine : return cosineSimilarityFloatForLargeDataset;
|
||||
case NGT::ObjectSpace::DistanceTypeNormalizedAngle : return normalizedAngleFloatForLargeDataset;
|
||||
case NGT::ObjectSpace::DistanceTypeAngle : return angleFloatForLargeDataset;
|
||||
case NGT::ObjectSpace::DistanceTypeL2 : return l2FloatForLargeDataset;
|
||||
case NGT::ObjectSpace::DistanceTypeL1 : return l1FloatForLargeDataset;
|
||||
case NGT::ObjectSpace::DistanceTypeSparseJaccard : return sparseJaccardFloatForLargeDataset;
|
||||
default: return l2FloatForLargeDataset;
|
||||
}
|
||||
break;
|
||||
case NGT::ObjectSpace::Uint8:
|
||||
switch (dtype) {
|
||||
case NGT::ObjectSpace::DistanceTypeHamming : return hammingUint8ForLargeDataset;
|
||||
case NGT::ObjectSpace::DistanceTypeJaccard : return jaccardUint8ForLargeDataset;
|
||||
case NGT::ObjectSpace::DistanceTypeL2 : return l2Uint8ForLargeDataset;
|
||||
case NGT::ObjectSpace::DistanceTypeL1 : return l1Uint8ForLargeDataset;
|
||||
default : return l2Uint8ForLargeDataset;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return l1Uint8ForLargeDataset;
|
||||
}
|
||||
}
|
||||
static void l1Uint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void l2Uint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void l1Float(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void l2Float(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void hammingUint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void jaccardUint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void sparseJaccardFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void cosineSimilarityFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void angleFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void normalizedCosineSimilarityFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void normalizedAngleFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
|
||||
static void l1Uint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void l2Uint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void l1FloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void l2FloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void hammingUint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void jaccardUint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void sparseJaccardFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void cosineSimilarityFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void angleFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void normalizedCosineSimilarityFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
static void normalizedAngleFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
|
||||
};
|
||||
#endif
|
||||
|
||||
class Property {
|
||||
public:
|
||||
Property() { setDefault(); }
|
||||
void setDefault() {
|
||||
truncationThreshold = 0;
|
||||
edgeSizeForCreation = NGT_CREATION_EDGE_SIZE;
|
||||
edgeSizeForSearch = 0;
|
||||
edgeSizeLimitForCreation = 5;
|
||||
insertionRadiusCoefficient = NGT_INSERTION_EXPLORATION_COEFFICIENT;
|
||||
seedSize = NGT_SEED_SIZE;
|
||||
seedType = SeedTypeNone;
|
||||
truncationThreadPoolSize = 8;
|
||||
batchSizeForCreation = 200;
|
||||
graphType = GraphTypeANNG;
|
||||
dynamicEdgeSizeBase = 30;
|
||||
dynamicEdgeSizeRate = 20;
|
||||
buildTimeLimit = 0.0;
|
||||
outgoingEdge = 10;
|
||||
incomingEdge = 80;
|
||||
}
|
||||
void clear() {
|
||||
truncationThreshold = -1;
|
||||
edgeSizeForCreation = -1;
|
||||
edgeSizeForSearch = -1;
|
||||
edgeSizeLimitForCreation = -1;
|
||||
insertionRadiusCoefficient = -1;
|
||||
seedSize = -1;
|
||||
seedType = SeedTypeNone;
|
||||
truncationThreadPoolSize = -1;
|
||||
batchSizeForCreation = -1;
|
||||
graphType = GraphTypeNone;
|
||||
dynamicEdgeSizeBase = -1;
|
||||
dynamicEdgeSizeRate = -1;
|
||||
buildTimeLimit = -1;
|
||||
outgoingEdge = -1;
|
||||
incomingEdge = -1;
|
||||
}
|
||||
void set(NGT::Property &prop);
|
||||
void get(NGT::Property &prop);
|
||||
|
||||
void exportProperty(NGT::PropertySet &p) {
|
||||
p.set("IncrimentalEdgeSizeLimitForTruncation", truncationThreshold);
|
||||
p.set("EdgeSizeForCreation", edgeSizeForCreation);
|
||||
p.set("EdgeSizeForSearch", edgeSizeForSearch);
|
||||
p.set("EdgeSizeLimitForCreation", edgeSizeLimitForCreation);
|
||||
assert(insertionRadiusCoefficient >= 1.0);
|
||||
p.set("EpsilonForCreation", insertionRadiusCoefficient - 1.0);
|
||||
p.set("BatchSizeForCreation", batchSizeForCreation);
|
||||
p.set("SeedSize", seedSize);
|
||||
p.set("TruncationThreadPoolSize", truncationThreadPoolSize);
|
||||
p.set("DynamicEdgeSizeBase", dynamicEdgeSizeBase);
|
||||
p.set("DynamicEdgeSizeRate", dynamicEdgeSizeRate);
|
||||
p.set("BuildTimeLimit", buildTimeLimit);
|
||||
p.set("OutgoingEdge", outgoingEdge);
|
||||
p.set("IncomingEdge", incomingEdge);
|
||||
switch (graphType) {
|
||||
case NeighborhoodGraph::GraphTypeKNNG: p.set("GraphType", "KNNG"); break;
|
||||
case NeighborhoodGraph::GraphTypeANNG: p.set("GraphType", "ANNG"); break;
|
||||
case NeighborhoodGraph::GraphTypeBKNNG: p.set("GraphType", "BKNNG"); break;
|
||||
case NeighborhoodGraph::GraphTypeONNG: p.set("GraphType", "ONNG"); break;
|
||||
case NeighborhoodGraph::GraphTypeIANNG: p.set("GraphType", "IANNG"); break;
|
||||
default: std::cerr << "Graph::exportProperty: Fatal error! Invalid Graph Type." << std::endl; abort();
|
||||
}
|
||||
switch (seedType) {
|
||||
case NeighborhoodGraph::SeedTypeRandomNodes: p.set("SeedType", "RandomNodes"); break;
|
||||
case NeighborhoodGraph::SeedTypeFixedNodes: p.set("SeedType", "FixedNodes"); break;
|
||||
case NeighborhoodGraph::SeedTypeFirstNode: p.set("SeedType", "FirstNode"); break;
|
||||
case NeighborhoodGraph::SeedTypeNone: p.set("SeedType", "None"); break;
|
||||
case NeighborhoodGraph::SeedTypeAllLeafNodes: p.set("SeedType", "AllLeafNodes"); break;
|
||||
default: std::cerr << "Graph::exportProperty: Fatal error! Invalid Seed Type." << std::endl; abort();
|
||||
}
|
||||
}
|
||||
void importProperty(NGT::PropertySet &p) {
|
||||
setDefault();
|
||||
truncationThreshold = p.getl("IncrimentalEdgeSizeLimitForTruncation", truncationThreshold);
|
||||
edgeSizeForCreation = p.getl("EdgeSizeForCreation", edgeSizeForCreation);
|
||||
edgeSizeForSearch = p.getl("EdgeSizeForSearch", edgeSizeForSearch);
|
||||
edgeSizeLimitForCreation = p.getl("EdgeSizeLimitForCreation", edgeSizeLimitForCreation);
|
||||
insertionRadiusCoefficient = p.getf("EpsilonForCreation", insertionRadiusCoefficient);
|
||||
insertionRadiusCoefficient += 1.0;
|
||||
batchSizeForCreation = p.getl("BatchSizeForCreation", batchSizeForCreation);
|
||||
seedSize = p.getl("SeedSize", seedSize);
|
||||
truncationThreadPoolSize = p.getl("TruncationThreadPoolSize", truncationThreadPoolSize);
|
||||
dynamicEdgeSizeBase = p.getl("DynamicEdgeSizeBase", dynamicEdgeSizeBase);
|
||||
dynamicEdgeSizeRate = p.getl("DynamicEdgeSizeRate", dynamicEdgeSizeRate);
|
||||
buildTimeLimit = p.getf("BuildTimeLimit", buildTimeLimit);
|
||||
outgoingEdge = p.getl("OutgoingEdge", outgoingEdge);
|
||||
incomingEdge = p.getl("IncomingEdge", incomingEdge);
|
||||
PropertySet::iterator it = p.find("GraphType");
|
||||
if (it != p.end()) {
|
||||
if (it->second == "KNNG") graphType = NeighborhoodGraph::GraphTypeKNNG;
|
||||
else if (it->second == "ANNG") graphType = NeighborhoodGraph::GraphTypeANNG;
|
||||
else if (it->second == "BKNNG") graphType = NeighborhoodGraph::GraphTypeBKNNG;
|
||||
else if (it->second == "ONNG") graphType = NeighborhoodGraph::GraphTypeONNG;
|
||||
else if (it->second == "IANNG") graphType = NeighborhoodGraph::GraphTypeIANNG;
|
||||
else { std::cerr << "Graph::importProperty: Fatal error! Invalid Graph Type. " << it->second << std::endl; abort(); }
|
||||
}
|
||||
it = p.find("SeedType");
|
||||
if (it != p.end()) {
|
||||
if (it->second == "RandomNodes") seedType = NeighborhoodGraph::SeedTypeRandomNodes;
|
||||
else if (it->second == "FixedNodes") seedType = NeighborhoodGraph::SeedTypeFixedNodes;
|
||||
else if (it->second == "FirstNode") seedType = NeighborhoodGraph::SeedTypeFirstNode;
|
||||
else if (it->second == "None") seedType = NeighborhoodGraph::SeedTypeNone;
|
||||
else if (it->second == "AllLeafNodes") seedType = NeighborhoodGraph::SeedTypeAllLeafNodes;
|
||||
else { std::cerr << "Graph::importProperty: Fatal error! Invalid Seed Type. " << it->second << std::endl; abort(); }
|
||||
}
|
||||
}
|
||||
friend std::ostream & operator<<(std::ostream& os, const Property& p) {
|
||||
os << "truncationThreshold=" << p.truncationThreshold << std::endl;
|
||||
os << "edgeSizeForCreation=" << p.edgeSizeForCreation << std::endl;
|
||||
os << "edgeSizeForSearch=" << p.edgeSizeForSearch << std::endl;
|
||||
os << "edgeSizeLimitForCreation=" << p.edgeSizeLimitForCreation << std::endl;
|
||||
os << "insertionRadiusCoefficient=" << p.insertionRadiusCoefficient << std::endl;
|
||||
os << "insertionRadiusCoefficient=" << p.insertionRadiusCoefficient << std::endl;
|
||||
os << "seedSize=" << p.seedSize << std::endl;
|
||||
os << "seedType=" << p.seedType << std::endl;
|
||||
os << "truncationThreadPoolSize=" << p.truncationThreadPoolSize << std::endl;
|
||||
os << "batchSizeForCreation=" << p.batchSizeForCreation << std::endl;
|
||||
os << "graphType=" << p.graphType << std::endl;
|
||||
os << "dynamicEdgeSizeBase=" << p.dynamicEdgeSizeBase << std::endl;
|
||||
os << "dynamicEdgeSizeRate=" << p.dynamicEdgeSizeRate << std::endl;
|
||||
os << "outgoingEdge=" << p.outgoingEdge << std::endl;
|
||||
os << "incomingEdge=" << p.incomingEdge << std::endl;
|
||||
return os;
|
||||
}
|
||||
|
||||
int16_t truncationThreshold;
|
||||
int16_t edgeSizeForCreation;
|
||||
int16_t edgeSizeForSearch;
|
||||
int16_t edgeSizeLimitForCreation;
|
||||
double insertionRadiusCoefficient;
|
||||
int16_t seedSize;
|
||||
SeedType seedType;
|
||||
int16_t truncationThreadPoolSize;
|
||||
int16_t batchSizeForCreation;
|
||||
GraphType graphType;
|
||||
int16_t dynamicEdgeSizeBase;
|
||||
int16_t dynamicEdgeSizeRate;
|
||||
float buildTimeLimit;
|
||||
int16_t outgoingEdge;
|
||||
int16_t incomingEdge;
|
||||
};
|
||||
|
||||
NeighborhoodGraph(): objectSpace(0) {
|
||||
property.truncationThreshold = NGT_TRUNCATION_THRESHOLD;
|
||||
// initialize random to generate random seeds
|
||||
#ifdef NGT_DISABLE_SRAND_FOR_RANDOM
|
||||
struct timeval randTime;
|
||||
gettimeofday(&randTime, 0);
|
||||
srand(randTime.tv_usec);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline GraphNode *getNode(ObjectID fid, size_t &minsize) { return repository.get(fid, minsize); }
|
||||
inline GraphNode *getNode(ObjectID fid) { return repository.VECTOR::get(fid); }
|
||||
void insertNode(ObjectID id, ObjectDistances &objects) {
|
||||
switch (property.graphType) {
|
||||
case GraphTypeANNG:
|
||||
insertANNGNode(id, objects);
|
||||
break;
|
||||
case GraphTypeIANNG:
|
||||
insertIANNGNode(id, objects);
|
||||
break;
|
||||
case GraphTypeONNG:
|
||||
insertONNGNode(id, objects);
|
||||
break;
|
||||
case GraphTypeKNNG:
|
||||
insertKNNGNode(id, objects);
|
||||
break;
|
||||
case GraphTypeBKNNG:
|
||||
insertBKNNGNode(id, objects);
|
||||
break;
|
||||
case GraphTypeNone:
|
||||
NGTThrowException("NGT::insertNode: GraphType is not specified.");
|
||||
break;
|
||||
default:
|
||||
NGTThrowException("NGT::insertNode: GraphType is invalid.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void insertBKNNGNode(ObjectID id, ObjectDistances &results) {
|
||||
if (repository.isEmpty(id)) {
|
||||
repository.insert(id, results);
|
||||
} else {
|
||||
GraphNode &rs = *getNode(id);
|
||||
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
rs.push_back((*ri), repository.allocator);
|
||||
#else
|
||||
rs.push_back((*ri));
|
||||
#endif
|
||||
}
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::sort(rs.begin(repository.allocator), rs.end(repository.allocator));
|
||||
ObjectID prev = 0;
|
||||
for (GraphNode::iterator ri = rs.begin(repository.allocator); ri != rs.end(repository.allocator);) {
|
||||
if (prev == (*ri).id) {
|
||||
ri = rs.erase(ri, repository.allocator);
|
||||
continue;
|
||||
}
|
||||
prev = (*ri).id;
|
||||
ri++;
|
||||
}
|
||||
#else
|
||||
std::sort(rs.begin(), rs.end());
|
||||
ObjectID prev = 0;
|
||||
for (GraphNode::iterator ri = rs.begin(); ri != rs.end();) {
|
||||
if (prev == (*ri).id) {
|
||||
ri = rs.erase(ri);
|
||||
continue;
|
||||
}
|
||||
prev = (*ri).id;
|
||||
ri++;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
|
||||
assert(id != (*ri).id);
|
||||
addBKNNGEdge((*ri).id, id, (*ri).distance);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void insertKNNGNode(ObjectID id, ObjectDistances &results) {
|
||||
repository.insert(id, results);
|
||||
}
|
||||
|
||||
void insertANNGNode(ObjectID id, ObjectDistances &results) {
|
||||
repository.insert(id, results);
|
||||
std::queue<ObjectID> truncateQueue;
|
||||
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
|
||||
assert(id != (*ri).id);
|
||||
if (addEdge((*ri).id, id, (*ri).distance)) {
|
||||
truncateQueue.push((*ri).id);
|
||||
}
|
||||
}
|
||||
while (!truncateQueue.empty()) {
|
||||
ObjectID tid = truncateQueue.front();
|
||||
truncateEdges(tid);
|
||||
truncateQueue.pop();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void insertIANNGNode(ObjectID id, ObjectDistances &results) {
|
||||
repository.insert(id, results);
|
||||
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
|
||||
assert(id != (*ri).id);
|
||||
addEdgeDeletingExcessEdges((*ri).id, id, (*ri).distance);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void insertONNGNode(ObjectID id, ObjectDistances &results) {
|
||||
if (property.truncationThreshold != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "NGT::insertONNGNode: truncation should be disabled!" << std::endl;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
int count = 0;
|
||||
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++, count++) {
|
||||
assert(id != (*ri).id);
|
||||
if (count >= property.incomingEdge) {
|
||||
break;
|
||||
}
|
||||
addEdge((*ri).id, id, (*ri).distance);
|
||||
}
|
||||
if (static_cast<int>(results.size()) > property.outgoingEdge) {
|
||||
results.resize(property.outgoingEdge);
|
||||
}
|
||||
repository.insert(id, results);
|
||||
}
|
||||
|
||||
void removeEdgesReliably(ObjectID id);
|
||||
|
||||
int truncateEdgesOptimally(ObjectID id, GraphNode &results, size_t truncationSize);
|
||||
|
||||
int truncateEdges(ObjectID id) {
|
||||
GraphNode &results = *getNode(id);
|
||||
if (results.size() == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
size_t truncationSize = NGT_TRUNCATION_THRESHOLD;
|
||||
if (truncationSize < (size_t)property.edgeSizeForCreation) {
|
||||
truncationSize = property.edgeSizeForCreation;
|
||||
}
|
||||
return truncateEdgesOptimally(id, results, truncationSize);
|
||||
}
|
||||
|
||||
// setup edgeSize
|
||||
inline size_t getEdgeSize(NGT::SearchContainer &sc) {
|
||||
size_t edgeSize = INT_MAX;
|
||||
if (sc.edgeSize < 0) {
|
||||
if (sc.edgeSize == -2) {
|
||||
double add = pow(10, (sc.explorationCoefficient - 1.0) * static_cast<float>(property.dynamicEdgeSizeRate));
|
||||
edgeSize = add >= static_cast<double>(INT_MAX) ? INT_MAX : property.dynamicEdgeSizeBase + add;
|
||||
} else {
|
||||
edgeSize = property.edgeSizeForSearch == 0 ? INT_MAX : property.edgeSizeForSearch;
|
||||
}
|
||||
} else {
|
||||
edgeSize = sc.edgeSize == 0 ? INT_MAX : sc.edgeSize;
|
||||
}
|
||||
return edgeSize;
|
||||
}
|
||||
|
||||
void search(NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
// for milvus
|
||||
void search(NGT::SearchContainer & sc, ObjectDistances & seeds, faiss::ConcurrentBitsetPtr & bitset);
|
||||
|
||||
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
|
||||
template <typename COMPARATOR, typename CHECK_LIST> void searchReadOnlyGraph(NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
#endif
|
||||
|
||||
void removeEdge(ObjectID fid, ObjectID rmid) {
|
||||
GraphNode &rs = *getNode(fid);
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
for (GraphNode::iterator ri = rs.begin(repository.allocator); ri != rs.end(repository.allocator); ri++) {
|
||||
if ((*ri).id == rmid) {
|
||||
rs.erase(ri, repository.allocator);
|
||||
break;
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (GraphNode::iterator ri = rs.begin(); ri != rs.end(); ri++) {
|
||||
if ((*ri).id == rmid) {
|
||||
rs.erase(ri);
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void removeEdge(GraphNode &node, ObjectDistance &edge) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
GraphNode::iterator ni = std::lower_bound(node.begin(repository.allocator), node.end(repository.allocator), edge);
|
||||
if (ni != node.end(repository.allocator) && (*ni).id == edge.id) {
|
||||
node.erase(ni, repository.allocator);
|
||||
#else
|
||||
GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), edge);
|
||||
if (ni != node.end() && (*ni).id == edge.id) {
|
||||
node.erase(ni);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
if (ni == node.end(repository.allocator)) {
|
||||
#else
|
||||
if (ni == node.end()) {
|
||||
#endif
|
||||
std::stringstream msg;
|
||||
msg << "NGT::removeEdge: Cannot found " << edge.id;
|
||||
NGTThrowException(msg);
|
||||
} else {
|
||||
std::stringstream msg;
|
||||
msg << "NGT::removeEdge: Cannot found " << (*ni).id << ":" << edge.id;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
removeNode(ObjectID id) {
|
||||
repository.erase(id);
|
||||
}
|
||||
|
||||
class BooleanVector : public std::vector<bool> {
|
||||
public:
|
||||
inline BooleanVector(size_t s):std::vector<bool>(s, false) {}
|
||||
inline void insert(size_t i) { std::vector<bool>::operator[](i) = true; }
|
||||
};
|
||||
|
||||
#ifdef NGT_GRAPH_VECTOR_RESULT
|
||||
typedef ObjectDistances ResultSet;
|
||||
#else
|
||||
typedef std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > ResultSet;
|
||||
#endif
|
||||
|
||||
#if defined(NGT_GRAPH_CHECK_BOOLEANSET)
|
||||
typedef BooleanSet DistanceCheckedSet;
|
||||
#elif defined(NGT_GRAPH_CHECK_VECTOR)
|
||||
typedef BooleanVector DistanceCheckedSet;
|
||||
#elif defined(NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET)
|
||||
typedef HashBasedBooleanSet DistanceCheckedSet;
|
||||
#else
|
||||
class DistanceCheckedSet : public unordered_set<ObjectID> {
|
||||
public:
|
||||
bool operator[](ObjectID id) { return find(id) != end(); }
|
||||
};
|
||||
#endif
|
||||
|
||||
typedef HashBasedBooleanSet DistanceCheckedSetForLargeDataset;
|
||||
|
||||
class NodeWithPosition : public ObjectDistance {
|
||||
public:
|
||||
NodeWithPosition(uint32_t p = 0):position(p){}
|
||||
NodeWithPosition(ObjectDistance &o):ObjectDistance(o), position(0){}
|
||||
NodeWithPosition &operator=(const NodeWithPosition &n) {
|
||||
ObjectDistance::operator=(static_cast<const ObjectDistance&>(n));
|
||||
position = n.position;
|
||||
assert(id != 0);
|
||||
return *this;
|
||||
}
|
||||
uint32_t position;
|
||||
};
|
||||
|
||||
#ifdef NGT_GRAPH_UNCHECK_STACK
|
||||
typedef std::stack<ObjectDistance> UncheckedSet;
|
||||
#else
|
||||
#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE
|
||||
typedef std::priority_queue<NodeWithPosition, std::vector<NodeWithPosition>, std::greater<NodeWithPosition> > UncheckedSet;
|
||||
#else
|
||||
typedef std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::greater<ObjectDistance> > UncheckedSet;
|
||||
#endif
|
||||
#endif
|
||||
void setupDistances(NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
void setupDistances(NGT::SearchContainer &sc, ObjectDistances &seeds, double (&comparator)(const void*, const void*, size_t));
|
||||
|
||||
void setupSeeds(SearchContainer &sc, ObjectDistances &seeds, ResultSet &results,
|
||||
UncheckedSet &unchecked, DistanceCheckedSet &distanceChecked);
|
||||
|
||||
#if !defined(NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET)
|
||||
void setupSeeds(SearchContainer &sc, ObjectDistances &seeds, ResultSet &results,
|
||||
UncheckedSet &unchecked, DistanceCheckedSetForLargeDataset &distanceChecked);
|
||||
#endif
|
||||
|
||||
|
||||
int getEdgeSize() {return property.edgeSizeForCreation;}
|
||||
|
||||
ObjectRepository &getObjectRepository() { return objectSpace->getRepository(); }
|
||||
|
||||
ObjectSpace &getObjectSpace() { return *objectSpace; }
|
||||
|
||||
void deleteInMemory() {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
assert(0);
|
||||
#else
|
||||
for (std::vector<NGT::GraphNode*>::iterator i = repository.begin(); i != repository.end(); i++) {
|
||||
if ((*i) != 0) {
|
||||
delete (*i);
|
||||
}
|
||||
}
|
||||
repository.clear();
|
||||
#endif
|
||||
}
|
||||
|
||||
static double (*getComparator())(const void*, const void*, size_t);
|
||||
|
||||
|
||||
protected:
|
||||
void
|
||||
addBKNNGEdge(ObjectID target, ObjectID addID, Distance addDistance) {
|
||||
if (repository.isEmpty(target)) {
|
||||
ObjectDistances objs;
|
||||
objs.push_back(ObjectDistance(addID, addDistance));
|
||||
repository.insert(target, objs);
|
||||
return;
|
||||
}
|
||||
addEdge(target, addID, addDistance, false);
|
||||
}
|
||||
|
||||
public:
|
||||
void addEdge(GraphNode &node, ObjectID addID, Distance addDistance, bool identityCheck = true) {
|
||||
ObjectDistance obj(addID, addDistance);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
GraphNode::iterator ni = std::lower_bound(node.begin(repository.allocator), node.end(repository.allocator), obj);
|
||||
if ((ni != node.end(repository.allocator)) && ((*ni).id == addID)) {
|
||||
if (identityCheck) {
|
||||
std::stringstream msg;
|
||||
msg << "NGT::addEdge: already existed! " << (*ni).id << ":" << addID;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#else
|
||||
GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), obj);
|
||||
if ((ni != node.end()) && ((*ni).id == addID)) {
|
||||
if (identityCheck) {
|
||||
std::stringstream msg;
|
||||
msg << "NGT::addEdge: already existed! " << (*ni).id << ":" << addID;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
node.insert(ni, obj, repository.allocator);
|
||||
#else
|
||||
node.insert(ni, obj);
|
||||
#endif
|
||||
}
|
||||
|
||||
// identityCheck is checking whether the same edge has already added to the node.
|
||||
// return whether truncation is needed that means the node has too many edges.
|
||||
bool addEdge(ObjectID target, ObjectID addID, Distance addDistance, bool identityCheck = true) {
|
||||
size_t minsize = 0;
|
||||
GraphNode &node = property.truncationThreshold == 0 ? *getNode(target) : *getNode(target, minsize);
|
||||
addEdge(node, addID, addDistance, identityCheck);
|
||||
if ((size_t)property.truncationThreshold != 0 && node.size() - minsize >
|
||||
(size_t)property.truncationThreshold) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void addEdgeDeletingExcessEdges(ObjectID target, ObjectID addID, Distance addDistance, bool identityCheck = true) {
|
||||
GraphNode &node = *getNode(target);
|
||||
size_t kEdge = property.edgeSizeForCreation - 1;
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
if (node.size() > kEdge && node.at(kEdge, repository.allocator).distance >= addDistance) {
|
||||
GraphNode &linkedNode = *getNode(node.at(kEdge, repository.allocator).id);
|
||||
ObjectDistance linkedNodeEdge(target, node.at(kEdge, repository.allocator).distance);
|
||||
if ((linkedNode.size() > kEdge) && node.at(kEdge, repository.allocator).distance >=
|
||||
linkedNode.at(kEdge, repository.allocator).distance) {
|
||||
#else
|
||||
if (node.size() > kEdge && node[kEdge].distance >= addDistance) {
|
||||
GraphNode &linkedNode = *getNode(node[kEdge].id);
|
||||
ObjectDistance linkedNodeEdge(target, node[kEdge].distance);
|
||||
if ((linkedNode.size() > kEdge) && node[kEdge].distance >= linkedNode[kEdge].distance) {
|
||||
#endif
|
||||
try {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
removeEdge(node, node.at(kEdge, repository.allocator));
|
||||
#else
|
||||
removeEdge(node, node[kEdge]);
|
||||
#endif
|
||||
} catch (Exception &exp) {
|
||||
std::stringstream msg;
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
msg << "addEdge: Cannot remove. (a) " << target << "," << addID << "," << node.at(kEdge, repository.allocator).id << "," << node.at(kEdge, repository.allocator).distance;
|
||||
#else
|
||||
msg << "addEdge: Cannot remove. (a) " << target << "," << addID << "," << node[kEdge].id << "," << node[kEdge].distance;
|
||||
#endif
|
||||
msg << ":" << exp.what();
|
||||
NGTThrowException(msg.str());
|
||||
}
|
||||
try {
|
||||
removeEdge(linkedNode, linkedNodeEdge);
|
||||
} catch (Exception &exp) {
|
||||
std::stringstream msg;
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
msg << "addEdge: Cannot remove. (b) " << target << "," << addID << "," << node.at(kEdge, repository.allocator).id << "," << node.at(kEdge, repository.allocator).distance;
|
||||
#else
|
||||
msg << "addEdge: Cannot remove. (b) " << target << "," << addID << "," << node[kEdge].id << "," << node[kEdge].distance;
|
||||
#endif
|
||||
msg << ":" << exp.what();
|
||||
NGTThrowException(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
addEdge(node, addID, addDistance, identityCheck);
|
||||
}
|
||||
|
||||
|
||||
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
|
||||
void loadSearchGraph(const std::string &database) {
|
||||
std::ifstream isg(database + "/grp");
|
||||
NeighborhoodGraph::searchRepository.deserialize(isg, NeighborhoodGraph::getObjectRepository());
|
||||
}
|
||||
#endif
|
||||
|
||||
public:
|
||||
|
||||
GraphRepository repository;
|
||||
ObjectSpace *objectSpace;
|
||||
|
||||
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
|
||||
SearchGraphRepository searchRepository;
|
||||
#endif
|
||||
|
||||
NeighborhoodGraph::Property property;
|
||||
|
||||
}; // NeighborhoodGraph
|
||||
|
||||
} // NGT
|
||||
|
|
@ -0,0 +1,789 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include "GraphReconstructor.h"
|
||||
#include "Optimizer.h"
|
||||
|
||||
namespace NGT {
|
||||
class GraphOptimizer {
|
||||
public:
|
||||
class ANNGEdgeOptimizationParameter {
|
||||
public:
|
||||
ANNGEdgeOptimizationParameter() {
|
||||
initialize();
|
||||
}
|
||||
void initialize() {
|
||||
noOfQueries = 200;
|
||||
noOfResults = 50;
|
||||
noOfThreads = 16;
|
||||
targetAccuracy = 0.9; // when epsilon is 0.0 and all of the edges are used
|
||||
targetNoOfObjects = 0;
|
||||
noOfSampleObjects = 100000;
|
||||
maxNoOfEdges = 100;
|
||||
}
|
||||
size_t noOfQueries;
|
||||
size_t noOfResults;
|
||||
size_t noOfThreads;
|
||||
float targetAccuracy;
|
||||
size_t targetNoOfObjects;
|
||||
size_t noOfSampleObjects;
|
||||
size_t maxNoOfEdges;
|
||||
};
|
||||
|
||||
GraphOptimizer(bool unlog = false) {
|
||||
init();
|
||||
logDisabled = unlog;
|
||||
}
|
||||
|
||||
GraphOptimizer(int outgoing, int incoming, int nofqs, int nofrs,
|
||||
float baseAccuracyFrom, float baseAccuracyTo,
|
||||
float rateAccuracyFrom, float rateAccuracyTo,
|
||||
double gte, double m,
|
||||
bool unlog // stderr log is disabled.
|
||||
) {
|
||||
init();
|
||||
set(outgoing, incoming, nofqs, nofrs, baseAccuracyFrom, baseAccuracyTo,
|
||||
rateAccuracyFrom, rateAccuracyTo, gte, m);
|
||||
logDisabled = unlog;
|
||||
}
|
||||
|
||||
void init() {
|
||||
numOfOutgoingEdges = 10;
|
||||
numOfIncomingEdges= 120;
|
||||
numOfQueries = 100;
|
||||
numOfResults = 20;
|
||||
baseAccuracyRange = std::pair<float, float>(0.30, 0.50);
|
||||
rateAccuracyRange = std::pair<float, float>(0.80, 0.90);
|
||||
gtEpsilon = 0.1;
|
||||
margin = 0.2;
|
||||
logDisabled = false;
|
||||
shortcutReduction = true;
|
||||
searchParameterOptimization = true;
|
||||
prefetchParameterOptimization = true;
|
||||
accuracyTableGeneration = true;
|
||||
}
|
||||
|
||||
void adjustSearchCoefficients(const std::string indexPath){
|
||||
NGT::Index index(indexPath);
|
||||
NGT::GraphIndex &graph = static_cast<NGT::GraphIndex&>(index.getIndex());
|
||||
NGT::Optimizer optimizer(index);
|
||||
if (logDisabled) {
|
||||
optimizer.disableLog();
|
||||
} else {
|
||||
optimizer.enableLog();
|
||||
}
|
||||
try {
|
||||
auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin);
|
||||
NGT::NeighborhoodGraph::Property &prop = graph.getGraphProperty();
|
||||
prop.dynamicEdgeSizeBase = coefficients.first;
|
||||
prop.dynamicEdgeSizeRate = coefficients.second;
|
||||
prop.edgeSizeForSearch = -2;
|
||||
} catch(NGT::Exception &err) {
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::adjustSearchCoefficients: Cannot adjust the search coefficients. " << err.what();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
graph.saveIndex(indexPath);
|
||||
}
|
||||
|
||||
static double measureQueryTime(NGT::Index &index, size_t start) {
|
||||
NGT::ObjectSpace &objectSpace = index.getObjectSpace();
|
||||
NGT::ObjectRepository &objectRepository = objectSpace.getRepository();
|
||||
size_t nQueries = 200;
|
||||
nQueries = objectRepository.size() - 1 < nQueries ? objectRepository.size() - 1 : nQueries;
|
||||
|
||||
size_t step = objectRepository.size() / nQueries;
|
||||
assert(step != 0);
|
||||
std::vector<size_t> ids;
|
||||
for (size_t startID = start; startID < step; startID++) {
|
||||
for (size_t id = startID; id < objectRepository.size(); id += step) {
|
||||
if (!objectRepository.isEmpty(id)) {
|
||||
ids.push_back(id);
|
||||
}
|
||||
}
|
||||
if (ids.size() >= nQueries) {
|
||||
ids.resize(nQueries);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (nQueries > ids.size()) {
|
||||
std::cerr << "# of Queries is not enough." << std::endl;
|
||||
return DBL_MAX;
|
||||
}
|
||||
|
||||
NGT::Timer timer;
|
||||
timer.reset();
|
||||
for (auto id = ids.begin(); id != ids.end(); id++) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
NGT::Object *obj = objectSpace.allocateObject(*objectRepository.get(*id));
|
||||
NGT::SearchContainer searchContainer(*obj);
|
||||
#else
|
||||
NGT::SearchContainer searchContainer(*objectRepository.get(*id));
|
||||
#endif
|
||||
NGT::ObjectDistances objects;
|
||||
searchContainer.setResults(&objects);
|
||||
searchContainer.setSize(10);
|
||||
searchContainer.setEpsilon(0.1);
|
||||
timer.restart();
|
||||
index.search(searchContainer);
|
||||
timer.stop();
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
objectSpace.deleteObject(obj);
|
||||
#endif
|
||||
}
|
||||
return timer.time * 1000.0;
|
||||
}
|
||||
|
||||
static std::pair<size_t, double> searchMinimumQueryTime(NGT::Index &index, size_t prefetchOffset,
|
||||
int maxPrefetchSize, size_t seedID) {
|
||||
NGT::ObjectSpace &objectSpace = index.getObjectSpace();
|
||||
int step = 256;
|
||||
int prevPrefetchSize = 64;
|
||||
size_t minPrefetchSize = 0;
|
||||
double minTime = DBL_MAX;
|
||||
for (step = 256; step != 32; step /= 2) {
|
||||
double prevTime = DBL_MAX;
|
||||
for (int prefetchSize = prevPrefetchSize - step < 64 ? 64 : prevPrefetchSize - step; prefetchSize <= maxPrefetchSize; prefetchSize += step) {
|
||||
objectSpace.setPrefetchOffset(prefetchOffset);
|
||||
objectSpace.setPrefetchSize(prefetchSize);
|
||||
double time = measureQueryTime(index, seedID);
|
||||
if (prevTime < time) {
|
||||
break;
|
||||
}
|
||||
prevTime = time;
|
||||
prevPrefetchSize = prefetchSize;
|
||||
}
|
||||
if (minTime > prevTime) {
|
||||
minTime = prevTime;
|
||||
minPrefetchSize = prevPrefetchSize;
|
||||
}
|
||||
}
|
||||
return std::make_pair(minPrefetchSize, minTime);
|
||||
}
|
||||
|
||||
static std::pair<size_t, size_t> adjustPrefetchParameters(NGT::Index &index) {
|
||||
|
||||
bool gridSearch = false;
|
||||
{
|
||||
double time = measureQueryTime(index, 1);
|
||||
if (time < 500.0) {
|
||||
gridSearch = true;
|
||||
}
|
||||
}
|
||||
|
||||
size_t prefetchOffset = 0;
|
||||
size_t prefetchSize = 0;
|
||||
std::vector<std::pair<size_t, size_t>> mins;
|
||||
NGT::ObjectSpace &objectSpace = index.getObjectSpace();
|
||||
int maxSize = objectSpace.getByteSizeOfObject() * 4;
|
||||
maxSize = maxSize < 64 * 28 ? maxSize : 64 * 28;
|
||||
for (int trial = 0; trial < 10; trial++) {
|
||||
size_t minps = 0;
|
||||
size_t minpo = 0;
|
||||
if (gridSearch) {
|
||||
double minTime = DBL_MAX;
|
||||
for (size_t po = 1; po <= 10; po++) {
|
||||
auto min = searchMinimumQueryTime(index, po, maxSize, trial + 1);
|
||||
if (minTime > min.second) {
|
||||
minTime = min.second;
|
||||
minps = min.first;
|
||||
minpo = po;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
double prevTime = DBL_MAX;
|
||||
for (size_t po = 1; po <= 10; po++) {
|
||||
auto min = searchMinimumQueryTime(index, po, maxSize, trial + 1);
|
||||
if (prevTime < min.second) {
|
||||
break;
|
||||
}
|
||||
prevTime = min.second;
|
||||
minps = min.first;
|
||||
minpo = po;
|
||||
}
|
||||
}
|
||||
if (std::find(mins.begin(), mins.end(), std::make_pair(minpo, minps)) != mins.end()) {
|
||||
prefetchOffset = minpo;
|
||||
prefetchSize = minps;
|
||||
mins.push_back(std::make_pair(minpo, minps));
|
||||
break;
|
||||
}
|
||||
mins.push_back(std::make_pair(minpo, minps));
|
||||
}
|
||||
return std::make_pair(prefetchOffset, prefetchSize);
|
||||
}
|
||||
|
||||
void execute(NGT::Index & index_)
|
||||
{
|
||||
NGT::GraphIndex & graphIndex = static_cast<NGT::GraphIndex &>(index_.getIndex());
|
||||
if (numOfOutgoingEdges > 0 || numOfIncomingEdges > 0)
|
||||
{
|
||||
if (!logDisabled)
|
||||
{
|
||||
std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
|
||||
}
|
||||
NGT::Timer timer;
|
||||
timer.start();
|
||||
std::vector<NGT::ObjectDistances> graph;
|
||||
try
|
||||
{
|
||||
std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
|
||||
// extract only edges from the index to reduce the memory usage.
|
||||
NGT::GraphReconstructor::extractGraph(graph, graphIndex);
|
||||
NeighborhoodGraph::Property & prop = graphIndex.getGraphProperty();
|
||||
if (prop.graphType != NGT::NeighborhoodGraph::GraphTypeANNG)
|
||||
{
|
||||
NGT::GraphReconstructor::convertToANNG(graph);
|
||||
}
|
||||
NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges);
|
||||
timer.stop();
|
||||
std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
|
||||
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
|
||||
}
|
||||
catch (NGT::Exception & err)
|
||||
{
|
||||
throw(err);
|
||||
}
|
||||
}
|
||||
|
||||
if (shortcutReduction)
|
||||
{
|
||||
if (!logDisabled)
|
||||
{
|
||||
std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
|
||||
}
|
||||
try
|
||||
{
|
||||
NGT::Timer timer;
|
||||
timer.start();
|
||||
NGT::GraphReconstructor::adjustPathsEffectively(graphIndex);
|
||||
timer.stop();
|
||||
std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl;
|
||||
}
|
||||
catch (NGT::Exception & err)
|
||||
{
|
||||
throw(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void optimizeSearchParameters(NGT::Index & outIndex)
|
||||
{
|
||||
if (searchParameterOptimization)
|
||||
{
|
||||
if (!logDisabled)
|
||||
{
|
||||
std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
|
||||
}
|
||||
NGT::GraphIndex & outGraph = static_cast<NGT::GraphIndex &>(outIndex.getIndex());
|
||||
NGT::Optimizer optimizer(outIndex);
|
||||
if (logDisabled)
|
||||
{
|
||||
optimizer.disableLog();
|
||||
}
|
||||
else
|
||||
{
|
||||
optimizer.enableLog();
|
||||
}
|
||||
try
|
||||
{
|
||||
auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin);
|
||||
NGT::NeighborhoodGraph::Property & prop = outGraph.getGraphProperty();
|
||||
prop.dynamicEdgeSizeBase = coefficients.first;
|
||||
prop.dynamicEdgeSizeRate = coefficients.second;
|
||||
prop.edgeSizeForSearch = -2;
|
||||
}
|
||||
catch (NGT::Exception & err)
|
||||
{
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::execute: Cannot adjust the search coefficients. " << err.what();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
}
|
||||
|
||||
if (searchParameterOptimization || prefetchParameterOptimization || accuracyTableGeneration)
|
||||
{
|
||||
// NGT::GraphIndex & outGraph = static_cast<NGT::GraphIndex &>(*outIndex.getIndex());
|
||||
if (prefetchParameterOptimization)
|
||||
{
|
||||
if (!logDisabled)
|
||||
{
|
||||
std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
|
||||
}
|
||||
try
|
||||
{
|
||||
auto prefetch = adjustPrefetchParameters(outIndex);
|
||||
NGT::Property prop;
|
||||
outIndex.getProperty(prop);
|
||||
prop.prefetchOffset = prefetch.first;
|
||||
prop.prefetchSize = prefetch.second;
|
||||
outIndex.setProperty(prop);
|
||||
}
|
||||
catch (NGT::Exception & err)
|
||||
{
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::execute: Cannot adjust prefetch parameters. " << err.what();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
}
|
||||
if (accuracyTableGeneration)
|
||||
{
|
||||
if (!logDisabled)
|
||||
{
|
||||
std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
|
||||
}
|
||||
try
|
||||
{
|
||||
auto table = NGT::Optimizer::generateAccuracyTable(outIndex, numOfResults, numOfQueries);
|
||||
NGT::Index::AccuracyTable accuracyTable(table);
|
||||
NGT::Property prop;
|
||||
outIndex.getProperty(prop);
|
||||
prop.accuracyTable = accuracyTable.getString();
|
||||
outIndex.setProperty(prop);
|
||||
}
|
||||
catch (NGT::Exception & err)
|
||||
{
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::execute: Cannot generate the accuracy table. " << err.what();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void execute(
|
||||
const std::string inIndexPath,
|
||||
const std::string outIndexPath
|
||||
){
|
||||
if (access(outIndexPath.c_str(), 0) == 0) {
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::execute: The specified index exists. " << outIndexPath;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
|
||||
const std::string com = "cp -r " + inIndexPath + " " + outIndexPath;
|
||||
int stat = system(com.c_str());
|
||||
if (stat != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::execute: Cannot create the specified index. " << outIndexPath;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
|
||||
{
|
||||
NGT::StdOstreamRedirector redirector(logDisabled);
|
||||
NGT::GraphIndex graphIndex(outIndexPath, false);
|
||||
if (numOfOutgoingEdges > 0 || numOfIncomingEdges > 0) {
|
||||
if (!logDisabled) {
|
||||
std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
|
||||
}
|
||||
redirector.begin();
|
||||
NGT::Timer timer;
|
||||
timer.start();
|
||||
std::vector<NGT::ObjectDistances> graph;
|
||||
try {
|
||||
std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
|
||||
// extract only edges from the index to reduce the memory usage.
|
||||
NGT::GraphReconstructor::extractGraph(graph, graphIndex);
|
||||
NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty();
|
||||
if (prop.graphType != NGT::NeighborhoodGraph::GraphTypeANNG) {
|
||||
NGT::GraphReconstructor::convertToANNG(graph);
|
||||
}
|
||||
NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges);
|
||||
timer.stop();
|
||||
std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
|
||||
graphIndex.saveGraph(outIndexPath);
|
||||
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
|
||||
graphIndex.saveProperty(outIndexPath);
|
||||
} catch (NGT::Exception &err) {
|
||||
redirector.end();
|
||||
throw(err);
|
||||
}
|
||||
}
|
||||
|
||||
if (shortcutReduction) {
|
||||
if (!logDisabled) {
|
||||
std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
|
||||
}
|
||||
try {
|
||||
NGT::Timer timer;
|
||||
timer.start();
|
||||
NGT::GraphReconstructor::adjustPathsEffectively(graphIndex);
|
||||
timer.stop();
|
||||
std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl;
|
||||
graphIndex.saveGraph(outIndexPath);
|
||||
} catch (NGT::Exception &err) {
|
||||
redirector.end();
|
||||
throw(err);
|
||||
}
|
||||
}
|
||||
redirector.end();
|
||||
}
|
||||
|
||||
optimizeSearchParameters(outIndexPath);
|
||||
|
||||
}
|
||||
|
||||
void optimizeSearchParameters(const std::string outIndexPath)
|
||||
{
|
||||
|
||||
if (searchParameterOptimization) {
|
||||
if (!logDisabled) {
|
||||
std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
|
||||
}
|
||||
NGT::Index outIndex(outIndexPath);
|
||||
NGT::GraphIndex &outGraph = static_cast<NGT::GraphIndex&>(outIndex.getIndex());
|
||||
NGT::Optimizer optimizer(outIndex);
|
||||
if (logDisabled) {
|
||||
optimizer.disableLog();
|
||||
} else {
|
||||
optimizer.enableLog();
|
||||
}
|
||||
try {
|
||||
auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin);
|
||||
NGT::NeighborhoodGraph::Property &prop = outGraph.getGraphProperty();
|
||||
prop.dynamicEdgeSizeBase = coefficients.first;
|
||||
prop.dynamicEdgeSizeRate = coefficients.second;
|
||||
prop.edgeSizeForSearch = -2;
|
||||
outGraph.saveProperty(outIndexPath);
|
||||
} catch(NGT::Exception &err) {
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::execute: Cannot adjust the search coefficients. " << err.what();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
}
|
||||
|
||||
if (searchParameterOptimization || prefetchParameterOptimization || accuracyTableGeneration) {
|
||||
NGT::StdOstreamRedirector redirector(logDisabled);
|
||||
redirector.begin();
|
||||
NGT::Index outIndex(outIndexPath, true);
|
||||
NGT::GraphIndex &outGraph = static_cast<NGT::GraphIndex&>(outIndex.getIndex());
|
||||
if (prefetchParameterOptimization) {
|
||||
if (!logDisabled) {
|
||||
std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
|
||||
}
|
||||
try {
|
||||
auto prefetch = adjustPrefetchParameters(outIndex);
|
||||
NGT::Property prop;
|
||||
outIndex.getProperty(prop);
|
||||
prop.prefetchOffset = prefetch.first;
|
||||
prop.prefetchSize = prefetch.second;
|
||||
outIndex.setProperty(prop);
|
||||
outGraph.saveProperty(outIndexPath);
|
||||
} catch(NGT::Exception &err) {
|
||||
redirector.end();
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::execute: Cannot adjust prefetch parameters. " << err.what();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
}
|
||||
if (accuracyTableGeneration) {
|
||||
if (!logDisabled) {
|
||||
std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
|
||||
}
|
||||
try {
|
||||
auto table = NGT::Optimizer::generateAccuracyTable(outIndex, numOfResults, numOfQueries);
|
||||
NGT::Index::AccuracyTable accuracyTable(table);
|
||||
NGT::Property prop;
|
||||
outIndex.getProperty(prop);
|
||||
prop.accuracyTable = accuracyTable.getString();
|
||||
outIndex.setProperty(prop);
|
||||
} catch(NGT::Exception &err) {
|
||||
redirector.end();
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::execute: Cannot generate the accuracy table. " << err.what();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
}
|
||||
try {
|
||||
outGraph.saveProperty(outIndexPath);
|
||||
redirector.end();
|
||||
} catch(NGT::Exception &err) {
|
||||
redirector.end();
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::execute: Cannot save the index. " << outIndexPath << err.what();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
static std::tuple<size_t, double, double> // optimized # of edges, accuracy, accuracy gain per edge
|
||||
optimizeNumberOfEdgesForANNG(NGT::Optimizer &optimizer, std::vector<std::vector<float>> &queries,
|
||||
size_t nOfResults, float targetAccuracy, size_t maxNoOfEdges) {
|
||||
|
||||
NGT::Index &index = optimizer.index;
|
||||
std::stringstream queryStream;
|
||||
std::stringstream gtStream;
|
||||
float maxEpsilon = 0.0;
|
||||
|
||||
optimizer.generatePseudoGroundTruth(queries, maxEpsilon, queryStream, gtStream);
|
||||
|
||||
size_t nOfEdges = 0;
|
||||
double accuracy = 0.0;
|
||||
size_t prevEdge = 0;
|
||||
double prevAccuracy = 0.0;
|
||||
double gain = 0.0;
|
||||
{
|
||||
std::vector<NGT::ObjectDistances> graph;
|
||||
NGT::GraphReconstructor::extractGraph(graph, static_cast<NGT::GraphIndex&>(index.getIndex()));
|
||||
float epsilon = 0.0;
|
||||
for (size_t edgeSize = 5; edgeSize <= maxNoOfEdges; edgeSize += (edgeSize >= 10 ? 10 : 5) ) {
|
||||
NGT::GraphReconstructor::reconstructANNGFromANNG(graph, index, edgeSize);
|
||||
NGT::Command::SearchParameter searchParameter;
|
||||
searchParameter.size = nOfResults;
|
||||
searchParameter.outputMode = 'e';
|
||||
searchParameter.edgeSize = 0;
|
||||
searchParameter.beginOfEpsilon = searchParameter.endOfEpsilon = epsilon;
|
||||
queryStream.clear();
|
||||
queryStream.seekg(0, std::ios_base::beg);
|
||||
std::vector<NGT::Optimizer::MeasuredValue> acc;
|
||||
NGT::Optimizer::search(index, queryStream, gtStream, searchParameter, acc);
|
||||
if (acc.size() == 0) {
|
||||
NGTThrowException("Fatal error! Cannot get any accuracy value.");
|
||||
}
|
||||
accuracy = acc[0].meanAccuracy;
|
||||
nOfEdges = edgeSize;
|
||||
if (prevEdge != 0) {
|
||||
gain = (accuracy - prevAccuracy) / (edgeSize - prevEdge);
|
||||
}
|
||||
if (accuracy >= targetAccuracy) {
|
||||
break;
|
||||
}
|
||||
prevEdge = edgeSize;
|
||||
prevAccuracy = accuracy;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(nOfEdges, accuracy, gain);
|
||||
}
|
||||
|
||||
static std::pair<size_t, float>
|
||||
optimizeNumberOfEdgesForANNG(NGT::Index &index, ANNGEdgeOptimizationParameter ¶meter)
|
||||
{
|
||||
if (parameter.targetNoOfObjects == 0) {
|
||||
parameter.targetNoOfObjects = index.getObjectRepositorySize();
|
||||
}
|
||||
|
||||
NGT::Optimizer optimizer(index, parameter.noOfResults);
|
||||
|
||||
NGT::ObjectRepository &objectRepository = index.getObjectSpace().getRepository();
|
||||
NGT::GraphIndex &graphIndex = static_cast<NGT::GraphIndex&>(index.getIndex());
|
||||
NGT::GraphAndTreeIndex &treeIndex = static_cast<NGT::GraphAndTreeIndex&>(index.getIndex());
|
||||
NGT::GraphRepository &graphRepository = graphIndex.NeighborhoodGraph::repository;
|
||||
//float targetAccuracy = parameter.targetAccuracy + FLT_EPSILON;
|
||||
|
||||
std::vector<std::vector<float>> queries;
|
||||
optimizer.extractAndRemoveRandomQueries(parameter.noOfQueries, queries);
|
||||
{
|
||||
graphRepository.deleteAll();
|
||||
treeIndex.DVPTree::deleteAll();
|
||||
treeIndex.DVPTree::insertNode(treeIndex.DVPTree::leafNodes.allocate());
|
||||
}
|
||||
|
||||
NGT::NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty();
|
||||
prop.edgeSizeForCreation = parameter.maxNoOfEdges;
|
||||
std::vector<std::pair<size_t, std::tuple<size_t, double, double>>> transition;
|
||||
size_t targetNo = 12500;
|
||||
for (;targetNo <= objectRepository.size() && targetNo <= parameter.noOfSampleObjects;
|
||||
targetNo *= 2) {
|
||||
ObjectID id = 0;
|
||||
size_t noOfObjects = 0;
|
||||
for (id = 1; id < objectRepository.size(); ++id) {
|
||||
if (!objectRepository.isEmpty(id)) {
|
||||
noOfObjects++;
|
||||
}
|
||||
if (noOfObjects >= targetNo) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
id++;
|
||||
index.createIndex(parameter.noOfThreads, id);
|
||||
auto edge = NGT::GraphOptimizer::optimizeNumberOfEdgesForANNG(optimizer, queries, parameter.noOfResults, parameter.targetAccuracy, parameter.maxNoOfEdges);
|
||||
transition.push_back(make_pair(noOfObjects, edge));
|
||||
}
|
||||
if (transition.size() < 2) {
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::optimizeNumberOfEdgesForANNG: Cannot optimize the number of edges. Too small object set. # of objects=" << objectRepository.size() << " target No.=" << targetNo;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
double edgeRate = 0.0;
|
||||
double accuracyRate = 0.0;
|
||||
for (auto i = transition.begin(); i != transition.end() - 1; ++i) {
|
||||
edgeRate += std::get<0>((*(i + 1)).second) - std::get<0>((*i).second);
|
||||
accuracyRate += std::get<1>((*(i + 1)).second) - std::get<1>((*i).second);
|
||||
}
|
||||
edgeRate /= (transition.size() - 1);
|
||||
accuracyRate /= (transition.size() - 1);
|
||||
size_t estimatedEdge = std::get<0>(transition[0].second) +
|
||||
edgeRate * (log2(parameter.targetNoOfObjects) - log2(transition[0].first));
|
||||
float estimatedAccuracy = std::get<1>(transition[0].second) +
|
||||
accuracyRate * (log2(parameter.targetNoOfObjects) - log2(transition[0].first));
|
||||
if (estimatedAccuracy < parameter.targetAccuracy) {
|
||||
estimatedEdge += (parameter.targetAccuracy - estimatedAccuracy) / std::get<2>(transition.back().second);
|
||||
estimatedAccuracy = parameter.targetAccuracy;
|
||||
}
|
||||
|
||||
if (estimatedEdge == 0) {
|
||||
std::stringstream msg;
|
||||
msg << "Optimizer::optimizeNumberOfEdgesForANNG: Cannot optimize the number of edges. "
|
||||
<< estimatedEdge << ":" << estimatedAccuracy << " # of objects=" << objectRepository.size();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
|
||||
return std::make_pair(estimatedEdge, estimatedAccuracy);
|
||||
}
|
||||
|
||||
std::pair<size_t, float>
|
||||
optimizeNumberOfEdgesForANNG(const std::string indexPath, GraphOptimizer::ANNGEdgeOptimizationParameter ¶meter) {
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
NGTThrowException("Not implemented for NGT with the shared memory option.");
|
||||
#endif
|
||||
|
||||
NGT::StdOstreamRedirector redirector(logDisabled);
|
||||
redirector.begin();
|
||||
|
||||
try {
|
||||
NGT::Index index(indexPath, false);
|
||||
|
||||
auto optimizedEdge = NGT::GraphOptimizer::optimizeNumberOfEdgesForANNG(index, parameter);
|
||||
|
||||
|
||||
NGT::GraphIndex &graph = static_cast<NGT::GraphIndex&>(index.getIndex());
|
||||
size_t noOfEdges = (optimizedEdge.first + 10) / 5 * 5;
|
||||
if (noOfEdges > parameter.maxNoOfEdges) {
|
||||
noOfEdges = parameter.maxNoOfEdges;
|
||||
}
|
||||
|
||||
NGT::NeighborhoodGraph::Property &prop = graph.getGraphProperty();
|
||||
prop.edgeSizeForCreation = noOfEdges;
|
||||
static_cast<NGT::GraphIndex&>(index.getIndex()).saveProperty(indexPath);
|
||||
optimizedEdge.first = noOfEdges;
|
||||
redirector.end();
|
||||
return optimizedEdge;
|
||||
} catch (NGT::Exception &err) {
|
||||
redirector.end();
|
||||
throw(err);
|
||||
}
|
||||
}
|
||||
|
||||
void set(int outgoing, int incoming, int nofqs, int nofrs,
|
||||
float baseAccuracyFrom, float baseAccuracyTo,
|
||||
float rateAccuracyFrom, float rateAccuracyTo,
|
||||
double gte, double m
|
||||
) {
|
||||
set(outgoing, incoming, nofqs, nofrs);
|
||||
setExtension(baseAccuracyFrom, baseAccuracyTo, rateAccuracyFrom, rateAccuracyTo, gte, m);
|
||||
}
|
||||
|
||||
void set(int outgoing, int incoming, int nofqs, int nofrs) {
|
||||
if (outgoing >= 0) {
|
||||
numOfOutgoingEdges = outgoing;
|
||||
}
|
||||
if (incoming >= 0) {
|
||||
numOfIncomingEdges = incoming;
|
||||
}
|
||||
if (nofqs > 0) {
|
||||
numOfQueries = nofqs;
|
||||
}
|
||||
if (nofrs > 0) {
|
||||
numOfResults = nofrs;
|
||||
}
|
||||
}
|
||||
|
||||
void setExtension(float baseAccuracyFrom, float baseAccuracyTo,
|
||||
float rateAccuracyFrom, float rateAccuracyTo,
|
||||
double gte, double m
|
||||
) {
|
||||
if (baseAccuracyFrom > 0.0) {
|
||||
baseAccuracyRange.first = baseAccuracyFrom;
|
||||
}
|
||||
if (baseAccuracyTo > 0.0) {
|
||||
baseAccuracyRange.second = baseAccuracyTo;
|
||||
}
|
||||
if (rateAccuracyFrom > 0.0) {
|
||||
rateAccuracyRange.first = rateAccuracyFrom;
|
||||
}
|
||||
if (rateAccuracyTo > 0.0) {
|
||||
rateAccuracyRange.second = rateAccuracyTo;
|
||||
}
|
||||
if (gte >= -1.0) {
|
||||
gtEpsilon = gte;
|
||||
}
|
||||
if (m > 0.0) {
|
||||
margin = m;
|
||||
}
|
||||
}
|
||||
|
||||
// obsolete because of a lack of a parameter
|
||||
void set(int outgoing, int incoming, int nofqs,
|
||||
float baseAccuracyFrom, float baseAccuracyTo,
|
||||
float rateAccuracyFrom, float rateAccuracyTo,
|
||||
double gte, double m
|
||||
) {
|
||||
if (outgoing >= 0) {
|
||||
numOfOutgoingEdges = outgoing;
|
||||
}
|
||||
if (incoming >= 0) {
|
||||
numOfIncomingEdges = incoming;
|
||||
}
|
||||
if (nofqs > 0) {
|
||||
numOfQueries = nofqs;
|
||||
}
|
||||
if (baseAccuracyFrom > 0.0) {
|
||||
baseAccuracyRange.first = baseAccuracyFrom;
|
||||
}
|
||||
if (baseAccuracyTo > 0.0) {
|
||||
baseAccuracyRange.second = baseAccuracyTo;
|
||||
}
|
||||
if (rateAccuracyFrom > 0.0) {
|
||||
rateAccuracyRange.first = rateAccuracyFrom;
|
||||
}
|
||||
if (rateAccuracyTo > 0.0) {
|
||||
rateAccuracyRange.second = rateAccuracyTo;
|
||||
}
|
||||
if (gte >= -1.0) {
|
||||
gtEpsilon = gte;
|
||||
}
|
||||
if (m > 0.0) {
|
||||
margin = m;
|
||||
}
|
||||
}
|
||||
|
||||
void setProcessingModes(bool shortcut = true, bool searchParameter = true, bool prefetchParameter = true,
|
||||
bool accuracyTable = true) {
|
||||
shortcutReduction = shortcut;
|
||||
searchParameterOptimization = searchParameter;
|
||||
prefetchParameterOptimization = prefetchParameter;
|
||||
accuracyTableGeneration = accuracyTable;
|
||||
}
|
||||
|
||||
size_t numOfOutgoingEdges;
|
||||
size_t numOfIncomingEdges;
|
||||
std::pair<float, float> baseAccuracyRange;
|
||||
std::pair<float, float> rateAccuracyRange;
|
||||
size_t numOfQueries;
|
||||
size_t numOfResults;
|
||||
double gtEpsilon;
|
||||
double margin;
|
||||
bool logDisabled;
|
||||
bool shortcutReduction;
|
||||
bool searchParameterOptimization;
|
||||
bool prefetchParameterOptimization;
|
||||
bool accuracyTableGeneration;
|
||||
};
|
||||
|
||||
}; // NGT
|
||||
|
|
@ -0,0 +1,907 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <list>
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#else
|
||||
#warning "*** OMP is *NOT* available! ***"
|
||||
#endif
|
||||
|
||||
namespace NGT {
|
||||
|
||||
class GraphReconstructor {
|
||||
public:
|
||||
static void extractGraph(std::vector<NGT::ObjectDistances> &graph, NGT::GraphIndex &graphIndex) {
|
||||
graph.reserve(graphIndex.repository.size());
|
||||
for (size_t id = 1; id < graphIndex.repository.size(); id++) {
|
||||
if (id % 1000000 == 0) {
|
||||
std::cerr << "GraphReconstructor::extractGraph: Processed " << id << " objects." << std::endl;
|
||||
}
|
||||
try {
|
||||
NGT::GraphNode &node = *graphIndex.getNode(id);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
NGT::ObjectDistances nd;
|
||||
nd.reserve(node.size());
|
||||
for (auto n = node.begin(graphIndex.repository.allocator); n != node.end(graphIndex.repository.allocator); ++n) {
|
||||
nd.push_back(ObjectDistance((*n).id, (*n).distance));
|
||||
}
|
||||
graph.push_back(nd);
|
||||
#else
|
||||
graph.push_back(node);
|
||||
#endif
|
||||
if (graph.back().size() != graph.back().capacity()) {
|
||||
std::cerr << "GraphReconstructor::extractGraph: Warning! The graph size must be the same as the capacity. " << id << std::endl;
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
graph.push_back(NGT::ObjectDistances());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
static void
|
||||
adjustPaths(NGT::Index &outIndex)
|
||||
{
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "construct index is not implemented." << std::endl;
|
||||
exit(1);
|
||||
#else
|
||||
NGT::GraphIndex &outGraph = dynamic_cast<NGT::GraphIndex&>(outIndex.getIndex());
|
||||
size_t rStartRank = 0;
|
||||
std::list<std::pair<size_t, NGT::GraphNode> > tmpGraph;
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
NGT::GraphNode &node = *outGraph.getNode(id);
|
||||
tmpGraph.push_back(std::pair<size_t, NGT::GraphNode>(id, node));
|
||||
if (node.size() > rStartRank) {
|
||||
node.resize(rStartRank);
|
||||
}
|
||||
}
|
||||
size_t removeCount = 0;
|
||||
for (size_t rank = rStartRank; ; rank++) {
|
||||
bool edge = false;
|
||||
Timer timer;
|
||||
for (auto it = tmpGraph.begin(); it != tmpGraph.end();) {
|
||||
size_t id = (*it).first;
|
||||
try {
|
||||
NGT::GraphNode &node = (*it).second;
|
||||
if (rank >= node.size()) {
|
||||
it = tmpGraph.erase(it);
|
||||
continue;
|
||||
}
|
||||
edge = true;
|
||||
if (rank >= 1 && node[rank - 1].distance > node[rank].distance) {
|
||||
std::cerr << "distance order is wrong!" << std::endl;
|
||||
std::cerr << id << ":" << rank << ":" << node[rank - 1].id << ":" << node[rank].id << std::endl;
|
||||
}
|
||||
NGT::GraphNode &tn = *outGraph.getNode(id);
|
||||
volatile bool found = false;
|
||||
if (rank < 1000) {
|
||||
for (size_t tni = 0; tni < tn.size() && !found; tni++) {
|
||||
if (tn[tni].id == node[rank].id) {
|
||||
continue;
|
||||
}
|
||||
NGT::GraphNode &dstNode = *outGraph.getNode(tn[tni].id);
|
||||
for (size_t dni = 0; dni < dstNode.size(); dni++) {
|
||||
if ((dstNode[dni].id == node[rank].id) && (dstNode[dni].distance < node[rank].distance)) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for num_threads(10)
|
||||
#endif
|
||||
for (size_t tni = 0; tni < tn.size(); tni++) {
|
||||
if (found) {
|
||||
continue;
|
||||
}
|
||||
if (tn[tni].id == node[rank].id) {
|
||||
continue;
|
||||
}
|
||||
NGT::GraphNode &dstNode = *outGraph.getNode(tn[tni].id);
|
||||
for (size_t dni = 0; dni < dstNode.size(); dni++) {
|
||||
if ((dstNode[dni].id == node[rank].id) && (dstNode[dni].distance < node[rank].distance)) {
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
outGraph.addEdge(id, node.at(i, outGraph.repository.allocator).id,
|
||||
node.at(i, outGraph.repository.allocator).distance, true);
|
||||
#else
|
||||
tn.push_back(NGT::ObjectDistance(node[rank].id, node[rank].distance));
|
||||
#endif
|
||||
} else {
|
||||
removeCount++;
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
it++;
|
||||
continue;
|
||||
}
|
||||
it++;
|
||||
}
|
||||
if (edge == false) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
}
|
||||
|
||||
static void
|
||||
adjustPathsEffectively(NGT::Index &outIndex)
|
||||
{
|
||||
NGT::GraphIndex &outGraph = dynamic_cast<NGT::GraphIndex&>(outIndex.getIndex());
|
||||
adjustPathsEffectively(outGraph);
|
||||
}
|
||||
|
||||
static bool edgeComp(NGT::ObjectDistance a, NGT::ObjectDistance b) {
|
||||
return a.id < b.id;
|
||||
}
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
static void insert(NGT::GraphNode &node, size_t edgeID, NGT::Distance edgeDistance, NGT::GraphIndex &graph) {
|
||||
NGT::ObjectDistance edge(edgeID, edgeDistance);
|
||||
GraphNode::iterator ni = std::lower_bound(node.begin(graph.repository.allocator), node.end(graph.repository.allocator), edge, edgeComp);
|
||||
node.insert(ni, edge, graph.repository.allocator);
|
||||
}
|
||||
|
||||
static bool hasEdge(NGT::GraphIndex &graph, size_t srcNodeID, size_t dstNodeID)
|
||||
{
|
||||
NGT::GraphNode &srcNode = *graph.getNode(srcNodeID);
|
||||
GraphNode::iterator ni = std::lower_bound(srcNode.begin(graph.repository.allocator), srcNode.end(graph.repository.allocator), ObjectDistance(dstNodeID, 0.0), edgeComp);
|
||||
return (ni != srcNode.end(graph.repository.allocator)) && ((*ni).id == dstNodeID);
|
||||
}
|
||||
#else
|
||||
static void insert(NGT::GraphNode &node, size_t edgeID, NGT::Distance edgeDistance) {
|
||||
NGT::ObjectDistance edge(edgeID, edgeDistance);
|
||||
GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), edge, edgeComp);
|
||||
node.insert(ni, edge);
|
||||
}
|
||||
|
||||
static bool hasEdge(NGT::GraphIndex &graph, size_t srcNodeID, size_t dstNodeID)
|
||||
{
|
||||
NGT::GraphNode &srcNode = *graph.getNode(srcNodeID);
|
||||
GraphNode::iterator ni = std::lower_bound(srcNode.begin(), srcNode.end(), ObjectDistance(dstNodeID, 0.0), edgeComp);
|
||||
return (ni != srcNode.end()) && ((*ni).id == dstNodeID);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
static void
|
||||
adjustPathsEffectively(NGT::GraphIndex &outGraph)
|
||||
{
|
||||
Timer timer;
|
||||
timer.start();
|
||||
std::vector<NGT::GraphNode> tmpGraph;
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
try {
|
||||
NGT::GraphNode &node = *outGraph.getNode(id);
|
||||
tmpGraph.push_back(node);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
node.clear(outGraph.repository.allocator);
|
||||
#else
|
||||
node.clear();
|
||||
#endif
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
tmpGraph.push_back(NGT::GraphNode(outGraph.repository.allocator));
|
||||
#else
|
||||
tmpGraph.push_back(NGT::GraphNode());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if (outGraph.repository.size() != tmpGraph.size() + 1) {
|
||||
std::stringstream msg;
|
||||
msg << "GraphReconstructor: Fatal inner error. " << outGraph.repository.size() << ":" << tmpGraph.size();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
timer.stop();
|
||||
std::cerr << "GraphReconstructor::adjustPaths: graph preparing time=" << timer << std::endl;
|
||||
timer.reset();
|
||||
timer.start();
|
||||
|
||||
std::vector<std::vector<std::pair<uint32_t, uint32_t> > > removeCandidates(tmpGraph.size());
|
||||
int removeCandidateCount = 0;
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (size_t idx = 0; idx < tmpGraph.size(); ++idx) {
|
||||
auto it = tmpGraph.begin() + idx;
|
||||
size_t id = idx + 1;
|
||||
try {
|
||||
NGT::GraphNode &srcNode = *it;
|
||||
std::unordered_map<uint32_t, std::pair<size_t, double> > neighbors;
|
||||
for (size_t sni = 0; sni < srcNode.size(); ++sni) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
neighbors[srcNode.at(sni, outGraph.repository.allocator).id] = std::pair<size_t, double>(sni, srcNode.at(sni, outGraph.repository.allocator).distance);
|
||||
#else
|
||||
neighbors[srcNode[sni].id] = std::pair<size_t, double>(sni, srcNode[sni].distance);
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<std::pair<int, std::pair<uint32_t, uint32_t> > > candidates;
|
||||
for (size_t sni = 0; sni < srcNode.size(); sni++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
NGT::GraphNode &pathNode = tmpGraph[srcNode.at(sni, outGraph.repository.allocator).id - 1];
|
||||
#else
|
||||
NGT::GraphNode &pathNode = tmpGraph[srcNode[sni].id - 1];
|
||||
#endif
|
||||
for (size_t pni = 0; pni < pathNode.size(); pni++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
auto dstNodeID = pathNode.at(pni, outGraph.repository.allocator).id;
|
||||
#else
|
||||
auto dstNodeID = pathNode[pni].id;
|
||||
#endif
|
||||
auto dstNode = neighbors.find(dstNodeID);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
if (dstNode != neighbors.end()
|
||||
&& srcNode.at(sni, outGraph.repository.allocator).distance < (*dstNode).second.second
|
||||
&& pathNode.at(pni, outGraph.repository.allocator).distance < (*dstNode).second.second
|
||||
) {
|
||||
#else
|
||||
if (dstNode != neighbors.end()
|
||||
&& srcNode[sni].distance < (*dstNode).second.second
|
||||
&& pathNode[pni].distance < (*dstNode).second.second
|
||||
) {
|
||||
#endif
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
candidates.push_back(std::pair<int, std::pair<uint32_t, uint32_t> >((*dstNode).second.first, std::pair<uint32_t, uint32_t>(srcNode.at(sni, outGraph.repository.allocator).id, dstNodeID)));
|
||||
#else
|
||||
candidates.push_back(std::pair<int, std::pair<uint32_t, uint32_t> >((*dstNode).second.first, std::pair<uint32_t, uint32_t>(srcNode[sni].id, dstNodeID)));
|
||||
#endif
|
||||
removeCandidateCount++;
|
||||
}
|
||||
}
|
||||
}
|
||||
sort(candidates.begin(), candidates.end(), std::greater<std::pair<int, std::pair<uint32_t, uint32_t>>>());
|
||||
removeCandidates[id - 1].reserve(candidates.size());
|
||||
for (size_t i = 0; i < candidates.size(); i++) {
|
||||
removeCandidates[id - 1].push_back(candidates[i].second);
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
timer.stop();
|
||||
std::cerr << "GraphReconstructor::adjustPaths: extracting removed edge candidates time=" << timer << std::endl;
|
||||
timer.reset();
|
||||
timer.start();
|
||||
|
||||
std::list<size_t> ids;
|
||||
for (size_t idx = 0; idx < tmpGraph.size(); ++idx) {
|
||||
ids.push_back(idx + 1);
|
||||
}
|
||||
|
||||
int removeCount = 0;
|
||||
removeCandidateCount = 0;
|
||||
for (size_t rank = 0; ids.size() != 0; rank++) {
|
||||
for (auto it = ids.begin(); it != ids.end(); ) {
|
||||
size_t id = *it;
|
||||
size_t idx = id - 1;
|
||||
try {
|
||||
NGT::GraphNode &srcNode = tmpGraph[idx];
|
||||
if (rank >= srcNode.size()) {
|
||||
if (!removeCandidates[idx].empty()) {
|
||||
std::cerr << "Something wrong! ID=" << id << " # of remaining candidates=" << removeCandidates[idx].size() << std::endl;
|
||||
abort();
|
||||
}
|
||||
#if !defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
NGT::GraphNode empty;
|
||||
tmpGraph[idx] = empty;
|
||||
#endif
|
||||
it = ids.erase(it);
|
||||
continue;
|
||||
}
|
||||
if (removeCandidates[idx].size() > 0) {
|
||||
removeCandidateCount++;
|
||||
bool pathExist = false;
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode.at(rank, outGraph.repository.allocator).id)) {
|
||||
#else
|
||||
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode[rank].id)) {
|
||||
#endif
|
||||
size_t path = removeCandidates[idx].back().first;
|
||||
size_t dst = removeCandidates[idx].back().second;
|
||||
removeCandidates[idx].pop_back();
|
||||
if (removeCandidates[idx].empty()) {
|
||||
std::vector<std::pair<uint32_t, uint32_t>> empty;
|
||||
removeCandidates[idx] = empty;
|
||||
}
|
||||
if ((hasEdge(outGraph, id, path)) && (hasEdge(outGraph, path, dst))) {
|
||||
pathExist = true;
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode.at(rank, outGraph.repository.allocator).id)) {
|
||||
#else
|
||||
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode[rank].id)) {
|
||||
#endif
|
||||
removeCandidates[idx].pop_back();
|
||||
if (removeCandidates[idx].empty()) {
|
||||
std::vector<std::pair<uint32_t, uint32_t>> empty;
|
||||
removeCandidates[idx] = empty;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (pathExist) {
|
||||
removeCount++;
|
||||
it++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
NGT::GraphNode &outSrcNode = *outGraph.getNode(id);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
insert(outSrcNode, srcNode.at(rank, outGraph.repository.allocator).id, srcNode.at(rank, outGraph.repository.allocator).distance, outGraph);
|
||||
#else
|
||||
insert(outSrcNode, srcNode[rank].id, srcNode[rank].distance);
|
||||
#endif
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
it++;
|
||||
continue;
|
||||
}
|
||||
it++;
|
||||
}
|
||||
}
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
try {
|
||||
NGT::GraphNode &node = *outGraph.getNode(id);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::sort(node.begin(outGraph.repository.allocator), node.end(outGraph.repository.allocator));
|
||||
#else
|
||||
std::sort(node.begin(), node.end());
|
||||
#endif
|
||||
} catch(...) {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static
|
||||
void convertToANNG(std::vector<NGT::ObjectDistances> &graph)
|
||||
{
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "convertToANNG is not implemented for shared memory." << std::endl;
|
||||
return;
|
||||
#else
|
||||
std::cerr << "convertToANNG begin" << std::endl;
|
||||
for (size_t idx = 0; idx < graph.size(); idx++) {
|
||||
NGT::GraphNode &node = graph[idx];
|
||||
for (auto ni = node.begin(); ni != node.end(); ++ni) {
|
||||
graph[(*ni).id - 1].push_back(NGT::ObjectDistance(idx + 1, (*ni).distance));
|
||||
}
|
||||
}
|
||||
for (size_t idx = 0; idx < graph.size(); idx++) {
|
||||
NGT::GraphNode &node = graph[idx];
|
||||
if (node.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
std::sort(node.begin(), node.end());
|
||||
NGT::ObjectID prev = 0;
|
||||
for (auto it = node.begin(); it != node.end();) {
|
||||
if (prev == (*it).id) {
|
||||
it = node.erase(it);
|
||||
continue;
|
||||
}
|
||||
prev = (*it).id;
|
||||
it++;
|
||||
}
|
||||
NGT::GraphNode tmp = node;
|
||||
node.swap(tmp);
|
||||
}
|
||||
std::cerr << "convertToANNG end" << std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
static
|
||||
void reconstructGraph(std::vector<NGT::ObjectDistances> &graph, NGT::GraphIndex &outGraph, size_t originalEdgeSize, size_t reverseEdgeSize)
|
||||
{
|
||||
if (reverseEdgeSize > 10000) {
|
||||
std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
NGT::Timer originalEdgeTimer, reverseEdgeTimer, normalizeEdgeTimer;
|
||||
originalEdgeTimer.start();
|
||||
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
try {
|
||||
NGT::GraphNode &node = *outGraph.getNode(id);
|
||||
if (originalEdgeSize == 0) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
node.clear(outGraph.repository.allocator);
|
||||
#else
|
||||
NGT::GraphNode empty;
|
||||
node.swap(empty);
|
||||
#endif
|
||||
} else {
|
||||
NGT::ObjectDistances n = graph[id - 1];
|
||||
if (n.size() < originalEdgeSize) {
|
||||
std::cerr << "GraphReconstructor: Warning. The edges are too few. " << n.size() << ":" << originalEdgeSize << " for " << id << std::endl;
|
||||
continue;
|
||||
}
|
||||
n.resize(originalEdgeSize);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
node.copy(n, outGraph.repository.allocator);
|
||||
#else
|
||||
node.swap(n);
|
||||
#endif
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
originalEdgeTimer.stop();
|
||||
|
||||
reverseEdgeTimer.start();
|
||||
int insufficientNodeCount = 0;
|
||||
for (size_t id = 1; id <= graph.size(); ++id) {
|
||||
try {
|
||||
NGT::ObjectDistances &node = graph[id - 1];
|
||||
size_t rsize = reverseEdgeSize;
|
||||
if (rsize > node.size()) {
|
||||
insufficientNodeCount++;
|
||||
rsize = node.size();
|
||||
}
|
||||
for (size_t i = 0; i < rsize; ++i) {
|
||||
NGT::Distance distance = node[i].distance;
|
||||
size_t nodeID = node[i].id;
|
||||
try {
|
||||
NGT::GraphNode &n = *outGraph.getNode(nodeID);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
n.push_back(NGT::ObjectDistance(id, distance), outGraph.repository.allocator);
|
||||
#else
|
||||
n.push_back(NGT::ObjectDistance(id, distance));
|
||||
#endif
|
||||
} catch(...) {}
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
reverseEdgeTimer.stop();
|
||||
if (insufficientNodeCount != 0) {
|
||||
std::cerr << "# of the nodes edges of which are in short = " << insufficientNodeCount << std::endl;
|
||||
}
|
||||
|
||||
normalizeEdgeTimer.start();
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
try {
|
||||
NGT::GraphNode &n = *outGraph.getNode(id);
|
||||
if (id % 100000 == 0) {
|
||||
std::cerr << "Processed " << id << " nodes" << std::endl;
|
||||
}
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::sort(n.begin(outGraph.repository.allocator), n.end(outGraph.repository.allocator));
|
||||
#else
|
||||
std::sort(n.begin(), n.end());
|
||||
#endif
|
||||
NGT::ObjectID prev = 0;
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
for (auto it = n.begin(outGraph.repository.allocator); it != n.end(outGraph.repository.allocator);) {
|
||||
#else
|
||||
for (auto it = n.begin(); it != n.end();) {
|
||||
#endif
|
||||
if (prev == (*it).id) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
it = n.erase(it, outGraph.repository.allocator);
|
||||
#else
|
||||
it = n.erase(it);
|
||||
#endif
|
||||
continue;
|
||||
}
|
||||
prev = (*it).id;
|
||||
it++;
|
||||
}
|
||||
#if !defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
NGT::GraphNode tmp = n;
|
||||
n.swap(tmp);
|
||||
#endif
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
normalizeEdgeTimer.stop();
|
||||
std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
|
||||
<< ":" << normalizeEdgeTimer.time << std::endl;
|
||||
|
||||
NGT::Property prop;
|
||||
outGraph.getProperty().get(prop);
|
||||
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
|
||||
outGraph.getProperty().set(prop);
|
||||
}
|
||||
|
||||
|
||||
|
||||
static
|
||||
void reconstructGraphWithConstraint(std::vector<NGT::ObjectDistances> &graph, NGT::GraphIndex &outGraph,
|
||||
size_t originalEdgeSize, size_t reverseEdgeSize,
|
||||
char mode = 'a')
|
||||
{
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "reconstructGraphWithConstraint is not implemented." << std::endl;
|
||||
abort();
|
||||
#else
|
||||
|
||||
NGT::Timer originalEdgeTimer, reverseEdgeTimer, normalizeEdgeTimer;
|
||||
|
||||
if (reverseEdgeSize > 10000) {
|
||||
std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
if (id % 1000000 == 0) {
|
||||
std::cerr << "Processed " << id << std::endl;
|
||||
}
|
||||
try {
|
||||
NGT::GraphNode &node = *outGraph.getNode(id);
|
||||
if (node.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
node.clear();
|
||||
NGT::GraphNode empty;
|
||||
node.swap(empty);
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
|
||||
|
||||
std::vector<ObjectDistances> reverse(graph.size() + 1);
|
||||
for (size_t id = 1; id <= graph.size(); ++id) {
|
||||
try {
|
||||
NGT::GraphNode &node = graph[id - 1];
|
||||
if (id % 100000 == 0) {
|
||||
std::cerr << "Processed (summing up) " << id << std::endl;
|
||||
}
|
||||
for (size_t rank = 0; rank < node.size(); rank++) {
|
||||
reverse[node[rank].id].push_back(ObjectDistance(id, node[rank].distance));
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<size_t, size_t> > reverseSize(graph.size() + 1);
|
||||
reverseSize[0] = std::pair<size_t, size_t>(0, 0);
|
||||
for (size_t rid = 1; rid <= graph.size(); ++rid) {
|
||||
reverseSize[rid] = std::pair<size_t, size_t>(reverse[rid].size(), rid);
|
||||
}
|
||||
std::sort(reverseSize.begin(), reverseSize.end());
|
||||
|
||||
|
||||
std::vector<uint32_t> indegreeCount(graph.size(), 0);
|
||||
size_t zeroCount = 0;
|
||||
for (size_t sizerank = 0; sizerank <= reverseSize.size(); sizerank++) {
|
||||
|
||||
if (reverseSize[sizerank].first == 0) {
|
||||
zeroCount++;
|
||||
continue;
|
||||
}
|
||||
size_t rid = reverseSize[sizerank].second;
|
||||
ObjectDistances &rnode = reverse[rid];
|
||||
for (auto rni = rnode.begin(); rni != rnode.end(); ++rni) {
|
||||
if (indegreeCount[(*rni).id] >= reverseEdgeSize) {
|
||||
continue;
|
||||
}
|
||||
NGT::GraphNode &node = *outGraph.getNode(rid);
|
||||
if (indegreeCount[(*rni).id] > 0 && node.size() >= originalEdgeSize) {
|
||||
continue;
|
||||
}
|
||||
|
||||
node.push_back(NGT::ObjectDistance((*rni).id, (*rni).distance));
|
||||
indegreeCount[(*rni).id]++;
|
||||
}
|
||||
}
|
||||
reverseEdgeTimer.stop();
|
||||
std::cerr << "The number of nodes with zero outdegree by reverse edges=" << zeroCount << std::endl;
|
||||
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
|
||||
|
||||
normalizeEdgeTimer.start();
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
try {
|
||||
NGT::GraphNode &n = *outGraph.getNode(id);
|
||||
if (id % 100000 == 0) {
|
||||
std::cerr << "Processed " << id << std::endl;
|
||||
}
|
||||
std::sort(n.begin(), n.end());
|
||||
NGT::ObjectID prev = 0;
|
||||
for (auto it = n.begin(); it != n.end();) {
|
||||
if (prev == (*it).id) {
|
||||
it = n.erase(it);
|
||||
continue;
|
||||
}
|
||||
prev = (*it).id;
|
||||
it++;
|
||||
}
|
||||
NGT::GraphNode tmp = n;
|
||||
n.swap(tmp);
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
normalizeEdgeTimer.stop();
|
||||
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
|
||||
|
||||
originalEdgeTimer.start();
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
if (id % 1000000 == 0) {
|
||||
std::cerr << "Processed " << id << std::endl;
|
||||
}
|
||||
NGT::GraphNode &node = graph[id - 1];
|
||||
try {
|
||||
NGT::GraphNode &onode = *outGraph.getNode(id);
|
||||
bool stop = false;
|
||||
for (size_t rank = 0; (rank < node.size() && rank < originalEdgeSize) && stop == false; rank++) {
|
||||
switch (mode) {
|
||||
case 'a':
|
||||
if (onode.size() >= originalEdgeSize) {
|
||||
stop = true;
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
case 'c':
|
||||
break;
|
||||
}
|
||||
NGT::Distance distance = node[rank].distance;
|
||||
size_t nodeID = node[rank].id;
|
||||
outGraph.addEdge(id, nodeID, distance, false);
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
originalEdgeTimer.stop();
|
||||
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
|
||||
|
||||
std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
|
||||
<< ":" << normalizeEdgeTimer.time << std::endl;
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
// reconstruct a pseudo ANNG with a fewer edges from an actual ANNG with more edges.
|
||||
// graph is a source ANNG
|
||||
// index is an index with a reconstructed ANNG
|
||||
static
|
||||
void reconstructANNGFromANNG(std::vector<NGT::ObjectDistances> &graph, NGT::Index &index, size_t edgeSize)
|
||||
{
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "reconstructANNGFromANNG is not implemented." << std::endl;
|
||||
abort();
|
||||
#else
|
||||
|
||||
NGT::GraphIndex &outGraph = dynamic_cast<NGT::GraphIndex&>(index.getIndex());
|
||||
|
||||
// remove all edges in the index.
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
if (id % 1000000 == 0) {
|
||||
std::cerr << "Processed " << id << " nodes." << std::endl;
|
||||
}
|
||||
try {
|
||||
NGT::GraphNode &node = *outGraph.getNode(id);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
node.clear(outGraph.repository.allocator);
|
||||
#else
|
||||
NGT::GraphNode empty;
|
||||
node.swap(empty);
|
||||
#endif
|
||||
} catch(NGT::Exception &err) {
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t id = 1; id <= graph.size(); ++id) {
|
||||
size_t edgeCount = 0;
|
||||
try {
|
||||
NGT::ObjectDistances &node = graph[id - 1];
|
||||
NGT::GraphNode &n = *outGraph.getNode(id);
|
||||
NGT::Distance prevDistance = 0.0;
|
||||
assert(n.size() == 0);
|
||||
for (size_t i = 0; i < node.size(); ++i) {
|
||||
NGT::Distance distance = node[i].distance;
|
||||
if (prevDistance > distance) {
|
||||
NGTThrowException("Edge distance order is invalid");
|
||||
}
|
||||
prevDistance = distance;
|
||||
size_t nodeID = node[i].id;
|
||||
if (node[i].id < id) {
|
||||
try {
|
||||
NGT::GraphNode &dn = *outGraph.getNode(nodeID);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
n.push_back(NGT::ObjectDistance(nodeID, distance), outGraph.repository.allocator);
|
||||
dn.push_back(NGT::ObjectDistance(id, distance), outGraph.repository.allocator);
|
||||
#else
|
||||
n.push_back(NGT::ObjectDistance(nodeID, distance));
|
||||
dn.push_back(NGT::ObjectDistance(id, distance));
|
||||
#endif
|
||||
} catch(...) {}
|
||||
edgeCount++;
|
||||
}
|
||||
if (edgeCount >= edgeSize) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
try {
|
||||
NGT::GraphNode &n = *outGraph.getNode(id);
|
||||
std::sort(n.begin(), n.end());
|
||||
NGT::ObjectID prev = 0;
|
||||
for (auto it = n.begin(); it != n.end();) {
|
||||
if (prev == (*it).id) {
|
||||
it = n.erase(it);
|
||||
continue;
|
||||
}
|
||||
prev = (*it).id;
|
||||
it++;
|
||||
}
|
||||
NGT::GraphNode tmp = n;
|
||||
n.swap(tmp);
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void refineANNG(NGT::Index &index, bool unlog, float epsilon = 0.1, float accuracy = 0.0, int noOfEdges = 0, int exploreEdgeSize = INT_MIN, size_t batchSize = 10000) {
|
||||
NGT::StdOstreamRedirector redirector(unlog);
|
||||
redirector.begin();
|
||||
try {
|
||||
refineANNG(index, epsilon, accuracy, noOfEdges, exploreEdgeSize, batchSize);
|
||||
} catch (NGT::Exception &err) {
|
||||
redirector.end();
|
||||
throw(err);
|
||||
}
|
||||
}
|
||||
|
||||
static void refineANNG(NGT::Index &index, float epsilon = 0.1, float accuracy = 0.0, int noOfEdges = 0, int exploreEdgeSize = INT_MIN, size_t batchSize = 10000) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
NGTThrowException("GraphReconstructor::refineANNG: Not implemented for the shared memory option.");
|
||||
#else
|
||||
auto prop = static_cast<GraphIndex&>(index.getIndex()).getGraphProperty();
|
||||
NGT::ObjectRepository &objectRepository = index.getObjectSpace().getRepository();
|
||||
NGT::GraphIndex &graphIndex = static_cast<GraphIndex&>(index.getIndex());
|
||||
size_t nOfObjects = objectRepository.size();
|
||||
bool error = false;
|
||||
std::string errorMessage;
|
||||
for (size_t bid = 1; bid < nOfObjects; bid += batchSize) {
|
||||
NGT::ObjectDistances results[batchSize];
|
||||
// search
|
||||
#pragma omp parallel for
|
||||
for (size_t idx = 0; idx < batchSize; idx++) {
|
||||
size_t id = bid + idx;
|
||||
if (id % 100000 == 0) {
|
||||
std::cerr << "# of processed objects=" << id << std::endl;
|
||||
}
|
||||
if (objectRepository.isEmpty(id)) {
|
||||
continue;
|
||||
}
|
||||
NGT::SearchContainer searchContainer(*objectRepository.get(id));
|
||||
searchContainer.setResults(&results[idx]);
|
||||
assert(prop.edgeSizeForCreation > 0);
|
||||
searchContainer.setSize(noOfEdges > prop.edgeSizeForCreation ? noOfEdges : prop.edgeSizeForCreation);
|
||||
if (accuracy > 0.0) {
|
||||
searchContainer.setExpectedAccuracy(accuracy);
|
||||
} else {
|
||||
searchContainer.setEpsilon(epsilon);
|
||||
}
|
||||
if (exploreEdgeSize != INT_MIN) {
|
||||
searchContainer.setEdgeSize(exploreEdgeSize);
|
||||
}
|
||||
if (!error) {
|
||||
try {
|
||||
index.search(searchContainer);
|
||||
} catch (NGT::Exception &err) {
|
||||
#pragma omp critical
|
||||
{
|
||||
error = true;
|
||||
errorMessage = err.what();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (error) {
|
||||
std::stringstream msg;
|
||||
msg << "GraphReconstructor::refineANNG: " << errorMessage;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
// outgoing edges
|
||||
#pragma omp parallel for
|
||||
for (size_t idx = 0; idx < batchSize; idx++) {
|
||||
size_t id = bid + idx;
|
||||
if (objectRepository.isEmpty(id)) {
|
||||
continue;
|
||||
}
|
||||
NGT::GraphNode &node = *graphIndex.getNode(id);
|
||||
for (auto i = results[idx].begin(); i != results[idx].end(); ++i) {
|
||||
if ((*i).id != id) {
|
||||
node.push_back(*i);
|
||||
}
|
||||
}
|
||||
std::sort(node.begin(), node.end());
|
||||
// dedupe
|
||||
ObjectID prev = 0;
|
||||
for (GraphNode::iterator ni = node.begin(); ni != node.end();) {
|
||||
if (prev == (*ni).id) {
|
||||
ni = node.erase(ni);
|
||||
continue;
|
||||
}
|
||||
prev = (*ni).id;
|
||||
ni++;
|
||||
}
|
||||
}
|
||||
// incomming edges
|
||||
if (noOfEdges != 0) {
|
||||
continue;
|
||||
}
|
||||
for (size_t idx = 0; idx < batchSize; idx++) {
|
||||
size_t id = bid + idx;
|
||||
if (id % 10000 == 0) {
|
||||
std::cerr << "# of processed objects=" << id << std::endl;
|
||||
}
|
||||
for (auto i = results[idx].begin(); i != results[idx].end(); ++i) {
|
||||
if ((*i).id != id) {
|
||||
NGT::GraphNode &node = *graphIndex.getNode((*i).id);
|
||||
graphIndex.addEdge(node, id, (*i).distance, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (noOfEdges != 0) {
|
||||
// prune to build knng
|
||||
size_t nedges = noOfEdges < 0 ? -noOfEdges : noOfEdges;
|
||||
#pragma omp parallel for
|
||||
for (ObjectID id = 1; id < nOfObjects; ++id) {
|
||||
if (objectRepository.isEmpty(id)) {
|
||||
continue;
|
||||
}
|
||||
NGT::GraphNode &node = *graphIndex.getNode(id);
|
||||
if (node.size() > nedges) {
|
||||
node.resize(nedges);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
}
|
||||
};
|
||||
|
||||
}; // NGT
|
|
@ -0,0 +1,110 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <cstring>
|
||||
#include <stdint.h>
|
||||
#include <climits>
|
||||
#include <unordered_set>
|
||||
|
||||
class HashBasedBooleanSet{
|
||||
private:
|
||||
uint32_t *_table;
|
||||
uint32_t _tableSize;
|
||||
uint32_t _mask;
|
||||
|
||||
std::unordered_set<uint32_t> _stlHash;
|
||||
|
||||
|
||||
inline uint32_t _hash1(const uint32_t value){
|
||||
return value & _mask;
|
||||
}
|
||||
|
||||
public:
|
||||
HashBasedBooleanSet():_table(NULL), _tableSize(0), _mask(0) {}
|
||||
|
||||
HashBasedBooleanSet(const uint64_t size):_table(NULL), _tableSize(0), _mask(0) {
|
||||
size_t bitSize = 0;
|
||||
size_t bit = size;
|
||||
while (bit != 0) {
|
||||
bitSize++;
|
||||
bit >>= 1;
|
||||
}
|
||||
size_t bucketSize = 0x1 << ((bitSize + 4) / 2 + 3);
|
||||
initialize(bucketSize);
|
||||
}
|
||||
void initialize(const uint32_t tableSize) {
|
||||
_tableSize = tableSize;
|
||||
_mask = _tableSize - 1;
|
||||
const uint32_t checkValue = _hash1(tableSize);
|
||||
if(checkValue != 0){
|
||||
std::cerr << "[WARN] table size is not 2^N : " << tableSize << std::endl;
|
||||
}
|
||||
|
||||
_table = new uint32_t[tableSize];
|
||||
memset(_table, 0, tableSize * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
~HashBasedBooleanSet(){
|
||||
delete[] _table;
|
||||
_stlHash.clear();
|
||||
}
|
||||
|
||||
inline bool operator[](const uint32_t num){
|
||||
const uint32_t hashValue = _hash1(num);
|
||||
|
||||
auto v = _table[hashValue];
|
||||
if (v == num){
|
||||
return true;
|
||||
}
|
||||
if (v == 0){
|
||||
return false;
|
||||
}
|
||||
if (_stlHash.count(num) <= 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline void set(const uint32_t num){
|
||||
uint32_t &value = _table[_hash1(num)];
|
||||
if(value == 0){
|
||||
value = num;
|
||||
}else{
|
||||
if(value != num){
|
||||
_stlHash.insert(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void insert(const uint32_t num){
|
||||
set(num);
|
||||
}
|
||||
|
||||
inline void reset(const uint32_t num){
|
||||
const uint32_t hashValue = _hash1(num);
|
||||
if(_table[hashValue] != 0){
|
||||
if(_table[hashValue] != num){
|
||||
_stlHash.erase(num);
|
||||
}else{
|
||||
_table[hashValue] = UINT_MAX;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,457 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include "MmapManagerImpl.hpp"
|
||||
|
||||
namespace MemoryManager{
|
||||
// static method ---
|
||||
void MmapManager::setDefaultOptionValue(init_option_st &optionst)
|
||||
{
|
||||
optionst.use_expand = MMAP_DEFAULT_ALLOW_EXPAND;
|
||||
optionst.reuse_type = REUSE_DATA_CLASSIFY;
|
||||
}
|
||||
|
||||
size_t MmapManager::getAlignSize(size_t size){
|
||||
if((size % MMAP_MEMORY_ALIGN) == 0){
|
||||
return size;
|
||||
}else{
|
||||
return ( (size >> MMAP_MEMORY_ALIGN_EXP ) + 1 ) * MMAP_MEMORY_ALIGN;
|
||||
}
|
||||
}
|
||||
// static method ---
|
||||
|
||||
|
||||
MmapManager::MmapManager():_impl(new MmapManager::Impl(*this))
|
||||
{
|
||||
for(uint64_t i = 0; i < MMAP_MAX_UNIT_NUM; ++i){
|
||||
_impl->mmapDataAddr[i] = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
MmapManager::~MmapManager() = default;
|
||||
|
||||
void MmapManager::dumpHeap() const
|
||||
{
|
||||
_impl->dumpHeap();
|
||||
}
|
||||
|
||||
bool MmapManager::isOpen() const
|
||||
{
|
||||
return _impl->isOpen;
|
||||
}
|
||||
|
||||
void *MmapManager::getEntryHook() const {
|
||||
return getAbsAddr(_impl->mmapCntlHead->entry_p);
|
||||
}
|
||||
|
||||
void MmapManager::setEntryHook(const void *entry_p){
|
||||
_impl->mmapCntlHead->entry_p = getRelAddr(entry_p);
|
||||
}
|
||||
|
||||
|
||||
bool MmapManager::init(const std::string &filePath, size_t size, const init_option_st *optionst) const
|
||||
{
|
||||
try{
|
||||
const std::string controlFile = filePath + MMAP_CNTL_FILE_SUFFIX;
|
||||
|
||||
struct stat st;
|
||||
if(stat(controlFile.c_str(), &st) == 0){
|
||||
return false;
|
||||
}
|
||||
if(filePath.length() > MMAP_MAX_FILE_NAME_LENGTH){
|
||||
std::cerr << "too long filepath" << std::endl;
|
||||
return false;
|
||||
}
|
||||
if((size % sysconf(_SC_PAGESIZE) != 0) || ( size < MMAP_LOWER_SIZE )){
|
||||
std::cerr << "input size error" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t fd = _impl->formatFile(controlFile, MMAP_CNTL_FILE_SIZE);
|
||||
assert(fd >= 0);
|
||||
|
||||
errno = 0;
|
||||
char *cntl_p = (char *)mmap(NULL, MMAP_CNTL_FILE_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
|
||||
if(cntl_p == MAP_FAILED){
|
||||
const std::string err_str = getErrorStr(errno);
|
||||
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
throw MmapManagerException(controlFile + " " + err_str);
|
||||
}
|
||||
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
|
||||
try {
|
||||
fd = _impl->formatFile(filePath, size);
|
||||
} catch (MmapManagerException &err) {
|
||||
if(munmap(cntl_p, MMAP_CNTL_FILE_SIZE) == -1) {
|
||||
throw MmapManagerException("[ERR] : munmap error : " + getErrorStr(errno) +
|
||||
" : Through the exception : " + err.what());
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
|
||||
boot_st bootStruct = {0};
|
||||
control_st controlStruct = {0};
|
||||
_impl->initBootStruct(bootStruct, size);
|
||||
_impl->initControlStruct(controlStruct, size);
|
||||
|
||||
char *cntl_head = cntl_p;
|
||||
cntl_head += sizeof(boot_st);
|
||||
|
||||
if(optionst != NULL){
|
||||
controlStruct.use_expand = optionst->use_expand;
|
||||
controlStruct.reuse_type = optionst->reuse_type;
|
||||
}
|
||||
|
||||
memcpy(cntl_p, (char *)&bootStruct, sizeof(boot_st));
|
||||
memcpy(cntl_head, (char *)&controlStruct, sizeof(control_st));
|
||||
|
||||
errno = 0;
|
||||
if(munmap(cntl_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
|
||||
|
||||
return true;
|
||||
}catch(MmapManagerException &e){
|
||||
std::cerr << "init error. " << e.what() << std::endl;
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
bool MmapManager::openMemory(const std::string &filePath)
|
||||
{
|
||||
try{
|
||||
if(_impl->isOpen == true){
|
||||
std::string err_str = "[ERROR] : openMemory error (double open).";
|
||||
throw MmapManagerException(err_str);
|
||||
}
|
||||
|
||||
const std::string controlFile = filePath + MMAP_CNTL_FILE_SUFFIX;
|
||||
_impl->filePath = filePath;
|
||||
|
||||
int32_t fd;
|
||||
|
||||
errno = 0;
|
||||
if((fd = open(controlFile.c_str(), O_RDWR, 0666)) == -1){
|
||||
const std::string err_str = getErrorStr(errno);
|
||||
throw MmapManagerException("file open error" + err_str);
|
||||
}
|
||||
|
||||
errno = 0;
|
||||
boot_st *boot_p = (boot_st*)mmap(NULL, MMAP_CNTL_FILE_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
|
||||
if(boot_p == MAP_FAILED){
|
||||
const std::string err_str = getErrorStr(errno);
|
||||
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
throw MmapManagerException(controlFile + " " + err_str);
|
||||
}
|
||||
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
|
||||
if(boot_p->version != MMAP_MANAGER_VERSION){
|
||||
std::cerr << "[WARN] : version error" << std::endl;
|
||||
errno = 0;
|
||||
if(munmap(boot_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
|
||||
throw MmapManagerException("MemoryManager version error");
|
||||
}
|
||||
|
||||
errno = 0;
|
||||
if((fd = open(filePath.c_str(), O_RDWR, 0666)) == -1){
|
||||
const std::string err_str = getErrorStr(errno);
|
||||
errno = 0;
|
||||
if(munmap(boot_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
|
||||
throw MmapManagerException("file open error = " + std::string(filePath.c_str()) + err_str);
|
||||
}
|
||||
|
||||
_impl->mmapCntlHead = (control_st*)( (char *)boot_p + sizeof(boot_st));
|
||||
_impl->mmapCntlAddr = (void *)boot_p;
|
||||
|
||||
for(uint64_t i = 0; i < _impl->mmapCntlHead->unit_num; i++){
|
||||
off_t offset = _impl->mmapCntlHead->base_size * i;
|
||||
errno = 0;
|
||||
_impl->mmapDataAddr[i] = mmap(NULL, _impl->mmapCntlHead->base_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, offset);
|
||||
if(_impl->mmapDataAddr[i] == MAP_FAILED){
|
||||
if (errno == EINVAL) {
|
||||
std::cerr << "MmapManager::openMemory: Fatal error. EINVAL" << std::endl
|
||||
<< " If you use valgrind, this error might occur when the DB is created." << std::endl
|
||||
<< " In the case of that, reduce bsize in SharedMemoryAllocator." << std::endl;
|
||||
assert(errno != EINVAL);
|
||||
}
|
||||
const std::string err_str = getErrorStr(errno);
|
||||
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
closeMemory(true);
|
||||
throw MmapManagerException(err_str);
|
||||
}
|
||||
}
|
||||
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
|
||||
_impl->isOpen = true;
|
||||
return true;
|
||||
}catch(MmapManagerException &e){
|
||||
std::cerr << "open error" << std::endl;
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
void MmapManager::closeMemory(const bool force)
|
||||
{
|
||||
try{
|
||||
if(force || _impl->isOpen){
|
||||
uint16_t count = 0;
|
||||
void *error_ids[MMAP_MAX_UNIT_NUM] = {0};
|
||||
for(uint16_t i = 0; i < _impl->mmapCntlHead->unit_num; i++){
|
||||
if(_impl->mmapDataAddr[i] != NULL){
|
||||
if(munmap(_impl->mmapDataAddr[i], _impl->mmapCntlHead->base_size) == -1){
|
||||
error_ids[i] = _impl->mmapDataAddr[i];;
|
||||
count++;
|
||||
}
|
||||
_impl->mmapDataAddr[i] = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
if(count > 0){
|
||||
std::string msg = "";
|
||||
|
||||
for(uint16_t i = 0; i < count; i++){
|
||||
std::stringstream ss;
|
||||
ss << error_ids[i];
|
||||
msg += ss.str() + ", ";
|
||||
}
|
||||
throw MmapManagerException("unmap error : ids = " + msg);
|
||||
}
|
||||
|
||||
if(_impl->mmapCntlAddr != NULL){
|
||||
if(munmap(_impl->mmapCntlAddr, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
|
||||
_impl->mmapCntlAddr = NULL;
|
||||
}
|
||||
_impl->isOpen = false;
|
||||
}
|
||||
}catch(MmapManagerException &e){
|
||||
std::cerr << "close error" << std::endl;
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
off_t MmapManager::alloc(const size_t size, const bool not_reuse_flag)
|
||||
{
|
||||
try{
|
||||
if(!_impl->isOpen){
|
||||
std::cerr << "not open this file" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
size_t alloc_size = getAlignSize(size);
|
||||
|
||||
if( (alloc_size + sizeof(chunk_head_st)) >= _impl->mmapCntlHead->base_size ){
|
||||
std::cerr << "alloc size over. size=" << size << "." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if(!not_reuse_flag){
|
||||
if( _impl->mmapCntlHead->reuse_type == REUSE_DATA_CLASSIFY
|
||||
|| _impl->mmapCntlHead->reuse_type == REUSE_DATA_QUEUE
|
||||
|| _impl->mmapCntlHead->reuse_type == REUSE_DATA_QUEUE_PLUS){
|
||||
off_t ret_offset;
|
||||
reuse_state_t reuse_state = REUSE_STATE_OK;
|
||||
ret_offset = reuse(alloc_size, reuse_state);
|
||||
if(reuse_state != REUSE_STATE_ALLOC){
|
||||
return ret_offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
head_st *unit_header = &_impl->mmapCntlHead->data_headers[_impl->mmapCntlHead->active_unit];
|
||||
if((unit_header->break_p + sizeof(chunk_head_st) + alloc_size) >= _impl->mmapCntlHead->base_size){
|
||||
if(_impl->mmapCntlHead->use_expand == true){
|
||||
if(_impl->expandMemory() == false){
|
||||
std::cerr << __func__ << ": cannot expand" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
unit_header = &_impl->mmapCntlHead->data_headers[_impl->mmapCntlHead->active_unit];
|
||||
}else{
|
||||
std::cerr << __func__ << ": total size over" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
const off_t file_offset = _impl->mmapCntlHead->active_unit * _impl->mmapCntlHead->base_size;
|
||||
const off_t ret_p = file_offset + ( unit_header->break_p + sizeof(chunk_head_st) );
|
||||
|
||||
chunk_head_st *chunk_head = (chunk_head_st*)(unit_header->break_p + (char *)_impl->mmapDataAddr[_impl->mmapCntlHead->active_unit]);
|
||||
_impl->setupChunkHead(chunk_head, false, _impl->mmapCntlHead->active_unit, -1, alloc_size);
|
||||
unit_header->break_p += alloc_size + sizeof(chunk_head_st);
|
||||
unit_header->chunk_num++;
|
||||
|
||||
return ret_p;
|
||||
}catch(MmapManagerException &e){
|
||||
std::cerr << "allocation error" << std::endl;
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
void MmapManager::free(const off_t p)
|
||||
{
|
||||
switch(_impl->mmapCntlHead->reuse_type){
|
||||
case REUSE_DATA_CLASSIFY:
|
||||
_impl->free_data_classify(p);
|
||||
break;
|
||||
case REUSE_DATA_QUEUE:
|
||||
_impl->free_data_queue(p);
|
||||
break;
|
||||
case REUSE_DATA_QUEUE_PLUS:
|
||||
_impl->free_data_queue_plus(p);
|
||||
break;
|
||||
default:
|
||||
_impl->free_data_classify(p);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
off_t MmapManager::reuse(const size_t size, reuse_state_t &reuse_state)
|
||||
{
|
||||
off_t ret_off;
|
||||
|
||||
switch(_impl->mmapCntlHead->reuse_type){
|
||||
case REUSE_DATA_CLASSIFY:
|
||||
ret_off = _impl->reuse_data_classify(size, reuse_state);
|
||||
break;
|
||||
case REUSE_DATA_QUEUE:
|
||||
ret_off = _impl->reuse_data_queue(size, reuse_state);
|
||||
break;
|
||||
case REUSE_DATA_QUEUE_PLUS:
|
||||
ret_off = _impl->reuse_data_queue_plus(size, reuse_state);
|
||||
break;
|
||||
default:
|
||||
ret_off = _impl->reuse_data_classify(size, reuse_state);
|
||||
break;
|
||||
}
|
||||
|
||||
return ret_off;
|
||||
}
|
||||
|
||||
void *MmapManager::getAbsAddr(off_t p) const
|
||||
{
|
||||
if(p < 0){
|
||||
return NULL;
|
||||
}
|
||||
const uint16_t unit_id = p / _impl->mmapCntlHead->base_size;
|
||||
const off_t file_offset = unit_id * _impl->mmapCntlHead->base_size;
|
||||
const off_t ret_p = p - file_offset;
|
||||
|
||||
return ABS_ADDR(ret_p, _impl->mmapDataAddr[unit_id]);
|
||||
}
|
||||
|
||||
off_t MmapManager::getRelAddr(const void *p) const
|
||||
{
|
||||
const chunk_head_st *chunk_head = (chunk_head_st *)((char *)p - sizeof(chunk_head_st));
|
||||
const uint16_t unit_id = chunk_head->unit_id;
|
||||
|
||||
const off_t file_offset = unit_id * _impl->mmapCntlHead->base_size;
|
||||
off_t ret_p = (off_t)((char *)p - (char *)_impl->mmapDataAddr[unit_id]);
|
||||
ret_p += file_offset;
|
||||
|
||||
return ret_p;
|
||||
}
|
||||
|
||||
std::string getErrorStr(int32_t err_num){
|
||||
char err_msg[256];
|
||||
#ifdef _GNU_SOURCE
|
||||
char *msg = strerror_r(err_num, err_msg, 256);
|
||||
return std::string(msg);
|
||||
#else
|
||||
strerror_r(err_num, err_msg, 256);
|
||||
return std::string(err_msg);
|
||||
#endif
|
||||
}
|
||||
|
||||
size_t MmapManager::getTotalSize() const
|
||||
{
|
||||
const uint16_t active_unit = _impl->mmapCntlHead->active_unit;
|
||||
const size_t ret_size = ((_impl->mmapCntlHead->unit_num - 1) * _impl->mmapCntlHead->base_size) + _impl->mmapCntlHead->data_headers[active_unit].break_p;
|
||||
|
||||
return ret_size;
|
||||
}
|
||||
|
||||
size_t MmapManager::getUseSize() const
|
||||
{
|
||||
size_t total_size = 0;
|
||||
void *ref_addr = (void *)&total_size;
|
||||
_impl->scanAllData(ref_addr, CHECK_STATS_USE_SIZE);
|
||||
|
||||
return total_size;
|
||||
}
|
||||
|
||||
uint64_t MmapManager::getUseNum() const
|
||||
{
|
||||
uint64_t total_chunk_num = 0;
|
||||
void *ref_addr = (void *)&total_chunk_num;
|
||||
_impl->scanAllData(ref_addr, CHECK_STATS_USE_NUM);
|
||||
|
||||
return total_chunk_num;
|
||||
}
|
||||
|
||||
size_t MmapManager::getFreeSize() const
|
||||
{
|
||||
size_t total_size = 0;
|
||||
void *ref_addr = (void *)&total_size;
|
||||
_impl->scanAllData(ref_addr, CHECK_STATS_FREE_SIZE);
|
||||
|
||||
return total_size;
|
||||
}
|
||||
|
||||
uint64_t MmapManager::getFreeNum() const
|
||||
{
|
||||
uint64_t total_chunk_num = 0;
|
||||
void *ref_addr = (void *)&total_chunk_num;
|
||||
_impl->scanAllData(ref_addr, CHECK_STATS_FREE_NUM);
|
||||
|
||||
return total_chunk_num;
|
||||
}
|
||||
|
||||
uint16_t MmapManager::getUnitNum() const
|
||||
{
|
||||
return _impl->mmapCntlHead->unit_num;
|
||||
}
|
||||
|
||||
size_t MmapManager::getQueueCapacity() const
|
||||
{
|
||||
free_queue_st *free_queue = &_impl->mmapCntlHead->free_queue;
|
||||
return free_queue->capacity;
|
||||
}
|
||||
|
||||
uint64_t MmapManager::getQueueNum() const
|
||||
{
|
||||
free_queue_st *free_queue = &_impl->mmapCntlHead->free_queue;
|
||||
return free_queue->tail;
|
||||
}
|
||||
|
||||
uint64_t MmapManager::getLargeListNum() const
|
||||
{
|
||||
uint64_t count = 0;
|
||||
free_list_st *free_list = &_impl->mmapCntlHead->free_data.large_list;
|
||||
|
||||
if(free_list->free_p == -1){
|
||||
return count;
|
||||
}
|
||||
|
||||
off_t current_off = free_list->free_p;
|
||||
chunk_head_st *current_chunk_head = (chunk_head_st *)getAbsAddr(current_off);
|
||||
|
||||
while(current_chunk_head != NULL){
|
||||
count++;
|
||||
current_off = current_chunk_head->free_next;
|
||||
current_chunk_head = (chunk_head_st *)getAbsAddr(current_off);
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#define ABS_ADDR(x, y) (void *)(x + (char *)y);
|
||||
|
||||
#define USE_MMAP_MANAGER
|
||||
|
||||
namespace MemoryManager{
|
||||
|
||||
typedef enum _option_reuse_t{
|
||||
REUSE_DATA_CLASSIFY,
|
||||
REUSE_DATA_QUEUE,
|
||||
REUSE_DATA_QUEUE_PLUS,
|
||||
}option_reuse_t;
|
||||
|
||||
typedef enum _reuse_state_t{
|
||||
REUSE_STATE_OK,
|
||||
REUSE_STATE_FALSE,
|
||||
REUSE_STATE_ALLOC,
|
||||
}reuse_state_t;
|
||||
|
||||
typedef enum _check_statistics_t{
|
||||
CHECK_STATS_USE_SIZE,
|
||||
CHECK_STATS_USE_NUM,
|
||||
CHECK_STATS_FREE_SIZE,
|
||||
CHECK_STATS_FREE_NUM,
|
||||
}check_statistics_t;
|
||||
|
||||
typedef struct _init_option_st{
|
||||
bool use_expand;
|
||||
option_reuse_t reuse_type;
|
||||
}init_option_st;
|
||||
|
||||
|
||||
class MmapManager{
|
||||
public:
|
||||
MmapManager();
|
||||
~MmapManager();
|
||||
|
||||
bool init(const std::string &filePath, size_t size, const init_option_st *optionst = NULL) const;
|
||||
bool openMemory(const std::string &filePath);
|
||||
void closeMemory(const bool force = false);
|
||||
off_t alloc(const size_t size, const bool not_reuse_flag = false);
|
||||
void free(const off_t p);
|
||||
off_t reuse(const size_t size, reuse_state_t &reuse_state);
|
||||
void *getAbsAddr(off_t p) const;
|
||||
off_t getRelAddr(const void *p) const;
|
||||
|
||||
size_t getTotalSize() const;
|
||||
size_t getUseSize() const;
|
||||
uint64_t getUseNum() const;
|
||||
size_t getFreeSize() const;
|
||||
uint64_t getFreeNum() const;
|
||||
uint16_t getUnitNum() const;
|
||||
size_t getQueueCapacity() const;
|
||||
uint64_t getQueueNum() const;
|
||||
uint64_t getLargeListNum() const;
|
||||
|
||||
void dumpHeap() const;
|
||||
|
||||
bool isOpen() const;
|
||||
void *getEntryHook() const;
|
||||
void setEntryHook(const void *entry_p);
|
||||
|
||||
// static method ---
|
||||
static void setDefaultOptionValue(init_option_st &optionst);
|
||||
static size_t getAlignSize(size_t size);
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
|
||||
std::string getErrorStr(int32_t err_num);
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "MmapManager.h"
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
namespace MemoryManager{
|
||||
const uint64_t MMAP_MANAGER_VERSION = 5;
|
||||
|
||||
const bool MMAP_DEFAULT_ALLOW_EXPAND = false;
|
||||
const uint64_t MMAP_CNTL_FILE_RANGE = 16;
|
||||
const size_t MMAP_CNTL_FILE_SIZE = MMAP_CNTL_FILE_RANGE * sysconf(_SC_PAGESIZE);
|
||||
const uint64_t MMAP_MAX_FILE_NAME_LENGTH = 1024;
|
||||
const std::string MMAP_CNTL_FILE_SUFFIX = "c";
|
||||
|
||||
const size_t MMAP_LOWER_SIZE = 1;
|
||||
const size_t MMAP_MEMORY_ALIGN = 8;
|
||||
const size_t MMAP_MEMORY_ALIGN_EXP = 3;
|
||||
|
||||
#ifndef MMANAGER_TEST_MODE
|
||||
const uint64_t MMAP_MAX_UNIT_NUM = 1024;
|
||||
#else
|
||||
const uint64_t MMAP_MAX_UNIT_NUM = 8;
|
||||
#endif
|
||||
|
||||
const uint64_t MMAP_FREE_QUEUE_SIZE = 1024;
|
||||
|
||||
const uint64_t MMAP_FREE_LIST_NUM = 64;
|
||||
|
||||
typedef struct _boot_st{
|
||||
uint32_t version;
|
||||
uint64_t reserve;
|
||||
size_t size;
|
||||
}boot_st;
|
||||
|
||||
typedef struct _head_st{
|
||||
off_t break_p;
|
||||
uint64_t chunk_num;
|
||||
uint64_t reserve;
|
||||
}head_st;
|
||||
|
||||
|
||||
typedef struct _free_list_st{
|
||||
off_t free_p;
|
||||
off_t free_last_p;
|
||||
}free_list_st;
|
||||
|
||||
|
||||
typedef struct _free_st{
|
||||
free_list_st large_list;
|
||||
free_list_st free_lists[MMAP_FREE_LIST_NUM];
|
||||
}free_st;
|
||||
|
||||
|
||||
typedef struct _free_queue_st{
|
||||
off_t data;
|
||||
size_t capacity;
|
||||
uint64_t tail;
|
||||
}free_queue_st;
|
||||
|
||||
|
||||
|
||||
typedef struct _control_st{
|
||||
bool use_expand;
|
||||
uint16_t unit_num;
|
||||
uint16_t active_unit;
|
||||
uint64_t reserve;
|
||||
size_t base_size;
|
||||
off_t entry_p;
|
||||
option_reuse_t reuse_type;
|
||||
free_st free_data;
|
||||
free_queue_st free_queue;
|
||||
head_st data_headers[MMAP_MAX_UNIT_NUM];
|
||||
}control_st;
|
||||
|
||||
typedef struct _chunk_head_st{
|
||||
bool delete_flg;
|
||||
uint16_t unit_id;
|
||||
off_t free_next;
|
||||
size_t size;
|
||||
}chunk_head_st;
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <exception>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace MemoryManager{
|
||||
class MmapManagerException : public std::domain_error{
|
||||
public:
|
||||
MmapManagerException(const std::string &msg) : std::domain_error(msg){}
|
||||
};
|
||||
}
|
|
@ -0,0 +1,644 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "MmapManagerDefs.h"
|
||||
#include "MmapManagerException.h"
|
||||
|
||||
#include <sys/mman.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#include <fcntl.h>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
#include <cassert>
|
||||
|
||||
namespace MemoryManager{
|
||||
|
||||
class MmapManager::Impl{
|
||||
public:
|
||||
Impl() = delete;
|
||||
Impl(MmapManager &ommanager);
|
||||
virtual ~Impl(){}
|
||||
|
||||
MmapManager &mmanager;
|
||||
bool isOpen;
|
||||
void *mmapCntlAddr;
|
||||
control_st *mmapCntlHead;
|
||||
std::string filePath;
|
||||
void *mmapDataAddr[MMAP_MAX_UNIT_NUM];
|
||||
|
||||
void initBootStruct(boot_st &bst, size_t size) const;
|
||||
void initFreeStruct(free_st &fst) const;
|
||||
void initFreeQueue(free_queue_st &fqst) const;
|
||||
void initControlStruct(control_st &cntlst, size_t size) const;
|
||||
|
||||
void setupChunkHead(chunk_head_st *chunk_head, const bool delete_flg, const uint16_t unit_id, const off_t free_next, const size_t size) const;
|
||||
bool expandMemory();
|
||||
int32_t formatFile(const std::string &targetFile, size_t size) const;
|
||||
void clearChunk(const off_t chunk_off) const;
|
||||
|
||||
void free_data_classify(const off_t p, const bool force_large_list = false) const;
|
||||
off_t reuse_data_classify(const size_t size, reuse_state_t &reuse_state, const bool force_large_list = false) const;
|
||||
void free_data_queue(const off_t p);
|
||||
off_t reuse_data_queue(const size_t size, reuse_state_t &reuse_state);
|
||||
void free_data_queue_plus(const off_t p);
|
||||
off_t reuse_data_queue_plus(const size_t size, reuse_state_t &reuse_state);
|
||||
|
||||
bool scanAllData(void *target, const check_statistics_t stats_type) const;
|
||||
|
||||
void upHeap(free_queue_st *free_queue, uint64_t index) const;
|
||||
void downHeap(free_queue_st *free_queue)const;
|
||||
bool insertHeap(free_queue_st *free_queue, const off_t p) const;
|
||||
bool getHeap(free_queue_st *free_queue, off_t *p) const;
|
||||
size_t getMaxHeapValue(free_queue_st *free_queue) const;
|
||||
void dumpHeap() const;
|
||||
|
||||
void divChunk(const off_t chunk_offset, const size_t size);
|
||||
};
|
||||
|
||||
|
||||
MmapManager::Impl::Impl(MmapManager &ommanager):mmanager(ommanager), isOpen(false), mmapCntlAddr(NULL), mmapCntlHead(NULL){}
|
||||
|
||||
|
||||
void MmapManager::Impl::initBootStruct(boot_st &bst, size_t size) const
|
||||
{
|
||||
bst.version = MMAP_MANAGER_VERSION;
|
||||
bst.reserve = 0;
|
||||
bst.size = size;
|
||||
}
|
||||
|
||||
void MmapManager::Impl::initFreeStruct(free_st &fst) const
|
||||
{
|
||||
fst.large_list.free_p = -1;
|
||||
fst.large_list.free_last_p = -1;
|
||||
for(uint32_t i = 0; i < MMAP_FREE_LIST_NUM; ++i){
|
||||
fst.free_lists[i].free_p = -1;
|
||||
fst.free_lists[i].free_last_p = -1;
|
||||
}
|
||||
}
|
||||
|
||||
void MmapManager::Impl::initFreeQueue(free_queue_st &fqst) const
|
||||
{
|
||||
fqst.data = -1;
|
||||
fqst.capacity = MMAP_FREE_QUEUE_SIZE;
|
||||
fqst.tail = 1;
|
||||
}
|
||||
|
||||
void MmapManager::Impl::initControlStruct(control_st &cntlst, size_t size) const
|
||||
{
|
||||
cntlst.use_expand = MMAP_DEFAULT_ALLOW_EXPAND;
|
||||
cntlst.unit_num = 1;
|
||||
cntlst.active_unit = 0;
|
||||
cntlst.reserve = 0;
|
||||
cntlst.base_size = size;
|
||||
cntlst.entry_p = 0;
|
||||
cntlst.reuse_type = REUSE_DATA_CLASSIFY;
|
||||
initFreeStruct(cntlst.free_data);
|
||||
initFreeQueue(cntlst.free_queue);
|
||||
memset(cntlst.data_headers, 0, sizeof(head_st) * MMAP_MAX_UNIT_NUM);
|
||||
}
|
||||
|
||||
void MmapManager::Impl::setupChunkHead(chunk_head_st *chunk_head, const bool delete_flg, const uint16_t unit_id, const off_t free_next, const size_t size) const
|
||||
{
|
||||
chunk_head_st chunk_buffer;
|
||||
chunk_buffer.delete_flg = delete_flg;
|
||||
chunk_buffer.unit_id = unit_id;
|
||||
chunk_buffer.free_next = free_next;
|
||||
chunk_buffer.size = size;
|
||||
|
||||
memcpy(chunk_head, &chunk_buffer, sizeof(chunk_head_st));
|
||||
}
|
||||
|
||||
bool MmapManager::Impl::expandMemory()
|
||||
{
|
||||
const uint16_t new_unit_num = mmapCntlHead->unit_num + 1;
|
||||
const size_t new_file_size = mmapCntlHead->base_size * new_unit_num;
|
||||
const off_t old_file_size = mmapCntlHead->base_size * mmapCntlHead->unit_num;
|
||||
|
||||
if(new_unit_num >= MMAP_MAX_UNIT_NUM){
|
||||
std::cerr << "over max unit num" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t fd = formatFile(filePath, new_file_size);
|
||||
assert(fd >= 0);
|
||||
|
||||
const off_t offset = mmapCntlHead->base_size * mmapCntlHead->unit_num;
|
||||
errno = 0;
|
||||
void *new_area = mmap(NULL, mmapCntlHead->base_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, offset);
|
||||
if(new_area == MAP_FAILED){
|
||||
const std::string err_str = getErrorStr(errno);
|
||||
|
||||
errno = 0;
|
||||
if(ftruncate(fd, old_file_size) == -1){
|
||||
const std::string err_str = getErrorStr(errno);
|
||||
throw MmapManagerException("truncate error" + err_str);
|
||||
}
|
||||
|
||||
if(close(fd) == -1) std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
|
||||
throw MmapManagerException("mmap error" + err_str);
|
||||
}
|
||||
if(close(fd) == -1) std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
|
||||
|
||||
mmapDataAddr[mmapCntlHead->unit_num] = new_area;
|
||||
|
||||
mmapCntlHead->unit_num = new_unit_num;
|
||||
mmapCntlHead->active_unit++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t MmapManager::Impl::formatFile(const std::string &targetFile, size_t size) const
|
||||
{
|
||||
const char *c = "";
|
||||
int32_t fd;
|
||||
|
||||
errno = 0;
|
||||
if((fd = open(targetFile.c_str(), O_RDWR|O_CREAT, 0666)) == -1){
|
||||
std::stringstream ss;
|
||||
ss << "[ERR] Cannot open the file. " << targetFile << " " << getErrorStr(errno);
|
||||
throw MmapManagerException(ss.str());
|
||||
}
|
||||
errno = 0;
|
||||
if(lseek(fd, (off_t)size-1, SEEK_SET) < 0){
|
||||
std::stringstream ss;
|
||||
ss << "[ERR] Cannot seek the file. " << targetFile << " " << getErrorStr(errno);
|
||||
if(close(fd) == -1) std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
throw MmapManagerException(ss.str());
|
||||
}
|
||||
errno = 0;
|
||||
if(write(fd, &c, sizeof(char)) == -1){
|
||||
std::stringstream ss;
|
||||
ss << "[ERR] Cannot write the file. Check the disk space. " << targetFile << " " << getErrorStr(errno);
|
||||
if(close(fd) == -1) std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
throw MmapManagerException(ss.str());
|
||||
}
|
||||
|
||||
return fd;
|
||||
}
|
||||
|
||||
void MmapManager::Impl::clearChunk(const off_t chunk_off) const
|
||||
{
|
||||
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_off);
|
||||
const off_t payload_off = chunk_off + sizeof(chunk_head_st);
|
||||
|
||||
chunk_head->delete_flg = false;
|
||||
chunk_head->free_next = -1;
|
||||
char *payload_addr = (char *)mmanager.getAbsAddr(payload_off);
|
||||
memset(payload_addr, 0, chunk_head->size);
|
||||
}
|
||||
|
||||
void MmapManager::Impl::free_data_classify(const off_t p, const bool force_large_list) const
|
||||
{
|
||||
const off_t chunk_offset = p - sizeof(chunk_head_st);
|
||||
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
|
||||
const size_t p_size = chunk_head->size;
|
||||
|
||||
|
||||
|
||||
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
|
||||
|
||||
free_list_st *free_list;
|
||||
if(p_size <= border_size && force_large_list == false){
|
||||
uint32_t index = (p_size / MMAP_MEMORY_ALIGN) - 1;
|
||||
free_list = &mmapCntlHead->free_data.free_lists[index];
|
||||
}else{
|
||||
free_list = &mmapCntlHead->free_data.large_list;
|
||||
}
|
||||
|
||||
if(free_list->free_p == -1){
|
||||
free_list->free_p = free_list->free_last_p = chunk_offset;
|
||||
}else{
|
||||
off_t last_off = free_list->free_last_p;
|
||||
chunk_head_st *tmp_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(last_off);
|
||||
free_list->free_last_p = tmp_chunk_head->free_next = chunk_offset;
|
||||
}
|
||||
chunk_head->delete_flg = true;
|
||||
}
|
||||
|
||||
off_t MmapManager::Impl::reuse_data_classify(const size_t size, reuse_state_t &reuse_state, const bool force_large_list) const
|
||||
{
|
||||
|
||||
|
||||
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
|
||||
|
||||
free_list_st *free_list;
|
||||
if(size <= border_size && force_large_list == false){
|
||||
uint32_t index = (size / MMAP_MEMORY_ALIGN) - 1;
|
||||
free_list = &mmapCntlHead->free_data.free_lists[index];
|
||||
}else{
|
||||
free_list = &mmapCntlHead->free_data.large_list;
|
||||
}
|
||||
|
||||
if(free_list->free_p == -1){
|
||||
reuse_state = REUSE_STATE_ALLOC;
|
||||
return -1;
|
||||
}
|
||||
|
||||
off_t current_off = free_list->free_p;
|
||||
off_t ret_off = 0;
|
||||
chunk_head_st *current_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(current_off);
|
||||
chunk_head_st *ret_chunk_head = NULL;
|
||||
|
||||
if( (size <= border_size) && (free_list->free_last_p == free_list->free_p) ){
|
||||
ret_off = current_off;
|
||||
ret_chunk_head = current_chunk_head;
|
||||
free_list->free_p = free_list->free_last_p = -1;
|
||||
}else{
|
||||
off_t ret_before_off = -1, before_off = -1;
|
||||
bool found_candidate_flag = false;
|
||||
|
||||
|
||||
while(current_chunk_head != NULL){
|
||||
if( current_chunk_head->size >= size ) found_candidate_flag = true;
|
||||
|
||||
if(found_candidate_flag){
|
||||
ret_off = current_off;
|
||||
ret_chunk_head = current_chunk_head;
|
||||
ret_before_off = before_off;
|
||||
break;
|
||||
}
|
||||
before_off = current_off;
|
||||
current_off = current_chunk_head->free_next;
|
||||
current_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(current_off);
|
||||
}
|
||||
|
||||
if(!found_candidate_flag){
|
||||
reuse_state = REUSE_STATE_ALLOC;
|
||||
return -1;
|
||||
}
|
||||
|
||||
const off_t free_next = ret_chunk_head->free_next;
|
||||
if(free_list->free_p == ret_off){
|
||||
free_list->free_p = free_next;
|
||||
}else{
|
||||
chunk_head_st *before_chunk = (chunk_head_st *)mmanager.getAbsAddr(ret_before_off);
|
||||
before_chunk->free_next = free_next;
|
||||
}
|
||||
|
||||
if(free_list->free_last_p == ret_off){
|
||||
free_list->free_last_p = ret_before_off;
|
||||
}
|
||||
}
|
||||
|
||||
clearChunk(ret_off);
|
||||
|
||||
ret_off = ret_off + sizeof(chunk_head_st);
|
||||
return ret_off;
|
||||
}
|
||||
|
||||
void MmapManager::Impl::free_data_queue(const off_t p)
|
||||
{
|
||||
free_queue_st *free_queue = &mmapCntlHead->free_queue;
|
||||
if(free_queue->data == -1){
|
||||
|
||||
const size_t queue_size = sizeof(off_t) * free_queue->capacity;
|
||||
const off_t alloc_offset = mmanager.alloc(queue_size);
|
||||
if(alloc_offset == -1){
|
||||
|
||||
return free_data_classify(p, true);
|
||||
}
|
||||
free_queue->data = alloc_offset;
|
||||
}else if(free_queue->tail >= free_queue->capacity){
|
||||
|
||||
const off_t tmp_old_queue = free_queue->data;
|
||||
const size_t old_size = sizeof(off_t) * free_queue->capacity;
|
||||
const size_t new_capacity = free_queue->capacity * 2;
|
||||
const size_t new_size = sizeof(off_t) * new_capacity;
|
||||
|
||||
if(new_size > mmapCntlHead->base_size){
|
||||
|
||||
|
||||
return free_data_classify(p, true);
|
||||
}else{
|
||||
const off_t alloc_offset = mmanager.alloc(new_size);
|
||||
if(alloc_offset == -1){
|
||||
|
||||
return free_data_classify(p, true);
|
||||
}
|
||||
free_queue->data = alloc_offset;
|
||||
const off_t *old_data = (off_t *)mmanager.getAbsAddr(tmp_old_queue);
|
||||
off_t *new_data = (off_t *)mmanager.getAbsAddr(free_queue->data);
|
||||
memcpy(new_data, old_data, old_size);
|
||||
|
||||
free_queue->capacity = new_capacity;
|
||||
mmanager.free(tmp_old_queue);
|
||||
}
|
||||
}
|
||||
|
||||
const off_t chunk_offset = p - sizeof(chunk_head_st);
|
||||
if(!insertHeap(free_queue, chunk_offset)){
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
chunk_head_st *chunk_head = (chunk_head_st*)mmanager.getAbsAddr(chunk_offset);
|
||||
chunk_head->delete_flg = 1;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
off_t MmapManager::Impl::reuse_data_queue(const size_t size, reuse_state_t &reuse_state)
|
||||
{
|
||||
free_queue_st *free_queue = &mmapCntlHead->free_queue;
|
||||
if(free_queue->data == -1){
|
||||
|
||||
reuse_state = REUSE_STATE_ALLOC;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if(getMaxHeapValue(free_queue) < size){
|
||||
reuse_state = REUSE_STATE_ALLOC;
|
||||
return -1;
|
||||
}
|
||||
|
||||
off_t ret_off;
|
||||
if(!getHeap(free_queue, &ret_off)){
|
||||
|
||||
reuse_state = REUSE_STATE_ALLOC;
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
reuse_state_t list_state = REUSE_STATE_OK;
|
||||
|
||||
off_t candidate_off = reuse_data_classify(MMAP_MEMORY_ALIGN, list_state, true);
|
||||
if(list_state == REUSE_STATE_OK){
|
||||
|
||||
mmanager.free(candidate_off);
|
||||
}
|
||||
|
||||
const off_t c_ret_off = ret_off;
|
||||
divChunk(c_ret_off, size);
|
||||
|
||||
clearChunk(ret_off);
|
||||
|
||||
ret_off = ret_off + sizeof(chunk_head_st);
|
||||
|
||||
return ret_off;
|
||||
}
|
||||
|
||||
void MmapManager::Impl::free_data_queue_plus(const off_t p)
|
||||
{
|
||||
const off_t chunk_offset = p - sizeof(chunk_head_st);
|
||||
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
|
||||
const size_t p_size = chunk_head->size;
|
||||
|
||||
|
||||
|
||||
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
|
||||
|
||||
if(p_size <= border_size){
|
||||
free_data_classify(p);
|
||||
}else{
|
||||
free_data_queue(p);
|
||||
}
|
||||
}
|
||||
|
||||
off_t MmapManager::Impl::reuse_data_queue_plus(const size_t size, reuse_state_t &reuse_state)
|
||||
{
|
||||
|
||||
|
||||
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
|
||||
|
||||
off_t ret_off;
|
||||
if(size <= border_size){
|
||||
ret_off = reuse_data_classify(size, reuse_state);
|
||||
if(reuse_state == REUSE_STATE_ALLOC){
|
||||
|
||||
reuse_state = REUSE_STATE_OK;
|
||||
ret_off = reuse_data_queue(size, reuse_state);
|
||||
}
|
||||
}else{
|
||||
ret_off = reuse_data_queue(size, reuse_state);
|
||||
}
|
||||
|
||||
return ret_off;
|
||||
}
|
||||
|
||||
|
||||
bool MmapManager::Impl::scanAllData(void *target, const check_statistics_t stats_type) const
|
||||
{
|
||||
const uint16_t unit_num = mmapCntlHead->unit_num;
|
||||
size_t total_size = 0;
|
||||
uint64_t total_chunk_num = 0;
|
||||
|
||||
for(int i = 0; i < unit_num; i++){
|
||||
const head_st *target_unit_head = &mmapCntlHead->data_headers[i];
|
||||
const uint64_t chunk_num = target_unit_head->chunk_num;
|
||||
const off_t base_offset = i * mmapCntlHead->base_size;
|
||||
off_t target_offset = base_offset;
|
||||
chunk_head_st *target_chunk;
|
||||
|
||||
for(uint64_t j = 0; j < chunk_num; j++){
|
||||
target_chunk = (chunk_head_st*)mmanager.getAbsAddr(target_offset);
|
||||
|
||||
if(stats_type == CHECK_STATS_USE_SIZE){
|
||||
if(target_chunk->delete_flg == false){
|
||||
total_size += target_chunk->size;
|
||||
}
|
||||
}else if(stats_type == CHECK_STATS_USE_NUM){
|
||||
if(target_chunk->delete_flg == false){
|
||||
total_chunk_num++;
|
||||
}
|
||||
}else if(stats_type == CHECK_STATS_FREE_SIZE){
|
||||
if(target_chunk->delete_flg == true){
|
||||
total_size += target_chunk->size;
|
||||
}
|
||||
}else if(stats_type == CHECK_STATS_FREE_NUM){
|
||||
if(target_chunk->delete_flg == true){
|
||||
total_chunk_num++;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t chunk_size = sizeof(chunk_head_st) + target_chunk->size;
|
||||
target_offset += chunk_size;
|
||||
}
|
||||
}
|
||||
|
||||
if(stats_type == CHECK_STATS_USE_SIZE || stats_type == CHECK_STATS_FREE_SIZE){
|
||||
size_t *tmp_size = (size_t *)target;
|
||||
*tmp_size = total_size;
|
||||
}else if(stats_type == CHECK_STATS_USE_NUM || stats_type == CHECK_STATS_FREE_NUM){
|
||||
uint64_t *tmp_chunk_num = (uint64_t *)target;
|
||||
*tmp_chunk_num = total_chunk_num;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void MmapManager::Impl::upHeap(free_queue_st *free_queue, uint64_t index) const
|
||||
{
|
||||
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
|
||||
|
||||
while(index > 1){
|
||||
uint64_t parent = index / 2;
|
||||
|
||||
const off_t parent_chunk_offset = queue[parent];
|
||||
const off_t index_chunk_offset = queue[index];
|
||||
const chunk_head_st *parent_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(parent_chunk_offset);
|
||||
const chunk_head_st *index_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(index_chunk_offset);
|
||||
|
||||
if(parent_chunk_head->size < index_chunk_head->size){
|
||||
|
||||
const off_t tmp = queue[parent];
|
||||
queue[parent] = queue[index];
|
||||
queue[index] = tmp;
|
||||
}
|
||||
index = parent;
|
||||
}
|
||||
}
|
||||
|
||||
void MmapManager::Impl::downHeap(free_queue_st *free_queue)const
|
||||
{
|
||||
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
|
||||
uint64_t index = 1;
|
||||
|
||||
while(index * 2 <= free_queue->tail){
|
||||
uint64_t child = index * 2;
|
||||
|
||||
const off_t index_chunk_offset = queue[index];
|
||||
const chunk_head_st *index_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(index_chunk_offset);
|
||||
|
||||
if(child + 1 < free_queue->tail){
|
||||
const off_t left_chunk_offset = queue[child];
|
||||
const off_t right_chunk_offset = queue[child+1];
|
||||
const chunk_head_st *left_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(left_chunk_offset);
|
||||
const chunk_head_st *right_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(right_chunk_offset);
|
||||
|
||||
|
||||
if(left_chunk_head->size < right_chunk_head->size){
|
||||
child = child + 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const off_t child_chunk_offset = queue[child];
|
||||
const chunk_head_st *child_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(child_chunk_offset);
|
||||
|
||||
if(child_chunk_head->size > index_chunk_head->size){
|
||||
|
||||
const off_t tmp = queue[child];
|
||||
queue[child] = queue[index];
|
||||
queue[index] = tmp;
|
||||
index = child;
|
||||
}else{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool MmapManager::Impl::insertHeap(free_queue_st *free_queue, const off_t p) const
|
||||
{
|
||||
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
|
||||
uint64_t index;
|
||||
if(free_queue->capacity < free_queue->tail){
|
||||
return false;
|
||||
}
|
||||
|
||||
index = free_queue->tail;
|
||||
queue[index] = p;
|
||||
free_queue->tail += 1;
|
||||
|
||||
upHeap(free_queue, index);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MmapManager::Impl::getHeap(free_queue_st *free_queue, off_t *p) const
|
||||
{
|
||||
|
||||
if( (free_queue->tail - 1) <= 0){
|
||||
return false;
|
||||
}
|
||||
|
||||
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
|
||||
*p = queue[1];
|
||||
free_queue->tail -= 1;
|
||||
queue[1] = queue[free_queue->tail];
|
||||
downHeap(free_queue);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t MmapManager::Impl::getMaxHeapValue(free_queue_st *free_queue) const
|
||||
{
|
||||
if(free_queue->data == -1){
|
||||
return 0;
|
||||
}
|
||||
const off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
|
||||
const chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(queue[1]);
|
||||
|
||||
return chunk_head->size;
|
||||
}
|
||||
|
||||
void MmapManager::Impl::dumpHeap() const
|
||||
{
|
||||
free_queue_st *free_queue = &mmapCntlHead->free_queue;
|
||||
if(free_queue->data == -1){
|
||||
std::cout << "heap unused" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
|
||||
for(uint32_t i = 1; i < free_queue->tail; ++i){
|
||||
const off_t chunk_offset = queue[i];
|
||||
const off_t payload_offset = chunk_offset + sizeof(chunk_head_st);
|
||||
const chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
|
||||
const size_t size = chunk_head->size;
|
||||
std::cout << "[" << chunk_offset << "(" << payload_offset << "), " << size << "] ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void MmapManager::Impl::divChunk(const off_t chunk_offset, const size_t size)
|
||||
{
|
||||
if((mmapCntlHead->reuse_type != REUSE_DATA_QUEUE)
|
||||
&& (mmapCntlHead->reuse_type != REUSE_DATA_QUEUE_PLUS)){
|
||||
return;
|
||||
}
|
||||
|
||||
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
|
||||
const size_t border_size = sizeof(chunk_head_st) + MMAP_MEMORY_ALIGN;
|
||||
const size_t align_size = getAlignSize(size);
|
||||
const size_t rest_size = chunk_head->size - align_size;
|
||||
|
||||
if(rest_size < border_size){
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
chunk_head->size = align_size;
|
||||
|
||||
const off_t new_chunk_offset = chunk_offset + sizeof(chunk_head_st) + align_size;
|
||||
chunk_head_st *new_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(new_chunk_offset);
|
||||
const size_t new_size = rest_size - sizeof(chunk_head_st);
|
||||
setupChunkHead(new_chunk_head, true, chunk_head->unit_id, -1, new_size);
|
||||
|
||||
|
||||
head_st *unit_header = &mmapCntlHead->data_headers[mmapCntlHead->active_unit];
|
||||
unit_header->chunk_num++;
|
||||
|
||||
|
||||
const off_t payload_offset = new_chunk_offset + sizeof(chunk_head_st);
|
||||
mmanager.free(payload_offset);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,602 @@
|
|||
//
|
||||
// Copyright (C) 2016-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include "NGT/NGTQ/Quantizer.h"
|
||||
|
||||
#define NGTQ_SEARCH_CODEBOOK_SIZE_FLUCTUATION
|
||||
|
||||
namespace NGTQ {
|
||||
|
||||
class Command {
|
||||
public:
|
||||
Command():debugLevel(0) {}
|
||||
|
||||
void
|
||||
create(NGT::Args &args)
|
||||
{
|
||||
const string usage = "Usage: ngtq create "
|
||||
"[-o object-type (f:float|c:unsigned char)] [-D distance-function] [-n data-size] "
|
||||
"[-p #-of-thread] [-d dimension] [-R global-codebook-range] [-r local-codebook-range] "
|
||||
"[-C global-codebook-size-limit] [-c local-codebook-size-limit] [-N local-division-no] "
|
||||
"[-T single-local-centroid (t|f)] [-e epsilon] [-i index-type (t:Tree|g:Graph)] "
|
||||
"[-M global-centroid-creation-mode (d|s)] [-L global-centroid-creation-mode (d|k|s)] "
|
||||
"[-S local-sample-coefficient] "
|
||||
"index(output) data.tsv(input)";
|
||||
string database;
|
||||
try {
|
||||
database = args.get("#1");
|
||||
} catch (...) {
|
||||
cerr << "DB is not specified." << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
string data;
|
||||
try {
|
||||
data = args.get("#2");
|
||||
} catch (...) {
|
||||
cerr << "Data is not specified." << endl;
|
||||
}
|
||||
|
||||
char objectType = args.getChar("o", 'f');
|
||||
char distanceType = args.getChar("D", '2');
|
||||
size_t dataSize = args.getl("n", 0);
|
||||
|
||||
NGTQ::Property property;
|
||||
property.threadSize = args.getl("p", 24);
|
||||
property.dimension = args.getl("d", 0);
|
||||
property.globalRange = args.getf("R", 0);
|
||||
property.localRange = args.getf("r", 0);
|
||||
property.globalCentroidLimit = args.getl("C", 1000000);
|
||||
property.localCentroidLimit = args.getl("c", 65000);
|
||||
property.localDivisionNo = args.getl("N", 8);
|
||||
property.batchSize = args.getl("b", 1000);
|
||||
property.localClusteringSampleCoefficient = args.getl("S", 10);
|
||||
{
|
||||
char localCentroidType = args.getChar("T", 'f');
|
||||
property.singleLocalCodebook = localCentroidType == 't' ? true : false;
|
||||
}
|
||||
{
|
||||
char centroidCreationMode = args.getChar("M", 'd');
|
||||
switch(centroidCreationMode) {
|
||||
case 'd': property.centroidCreationMode = NGTQ::CentroidCreationModeDynamic; break;
|
||||
case 's': property.centroidCreationMode = NGTQ::CentroidCreationModeStatic; break;
|
||||
default:
|
||||
cerr << "ngt: Invalid centroid creation mode. " << centroidCreationMode << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
{
|
||||
char localCentroidCreationMode = args.getChar("L", 'd');
|
||||
switch(localCentroidCreationMode) {
|
||||
case 'd': property.localCentroidCreationMode = NGTQ::CentroidCreationModeDynamic; break;
|
||||
case 's': property.localCentroidCreationMode = NGTQ::CentroidCreationModeStatic; break;
|
||||
case 'k': property.localCentroidCreationMode = NGTQ::CentroidCreationModeDynamicKmeans; break;
|
||||
default:
|
||||
cerr << "ngt: Invalid centroid creation mode. " << localCentroidCreationMode << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
NGT::Property globalProperty;
|
||||
NGT::Property localProperty;
|
||||
|
||||
{
|
||||
char indexType = args.getChar("i", 't');
|
||||
globalProperty.indexType = indexType == 't' ? NGT::Property::GraphAndTree : NGT::Property::Graph;
|
||||
localProperty.indexType = globalProperty.indexType;
|
||||
}
|
||||
globalProperty.insertionRadiusCoefficient = args.getf("e", 0.1) + 1.0;
|
||||
localProperty.insertionRadiusCoefficient = globalProperty.insertionRadiusCoefficient;
|
||||
|
||||
if (debugLevel >= 1) {
|
||||
cerr << "epsilon=" << globalProperty.insertionRadiusCoefficient << endl;
|
||||
cerr << "data size=" << dataSize << endl;
|
||||
cerr << "dimension=" << property.dimension << endl;
|
||||
cerr << "thread size=" << property.threadSize << endl;
|
||||
cerr << "batch size=" << localProperty.batchSizeForCreation << endl;;
|
||||
cerr << "index type=" << globalProperty.indexType << endl;
|
||||
}
|
||||
|
||||
|
||||
switch (objectType) {
|
||||
case 'f': property.dataType = NGTQ::DataTypeFloat; break;
|
||||
case 'c': property.dataType = NGTQ::DataTypeUint8; break;
|
||||
default:
|
||||
cerr << "ngt: Invalid object type. " << objectType << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
|
||||
switch (distanceType) {
|
||||
case '2': property.distanceType = NGTQ::DistanceTypeL2; break;
|
||||
case '1': property.distanceType = NGTQ::DistanceTypeL1; break;
|
||||
case 'a': property.distanceType = NGTQ::DistanceTypeAngle; break;
|
||||
default:
|
||||
cerr << "ngt: Invalid distance type. " << distanceType << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
|
||||
cerr << "ngtq: Create" << endl;
|
||||
NGTQ::Index::create(database, property, globalProperty, localProperty);
|
||||
|
||||
cerr << "ngtq: Append" << endl;
|
||||
NGTQ::Index::append(database, data, dataSize);
|
||||
}
|
||||
|
||||
void
|
||||
rebuild(NGT::Args &args)
|
||||
{
|
||||
const string usage = "Usage: ngtq rebuild "
|
||||
"[-o object-type (f:float|c:unsigned char)] [-D distance-function] [-n data-size] "
|
||||
"[-p #-of-thread] [-d dimension] [-R global-codebook-range] [-r local-codebook-range] "
|
||||
"[-C global-codebook-size-limit] [-c local-codebook-size-limit] [-N local-division-no] "
|
||||
"[-T single-local-centroid (t|f)] [-e epsilon] [-i index-type (t:Tree|g:Graph)] "
|
||||
"[-M centroid-creation_mode (d|s)] "
|
||||
"index(output) data.tsv(input)";
|
||||
string srcIndex;
|
||||
try {
|
||||
srcIndex = args.get("#1");
|
||||
} catch (...) {
|
||||
cerr << "DB is not specified." << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
string rebuiltIndex = srcIndex + ".tmp";
|
||||
|
||||
|
||||
NGTQ::Property property;
|
||||
NGT::Property globalProperty;
|
||||
NGT::Property localProperty;
|
||||
|
||||
{
|
||||
NGTQ::Index index(srcIndex);
|
||||
property = index.getQuantizer().property;
|
||||
index.getQuantizer().globalCodebook.getProperty(globalProperty);
|
||||
index.getQuantizer().getLocalCodebook(0).getProperty(localProperty);
|
||||
}
|
||||
|
||||
property.globalRange = args.getf("R", property.globalRange);
|
||||
property.localRange = args.getf("r", property.localRange);
|
||||
property.globalCentroidLimit = args.getl("C", property.globalCentroidLimit);
|
||||
property.localCentroidLimit = args.getl("c", property.localCentroidLimit);
|
||||
property.localDivisionNo = args.getl("N", property.localDivisionNo);
|
||||
{
|
||||
char localCentroidType = args.getChar("T", '-');
|
||||
if (localCentroidType != '-') {
|
||||
property.singleLocalCodebook = localCentroidType == 't' ? true : false;
|
||||
}
|
||||
}
|
||||
{
|
||||
char centroidCreationMode = args.getChar("M", '-');
|
||||
if (centroidCreationMode != '-') {
|
||||
property.centroidCreationMode = centroidCreationMode == 'd' ?
|
||||
NGTQ::CentroidCreationModeDynamic : NGTQ::CentroidCreationModeStatic;
|
||||
}
|
||||
}
|
||||
|
||||
cerr << "global range=" << property.globalRange << endl;
|
||||
cerr << "local range=" << property.localRange << endl;
|
||||
cerr << "global centroid limit=" << property.globalCentroidLimit << endl;
|
||||
cerr << "local centroid limit=" << property.localCentroidLimit << endl;
|
||||
cerr << "local division no=" << property.localDivisionNo << endl;
|
||||
|
||||
NGTQ::Index::create(rebuiltIndex, property, globalProperty, localProperty);
|
||||
cerr << "created a new db" << endl;
|
||||
cerr << "start rebuilding..." << endl;
|
||||
NGTQ::Index::rebuild(srcIndex, rebuiltIndex);
|
||||
{
|
||||
string src = srcIndex;
|
||||
string dst = srcIndex + ".org";
|
||||
if (std::rename(src.c_str(), dst.c_str()) != 0) {
|
||||
stringstream msg;
|
||||
msg << "ngtq::rebuild: Cannot rename. " << src << "=>" << dst ;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
}
|
||||
{
|
||||
string src = rebuiltIndex;
|
||||
string dst = srcIndex;
|
||||
if (std::rename(src.c_str(), dst.c_str()) != 0) {
|
||||
stringstream msg;
|
||||
msg << "ngtq::rebuild: Cannot rename. " << src << "=>" << dst ;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
append(NGT::Args &args)
|
||||
{
|
||||
const string usage = "Usage: ngtq append [-n data-size] "
|
||||
"index(output) data.tsv(input)";
|
||||
string index;
|
||||
try {
|
||||
index = args.get("#1");
|
||||
} catch (...) {
|
||||
cerr << "DB is not specified." << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
string data;
|
||||
try {
|
||||
data = args.get("#2");
|
||||
} catch (...) {
|
||||
cerr << "Data is not specified." << endl;
|
||||
}
|
||||
|
||||
size_t dataSize = args.getl("n", 0);
|
||||
|
||||
if (debugLevel >= 1) {
|
||||
cerr << "data size=" << dataSize << endl;
|
||||
}
|
||||
|
||||
NGTQ::Index::append(index, data, dataSize);
|
||||
|
||||
}
|
||||
|
||||
void
|
||||
search(NGT::Args &args)
|
||||
{
|
||||
const string usage = "Usage: ngtq search [-i g|t|s] [-n result-size] [-e epsilon] [-m mode(r|l|c|a)] "
|
||||
"[-E edge-size] [-o output-mode] [-b result expansion(begin:end:[x]step)] "
|
||||
"index(input) query.tsv(input)";
|
||||
string database;
|
||||
try {
|
||||
database = args.get("#1");
|
||||
} catch (...) {
|
||||
cerr << "DB is not specified" << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
|
||||
string query;
|
||||
try {
|
||||
query = args.get("#2");
|
||||
} catch (...) {
|
||||
cerr << "Query is not specified" << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
|
||||
int size = args.getl("n", 20);
|
||||
char outputMode = args.getChar("o", '-');
|
||||
float epsilon = 0.1;
|
||||
|
||||
char mode = args.getChar("m", '-');
|
||||
NGTQ::AggregationMode aggregationMode;
|
||||
switch (mode) {
|
||||
case 'r': aggregationMode = NGTQ::AggregationModeExactDistanceThroughApproximateDistance; break; // refine
|
||||
case 'e': aggregationMode = NGTQ::AggregationModeExactDistance; break; // refine
|
||||
case 'l': aggregationMode = NGTQ::AggregationModeApproximateDistanceWithLookupTable; break; // lookup
|
||||
case 'c': aggregationMode = NGTQ::AggregationModeApproximateDistanceWithCache; break; // cache
|
||||
case '-':
|
||||
case 'a': aggregationMode = NGTQ::AggregationModeApproximateDistance; break; // cache
|
||||
default:
|
||||
cerr << "Invalid aggregation mode. " << mode << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
|
||||
if (args.getString("e", "none") == "-") {
|
||||
// linear search
|
||||
epsilon = FLT_MAX;
|
||||
} else {
|
||||
epsilon = args.getf("e", 0.1);
|
||||
}
|
||||
|
||||
size_t beginOfResultExpansion, endOfResultExpansion, stepOfResultExpansion;
|
||||
bool mulStep = false;
|
||||
{
|
||||
beginOfResultExpansion = stepOfResultExpansion = 1;
|
||||
endOfResultExpansion = 0;
|
||||
string str = args.getString("b", "16");
|
||||
vector<string> tokens;
|
||||
NGT::Common::tokenize(str, tokens, ":");
|
||||
if (tokens.size() >= 1) { beginOfResultExpansion = NGT::Common::strtod(tokens[0]); }
|
||||
if (tokens.size() >= 2) { endOfResultExpansion = NGT::Common::strtod(tokens[1]); }
|
||||
if (tokens.size() >= 3) {
|
||||
if (tokens[2][0] == 'x') {
|
||||
mulStep = true;
|
||||
stepOfResultExpansion = NGT::Common::strtod(tokens[2].substr(1));
|
||||
} else {
|
||||
stepOfResultExpansion = NGT::Common::strtod(tokens[2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (debugLevel >= 1) {
|
||||
cerr << "size=" << size << endl;
|
||||
cerr << "result expansion=" << beginOfResultExpansion << "->" << endOfResultExpansion << "," << stepOfResultExpansion << endl;
|
||||
}
|
||||
|
||||
NGTQ::Index index(database);
|
||||
try {
|
||||
ifstream is(query);
|
||||
if (!is) {
|
||||
cerr << "Cannot open the specified file. " << query << endl;
|
||||
return;
|
||||
}
|
||||
if (outputMode == 's') { cout << "# Beginning of Evaluation" << endl; }
|
||||
string line;
|
||||
double totalTime = 0;
|
||||
int queryCount = 0;
|
||||
while(getline(is, line)) {
|
||||
NGT::Object *query = index.allocateObject(line, " \t", 0);
|
||||
queryCount++;
|
||||
size_t resultExpansion = 0;
|
||||
for (size_t base = beginOfResultExpansion;
|
||||
resultExpansion <= endOfResultExpansion;
|
||||
base = mulStep ? base * stepOfResultExpansion : base + stepOfResultExpansion) {
|
||||
resultExpansion = base;
|
||||
NGT::ObjectDistances objects;
|
||||
|
||||
if (outputMode == 'e') {
|
||||
index.search(query, objects, size, resultExpansion, aggregationMode, epsilon);
|
||||
objects.clear();
|
||||
}
|
||||
|
||||
NGT::Timer timer;
|
||||
timer.start();
|
||||
// size : # of final resultant objects
|
||||
// resultExpansion : # of resultant objects by using codebook search
|
||||
index.search(query, objects, size, resultExpansion, aggregationMode, epsilon);
|
||||
timer.stop();
|
||||
|
||||
totalTime += timer.time;
|
||||
if (outputMode == 'e') {
|
||||
cout << "# Query No.=" << queryCount << endl;
|
||||
cout << "# Query=" << line.substr(0, 20) + " ..." << endl;
|
||||
cout << "# Index Type=" << "----" << endl;
|
||||
cout << "# Size=" << size << endl;
|
||||
cout << "# Epsilon=" << epsilon << endl;
|
||||
cout << "# Result expansion=" << resultExpansion << endl;
|
||||
cout << "# Distance Computation=" << index.getQuantizer().distanceComputationCount << endl;
|
||||
cout << "# Query Time (msec)=" << timer.time * 1000.0 << endl;
|
||||
} else {
|
||||
cout << "Query No." << queryCount << endl;
|
||||
cout << "Rank\tIN-ID\tID\tDistance" << endl;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < objects.size(); i++) {
|
||||
cout << i + 1 << "\t" << objects[i].id << "\t";
|
||||
cout << objects[i].distance << endl;
|
||||
}
|
||||
|
||||
if (outputMode == 'e') {
|
||||
cout << "# End of Search" << endl;
|
||||
} else {
|
||||
cout << "Query Time= " << timer.time << " (sec), " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
}
|
||||
}
|
||||
if (outputMode == 'e') {
|
||||
cout << "# End of Query" << endl;
|
||||
}
|
||||
index.deleteObject(query);
|
||||
}
|
||||
if (outputMode == 'e') {
|
||||
cout << "# Average Query Time (msec)=" << totalTime * 1000.0 / (double)queryCount << endl;
|
||||
cout << "# Number of queries=" << queryCount << endl;
|
||||
cout << "# End of Evaluation" << endl;
|
||||
} else {
|
||||
cout << "Average Query Time= " << totalTime / (double)queryCount << " (sec), "
|
||||
<< totalTime * 1000.0 / (double)queryCount << " (msec), ("
|
||||
<< totalTime << "/" << queryCount << ")" << endl;
|
||||
}
|
||||
} catch (NGT::Exception &err) {
|
||||
cerr << "Error " << err.what() << endl;
|
||||
cerr << usage << endl;
|
||||
} catch (...) {
|
||||
cerr << "Error" << endl;
|
||||
cerr << usage << endl;
|
||||
}
|
||||
index.close();
|
||||
}
|
||||
|
||||
void
|
||||
remove(NGT::Args &args)
|
||||
{
|
||||
const string usage = "Usage: ngtq remove [-d object-ID-type(f|d)] index(input) object-ID(input)";
|
||||
string database;
|
||||
try {
|
||||
database = args.get("#1");
|
||||
} catch (...) {
|
||||
cerr << "DB is not specified" << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
try {
|
||||
args.get("#2");
|
||||
} catch (...) {
|
||||
cerr << "ID is not specified" << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
char dataType = args.getChar("d", 'f');
|
||||
if (debugLevel >= 1) {
|
||||
cerr << "dataType=" << dataType << endl;
|
||||
}
|
||||
|
||||
try {
|
||||
vector<NGT::ObjectID> objects;
|
||||
if (dataType == 'f') {
|
||||
string ids;
|
||||
try {
|
||||
ids = args.get("#2");
|
||||
} catch (...) {
|
||||
cerr << "Data file is not specified" << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
ifstream is(ids);
|
||||
if (!is) {
|
||||
cerr << "Cannot open the specified file. " << ids << endl;
|
||||
return;
|
||||
}
|
||||
string line;
|
||||
int count = 0;
|
||||
while(getline(is, line)) {
|
||||
count++;
|
||||
vector<string> tokens;
|
||||
NGT::Common::tokenize(line, tokens, "\t ");
|
||||
if (tokens.size() == 0 || tokens[0].size() == 0) {
|
||||
continue;
|
||||
}
|
||||
char *e;
|
||||
size_t id;
|
||||
try {
|
||||
id = strtol(tokens[0].c_str(), &e, 10);
|
||||
objects.push_back(id);
|
||||
} catch (...) {
|
||||
cerr << "Illegal data. " << tokens[0] << endl;
|
||||
}
|
||||
if (*e != 0) {
|
||||
cerr << "Illegal data. " << e << endl;
|
||||
}
|
||||
cerr << "removed ID=" << id << endl;
|
||||
}
|
||||
} else {
|
||||
size_t id = args.getl("#2", 0);
|
||||
cerr << "removed ID=" << id << endl;
|
||||
objects.push_back(id);
|
||||
}
|
||||
NGT::Index::remove(database, objects);
|
||||
} catch (NGT::Exception &err) {
|
||||
cerr << "Error " << err.what() << endl;
|
||||
cerr << usage << endl;
|
||||
} catch (...) {
|
||||
cerr << "Error" << endl;
|
||||
cerr << usage << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
info(NGT::Args &args)
|
||||
{
|
||||
const string usage = "Usage: ngtq info index";
|
||||
string database;
|
||||
try {
|
||||
database = args.get("#1");
|
||||
} catch (...) {
|
||||
cerr << "DB is not specified" << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
NGTQ::Index index(database);
|
||||
index.info(cout);
|
||||
|
||||
}
|
||||
|
||||
void
|
||||
validate(NGT::Args &args)
|
||||
{
|
||||
const string usage = "parameter";
|
||||
string database;
|
||||
try {
|
||||
database = args.get("#1");
|
||||
} catch (...) {
|
||||
cerr << "DB is not specified" << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
NGTQ::Index index(database);
|
||||
|
||||
index.getQuantizer().validate();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
#ifdef NGTQ_SHARED_INVERTED_INDEX
|
||||
void
|
||||
compress(NGT::Args &args)
|
||||
{
|
||||
const string usage = "Usage: ngtq compress index)";
|
||||
string database;
|
||||
try {
|
||||
database = args.get("#1");
|
||||
} catch (...) {
|
||||
cerr << "DB is not specified" << endl;
|
||||
cerr << usage << endl;
|
||||
return;
|
||||
}
|
||||
try {
|
||||
NGTQ::Index::compress(database);
|
||||
} catch (NGT::Exception &err) {
|
||||
cerr << "Error " << err.what() << endl;
|
||||
cerr << usage << endl;
|
||||
} catch (...) {
|
||||
cerr << "Error" << endl;
|
||||
cerr << usage << endl;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void help() {
|
||||
cerr << "Usage : ngtq command database data" << endl;
|
||||
cerr << " command : create search remove append export import" << endl;
|
||||
}
|
||||
|
||||
void execute(NGT::Args args) {
|
||||
string command;
|
||||
try {
|
||||
command = args.get("#0");
|
||||
} catch(...) {
|
||||
help();
|
||||
return;
|
||||
}
|
||||
|
||||
debugLevel = args.getl("X", 0);
|
||||
|
||||
try {
|
||||
if (debugLevel >= 1) {
|
||||
cerr << "ngt::command=" << command << endl;
|
||||
}
|
||||
if (command == "search") {
|
||||
search(args);
|
||||
} else if (command == "create") {
|
||||
create(args);
|
||||
} else if (command == "append") {
|
||||
append(args);
|
||||
} else if (command == "remove") {
|
||||
remove(args);
|
||||
} else if (command == "info") {
|
||||
info(args);
|
||||
} else if (command == "validate") {
|
||||
validate(args);
|
||||
} else if (command == "rebuild") {
|
||||
rebuild(args);
|
||||
#ifdef NGTQ_SHARED_INVERTED_INDEX
|
||||
} else if (command == "compress") {
|
||||
compress(args);
|
||||
#endif
|
||||
} else {
|
||||
cerr << "Illegal command. " << command << endl;
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
cerr << "ngt: Fatal error: " << err.what() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
int debugLevel;
|
||||
|
||||
};
|
||||
|
||||
};
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,338 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include "NGT/defines.h"
|
||||
|
||||
#include "NGT/Node.h"
|
||||
#include "NGT/Tree.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using namespace std;
|
||||
|
||||
const double NGT::Node::Object::Pivot = -1.0;
|
||||
|
||||
using namespace NGT;
|
||||
|
||||
void
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
InternalNode::updateChild(DVPTree &dvptree, Node::ID src, Node::ID dst,
|
||||
SharedMemoryAllocator &allocator) {
|
||||
#else
|
||||
InternalNode::updateChild(DVPTree &dvptree, Node::ID src, Node::ID dst) {
|
||||
#endif
|
||||
int cs = dvptree.internalChildrenSize;
|
||||
for (int i = 0; i < cs; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
if (getChildren(allocator)[i] == src) {
|
||||
getChildren(allocator)[i] = dst;
|
||||
#else
|
||||
if (getChildren()[i] == src) {
|
||||
getChildren()[i] = dst;
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int
|
||||
LeafNode::selectPivotByMaxDistance(Container &c, Node::Objects &fs)
|
||||
{
|
||||
DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c;
|
||||
int fsize = fs.size();
|
||||
Distance maxd = 0.0;
|
||||
int maxid = 0;
|
||||
for (int i = 1; i < fsize; i++) {
|
||||
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[0].object, *fs[i].object);
|
||||
if (d >= maxd) {
|
||||
maxd = d;
|
||||
maxid = i;
|
||||
}
|
||||
}
|
||||
|
||||
int aid = maxid;
|
||||
maxd = 0.0;
|
||||
maxid = 0;
|
||||
for (int i = 0; i < fsize; i++) {
|
||||
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[aid].object, *fs[i].object);
|
||||
if (i == aid) {
|
||||
continue;
|
||||
}
|
||||
if (d >= maxd) {
|
||||
maxd = d;
|
||||
maxid = i;
|
||||
}
|
||||
}
|
||||
|
||||
int bid = maxid;
|
||||
maxd = 0.0;
|
||||
maxid = 0;
|
||||
for (int i = 0; i < fsize; i++) {
|
||||
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[bid].object, *fs[i].object);
|
||||
if (i == bid) {
|
||||
continue;
|
||||
}
|
||||
if (d >= maxd) {
|
||||
maxd = d;
|
||||
maxid = i;
|
||||
}
|
||||
}
|
||||
return maxid;
|
||||
}
|
||||
|
||||
int
|
||||
LeafNode::selectPivotByMaxVariance(Container &c, Node::Objects &fs)
|
||||
{
|
||||
DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c;
|
||||
|
||||
int fsize = fs.size();
|
||||
Distance *distance = new Distance[fsize * fsize];
|
||||
|
||||
for (int i = 0; i < fsize; i++) {
|
||||
distance[i * fsize + i] = 0;
|
||||
}
|
||||
|
||||
for (int i = 0; i < fsize; i++) {
|
||||
for (int j = i + 1; j < fsize; j++) {
|
||||
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[i].object, *fs[j].object);
|
||||
distance[i * fsize + j] = d;
|
||||
distance[j * fsize + i] = d;
|
||||
}
|
||||
}
|
||||
|
||||
double *variance = new double[fsize];
|
||||
for (int i = 0; i < fsize; i++) {
|
||||
double avg = 0.0;
|
||||
for (int j = 0; j < fsize; j++) {
|
||||
avg += distance[i * fsize + j];
|
||||
}
|
||||
avg /= (double)fsize;
|
||||
|
||||
double v = 0.0;
|
||||
for (int j = 0; j < fsize; j++) {
|
||||
v += pow(distance[i * fsize + j] - avg, 2.0);
|
||||
}
|
||||
variance[i] = v / (double)fsize;
|
||||
}
|
||||
|
||||
double maxv = variance[0];
|
||||
int maxid = 0;
|
||||
for (int i = 0; i < fsize; i++) {
|
||||
if (variance[i] > maxv) {
|
||||
maxv = variance[i];
|
||||
maxid = i;
|
||||
}
|
||||
}
|
||||
delete [] variance;
|
||||
delete [] distance;
|
||||
|
||||
return maxid;
|
||||
}
|
||||
|
||||
void
|
||||
LeafNode::splitObjects(Container &c, Objects &fs, int pv)
|
||||
{
|
||||
DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c;
|
||||
|
||||
// sort the objects by distance
|
||||
int fsize = fs.size();
|
||||
for (int i = 0; i < fsize; i++) {
|
||||
if (i == pv) {
|
||||
fs[i].distance = 0;
|
||||
} else {
|
||||
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[pv].object, *fs[i].object);
|
||||
fs[i].distance = d;
|
||||
}
|
||||
}
|
||||
|
||||
sort(fs.begin(), fs.end());
|
||||
|
||||
int childrenSize = iobj.vptree->internalChildrenSize;
|
||||
int cid = childrenSize - 1;
|
||||
int cms = (fsize * cid) / childrenSize;
|
||||
|
||||
// divide the objects into child clusters.
|
||||
fs[fsize - 1].clusterID = cid;
|
||||
for (int i = fsize - 2; i >= 0; i--) {
|
||||
if (i < cms && cid > 0) {
|
||||
if (fs[i].distance != fs[i + 1].distance) {
|
||||
cid--;
|
||||
cms = (fsize * cid) / childrenSize;
|
||||
}
|
||||
}
|
||||
fs[i].clusterID = cid;
|
||||
}
|
||||
|
||||
if (cid != 0) {
|
||||
// the required number of child nodes could not be acquired
|
||||
stringstream msg;
|
||||
msg << "LeafNode::splitObjects: Too many same distances. Reduce internal children size for the tree index or not use the tree index." << endl;
|
||||
msg << " internalChildrenSize=" << childrenSize << endl;
|
||||
msg << " # of the children=" << (childrenSize - cid) << endl;
|
||||
msg << " Size=" << fsize << endl;
|
||||
msg << " pivot=" << pv << endl;
|
||||
msg << " cluster id=" << cid << endl;
|
||||
msg << " Show distances for debug." << endl;
|
||||
for (int i = 0; i < fsize; i++) {
|
||||
msg << " " << fs[i].id << ":" << fs[i].distance << endl;
|
||||
msg << " ";
|
||||
PersistentObject &po = *fs[i].object;
|
||||
iobj.vptree->objectSpace->show(msg, po);
|
||||
msg << endl;
|
||||
}
|
||||
if (fs[fsize - 1].clusterID == cid) {
|
||||
msg << "LeafNode::splitObjects: All of the object distances are the same!" << endl;;
|
||||
NGTThrowException(msg.str());
|
||||
} else {
|
||||
cerr << msg.str() << endl;
|
||||
cerr << "LeafNode::splitObjects: Anyway, continue..." << endl;
|
||||
// sift the cluster IDs to start from 0 to continue.
|
||||
for (int i = 0; i < fsize; i++) {
|
||||
fs[i].clusterID -= cid;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
long long *pivots = new long long[childrenSize];
|
||||
for (int i = 0; i < childrenSize; i++) {
|
||||
pivots[i] = -1;
|
||||
}
|
||||
|
||||
// find the boundaries for the subspaces
|
||||
for (int i = 0; i < fsize; i++) {
|
||||
if (pivots[fs[i].clusterID] == -1) {
|
||||
pivots[fs[i].clusterID] = i;
|
||||
fs[i].leafDistance = Object::Pivot;
|
||||
} else {
|
||||
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[pivots[fs[i].clusterID]].object, *fs[i].object);
|
||||
fs[i].leafDistance = d;
|
||||
}
|
||||
}
|
||||
delete[] pivots;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
LeafNode::removeObject(size_t id, size_t replaceId, SharedMemoryAllocator &allocator) {
|
||||
#else
|
||||
LeafNode::removeObject(size_t id, size_t replaceId) {
|
||||
#endif
|
||||
|
||||
size_t fsize = getObjectSize();
|
||||
size_t idx;
|
||||
for (idx = 0; idx < fsize; idx++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
if (getObjectIDs(allocator)[idx].id == id) {
|
||||
if (replaceId != 0) {
|
||||
getObjectIDs(allocator)[idx].id = replaceId;
|
||||
#else
|
||||
if (getObjectIDs()[idx].id == id) {
|
||||
if (replaceId != 0) {
|
||||
getObjectIDs()[idx].id = replaceId;
|
||||
#endif
|
||||
return;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (idx == fsize) {
|
||||
if (pivot == 0) {
|
||||
NGTThrowException("LeafNode::removeObject: Internal error!. the pivot is illegal.");
|
||||
}
|
||||
stringstream msg;
|
||||
msg << "VpTree::Leaf::remove: Warning. Cannot find the specified object. ID=" << id << "," << replaceId << " idx=" << idx << " If the same objects were inserted into the index, ignore this message.";
|
||||
NGTThrowException(msg.str());
|
||||
}
|
||||
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
for (; idx < objectIDs.size() - 1; idx++) {
|
||||
getObjectIDs()[idx] = getObjectIDs()[idx + 1];
|
||||
}
|
||||
objectIDs.pop_back();
|
||||
#else
|
||||
objectSize--;
|
||||
for (; idx < objectSize; idx++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
getObjectIDs(allocator)[idx] = getObjectIDs(allocator)[idx + 1];
|
||||
#else
|
||||
getObjectIDs()[idx] = getObjectIDs()[idx + 1];
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
bool InternalNode::verify(PersistentRepository<InternalNode> &internalNodes, PersistentRepository<LeafNode> &leafNodes,
|
||||
SharedMemoryAllocator &allocator) {
|
||||
#else
|
||||
bool InternalNode::verify(Repository<InternalNode> &internalNodes, Repository<LeafNode> &leafNodes) {
|
||||
#endif
|
||||
size_t isize = internalNodes.size();
|
||||
size_t lsize = leafNodes.size();
|
||||
bool valid = true;
|
||||
for (size_t i = 0; i < childrenSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
size_t nid = getChildren(allocator)[i].getID();
|
||||
ID::Type type = getChildren(allocator)[i].getType();
|
||||
#else
|
||||
size_t nid = getChildren()[i].getID();
|
||||
ID::Type type = getChildren()[i].getType();
|
||||
#endif
|
||||
size_t size = type == ID::Leaf ? lsize : isize;
|
||||
if (nid >= size) {
|
||||
cerr << "Error! Internal children node id is too big." << nid << ":" << size << endl;
|
||||
valid = false;
|
||||
}
|
||||
try {
|
||||
if (type == ID::Leaf) {
|
||||
leafNodes.get(nid);
|
||||
} else {
|
||||
internalNodes.get(nid);
|
||||
}
|
||||
} catch (...) {
|
||||
cerr << "Error! Cannot get the node. " << ((type == ID::Leaf) ? "Leaf" : "Internal") << endl;
|
||||
valid = false;
|
||||
}
|
||||
}
|
||||
return valid;
|
||||
}
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
bool LeafNode::verify(size_t nobjs, vector<uint8_t> &status, SharedMemoryAllocator &allocator) {
|
||||
#else
|
||||
bool LeafNode::verify(size_t nobjs, vector<uint8_t> &status) {
|
||||
#endif
|
||||
bool valid = true;
|
||||
for (size_t i = 0; i < objectSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
size_t nid = getObjectIDs(allocator)[i].id;
|
||||
#else
|
||||
size_t nid = getObjectIDs()[i].id;
|
||||
#endif
|
||||
if (nid > nobjs) {
|
||||
cerr << "Error! Object id is too big. " << nid << ":" << nobjs << endl;
|
||||
valid =false;
|
||||
continue;
|
||||
}
|
||||
status[nid] |= 0x04;
|
||||
}
|
||||
return valid;
|
||||
}
|
|
@ -0,0 +1,772 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include "NGT/Common.h"
|
||||
#include "NGT/ObjectSpaceRepository.h"
|
||||
#include "NGT/defines.h"
|
||||
|
||||
namespace NGT {
|
||||
class DVPTree;
|
||||
class InternalNode;
|
||||
class LeafNode;
|
||||
class Node {
|
||||
public:
|
||||
typedef unsigned int NodeID;
|
||||
class ID {
|
||||
public:
|
||||
enum Type {
|
||||
Leaf = 1,
|
||||
Internal = 0
|
||||
};
|
||||
ID():id(0) {}
|
||||
ID &operator=(const ID &n) {
|
||||
id = n.id;
|
||||
return *this;
|
||||
}
|
||||
ID &operator=(int i) {
|
||||
setID(i);
|
||||
return *this;
|
||||
}
|
||||
bool operator==(ID &n) { return id == n.id; }
|
||||
bool operator<(ID &n) { return id < n.id; }
|
||||
Type getType() { return (Type)((0x80000000 & id) >> 31); }
|
||||
NodeID getID() { return 0x7fffffff & id; }
|
||||
NodeID get() { return id; }
|
||||
void setID(NodeID i) { id = (0x80000000 & id) | i; }
|
||||
void setType(Type t) { id = (t << 31) | getID(); }
|
||||
void setRaw(NodeID i) { id = i; }
|
||||
void setNull() { id = 0; }
|
||||
// for milvus
|
||||
void serialize(std::stringstream & os) { NGT::Serializer::write(os, id); }
|
||||
void serialize(std::ofstream &os) { NGT::Serializer::write(os, id); }
|
||||
void deserialize(std::ifstream &is) { NGT::Serializer::read(is, id); }
|
||||
// for milvus
|
||||
void deserialize(std::stringstream & is) { NGT::Serializer::read(is, id); }
|
||||
void serializeAsText(std::ofstream &os) { NGT::Serializer::writeAsText(os, id); }
|
||||
void deserializeAsText(std::ifstream &is) { NGT::Serializer::readAsText(is, id); }
|
||||
protected:
|
||||
NodeID id;
|
||||
};
|
||||
|
||||
class Object {
|
||||
public:
|
||||
Object():object(0) {}
|
||||
bool operator<(const Object &o) const { return distance < o.distance; }
|
||||
static const double Pivot;
|
||||
ObjectID id;
|
||||
PersistentObject *object;
|
||||
Distance distance;
|
||||
Distance leafDistance;
|
||||
int clusterID;
|
||||
};
|
||||
|
||||
typedef std::vector<Object> Objects;
|
||||
|
||||
Node() {
|
||||
parent.setNull();
|
||||
id.setNull();
|
||||
}
|
||||
|
||||
virtual ~Node() {}
|
||||
|
||||
Node &operator=(const Node &n) {
|
||||
id = n.id;
|
||||
parent = n.parent;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// for milvus
|
||||
void serialize(std::stringstream & os)
|
||||
{
|
||||
id.serialize(os);
|
||||
parent.serialize(os);
|
||||
}
|
||||
|
||||
void serialize(std::ofstream &os) {
|
||||
id.serialize(os);
|
||||
parent.serialize(os);
|
||||
}
|
||||
|
||||
void deserialize(std::ifstream &is) {
|
||||
id.deserialize(is);
|
||||
parent.deserialize(is);
|
||||
}
|
||||
|
||||
void deserialize(std::stringstream & is)
|
||||
{
|
||||
id.deserialize(is);
|
||||
parent.deserialize(is);
|
||||
}
|
||||
|
||||
void serializeAsText(std::ofstream &os) {
|
||||
id.serializeAsText(os);
|
||||
os << " ";
|
||||
parent.serializeAsText(os);
|
||||
}
|
||||
|
||||
void deserializeAsText(std::ifstream &is) {
|
||||
id.deserializeAsText(is);
|
||||
parent.deserializeAsText(is);
|
||||
}
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
void setPivot(PersistentObject &f, ObjectSpace &os, SharedMemoryAllocator &allocator) {
|
||||
if (pivot == 0) {
|
||||
pivot = NGT::PersistentObject::allocate(os);
|
||||
}
|
||||
getPivot(os).set(f, os);
|
||||
}
|
||||
PersistentObject &getPivot(ObjectSpace &os) {
|
||||
return *(PersistentObject*)os.getRepository().getAllocator().getAddr(pivot);
|
||||
}
|
||||
void deletePivot(ObjectSpace &os, SharedMemoryAllocator &allocator) {
|
||||
os.deleteObject(&getPivot(os));
|
||||
}
|
||||
#else // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
void setPivot(NGT::Object &f, ObjectSpace &os) {
|
||||
if (pivot == 0) {
|
||||
pivot = NGT::Object::allocate(os);
|
||||
}
|
||||
os.copy(getPivot(), f);
|
||||
}
|
||||
NGT::Object &getPivot() { return *pivot; }
|
||||
void deletePivot(ObjectSpace &os) {
|
||||
os.deleteObject(pivot);
|
||||
}
|
||||
#endif // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
|
||||
bool pivotIsEmpty() {
|
||||
return pivot == 0;
|
||||
}
|
||||
|
||||
ID id;
|
||||
ID parent;
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
off_t pivot;
|
||||
#else
|
||||
NGT::Object *pivot;
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
|
||||
class InternalNode : public Node {
|
||||
public:
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
InternalNode(size_t csize, SharedMemoryAllocator &allocator) : childrenSize(csize) { initialize(allocator); }
|
||||
InternalNode(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) : childrenSize(5) { initialize(allocator); }
|
||||
#else
|
||||
InternalNode(size_t csize) : childrenSize(csize) { initialize(); }
|
||||
InternalNode(NGT::ObjectSpace *os = 0) : childrenSize(5) { initialize(); }
|
||||
#endif
|
||||
|
||||
~InternalNode() {
|
||||
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
if (children != 0) {
|
||||
delete[] children;
|
||||
}
|
||||
if (borders != 0) {
|
||||
delete[] borders;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
void initialize(SharedMemoryAllocator &allocator) {
|
||||
#else
|
||||
void initialize() {
|
||||
#endif
|
||||
id = 0;
|
||||
id.setType(ID::Internal);
|
||||
pivot = 0;
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
children = allocator.getOffset(new(allocator) ID[childrenSize]);
|
||||
#else
|
||||
children = new ID[childrenSize];
|
||||
#endif
|
||||
for (size_t i = 0; i < childrenSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
getChildren(allocator)[i] = 0;
|
||||
#else
|
||||
getChildren()[i] = 0;
|
||||
#endif
|
||||
}
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
borders = allocator.getOffset(new(allocator) Distance[childrenSize - 1]);
|
||||
#else
|
||||
borders = new Distance[childrenSize - 1];
|
||||
#endif
|
||||
for (size_t i = 0; i < childrenSize - 1; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
getBorders(allocator)[i] = 0;
|
||||
#else
|
||||
getBorders()[i] = 0;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
void updateChild(DVPTree &dvptree, ID src, ID dst, SharedMemoryAllocator &allocator);
|
||||
#else
|
||||
void updateChild(DVPTree &dvptree, ID src, ID dst);
|
||||
#endif
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ID *getChildren(SharedMemoryAllocator &allocator) { return (ID*)allocator.getAddr(children); }
|
||||
Distance *getBorders(SharedMemoryAllocator &allocator) { return (Distance*)allocator.getAddr(borders); }
|
||||
#else // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ID *getChildren() { return children; }
|
||||
Distance *getBorders() { return borders; }
|
||||
#endif // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
|
||||
// for milvus
|
||||
void serialize(std::stringstream & os, ObjectSpace * objectspace = 0)
|
||||
{
|
||||
Node::serialize(os);
|
||||
if (pivot == 0)
|
||||
{
|
||||
NGTThrowException("Node::write: pivot is null!");
|
||||
}
|
||||
assert(objectspace != 0);
|
||||
getPivot().serialize(os, objectspace);
|
||||
NGT::Serializer::write(os, childrenSize);
|
||||
for (size_t i = 0; i < childrenSize; i++)
|
||||
{
|
||||
getChildren()[i].serialize(os);
|
||||
}
|
||||
for (size_t i = 0; i < childrenSize - 1; i++)
|
||||
{
|
||||
NGT::Serializer::write(os, getBorders()[i]);
|
||||
}
|
||||
}
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
void serialize(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
|
||||
#else
|
||||
void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) {
|
||||
#endif
|
||||
Node::serialize(os);
|
||||
if (pivot == 0) {
|
||||
NGTThrowException("Node::write: pivot is null!");
|
||||
}
|
||||
assert(objectspace != 0);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
getPivot(*objectspace).serialize(os, allocator, objectspace);
|
||||
#else
|
||||
getPivot().serialize(os, objectspace);
|
||||
#endif
|
||||
NGT::Serializer::write(os, childrenSize);
|
||||
for (size_t i = 0; i < childrenSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
getChildren(allocator)[i].serialize(os);
|
||||
#else
|
||||
getChildren()[i].serialize(os);
|
||||
#endif
|
||||
}
|
||||
for (size_t i = 0; i < childrenSize - 1; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
NGT::Serializer::write(os, getBorders(allocator)[i]);
|
||||
#else
|
||||
NGT::Serializer::write(os, getBorders()[i]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) {
|
||||
Node::deserialize(is);
|
||||
if (pivot == 0) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
#else
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
#endif
|
||||
}
|
||||
assert(objectspace != 0);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "not implemented" << std::endl;
|
||||
assert(0);
|
||||
#else
|
||||
getPivot().deserialize(is, objectspace);
|
||||
#endif
|
||||
NGT::Serializer::read(is, childrenSize);
|
||||
assert(children != 0);
|
||||
for (size_t i = 0; i < childrenSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
assert(0);
|
||||
#else
|
||||
getChildren()[i].deserialize(is);
|
||||
#endif
|
||||
}
|
||||
assert(borders != 0);
|
||||
for (size_t i = 0; i < childrenSize - 1; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
assert(0);
|
||||
#else
|
||||
NGT::Serializer::read(is, getBorders()[i]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
// for milvus
|
||||
void deserialize(std::stringstream & is, ObjectSpace * objectspace = 0)
|
||||
{
|
||||
Node::deserialize(is);
|
||||
if (pivot == 0)
|
||||
{
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
#else
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
#endif
|
||||
}
|
||||
assert(objectspace != 0);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "not implemented" << std::endl;
|
||||
assert(0);
|
||||
#else
|
||||
getPivot().deserialize(is, objectspace);
|
||||
#endif
|
||||
NGT::Serializer::read(is, childrenSize);
|
||||
assert(children != 0);
|
||||
for (size_t i = 0; i < childrenSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
assert(0);
|
||||
#else
|
||||
getChildren()[i].deserialize(is);
|
||||
#endif
|
||||
}
|
||||
assert(borders != 0);
|
||||
for (size_t i = 0; i < childrenSize - 1; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
assert(0);
|
||||
#else
|
||||
NGT::Serializer::read(is, getBorders()[i]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
|
||||
#else
|
||||
void serializeAsText(std::ofstream &os, ObjectSpace *objectspace = 0) {
|
||||
#endif
|
||||
Node::serializeAsText(os);
|
||||
if (pivot == 0) {
|
||||
NGTThrowException("Node::write: pivot is null!");
|
||||
}
|
||||
os << " ";
|
||||
assert(objectspace != 0);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
getPivot(*objectspace).serializeAsText(os, objectspace);
|
||||
#else
|
||||
getPivot().serializeAsText(os, objectspace);
|
||||
#endif
|
||||
os << " ";
|
||||
NGT::Serializer::writeAsText(os, childrenSize);
|
||||
os << " ";
|
||||
for (size_t i = 0; i < childrenSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
getChildren(allocator)[i].serializeAsText(os);
|
||||
#else
|
||||
getChildren()[i].serializeAsText(os);
|
||||
#endif
|
||||
os << " ";
|
||||
}
|
||||
for (size_t i = 0; i < childrenSize - 1; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
NGT::Serializer::writeAsText(os, getBorders(allocator)[i]);
|
||||
#else
|
||||
NGT::Serializer::writeAsText(os, getBorders()[i]);
|
||||
#endif
|
||||
os << " ";
|
||||
}
|
||||
}
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
|
||||
#else
|
||||
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) {
|
||||
#endif
|
||||
Node::deserializeAsText(is);
|
||||
if (pivot == 0) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
#else
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
#endif
|
||||
}
|
||||
assert(objectspace != 0);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
getPivot(*objectspace).deserializeAsText(is, objectspace);
|
||||
#else
|
||||
getPivot().deserializeAsText(is, objectspace);
|
||||
#endif
|
||||
size_t csize;
|
||||
NGT::Serializer::readAsText(is, csize);
|
||||
assert(children != 0);
|
||||
assert(childrenSize == csize);
|
||||
for (size_t i = 0; i < childrenSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
getChildren(allocator)[i].deserializeAsText(is);
|
||||
#else
|
||||
getChildren()[i].deserializeAsText(is);
|
||||
#endif
|
||||
}
|
||||
assert(borders != 0);
|
||||
for (size_t i = 0; i < childrenSize - 1; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
NGT::Serializer::readAsText(is, getBorders(allocator)[i]);
|
||||
#else
|
||||
NGT::Serializer::readAsText(is, getBorders()[i]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void show() {
|
||||
std::cout << "Show internal node " << childrenSize << ":";
|
||||
for (size_t i = 0; i < childrenSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
assert(0);
|
||||
#else
|
||||
std::cout << getChildren()[i].getID() << " ";
|
||||
#endif
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
bool verify(PersistentRepository<InternalNode> &internalNodes, PersistentRepository<LeafNode> &leafNodes,
|
||||
SharedMemoryAllocator &allocator);
|
||||
#else
|
||||
bool verify(Repository<InternalNode> &internalNodes, Repository<LeafNode> &leafNodes);
|
||||
#endif
|
||||
|
||||
static const int InternalChildrenSizeMax = 5;
|
||||
const size_t childrenSize;
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
off_t children;
|
||||
off_t borders;
|
||||
#else
|
||||
ID *children;
|
||||
Distance *borders;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
class LeafNode : public Node {
|
||||
public:
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
LeafNode(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) {
|
||||
#else
|
||||
LeafNode(NGT::ObjectSpace *os = 0) {
|
||||
#endif
|
||||
id = 0;
|
||||
id.setType(ID::Leaf);
|
||||
pivot = 0;
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
objectIDs.reserve(LeafObjectsSizeMax);
|
||||
#else
|
||||
objectSize = 0;
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
objectIDs = allocator.getOffset(new(allocator) Object[LeafObjectsSizeMax]);
|
||||
#else
|
||||
objectIDs = new NGT::ObjectDistance[LeafObjectsSizeMax];
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
~LeafNode() {
|
||||
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
#ifndef NGT_NODE_USE_VECTOR
|
||||
if (objectIDs != 0) {
|
||||
delete[] objectIDs;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
static int
|
||||
selectPivotByMaxDistance(Container &iobj, Node::Objects &fs);
|
||||
|
||||
static int
|
||||
selectPivotByMaxVariance(Container &iobj, Node::Objects &fs);
|
||||
|
||||
static void
|
||||
splitObjects(Container &insertedObject, Objects &splitObjectSet, int pivot);
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
void removeObject(size_t id, size_t replaceId, SharedMemoryAllocator &allocator);
|
||||
#else
|
||||
void removeObject(size_t id, size_t replaceId);
|
||||
#endif
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
#ifndef NGT_NODE_USE_VECTOR
|
||||
NGT::ObjectDistance *getObjectIDs(SharedMemoryAllocator &allocator) {
|
||||
return (NGT::ObjectDistance *)allocator.getAddr(objectIDs);
|
||||
}
|
||||
#endif
|
||||
#else // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
NGT::ObjectDistance *getObjectIDs() { return objectIDs; }
|
||||
#endif // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
|
||||
// for milvus
|
||||
void serialize(std::stringstream & os, ObjectSpace * objectspace = 0)
|
||||
{
|
||||
Node::serialize(os);
|
||||
NGT::Serializer::write(os, objectSize);
|
||||
for (int i = 0; i < objectSize; i++)
|
||||
{
|
||||
objectIDs[i].serialize(os);
|
||||
}
|
||||
if (pivot == 0)
|
||||
{
|
||||
// Before insertion, parent ID == 0 and object size == 0, that indicates an empty index
|
||||
if (parent.getID() != 0 || objectSize != 0)
|
||||
{
|
||||
NGTThrowException("Node::write: pivot is null!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(objectspace != 0);
|
||||
pivot->serialize(os, objectspace);
|
||||
}
|
||||
}
|
||||
void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) {
|
||||
Node::serialize(os);
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
NGT::Serializer::write(os, objectIDs);
|
||||
#else
|
||||
NGT::Serializer::write(os, objectSize);
|
||||
for (int i = 0; i < objectSize; i++) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
std::cerr << "not implemented" << std::endl;
|
||||
assert(0);
|
||||
#else
|
||||
objectIDs[i].serialize(os);
|
||||
#endif
|
||||
}
|
||||
#endif // NGT_NODE_USE_VECTOR
|
||||
if (pivot == 0) {
|
||||
// Before insertion, parent ID == 0 and object size == 0, that indicates an empty index
|
||||
if (parent.getID() != 0 || objectSize != 0) {
|
||||
NGTThrowException("Node::write: pivot is null!");
|
||||
}
|
||||
} else {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
std::cerr << "not implemented" << std::endl;
|
||||
assert(0);
|
||||
#else
|
||||
assert(objectspace != 0);
|
||||
pivot->serialize(os, objectspace);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) {
|
||||
Node::deserialize(is);
|
||||
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
objectIDs.clear();
|
||||
NGT::Serializer::read(is, objectIDs);
|
||||
#else
|
||||
assert(objectIDs != 0);
|
||||
NGT::Serializer::read(is, objectSize);
|
||||
for (int i = 0; i < objectSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "not implemented" << std::endl;
|
||||
assert(0);
|
||||
#else
|
||||
getObjectIDs()[i].deserialize(is);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
if (parent.getID() == 0 && objectSize == 0) {
|
||||
// The index is empty
|
||||
return;
|
||||
}
|
||||
if (pivot == 0) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
#else
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
assert(pivot != 0);
|
||||
#endif
|
||||
}
|
||||
assert(objectspace != 0);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "not implemented" << std::endl;
|
||||
assert(0);
|
||||
#else
|
||||
getPivot().deserialize(is, objectspace);
|
||||
#endif
|
||||
}
|
||||
|
||||
// for milvus
|
||||
void deserialize(std::stringstream & is, ObjectSpace * objectspace = 0)
|
||||
{
|
||||
Node::deserialize(is);
|
||||
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
objectIDs.clear();
|
||||
NGT::Serializer::read(is, objectIDs);
|
||||
#else
|
||||
assert(objectIDs != 0);
|
||||
NGT::Serializer::read(is, objectSize);
|
||||
for (int i = 0; i < objectSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "not implemented" << std::endl;
|
||||
assert(0);
|
||||
#else
|
||||
getObjectIDs()[i].deserialize(is);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
if (parent.getID() == 0 && objectSize == 0) {
|
||||
// The index is empty
|
||||
return;
|
||||
}
|
||||
if (pivot == 0) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
#else
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
assert(pivot != 0);
|
||||
#endif
|
||||
}
|
||||
assert(objectspace != 0);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "not implemented" << std::endl;
|
||||
assert(0);
|
||||
#else
|
||||
getPivot().deserialize(is, objectspace);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
|
||||
#else
|
||||
void serializeAsText(std::ofstream &os, ObjectSpace *objectspace = 0) {
|
||||
#endif
|
||||
Node::serializeAsText(os);
|
||||
os << " ";
|
||||
if (pivot == 0) {
|
||||
NGTThrowException("Node::write: pivot is null!");
|
||||
}
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
getPivot(*objectspace).serializeAsText(os, objectspace);
|
||||
#else
|
||||
assert(pivot != 0);
|
||||
assert(objectspace != 0);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
pivot->serializeAsText(os, allocator, objectspace);
|
||||
#else
|
||||
pivot->serializeAsText(os, objectspace);
|
||||
#endif
|
||||
#endif
|
||||
os << " ";
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
NGT::Serializer::writeAsText(os, objectIDs);
|
||||
#else
|
||||
NGT::Serializer::writeAsText(os, objectSize);
|
||||
for (int i = 0; i < objectSize; i++) {
|
||||
os << " ";
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
getObjectIDs(allocator)[i].serializeAsText(os);
|
||||
#else
|
||||
objectIDs[i].serializeAsText(os);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
|
||||
#else
|
||||
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) {
|
||||
#endif
|
||||
Node::deserializeAsText(is);
|
||||
if (pivot == 0) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
#else
|
||||
pivot = PersistentObject::allocate(*objectspace);
|
||||
#endif
|
||||
}
|
||||
assert(objectspace != 0);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
getPivot(*objectspace).deserializeAsText(is, objectspace);
|
||||
#else
|
||||
getPivot().deserializeAsText(is, objectspace);
|
||||
#endif
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
objectIDs.clear();
|
||||
NGT::Serializer::readAsText(is, objectIDs);
|
||||
#else
|
||||
assert(objectIDs != 0);
|
||||
NGT::Serializer::readAsText(is, objectSize);
|
||||
for (int i = 0; i < objectSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
getObjectIDs(allocator)[i].deserializeAsText(is);
|
||||
#else
|
||||
getObjectIDs()[i].deserializeAsText(is);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void show() {
|
||||
std::cout << "Show leaf node " << objectSize << ":";
|
||||
for (int i = 0; i < objectSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "not implemented" << std::endl;
|
||||
assert(0);
|
||||
#else
|
||||
std::cout << getObjectIDs()[i].id << "," << getObjectIDs()[i].distance << " ";
|
||||
#endif
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
bool verify(size_t nobjs, std::vector<uint8_t> &status, SharedMemoryAllocator &allocator);
|
||||
#else
|
||||
bool verify(size_t nobjs, std::vector<uint8_t> &status);
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
size_t getObjectSize() { return objectIDs.size(); }
|
||||
#else
|
||||
size_t getObjectSize() { return objectSize; }
|
||||
#endif
|
||||
|
||||
static const size_t LeafObjectsSizeMax = 100;
|
||||
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
std::vector<Object> objectIDs;
|
||||
#else
|
||||
unsigned short objectSize;
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
off_t objectIDs;
|
||||
#else
|
||||
ObjectDistance *objectIDs;
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,395 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
#include <sstream>
|
||||
|
||||
namespace NGT {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
class ObjectRepository :
|
||||
public PersistentRepository<PersistentObject> {
|
||||
public:
|
||||
typedef PersistentRepository<PersistentObject> Parent;
|
||||
void open(const std::string &smfile, size_t sharedMemorySize) {
|
||||
std::string file = smfile;
|
||||
file.append("po");
|
||||
Parent::open(file, sharedMemorySize);
|
||||
}
|
||||
#else
|
||||
class ObjectRepository : public Repository<Object> {
|
||||
public:
|
||||
typedef Repository<Object> Parent;
|
||||
#endif
|
||||
ObjectRepository(size_t dim, const std::type_info &ot):dimension(dim), type(ot), sparse(false) { }
|
||||
|
||||
void initialize() {
|
||||
deleteAll();
|
||||
Parent::push_back((PersistentObject*)0);
|
||||
}
|
||||
|
||||
// for milvus
|
||||
void serialize(std::stringstream & obj, ObjectSpace * ospace) { Parent::serialize(obj, ospace); }
|
||||
|
||||
void serialize(const std::string &ofile, ObjectSpace *ospace) {
|
||||
std::ofstream objs(ofile);
|
||||
if (!objs.is_open()) {
|
||||
std::stringstream msg;
|
||||
msg << "NGT::ObjectSpace: Cannot open the specified file " << ofile << ".";
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
Parent::serialize(objs, ospace);
|
||||
}
|
||||
|
||||
void deserialize(std::stringstream & obj, ObjectSpace * ospace)
|
||||
{
|
||||
assert(ospace != 0);
|
||||
Parent::deserialize(obj, ospace);
|
||||
}
|
||||
|
||||
void deserialize(const std::string &ifile, ObjectSpace *ospace) {
|
||||
assert(ospace != 0);
|
||||
std::ifstream objs(ifile);
|
||||
if (!objs.is_open()) {
|
||||
std::stringstream msg;
|
||||
msg << "NGT::ObjectSpace: Cannot open the specified file " << ifile << ".";
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
Parent::deserialize(objs, ospace);
|
||||
}
|
||||
|
||||
void serializeAsText(const std::string &ofile, ObjectSpace *ospace) {
|
||||
std::ofstream objs(ofile);
|
||||
if (!objs.is_open()) {
|
||||
std::stringstream msg;
|
||||
msg << "NGT::ObjectSpace: Cannot open the specified file " << ofile << ".";
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
Parent::serializeAsText(objs, ospace);
|
||||
}
|
||||
|
||||
void deserializeAsText(const std::string &ifile, ObjectSpace *ospace) {
|
||||
std::ifstream objs(ifile);
|
||||
if (!objs.is_open()) {
|
||||
std::stringstream msg;
|
||||
msg << "NGT::ObjectSpace: Cannot open the specified file " << ifile << ".";
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
Parent::deserializeAsText(objs, ospace);
|
||||
}
|
||||
|
||||
void readText(std::istream &is, size_t dataSize = 0) {
|
||||
initialize();
|
||||
appendText(is, dataSize);
|
||||
}
|
||||
|
||||
// For milvus
|
||||
template <typename T>
|
||||
void readRawData(const T * raw_data, size_t dataSize)
|
||||
{
|
||||
initialize();
|
||||
append(raw_data, dataSize);
|
||||
}
|
||||
|
||||
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<double> &obj) {
|
||||
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(double): Fatal error! Something wrong!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<float> &obj) {
|
||||
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(float): Fatal error! Something wrong!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<uint8_t> &obj) {
|
||||
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(uint8_t): Fatal error! Something wrong!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
virtual PersistentObject *allocateNormalizedPersistentObject(const float *obj, size_t size) {
|
||||
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject: Fatal error! Something wrong!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
void appendText(std::istream &is, size_t dataSize = 0) {
|
||||
if (dimension == 0) {
|
||||
NGTThrowException("ObjectSpace::readText: Dimension is not specified.");
|
||||
}
|
||||
if (size() == 0) {
|
||||
// First entry should be always a dummy entry.
|
||||
// If it is empty, the dummy entry should be inserted.
|
||||
push_back((PersistentObject*)0);
|
||||
}
|
||||
size_t prevDataSize = size();
|
||||
if (dataSize > 0) {
|
||||
reserve(size() + dataSize);
|
||||
}
|
||||
std::string line;
|
||||
size_t lineNo = 0;
|
||||
while (getline(is, line)) {
|
||||
lineNo++;
|
||||
if (dataSize > 0 && (dataSize <= size() - prevDataSize)) {
|
||||
std::cerr << "The size of data reached the specified size. The remaining data in the file are not inserted. "
|
||||
<< dataSize << std::endl;
|
||||
break;
|
||||
}
|
||||
std::vector<double> object;
|
||||
try {
|
||||
extractObjectFromText(line, "\t ", object);
|
||||
PersistentObject *obj = 0;
|
||||
try {
|
||||
obj = allocateNormalizedPersistentObject(object);
|
||||
} catch (Exception &err) {
|
||||
std::cerr << err.what() << " continue..." << std::endl;
|
||||
obj = allocatePersistentObject(object);
|
||||
}
|
||||
push_back(obj);
|
||||
} catch (Exception &err) {
|
||||
std::cerr << "ObjectSpace::readText: Warning! Invalid line. [" << line << "] Skip the line " << lineNo << " and continue." << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append(T *data, size_t objectCount) {
|
||||
if (dimension == 0) {
|
||||
NGTThrowException("ObjectSpace::readText: Dimension is not specified.");
|
||||
}
|
||||
if (size() == 0) {
|
||||
// First entry should be always a dummy entry.
|
||||
// If it is empty, the dummy entry should be inserted.
|
||||
push_back((PersistentObject*)0);
|
||||
}
|
||||
if (objectCount > 0) {
|
||||
reserve(size() + objectCount);
|
||||
}
|
||||
for (size_t idx = 0; idx < objectCount; idx++, data += dimension) {
|
||||
std::vector<double> object;
|
||||
object.reserve(dimension);
|
||||
for (size_t dataidx = 0; dataidx < dimension; dataidx++) {
|
||||
object.push_back(data[dataidx]);
|
||||
}
|
||||
try {
|
||||
PersistentObject *obj = 0;
|
||||
try {
|
||||
obj = allocateNormalizedPersistentObject(object);
|
||||
} catch (Exception &err) {
|
||||
std::cerr << err.what() << " continue..." << std::endl;
|
||||
obj = allocatePersistentObject(object);
|
||||
}
|
||||
push_back(obj);
|
||||
|
||||
} catch (Exception &err) {
|
||||
std::cerr << "ObjectSpace::readText: Warning! Invalid data. Skip the data no. " << idx << " and continue." << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Object *allocateObject() {
|
||||
return (Object*) new Object(paddedByteSize);
|
||||
}
|
||||
|
||||
// This method is called during search to generate query.
|
||||
// Therefore the object is not persistent.
|
||||
Object *allocateObject(const std::string &textLine, const std::string &sep) {
|
||||
std::vector<double> object;
|
||||
extractObjectFromText(textLine, sep, object);
|
||||
Object *po = (Object*)allocateObject(object);
|
||||
return (Object*)po;
|
||||
}
|
||||
|
||||
void extractObjectFromText(const std::string &textLine, const std::string &sep, std::vector<double> &object) {
|
||||
object.resize(dimension);
|
||||
std::vector<std::string> tokens;
|
||||
NGT::Common::tokenize(textLine, tokens, sep);
|
||||
if (dimension > tokens.size()) {
|
||||
std::stringstream msg;
|
||||
msg << "ObjectSpace::allocate: too few dimension. " << tokens.size() << ":" << dimension << ". "
|
||||
<< textLine;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
size_t idx;
|
||||
for (idx = 0; idx < dimension; idx++) {
|
||||
if (tokens[idx].size() == 0) {
|
||||
std::stringstream msg;
|
||||
msg << "ObjectSpace::allocate: too few dimension. " << tokens.size() << ":"
|
||||
<< dimension << ". " << textLine;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
char *e;
|
||||
object[idx] = strtod(tokens[idx].c_str(), &e);
|
||||
if (*e != 0) {
|
||||
std::cerr << "ObjectSpace::readText: Warning! Not numerical value. [" << e << "]" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Object *allocateObject(T *o, size_t size) {
|
||||
size_t osize = paddedByteSize;
|
||||
if (sparse) {
|
||||
size_t vsize = size * (type == typeid(float) ? 4 : 1);
|
||||
osize = osize < vsize ? vsize : osize;
|
||||
} else {
|
||||
if (dimension != size) {
|
||||
std::cerr << "ObjectSpace::allocateObject: Fatal error! dimension is invalid. The indexed objects="
|
||||
<< dimension << " The specified object=" << size << std::endl;
|
||||
assert(dimension == size);
|
||||
}
|
||||
}
|
||||
Object *po = new Object(osize);
|
||||
void *object = static_cast<void*>(&(*po)[0]);
|
||||
if (type == typeid(uint8_t)) {
|
||||
uint8_t *obj = static_cast<uint8_t*>(object);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
obj[i] = static_cast<uint8_t>(o[i]);
|
||||
}
|
||||
} else if (type == typeid(float)) {
|
||||
float *obj = static_cast<float*>(object);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
obj[i] = static_cast<float>(o[i]);
|
||||
}
|
||||
} else {
|
||||
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
return po;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Object *allocateObject(const std::vector<T> &o) {
|
||||
return allocateObject(o.data(), o.size());
|
||||
}
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
PersistentObject *allocatePersistentObject(Object &o) {
|
||||
SharedMemoryAllocator &objectAllocator = getAllocator();
|
||||
size_t cpsize = dimension;
|
||||
if (type == typeid(uint8_t)) {
|
||||
cpsize *= sizeof(uint8_t);
|
||||
} else if (type == typeid(float)) {
|
||||
cpsize *= sizeof(float);
|
||||
} else {
|
||||
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
PersistentObject *po = new (objectAllocator) PersistentObject(objectAllocator, paddedByteSize);
|
||||
void *dsto = &(*po).at(0, allocator);
|
||||
void *srco = &o[0];
|
||||
memcpy(dsto, srco, cpsize);
|
||||
return po;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
PersistentObject *allocatePersistentObject(T *o, size_t size) {
|
||||
SharedMemoryAllocator &objectAllocator = getAllocator();
|
||||
PersistentObject *po = new (objectAllocator) PersistentObject(objectAllocator, paddedByteSize);
|
||||
if (size != 0 && dimension != size) {
|
||||
std::stringstream msg;
|
||||
msg << "ObjectSpace::allocatePersistentObject: Fatal error! The dimensionality is invalid. The specified dimensionality="
|
||||
<< (sparse ? dimension - 1 : dimension) << ". The specified object=" << (sparse ? size - 1 : size) << ".";
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
void *object = static_cast<void*>(&(*po).at(0, allocator));
|
||||
if (type == typeid(uint8_t)) {
|
||||
uint8_t *obj = static_cast<uint8_t*>(object);
|
||||
for (size_t i = 0; i < dimension; i++) {
|
||||
obj[i] = static_cast<uint8_t>(o[i]);
|
||||
}
|
||||
} else if (type == typeid(float)) {
|
||||
float *obj = static_cast<float*>(object);
|
||||
for (size_t i = 0; i < dimension; i++) {
|
||||
obj[i] = static_cast<float>(o[i]);
|
||||
}
|
||||
} else {
|
||||
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
return po;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
PersistentObject *allocatePersistentObject(const std::vector<T> &o) {
|
||||
return allocatePersistentObject(o.data(), o.size());
|
||||
}
|
||||
|
||||
#else
|
||||
template <typename T>
|
||||
PersistentObject *allocatePersistentObject(T *o, size_t size) {
|
||||
if (size != 0 && dimension != size) {
|
||||
std::stringstream msg;
|
||||
msg << "ObjectSpace::allocatePersistentObject: Fatal error! The dimensionality is invalid. The specified dimensionality="
|
||||
<< (sparse ? dimension - 1 : dimension) << ". The specified object=" << (sparse ? size - 1 : size) << ".";
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
return allocateObject(o, size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
PersistentObject *allocatePersistentObject(const std::vector<T> &o) {
|
||||
return allocatePersistentObject(o.data(), o.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
void deleteObject(Object *po) {
|
||||
delete po;
|
||||
}
|
||||
|
||||
private:
|
||||
void extractObject(void *object, std::vector<double> &d) {
|
||||
if (type == typeid(uint8_t)) {
|
||||
uint8_t *obj = (uint8_t*)object;
|
||||
for (size_t i = 0; i < dimension; i++) {
|
||||
d.push_back(obj[i]);
|
||||
}
|
||||
} else if (type == typeid(float)) {
|
||||
float *obj = (float*)object;
|
||||
for (size_t i = 0; i < dimension; i++) {
|
||||
d.push_back(obj[i]);
|
||||
}
|
||||
} else {
|
||||
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
public:
|
||||
void extractObject(Object *o, std::vector<double> &d) {
|
||||
void *object = (void*)(&(*o)[0]);
|
||||
extractObject(object, d);
|
||||
}
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
void extractObject(PersistentObject *o, std::vector<double> &d) {
|
||||
SharedMemoryAllocator &objectAllocator = getAllocator();
|
||||
void *object = (void*)(&(*o).at(0, objectAllocator));
|
||||
extractObject(object, d);
|
||||
}
|
||||
#endif
|
||||
|
||||
void setLength(size_t l) { byteSize = l; }
|
||||
void setPaddedLength(size_t l) { paddedByteSize = l; }
|
||||
void setSparse() { sparse = true; }
|
||||
size_t getByteSize() { return byteSize; }
|
||||
size_t insert(PersistentObject *obj) { return Parent::insert(obj); }
|
||||
const size_t dimension;
|
||||
const std::type_info &type;
|
||||
protected:
|
||||
size_t byteSize; // the length of all of elements.
|
||||
size_t paddedByteSize;
|
||||
bool sparse; // sparse data format
|
||||
};
|
||||
|
||||
} // namespace NGT
|
|
@ -0,0 +1,475 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "PrimitiveComparator.h"
|
||||
|
||||
class ObjectSpace;
|
||||
|
||||
namespace NGT {
|
||||
|
||||
class PersistentObjectDistances;
|
||||
class ObjectDistances : public std::vector<ObjectDistance> {
|
||||
public:
|
||||
ObjectDistances(NGT::ObjectSpace *os = 0) {}
|
||||
// for milvus
|
||||
void serialize(std::stringstream & os, ObjectSpace * objspace = 0) { NGT::Serializer::write(os, (std::vector<ObjectDistance> &)*this); }
|
||||
void serialize(std::ofstream &os, ObjectSpace *objspace = 0) { NGT::Serializer::write(os, (std::vector<ObjectDistance>&)*this);}
|
||||
// for milvus
|
||||
void deserialize(std::stringstream & is, ObjectSpace * objspace = 0)
|
||||
{
|
||||
NGT::Serializer::read(is, (std::vector<ObjectDistance> &)*this);
|
||||
}
|
||||
void deserialize(std::ifstream &is, ObjectSpace *objspace = 0) { NGT::Serializer::read(is, (std::vector<ObjectDistance>&)*this);}
|
||||
|
||||
void serializeAsText(std::ofstream &os, ObjectSpace *objspace = 0) {
|
||||
NGT::Serializer::writeAsText(os, size());
|
||||
os << " ";
|
||||
for (size_t i = 0; i < size(); i++) {
|
||||
(*this)[i].serializeAsText(os);
|
||||
os << " ";
|
||||
}
|
||||
}
|
||||
void deserializeAsText(std::ifstream &is, ObjectSpace *objspace = 0) {
|
||||
size_t s;
|
||||
NGT::Serializer::readAsText(is, s);
|
||||
resize(s);
|
||||
for (size_t i = 0; i < size(); i++) {
|
||||
(*this)[i].deserializeAsText(is);
|
||||
}
|
||||
}
|
||||
|
||||
void moveFrom(std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > &pq) {
|
||||
this->clear();
|
||||
this->resize(pq.size());
|
||||
for (int i = pq.size() - 1; i >= 0; i--) {
|
||||
(*this)[i] = pq.top();
|
||||
pq.pop();
|
||||
}
|
||||
assert(pq.size() == 0);
|
||||
}
|
||||
|
||||
void moveFrom(std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > &pq, double (&f)(double)) {
|
||||
this->clear();
|
||||
this->resize(pq.size());
|
||||
for (int i = pq.size() - 1; i >= 0; i--) {
|
||||
(*this)[i] = pq.top();
|
||||
(*this)[i].distance = f((*this)[i].distance);
|
||||
pq.pop();
|
||||
}
|
||||
assert(pq.size() == 0);
|
||||
}
|
||||
|
||||
void moveFrom(std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > &pq, unsigned int id) {
|
||||
this->clear();
|
||||
if (pq.size() == 0) {
|
||||
return;
|
||||
}
|
||||
this->resize(id == 0 ? pq.size() : pq.size() - 1);
|
||||
int i = this->size() - 1;
|
||||
while (pq.size() != 0 && i >= 0) {
|
||||
if (pq.top().id != id) {
|
||||
(*this)[i] = pq.top();
|
||||
i--;
|
||||
}
|
||||
pq.pop();
|
||||
}
|
||||
if (pq.size() != 0 && pq.top().id != id) {
|
||||
std::cerr << "moveFrom: Fatal error: somethig wrong! " << pq.size() << ":" << this->size() << ":" << id << ":" << pq.top().id << std::endl;
|
||||
assert(pq.size() == 0 || pq.top().id == id);
|
||||
}
|
||||
}
|
||||
|
||||
ObjectDistances &operator=(PersistentObjectDistances &objs);
|
||||
};
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
class PersistentObjectDistances : public Vector<ObjectDistance> {
|
||||
public:
|
||||
PersistentObjectDistances(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) {}
|
||||
void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) { NGT::Serializer::write(os, (Vector<ObjectDistance>&)*this); }
|
||||
void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) { NGT::Serializer::read(is, (Vector<ObjectDistance>&)*this); }
|
||||
void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objspace = 0) {
|
||||
NGT::Serializer::writeAsText(os, size());
|
||||
os << " ";
|
||||
for (size_t i = 0; i < size(); i++) {
|
||||
(*this).at(i, allocator).serializeAsText(os);
|
||||
os << " ";
|
||||
}
|
||||
}
|
||||
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objspace = 0) {
|
||||
size_t s;
|
||||
is >> s;
|
||||
resize(s, allocator);
|
||||
for (size_t i = 0; i < size(); i++) {
|
||||
(*this).at(i, allocator).deserializeAsText(is);
|
||||
}
|
||||
}
|
||||
PersistentObjectDistances ©(ObjectDistances &objs, SharedMemoryAllocator &allocator) {
|
||||
clear(allocator);
|
||||
reserve(objs.size(), allocator);
|
||||
for (ObjectDistances::iterator i = objs.begin(); i != objs.end(); i++) {
|
||||
push_back(*i, allocator);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
typedef PersistentObjectDistances GraphNode;
|
||||
|
||||
inline ObjectDistances &ObjectDistances::operator=(PersistentObjectDistances &objs)
|
||||
{
|
||||
clear();
|
||||
reserve(objs.size());
|
||||
std::cerr << "not implemented" << std::endl;
|
||||
assert(0);
|
||||
return *this;
|
||||
}
|
||||
#else // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
typedef ObjectDistances GraphNode;
|
||||
#endif // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
class PersistentObject;
|
||||
#else
|
||||
typedef Object PersistentObject;
|
||||
#endif
|
||||
|
||||
class ObjectRepository;
|
||||
|
||||
class ObjectSpace {
|
||||
public:
|
||||
class Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
Comparator(size_t d, SharedMemoryAllocator &a) : dimension(d), allocator(a) {}
|
||||
#else
|
||||
Comparator(size_t d) : dimension(d) {}
|
||||
#endif
|
||||
virtual double operator()(Object &objecta, Object &objectb) = 0;
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
virtual double operator()(Object &objecta, PersistentObject &objectb) = 0;
|
||||
virtual double operator()(PersistentObject &objecta, PersistentObject &objectb) = 0;
|
||||
#endif
|
||||
size_t dimension;
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
SharedMemoryAllocator &allocator;
|
||||
#endif
|
||||
virtual ~Comparator(){}
|
||||
};
|
||||
enum DistanceType {
|
||||
DistanceTypeNone = -1,
|
||||
DistanceTypeL1 = 0,
|
||||
DistanceTypeL2 = 1,
|
||||
DistanceTypeHamming = 2,
|
||||
DistanceTypeAngle = 3,
|
||||
DistanceTypeCosine = 4,
|
||||
DistanceTypeNormalizedAngle = 5,
|
||||
DistanceTypeNormalizedCosine = 6,
|
||||
DistanceTypeJaccard = 7,
|
||||
DistanceTypeSparseJaccard = 8
|
||||
};
|
||||
|
||||
enum ObjectType {
|
||||
ObjectTypeNone = 0,
|
||||
Uint8 = 1,
|
||||
Float = 2
|
||||
};
|
||||
|
||||
|
||||
typedef std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > ResultSet;
|
||||
ObjectSpace(size_t d):dimension(d), distanceType(DistanceTypeNone), comparator(0), normalization(false) {}
|
||||
virtual ~ObjectSpace() { if (comparator != 0) { delete comparator; } }
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
virtual void open(const std::string &f, size_t shareMemorySize) = 0;
|
||||
virtual Object *allocateObject(Object &o) = 0;
|
||||
virtual Object *allocateObject(PersistentObject &o) = 0;
|
||||
virtual PersistentObject *allocatePersistentObject(Object &obj) = 0;
|
||||
virtual void deleteObject(PersistentObject *) = 0;
|
||||
virtual void copy(PersistentObject &objecta, PersistentObject &objectb) = 0;
|
||||
virtual void show(std::ostream &os, PersistentObject &object) = 0;
|
||||
virtual size_t insert(PersistentObject *obj) = 0;
|
||||
#else
|
||||
virtual size_t insert(Object *obj) = 0;
|
||||
#endif
|
||||
|
||||
Comparator &getComparator() { return *comparator; }
|
||||
|
||||
virtual void serialize(const std::string &of) = 0;
|
||||
// for milvus
|
||||
virtual void serialize(std::stringstream & obj) = 0;
|
||||
// for milvus
|
||||
virtual void deserialize(std::stringstream & obj) = 0;
|
||||
virtual void deserialize(const std::string &ifile) = 0;
|
||||
virtual void serializeAsText(const std::string &of) = 0;
|
||||
virtual void deserializeAsText(const std::string &of) = 0;
|
||||
//for milvus
|
||||
virtual void readRawData(const float * raw_data, size_t dataSize) = 0;
|
||||
virtual void readText(std::istream &is, size_t dataSize) = 0;
|
||||
virtual void appendText(std::istream &is, size_t dataSize) = 0;
|
||||
virtual void append(const float *data, size_t dataSize) = 0;
|
||||
virtual void append(const double *data, size_t dataSize) = 0;
|
||||
|
||||
virtual void copy(Object &objecta, Object &objectb) = 0;
|
||||
|
||||
virtual void linearSearch(Object &query, double radius, size_t size,
|
||||
ObjectSpace::ResultSet &results) = 0;
|
||||
|
||||
virtual const std::type_info &getObjectType() = 0;
|
||||
virtual void show(std::ostream &os, Object &object) = 0;
|
||||
virtual size_t getSize() = 0;
|
||||
virtual size_t getSizeOfElement() = 0;
|
||||
virtual size_t getByteSizeOfObject() = 0;
|
||||
virtual Object *allocateNormalizedObject(const std::string &textLine, const std::string &sep) = 0;
|
||||
virtual Object *allocateNormalizedObject(const std::vector<double> &obj) = 0;
|
||||
virtual Object *allocateNormalizedObject(const std::vector<float> &obj) = 0;
|
||||
virtual Object *allocateNormalizedObject(const std::vector<uint8_t> &obj) = 0;
|
||||
virtual Object *allocateNormalizedObject(const float *obj, size_t size) = 0;
|
||||
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<double> &obj) = 0;
|
||||
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<float> &obj) = 0;
|
||||
virtual void deleteObject(Object *po) = 0;
|
||||
virtual Object *allocateObject() = 0;
|
||||
virtual void remove(size_t id) = 0;
|
||||
|
||||
virtual ObjectRepository &getRepository() = 0;
|
||||
|
||||
virtual void setDistanceType(DistanceType t) = 0;
|
||||
|
||||
virtual void *getObject(size_t idx) = 0;
|
||||
virtual void getObject(size_t idx, std::vector<float> &v) = 0;
|
||||
virtual void getObjects(const std::vector<size_t> &idxs, std::vector<std::vector<float>> &vs) = 0;
|
||||
|
||||
size_t getDimension() { return dimension; }
|
||||
size_t getPaddedDimension() { return ((dimension - 1) / 16 + 1) * 16; }
|
||||
|
||||
template <typename T>
|
||||
void normalize(T *data, size_t dim) {
|
||||
double sum = 0.0;
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
sum += (double)data[i] * (double)data[i];
|
||||
}
|
||||
if (sum == 0.0) {
|
||||
std::stringstream msg;
|
||||
msg << "ObjectSpace::normalize: Error! the object is an invalid zero vector for the cosine similarity or angle distance.";
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
sum = sqrt(sum);
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
data[i] = (double)data[i] / sum;
|
||||
}
|
||||
}
|
||||
uint32_t getPrefetchOffset() { return prefetchOffset; }
|
||||
uint32_t setPrefetchOffset(size_t offset) {
|
||||
if (offset == 0) {
|
||||
prefetchOffset = floor(300.0 / (static_cast<float>(getPaddedDimension()) + 30.0) + 1.0);
|
||||
} else {
|
||||
prefetchOffset = offset;
|
||||
}
|
||||
return prefetchOffset;
|
||||
}
|
||||
uint32_t getPrefetchSize() { return prefetchSize; }
|
||||
uint32_t setPrefetchSize(size_t size) {
|
||||
if (size == 0) {
|
||||
prefetchSize = getByteSizeOfObject();
|
||||
} else {
|
||||
prefetchSize = size;
|
||||
}
|
||||
return prefetchSize;
|
||||
}
|
||||
protected:
|
||||
const size_t dimension;
|
||||
DistanceType distanceType;
|
||||
Comparator *comparator;
|
||||
bool normalization;
|
||||
uint32_t prefetchOffset;
|
||||
uint32_t prefetchSize;
|
||||
};
|
||||
|
||||
class BaseObject {
|
||||
public:
|
||||
virtual uint8_t &operator[](size_t idx) const = 0;
|
||||
void serialize(std::ostream &os, ObjectSpace *objectspace = 0) {
|
||||
assert(objectspace != 0);
|
||||
size_t byteSize = objectspace->getByteSizeOfObject();
|
||||
NGT::Serializer::write(os, (uint8_t*)&(*this)[0], byteSize);
|
||||
}
|
||||
void deserialize(std::istream &is, ObjectSpace *objectspace = 0) {
|
||||
assert(objectspace != 0);
|
||||
size_t byteSize = objectspace->getByteSizeOfObject();
|
||||
assert(&(*this)[0] != 0);
|
||||
NGT::Serializer::read(is, (uint8_t*)&(*this)[0], byteSize);
|
||||
}
|
||||
void serializeAsText(std::ostream &os, ObjectSpace *objectspace = 0) {
|
||||
assert(objectspace != 0);
|
||||
const std::type_info &t = objectspace->getObjectType();
|
||||
size_t dimension = objectspace->getDimension();
|
||||
void *ref = (void*)&(*this)[0];
|
||||
if (t == typeid(uint8_t)) {
|
||||
NGT::Serializer::writeAsText(os, (uint8_t*)ref, dimension);
|
||||
} else if (t == typeid(float)) {
|
||||
NGT::Serializer::writeAsText(os, (float*)ref, dimension);
|
||||
} else if (t == typeid(double)) {
|
||||
NGT::Serializer::writeAsText(os, (double*)ref, dimension);
|
||||
} else if (t == typeid(uint16_t)) {
|
||||
NGT::Serializer::writeAsText(os, (uint16_t*)ref, dimension);
|
||||
} else if (t == typeid(uint32_t)) {
|
||||
NGT::Serializer::writeAsText(os, (uint32_t*)ref, dimension);
|
||||
} else {
|
||||
std::cerr << "Object::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) {
|
||||
assert(objectspace != 0);
|
||||
const std::type_info &t = objectspace->getObjectType();
|
||||
size_t dimension = objectspace->getDimension();
|
||||
void *ref = (void*)&(*this)[0];
|
||||
assert(ref != 0);
|
||||
if (t == typeid(uint8_t)) {
|
||||
NGT::Serializer::readAsText(is, (uint8_t*)ref, dimension);
|
||||
} else if (t == typeid(float)) {
|
||||
NGT::Serializer::readAsText(is, (float*)ref, dimension);
|
||||
} else if (t == typeid(double)) {
|
||||
NGT::Serializer::readAsText(is, (double*)ref, dimension);
|
||||
} else if (t == typeid(uint16_t)) {
|
||||
NGT::Serializer::readAsText(is, (uint16_t*)ref, dimension);
|
||||
} else if (t == typeid(uint32_t)) {
|
||||
NGT::Serializer::readAsText(is, (uint32_t*)ref, dimension);
|
||||
} else {
|
||||
std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class Object : public BaseObject {
|
||||
public:
|
||||
Object(NGT::ObjectSpace *os = 0):vector(0) {
|
||||
assert(os != 0);
|
||||
size_t s = os->getByteSizeOfObject();
|
||||
construct(s);
|
||||
}
|
||||
|
||||
Object(size_t s):vector(0) {
|
||||
assert(s != 0);
|
||||
construct(s);
|
||||
}
|
||||
|
||||
void copy(Object &o, size_t s) {
|
||||
assert(vector != 0);
|
||||
for (size_t i = 0; i < s; i++) {
|
||||
vector[i] = o[i];
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~Object() { clear(); }
|
||||
|
||||
uint8_t &operator[](size_t idx) const { return vector[idx]; }
|
||||
|
||||
void *getPointer(size_t idx = 0) const { return vector + idx; }
|
||||
|
||||
static Object *allocate(ObjectSpace &objectspace) { return new Object(&objectspace); }
|
||||
private:
|
||||
void clear() {
|
||||
if (vector != 0) {
|
||||
MemoryCache::alignedFree(vector);
|
||||
}
|
||||
vector = 0;
|
||||
}
|
||||
|
||||
void construct(size_t s) {
|
||||
assert(vector == 0);
|
||||
size_t allocsize = ((s - 1) / 64 + 1) * 64;
|
||||
vector = static_cast<uint8_t*>(MemoryCache::alignedAlloc(allocsize));
|
||||
memset(vector, 0, allocsize);
|
||||
}
|
||||
|
||||
uint8_t* vector;
|
||||
};
|
||||
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
class PersistentObject : public BaseObject {
|
||||
public:
|
||||
PersistentObject(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0):array(0) {
|
||||
assert(os != 0);
|
||||
size_t s = os->getByteSizeOfObject();
|
||||
construct(s, allocator);
|
||||
}
|
||||
PersistentObject(SharedMemoryAllocator &allocator, size_t s):array(0) {
|
||||
assert(s != 0);
|
||||
construct(s, allocator);
|
||||
}
|
||||
|
||||
~PersistentObject() {}
|
||||
|
||||
uint8_t &at(size_t idx, SharedMemoryAllocator &allocator) const {
|
||||
uint8_t *a = (uint8_t *)allocator.getAddr(array);
|
||||
return a[idx];
|
||||
}
|
||||
uint8_t &operator[](size_t idx) const {
|
||||
std::cerr << "not implemented" << std::endl;
|
||||
assert(0);
|
||||
uint8_t *a = 0;
|
||||
return a[idx];
|
||||
}
|
||||
|
||||
void *getPointer(size_t idx, SharedMemoryAllocator &allocator) {
|
||||
uint8_t *a = (uint8_t *)allocator.getAddr(array);
|
||||
return a + idx;
|
||||
}
|
||||
|
||||
// set v in objectspace to this object using allocator.
|
||||
void set(PersistentObject &po, ObjectSpace &objectspace);
|
||||
|
||||
static off_t allocate(ObjectSpace &objectspace);
|
||||
|
||||
void serializeAsText(std::ostream &os, SharedMemoryAllocator &allocator,
|
||||
ObjectSpace *objectspace = 0) {
|
||||
serializeAsText(os, objectspace);
|
||||
}
|
||||
|
||||
void serializeAsText(std::ostream &os, ObjectSpace *objectspace = 0);
|
||||
|
||||
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator,
|
||||
ObjectSpace *objectspace = 0) {
|
||||
deserializeAsText(is, objectspace);
|
||||
}
|
||||
|
||||
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0);
|
||||
|
||||
void serialize(std::ostream &os, SharedMemoryAllocator &allocator,
|
||||
ObjectSpace *objectspace = 0) {
|
||||
std::cerr << "serialize is not implemented" << std::endl;
|
||||
assert(0);
|
||||
}
|
||||
|
||||
private:
|
||||
void construct(size_t s, SharedMemoryAllocator &allocator) {
|
||||
assert(array == 0);
|
||||
assert(s != 0);
|
||||
size_t allocsize = ((s - 1) / 64 + 1) * 64;
|
||||
array = allocator.getOffset(new(allocator) uint8_t[allocsize]);
|
||||
memset(getPointer(0, allocator), 0, allocsize);
|
||||
}
|
||||
off_t array;
|
||||
};
|
||||
#endif // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,620 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
#include <sstream>
|
||||
#include "Common.h"
|
||||
#include "ObjectSpace.h"
|
||||
#include "ObjectRepository.h"
|
||||
#include "PrimitiveComparator.h"
|
||||
|
||||
class ObjectSpace;
|
||||
|
||||
namespace NGT {
|
||||
|
||||
template <typename OBJECT_TYPE, typename COMPARE_TYPE>
|
||||
class ObjectSpaceRepository : public ObjectSpace, public ObjectRepository {
|
||||
public:
|
||||
|
||||
class ComparatorL1 : public Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ComparatorL1(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
double operator()(Object &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
#else
|
||||
ComparatorL1(size_t d) : Comparator(d) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class ComparatorL2 : public Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ComparatorL2(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
double operator()(Object &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
#else
|
||||
ComparatorL2(size_t d) : Comparator(d) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class ComparatorHammingDistance : public Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ComparatorHammingDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
double operator()(Object &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
#else
|
||||
ComparatorHammingDistance(size_t d) : Comparator(d) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class ComparatorJaccardDistance : public Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ComparatorJaccardDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
double operator()(Object &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
#else
|
||||
ComparatorJaccardDistance(size_t d) : Comparator(d) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class ComparatorSparseJaccardDistance : public Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ComparatorSparseJaccardDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
double operator()(Object &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
#else
|
||||
ComparatorSparseJaccardDistance(size_t d) : Comparator(d) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class ComparatorAngleDistance : public Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ComparatorAngleDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
double operator()(Object &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
#else
|
||||
ComparatorAngleDistance(size_t d) : Comparator(d) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class ComparatorNormalizedAngleDistance : public Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ComparatorNormalizedAngleDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
double operator()(Object &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
#else
|
||||
ComparatorNormalizedAngleDistance(size_t d) : Comparator(d) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class ComparatorCosineSimilarity : public Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ComparatorCosineSimilarity(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
double operator()(Object &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
#else
|
||||
ComparatorCosineSimilarity(size_t d) : Comparator(d) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class ComparatorNormalizedCosineSimilarity : public Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ComparatorNormalizedCosineSimilarity(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
double operator()(Object &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
#else
|
||||
ComparatorNormalizedCosineSimilarity(size_t d) : Comparator(d) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
ObjectSpaceRepository(size_t d, const std::type_info &ot, DistanceType t) : ObjectSpace(d), ObjectRepository(d, ot) {
|
||||
size_t objectSize = 0;
|
||||
if (ot == typeid(uint8_t)) {
|
||||
objectSize = sizeof(uint8_t);
|
||||
} else if (ot == typeid(float)) {
|
||||
objectSize = sizeof(float);
|
||||
} else {
|
||||
std::stringstream msg;
|
||||
msg << "ObjectSpace::constructor: Not supported type. " << ot.name();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
setLength(objectSize * d);
|
||||
setPaddedLength(objectSize * ObjectSpace::getPaddedDimension());
|
||||
setDistanceType(t);
|
||||
}
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
void open(const std::string &f, size_t sharedMemorySize) { ObjectRepository::open(f, sharedMemorySize); }
|
||||
void copy(PersistentObject &objecta, PersistentObject &objectb) { objecta = objectb; }
|
||||
|
||||
void show(std::ostream &os, PersistentObject &object) {
|
||||
const std::type_info &t = getObjectType();
|
||||
if (t == typeid(uint8_t)) {
|
||||
unsigned char *optr = static_cast<unsigned char*>(&object.at(0,allocator));
|
||||
for (size_t i = 0; i < getDimension(); i++) {
|
||||
os << (int)optr[i] << " ";
|
||||
}
|
||||
} else if (t == typeid(float)) {
|
||||
float *optr = reinterpret_cast<float*>(&object.at(0,allocator));
|
||||
for (size_t i = 0; i < getDimension(); i++) {
|
||||
os << optr[i] << " ";
|
||||
}
|
||||
} else {
|
||||
os << " not implement for the type.";
|
||||
}
|
||||
}
|
||||
|
||||
Object *allocateObject(Object &o) {
|
||||
Object *po = new Object(getByteSizeOfObject());
|
||||
for (size_t i = 0; i < getByteSizeOfObject(); i++) {
|
||||
(*po)[i] = o[i];
|
||||
}
|
||||
return po;
|
||||
}
|
||||
Object *allocateObject(PersistentObject &o) {
|
||||
PersistentObject &spo = (PersistentObject &)o;
|
||||
Object *po = new Object(getByteSizeOfObject());
|
||||
for (size_t i = 0; i < getByteSizeOfObject(); i++) {
|
||||
(*po)[i] = spo.at(i,ObjectRepository::allocator);
|
||||
}
|
||||
return (Object*)po;
|
||||
}
|
||||
void deleteObject(PersistentObject *po) {
|
||||
delete po;
|
||||
}
|
||||
#endif // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
|
||||
void copy(Object &objecta, Object &objectb) {
|
||||
objecta.copy(objectb, getByteSizeOfObject());
|
||||
}
|
||||
|
||||
void setDistanceType(DistanceType t) {
|
||||
if (comparator != 0) {
|
||||
delete comparator;
|
||||
}
|
||||
assert(ObjectSpace::dimension != 0);
|
||||
distanceType = t;
|
||||
switch (distanceType) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
case DistanceTypeL1:
|
||||
comparator = new ObjectSpaceRepository::ComparatorL1(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
|
||||
break;
|
||||
case DistanceTypeL2:
|
||||
comparator = new ObjectSpaceRepository::ComparatorL2(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
|
||||
break;
|
||||
case DistanceTypeHamming:
|
||||
comparator = new ObjectSpaceRepository::ComparatorHammingDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
|
||||
break;
|
||||
case DistanceTypeJaccard:
|
||||
comparator = new ObjectSpaceRepository::ComparatorJaccardDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
|
||||
break;
|
||||
case DistanceTypeSparseJaccard:
|
||||
comparator = new ObjectSpaceRepository::ComparatorSparseJaccardDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
|
||||
setSparse();
|
||||
break;
|
||||
case DistanceTypeAngle:
|
||||
comparator = new ObjectSpaceRepository::ComparatorAngleDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
|
||||
break;
|
||||
case DistanceTypeCosine:
|
||||
comparator = new ObjectSpaceRepository::ComparatorCosineSimilarity(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
|
||||
break;
|
||||
case DistanceTypeNormalizedAngle:
|
||||
comparator = new ObjectSpaceRepository::ComparatorNormalizedAngleDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
|
||||
normalization = true;
|
||||
break;
|
||||
case DistanceTypeNormalizedCosine:
|
||||
comparator = new ObjectSpaceRepository::ComparatorNormalizedCosineSimilarity(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
|
||||
normalization = true;
|
||||
break;
|
||||
#else
|
||||
case DistanceTypeL1:
|
||||
comparator = new ObjectSpaceRepository::ComparatorL1(ObjectSpace::getPaddedDimension());
|
||||
break;
|
||||
case DistanceTypeL2:
|
||||
comparator = new ObjectSpaceRepository::ComparatorL2(ObjectSpace::getPaddedDimension());
|
||||
break;
|
||||
case DistanceTypeHamming:
|
||||
comparator = new ObjectSpaceRepository::ComparatorHammingDistance(ObjectSpace::getPaddedDimension());
|
||||
break;
|
||||
case DistanceTypeJaccard:
|
||||
comparator = new ObjectSpaceRepository::ComparatorJaccardDistance(ObjectSpace::getPaddedDimension());
|
||||
break;
|
||||
case DistanceTypeSparseJaccard:
|
||||
comparator = new ObjectSpaceRepository::ComparatorSparseJaccardDistance(ObjectSpace::getPaddedDimension());
|
||||
setSparse();
|
||||
break;
|
||||
case DistanceTypeAngle:
|
||||
comparator = new ObjectSpaceRepository::ComparatorAngleDistance(ObjectSpace::getPaddedDimension());
|
||||
break;
|
||||
case DistanceTypeCosine:
|
||||
comparator = new ObjectSpaceRepository::ComparatorCosineSimilarity(ObjectSpace::getPaddedDimension());
|
||||
break;
|
||||
case DistanceTypeNormalizedAngle:
|
||||
comparator = new ObjectSpaceRepository::ComparatorNormalizedAngleDistance(ObjectSpace::getPaddedDimension());
|
||||
normalization = true;
|
||||
break;
|
||||
case DistanceTypeNormalizedCosine:
|
||||
comparator = new ObjectSpaceRepository::ComparatorNormalizedCosineSimilarity(ObjectSpace::getPaddedDimension());
|
||||
normalization = true;
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
std::cerr << "Distance type is not specified" << std::endl;
|
||||
assert(distanceType != DistanceTypeNone);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void serialize(const std::string & ofile) { ObjectRepository::serialize(ofile, this); }
|
||||
// for milvus
|
||||
void serialize(std::stringstream & obj) { ObjectRepository::serialize(obj, this); }
|
||||
// for milvus
|
||||
void deserialize(std::stringstream & obj) { ObjectRepository::deserialize(obj, this); }
|
||||
void deserialize(const std::string &ifile) { ObjectRepository::deserialize(ifile, this); }
|
||||
void serializeAsText(const std::string &ofile) { ObjectRepository::serializeAsText(ofile, this); }
|
||||
void deserializeAsText(const std::string &ifile) { ObjectRepository::deserializeAsText(ifile, this); }
|
||||
// For milvus
|
||||
void readRawData(const float * raw_data, size_t dataSize) { ObjectRepository::readRawData<float>(raw_data, dataSize); }
|
||||
void readText(std::istream &is, size_t dataSize) { ObjectRepository::readText(is, dataSize); }
|
||||
void appendText(std::istream &is, size_t dataSize) { ObjectRepository::appendText(is, dataSize); }
|
||||
|
||||
void append(const float *data, size_t dataSize) { ObjectRepository::append(data, dataSize); }
|
||||
void append(const double *data, size_t dataSize) { ObjectRepository::append(data, dataSize); }
|
||||
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
PersistentObject *allocatePersistentObject(Object &obj) {
|
||||
return ObjectRepository::allocatePersistentObject(obj);
|
||||
}
|
||||
size_t insert(PersistentObject *obj) { return ObjectRepository::insert(obj); }
|
||||
#else
|
||||
size_t insert(Object *obj) { return ObjectRepository::insert(obj); }
|
||||
#endif
|
||||
|
||||
void remove(size_t id) { ObjectRepository::remove(id); }
|
||||
|
||||
void linearSearch(Object &query, double radius, size_t size, ObjectSpace::ResultSet &results) {
|
||||
if (!results.empty()) {
|
||||
NGTThrowException("lenearSearch: results is not empty");
|
||||
}
|
||||
#ifndef NGT_PREFETCH_DISABLED
|
||||
size_t byteSizeOfObject = getByteSizeOfObject();
|
||||
const size_t prefetchOffset = getPrefetchOffset();
|
||||
#endif
|
||||
ObjectRepository &rep = *this;
|
||||
for (size_t idx = 0; idx < rep.size(); idx++) {
|
||||
#ifndef NGT_PREFETCH_DISABLED
|
||||
if (idx + prefetchOffset < rep.size() && rep[idx + prefetchOffset] != 0) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
MemoryCache::prefetch((unsigned char*)&(*static_cast<PersistentObject*>(ObjectRepository::get(idx + prefetchOffset))), byteSizeOfObject);
|
||||
#else
|
||||
MemoryCache::prefetch((unsigned char*)&(*static_cast<PersistentObject*>(rep[idx + prefetchOffset]))[0], byteSizeOfObject);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
if (rep[idx] == 0) {
|
||||
continue;
|
||||
}
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
Distance d = (*comparator)((Object&)query, (PersistentObject&)*rep[idx]);
|
||||
#else
|
||||
Distance d = (*comparator)((Object&)query, (Object&)*rep[idx]);
|
||||
#endif
|
||||
if (radius < 0.0 || d <= radius) {
|
||||
NGT::ObjectDistance obj(idx, d);
|
||||
results.push(obj);
|
||||
if (results.size() > size) {
|
||||
results.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void *getObject(size_t idx) {
|
||||
if (isEmpty(idx)) {
|
||||
std::stringstream msg;
|
||||
msg << "NGT::ObjectSpaceRepository: The specified ID is out of the range. The object ID should be greater than zero. " << idx << ":" << ObjectRepository::size() << ".";
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
PersistentObject &obj = *(*this)[idx];
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
return reinterpret_cast<OBJECT_TYPE*>(&obj.at(0, allocator));
|
||||
#else
|
||||
return reinterpret_cast<OBJECT_TYPE*>(&obj[0]);
|
||||
#endif
|
||||
}
|
||||
|
||||
void getObject(size_t idx, std::vector<float> &v) {
|
||||
OBJECT_TYPE *obj = static_cast<OBJECT_TYPE*>(getObject(idx));
|
||||
size_t dim = getDimension();
|
||||
v.resize(dim);
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
v[i] = static_cast<float>(obj[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void getObjects(const std::vector<size_t> &idxs, std::vector<std::vector<float>> &vs) {
|
||||
vs.resize(idxs.size());
|
||||
auto v = vs.begin();
|
||||
for (auto idx = idxs.begin(); idx != idxs.end(); idx++, v++) {
|
||||
getObject(*idx, *v);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
void normalize(PersistentObject &object) {
|
||||
OBJECT_TYPE *obj = (OBJECT_TYPE*)&object.at(0, getRepository().getAllocator());
|
||||
ObjectSpace::normalize(obj, ObjectSpace::dimension);
|
||||
}
|
||||
#endif
|
||||
void normalize(Object &object) {
|
||||
OBJECT_TYPE *obj = (OBJECT_TYPE*)&object[0];
|
||||
ObjectSpace::normalize(obj, ObjectSpace::dimension);
|
||||
}
|
||||
|
||||
Object *allocateObject() { return ObjectRepository::allocateObject(); }
|
||||
void deleteObject(Object *po) { ObjectRepository::deleteObject(po); }
|
||||
|
||||
Object *allocateNormalizedObject(const std::string &textLine, const std::string &sep) {
|
||||
Object *allocatedObject = ObjectRepository::allocateObject(textLine, sep);
|
||||
if (normalization) {
|
||||
normalize(*allocatedObject);
|
||||
}
|
||||
return allocatedObject;
|
||||
}
|
||||
Object *allocateNormalizedObject(const std::vector<double> &obj) {
|
||||
Object *allocatedObject = ObjectRepository::allocateObject(obj);
|
||||
if (normalization) {
|
||||
normalize(*allocatedObject);
|
||||
}
|
||||
return allocatedObject;
|
||||
}
|
||||
Object *allocateNormalizedObject(const std::vector<float> &obj) {
|
||||
Object *allocatedObject = ObjectRepository::allocateObject(obj);
|
||||
if (normalization) {
|
||||
normalize(*allocatedObject);
|
||||
}
|
||||
return allocatedObject;
|
||||
}
|
||||
Object *allocateNormalizedObject(const std::vector<uint8_t> &obj) {
|
||||
Object *allocatedObject = ObjectRepository::allocateObject(obj);
|
||||
if (normalization) {
|
||||
normalize(*allocatedObject);
|
||||
}
|
||||
return allocatedObject;
|
||||
}
|
||||
Object *allocateNormalizedObject(const float *obj, size_t size) {
|
||||
Object *allocatedObject = ObjectRepository::allocateObject(obj, size);
|
||||
if (normalization) {
|
||||
normalize(*allocatedObject);
|
||||
}
|
||||
return allocatedObject;
|
||||
}
|
||||
|
||||
PersistentObject *allocateNormalizedPersistentObject(const std::vector<double> &obj) {
|
||||
PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj);
|
||||
if (normalization) {
|
||||
normalize(*allocatedObject);
|
||||
}
|
||||
return allocatedObject;
|
||||
}
|
||||
|
||||
PersistentObject *allocateNormalizedPersistentObject(const std::vector<float> &obj) {
|
||||
PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj);
|
||||
if (normalization) {
|
||||
normalize(*allocatedObject);
|
||||
}
|
||||
return allocatedObject;
|
||||
}
|
||||
|
||||
PersistentObject *allocateNormalizedPersistentObject(const std::vector<uint8_t> &obj) {
|
||||
PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj);
|
||||
if (normalization) {
|
||||
normalize(*allocatedObject);
|
||||
}
|
||||
return allocatedObject;
|
||||
}
|
||||
|
||||
size_t getSize() { return ObjectRepository::size(); }
|
||||
size_t getSizeOfElement() { return sizeof(OBJECT_TYPE); }
|
||||
const std::type_info &getObjectType() { return typeid(OBJECT_TYPE); };
|
||||
size_t getByteSizeOfObject() { return getByteSize(); }
|
||||
|
||||
ObjectRepository &getRepository() { return *this; };
|
||||
|
||||
void show(std::ostream &os, Object &object) {
|
||||
const std::type_info &t = getObjectType();
|
||||
if (t == typeid(uint8_t)) {
|
||||
unsigned char *optr = static_cast<unsigned char*>(&object[0]);
|
||||
for (size_t i = 0; i < getDimension(); i++) {
|
||||
os << (int)optr[i] << " ";
|
||||
}
|
||||
} else if (t == typeid(float)) {
|
||||
float *optr = reinterpret_cast<float*>(&object[0]);
|
||||
for (size_t i = 0; i < getDimension(); i++) {
|
||||
os << optr[i] << " ";
|
||||
}
|
||||
} else {
|
||||
os << " not implement for the type.";
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
// set v in objectspace to this object using allocator.
|
||||
inline void PersistentObject::set(PersistentObject &po, ObjectSpace &objectspace) {
|
||||
SharedMemoryAllocator &allocator = objectspace.getRepository().getAllocator();
|
||||
uint8_t *src = (uint8_t *)&po.at(0, allocator);
|
||||
uint8_t *dst = (uint8_t *)&(*this).at(0, allocator);
|
||||
memcpy(dst, src, objectspace.getByteSizeOfObject());
|
||||
}
|
||||
|
||||
inline off_t PersistentObject::allocate(ObjectSpace &objectspace) {
|
||||
SharedMemoryAllocator &allocator = objectspace.getRepository().getAllocator();
|
||||
return allocator.getOffset(new(allocator) PersistentObject(allocator, &objectspace));
|
||||
}
|
||||
|
||||
inline void PersistentObject::serializeAsText(std::ostream &os, ObjectSpace *objectspace) {
|
||||
assert(objectspace != 0);
|
||||
SharedMemoryAllocator &allocator = objectspace->getRepository().getAllocator();
|
||||
const std::type_info &t = objectspace->getObjectType();
|
||||
void *ref = &(*this).at(0, allocator);
|
||||
size_t dimension = objectspace->getDimension();
|
||||
if (t == typeid(uint8_t)) {
|
||||
NGT::Serializer::writeAsText(os, (uint8_t*)ref, dimension);
|
||||
} else if (t == typeid(float)) {
|
||||
NGT::Serializer::writeAsText(os, (float*)ref, dimension);
|
||||
} else if (t == typeid(double)) {
|
||||
NGT::Serializer::writeAsText(os, (double*)ref, dimension);
|
||||
} else if (t == typeid(uint16_t)) {
|
||||
NGT::Serializer::writeAsText(os, (uint16_t*)ref, dimension);
|
||||
} else if (t == typeid(uint32_t)) {
|
||||
NGT::Serializer::writeAsText(os, (uint32_t*)ref, dimension);
|
||||
} else {
|
||||
std::cerr << "ObjectT::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
|
||||
inline void PersistentObject::deserializeAsText(std::ifstream &is, ObjectSpace *objectspace) {
|
||||
assert(objectspace != 0);
|
||||
SharedMemoryAllocator &allocator = objectspace->getRepository().getAllocator();
|
||||
const std::type_info &t = objectspace->getObjectType();
|
||||
size_t dimension = objectspace->getDimension();
|
||||
void *ref = &(*this).at(0, allocator);
|
||||
assert(ref != 0);
|
||||
if (t == typeid(uint8_t)) {
|
||||
NGT::Serializer::readAsText(is, (uint8_t*)ref, dimension);
|
||||
} else if (t == typeid(float)) {
|
||||
NGT::Serializer::readAsText(is, (float*)ref, dimension);
|
||||
} else if (t == typeid(double)) {
|
||||
NGT::Serializer::readAsText(is, (double*)ref, dimension);
|
||||
} else if (t == typeid(uint16_t)) {
|
||||
NGT::Serializer::readAsText(is, (uint16_t*)ref, dimension);
|
||||
} else if (t == typeid(uint32_t)) {
|
||||
NGT::Serializer::readAsText(is, (uint32_t*)ref, dimension);
|
||||
} else {
|
||||
std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace NGT
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,781 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "NGT/defines.h"
|
||||
|
||||
#if defined(NGT_NO_AVX)
|
||||
// #warning "*** SIMD is *NOT* available! ***"
|
||||
#else
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
namespace NGT {
|
||||
|
||||
class MemoryCache {
|
||||
public:
|
||||
inline static void
|
||||
prefetch(unsigned char* ptr, const size_t byteSizeOfObject) {
|
||||
#if !defined(NGT_NO_AVX)
|
||||
switch ((byteSizeOfObject - 1) >> 6) {
|
||||
default:
|
||||
case 28:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 27:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 26:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 25:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 24:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 23:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 22:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 21:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 20:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 19:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 18:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 17:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 16:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 15:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 14:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 13:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 12:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 11:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 10:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 9:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 8:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 7:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 6:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 5:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 4:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 3:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 2:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 1:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
case 0:
|
||||
_mm_prefetch(ptr, _MM_HINT_T0);
|
||||
ptr += 64;
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
inline static void*
|
||||
alignedAlloc(const size_t allocSize) {
|
||||
#ifdef NGT_NO_AVX
|
||||
return new uint8_t[allocSize];
|
||||
#else
|
||||
#if defined(NGT_AVX512)
|
||||
size_t alignment = 64;
|
||||
uint64_t mask = 0xFFFFFFFFFFFFFFC0;
|
||||
#elif defined(NGT_AVX2)
|
||||
size_t alignment = 32;
|
||||
uint64_t mask = 0xFFFFFFFFFFFFFFE0;
|
||||
#else
|
||||
size_t alignment = 16;
|
||||
uint64_t mask = 0xFFFFFFFFFFFFFFF0;
|
||||
#endif
|
||||
uint8_t* p = new uint8_t[allocSize + alignment];
|
||||
uint8_t* ptr = p + alignment;
|
||||
ptr = reinterpret_cast<uint8_t*>((reinterpret_cast<uint64_t>(ptr) & mask));
|
||||
*p++ = 0xAB;
|
||||
while (p != ptr) *p++ = 0xCD;
|
||||
return ptr;
|
||||
#endif
|
||||
}
|
||||
inline static void
|
||||
alignedFree(void* ptr) {
|
||||
#ifdef NGT_NO_AVX
|
||||
delete[] static_cast<uint8_t*>(ptr);
|
||||
#else
|
||||
uint8_t* p = static_cast<uint8_t*>(ptr);
|
||||
p--;
|
||||
while (*p == 0xCD) p--;
|
||||
if (*p != 0xAB) {
|
||||
NGTThrowException("MemoryCache::alignedFree: Fatal Error! Cannot find allocated address.");
|
||||
}
|
||||
delete[] p;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
class PrimitiveComparator {
|
||||
public:
|
||||
static double
|
||||
absolute(double v) {
|
||||
return fabs(v);
|
||||
}
|
||||
static int
|
||||
absolute(int v) {
|
||||
return abs(v);
|
||||
}
|
||||
|
||||
#if defined(NGT_NO_AVX)
|
||||
template <typename OBJECT_TYPE, typename COMPARE_TYPE>
|
||||
inline static double
|
||||
compareL2(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
const OBJECT_TYPE* last = a + size;
|
||||
const OBJECT_TYPE* lastgroup = last - 3;
|
||||
COMPARE_TYPE diff0, diff1, diff2, diff3;
|
||||
double d = 0.0;
|
||||
while (a < lastgroup) {
|
||||
diff0 = (COMPARE_TYPE)(a[0] - b[0]);
|
||||
diff1 = (COMPARE_TYPE)(a[1] - b[1]);
|
||||
diff2 = (COMPARE_TYPE)(a[2] - b[2]);
|
||||
diff3 = (COMPARE_TYPE)(a[3] - b[3]);
|
||||
d += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
|
||||
a += 4;
|
||||
b += 4;
|
||||
}
|
||||
while (a < last) {
|
||||
diff0 = (COMPARE_TYPE)(*a++ - *b++);
|
||||
d += diff0 * diff0;
|
||||
}
|
||||
return sqrt((double)d);
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareL2(const uint8_t* a, const uint8_t* b, size_t size) {
|
||||
return compareL2<uint8_t, int>(a, b, size);
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareL2(const float* a, const float* b, size_t size) {
|
||||
return compareL2<float, double>(a, b, size);
|
||||
}
|
||||
|
||||
#else
|
||||
inline static double
|
||||
compareL2(const float* a, const float* b, size_t size) {
|
||||
const float* last = a + size;
|
||||
#if defined(NGT_AVX512)
|
||||
__m512 sum512 = _mm512_setzero_ps();
|
||||
while (a < last) {
|
||||
__m512 v = _mm512_sub_ps(_mm512_loadu_ps(a), _mm512_loadu_ps(b));
|
||||
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v, v));
|
||||
a += 16;
|
||||
b += 16;
|
||||
}
|
||||
|
||||
__m256 sum256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum512, 0), _mm512_extractf32x8_ps(sum512, 1));
|
||||
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
|
||||
#elif defined(NGT_AVX2)
|
||||
__m256 sum256 = _mm256_setzero_ps();
|
||||
__m256 v;
|
||||
while (a < last) {
|
||||
v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
|
||||
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v, v));
|
||||
a += 8;
|
||||
b += 8;
|
||||
v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
|
||||
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v, v));
|
||||
a += 8;
|
||||
b += 8;
|
||||
}
|
||||
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
|
||||
#else
|
||||
__m128 sum128 = _mm_setzero_ps();
|
||||
__m128 v;
|
||||
while (a < last) {
|
||||
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
|
||||
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
|
||||
a += 4;
|
||||
b += 4;
|
||||
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
|
||||
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
|
||||
a += 4;
|
||||
b += 4;
|
||||
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
|
||||
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
|
||||
a += 4;
|
||||
b += 4;
|
||||
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
|
||||
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
|
||||
a += 4;
|
||||
b += 4;
|
||||
}
|
||||
#endif
|
||||
|
||||
__attribute__((aligned(32))) float f[4];
|
||||
_mm_store_ps(f, sum128);
|
||||
|
||||
double s = f[0] + f[1] + f[2] + f[3];
|
||||
return sqrt(s);
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareL2(const unsigned char* a, const unsigned char* b, size_t size) {
|
||||
__m128 sum = _mm_setzero_ps();
|
||||
const unsigned char* last = a + size;
|
||||
const unsigned char* lastgroup = last - 7;
|
||||
const __m128i zero = _mm_setzero_si128();
|
||||
while (a < lastgroup) {
|
||||
__m128i x1 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)a));
|
||||
__m128i x2 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)b));
|
||||
x1 = _mm_subs_epi16(x1, x2);
|
||||
__m128i v = _mm_mullo_epi16(x1, x1);
|
||||
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpacklo_epi16(v, zero)));
|
||||
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpackhi_epi16(v, zero)));
|
||||
a += 8;
|
||||
b += 8;
|
||||
}
|
||||
__attribute__((aligned(32))) float f[4];
|
||||
_mm_store_ps(f, sum);
|
||||
double s = f[0] + f[1] + f[2] + f[3];
|
||||
while (a < last) {
|
||||
int d = (int)*a++ - (int)*b++;
|
||||
s += d * d;
|
||||
}
|
||||
return sqrt(s);
|
||||
}
|
||||
#endif
|
||||
#if defined(NGT_NO_AVX)
|
||||
template <typename OBJECT_TYPE, typename COMPARE_TYPE>
|
||||
static double
|
||||
compareL1(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
const OBJECT_TYPE* last = a + size;
|
||||
const OBJECT_TYPE* lastgroup = last - 3;
|
||||
COMPARE_TYPE diff0, diff1, diff2, diff3;
|
||||
double d = 0.0;
|
||||
while (a < lastgroup) {
|
||||
diff0 = (COMPARE_TYPE)(a[0] - b[0]);
|
||||
diff1 = (COMPARE_TYPE)(a[1] - b[1]);
|
||||
diff2 = (COMPARE_TYPE)(a[2] - b[2]);
|
||||
diff3 = (COMPARE_TYPE)(a[3] - b[3]);
|
||||
d += absolute(diff0) + absolute(diff1) + absolute(diff2) + absolute(diff3);
|
||||
a += 4;
|
||||
b += 4;
|
||||
}
|
||||
while (a < last) {
|
||||
diff0 = (COMPARE_TYPE)*a++ - (COMPARE_TYPE)*b++;
|
||||
d += absolute(diff0);
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareL1(const uint8_t* a, const uint8_t* b, size_t size) {
|
||||
return compareL1<uint8_t, int>(a, b, size);
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareL1(const float* a, const float* b, size_t size) {
|
||||
return compareL1<float, double>(a, b, size);
|
||||
}
|
||||
|
||||
#else
|
||||
inline static double
|
||||
compareL1(const float* a, const float* b, size_t size) {
|
||||
__m256 sum = _mm256_setzero_ps();
|
||||
const float* last = a + size;
|
||||
const float* lastgroup = last - 7;
|
||||
while (a < lastgroup) {
|
||||
__m256 x1 = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
|
||||
const __m256 mask = _mm256_set1_ps(-0.0f);
|
||||
__m256 v = _mm256_andnot_ps(mask, x1);
|
||||
sum = _mm256_add_ps(sum, v);
|
||||
a += 8;
|
||||
b += 8;
|
||||
}
|
||||
__attribute__((aligned(32))) float f[8];
|
||||
_mm256_store_ps(f, sum);
|
||||
double s = f[0] + f[1] + f[2] + f[3] + f[4] + f[5] + f[6] + f[7];
|
||||
while (a < last) {
|
||||
double d = fabs(*a++ - *b++);
|
||||
s += d;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
inline static double
|
||||
compareL1(const unsigned char* a, const unsigned char* b, size_t size) {
|
||||
__m128 sum = _mm_setzero_ps();
|
||||
const unsigned char* last = a + size;
|
||||
const unsigned char* lastgroup = last - 7;
|
||||
const __m128i zero = _mm_setzero_si128();
|
||||
while (a < lastgroup) {
|
||||
__m128i x1 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)a));
|
||||
__m128i x2 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)b));
|
||||
x1 = _mm_subs_epi16(x1, x2);
|
||||
x1 = _mm_sign_epi16(x1, x1);
|
||||
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpacklo_epi16(x1, zero)));
|
||||
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpackhi_epi16(x1, zero)));
|
||||
a += 8;
|
||||
b += 8;
|
||||
}
|
||||
__attribute__((aligned(32))) float f[4];
|
||||
_mm_store_ps(f, sum);
|
||||
double s = f[0] + f[1] + f[2] + f[3];
|
||||
while (a < last) {
|
||||
double d = fabs((double)*a++ - (double)*b++);
|
||||
s += d;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(NGT_NO_AVX) || !defined(__POPCNT__)
|
||||
inline static double
|
||||
popCount(uint32_t x) {
|
||||
x = (x & 0x55555555) + (x >> 1 & 0x55555555);
|
||||
x = (x & 0x33333333) + (x >> 2 & 0x33333333);
|
||||
x = (x & 0x0F0F0F0F) + (x >> 4 & 0x0F0F0F0F);
|
||||
x = (x & 0x00FF00FF) + (x >> 8 & 0x00FF00FF);
|
||||
x = (x & 0x0000FFFF) + (x >> 16 & 0x0000FFFF);
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareHammingDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
const uint32_t* last = reinterpret_cast<const uint32_t*>(a + size);
|
||||
|
||||
const uint32_t* uinta = reinterpret_cast<const uint32_t*>(a);
|
||||
const uint32_t* uintb = reinterpret_cast<const uint32_t*>(b);
|
||||
size_t count = 0;
|
||||
while (uinta < last) {
|
||||
count += popCount(*uinta++ ^ *uintb++);
|
||||
}
|
||||
|
||||
return static_cast<double>(count);
|
||||
}
|
||||
#else
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareHammingDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
const uint64_t* last = reinterpret_cast<const uint64_t*>(a + size);
|
||||
|
||||
const uint64_t* uinta = reinterpret_cast<const uint64_t*>(a);
|
||||
const uint64_t* uintb = reinterpret_cast<const uint64_t*>(b);
|
||||
size_t count = 0;
|
||||
while (uinta < last) {
|
||||
count += _mm_popcnt_u64(*uinta++ ^ *uintb++);
|
||||
count += _mm_popcnt_u64(*uinta++ ^ *uintb++);
|
||||
}
|
||||
|
||||
return static_cast<double>(count);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(NGT_NO_AVX) || !defined(__POPCNT__)
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareJaccardDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
const uint32_t* last = reinterpret_cast<const uint32_t*>(a + size);
|
||||
|
||||
const uint32_t* uinta = reinterpret_cast<const uint32_t*>(a);
|
||||
const uint32_t* uintb = reinterpret_cast<const uint32_t*>(b);
|
||||
size_t count = 0;
|
||||
size_t countDe = 0;
|
||||
while (uinta < last) {
|
||||
count += popCount(*uinta & *uintb);
|
||||
countDe += popCount(*uinta++ | *uintb++);
|
||||
count += popCount(*uinta & *uintb);
|
||||
countDe += popCount(*uinta++ | *uintb++);
|
||||
}
|
||||
|
||||
return 1.0 - static_cast<double>(count) / static_cast<double>(countDe);
|
||||
}
|
||||
#else
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareJaccardDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
const uint64_t* last = reinterpret_cast<const uint64_t*>(a + size);
|
||||
|
||||
const uint64_t* uinta = reinterpret_cast<const uint64_t*>(a);
|
||||
const uint64_t* uintb = reinterpret_cast<const uint64_t*>(b);
|
||||
size_t count = 0;
|
||||
size_t countDe = 0;
|
||||
while (uinta < last) {
|
||||
count += _mm_popcnt_u64(*uinta & *uintb);
|
||||
countDe += _mm_popcnt_u64(*uinta++ | *uintb++);
|
||||
count += _mm_popcnt_u64(*uinta & *uintb);
|
||||
countDe += _mm_popcnt_u64(*uinta++ | *uintb++);
|
||||
}
|
||||
|
||||
return 1.0 - static_cast<double>(count) / static_cast<double>(countDe);
|
||||
}
|
||||
#endif
|
||||
|
||||
inline static double
|
||||
compareSparseJaccardDistance(const unsigned char* a, unsigned char* b, size_t size) {
|
||||
abort();
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareSparseJaccardDistance(const float* a, const float* b, size_t size) {
|
||||
size_t loca = 0;
|
||||
size_t locb = 0;
|
||||
const uint32_t* ai = reinterpret_cast<const uint32_t*>(a);
|
||||
const uint32_t* bi = reinterpret_cast<const uint32_t*>(b);
|
||||
size_t count = 0;
|
||||
while (locb < size && ai[loca] != 0 && bi[loca] != 0) {
|
||||
int64_t sub = static_cast<int64_t>(ai[loca]) - static_cast<int64_t>(bi[locb]);
|
||||
count += sub == 0;
|
||||
loca += sub <= 0;
|
||||
locb += sub >= 0;
|
||||
}
|
||||
while (ai[loca] != 0) {
|
||||
loca++;
|
||||
}
|
||||
while (locb < size && bi[locb] != 0) {
|
||||
locb++;
|
||||
}
|
||||
return 1.0 - static_cast<double>(count) / static_cast<double>(loca + locb - count);
|
||||
}
|
||||
|
||||
#if defined(NGT_NO_AVX)
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareDotProduct(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
double sum = 0.0;
|
||||
for (size_t loc = 0; loc < size; loc++) {
|
||||
sum += (double)a[loc] * (double)b[loc];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareCosine(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
double normA = 0.0;
|
||||
double normB = 0.0;
|
||||
double sum = 0.0;
|
||||
for (size_t loc = 0; loc < size; loc++) {
|
||||
normA += (double)a[loc] * (double)a[loc];
|
||||
normB += (double)b[loc] * (double)b[loc];
|
||||
sum += (double)a[loc] * (double)b[loc];
|
||||
}
|
||||
|
||||
double cosine = sum / sqrt(normA * normB);
|
||||
|
||||
return cosine;
|
||||
}
|
||||
#else
|
||||
inline static double
|
||||
compareDotProduct(const float* a, const float* b, size_t size) {
|
||||
const float* last = a + size;
|
||||
#if defined(NGT_AVX512)
|
||||
__m512 sum512 = _mm512_setzero_ps();
|
||||
while (a < last) {
|
||||
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(_mm512_loadu_ps(a), _mm512_loadu_ps(b)));
|
||||
a += 16;
|
||||
b += 16;
|
||||
}
|
||||
|
||||
__m256 sum256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum512, 0), _mm512_extractf32x8_ps(sum512, 1));
|
||||
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
|
||||
#elif defined(NGT_AVX2)
|
||||
__m256 sum256 = _mm256_setzero_ps();
|
||||
while (a < last) {
|
||||
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b)));
|
||||
a += 8;
|
||||
b += 8;
|
||||
}
|
||||
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
|
||||
#else
|
||||
__m128 sum128 = _mm_setzero_ps();
|
||||
while (a < last) {
|
||||
sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(a), _mm_loadu_ps(b)));
|
||||
a += 4;
|
||||
b += 4;
|
||||
}
|
||||
#endif
|
||||
__attribute__((aligned(32))) float f[4];
|
||||
_mm_store_ps(f, sum128);
|
||||
|
||||
double s = f[0] + f[1] + f[2] + f[3];
|
||||
return s;
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareDotProduct(const unsigned char* a, const unsigned char* b, size_t size) {
|
||||
double sum = 0.0;
|
||||
for (size_t loc = 0; loc < size; loc++) {
|
||||
sum += (double)a[loc] * (double)b[loc];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareCosine(const float* a, const float* b, size_t size) {
|
||||
const float* last = a + size;
|
||||
#if defined(NGT_AVX512)
|
||||
__m512 normA = _mm512_setzero_ps();
|
||||
__m512 normB = _mm512_setzero_ps();
|
||||
__m512 sum = _mm512_setzero_ps();
|
||||
while (a < last) {
|
||||
__m512 am = _mm512_loadu_ps(a);
|
||||
__m512 bm = _mm512_loadu_ps(b);
|
||||
normA = _mm512_add_ps(normA, _mm512_mul_ps(am, am));
|
||||
normB = _mm512_add_ps(normB, _mm512_mul_ps(bm, bm));
|
||||
sum = _mm512_add_ps(sum, _mm512_mul_ps(am, bm));
|
||||
a += 16;
|
||||
b += 16;
|
||||
}
|
||||
__m256 am256 = _mm256_add_ps(_mm512_extractf32x8_ps(normA, 0), _mm512_extractf32x8_ps(normA, 1));
|
||||
__m256 bm256 = _mm256_add_ps(_mm512_extractf32x8_ps(normB, 0), _mm512_extractf32x8_ps(normB, 1));
|
||||
__m256 s256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum, 0), _mm512_extractf32x8_ps(sum, 1));
|
||||
__m128 am128 = _mm_add_ps(_mm256_extractf128_ps(am256, 0), _mm256_extractf128_ps(am256, 1));
|
||||
__m128 bm128 = _mm_add_ps(_mm256_extractf128_ps(bm256, 0), _mm256_extractf128_ps(bm256, 1));
|
||||
__m128 s128 = _mm_add_ps(_mm256_extractf128_ps(s256, 0), _mm256_extractf128_ps(s256, 1));
|
||||
#elif defined(NGT_AVX2)
|
||||
__m256 normA = _mm256_setzero_ps();
|
||||
__m256 normB = _mm256_setzero_ps();
|
||||
__m256 sum = _mm256_setzero_ps();
|
||||
__m256 am, bm;
|
||||
while (a < last) {
|
||||
am = _mm256_loadu_ps(a);
|
||||
bm = _mm256_loadu_ps(b);
|
||||
normA = _mm256_add_ps(normA, _mm256_mul_ps(am, am));
|
||||
normB = _mm256_add_ps(normB, _mm256_mul_ps(bm, bm));
|
||||
sum = _mm256_add_ps(sum, _mm256_mul_ps(am, bm));
|
||||
a += 8;
|
||||
b += 8;
|
||||
}
|
||||
__m128 am128 = _mm_add_ps(_mm256_extractf128_ps(normA, 0), _mm256_extractf128_ps(normA, 1));
|
||||
__m128 bm128 = _mm_add_ps(_mm256_extractf128_ps(normB, 0), _mm256_extractf128_ps(normB, 1));
|
||||
__m128 s128 = _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
|
||||
#else
|
||||
__m128 am128 = _mm_setzero_ps();
|
||||
__m128 bm128 = _mm_setzero_ps();
|
||||
__m128 s128 = _mm_setzero_ps();
|
||||
__m128 am, bm;
|
||||
while (a < last) {
|
||||
am = _mm_loadu_ps(a);
|
||||
bm = _mm_loadu_ps(b);
|
||||
am128 = _mm_add_ps(am128, _mm_mul_ps(am, am));
|
||||
bm128 = _mm_add_ps(bm128, _mm_mul_ps(bm, bm));
|
||||
s128 = _mm_add_ps(s128, _mm_mul_ps(am, bm));
|
||||
a += 4;
|
||||
b += 4;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
__attribute__((aligned(32))) float f[4];
|
||||
_mm_store_ps(f, am128);
|
||||
double na = f[0] + f[1] + f[2] + f[3];
|
||||
_mm_store_ps(f, bm128);
|
||||
double nb = f[0] + f[1] + f[2] + f[3];
|
||||
_mm_store_ps(f, s128);
|
||||
double s = f[0] + f[1] + f[2] + f[3];
|
||||
|
||||
double cosine = s / sqrt(na * nb);
|
||||
return cosine;
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareCosine(const unsigned char* a, const unsigned char* b, size_t size) {
|
||||
double normA = 0.0;
|
||||
double normB = 0.0;
|
||||
double sum = 0.0;
|
||||
for (size_t loc = 0; loc < size; loc++) {
|
||||
normA += (double)a[loc] * (double)a[loc];
|
||||
normB += (double)b[loc] * (double)b[loc];
|
||||
sum += (double)a[loc] * (double)b[loc];
|
||||
}
|
||||
|
||||
double cosine = sum / sqrt(normA * normB);
|
||||
|
||||
return cosine;
|
||||
}
|
||||
#endif // #if defined(NGT_NO_AVX)
|
||||
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareAngleDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
double cosine = compareCosine(a, b, size);
|
||||
if (cosine >= 1.0) {
|
||||
return 0.0;
|
||||
} else if (cosine <= -1.0) {
|
||||
return acos(-1.0);
|
||||
} else {
|
||||
return acos(cosine);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareNormalizedAngleDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
double cosine = compareDotProduct(a, b, size);
|
||||
if (cosine >= 1.0) {
|
||||
return 0.0;
|
||||
} else if (cosine <= -1.0) {
|
||||
return acos(-1.0);
|
||||
} else {
|
||||
return acos(cosine);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareCosineSimilarity(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
return 1.0 - compareCosine(a, b, size);
|
||||
}
|
||||
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareNormalizedCosineSimilarity(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
double v = 1.0 - compareDotProduct(a, b, size);
|
||||
return v < 0.0 ? 0.0 : v;
|
||||
}
|
||||
|
||||
class L1Uint8 {
|
||||
public:
|
||||
inline static double
|
||||
compare(const void* a, const void* b, size_t size) {
|
||||
return PrimitiveComparator::compareL1((const uint8_t*)a, (const uint8_t*)b, size);
|
||||
}
|
||||
};
|
||||
|
||||
class L2Uint8 {
|
||||
public:
|
||||
inline static double
|
||||
compare(const void* a, const void* b, size_t size) {
|
||||
return PrimitiveComparator::compareL2((const uint8_t*)a, (const uint8_t*)b, size);
|
||||
}
|
||||
};
|
||||
|
||||
class HammingUint8 {
|
||||
public:
|
||||
inline static double
|
||||
compare(const void* a, const void* b, size_t size) {
|
||||
return PrimitiveComparator::compareHammingDistance((const uint8_t*)a, (const uint8_t*)b, size);
|
||||
}
|
||||
};
|
||||
|
||||
class JaccardUint8 {
|
||||
public:
|
||||
inline static double
|
||||
compare(const void* a, const void* b, size_t size) {
|
||||
return PrimitiveComparator::compareJaccardDistance((const uint8_t*)a, (const uint8_t*)b, size);
|
||||
}
|
||||
};
|
||||
|
||||
class SparseJaccardFloat {
|
||||
public:
|
||||
inline static double
|
||||
compare(const void* a, const void* b, size_t size) {
|
||||
return PrimitiveComparator::compareSparseJaccardDistance((const float*)a, (const float*)b, size);
|
||||
}
|
||||
};
|
||||
|
||||
class L2Float {
|
||||
public:
|
||||
inline static double
|
||||
compare(const void* a, const void* b, size_t size) {
|
||||
#if defined(NGT_NO_AVX)
|
||||
return PrimitiveComparator::compareL2<float, double>((const float*)a, (const float*)b, size);
|
||||
#else
|
||||
return PrimitiveComparator::compareL2((const float*)a, (const float*)b, size);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
class L1Float {
|
||||
public:
|
||||
inline static double
|
||||
compare(const void* a, const void* b, size_t size) {
|
||||
return PrimitiveComparator::compareL1((const float*)a, (const float*)b, size);
|
||||
}
|
||||
};
|
||||
|
||||
class CosineSimilarityFloat {
|
||||
public:
|
||||
inline static double
|
||||
compare(const void* a, const void* b, size_t size) {
|
||||
return PrimitiveComparator::compareCosineSimilarity((const float*)a, (const float*)b, size);
|
||||
}
|
||||
};
|
||||
|
||||
class AngleFloat {
|
||||
public:
|
||||
inline static double
|
||||
compare(const void* a, const void* b, size_t size) {
|
||||
return PrimitiveComparator::compareAngleDistance((const float*)a, (const float*)b, size);
|
||||
}
|
||||
};
|
||||
|
||||
class NormalizedCosineSimilarityFloat {
|
||||
public:
|
||||
inline static double
|
||||
compare(const void* a, const void* b, size_t size) {
|
||||
return PrimitiveComparator::compareNormalizedCosineSimilarity((const float*)a, (const float*)b, size);
|
||||
}
|
||||
};
|
||||
|
||||
class NormalizedAngleFloat {
|
||||
public:
|
||||
inline static double
|
||||
compare(const void* a, const void* b, size_t size) {
|
||||
return PrimitiveComparator::compareNormalizedAngleDistance((const float*)a, (const float*)b, size);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace NGT
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include "NGT/SharedMemoryAllocator.h"
|
||||
|
||||
|
||||
|
||||
void* operator
|
||||
new(size_t size, SharedMemoryAllocator &allocator)
|
||||
{
|
||||
void *addr = allocator.allocate(size);
|
||||
#ifdef MEMORY_ALLOCATOR_INFO
|
||||
std::cerr << "new:" << size << " " << addr << " " << allocator.getTotalSize() << std::endl;
|
||||
#endif
|
||||
return addr;
|
||||
}
|
||||
|
||||
void* operator
|
||||
new[](size_t size, SharedMemoryAllocator &allocator)
|
||||
{
|
||||
|
||||
void *addr = allocator.allocate(size);
|
||||
#ifdef MEMORY_ALLOCATOR_INFO
|
||||
std::cerr << "new[]:" << size << " " << addr << " " << allocator.getTotalSize() << std::endl;
|
||||
#endif
|
||||
return addr;
|
||||
}
|
|
@ -0,0 +1,209 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "NGT/defines.h"
|
||||
#include "NGT/MmapManager.h"
|
||||
|
||||
#include <unistd.h>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <exception>
|
||||
#include <cassert>
|
||||
|
||||
#define MMAP_MANAGER
|
||||
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
class SharedMemoryAllocator {
|
||||
public:
|
||||
enum GetMemorySizeType {
|
||||
GetTotalMemorySize = 0,
|
||||
GetAllocatedMemorySize = 1,
|
||||
GetFreedMemorySize = 2
|
||||
};
|
||||
|
||||
SharedMemoryAllocator():isValid(false) {
|
||||
#ifdef SMA_TRACE
|
||||
std::cerr << "SharedMemoryAllocatorSiglton::constructor" << std::endl;
|
||||
#endif
|
||||
}
|
||||
SharedMemoryAllocator(const SharedMemoryAllocator& a){}
|
||||
SharedMemoryAllocator& operator=(const SharedMemoryAllocator& a){ return *this; }
|
||||
public:
|
||||
void* allocate(size_t size) {
|
||||
if (isValid == false) {
|
||||
std::cerr << "SharedMemoryAllocator::allocate: Fatal error! " << std::endl;
|
||||
assert(isValid);
|
||||
}
|
||||
#ifdef SMA_TRACE
|
||||
std::cerr << "SharedMemoryAllocator::allocate: size=" << size << std::endl;
|
||||
std::cerr << "SharedMemoryAllocator::allocate: before " << getTotalSize() << ":" << getAllocatedSize() << ":" << getFreedSize() << std::endl;
|
||||
#endif
|
||||
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
|
||||
if(!isValid){
|
||||
return NULL;
|
||||
}
|
||||
off_t file_offset = mmanager->alloc(size, true);
|
||||
if (file_offset == -1) {
|
||||
std::cerr << "Fatal Error: Allocating memory size is too big for this settings." << std::endl;
|
||||
std::cerr << " Max allocation size should be enlarged." << std::endl;
|
||||
abort();
|
||||
}
|
||||
void *p = mmanager->getAbsAddr(file_offset);
|
||||
std::memset(p, 0, size);
|
||||
#ifdef SMA_TRACE
|
||||
std::cerr << "SharedMemoryAllocator::allocate: end" <<std::endl;
|
||||
#endif
|
||||
return p;
|
||||
#else
|
||||
void *ptr = std::malloc(size);
|
||||
std::memset(ptr, 0, size);
|
||||
return ptr;
|
||||
#endif
|
||||
}
|
||||
void free(void *ptr) {
|
||||
#ifdef SMA_TRACE
|
||||
std::cerr << "SharedMemoryAllocator::free: ptr=" << ptr << std::endl;
|
||||
#endif
|
||||
if (ptr == 0) {
|
||||
std::cerr << "SharedMemoryAllocator::free: ptr is invalid! ptr=" << ptr << std::endl;
|
||||
}
|
||||
if (ptr == 0) {
|
||||
return;
|
||||
}
|
||||
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
|
||||
off_t file_offset = mmanager->getRelAddr(ptr);
|
||||
mmanager->free(file_offset);
|
||||
#else
|
||||
std::free(ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
void *construct(const std::string &filePath, size_t memorysize = 0) {
|
||||
file = filePath; // debug
|
||||
#ifdef SMA_TRACE
|
||||
std::cerr << "ObjectSharedMemoryAllocator::construct: file " << filePath << std::endl;
|
||||
#endif
|
||||
void *hook = 0;
|
||||
#ifdef MMAP_MANAGER
|
||||
mmanager = new MemoryManager::MmapManager();
|
||||
// msize is the maximum allocated size (M byte) at once.
|
||||
size_t msize = memorysize;
|
||||
if (msize == 0) {
|
||||
msize = NGT_SHARED_MEMORY_MAX_SIZE;
|
||||
}
|
||||
size_t bsize = msize * 1048576 / sysconf(_SC_PAGESIZE) + 1; // 1048576=1M
|
||||
uint64_t size = bsize * sysconf(_SC_PAGESIZE);
|
||||
MemoryManager::init_option_st option;
|
||||
MemoryManager::MmapManager::setDefaultOptionValue(option);
|
||||
option.use_expand = true;
|
||||
option.reuse_type = MemoryManager::REUSE_DATA_CLASSIFY;
|
||||
bool create = true;
|
||||
if(!mmanager->init(filePath, size, &option)){
|
||||
#ifdef SMA_TRACE
|
||||
std::cerr << "SMA: info. already existed." << std::endl;
|
||||
#endif
|
||||
create = false;
|
||||
} else {
|
||||
#ifdef SMA_TRACE
|
||||
std::cerr << "SMA::construct: msize=" << msize << ":" << memorysize << std::endl;
|
||||
#endif
|
||||
}
|
||||
if(!mmanager->openMemory(filePath)){
|
||||
std::cerr << "SMA: open error" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
if (!create) {
|
||||
#ifdef SMA_TRACE
|
||||
std::cerr << "SMA: get hook to initialize data structure" << std::endl;
|
||||
#endif
|
||||
hook = mmanager->getEntryHook();
|
||||
assert(hook != 0);
|
||||
}
|
||||
#endif
|
||||
isValid = true;
|
||||
#ifdef SMA_TRACE
|
||||
std::cerr << "SharedMemoryAllocator::construct: " << filePath << " total="
|
||||
<< getTotalSize() << " allocated=" << getAllocatedSize() << " freed="
|
||||
<< getFreedSize() << " (" << (double)getFreedSize() / (double)getTotalSize() << ") " << std::endl;
|
||||
#endif
|
||||
return hook;
|
||||
}
|
||||
void destruct() {
|
||||
if (!isValid) {
|
||||
return;
|
||||
}
|
||||
isValid = false;
|
||||
#ifdef MMAP_MANAGER
|
||||
mmanager->closeMemory();
|
||||
delete mmanager;
|
||||
#endif
|
||||
};
|
||||
void setEntry(void *entry) {
|
||||
#ifdef MMAP_MANAGER
|
||||
mmanager->setEntryHook(entry);
|
||||
#endif
|
||||
}
|
||||
void *getAddr(off_t oft) {
|
||||
if (oft == 0) {
|
||||
return 0;
|
||||
}
|
||||
assert(oft > 0);
|
||||
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
|
||||
return mmanager->getAbsAddr(oft);
|
||||
#else
|
||||
return (void*)oft;
|
||||
#endif
|
||||
}
|
||||
off_t getOffset(void *adr) {
|
||||
if (adr == 0) {
|
||||
return 0;
|
||||
}
|
||||
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
|
||||
return mmanager->getRelAddr(adr);
|
||||
#else
|
||||
return (off_t)adr;
|
||||
#endif
|
||||
}
|
||||
size_t getMemorySize(GetMemorySizeType t) {
|
||||
switch (t) {
|
||||
case GetTotalMemorySize : return getTotalSize();
|
||||
case GetAllocatedMemorySize : return getAllocatedSize();
|
||||
case GetFreedMemorySize : return getFreedSize();
|
||||
}
|
||||
return getTotalSize();
|
||||
}
|
||||
size_t getTotalSize() { return mmanager->getTotalSize(); }
|
||||
size_t getAllocatedSize() { return mmanager->getUseSize(); }
|
||||
size_t getFreedSize() { return mmanager->getFreeSize(); }
|
||||
|
||||
bool isValid;
|
||||
std::string file;
|
||||
#ifdef MMAP_MANAGER
|
||||
MemoryManager::MmapManager *mmanager;
|
||||
#endif
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void* operator new(size_t size, SharedMemoryAllocator &allocator);
|
||||
void* operator new[](size_t size, SharedMemoryAllocator &allocator);
|
|
@ -0,0 +1,128 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include <pthread.h>
|
||||
|
||||
#include "Thread.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace NGT;
|
||||
|
||||
namespace NGT {
|
||||
class ThreadInfo {
|
||||
public:
|
||||
pthread_t threadid;
|
||||
pthread_attr_t threadAttr;
|
||||
};
|
||||
|
||||
class ThreadMutex {
|
||||
public:
|
||||
pthread_mutex_t mutex;
|
||||
pthread_cond_t condition;
|
||||
};
|
||||
}
|
||||
|
||||
Thread::Thread() {
|
||||
threadInfo = new ThreadInfo;
|
||||
threadInfo->threadid = 0;
|
||||
threadNo = -1;
|
||||
isTerminate = false;
|
||||
}
|
||||
|
||||
Thread::~Thread() {
|
||||
if (threadInfo != 0) {
|
||||
delete threadInfo;
|
||||
}
|
||||
}
|
||||
|
||||
ThreadMutex *
|
||||
Thread::constructThreadMutex()
|
||||
{
|
||||
return new ThreadMutex;
|
||||
}
|
||||
|
||||
void
|
||||
Thread::destructThreadMutex(ThreadMutex *t)
|
||||
{
|
||||
if (t != 0) {
|
||||
pthread_mutex_destroy(&(t->mutex));
|
||||
pthread_cond_destroy(&(t->condition));
|
||||
delete t;
|
||||
}
|
||||
}
|
||||
|
||||
int
|
||||
Thread::start()
|
||||
{
|
||||
pthread_attr_init(&(threadInfo->threadAttr));
|
||||
size_t stackSize = 0;
|
||||
pthread_attr_getstacksize(&(threadInfo->threadAttr), &stackSize);
|
||||
if (stackSize < 0xa00000) { // 64bit stack size
|
||||
stackSize *= 4;
|
||||
}
|
||||
pthread_attr_setstacksize(&(threadInfo->threadAttr), stackSize);
|
||||
pthread_attr_getstacksize(&(threadInfo->threadAttr), &stackSize);
|
||||
return pthread_create(&(threadInfo->threadid), &(threadInfo->threadAttr), Thread::startThread, this);
|
||||
|
||||
}
|
||||
|
||||
int
|
||||
Thread::join()
|
||||
{
|
||||
return pthread_join(threadInfo->threadid, 0);
|
||||
}
|
||||
|
||||
void
|
||||
Thread::lock(ThreadMutex &m)
|
||||
{
|
||||
pthread_mutex_lock(&m.mutex);
|
||||
}
|
||||
void
|
||||
Thread::unlock(ThreadMutex &m)
|
||||
{
|
||||
pthread_mutex_unlock(&m.mutex);
|
||||
}
|
||||
void
|
||||
Thread::signal(ThreadMutex &m)
|
||||
{
|
||||
pthread_cond_signal(&m.condition);
|
||||
}
|
||||
|
||||
void
|
||||
Thread::wait(ThreadMutex &m)
|
||||
{
|
||||
if (pthread_cond_wait(&m.condition, &m.mutex) != 0) {
|
||||
cerr << "waitForSignalFromThread: internal error" << endl;
|
||||
NGTThrowException("waitForSignalFromThread: internal error");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
Thread::broadcast(ThreadMutex &m)
|
||||
{
|
||||
pthread_cond_broadcast(&m.condition);
|
||||
}
|
||||
|
||||
void
|
||||
Thread::mutexInit(ThreadMutex &m)
|
||||
{
|
||||
if (pthread_mutex_init(&m.mutex, NULL) != 0) {
|
||||
NGTThrowException("Thread::mutexInit: Cannot initialize mutex");
|
||||
}
|
||||
if (pthread_cond_init(&m.condition, NULL) != 0) {
|
||||
NGTThrowException("Thread::mutexInit: Cannot initialize condition");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,291 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "NGT/Common.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <sys/time.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <deque>
|
||||
|
||||
namespace NGT {
|
||||
void * evaluate_responce(void *);
|
||||
|
||||
class ThreadTerminationException : public Exception {
|
||||
public:
|
||||
ThreadTerminationException(const std::string &file, size_t line, std::stringstream &m) { set(file, line, m.str()); }
|
||||
ThreadTerminationException(const std::string &file, size_t line, const std::string &m) { set(file, line, m); }
|
||||
};
|
||||
|
||||
class ThreadInfo;
|
||||
class ThreadMutex;
|
||||
|
||||
class Thread
|
||||
{
|
||||
public:
|
||||
Thread();
|
||||
|
||||
virtual ~Thread();
|
||||
virtual int start();
|
||||
|
||||
virtual int join();
|
||||
|
||||
static ThreadMutex *constructThreadMutex();
|
||||
static void destructThreadMutex(ThreadMutex *t);
|
||||
|
||||
static void mutexInit(ThreadMutex &m);
|
||||
|
||||
static void lock(ThreadMutex &m);
|
||||
static void unlock(ThreadMutex &m);
|
||||
static void signal(ThreadMutex &m);
|
||||
static void wait(ThreadMutex &m);
|
||||
static void broadcast(ThreadMutex &m);
|
||||
|
||||
protected:
|
||||
virtual int run() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
static void* startThread(void *thread) {
|
||||
if (thread == 0) {
|
||||
return 0;
|
||||
}
|
||||
Thread* p = (Thread*)thread;
|
||||
p->run();
|
||||
return thread;
|
||||
}
|
||||
|
||||
public:
|
||||
int threadNo;
|
||||
bool isTerminate;
|
||||
|
||||
protected:
|
||||
ThreadInfo *threadInfo;
|
||||
};
|
||||
|
||||
template <class JOB, class SHARED_DATA, class THREAD>
|
||||
class ThreadPool {
|
||||
public:
|
||||
class JobQueue : public std::deque<JOB> {
|
||||
public:
|
||||
JobQueue() {
|
||||
threadMutex = Thread::constructThreadMutex();
|
||||
Thread::mutexInit(*threadMutex);
|
||||
}
|
||||
~JobQueue() {
|
||||
Thread::destructThreadMutex(threadMutex);
|
||||
}
|
||||
bool isDeficient() { return std::deque<JOB>::size() <= requestSize; }
|
||||
bool isEmpty() { return std::deque<JOB>::size() == 0; }
|
||||
bool isFull() { return std::deque<JOB>::size() >= maxSize; }
|
||||
void setRequestSize(int s) { requestSize = s; }
|
||||
void setMaxSize(int s) { maxSize = s; }
|
||||
void lock() { Thread::lock(*threadMutex); }
|
||||
void unlock() { Thread::unlock(*threadMutex); }
|
||||
void signal() { Thread::signal(*threadMutex); }
|
||||
void wait() { Thread::wait(*threadMutex); }
|
||||
void wait(JobQueue &q) { wait(*q.threadMutex); }
|
||||
void broadcast() { Thread::broadcast(*threadMutex); }
|
||||
unsigned int requestSize;
|
||||
unsigned int maxSize;
|
||||
ThreadMutex *threadMutex;
|
||||
};
|
||||
class InputJobQueue : public JobQueue {
|
||||
public:
|
||||
InputJobQueue() {
|
||||
isTerminate = false;
|
||||
underPushing = false;
|
||||
pushedSize = 0;
|
||||
}
|
||||
|
||||
void popFront(JOB &d) {
|
||||
JobQueue::lock();
|
||||
while (JobQueue::isEmpty()) {
|
||||
if (isTerminate) {
|
||||
JobQueue::unlock();
|
||||
NGTThrowSpecificException("Thread::termination", ThreadTerminationException);
|
||||
}
|
||||
JobQueue::wait();
|
||||
}
|
||||
d = std::deque<JOB>::front();
|
||||
std::deque<JOB>::pop_front();
|
||||
JobQueue::unlock();
|
||||
return;
|
||||
}
|
||||
|
||||
void popFront(std::deque<JOB> &d, size_t s) {
|
||||
JobQueue::lock();
|
||||
while (JobQueue::isEmpty()) {
|
||||
if (isTerminate) {
|
||||
JobQueue::unlock();
|
||||
NGTThrowSpecificException("Thread::termination", ThreadTerminationException);
|
||||
}
|
||||
JobQueue::wait();
|
||||
}
|
||||
for (size_t i = 0; i < s; i++) {
|
||||
d.push_back(std::deque<JOB>::front());
|
||||
std::deque<JOB>::pop_front();
|
||||
if (JobQueue::isEmpty()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
JobQueue::unlock();
|
||||
return;
|
||||
}
|
||||
|
||||
void pushBack(JOB &data) {
|
||||
JobQueue::lock();
|
||||
if (!underPushing) {
|
||||
underPushing = true;
|
||||
pushedSize = 0;
|
||||
}
|
||||
pushedSize++;
|
||||
std::deque<JOB>::push_back(data);
|
||||
JobQueue::unlock();
|
||||
JobQueue::signal();
|
||||
}
|
||||
|
||||
void pushBackEnd() {
|
||||
underPushing = false;
|
||||
}
|
||||
|
||||
void terminate() {
|
||||
JobQueue::lock();
|
||||
if (underPushing || !JobQueue::isEmpty()) {
|
||||
JobQueue::unlock();
|
||||
NGTThrowException("Thread::teminate:Under pushing!");
|
||||
}
|
||||
isTerminate = true;
|
||||
JobQueue::unlock();
|
||||
JobQueue::broadcast();
|
||||
}
|
||||
|
||||
bool isTerminate;
|
||||
bool underPushing;
|
||||
size_t pushedSize;
|
||||
|
||||
};
|
||||
|
||||
class OutputJobQueue : public JobQueue {
|
||||
public:
|
||||
void waitForFull() {
|
||||
JobQueue::wait();
|
||||
JobQueue::unlock();
|
||||
}
|
||||
|
||||
void pushBack(JOB &data) {
|
||||
JobQueue::lock();
|
||||
std::deque<JOB>::push_back(data);
|
||||
if (!JobQueue::isFull()) {
|
||||
JobQueue::unlock();
|
||||
return;
|
||||
}
|
||||
JobQueue::unlock();
|
||||
JobQueue::signal();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class SharedData {
|
||||
public:
|
||||
SharedData():isAvailable(false) {
|
||||
inputJobs.requestSize = 5;
|
||||
inputJobs.maxSize = 50;
|
||||
}
|
||||
SHARED_DATA sharedData;
|
||||
InputJobQueue inputJobs;
|
||||
OutputJobQueue outputJobs;
|
||||
bool isAvailable;
|
||||
};
|
||||
|
||||
class Thread : public THREAD {
|
||||
public:
|
||||
SHARED_DATA &getSharedData() {
|
||||
if (threadPool->sharedData.isAvailable) {
|
||||
return threadPool->sharedData.sharedData;
|
||||
} else {
|
||||
NGTThrowException("Thread::getSharedData: Shared data is unavailable. No set yet.");
|
||||
}
|
||||
}
|
||||
InputJobQueue &getInputJobQueue() {
|
||||
return threadPool->sharedData.inputJobs;
|
||||
}
|
||||
OutputJobQueue &getOutputJobQueue() {
|
||||
return threadPool->sharedData.outputJobs;
|
||||
}
|
||||
ThreadPool *threadPool;
|
||||
};
|
||||
|
||||
ThreadPool(int s) {
|
||||
size = s;
|
||||
threads = new Thread[s];
|
||||
}
|
||||
|
||||
~ThreadPool() {
|
||||
delete[] threads;
|
||||
}
|
||||
|
||||
void setSharedData(SHARED_DATA d) {
|
||||
sharedData.sharedData = d;
|
||||
sharedData.isAvailable = true;
|
||||
}
|
||||
|
||||
void create() {
|
||||
for (unsigned int i = 0; i < size; i++) {
|
||||
threads[i].threadPool = this;
|
||||
threads[i].threadNo = i;
|
||||
threads[i].start();
|
||||
}
|
||||
}
|
||||
|
||||
void pushInputQueue(JOB &data) {
|
||||
if (!sharedData.inputJobs.underPushing) {
|
||||
sharedData.outputJobs.lock();
|
||||
}
|
||||
sharedData.inputJobs.pushBack(data);
|
||||
}
|
||||
|
||||
void waitForFinish() {
|
||||
sharedData.inputJobs.pushBackEnd();
|
||||
sharedData.outputJobs.setMaxSize(sharedData.inputJobs.pushedSize);
|
||||
sharedData.inputJobs.pushedSize = 0;
|
||||
sharedData.outputJobs.waitForFull();
|
||||
}
|
||||
|
||||
void terminate() {
|
||||
sharedData.inputJobs.terminate();
|
||||
for (unsigned int i = 0; i < size; i++) {
|
||||
threads[i].join();
|
||||
}
|
||||
}
|
||||
|
||||
InputJobQueue &getInputJobQueue() { return sharedData.inputJobs; }
|
||||
OutputJobQueue &getOutputJobQueue() { return sharedData.outputJobs; }
|
||||
|
||||
SharedData sharedData; // shared data
|
||||
Thread *threads; // thread set
|
||||
unsigned int size; // thread size
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,564 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include "NGT/defines.h"
|
||||
|
||||
#include "NGT/Tree.h"
|
||||
#include "NGT/Node.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using namespace NGT;
|
||||
|
||||
void
|
||||
DVPTree::insert(InsertContainer &iobj) {
|
||||
SearchContainer q(iobj.object);
|
||||
q.mode = SearchContainer::SearchLeaf;
|
||||
q.vptree = this;
|
||||
q.radius = 0.0;
|
||||
|
||||
search(q);
|
||||
|
||||
iobj.vptree = this;
|
||||
|
||||
assert(q.nodeID.getType() == Node::ID::Leaf);
|
||||
LeafNode *ln = (LeafNode*)getNode(q.nodeID);
|
||||
insert(iobj, ln);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void
|
||||
DVPTree::insert(InsertContainer &iobj, LeafNode *leafNode)
|
||||
{
|
||||
LeafNode &leaf = *leafNode;
|
||||
size_t fsize = leaf.getObjectSize();
|
||||
if (fsize != 0) {
|
||||
NGT::ObjectSpace::Comparator &comparator = objectSpace->getComparator();
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
Distance d = comparator(iobj.object, leaf.getPivot(*objectSpace));
|
||||
#else
|
||||
Distance d = comparator(iobj.object, leaf.getPivot());
|
||||
#endif
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
NGT::ObjectDistance *objects = leaf.getObjectIDs(leafNodes.allocator);
|
||||
#else
|
||||
NGT::ObjectDistance *objects = leaf.getObjectIDs();
|
||||
#endif
|
||||
|
||||
for (size_t i = 0; i < fsize; i++) {
|
||||
if (objects[i].distance == d) {
|
||||
Distance idd = 0.0;
|
||||
ObjectID loid;
|
||||
try {
|
||||
loid = objects[i].id;
|
||||
idd = comparator(iobj.object, *getObjectRepository().get(loid));
|
||||
} catch (Exception &e) {
|
||||
stringstream msg;
|
||||
msg << "LeafNode::insert: Cannot find object which belongs to a leaf node. id="
|
||||
<< objects[i].id << ":" << e.what() << endl;
|
||||
NGTThrowException(msg.str());
|
||||
}
|
||||
if (idd == 0.0) {
|
||||
if (loid == iobj.id) {
|
||||
stringstream msg;
|
||||
msg << "DVPTree::insert:already existed. " << iobj.id;
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (leaf.getObjectSize() >= leafObjectsSize) {
|
||||
split(iobj, leaf);
|
||||
} else {
|
||||
insertObject(iobj, leaf);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
Node::ID
|
||||
DVPTree::split(InsertContainer &iobj, LeafNode &leaf)
|
||||
{
|
||||
Node::Objects *fs = getObjects(leaf, iobj);
|
||||
int pv = DVPTree::MaxVariance;
|
||||
switch (splitMode) {
|
||||
case DVPTree::MaxVariance:
|
||||
pv = LeafNode::selectPivotByMaxVariance(iobj, *fs);
|
||||
break;
|
||||
case DVPTree::MaxDistance:
|
||||
pv = LeafNode::selectPivotByMaxDistance(iobj, *fs);
|
||||
break;
|
||||
}
|
||||
|
||||
LeafNode::splitObjects(iobj, *fs, pv);
|
||||
|
||||
Node::ID nid = recombineNodes(iobj, *fs, leaf);
|
||||
delete fs;
|
||||
|
||||
return nid;
|
||||
}
|
||||
|
||||
Node::ID
|
||||
DVPTree::recombineNodes(InsertContainer &ic, Node::Objects &fs, LeafNode &leaf)
|
||||
{
|
||||
LeafNode *ln[internalChildrenSize];
|
||||
Node::ID targetParent = leaf.parent;
|
||||
Node::ID targetId = leaf.id;
|
||||
ln[0] = &leaf;
|
||||
ln[0]->objectSize = 0;
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
for (size_t i = 1; i < internalChildrenSize; i++) {
|
||||
ln[i] = new(leafNodes.allocator) LeafNode(leafNodes.allocator);
|
||||
}
|
||||
#else
|
||||
for (size_t i = 1; i < internalChildrenSize; i++) {
|
||||
ln[i] = new LeafNode;
|
||||
}
|
||||
#endif
|
||||
InternalNode *in = createInternalNode();
|
||||
Node::ID inid = in->id;
|
||||
try {
|
||||
if (targetParent.getID() != 0) {
|
||||
InternalNode &pnode = *(InternalNode*)getNode(targetParent);
|
||||
for (size_t i = 0; i < internalChildrenSize; i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
if (pnode.getChildren(internalNodes.allocator)[i] == targetId) {
|
||||
pnode.getChildren(internalNodes.allocator)[i] = inid;
|
||||
#else
|
||||
if (pnode.getChildren()[i] == targetId) {
|
||||
pnode.getChildren()[i] = inid;
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
in->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, internalNodes.allocator);
|
||||
#else
|
||||
in->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace);
|
||||
#endif
|
||||
|
||||
in->parent = targetParent;
|
||||
|
||||
int fsize = fs.size();
|
||||
int cid = fs[0].clusterID;
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
LeafNode::ObjectIDs fid;
|
||||
fid.id = fs[0].id;
|
||||
fid.distance = 0.0;
|
||||
ln[cid]->objectIDs.push_back(fid);
|
||||
#else
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
ln[cid]->getObjectIDs(leafNodes.allocator)[ln[cid]->objectSize].id = fs[0].id;
|
||||
ln[cid]->getObjectIDs(leafNodes.allocator)[ln[cid]->objectSize++].distance = 0.0;
|
||||
#else
|
||||
ln[cid]->getObjectIDs()[ln[cid]->objectSize].id = fs[0].id;
|
||||
ln[cid]->getObjectIDs()[ln[cid]->objectSize++].distance = 0.0;
|
||||
#endif
|
||||
#endif
|
||||
if (fs[0].leafDistance == Node::Object::Pivot) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
ln[cid]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, leafNodes.allocator);
|
||||
#else
|
||||
ln[cid]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace);
|
||||
#endif
|
||||
} else {
|
||||
NGTThrowException("recombineNodes: internal error : illegal pivot.");
|
||||
}
|
||||
ln[cid]->parent = inid;
|
||||
int maxClusterID = cid;
|
||||
for (int i = 1; i < fsize; i++) {
|
||||
int clusterID = fs[i].clusterID;
|
||||
if (clusterID > maxClusterID) {
|
||||
maxClusterID = clusterID;
|
||||
}
|
||||
Distance ld;
|
||||
if (fs[i].leafDistance == Node::Object::Pivot) {
|
||||
// pivot
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
ln[clusterID]->setPivot(*getObjectRepository().get(fs[i].id), *objectSpace, leafNodes.allocator);
|
||||
#else
|
||||
ln[clusterID]->setPivot(*getObjectRepository().get(fs[i].id), *objectSpace);
|
||||
#endif
|
||||
ld = 0.0;
|
||||
} else {
|
||||
ld = fs[i].leafDistance;
|
||||
}
|
||||
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
fid.id = fs[i].id;
|
||||
fid.distance = ld;
|
||||
ln[clusterID]->objectIDs.push_back(fid);
|
||||
#else
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
ln[clusterID]->getObjectIDs(leafNodes.allocator)[ln[clusterID]->objectSize].id = fs[i].id;
|
||||
ln[clusterID]->getObjectIDs(leafNodes.allocator)[ln[clusterID]->objectSize++].distance = ld;
|
||||
#else
|
||||
ln[clusterID]->getObjectIDs()[ln[clusterID]->objectSize].id = fs[i].id;
|
||||
ln[clusterID]->getObjectIDs()[ln[clusterID]->objectSize++].distance = ld;
|
||||
#endif
|
||||
#endif
|
||||
ln[clusterID]->parent = inid;
|
||||
if (clusterID != cid) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
in->getBorders(internalNodes.allocator)[cid] = fs[i].distance;
|
||||
#else
|
||||
in->getBorders()[cid] = fs[i].distance;
|
||||
#endif
|
||||
cid = fs[i].clusterID;
|
||||
}
|
||||
}
|
||||
// When the number of the children is less than the expected,
|
||||
// proper values are set to the empty children.
|
||||
for (size_t i = maxClusterID + 1; i < internalChildrenSize; i++) {
|
||||
ln[i]->parent = inid;
|
||||
// dummy
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
ln[i]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, leafNodes.allocator);
|
||||
#else
|
||||
ln[i]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace);
|
||||
#endif
|
||||
if (i < (internalChildrenSize - 1)) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
in->getBorders(internalNodes.allocator)[i] = FLT_MAX;
|
||||
#else
|
||||
in->getBorders()[i] = FLT_MAX;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
in->getChildren(internalNodes.allocator)[0] = targetId;
|
||||
#else
|
||||
in->getChildren()[0] = targetId;
|
||||
#endif
|
||||
for (size_t i = 1; i < internalChildrenSize; i++) {
|
||||
insertNode(ln[i]);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
in->getChildren(internalNodes.allocator)[i] = ln[i]->id;
|
||||
#else
|
||||
in->getChildren()[i] = ln[i]->id;
|
||||
#endif
|
||||
}
|
||||
} catch(Exception &e) {
|
||||
throw e;
|
||||
}
|
||||
return inid;
|
||||
}
|
||||
|
||||
void
|
||||
DVPTree::insertObject(InsertContainer &ic, LeafNode &leaf) {
|
||||
if (leaf.getObjectSize() == 0) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
leaf.setPivot(*getObjectRepository().get(ic.id), *objectSpace, leafNodes.allocator);
|
||||
#else
|
||||
leaf.setPivot(*getObjectRepository().get(ic.id), *objectSpace);
|
||||
#endif
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
LeafNode::ObjectIDs fid;
|
||||
fid.id = ic.id;
|
||||
fid.distance = 0;
|
||||
leaf.objectIDs.push_back(fid);
|
||||
#else
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize].id = ic.id;
|
||||
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize++].distance = 0;
|
||||
#else
|
||||
leaf.getObjectIDs()[leaf.objectSize].id = ic.id;
|
||||
leaf.getObjectIDs()[leaf.objectSize++].distance = 0;
|
||||
#endif
|
||||
#endif
|
||||
} else {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
Distance d = objectSpace->getComparator()(ic.object, leaf.getPivot(*objectSpace));
|
||||
#else
|
||||
Distance d = objectSpace->getComparator()(ic.object, leaf.getPivot());
|
||||
#endif
|
||||
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
LeafNode::ObjectIDs fid;
|
||||
fid.id = ic.id;
|
||||
fid.distance = d;
|
||||
leaf.objectIDs.push_back(fid);
|
||||
std::sort(leaf.objectIDs.begin(), leaf.objectIDs.end(), LeafNode::ObjectIDs());
|
||||
#else
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize].id = ic.id;
|
||||
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize++].distance = d;
|
||||
#else
|
||||
leaf.getObjectIDs()[leaf.objectSize].id = ic.id;
|
||||
leaf.getObjectIDs()[leaf.objectSize++].distance = d;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
Node::Objects *
|
||||
DVPTree::getObjects(LeafNode &n, Container &iobj)
|
||||
{
|
||||
int size = n.getObjectSize() + 1;
|
||||
|
||||
Node::Objects *fs = new Node::Objects(size);
|
||||
for (size_t i = 0; i < n.getObjectSize(); i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
(*fs)[i].object = getObjectRepository().get(n.getObjectIDs(leafNodes.allocator)[i].id);
|
||||
(*fs)[i].id = n.getObjectIDs(leafNodes.allocator)[i].id;
|
||||
#else
|
||||
(*fs)[i].object = getObjectRepository().get(n.getObjectIDs()[i].id);
|
||||
(*fs)[i].id = n.getObjectIDs()[i].id;
|
||||
#endif
|
||||
}
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
(*fs)[n.getObjectSize()].object = getObjectRepository().get(iobj.id);
|
||||
#else
|
||||
(*fs)[n.getObjectSize()].object = &iobj.object;
|
||||
#endif
|
||||
(*fs)[n.getObjectSize()].id = iobj.id;
|
||||
return fs;
|
||||
}
|
||||
|
||||
void
|
||||
DVPTree::removeEmptyNodes(InternalNode &inode) {
|
||||
|
||||
int csize = internalChildrenSize;
|
||||
|
||||
|
||||
InternalNode *target = &inode;
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
Node::ID *children = target->getChildren(internalNodes.allocator);
|
||||
#else
|
||||
Node::ID *children = target->getChildren();
|
||||
#endif
|
||||
for(;;) {
|
||||
for (int i = 0; i < csize; i++) {
|
||||
if (children[i].getType() == Node::ID::Internal) {
|
||||
return;
|
||||
}
|
||||
LeafNode &ln = *static_cast<LeafNode*>(getNode(children[i]));
|
||||
if (ln.getObjectSize() != 0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < csize; i++) {
|
||||
removeNode(children[i]);
|
||||
}
|
||||
if (target->parent.getID() == 0) {
|
||||
removeNode(target->id);
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
LeafNode *root = new(leafNodes.allocator) LeafNode(leafNodes.allocator);
|
||||
#else
|
||||
LeafNode *root = new LeafNode;
|
||||
#endif
|
||||
insertNode(root);
|
||||
if (root->id.getID() != 1) {
|
||||
NGTThrowException("Root id Error");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
LeafNode *ln = new(leafNodes.allocator) LeafNode(leafNodes.allocator);
|
||||
#else
|
||||
LeafNode *ln = new LeafNode;
|
||||
#endif
|
||||
ln->parent = target->parent;
|
||||
insertNode(ln);
|
||||
|
||||
InternalNode &in = *(InternalNode*)getNode(ln->parent);
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
in.updateChild(*this, target->id, ln->id, internalNodes.allocator);
|
||||
#else
|
||||
in.updateChild(*this, target->id, ln->id);
|
||||
#endif
|
||||
removeNode(target->id);
|
||||
target = ∈
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
DVPTree::search(SearchContainer &sc, InternalNode &node, UncheckedNode &uncheckedNode)
|
||||
{
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
Distance d = objectSpace->getComparator()(sc.object, node.getPivot(*objectSpace));
|
||||
#else
|
||||
Distance d = objectSpace->getComparator()(sc.object, node.getPivot());
|
||||
#endif
|
||||
#ifdef NGT_DISTANCE_COMPUTATION_COUNT
|
||||
sc.distanceComputationCount++;
|
||||
#endif
|
||||
|
||||
int bsize = internalChildrenSize - 1;
|
||||
|
||||
vector<ObjectDistance> regions;
|
||||
regions.reserve(internalChildrenSize);
|
||||
|
||||
ObjectDistance child;
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
Distance *borders = node.getBorders(internalNodes.allocator);
|
||||
#else
|
||||
Distance *borders = node.getBorders();
|
||||
#endif
|
||||
int mid;
|
||||
for (mid = 0; mid < bsize; mid++) {
|
||||
if (d < borders[mid]) {
|
||||
child.id = mid;
|
||||
child.distance = 0.0;
|
||||
regions.push_back(child);
|
||||
if (d + sc.radius < borders[mid]) {
|
||||
break;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (d < borders[mid] + sc.radius) {
|
||||
child.id = mid;
|
||||
child.distance = d - borders[mid];
|
||||
regions.push_back(child);
|
||||
continue;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mid == bsize) {
|
||||
if (d >= borders[mid - 1]) {
|
||||
child.id = mid;
|
||||
child.distance = 0.0;
|
||||
regions.push_back(child);
|
||||
} else {
|
||||
child.id = mid;
|
||||
child.distance = borders[mid - 1] - d;
|
||||
regions.push_back(child);
|
||||
}
|
||||
}
|
||||
|
||||
sort(regions.begin(), regions.end());
|
||||
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
Node::ID *children = node.getChildren(internalNodes.allocator);
|
||||
#else
|
||||
Node::ID *children = node.getChildren();
|
||||
#endif
|
||||
|
||||
vector<ObjectDistance>::iterator i;
|
||||
if (sc.mode == DVPTree::SearchContainer::SearchLeaf) {
|
||||
if (children[regions.front().id].getType() == Node::ID::Leaf) {
|
||||
sc.nodeID.setRaw(children[regions.front().id].get());
|
||||
assert(uncheckedNode.empty());
|
||||
} else {
|
||||
uncheckedNode.push(children[regions.front().id]);
|
||||
}
|
||||
} else {
|
||||
for (i = regions.begin(); i != regions.end(); i++) {
|
||||
uncheckedNode.push(children[i->id]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void
|
||||
DVPTree::search(SearchContainer &so, LeafNode &node, UncheckedNode &uncheckedNode)
|
||||
{
|
||||
DVPTree::SearchContainer &q = (DVPTree::SearchContainer&)so;
|
||||
|
||||
if (node.getObjectSize() == 0) {
|
||||
return;
|
||||
}
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
Distance pq = objectSpace->getComparator()(q.object, node.getPivot(*objectSpace));
|
||||
#else
|
||||
Distance pq = objectSpace->getComparator()(q.object, node.getPivot());
|
||||
#endif
|
||||
#ifdef NGT_DISTANCE_COMPUTATION_COUNT
|
||||
so.distanceComputationCount++;
|
||||
#endif
|
||||
|
||||
ObjectDistance r;
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
NGT::ObjectDistance *objects = node.getObjectIDs(leafNodes.allocator);
|
||||
#else
|
||||
NGT::ObjectDistance *objects = node.getObjectIDs();
|
||||
#endif
|
||||
|
||||
for (size_t i = 0; i < node.getObjectSize(); i++) {
|
||||
if ((objects[i].distance <= pq + q.radius) &&
|
||||
(objects[i].distance >= pq - q.radius)) {
|
||||
Distance d = 0;
|
||||
try {
|
||||
d = objectSpace->getComparator()(q.object, *q.vptree->getObjectRepository().get(objects[i].id));
|
||||
#ifdef NGT_DISTANCE_COMPUTATION_COUNT
|
||||
so.distanceComputationCount++;
|
||||
#endif
|
||||
} catch(...) {
|
||||
NGTThrowException("VpTree::LeafNode::search: Internal fatal error : Cannot get object");
|
||||
}
|
||||
if (d <= q.radius) {
|
||||
r.id = objects[i].id;
|
||||
r.distance = d;
|
||||
so.getResult().push_back(r);
|
||||
std::sort(so.getResult().begin(), so.getResult().end());
|
||||
if (so.getResult().size() > q.size) {
|
||||
so.getResult().resize(q.size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
DVPTree::search(SearchContainer &sc) {
|
||||
((SearchContainer&)sc).vptree = this;
|
||||
Node *root = getRootNode();
|
||||
assert(root != 0);
|
||||
if (sc.mode == DVPTree::SearchContainer::SearchLeaf) {
|
||||
if (root->id.getType() == Node::ID::Leaf) {
|
||||
sc.nodeID.setRaw(root->id.get());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
UncheckedNode uncheckedNode;
|
||||
uncheckedNode.push(root->id);
|
||||
|
||||
while (!uncheckedNode.empty()) {
|
||||
Node::ID nodeid = uncheckedNode.top();
|
||||
uncheckedNode.pop();
|
||||
Node *cnode = getNode(nodeid);
|
||||
if (cnode == 0) {
|
||||
cerr << "Error! child node is null. but continue." << endl;
|
||||
continue;
|
||||
}
|
||||
if (cnode->id.getType() == Node::ID::Internal) {
|
||||
search(sc, (InternalNode&)*cnode, uncheckedNode);
|
||||
} else if (cnode->id.getType() == Node::ID::Leaf) {
|
||||
search(sc, (LeafNode&)*cnode, uncheckedNode);
|
||||
} else {
|
||||
cerr << "Tree: Inner fatal error!: Node type error!" << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,511 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "NGT/Common.h"
|
||||
#include "NGT/Node.h"
|
||||
#include "NGT/defines.h"
|
||||
#include "faiss/utils/ConcurrentBitset.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <stack>
|
||||
#include <set>
|
||||
|
||||
namespace NGT {
|
||||
|
||||
class DVPTree {
|
||||
|
||||
public:
|
||||
enum SplitMode {
|
||||
MaxDistance = 0,
|
||||
MaxVariance = 1
|
||||
};
|
||||
|
||||
typedef std::vector<Node::ID> IDVector;
|
||||
|
||||
class Container : public NGT::Container {
|
||||
public:
|
||||
Container(Object &f, ObjectID i):NGT::Container(f, i) {}
|
||||
DVPTree *vptree;
|
||||
};
|
||||
|
||||
class SearchContainer : public NGT::SearchContainer {
|
||||
public:
|
||||
enum Mode {
|
||||
SearchLeaf = 0,
|
||||
SearchObject = 1
|
||||
};
|
||||
|
||||
SearchContainer(Object &f, ObjectID i):NGT::SearchContainer(f, i) {}
|
||||
SearchContainer(Object &f):NGT::SearchContainer(f, 0) {}
|
||||
|
||||
DVPTree *vptree;
|
||||
|
||||
Mode mode;
|
||||
Node::ID nodeID;
|
||||
};
|
||||
class InsertContainer : public Container {
|
||||
public:
|
||||
InsertContainer(Object &f, ObjectID i):Container(f, i) {}
|
||||
};
|
||||
|
||||
class RemoveContainer : public Container {
|
||||
public:
|
||||
RemoveContainer(Object &f, ObjectID i):Container(f, i) {}
|
||||
};
|
||||
|
||||
DVPTree() {
|
||||
leafObjectsSize = LeafNode::LeafObjectsSizeMax;
|
||||
internalChildrenSize = InternalNode::InternalChildrenSizeMax;
|
||||
splitMode = MaxVariance;
|
||||
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
insertNode(new LeafNode);
|
||||
#endif
|
||||
}
|
||||
|
||||
virtual ~DVPTree() {
|
||||
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
deleteAll();
|
||||
#endif
|
||||
}
|
||||
|
||||
void deleteAll() {
|
||||
for (size_t i = 0; i < leafNodes.size(); i++) {
|
||||
if (leafNodes[i] != 0) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
leafNodes[i]->deletePivot(*objectSpace, leafNodes.allocator);
|
||||
#else
|
||||
leafNodes[i]->deletePivot(*objectSpace);
|
||||
#endif
|
||||
delete leafNodes[i];
|
||||
}
|
||||
}
|
||||
leafNodes.clear();
|
||||
for (size_t i = 0; i < internalNodes.size(); i++) {
|
||||
if (internalNodes[i] != 0) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
internalNodes[i]->deletePivot(*objectSpace, internalNodes.allocator);
|
||||
#else
|
||||
internalNodes[i]->deletePivot(*objectSpace);
|
||||
#endif
|
||||
delete internalNodes[i];
|
||||
}
|
||||
}
|
||||
internalNodes.clear();
|
||||
}
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
void open(const std::string &f, size_t sharedMemorySize) {
|
||||
// If no file, then create a new file.
|
||||
leafNodes.open(f + "l", sharedMemorySize);
|
||||
internalNodes.open(f + "i", sharedMemorySize);
|
||||
if (leafNodes.size() == 0) {
|
||||
if (internalNodes.size() != 0) {
|
||||
NGTThrowException("Tree::Open: Internal error. Internal and leaf are inconsistent.");
|
||||
}
|
||||
LeafNode *ln = leafNodes.allocate();
|
||||
insertNode(ln);
|
||||
}
|
||||
}
|
||||
#endif // NGT_SHARED_MEMORY_ALLOCATOR
|
||||
|
||||
void insert(InsertContainer &iobj);
|
||||
|
||||
void insert(InsertContainer &iobj, LeafNode *n);
|
||||
|
||||
Node::ID split(InsertContainer &iobj, LeafNode &leaf);
|
||||
|
||||
Node::ID recombineNodes(InsertContainer &ic, Node::Objects &fs, LeafNode &leaf);
|
||||
|
||||
void insertObject(InsertContainer &obj, LeafNode &leaf);
|
||||
|
||||
typedef std::stack<Node::ID> UncheckedNode;
|
||||
|
||||
void search(SearchContainer &so);
|
||||
void search(SearchContainer &so, InternalNode &node, UncheckedNode &uncheckedNode);
|
||||
void search(SearchContainer &so, LeafNode &node, UncheckedNode &uncheckedNode);
|
||||
|
||||
bool searchObject(ObjectID id) {
|
||||
LeafNode &ln = getLeaf(id);
|
||||
for (size_t i = 0; i < ln.getObjectSize(); i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
if (ln.getObjectIDs(leafNodes.allocator)[i].id == id) {
|
||||
#else
|
||||
if (ln.getObjectIDs()[i].id == id) {
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
LeafNode &getLeaf(ObjectID id) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
Object *qobject = objectSpace->allocateObject(*getObjectRepository().get(id));
|
||||
SearchContainer q(*qobject);
|
||||
#else
|
||||
SearchContainer q(*getObjectRepository().get(id));
|
||||
#endif
|
||||
q.mode = SearchContainer::SearchLeaf;
|
||||
q.vptree = this;
|
||||
q.radius = 0.0;
|
||||
q.size = 1;
|
||||
|
||||
search(q);
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
objectSpace->deleteObject(qobject);
|
||||
#endif
|
||||
|
||||
return *(LeafNode*)getNode(q.nodeID);
|
||||
|
||||
}
|
||||
|
||||
void replace(ObjectID id, ObjectID replacedId) { remove(id, replacedId); }
|
||||
|
||||
// remove the specified object.
|
||||
void remove(ObjectID id, ObjectID replaceId = 0) {
|
||||
LeafNode &ln = getLeaf(id);
|
||||
try {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
ln.removeObject(id, replaceId, leafNodes.allocator);
|
||||
#else
|
||||
ln.removeObject(id, replaceId);
|
||||
#endif
|
||||
} catch(Exception &err) {
|
||||
std::stringstream msg;
|
||||
msg << "VpTree::remove: Inner error. Cannot remove object. leafNode=" << ln.id.getID() << ":" << err.what();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
if (ln.getObjectSize() == 0) {
|
||||
if (ln.parent.getID() != 0) {
|
||||
InternalNode &inode = *(InternalNode*)getNode(ln.parent);
|
||||
removeEmptyNodes(inode);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void removeNaively(ObjectID id, ObjectID replaceId = 0) {
|
||||
for (size_t i = 0; i < leafNodes.size(); i++) {
|
||||
if (leafNodes[i] != 0) {
|
||||
try {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
leafNodes[i]->removeObject(id, replaceId, leafNodes.allocator);
|
||||
#else
|
||||
leafNodes[i]->removeObject(id, replaceId);
|
||||
#endif
|
||||
break;
|
||||
} catch(...) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Node *getRootNode() {
|
||||
size_t nid = 1;
|
||||
Node *root;
|
||||
try {
|
||||
root = internalNodes.get(nid);
|
||||
} catch(Exception &err) {
|
||||
try {
|
||||
root = leafNodes.get(nid);
|
||||
} catch(Exception &e) {
|
||||
std::stringstream msg;
|
||||
msg << "VpTree::getRootNode: Inner error. Cannot get a leaf root node. " << nid << ":" << e.what();
|
||||
NGTThrowException(msg);
|
||||
}
|
||||
}
|
||||
|
||||
return root;
|
||||
}
|
||||
|
||||
InternalNode *createInternalNode() {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
InternalNode *n = new(internalNodes.allocator) InternalNode(internalChildrenSize, internalNodes.allocator);
|
||||
#else
|
||||
InternalNode *n = new InternalNode(internalChildrenSize);
|
||||
#endif
|
||||
insertNode(n);
|
||||
return n;
|
||||
}
|
||||
|
||||
void
|
||||
removeNode(Node::ID id) {
|
||||
size_t idx = id.getID();
|
||||
if (id.getType() == Node::ID::Leaf) {
|
||||
leafNodes.remove(idx);
|
||||
} else {
|
||||
internalNodes.remove(idx);
|
||||
}
|
||||
}
|
||||
|
||||
void removeEmptyNodes(InternalNode &node);
|
||||
|
||||
Node::Objects * getObjects(LeafNode &n, Container &iobj);
|
||||
|
||||
// for milvus
|
||||
void
|
||||
getObjectIDsFromLeaf(Node::ID nid, ObjectDistances& rl, faiss::ConcurrentBitsetPtr& bitset) {
|
||||
LeafNode& ln = *(LeafNode*)getNode(nid);
|
||||
rl.clear();
|
||||
ObjectDistance r;
|
||||
for (size_t i = 0; i < ln.getObjectSize(); i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
r.id = ln.getObjectIDs(leafNodes.allocator)[i].id;
|
||||
r.distance = ln.getObjectIDs(leafNodes.allocator)[i].distance;
|
||||
#else
|
||||
r.id = ln.getObjectIDs()[i].id;
|
||||
r.distance = ln.getObjectIDs()[i].distance;
|
||||
#endif
|
||||
if (bitset != nullptr && bitset->test(r.id - 1)) {
|
||||
continue;
|
||||
}
|
||||
rl.push_back(r);
|
||||
}
|
||||
return;
|
||||
}
|
||||
void getObjectIDsFromLeaf(Node::ID nid, ObjectDistances &rl) {
|
||||
LeafNode &ln = *(LeafNode*)getNode(nid);
|
||||
rl.clear();
|
||||
ObjectDistance r;
|
||||
for (size_t i = 0; i < ln.getObjectSize(); i++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
r.id = ln.getObjectIDs(leafNodes.allocator)[i].id;
|
||||
r.distance = ln.getObjectIDs(leafNodes.allocator)[i].distance;
|
||||
#else
|
||||
r.id = ln.getObjectIDs()[i].id;
|
||||
r.distance = ln.getObjectIDs()[i].distance;
|
||||
#endif
|
||||
rl.push_back(r);
|
||||
}
|
||||
return;
|
||||
}
|
||||
void
|
||||
insertNode(LeafNode *n) {
|
||||
size_t id = leafNodes.insert(n);
|
||||
n->id.setID(id);
|
||||
n->id.setType(Node::ID::Leaf);
|
||||
}
|
||||
|
||||
// replace
|
||||
void replaceNode(LeafNode *n) {
|
||||
int id = n->id.getID();
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
leafNodes.set(id, n);
|
||||
#else
|
||||
leafNodes[id] = n;
|
||||
#endif
|
||||
}
|
||||
|
||||
void
|
||||
insertNode(InternalNode *n)
|
||||
{
|
||||
size_t id = internalNodes.insert(n);
|
||||
n->id.setID(id);
|
||||
n->id.setType(Node::ID::Internal);
|
||||
}
|
||||
|
||||
Node *getNode(Node::ID &id) {
|
||||
Node *n = 0;
|
||||
Node::NodeID idx = id.getID();
|
||||
if (id.getType() == Node::ID::Leaf) {
|
||||
n = leafNodes.get(idx);
|
||||
} else {
|
||||
n = internalNodes.get(idx);
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
void getAllLeafNodeIDs(std::vector<Node::ID> &leafIDs) {
|
||||
leafIDs.clear();
|
||||
Node *root = getRootNode();
|
||||
if (root->id.getType() == Node::ID::Leaf) {
|
||||
leafIDs.push_back(root->id);
|
||||
return;
|
||||
}
|
||||
UncheckedNode uncheckedNode;
|
||||
uncheckedNode.push(root->id);
|
||||
while (!uncheckedNode.empty()) {
|
||||
Node::ID nodeid = uncheckedNode.top();
|
||||
uncheckedNode.pop();
|
||||
Node *cnode = getNode(nodeid);
|
||||
if (cnode->id.getType() == Node::ID::Internal) {
|
||||
InternalNode &inode = static_cast<InternalNode&>(*cnode);
|
||||
for (size_t ci = 0; ci < internalChildrenSize; ci++) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
uncheckedNode.push(inode.getChildren(internalNodes.allocator)[ci]);
|
||||
#else
|
||||
uncheckedNode.push(inode.getChildren()[ci]);
|
||||
#endif
|
||||
}
|
||||
} else if (cnode->id.getType() == Node::ID::Leaf) {
|
||||
leafIDs.push_back(static_cast<LeafNode&>(*cnode).id);
|
||||
} else {
|
||||
std::cerr << "Tree: Inner fatal error!: Node type error!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for milvus
|
||||
void serialize(std::stringstream & os)
|
||||
{
|
||||
leafNodes.serialize(os, objectSpace);
|
||||
internalNodes.serialize(os, objectSpace);
|
||||
}
|
||||
|
||||
void serialize(std::ofstream &os) {
|
||||
leafNodes.serialize(os, objectSpace);
|
||||
internalNodes.serialize(os, objectSpace);
|
||||
}
|
||||
|
||||
void deserialize(std::ifstream &is) {
|
||||
leafNodes.deserialize(is, objectSpace);
|
||||
internalNodes.deserialize(is, objectSpace);
|
||||
}
|
||||
|
||||
void deserialize(std::stringstream & is)
|
||||
{
|
||||
leafNodes.deserialize(is, objectSpace);
|
||||
internalNodes.deserialize(is, objectSpace);
|
||||
}
|
||||
|
||||
void serializeAsText(std::ofstream &os) {
|
||||
leafNodes.serializeAsText(os, objectSpace);
|
||||
internalNodes.serializeAsText(os, objectSpace);
|
||||
}
|
||||
|
||||
void deserializeAsText(std::ifstream &is) {
|
||||
leafNodes.deserializeAsText(is, objectSpace);
|
||||
internalNodes.deserializeAsText(is, objectSpace);
|
||||
}
|
||||
|
||||
void show() {
|
||||
std::cout << "Show tree " << std::endl;
|
||||
for (size_t i = 0; i < leafNodes.size(); i++) {
|
||||
if (leafNodes[i] != 0) {
|
||||
std::cout << i << ":";
|
||||
(*leafNodes[i]).show();
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < internalNodes.size(); i++) {
|
||||
if (internalNodes[i] != 0) {
|
||||
std::cout << i << ":";
|
||||
(*internalNodes[i]).show();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool verify(size_t objCount, std::vector<uint8_t> &status) {
|
||||
std::cerr << "Started verifying internal nodes. size=" << internalNodes.size() << "..." << std::endl;
|
||||
bool valid = true;
|
||||
for (size_t i = 0; i < internalNodes.size(); i++) {
|
||||
if (internalNodes[i] != 0) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
valid = valid && (*internalNodes[i]).verify(internalNodes, leafNodes, internalNodes.allocator);
|
||||
#else
|
||||
valid = valid && (*internalNodes[i]).verify(internalNodes, leafNodes);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
std::cerr << "Started verifying leaf nodes. size=" << leafNodes.size() << " ..." << std::endl;
|
||||
for (size_t i = 0; i < leafNodes.size(); i++) {
|
||||
if (leafNodes[i] != 0) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
valid = valid && (*leafNodes[i]).verify(objCount, status, leafNodes.allocator);
|
||||
#else
|
||||
valid = valid && (*leafNodes[i]).verify(objCount, status);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return valid;
|
||||
}
|
||||
|
||||
void deleteInMemory() {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
assert(0);
|
||||
#else
|
||||
for (std::vector<NGT::LeafNode*>::iterator i = leafNodes.begin(); i != leafNodes.end(); i++) {
|
||||
if ((*i) != 0) {
|
||||
delete (*i);
|
||||
}
|
||||
}
|
||||
leafNodes.clear();
|
||||
for (std::vector<NGT::InternalNode*>::iterator i = internalNodes.begin(); i != internalNodes.end(); i++) {
|
||||
if ((*i) != 0) {
|
||||
delete (*i);
|
||||
}
|
||||
}
|
||||
internalNodes.clear();
|
||||
#endif
|
||||
}
|
||||
|
||||
ObjectRepository &getObjectRepository() { return objectSpace->getRepository(); }
|
||||
|
||||
size_t getSharedMemorySize(std::ostream &os, SharedMemoryAllocator::GetMemorySizeType t) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
size_t isize = internalNodes.getAllocator().getMemorySize(t);
|
||||
os << "internal=" << isize << std::endl;
|
||||
size_t lsize = leafNodes.getAllocator().getMemorySize(t);
|
||||
os << "leaf=" << lsize << std::endl;
|
||||
return isize + lsize;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void getAllObjectIDs(std::set<ObjectID> &ids) {
|
||||
for (size_t i = 0; i < leafNodes.size(); i++) {
|
||||
if (leafNodes[i] != 0) {
|
||||
LeafNode &ln = *leafNodes[i];
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
auto objs = ln.getObjectIDs(leafNodes.allocator);
|
||||
#else
|
||||
auto objs = ln.getObjectIDs();
|
||||
#endif
|
||||
for (size_t idx = 0; idx < ln.objectSize; ++idx) {
|
||||
ids.insert(objs[idx].id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
size_t internalChildrenSize;
|
||||
size_t leafObjectsSize;
|
||||
|
||||
SplitMode splitMode;
|
||||
|
||||
std::string name;
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
PersistentRepository<LeafNode> leafNodes;
|
||||
PersistentRepository<InternalNode> internalNodes;
|
||||
#else
|
||||
Repository<LeafNode> leafNodes;
|
||||
Repository<InternalNode> internalNodes;
|
||||
#endif
|
||||
|
||||
ObjectSpace *objectSpace;
|
||||
|
||||
};
|
||||
} // namespace DVPTree
|
||||
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include "NGT/Version.h"
|
||||
|
||||
void
|
||||
NGT::Version::get(std::ostream &os)
|
||||
{
|
||||
os << " Version:" << NGT::Version::getVersion() << std::endl;
|
||||
os << " Built date:" << NGT::Version::getBuildDate() << std::endl;
|
||||
os << " The last git tag:" << Version::getGitTag() << std::endl;
|
||||
os << " The last git commit hash:" << Version::getGitHash() << std::endl;
|
||||
os << " The last git commit date:" << Version::getGitDate() << std::endl;
|
||||
}
|
||||
|
||||
const std::string
|
||||
NGT::Version::getVersion()
|
||||
{
|
||||
return NGT_VERSION;
|
||||
}
|
||||
|
||||
const std::string
|
||||
NGT::Version::getBuildDate()
|
||||
{
|
||||
return NGT_BUILD_DATE;
|
||||
}
|
||||
|
||||
const std::string
|
||||
NGT::Version::getGitHash()
|
||||
{
|
||||
return NGT_GIT_HASH;
|
||||
}
|
||||
|
||||
const std::string
|
||||
NGT::Version::getGitDate()
|
||||
{
|
||||
return NGT_GIT_DATE;
|
||||
}
|
||||
|
||||
const std::string
|
||||
NGT::Version::getGitTag()
|
||||
{
|
||||
return NGT_GIT_TAG;
|
||||
}
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#ifndef NGT_VERSION
|
||||
#define NGT_VERSION "-"
|
||||
#endif
|
||||
#ifndef NGT_BUILD_DATE
|
||||
#define NGT_BUILD_DATE "-"
|
||||
#endif
|
||||
#ifndef NGT_GIT_HASH
|
||||
#define NGT_GIT_HASH "-"
|
||||
#endif
|
||||
#ifndef NGT_GIT_DATE
|
||||
#define NGT_GIT_DATE "-"
|
||||
#endif
|
||||
#ifndef NGT_GIT_TAG
|
||||
#define NGT_GIT_TAG "-"
|
||||
#endif
|
||||
|
||||
namespace NGT {
|
||||
class Version {
|
||||
public:
|
||||
static void
|
||||
get(std::ostream& os);
|
||||
static const std::string
|
||||
getVersion();
|
||||
static const std::string
|
||||
getBuildDate();
|
||||
static const std::string
|
||||
getGitHash();
|
||||
static const std::string
|
||||
getGitDate();
|
||||
static const std::string
|
||||
getGitTag();
|
||||
static const std::string
|
||||
get();
|
||||
};
|
||||
|
||||
}; // namespace NGT
|
||||
|
||||
#ifdef NGT_VERSION_FOR_HEADER
|
||||
#include "Version.cpp"
|
||||
#endif
|
|
@ -0,0 +1,60 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
// Begin of cmake defines
|
||||
#if 0
|
||||
#cmakedefine NGT_SHARED_MEMORY_ALLOCATOR // use shared memory for indexes
|
||||
#cmakedefine NGT_GRAPH_CHECK_VECTOR // use vector to check whether accessed
|
||||
#cmakedefine NGT_AVX_DISABLED // not use avx to compare
|
||||
#cmakedefine NGT_LARGE_DATASET // more than 10M objects
|
||||
#cmakedefine NGT_DISTANCE_COMPUTATION_COUNT // count # of distance computations
|
||||
#endif
|
||||
// End of cmake defines
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Release Definitions for OSS
|
||||
|
||||
//#define NGT_DISTANCE_COMPUTATION_COUNT
|
||||
|
||||
#define NGT_CREATION_EDGE_SIZE 10
|
||||
#define NGT_EXPLORATION_COEFFICIENT 1.1
|
||||
#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1
|
||||
#define NGT_SHARED_MEMORY_MAX_SIZE 1024 // MB
|
||||
#define NGT_FORCED_REMOVE // When errors occur due to the index inconsistency, ignore them.
|
||||
|
||||
#define NGT_COMPACT_VECTOR
|
||||
#define NGT_GRAPH_READ_ONLY_GRAPH
|
||||
|
||||
#ifdef NGT_LARGE_DATASET
|
||||
#define NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET
|
||||
#else
|
||||
#define NGT_GRAPH_CHECK_VECTOR
|
||||
#endif
|
||||
|
||||
#if defined(NGT_AVX_DISABLED)
|
||||
#define NGT_NO_AVX
|
||||
#else
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
#define NGT_AVX512
|
||||
#elif defined(__AVX2__)
|
||||
#define NGT_AVX2
|
||||
#else
|
||||
#define NGT_NO_AVX
|
||||
#endif
|
||||
#endif
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
// Begin of cmake defines
|
||||
#cmakedefine NGT_SHARED_MEMORY_ALLOCATOR // use shared memory for indexes
|
||||
#cmakedefine NGT_GRAPH_CHECK_VECTOR // use vector to check whether accessed
|
||||
#cmakedefine NGT_AVX_DISABLED // not use avx to compare
|
||||
#cmakedefine NGT_LARGE_DATASET // more than 10M objects
|
||||
#cmakedefine NGT_DISTANCE_COMPUTATION_COUNT // count # of distance computations
|
||||
// End of cmake defines
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Release Definitions for OSS
|
||||
|
||||
//#define NGT_DISTANCE_COMPUTATION_COUNT
|
||||
|
||||
#define NGT_CREATION_EDGE_SIZE 10
|
||||
#define NGT_EXPLORATION_COEFFICIENT 1.1
|
||||
#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1
|
||||
#define NGT_SHARED_MEMORY_MAX_SIZE 1024 // MB
|
||||
#define NGT_FORCED_REMOVE // When errors occur due to the index inconsistency, ignore them.
|
||||
|
||||
#define NGT_COMPACT_VECTOR
|
||||
#define NGT_GRAPH_READ_ONLY_GRAPH
|
||||
|
||||
#ifdef NGT_LARGE_DATASET
|
||||
#define NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET
|
||||
#else
|
||||
#define NGT_GRAPH_CHECK_VECTOR
|
||||
#endif
|
||||
|
||||
#if defined(NGT_AVX_DISABLED)
|
||||
#define NGT_NO_AVX
|
||||
#else
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
#define NGT_AVX512
|
||||
#elif defined(__AVX2__)
|
||||
#define NGT_AVX2
|
||||
#else
|
||||
#define NGT_NO_AVX
|
||||
#endif
|
||||
#endif
|
||||
|
|
@ -1,11 +1,13 @@
|
|||
include_directories(${INDEX_SOURCE_DIR}/thirdparty)
|
||||
include_directories(${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService)
|
||||
include_directories(${INDEX_SOURCE_DIR}/thirdparty/NGT/lib)
|
||||
include_directories(${INDEX_SOURCE_DIR}/knowhere)
|
||||
include_directories(${INDEX_SOURCE_DIR})
|
||||
|
||||
set(depend_libs
|
||||
gtest gmock gtest_main gmock_main
|
||||
faiss fiu
|
||||
ngt
|
||||
)
|
||||
if (FAISS_WITH_MKL)
|
||||
set(depend_libs ${depend_libs}
|
||||
|
@ -268,3 +270,28 @@ install(TARGETS test_structured_index_sort DESTINATION unittest)
|
|||
|
||||
#add_subdirectory(faiss_benchmark)
|
||||
#add_subdirectory(metric_alg_benchmark)
|
||||
################################################################################
|
||||
#<NGTPANNG-TEST>
|
||||
set(ngtpanng_srcs
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGTPANNG.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGT.cpp
|
||||
)
|
||||
if (NOT TARGET test_ngtpanng)
|
||||
add_executable(test_ngtpanng test_ngtpanng.cpp ${ngtpanng_srcs} ${util_srcs})
|
||||
endif ()
|
||||
target_link_libraries(test_ngtpanng ${depend_libs} ${unittest_libs} ${basic_libs})
|
||||
install(TARGETS test_ngtpanng DESTINATION unittest)
|
||||
|
||||
################################################################################
|
||||
#<NGTPANNG-TEST>
|
||||
set(ngtonng_srcs
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGTONNG.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGT.cpp
|
||||
)
|
||||
if (NOT TARGET test_ngtonng)
|
||||
add_executable(test_ngtonng test_ngtonng.cpp ${ngtonng_srcs} ${util_srcs})
|
||||
endif ()
|
||||
target_link_libraries(test_ngtonng ${depend_libs} ${unittest_libs} ${basic_libs})
|
||||
install(TARGETS test_ngtonng DESTINATION unittest)
|
||||
|
||||
################################################################################
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexNGTONNG.h"
|
||||
|
||||
#include "unittest/utils.h"
|
||||
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class NGTONNGTest : public DataGen, public TestWithParam<std::string> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
IndexType = GetParam();
|
||||
Generate(128, 10000, 10);
|
||||
index_ = std::make_shared<milvus::knowhere::IndexNGTONNG>();
|
||||
conf = milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, dim},
|
||||
{milvus::knowhere::meta::TOPK, 10},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
};
|
||||
}
|
||||
|
||||
protected:
|
||||
milvus::knowhere::Config conf;
|
||||
std::shared_ptr<milvus::knowhere::IndexNGTONNG> index_ = nullptr;
|
||||
std::string IndexType;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(NGTONNGParameters, NGTONNGTest, Values("NGTONNG"));
|
||||
|
||||
TEST_P(NGTONNGTest, ngtonng_basic) {
|
||||
assert(!xb.empty());
|
||||
|
||||
// null index
|
||||
{
|
||||
ASSERT_ANY_THROW(index_->Train(base_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Serialize(conf));
|
||||
ASSERT_ANY_THROW(index_->Add(base_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Count());
|
||||
ASSERT_ANY_THROW(index_->Dim());
|
||||
}
|
||||
|
||||
index_->BuildAll(base_dataset, conf); // Train + Add
|
||||
ASSERT_EQ(index_->Count(), nb);
|
||||
ASSERT_EQ(index_->Dim(), dim);
|
||||
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, k);
|
||||
}
|
||||
|
||||
TEST_P(NGTONNGTest, ngtonng_delete) {
|
||||
assert(!xb.empty());
|
||||
|
||||
index_->BuildAll(base_dataset, conf); // Train + Add
|
||||
ASSERT_EQ(index_->Count(), nb);
|
||||
ASSERT_EQ(index_->Dim(), dim);
|
||||
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
bitset->set(i);
|
||||
}
|
||||
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result1, nq, k);
|
||||
|
||||
index_->SetBlacklist(bitset);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
}
|
||||
|
||||
TEST_P(NGTONNGTest, ngtonng_serialize) {
|
||||
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
{
|
||||
// write and flush
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
}
|
||||
|
||||
FileIOReader reader(filename);
|
||||
reader(ret, bin->size);
|
||||
};
|
||||
|
||||
{
|
||||
// serialize index
|
||||
index_->BuildAll(base_dataset, conf);
|
||||
auto binaryset = index_->Serialize(milvus::knowhere::Config());
|
||||
|
||||
auto bin_obj_data = binaryset.GetByName("ngt_obj_data");
|
||||
std::string filename1 = "/tmp/ngt_obj_data_serialize.bin";
|
||||
auto load_data1 = new uint8_t[bin_obj_data->size];
|
||||
serialize(filename1, bin_obj_data, load_data1);
|
||||
|
||||
auto bin_grp_data = binaryset.GetByName("ngt_grp_data");
|
||||
std::string filename2 = "/tmp/ngt_grp_data_serialize.bin";
|
||||
auto load_data2 = new uint8_t[bin_grp_data->size];
|
||||
serialize(filename2, bin_grp_data, load_data2);
|
||||
|
||||
auto bin_prf_data = binaryset.GetByName("ngt_prf_data");
|
||||
std::string filename3 = "/tmp/ngt_prf_data_serialize.bin";
|
||||
auto load_data3 = new uint8_t[bin_prf_data->size];
|
||||
serialize(filename3, bin_prf_data, load_data3);
|
||||
|
||||
auto bin_tre_data = binaryset.GetByName("ngt_tre_data");
|
||||
std::string filename4 = "/tmp/ngt_tre_data_serialize.bin";
|
||||
auto load_data4 = new uint8_t[bin_tre_data->size];
|
||||
serialize(filename4, bin_tre_data, load_data4);
|
||||
|
||||
binaryset.clear();
|
||||
std::shared_ptr<uint8_t[]> obj_data(load_data1);
|
||||
binaryset.Append("ngt_obj_data", obj_data, bin_obj_data->size);
|
||||
|
||||
std::shared_ptr<uint8_t[]> grp_data(load_data2);
|
||||
binaryset.Append("ngt_grp_data", grp_data, bin_grp_data->size);
|
||||
|
||||
std::shared_ptr<uint8_t[]> prf_data(load_data3);
|
||||
binaryset.Append("ngt_prf_data", prf_data, bin_prf_data->size);
|
||||
|
||||
std::shared_ptr<uint8_t[]> tre_data(load_data4);
|
||||
binaryset.Append("ngt_tre_data", tre_data, bin_tre_data->size);
|
||||
|
||||
index_->Load(binaryset);
|
||||
ASSERT_EQ(index_->Count(), nb);
|
||||
ASSERT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexNGTPANNG.h"
|
||||
|
||||
#include "unittest/utils.h"
|
||||
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class NGTPANNGTest : public DataGen, public TestWithParam<std::string> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
IndexType = GetParam();
|
||||
Generate(128, 10000, 10);
|
||||
index_ = std::make_shared<milvus::knowhere::IndexNGTPANNG>();
|
||||
conf = milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, dim},
|
||||
{milvus::knowhere::meta::TOPK, 10},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
};
|
||||
}
|
||||
|
||||
protected:
|
||||
milvus::knowhere::Config conf;
|
||||
std::shared_ptr<milvus::knowhere::IndexNGTPANNG> index_ = nullptr;
|
||||
std::string IndexType;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(NGTPANNGParameters, NGTPANNGTest, Values("NGTPANNG"));
|
||||
|
||||
TEST_P(NGTPANNGTest, ngtpanng_basic) {
|
||||
assert(!xb.empty());
|
||||
|
||||
// null index
|
||||
{
|
||||
ASSERT_ANY_THROW(index_->Train(base_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Serialize(conf));
|
||||
ASSERT_ANY_THROW(index_->Add(base_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Count());
|
||||
ASSERT_ANY_THROW(index_->Dim());
|
||||
}
|
||||
|
||||
index_->BuildAll(base_dataset, conf);
|
||||
ASSERT_EQ(index_->Count(), nb);
|
||||
ASSERT_EQ(index_->Dim(), dim);
|
||||
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, k);
|
||||
}
|
||||
|
||||
TEST_P(NGTPANNGTest, ngtpanng_delete) {
|
||||
assert(!xb.empty());
|
||||
|
||||
index_->BuildAll(base_dataset, conf);
|
||||
ASSERT_EQ(index_->Count(), nb);
|
||||
ASSERT_EQ(index_->Dim(), dim);
|
||||
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
bitset->set(i);
|
||||
}
|
||||
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result1, nq, k);
|
||||
|
||||
index_->SetBlacklist(bitset);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
}
|
||||
|
||||
TEST_P(NGTPANNGTest, ngtpanng_serialize) {
|
||||
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
{
|
||||
// write and flush
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
}
|
||||
|
||||
FileIOReader reader(filename);
|
||||
reader(ret, bin->size);
|
||||
};
|
||||
|
||||
{
|
||||
// serialize index
|
||||
index_->BuildAll(base_dataset, conf);
|
||||
auto binaryset = index_->Serialize(milvus::knowhere::Config());
|
||||
|
||||
auto bin_obj_data = binaryset.GetByName("ngt_obj_data");
|
||||
std::string filename1 = "/tmp/ngt_obj_data_serialize.bin";
|
||||
auto load_data1 = new uint8_t[bin_obj_data->size];
|
||||
serialize(filename1, bin_obj_data, load_data1);
|
||||
|
||||
auto bin_grp_data = binaryset.GetByName("ngt_grp_data");
|
||||
std::string filename2 = "/tmp/ngt_grp_data_serialize.bin";
|
||||
auto load_data2 = new uint8_t[bin_grp_data->size];
|
||||
serialize(filename2, bin_grp_data, load_data2);
|
||||
|
||||
auto bin_prf_data = binaryset.GetByName("ngt_prf_data");
|
||||
std::string filename3 = "/tmp/ngt_prf_data_serialize.bin";
|
||||
auto load_data3 = new uint8_t[bin_prf_data->size];
|
||||
serialize(filename3, bin_prf_data, load_data3);
|
||||
|
||||
auto bin_tre_data = binaryset.GetByName("ngt_tre_data");
|
||||
std::string filename4 = "/tmp/ngt_tre_data_serialize.bin";
|
||||
auto load_data4 = new uint8_t[bin_tre_data->size];
|
||||
serialize(filename4, bin_tre_data, load_data4);
|
||||
|
||||
binaryset.clear();
|
||||
std::shared_ptr<uint8_t[]> obj_data(load_data1);
|
||||
binaryset.Append("ngt_obj_data", obj_data, bin_obj_data->size);
|
||||
|
||||
std::shared_ptr<uint8_t[]> grp_data(load_data2);
|
||||
binaryset.Append("ngt_grp_data", grp_data, bin_grp_data->size);
|
||||
|
||||
std::shared_ptr<uint8_t[]> prf_data(load_data3);
|
||||
binaryset.Append("ngt_prf_data", prf_data, bin_prf_data->size);
|
||||
|
||||
std::shared_ptr<uint8_t[]> tre_data(load_data4);
|
||||
binaryset.Append("ngt_tre_data", tre_data, bin_tre_data->size);
|
||||
|
||||
index_->Load(binaryset);
|
||||
ASSERT_EQ(index_->Count(), nb);
|
||||
ASSERT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
}
|
||||
}
|
|
@ -191,6 +191,8 @@ ValidateIndexType(std::string& index_type) {
|
|||
knowhere::IndexEnum::INDEX_RHNSWFlat,
|
||||
knowhere::IndexEnum::INDEX_RHNSWPQ,
|
||||
knowhere::IndexEnum::INDEX_RHNSWSQ,
|
||||
knowhere::IndexEnum::INDEX_NGTPANNG,
|
||||
knowhere::IndexEnum::INDEX_NGTONNG,
|
||||
|
||||
// structured index names
|
||||
engine::DEFAULT_STRUCTURED_INDEX,
|
||||
|
|
|
@ -26,6 +26,8 @@ const char* NAME_ENGINE_TYPE_ANNOY = "ANNOY";
|
|||
const char* NAME_ENGINE_TYPE_RHNSWFLAT = "RHNSW_FLAT";
|
||||
const char* NAME_ENGINE_TYPE_RHNSWPQ = "RHNSW_PQ";
|
||||
const char* NAME_ENGINE_TYPE_RHNSWSQ8 = "RHNSW_SQ8";
|
||||
const char* NAME_ENGINE_TYPE_NGTPANNG = "NGTPANNG";
|
||||
const char* NAME_ENGINE_TYPE_NGTONNG = "NGTONNG";
|
||||
|
||||
const char* NAME_METRIC_TYPE_L2 = "L2";
|
||||
const char* NAME_METRIC_TYPE_IP = "IP";
|
||||
|
|
|
@ -29,6 +29,8 @@ extern const char* NAME_ENGINE_TYPE_ANNOY;
|
|||
extern const char* NAME_ENGINE_TYPE_RHNSWFLAT;
|
||||
extern const char* NAME_ENGINE_TYPE_RHNSWPQ;
|
||||
extern const char* NAME_ENGINE_TYPE_RHNSWSQ;
|
||||
extern const char* NAME_ENGINE_TYPE_NGTPANNG;
|
||||
extern const char* NAME_ENGINE_TYPE_NGTONNG;
|
||||
|
||||
extern const char* NAME_METRIC_TYPE_L2;
|
||||
extern const char* NAME_METRIC_TYPE_IP;
|
||||
|
|
|
@ -263,7 +263,7 @@ void
|
|||
ClientTest::CreateIndex(const std::string& collection_name, int64_t nlist) {
|
||||
milvus_sdk::TimeRecorder rc("Create index");
|
||||
std::cout << "Wait until create all index done" << std::endl;
|
||||
JSON json_params = {{"index_type", "IVF_FLAT"}, {"metric_type", "L2"}, {"params", {{"nlist", nlist}}}};
|
||||
JSON json_params = {{"index_type", "NGT_PANNG"}, {"metric_type", "L2"}};
|
||||
milvus::IndexParam index1 = {collection_name, "field_vec", json_params.dump()};
|
||||
milvus_sdk::Utils::PrintIndexParam(index1);
|
||||
milvus::Status stat = conn_->CreateIndex(index1);
|
||||
|
|
|
@ -123,6 +123,10 @@ Utils::IndexTypeName(const milvus::IndexType& index_type) {
|
|||
return "RHNSWPQ";
|
||||
case milvus::IndexType::ANNOY:
|
||||
return "ANNOY";
|
||||
case milvus::IndexType::NGTPANNG:
|
||||
return "NGTPANNG";
|
||||
case milvus::IndexType::NGTONNG:
|
||||
return "NGTONNG";
|
||||
default:
|
||||
return "Unknown index type";
|
||||
}
|
||||
|
|
|
@ -43,6 +43,8 @@ enum class IndexType {
|
|||
RHNSWFLAT = 13,
|
||||
RHNSWPQ = 14,
|
||||
RHNSWSQ = 15,
|
||||
NGTPANNG = 16,
|
||||
NGTONNG = 17,
|
||||
};
|
||||
|
||||
enum class MetricType {
|
||||
|
|
Loading…
Reference in New Issue