Fix missing file

Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
pull/4973/head^2
zhenshan.cao 2020-12-18 18:35:03 +08:00 committed by yefu.chen
parent 0110ba6bd2
commit 0d75840ed6
262 changed files with 7922 additions and 4584 deletions

View File

@ -1,17 +0,0 @@
package main
import (
"fmt"
"os"
"github.com/zilliztech/milvus-distributed/internal/storage"
)
func main() {
if len(os.Args) == 1 {
fmt.Println("usage: binlog file1 file2 ...")
}
if err := storage.PrintBinlogFiles(os.Args[1:]); err != nil {
fmt.Printf("error: %s\n", err.Error())
}
}

View File

@ -17,4 +17,4 @@ target_sources( cache PRIVATE ${CACHE_FILES}
Cache.inl
)
target_include_directories( cache PUBLIC ${MILVUS_ENGINE_SRC}/cache )
target_link_libraries( cache PRIVATE fiu)

View File

@ -12,7 +12,6 @@
#pragma once
#include "Cache.h"
// #include "s/Metrics.h"
#include "utils/Log.h"
#include <memory>

View File

@ -47,7 +47,6 @@ CacheMgr<ItemObj>::GetItem(const std::string& key) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return nullptr;
}
// server::Metrics::GetInstance().CacheAccessTotalIncrement();
return cache_->get(key);
}
@ -59,7 +58,6 @@ CacheMgr<ItemObj>::InsertItem(const std::string& key, const ItemObj& data) {
return;
}
cache_->insert(key, data);
// server::Metrics::GetInstance().CacheAccessTotalIncrement();
}
template <typename ItemObj>
@ -70,7 +68,6 @@ CacheMgr<ItemObj>::EraseItem(const std::string& key) {
return;
}
cache_->erase(key);
// server::Metrics::GetInstance().CacheAccessTotalIncrement();
}
template <typename ItemObj>

View File

@ -13,36 +13,42 @@
#include <utility>
// #include <fiu/fiu-local.h>
#include <fiu/fiu-local.h>
#include "config/ServerConfig.h"
#include "utils/Log.h"
#include "value/config/ServerConfig.h"
namespace milvus {
namespace cache {
CpuCacheMgr::CpuCacheMgr() {
// cache_ = std::make_shared<Cache<DataObjPtr>>(config.cache.cache_size(), 1UL << 32, "[CACHE CPU]");
// if (config.cache.cpu_cache_threshold() > 0.0) {
// cache_->set_freemem_percent(config.cache.cpu_cache_threshold());
// }
ConfigMgr::GetInstance().Attach("cache.cache_size", this);
}
CpuCacheMgr::~CpuCacheMgr() {
ConfigMgr::GetInstance().Detach("cache.cache_size", this);
}
CpuCacheMgr&
CpuCacheMgr::GetInstance() {
static CpuCacheMgr s_mgr;
return s_mgr;
}
CpuCacheMgr::CpuCacheMgr() {
cache_ = std::make_shared<Cache<DataObjPtr>>(config.cache.cache_size(), 1UL << 32, "[CACHE CPU]");
if (config.cache.cpu_cache_threshold() > 0.0) {
cache_->set_freemem_percent(config.cache.cpu_cache_threshold());
}
ConfigMgr::GetInstance().Attach("cache.cache_size", this);
}
CpuCacheMgr::~CpuCacheMgr() {
ConfigMgr::GetInstance().Detach("cache.cache_size", this);
}
DataObjPtr
CpuCacheMgr::GetItem(const std::string& key) {
auto ret = CacheMgr<DataObjPtr>::GetItem(key);
return ret;
}
void
CpuCacheMgr::ConfigUpdate(const std::string& name) {
// SetCapacity(config.cache.cache_size());
SetCapacity(config.cache.cache_size());
}
} // namespace cache

View File

@ -16,20 +16,24 @@
#include "cache/CacheMgr.h"
#include "cache/DataObj.h"
#include "config/ConfigMgr.h"
#include "value/config/ConfigMgr.h"
namespace milvus {
namespace cache {
class CpuCacheMgr : public CacheMgr<DataObjPtr>, public ConfigObserver {
public:
static CpuCacheMgr&
GetInstance();
private:
CpuCacheMgr();
~CpuCacheMgr();
public:
static CpuCacheMgr&
GetInstance();
DataObjPtr
GetItem(const std::string& key) override;
public:
void

View File

@ -20,8 +20,6 @@ class DataObj {
public:
virtual int64_t
Size() = 0;
public:
virtual ~DataObj() = default;
};

View File

@ -10,10 +10,10 @@
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "cache/GpuCacheMgr.h"
#include "config/ServerConfig.h"
#include "utils/Log.h"
#include "value/config/ServerConfig.h"
// #include <fiu/fiu-local.h>
#include <fiu/fiu-local.h>
#include <sstream>
#include <utility>

View File

@ -17,7 +17,7 @@
#include "cache/CacheMgr.h"
#include "cache/DataObj.h"
#include "config/ConfigMgr.h"
#include "value/config/ConfigMgr.h"
namespace milvus {
namespace cache {

View File

@ -1,31 +0,0 @@
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
# library
set( CONFIG_SRCS ConfigMgr.h
ConfigMgr.cpp
ConfigType.h
ConfigType.cpp
ServerConfig.h
ServerConfig.cpp
)
set( CONFIG_LIBS yaml-cpp
)
create_library(
TARGET config
SRCS ${CONFIG_SRCS}
LIBS ${CONFIG_LIBS}
)

View File

@ -1,219 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <yaml-cpp/yaml.h>
#include <cstring>
#include <limits>
#include <unordered_map>
#include <iostream>
#include "config/ConfigMgr.h"
#include "config/ServerConfig.h"
namespace {
const int64_t MB = (1024ll * 1024);
const int64_t GB = (1024ll * 1024 * 1024);
void
Flatten(const YAML::Node& node, std::unordered_map<std::string, std::string>& target, const std::string& prefix) {
for (auto& it : node) {
auto key = prefix.empty() ? it.first.as<std::string>() : prefix + "." + it.first.as<std::string>();
switch (it.second.Type()) {
case YAML::NodeType::Null: {
target[key] = "";
break;
}
case YAML::NodeType::Scalar: {
target[key] = it.second.as<std::string>();
break;
}
case YAML::NodeType::Sequence: {
std::string value;
for (auto& sub : it.second) value += sub.as<std::string>() + ",";
target[key] = value;
break;
}
case YAML::NodeType::Map: {
Flatten(it.second, target, key);
break;
}
case YAML::NodeType::Undefined: {
throw "Unexpected";
}
default:
break;
}
}
}
void
ThrowIfNotSuccess(const milvus::ConfigStatus& cs) {
if (cs.set_return != milvus::SetReturn::SUCCESS) {
throw cs;
}
}
}; // namespace
namespace milvus {
ConfigMgr ConfigMgr::instance;
ConfigMgr::ConfigMgr() {
config_list_ = {
/* general */
{"timezone", CreateStringConfig("timezone", false, &config.timezone.value, "UTC+8", nullptr, nullptr)},
/* network */
{"network.address",
CreateStringConfig("network.address", false, &config.network.address.value, "0.0.0.0", nullptr, nullptr)},
{"network.port",
CreateIntegerConfig("network.port", false, 0, 65535, &config.network.port.value, 19530, nullptr, nullptr)},
/* pulsar */
{"pulsar.address",
CreateStringConfig("pulsar.address", false, &config.pulsar.address.value, "localhost", nullptr, nullptr)},
{"pulsar.port",
CreateIntegerConfig("pulsar.port", false, 0, 65535, &config.pulsar.port.value, 6650, nullptr, nullptr)},
/* log */
{"logs.level", CreateStringConfig("logs.level", false, &config.logs.level.value, "debug", nullptr, nullptr)},
{"logs.trace.enable",
CreateBoolConfig("logs.trace.enable", false, &config.logs.trace.enable.value, true, nullptr, nullptr)},
{"logs.path",
CreateStringConfig("logs.path", false, &config.logs.path.value, "/var/lib/milvus/logs", nullptr, nullptr)},
{"logs.max_log_file_size", CreateSizeConfig("logs.max_log_file_size", false, 512 * MB, 4096 * MB,
&config.logs.max_log_file_size.value, 1024 * MB, nullptr, nullptr)},
{"logs.log_rotate_num", CreateIntegerConfig("logs.log_rotate_num", false, 0, 1024,
&config.logs.log_rotate_num.value, 0, nullptr, nullptr)},
/* tracing */
{"tracing.json_config_path", CreateStringConfig("tracing.json_config_path", false,
&config.tracing.json_config_path.value, "", nullptr, nullptr)},
/* invisible */
/* engine */
{"engine.build_index_threshold",
CreateIntegerConfig("engine.build_index_threshold", false, 0, std::numeric_limits<int64_t>::max(),
&config.engine.build_index_threshold.value, 4096, nullptr, nullptr)},
{"engine.search_combine_nq",
CreateIntegerConfig("engine.search_combine_nq", true, 0, std::numeric_limits<int64_t>::max(),
&config.engine.search_combine_nq.value, 64, nullptr, nullptr)},
{"engine.use_blas_threshold",
CreateIntegerConfig("engine.use_blas_threshold", true, 0, std::numeric_limits<int64_t>::max(),
&config.engine.use_blas_threshold.value, 1100, nullptr, nullptr)},
{"engine.omp_thread_num",
CreateIntegerConfig("engine.omp_thread_num", true, 0, std::numeric_limits<int64_t>::max(),
&config.engine.omp_thread_num.value, 0, nullptr, nullptr)},
{"engine.simd_type", CreateEnumConfig("engine.simd_type", false, &SimdMap, &config.engine.simd_type.value,
SimdType::AUTO, nullptr, nullptr)},
};
}
void
ConfigMgr::Init() {
std::lock_guard<std::mutex> lock(GetConfigMutex());
for (auto& kv : config_list_) {
kv.second->Init();
}
}
void
ConfigMgr::Load(const std::string& path) {
/* load from milvus.yaml */
auto yaml = YAML::LoadFile(path);
/* make it flattened */
std::unordered_map<std::string, std::string> flattened;
// auto proxy_yaml = yaml["porxy"];
auto other_yaml = YAML::Node{};
other_yaml["pulsar"] = yaml["pulsar"];
Flatten(yaml["proxy"], flattened, "");
Flatten(other_yaml, flattened, "");
// Flatten(yaml["proxy"], flattened, "");
/* update config */
for (auto& it : flattened) Set(it.first, it.second, false);
}
void
ConfigMgr::Set(const std::string& name, const std::string& value, bool update) {
std::cout << "InSet Config " << name << std::endl;
if (config_list_.find(name) == config_list_.end()) {
std::cout << "Config " << name << " not found!" << std::endl;
return;
}
try {
auto& config = config_list_.at(name);
std::unique_lock<std::mutex> lock(GetConfigMutex());
/* update=false when loading from config file */
if (not update) {
ThrowIfNotSuccess(config->Set(value, update));
} else if (config->modifiable_) {
/* set manually */
ThrowIfNotSuccess(config->Set(value, update));
lock.unlock();
Notify(name);
} else {
throw ConfigStatus(SetReturn::IMMUTABLE, "Config " + name + " is not modifiable");
}
} catch (ConfigStatus& cs) {
throw cs;
} catch (...) {
throw "Config " + name + " not found.";
}
}
std::string
ConfigMgr::Get(const std::string& name) const {
try {
auto& config = config_list_.at(name);
std::lock_guard<std::mutex> lock(GetConfigMutex());
return config->Get();
} catch (...) {
throw "Config " + name + " not found.";
}
}
std::string
ConfigMgr::Dump() const {
std::stringstream ss;
for (auto& kv : config_list_) {
auto& config = kv.second;
ss << config->name_ << ": " << config->Get() << std::endl;
}
return ss.str();
}
void
ConfigMgr::Attach(const std::string& name, ConfigObserver* observer) {
std::lock_guard<std::mutex> lock(observer_mutex_);
observers_[name].push_back(observer);
}
void
ConfigMgr::Detach(const std::string& name, ConfigObserver* observer) {
std::lock_guard<std::mutex> lock(observer_mutex_);
if (observers_.find(name) == observers_.end())
return;
auto& ob_list = observers_[name];
ob_list.remove(observer);
}
void
ConfigMgr::Notify(const std::string& name) {
std::lock_guard<std::mutex> lock(observer_mutex_);
if (observers_.find(name) == observers_.end())
return;
auto& ob_list = observers_[name];
for (auto& ob : ob_list) {
ob->ConfigUpdate(name);
}
}
} // namespace milvus

View File

@ -1,556 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "config/ConfigType.h"
#include "config/ServerConfig.h"
#include <strings.h>
#include <algorithm>
#include <cassert>
#include <functional>
#include <sstream>
#include <string>
namespace {
std::unordered_map<std::string, int64_t> BYTE_UNITS = {
{"b", 1},
{"k", 1024},
{"m", 1024 * 1024},
{"g", 1024 * 1024 * 1024},
};
bool
is_integer(const std::string& s) {
if (not s.empty() && (std::isdigit(s[0]) || s[0] == '-')) {
auto ss = s.substr(1);
return std::find_if(ss.begin(), ss.end(), [](unsigned char c) { return !std::isdigit(c); }) == ss.end();
}
return false;
}
bool
is_number(const std::string& s) {
return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { return !std::isdigit(c); }) == s.end();
}
bool
is_alpha(const std::string& s) {
return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { return !std::isalpha(c); }) == s.end();
}
template <typename T>
bool
boundary_check(T val, T lower_bound, T upper_bound) {
return lower_bound <= val && val <= upper_bound;
}
bool
parse_bool(const std::string& str, std::string& err) {
if (!strcasecmp(str.c_str(), "true"))
return true;
else if (!strcasecmp(str.c_str(), "false"))
return false;
else
err = "The specified value must be true or false";
return false;
}
std::string
str_tolower(std::string s) {
std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
return s;
}
int64_t
parse_bytes(const std::string& str, std::string& err) {
try {
if (str.find_first_of('-') != std::string::npos) {
std::stringstream ss;
ss << "The specified value for memory (" << str << ") should be a positive integer.";
err = ss.str();
return 0;
}
std::string s = str;
if (is_number(s))
return std::stoll(s);
if (s.length() == 0)
return 0;
auto last_two = s.substr(s.length() - 2, 2);
auto last_one = s.substr(s.length() - 1);
if (is_alpha(last_two) && is_alpha(last_one))
if (last_one == "b" or last_one == "B")
s = s.substr(0, s.length() - 1);
auto& units = BYTE_UNITS;
auto suffix = str_tolower(s.substr(s.length() - 1));
std::string digits_part;
if (is_number(suffix)) {
digits_part = s;
suffix = 'b';
} else {
digits_part = s.substr(0, s.length() - 1);
}
if (is_number(digits_part) && (units.find(suffix) != units.end() || is_number(suffix))) {
auto digits = std::stoll(digits_part);
return digits * units[suffix];
} else {
std::stringstream ss;
ss << "The specified value for memory (" << str << ") should specify the units."
<< "The postfix should be one of the `b` `k` `m` `g` characters";
err = ss.str();
}
} catch (...) {
err = "Unknown error happened on parse bytes.";
}
return 0;
}
} // namespace
// Use (void) to silent unused warnings.
#define assertm(exp, msg) assert(((void)msg, exp))
namespace milvus {
std::vector<std::string>
OptionValue(const configEnum& ce) {
std::vector<std::string> ret;
for (auto& e : ce) {
ret.emplace_back(e.first);
}
return ret;
}
BaseConfig::BaseConfig(const char* name, const char* alias, bool modifiable)
: name_(name), alias_(alias), modifiable_(modifiable) {
}
void
BaseConfig::Init() {
assertm(not inited_, "already initialized");
inited_ = true;
}
BoolConfig::BoolConfig(const char* name,
const char* alias,
bool modifiable,
bool* config,
bool default_value,
std::function<bool(bool val, std::string& err)> is_valid_fn,
std::function<bool(bool val, bool prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
BoolConfig::Init() {
BaseConfig::Init();
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
BoolConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
std::string err;
bool value = parse_bool(val, err);
if (not err.empty())
return ConfigStatus(SetReturn::INVALID, err);
if (is_valid_fn_ && not is_valid_fn_(value, err))
return ConfigStatus(SetReturn::INVALID, err);
bool prev = *config_;
*config_ = value;
if (update && update_fn_ && not update_fn_(value, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
BoolConfig::Get() {
assertm(inited_, "uninitialized");
return *config_ ? "true" : "false";
}
StringConfig::StringConfig(
const char* name,
const char* alias,
bool modifiable,
std::string* config,
const char* default_value,
std::function<bool(const std::string& val, std::string& err)> is_valid_fn,
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
StringConfig::Init() {
BaseConfig::Init();
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
StringConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
std::string err;
if (is_valid_fn_ && not is_valid_fn_(val, err))
return ConfigStatus(SetReturn::INVALID, err);
std::string prev = *config_;
*config_ = val;
if (update && update_fn_ && not update_fn_(val, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
StringConfig::Get() {
assertm(inited_, "uninitialized");
return *config_;
}
EnumConfig::EnumConfig(const char* name,
const char* alias,
bool modifiable,
configEnum* enumd,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
enum_value_(enumd),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
EnumConfig::Init() {
BaseConfig::Init();
assert(enum_value_ != nullptr);
assertm(not enum_value_->empty(), "enum value empty");
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
EnumConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
if (enum_value_->find(val) == enum_value_->end()) {
auto option_values = OptionValue(*enum_value_);
std::stringstream ss;
ss << "Config " << name_ << "(" << val << ") must be one of following: ";
for (size_t i = 0; i < option_values.size() - 1; ++i) {
ss << option_values[i] << ", ";
}
ss << option_values.back() << ".";
return ConfigStatus(SetReturn::ENUM_VALUE_NOTFOUND, ss.str());
}
int64_t value = enum_value_->at(val);
std::string err;
if (is_valid_fn_ && not is_valid_fn_(value, err)) {
return ConfigStatus(SetReturn::INVALID, err);
}
int64_t prev = *config_;
*config_ = value;
if (update && update_fn_ && not update_fn_(value, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
EnumConfig::Get() {
assertm(inited_, "uninitialized");
for (auto& it : *enum_value_) {
if (*config_ == it.second) {
return it.first;
}
}
return "unknown";
}
IntegerConfig::IntegerConfig(const char* name,
const char* alias,
bool modifiable,
int64_t lower_bound,
int64_t upper_bound,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
lower_bound_(lower_bound),
upper_bound_(upper_bound),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
IntegerConfig::Init() {
BaseConfig::Init();
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
IntegerConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
if (not is_integer(val)) {
std::stringstream ss;
ss << "Config " << name_ << "(" << val << ") must be a integer.";
return ConfigStatus(SetReturn::INVALID, ss.str());
}
int64_t value = std::stoll(val);
if (not boundary_check<int64_t>(value, lower_bound_, upper_bound_)) {
std::stringstream ss;
ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << ", " << upper_bound_
<< "].";
return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str());
}
std::string err;
if (is_valid_fn_ && not is_valid_fn_(value, err))
return ConfigStatus(SetReturn::INVALID, err);
int64_t prev = *config_;
*config_ = value;
if (update && update_fn_ && not update_fn_(value, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
IntegerConfig::Get() {
assertm(inited_, "uninitialized");
return std::to_string(*config_);
}
FloatingConfig::FloatingConfig(const char* name,
const char* alias,
bool modifiable,
double lower_bound,
double upper_bound,
double* config,
double default_value,
std::function<bool(double val, std::string& err)> is_valid_fn,
std::function<bool(double val, double prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
lower_bound_(lower_bound),
upper_bound_(upper_bound),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
FloatingConfig::Init() {
BaseConfig::Init();
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
FloatingConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
double value = std::stod(val);
if (not boundary_check<double>(value, lower_bound_, upper_bound_)) {
std::stringstream ss;
ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << ", " << upper_bound_
<< "].";
return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str());
}
std::string err;
if (is_valid_fn_ && not is_valid_fn_(value, err))
return ConfigStatus(SetReturn::INVALID, err);
double prev = *config_;
*config_ = value;
if (update && update_fn_ && not update_fn_(value, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
FloatingConfig::Get() {
assertm(inited_, "uninitialized");
return std::to_string(*config_);
}
SizeConfig::SizeConfig(const char* name,
const char* alias,
bool modifiable,
int64_t lower_bound,
int64_t upper_bound,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
lower_bound_(lower_bound),
upper_bound_(upper_bound),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
SizeConfig::Init() {
BaseConfig::Init();
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
SizeConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
std::string err;
int64_t value = parse_bytes(val, err);
if (not err.empty()) {
return ConfigStatus(SetReturn::INVALID, err);
}
if (not boundary_check<int64_t>(value, lower_bound_, upper_bound_)) {
std::stringstream ss;
ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << " Byte, " << upper_bound_
<< " Byte].";
return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str());
}
if (is_valid_fn_ && not is_valid_fn_(value, err)) {
return ConfigStatus(SetReturn::INVALID, err);
}
int64_t prev = *config_;
*config_ = value;
if (update && update_fn_ && not update_fn_(value, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
SizeConfig::Get() {
assertm(inited_, "uninitialized");
return std::to_string(*config_);
}
} // namespace milvus

View File

@ -1,265 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace milvus {
using configEnum = const std::unordered_map<std::string, int64_t>;
std::vector<std::string>
OptionValue(const configEnum& ce);
enum SetReturn {
SUCCESS = 1,
IMMUTABLE,
ENUM_VALUE_NOTFOUND,
INVALID,
OUT_OF_RANGE,
UPDATE_FAILURE,
EXCEPTION,
UNEXPECTED,
};
struct ConfigStatus {
ConfigStatus(SetReturn sr, std::string msg) : set_return(sr), message(std::move(msg)) {
}
SetReturn set_return;
std::string message;
};
class BaseConfig {
public:
BaseConfig(const char* name, const char* alias, bool modifiable);
virtual ~BaseConfig() = default;
public:
bool inited_ = false;
const char* name_;
const char* alias_;
const bool modifiable_;
public:
virtual void
Init();
virtual ConfigStatus
Set(const std::string& value, bool update) = 0;
virtual std::string
Get() = 0;
};
using BaseConfigPtr = std::shared_ptr<BaseConfig>;
class BoolConfig : public BaseConfig {
public:
BoolConfig(const char* name,
const char* alias,
bool modifiable,
bool* config,
bool default_value,
std::function<bool(bool val, std::string& err)> is_valid_fn,
std::function<bool(bool val, bool prev, std::string& err)> update_fn);
private:
bool* config_;
const bool default_value_;
std::function<bool(bool val, std::string& err)> is_valid_fn_;
std::function<bool(bool val, bool prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
class StringConfig : public BaseConfig {
public:
StringConfig(const char* name,
const char* alias,
bool modifiable,
std::string* config,
const char* default_value,
std::function<bool(const std::string& val, std::string& err)> is_valid_fn,
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn);
private:
std::string* config_;
const char* default_value_;
std::function<bool(const std::string& val, std::string& err)> is_valid_fn_;
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
class EnumConfig : public BaseConfig {
public:
EnumConfig(const char* name,
const char* alias,
bool modifiable,
configEnum* enumd,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
private:
int64_t* config_;
configEnum* enum_value_;
const int64_t default_value_;
std::function<bool(int64_t val, std::string& err)> is_valid_fn_;
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
class IntegerConfig : public BaseConfig {
public:
IntegerConfig(const char* name,
const char* alias,
bool modifiable,
int64_t lower_bound,
int64_t upper_bound,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
private:
int64_t* config_;
int64_t lower_bound_;
int64_t upper_bound_;
const int64_t default_value_;
std::function<bool(int64_t val, std::string& err)> is_valid_fn_;
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
class FloatingConfig : public BaseConfig {
public:
FloatingConfig(const char* name,
const char* alias,
bool modifiable,
double lower_bound,
double upper_bound,
double* config,
double default_value,
std::function<bool(double val, std::string& err)> is_valid_fn,
std::function<bool(double val, double prev, std::string& err)> update_fn);
private:
double* config_;
double lower_bound_;
double upper_bound_;
const double default_value_;
std::function<bool(double val, std::string& err)> is_valid_fn_;
std::function<bool(double val, double prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
class SizeConfig : public BaseConfig {
public:
SizeConfig(const char* name,
const char* alias,
bool modifiable,
int64_t lower_bound,
int64_t upper_bound,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
private:
int64_t* config_;
int64_t lower_bound_;
int64_t upper_bound_;
const int64_t default_value_;
std::function<bool(int64_t val, std::string& err)> is_valid_fn_;
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
#define CreateBoolConfig(name, modifiable, config_addr, default, is_valid, update) \
std::make_shared<BoolConfig>(name, nullptr, modifiable, config_addr, (default), is_valid, update)
#define CreateStringConfig(name, modifiable, config_addr, default, is_valid, update) \
std::make_shared<StringConfig>(name, nullptr, modifiable, config_addr, (default), is_valid, update)
#define CreateEnumConfig(name, modifiable, enumd, config_addr, default, is_valid, update) \
std::make_shared<EnumConfig>(name, nullptr, modifiable, enumd, config_addr, (default), is_valid, update)
#define CreateIntegerConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \
std::make_shared<IntegerConfig>(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \
is_valid, update)
#define CreateFloatingConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \
std::make_shared<FloatingConfig>(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \
is_valid, update)
#define CreateSizeConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \
std::make_shared<SizeConfig>(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \
is_valid, update)
} // namespace milvus

View File

@ -1,493 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <cstring>
#include <functional>
#include "config/ServerConfig.h"
#include "gtest/gtest.h"
namespace milvus {
#define _MODIFIABLE (true)
#define _IMMUTABLE (false)
template <typename T>
class Utils {
public:
bool
validate_fn(const T& value, std::string& err) {
validate_value = value;
return true;
}
bool
update_fn(const T& value, const T& prev, std::string& err) {
new_value = value;
prev_value = prev;
return true;
}
protected:
T validate_value;
T new_value;
T prev_value;
};
/* ValidBoolConfigTest */
class ValidBoolConfigTest : public testing::Test, public Utils<bool> {
protected:
};
TEST_F(ValidBoolConfigTest, init_load_update_get_test) {
auto validate = std::bind(&ValidBoolConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidBoolConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
bool bool_value = true;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, false, validate, update);
ASSERT_EQ(bool_value, true);
ASSERT_EQ(bool_config->modifiable_, true);
bool_config->Init();
ASSERT_EQ(bool_value, false);
ASSERT_EQ(bool_config->Get(), "false");
{
// now `bool_value` is `false`, calling Set(update=false) to set it to `true`, but not notify update_fn()
validate_value = false;
new_value = false;
prev_value = true;
ConfigStatus status(SetReturn::SUCCESS, "");
status = bool_config->Set("true", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(bool_value, true);
EXPECT_EQ(bool_config->Get(), "true");
// expect change
EXPECT_EQ(validate_value, true);
// expect not change
EXPECT_EQ(new_value, false);
EXPECT_EQ(prev_value, true);
}
{
// now `bool_value` is `true`, calling Set(update=true) to set it to `false`, will notify update_fn()
validate_value = true;
new_value = true;
prev_value = false;
ConfigStatus status(SetReturn::SUCCESS, "");
status = bool_config->Set("false", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(bool_value, false);
EXPECT_EQ(bool_config->Get(), "false");
// expect change
EXPECT_EQ(validate_value, false);
EXPECT_EQ(new_value, false);
EXPECT_EQ(prev_value, true);
}
}
/* ValidStringConfigTest */
class ValidStringConfigTest : public testing::Test, public Utils<std::string> {
protected:
};
TEST_F(ValidStringConfigTest, init_load_update_get_test) {
auto validate = std::bind(&ValidStringConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidStringConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
std::string string_value;
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", validate, update);
ASSERT_EQ(string_value, "");
ASSERT_EQ(string_config->modifiable_, true);
string_config->Init();
ASSERT_EQ(string_value, "Magic");
ASSERT_EQ(string_config->Get(), "Magic");
{
// now `string_value` is `Magic`, calling Set(update=false) to set it to `cigaM`, but not notify update_fn()
validate_value = "";
new_value = "";
prev_value = "";
ConfigStatus status(SetReturn::SUCCESS, "");
status = string_config->Set("cigaM", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(string_value, "cigaM");
EXPECT_EQ(string_config->Get(), "cigaM");
// expect change
EXPECT_EQ(validate_value, "cigaM");
// expect not change
EXPECT_EQ(new_value, "");
EXPECT_EQ(prev_value, "");
}
{
// now `string_value` is `cigaM`, calling Set(update=true) to set it to `Check`, will notify update_fn()
validate_value = "";
new_value = "";
prev_value = "";
ConfigStatus status(SetReturn::SUCCESS, "");
status = string_config->Set("Check", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(string_value, "Check");
EXPECT_EQ(string_config->Get(), "Check");
// expect change
EXPECT_EQ(validate_value, "Check");
EXPECT_EQ(new_value, "Check");
EXPECT_EQ(prev_value, "cigaM");
}
}
/* ValidIntegerConfigTest */
class ValidIntegerConfigTest : public testing::Test, public Utils<int64_t> {
protected:
};
TEST_F(ValidIntegerConfigTest, init_load_update_get_test) {
auto validate = std::bind(&ValidIntegerConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidIntegerConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
int64_t integer_value = 0;
auto integer_config = CreateIntegerConfig("i", _MODIFIABLE, -100, 100, &integer_value, 42, validate, update);
ASSERT_EQ(integer_value, 0);
ASSERT_EQ(integer_config->modifiable_, true);
integer_config->Init();
ASSERT_EQ(integer_value, 42);
ASSERT_EQ(integer_config->Get(), "42");
{
// now `integer_value` is `42`, calling Set(update=false) to set it to `24`, but not notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = integer_config->Set("24", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(integer_value, 24);
EXPECT_EQ(integer_config->Get(), "24");
// expect change
EXPECT_EQ(validate_value, 24);
// expect not change
EXPECT_EQ(new_value, 0);
EXPECT_EQ(prev_value, 0);
}
{
// now `integer_value` is `24`, calling Set(update=true) to set it to `36`, will notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = integer_config->Set("36", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(integer_value, 36);
EXPECT_EQ(integer_config->Get(), "36");
// expect change
EXPECT_EQ(validate_value, 36);
EXPECT_EQ(new_value, 36);
EXPECT_EQ(prev_value, 24);
}
}
/* ValidFloatingConfigTest */
class ValidFloatingConfigTest : public testing::Test, public Utils<double> {
protected:
};
TEST_F(ValidFloatingConfigTest, init_load_update_get_test) {
auto validate =
std::bind(&ValidFloatingConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidFloatingConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
double floating_value = 0.0;
auto floating_config = CreateFloatingConfig("f", _MODIFIABLE, -10.0, 10.0, &floating_value, 3.14, validate, update);
ASSERT_FLOAT_EQ(floating_value, 0.0);
ASSERT_EQ(floating_config->modifiable_, true);
floating_config->Init();
ASSERT_FLOAT_EQ(floating_value, 3.14);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 3.14);
{
// now `floating_value` is `3.14`, calling Set(update=false) to set it to `6.22`, but not notify update_fn()
validate_value = 0.0;
new_value = 0.0;
prev_value = 0.0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = floating_config->Set("6.22", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
ASSERT_FLOAT_EQ(floating_value, 6.22);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 6.22);
// expect change
ASSERT_FLOAT_EQ(validate_value, 6.22);
// expect not change
ASSERT_FLOAT_EQ(new_value, 0.0);
ASSERT_FLOAT_EQ(prev_value, 0.0);
}
{
// now `integer_value` is `6.22`, calling Set(update=true) to set it to `-3.14`, will notify update_fn()
validate_value = 0.0;
new_value = 0.0;
prev_value = 0.0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = floating_config->Set("-3.14", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
ASSERT_FLOAT_EQ(floating_value, -3.14);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), -3.14);
// expect change
ASSERT_FLOAT_EQ(validate_value, -3.14);
ASSERT_FLOAT_EQ(new_value, -3.14);
ASSERT_FLOAT_EQ(prev_value, 6.22);
}
}
/* ValidEnumConfigTest */
class ValidEnumConfigTest : public testing::Test, public Utils<int64_t> {
protected:
};
// template <>
// int64_t Utils<int64_t>::validate_value = 0;
// template <>
// int64_t Utils<int64_t>::new_value = 0;
// template <>
// int64_t Utils<int64_t>::prev_value = 0;
TEST_F(ValidEnumConfigTest, init_load_update_get_test) {
auto validate = std::bind(&ValidEnumConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidEnumConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value = 0;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, validate, update);
ASSERT_EQ(enum_value, 0);
ASSERT_EQ(enum_config->modifiable_, true);
enum_config->Init();
ASSERT_EQ(enum_value, 1);
ASSERT_EQ(enum_config->Get(), "a");
{
// now `enum_value` is `a`, calling Set(update=false) to set it to `b`, but not notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = enum_config->Set("b", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
ASSERT_EQ(enum_value, 2);
ASSERT_EQ(enum_config->Get(), "b");
// expect change
ASSERT_EQ(validate_value, 2);
// expect not change
ASSERT_EQ(new_value, 0);
ASSERT_EQ(prev_value, 0);
}
{
// now `enum_value` is `b`, calling Set(update=true) to set it to `c`, will notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = enum_config->Set("c", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
ASSERT_EQ(enum_value, 3);
ASSERT_EQ(enum_config->Get(), "c");
// expect change
ASSERT_EQ(validate_value, 3);
ASSERT_EQ(new_value, 3);
ASSERT_EQ(prev_value, 2);
}
}
/* ValidSizeConfigTest */
class ValidSizeConfigTest : public testing::Test, public Utils<int64_t> {
protected:
};
// template <>
// int64_t Utils<int64_t>::validate_value = 0;
// template <>
// int64_t Utils<int64_t>::new_value = 0;
// template <>
// int64_t Utils<int64_t>::prev_value = 0;
TEST_F(ValidSizeConfigTest, init_load_update_get_test) {
auto validate = std::bind(&ValidSizeConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidSizeConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
int64_t size_value = 0;
auto size_config = CreateSizeConfig("i", _MODIFIABLE, 0, 1024 * 1024, &size_value, 1024, validate, update);
ASSERT_EQ(size_value, 0);
ASSERT_EQ(size_config->modifiable_, true);
size_config->Init();
ASSERT_EQ(size_value, 1024);
ASSERT_EQ(size_config->Get(), "1024");
{
// now `size_value` is `1024`, calling Set(update=false) to set it to `4096`, but not notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = size_config->Set("4096", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(size_value, 4096);
EXPECT_EQ(size_config->Get(), "4096");
// expect change
EXPECT_EQ(validate_value, 4096);
// expect not change
EXPECT_EQ(new_value, 0);
EXPECT_EQ(prev_value, 0);
}
{
// now `size_value` is `4096`, calling Set(update=true) to set it to `256kb`, will notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = size_config->Set("256kb", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(size_value, 256 * 1024);
EXPECT_EQ(size_config->Get(), "262144");
// expect change
EXPECT_EQ(validate_value, 262144);
EXPECT_EQ(new_value, 262144);
EXPECT_EQ(prev_value, 4096);
}
}
class ValidTest : public testing::Test {
protected:
configEnum family{
{"ipv4", 1},
{"ipv6", 2},
};
struct Server {
bool running = true;
std::string hostname;
int64_t family = 0;
int64_t port = 0;
double uptime = 0;
};
Server server;
protected:
void
SetUp() override {
config_list = {
CreateBoolConfig("running", true, &server.running, true, nullptr, nullptr),
CreateStringConfig("hostname", true, &server.hostname, "Magic", nullptr, nullptr),
CreateEnumConfig("socket_family", false, &family, &server.family, 2, nullptr, nullptr),
CreateIntegerConfig("port", true, 1024, 65535, &server.port, 19530, nullptr, nullptr),
CreateFloatingConfig("uptime", true, 0, 9999.0, &server.uptime, 0, nullptr, nullptr),
};
}
void
TearDown() override {
}
protected:
void
Init() {
for (auto& config : config_list) {
config->Init();
}
}
void
Load() {
std::unordered_map<std::string, std::string> config_file{
{"running", "false"},
};
for (auto& c : config_file) Set(c.first, c.second, false);
}
void
Set(const std::string& name, const std::string& value, bool update = true) {
for (auto& config : config_list) {
if (std::strcmp(name.c_str(), config->name_) == 0) {
config->Set(value, update);
return;
}
}
throw "Config " + name + " not found.";
}
std::string
Get(const std::string& name) {
for (auto& config : config_list) {
if (std::strcmp(name.c_str(), config->name_) == 0) {
return config->Get();
}
}
throw "Config " + name + " not found.";
}
std::vector<BaseConfigPtr> config_list;
};
} // namespace milvus

View File

@ -1,861 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "config/ServerConfig.h"
#include "gtest/gtest.h"
namespace milvus {
#define _MODIFIABLE (true)
#define _IMMUTABLE (false)
template <typename T>
class Utils {
public:
static bool
valid_check_failure(const T& value, std::string& err) {
err = "Value is invalid.";
return false;
}
static bool
update_failure(const T& value, const T& prev, std::string& err) {
err = "Update is failure";
return false;
}
static bool
valid_check_raise_string(const T& value, std::string& err) {
throw "string exception";
}
static bool
valid_check_raise_exception(const T& value, std::string& err) {
throw std::bad_alloc();
}
};
/* BoolConfigTest */
class BoolConfigTest : public testing::Test, public Utils<bool> {};
TEST_F(BoolConfigTest, nullptr_init_test) {
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, nullptr, true, nullptr, nullptr);
ASSERT_DEATH(bool_config->Init(), "nullptr");
}
TEST_F(BoolConfigTest, init_twice_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr);
ASSERT_DEATH(
{
bool_config->Init();
bool_config->Init();
},
"initialized");
}
TEST_F(BoolConfigTest, non_init_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr);
ASSERT_DEATH(bool_config->Set("false", true), "uninitialized");
ASSERT_DEATH(bool_config->Get(), "uninitialized");
}
TEST_F(BoolConfigTest, immutable_update_test) {
bool bool_value = false;
auto bool_config = CreateBoolConfig("b", _IMMUTABLE, &bool_value, true, nullptr, nullptr);
bool_config->Init();
ASSERT_EQ(bool_value, true);
ConfigStatus status(SUCCESS, "");
status = bool_config->Set("false", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_EQ(bool_value, true);
}
TEST_F(BoolConfigTest, set_invalid_value_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr);
bool_config->Init();
ConfigStatus status(SUCCESS, "");
status = bool_config->Set(" false", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("false ", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("afalse", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("falsee", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("abcdefg", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("123456", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
}
TEST_F(BoolConfigTest, valid_check_fail_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_failure, nullptr);
bool_config->Init();
ConfigStatus status(SUCCESS, "");
status = bool_config->Set("123456", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
}
TEST_F(BoolConfigTest, update_fail_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, update_failure);
bool_config->Init();
ConfigStatus status(SUCCESS, "");
status = bool_config->Set("false", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_EQ(bool_config->Get(), "true");
}
TEST_F(BoolConfigTest, string_exception_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_raise_string, nullptr);
bool_config->Init();
ConfigStatus status(SUCCESS, "");
status = bool_config->Set("false", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_EQ(bool_config->Get(), "true");
}
TEST_F(BoolConfigTest, standard_exception_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_raise_exception, nullptr);
bool_config->Init();
ConfigStatus status(SUCCESS, "");
status = bool_config->Set("false", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_EQ(bool_config->Get(), "true");
}
/* StringConfigTest */
class StringConfigTest : public testing::Test, public Utils<std::string> {};
TEST_F(StringConfigTest, nullptr_init_test) {
auto string_config = CreateStringConfig("s", true, nullptr, "Magic", nullptr, nullptr);
ASSERT_DEATH(string_config->Init(), "nullptr");
}
TEST_F(StringConfigTest, init_twice_test) {
std::string string_value;
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, nullptr);
ASSERT_DEATH(
{
string_config->Init();
string_config->Init();
},
"initialized");
}
TEST_F(StringConfigTest, non_init_test) {
std::string string_value;
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, nullptr);
ASSERT_DEATH(string_config->Set("value", true), "uninitialized");
ASSERT_DEATH(string_config->Get(), "uninitialized");
}
TEST_F(StringConfigTest, immutable_update_test) {
std::string string_value;
auto string_config = CreateStringConfig("s", _IMMUTABLE, &string_value, "Magic", nullptr, nullptr);
string_config->Init();
ASSERT_EQ(string_value, "Magic");
ConfigStatus status(SUCCESS, "");
status = string_config->Set("cigaM", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_EQ(string_value, "Magic");
}
TEST_F(StringConfigTest, valid_check_fail_test) {
std::string string_value;
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_failure, nullptr);
string_config->Init();
ConfigStatus status(SUCCESS, "");
status = string_config->Set("123456", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(string_config->Get(), "Magic");
}
TEST_F(StringConfigTest, update_fail_test) {
std::string string_value;
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, update_failure);
string_config->Init();
ConfigStatus status(SUCCESS, "");
status = string_config->Set("Mi", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_EQ(string_config->Get(), "Magic");
}
TEST_F(StringConfigTest, string_exception_test) {
std::string string_value;
auto string_config =
CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_raise_string, nullptr);
string_config->Init();
ConfigStatus status(SUCCESS, "");
status = string_config->Set("any", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_EQ(string_config->Get(), "Magic");
}
TEST_F(StringConfigTest, standard_exception_test) {
std::string string_value;
auto string_config =
CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_raise_exception, nullptr);
string_config->Init();
ConfigStatus status(SUCCESS, "");
status = string_config->Set("any", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_EQ(string_config->Get(), "Magic");
}
/* IntegerConfigTest */
class IntegerConfigTest : public testing::Test, public Utils<int64_t> {};
TEST_F(IntegerConfigTest, nullptr_init_test) {
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, nullptr, 19530, nullptr, nullptr);
ASSERT_DEATH(integer_config->Init(), "nullptr");
}
TEST_F(IntegerConfigTest, init_twice_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
ASSERT_DEATH(
{
integer_config->Init();
integer_config->Init();
},
"initialized");
}
TEST_F(IntegerConfigTest, non_init_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
ASSERT_DEATH(integer_config->Set("42", true), "uninitialized");
ASSERT_DEATH(integer_config->Get(), "uninitialized");
}
TEST_F(IntegerConfigTest, immutable_update_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", _IMMUTABLE, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
integer_config->Init();
ASSERT_EQ(integer_value, 19530);
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("2048", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_EQ(integer_value, 19530);
}
TEST_F(IntegerConfigTest, set_invalid_value_test) {
}
TEST_F(IntegerConfigTest, valid_check_fail_test) {
int64_t integer_value;
auto integer_config =
CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_failure, nullptr);
integer_config->Init();
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("2048", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "19530");
}
TEST_F(IntegerConfigTest, update_fail_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, update_failure);
integer_config->Init();
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("2048", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_EQ(integer_config->Get(), "19530");
}
TEST_F(IntegerConfigTest, string_exception_test) {
int64_t integer_value;
auto integer_config =
CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_raise_string, nullptr);
integer_config->Init();
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("2048", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_EQ(integer_config->Get(), "19530");
}
TEST_F(IntegerConfigTest, standard_exception_test) {
int64_t integer_value;
auto integer_config =
CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_raise_exception, nullptr);
integer_config->Init();
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("2048", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_EQ(integer_config->Get(), "19530");
}
TEST_F(IntegerConfigTest, out_of_range_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
integer_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("1023", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(integer_config->Get(), "19530");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("65536", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(integer_config->Get(), "19530");
}
}
TEST_F(IntegerConfigTest, invalid_bound_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 100, 0, &integer_value, 50, nullptr, nullptr);
integer_config->Init();
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("30", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(integer_config->Get(), "50");
}
TEST_F(IntegerConfigTest, invalid_format_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 0, 100, &integer_value, 50, nullptr, nullptr);
integer_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("3-0", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("30-", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("+30", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("a30", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("30a", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("3a0", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
}
/* FloatingConfigTest */
class FloatingConfigTest : public testing::Test, public Utils<double> {};
TEST_F(FloatingConfigTest, nullptr_init_test) {
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, nullptr, 4.5, nullptr, nullptr);
ASSERT_DEATH(floating_config->Init(), "nullptr");
}
TEST_F(FloatingConfigTest, init_twice_test) {
double floating_value;
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr);
ASSERT_DEATH(
{
floating_config->Init();
floating_config->Init();
},
"initialized");
}
TEST_F(FloatingConfigTest, non_init_test) {
double floating_value;
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr);
ASSERT_DEATH(floating_config->Set("3.14", true), "uninitialized");
ASSERT_DEATH(floating_config->Get(), "uninitialized");
}
TEST_F(FloatingConfigTest, immutable_update_test) {
double floating_value;
auto floating_config = CreateFloatingConfig("f", _IMMUTABLE, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr);
floating_config->Init();
ASSERT_FLOAT_EQ(floating_value, 4.5);
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("1.23", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, set_invalid_value_test) {
}
TEST_F(FloatingConfigTest, valid_check_fail_test) {
double floating_value;
auto floating_config =
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_failure, nullptr);
floating_config->Init();
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("1.23", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, update_fail_test) {
double floating_value;
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, update_failure);
floating_config->Init();
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("1.23", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, string_exception_test) {
double floating_value;
auto floating_config =
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_string, nullptr);
floating_config->Init();
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("1.23", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, standard_exception_test) {
double floating_value;
auto floating_config =
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_exception, nullptr);
floating_config->Init();
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("1.23", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, out_of_range_test) {
double floating_value;
auto floating_config =
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_exception, nullptr);
floating_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("0.99", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
{
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("10.00", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
}
TEST_F(FloatingConfigTest, invalid_bound_test) {
double floating_value;
auto floating_config =
CreateFloatingConfig("f", true, 9.9, 1.0, &floating_value, 4.5, valid_check_raise_exception, nullptr);
floating_config->Init();
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("6.0", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, DISABLED_invalid_format_test) {
double floating_value;
auto floating_config = CreateFloatingConfig("f", true, 1.0, 100.0, &floating_value, 4.5, nullptr, nullptr);
floating_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("6.0.1", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
{
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("6a0", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
}
/* EnumConfigTest */
class EnumConfigTest : public testing::Test, public Utils<int64_t> {};
TEST_F(EnumConfigTest, nullptr_init_test) {
configEnum testEnum{
{"e", 1},
};
int64_t testEnumValue;
auto enum_config_1 = CreateEnumConfig("e", _MODIFIABLE, &testEnum, nullptr, 2, nullptr, nullptr);
ASSERT_DEATH(enum_config_1->Init(), "nullptr");
auto enum_config_2 = CreateEnumConfig("e", _MODIFIABLE, nullptr, &testEnumValue, 2, nullptr, nullptr);
ASSERT_DEATH(enum_config_2->Init(), "nullptr");
auto enum_config_3 = CreateEnumConfig("e", _MODIFIABLE, nullptr, nullptr, 2, nullptr, nullptr);
ASSERT_DEATH(enum_config_3->Init(), "nullptr");
}
TEST_F(EnumConfigTest, init_twice_test) {
configEnum testEnum{
{"e", 1},
};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr);
ASSERT_DEATH(
{
enum_config->Init();
enum_config->Init();
},
"initialized");
}
TEST_F(EnumConfigTest, non_init_test) {
configEnum testEnum{
{"e", 1},
};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr);
ASSERT_DEATH(enum_config->Set("e", true), "uninitialized");
ASSERT_DEATH(enum_config->Get(), "uninitialized");
}
TEST_F(EnumConfigTest, immutable_update_test) {
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value = 0;
auto enum_config = CreateEnumConfig("e", _IMMUTABLE, &testEnum, &enum_value, 1, nullptr, nullptr);
enum_config->Init();
ASSERT_EQ(enum_value, 1);
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_EQ(enum_value, 1);
}
TEST_F(EnumConfigTest, set_invalid_value_check) {
configEnum testEnum{
{"a", 1},
};
int64_t enum_value = 0;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, nullptr, nullptr);
enum_config->Init();
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::ENUM_VALUE_NOTFOUND);
ASSERT_EQ(enum_config->Get(), "a");
}
TEST_F(EnumConfigTest, empty_enum_test) {
configEnum testEnum{};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr);
ASSERT_DEATH(enum_config->Init(), "empty");
}
TEST_F(EnumConfigTest, valid_check_fail_test) {
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_failure, nullptr);
enum_config->Init();
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(enum_config->Get(), "a");
}
TEST_F(EnumConfigTest, update_fail_test) {
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, nullptr, update_failure);
enum_config->Init();
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_EQ(enum_config->Get(), "a");
}
TEST_F(EnumConfigTest, string_exception_test) {
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_raise_string, nullptr);
enum_config->Init();
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_EQ(enum_config->Get(), "a");
}
TEST_F(EnumConfigTest, standard_exception_test) {
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value;
auto enum_config =
CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_raise_exception, nullptr);
enum_config->Init();
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_EQ(enum_config->Get(), "a");
}
/* SizeConfigTest */
class SizeConfigTest : public testing::Test, public Utils<int64_t> {};
TEST_F(SizeConfigTest, nullptr_init_test) {
auto size_config = CreateSizeConfig("i", true, 1024, 4096, nullptr, 2048, nullptr, nullptr);
ASSERT_DEATH(size_config->Init(), "nullptr");
}
TEST_F(SizeConfigTest, init_twice_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
ASSERT_DEATH(
{
size_config->Init();
size_config->Init();
},
"initialized");
}
TEST_F(SizeConfigTest, non_init_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
ASSERT_DEATH(size_config->Set("3000", true), "uninitialized");
ASSERT_DEATH(size_config->Get(), "uninitialized");
}
TEST_F(SizeConfigTest, immutable_update_test) {
int64_t size_value = 0;
auto size_config = CreateSizeConfig("i", _IMMUTABLE, 1024, 4096, &size_value, 2048, nullptr, nullptr);
size_config->Init();
ASSERT_EQ(size_value, 2048);
ConfigStatus status(SUCCESS, "");
status = size_config->Set("3000", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_EQ(size_value, 2048);
}
TEST_F(SizeConfigTest, set_invalid_value_test) {
}
TEST_F(SizeConfigTest, valid_check_fail_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_failure, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("3000", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, update_fail_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, update_failure);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("3000", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, string_exception_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_raise_string, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("3000", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, standard_exception_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_raise_exception, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("3000", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, out_of_range_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
size_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = size_config->Set("1023", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(size_config->Get(), "2048");
}
{
ConfigStatus status(SUCCESS, "");
status = size_config->Set("4097", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(size_config->Get(), "2048");
}
}
TEST_F(SizeConfigTest, negative_integer_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("-3KB", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, invalid_bound_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 100, 0, &size_value, 50, nullptr, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("30", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(size_config->Get(), "50");
}
TEST_F(SizeConfigTest, invalid_unit_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("1 TB", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, invalid_format_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
size_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = size_config->Set("a10GB", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
{
ConfigStatus status(SUCCESS, "");
status = size_config->Set("200*0", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
{
ConfigStatus status(SUCCESS, "");
status = size_config->Set("10AB", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
}
} // namespace milvus

View File

@ -1,110 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "config/ConfigType.h"
namespace milvus {
extern std::mutex&
GetConfigMutex();
template <typename T>
class ConfigValue {
public:
explicit ConfigValue(T init_value) : value(std::move(init_value)) {
}
const T&
operator()() {
std::lock_guard<std::mutex> lock(GetConfigMutex());
return value;
}
public:
T value;
};
enum ClusterRole {
RW = 1,
RO,
};
enum SimdType {
AUTO = 1,
SSE,
AVX2,
AVX512,
};
const configEnum SimdMap{
{"auto", SimdType::AUTO},
{"sse", SimdType::SSE},
{"avx2", SimdType::AVX2},
{"avx512", SimdType::AVX512},
};
struct ServerConfig {
using String = ConfigValue<std::string>;
using Bool = ConfigValue<bool>;
using Integer = ConfigValue<int64_t>;
using Floating = ConfigValue<double>;
String timezone{"unknown"};
struct Network {
String address{"unknown"};
Integer port{0};
} network;
struct Pulsar {
String address{"localhost"};
Integer port{6650};
} pulsar;
struct Engine {
Integer build_index_threshold{4096};
Integer search_combine_nq{0};
Integer use_blas_threshold{0};
Integer omp_thread_num{0};
Integer simd_type{0};
} engine;
struct Tracing {
String json_config_path{"unknown"};
} tracing;
struct Logs {
String level{"unknown"};
struct Trace {
Bool enable{false};
} trace;
String path{"unknown"};
Integer max_log_file_size{0};
Integer log_rotate_num{0};
} logs;
};
extern ServerConfig config;
extern std::mutex _config_mutex;
std::vector<std::string>
ParsePreloadCollection(const std::string&);
std::vector<int64_t>
ParseGPUDevices(const std::string&);
} // namespace milvus

View File

@ -51,6 +51,7 @@ include(DefineOptionsCore)
include(BuildUtilsCore)
using_ccache_if_defined( KNOWHERE_USE_CCACHE )
set_directory_properties(PROPERTIES RULE_LAUNCH_COMPILE "")
if (MILVUS_GPU_VERSION)
message(STATUS "Building Knowhere GPU version")

View File

@ -16,12 +16,19 @@
#include <faiss/Clustering.h>
#include <faiss/utils/distances.h>
#include "config/ServerConfig.h"
#include "NGT/lib/NGT/defines.h"
#include "faiss/FaissHook.h"
#include "faiss/common.h"
#include "faiss/utils/utils.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/IndexType.h"
#include "knowhere/index/vector_index/IndexHNSW.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
#include "scheduler/Utils.h"
#include "utils/ConfigUtils.h"
#include "utils/Error.h"
#include "utils/Log.h"
#include "value/config/ServerConfig.h"
#include <fiu/fiu-local.h>
#include <map>
@ -63,20 +70,6 @@ KnowhereResource::Initialize() {
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
}
// engine config
int64_t omp_thread = config.engine.omp_thread_num();
if (omp_thread > 0) {
omp_set_num_threads(omp_thread);
LOG_SERVER_DEBUG_ << "Specify openmp thread number: " << omp_thread;
} else {
int64_t sys_thread_cnt = 8;
if (milvus::server::GetSystemAvailableThreads(sys_thread_cnt)) {
omp_thread = static_cast<int32_t>(ceil(sys_thread_cnt * 0.5));
omp_set_num_threads(omp_thread);
}
}
// init faiss global variable
int64_t use_blas_threshold = config.engine.use_blas_threshold();
faiss::distance_compute_blas_threshold = use_blas_threshold;
@ -95,40 +88,49 @@ KnowhereResource::Initialize() {
#ifdef MILVUS_GPU_VERSION
bool enable_gpu = config.gpu.enable();
fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false);
if (!enable_gpu) {
return Status::OK();
}
if (enable_gpu) {
struct GpuResourceSetting {
int64_t pinned_memory = 256 * M_BYTE;
int64_t temp_memory = 256 * M_BYTE;
int64_t resource_num = 2;
};
using GpuResourcesArray = std::map<int64_t, GpuResourceSetting>;
GpuResourcesArray gpu_resources;
struct GpuResourceSetting {
int64_t pinned_memory = 256 * M_BYTE;
int64_t temp_memory = 256 * M_BYTE;
int64_t resource_num = 2;
};
using GpuResourcesArray = std::map<int64_t, GpuResourceSetting>;
GpuResourcesArray gpu_resources;
// get build index gpu resource
std::vector<int64_t> build_index_gpus = ParseGPUDevices(config.gpu.build_index_devices());
// get build index gpu resource
std::vector<int64_t> build_index_gpus = ParseGPUDevices(config.gpu.build_index_devices());
for (auto gpu_id : build_index_gpus) {
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
}
for (auto gpu_id : build_index_gpus) {
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
}
// get search gpu resource
std::vector<int64_t> search_gpus = ParseGPUDevices(config.gpu.search_devices());
// get search gpu resource
std::vector<int64_t> search_gpus = ParseGPUDevices(config.gpu.search_devices());
for (auto& gpu_id : search_gpus) {
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
}
for (auto& gpu_id : search_gpus) {
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
}
// init gpu resources
for (auto& gpu_resource : gpu_resources) {
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(gpu_resource.first, gpu_resource.second.pinned_memory,
gpu_resource.second.temp_memory,
gpu_resource.second.resource_num);
// init gpu resources
for (auto& gpu_resource : gpu_resources) {
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(
gpu_resource.first, gpu_resource.second.pinned_memory, gpu_resource.second.temp_memory,
gpu_resource.second.resource_num);
}
}
#endif
faiss::LOG_ERROR_ = &knowhere::log_error_;
faiss::LOG_WARNING_ = &knowhere::log_warning_;
// faiss::LOG_DEBUG_ = &knowhere::log_debug_;
NGT_LOG_ERROR_ = &knowhere::log_error_;
NGT_LOG_WARNING_ = &knowhere::log_warning_;
// NGT_LOG_DEBUG_ = &knowhere::log_debug_;
auto stat_level = config.engine.statistics_level();
milvus::knowhere::STATISTICS_LEVEL = stat_level;
faiss::STATISTICS_LEVEL = stat_level;
return Status::OK();
}

View File

@ -39,11 +39,21 @@ endif ()
set(external_srcs
knowhere/common/Exception.cpp
knowhere/common/Log.cpp
knowhere/common/Timer.cpp
knowhere/common/Utils.cpp
)
set (LOG_SRC
knowhere/common/Log.cpp
${MILVUS_THIRDPARTY_SRC}/easyloggingpp/easylogging++.cc
)
add_library(index_log STATIC ${LOG_SRC})
set_target_properties(index_log PROPERTIES RULE_LAUNCH_COMPILE "")
set_target_properties(index_log PROPERTIES RULE_LAUNCH_LINK "")
include_directories(${MILVUS_THIRDPARTY_SRC})
set(vector_index_srcs
knowhere/index/IndexType.cpp
knowhere/index/vector_index/adapter/VectorAdapter.cpp
knowhere/index/vector_index/helpers/FaissIO.cpp
knowhere/index/vector_index/helpers/IndexParameter.cpp
@ -61,8 +71,6 @@ set(vector_index_srcs
knowhere/index/vector_index/IndexIVF.cpp
knowhere/index/vector_index/IndexIVFPQ.cpp
knowhere/index/vector_index/IndexIVFSQ.cpp
knowhere/index/IndexType.cpp
knowhere/index/vector_index/VecIndexFactory.cpp
knowhere/index/vector_index/IndexAnnoy.cpp
knowhere/index/vector_index/IndexRHNSW.cpp
knowhere/index/vector_index/IndexHNSW.cpp
@ -72,6 +80,8 @@ set(vector_index_srcs
knowhere/index/vector_index/IndexNGT.cpp
knowhere/index/vector_index/IndexNGTPANNG.cpp
knowhere/index/vector_index/IndexNGTONNG.cpp
knowhere/index/vector_index/Statistics.cpp
knowhere/index/vector_index/VecIndexFactory.cpp
)
set(vector_offset_index_srcs
@ -96,6 +106,7 @@ set(depend_libs
pthread
fiu
ngt
index_log
)
if (MILVUS_SUPPORT_SPTAG)
@ -143,7 +154,6 @@ endif ()
target_link_libraries(
knowhere
milvus_utils
${depend_libs}
)

View File

@ -79,6 +79,11 @@ class BinarySet {
binary_map_.clear();
}
bool
Contains(const std::string& key) const {
return binary_map_.find(key) != binary_map_.end();
}
public:
std::map<std::string, BinaryPtr> binary_map_;
};

View File

@ -16,7 +16,7 @@
namespace milvus {
namespace knowhere {
using Config = milvus::Json;
using Config = milvus::json;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,104 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/common/Utils.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <vector>
namespace milvus {
namespace knowhere {
const char* INDEX_FILE_SLICE_SIZE_IN_MEGABYTE = "SLICE_SIZE";
const char* INDEX_FILE_SLICE_META = "SLICE_META";
static const char* META = "meta";
static const char* NAME = "name";
static const char* SLICE_NUM = "slice_num";
static const char* TOTAL_LEN = "total_len";
void
Slice(const std::string& prefix,
const BinaryPtr& data_src,
const int64_t& slice_len,
BinarySet& binarySet,
milvus::json& ret) {
if (!data_src) {
return;
}
int slice_num = 0;
for (int64_t i = 0; i < data_src->size; ++slice_num) {
int64_t ri = std::min(i + slice_len, data_src->size);
auto size = static_cast<size_t>(ri - i);
auto slice_i = reinterpret_cast<uint8_t*>(malloc(size));
memcpy(slice_i, data_src->data.get() + i, size);
std::shared_ptr<uint8_t[]> slice_i_sp(slice_i, std::default_delete<uint8_t[]>());
binarySet.Append(prefix + "_" + std::to_string(slice_num), slice_i_sp, ri - i);
i = ri;
}
ret[NAME] = prefix;
ret[SLICE_NUM] = slice_num;
ret[TOTAL_LEN] = data_src->size;
}
void
Assemble(BinarySet& binarySet) {
auto slice_meta = binarySet.Erase(INDEX_FILE_SLICE_META);
if (slice_meta == nullptr) {
return;
}
milvus::json meta_data =
milvus::json::parse(std::string(reinterpret_cast<char*>(slice_meta->data.get()), slice_meta->size));
for (auto& item : meta_data[META]) {
std::string prefix = item[NAME];
int slice_num = item[SLICE_NUM];
auto total_len = static_cast<size_t>(item[TOTAL_LEN]);
auto p_data = reinterpret_cast<uint8_t*>(malloc(total_len));
int64_t pos = 0;
for (auto i = 0; i < slice_num; ++i) {
auto slice_i_sp = binarySet.Erase(prefix + "_" + std::to_string(i));
memcpy(p_data + pos, slice_i_sp->data.get(), static_cast<size_t>(slice_i_sp->size));
pos += slice_i_sp->size;
}
std::shared_ptr<uint8_t[]> integral_data(p_data, std::default_delete<uint8_t[]>());
binarySet.Append(prefix, integral_data, total_len);
}
}
void
Disassemble(const int64_t& slice_size_in_byte, BinarySet& binarySet) {
milvus::json meta_info;
std::vector<std::string> slice_key_list;
for (auto& kv : binarySet.binary_map_) {
if (kv.second->size > slice_size_in_byte) {
slice_key_list.push_back(kv.first);
}
}
for (auto& key : slice_key_list) {
milvus::json slice_i;
Slice(key, binarySet.Erase(key), slice_size_in_byte, binarySet, slice_i);
meta_info[META].emplace_back(slice_i);
}
if (!slice_key_list.empty()) {
auto meta_str = meta_info.dump();
std::shared_ptr<uint8_t[]> meta_data(new uint8_t[meta_str.length() + 1], std::default_delete<uint8_t[]>());
memcpy(meta_data.get(), meta_str.data(), meta_str.length());
meta_data.get()[meta_str.length()] = 0;
binarySet.Append(INDEX_FILE_SLICE_META, meta_data, meta_str.length() + 1);
}
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,33 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <string>
#include "BinarySet.h"
#include "Config.h"
#include "Exception.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace milvus {
namespace knowhere {
extern const char* INDEX_FILE_SLICE_SIZE_IN_MEGABYTE;
extern const char* INDEX_FILE_SLICE_META;
void
Assemble(BinarySet& binarySet);
void
Disassemble(const int64_t& slice_size_in_byte, BinarySet& binarySet);
} // namespace knowhere
} // namespace milvus

View File

@ -38,6 +38,8 @@ enum class OldIndexType {
RHNSW_FLAT,
RHNSW_PQ,
RHNSW_SQ,
NGTPANNG,
NGTONNG,
FAISS_BIN_IDMAP = 100,
FAISS_BIN_IVFLAT_CPU = 101,
};

View File

@ -20,13 +20,11 @@
namespace milvus {
namespace knowhere {
enum class OperatorType { LT = 0, LE = 1, GT = 3, GE = 4 };
enum OperatorType { LT = 0, LE = 1, GT = 3, GE = 4 };
static std::map<std::string, OperatorType> s_map_operator_type = {
{"LT", OperatorType::LT},
{"LE", OperatorType::LE},
{"GT", OperatorType::GT},
{"GE", OperatorType::GE},
{"LT", OperatorType::LT}, {"LTE", OperatorType::LE}, {"GT", OperatorType::GT}, {"GTE", OperatorType::GE},
{"lt", OperatorType::LT}, {"lte", OperatorType::LE}, {"gt", OperatorType::GT}, {"gte", OperatorType::GE},
};
template <typename T>

View File

@ -12,7 +12,7 @@
#include <algorithm>
#include <memory>
#include <utility>
#include "knowhere/common/Log.h"
#include "knowhere/knowhere/common/Log.h"
#include "knowhere/index/structured_index/StructuredIndexFlat.h"
namespace milvus {
@ -31,6 +31,16 @@ template <typename T>
StructuredIndexFlat<T>::~StructuredIndexFlat() {
}
template <typename T>
BinarySet
StructuredIndexFlat<T>::Serialize(const milvus::knowhere::Config& config) {
}
template <typename T>
void
StructuredIndexFlat<T>::Load(const milvus::knowhere::BinarySet& index_binary) {
}
template <typename T>
void
StructuredIndexFlat<T>::Build(const size_t n, const T* values) {
@ -45,14 +55,11 @@ StructuredIndexFlat<T>::Build(const size_t n, const T* values) {
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexFlat<T>::In(const size_t n, const T* values) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
for (size_t i = 0; i < n; ++i) {
for (const auto& index : data_) {
if (index->a_ == *(values + i)) {
bitset->set(index->idx_);
if (index.a_ == *(values + i)) {
bitset->set(index.idx_);
}
}
}
@ -62,14 +69,11 @@ StructuredIndexFlat<T>::In(const size_t n, const T* values) {
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexFlat<T>::NotIn(const size_t n, const T* values) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size(), 0xff);
for (size_t i = 0; i < n; ++i) {
for (const auto& index : data_) {
if (index->a_ == *(values + i)) {
bitset->clear(index->idx_);
if (index.a_ == *(values + i)) {
bitset->clear(index.idx_);
}
}
}
@ -79,31 +83,28 @@ StructuredIndexFlat<T>::NotIn(const size_t n, const T* values) {
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexFlat<T>::Range(const T value, const OperatorType op) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
auto lb = data_.begin();
auto ub = data_.end();
for (; lb <= ub; lb++) {
switch (op) {
case OperatorType::LT:
if (lb < IndexStructure<T>(value)) {
if (*lb < IndexStructure<T>(value)) {
bitset->set(lb->idx_);
}
break;
case OperatorType::LE:
if (lb <= IndexStructure<T>(value)) {
if (*lb <= IndexStructure<T>(value)) {
bitset->set(lb->idx_);
}
break;
case OperatorType::GT:
if (lb > IndexStructure<T>(value)) {
if (*lb > IndexStructure<T>(value)) {
bitset->set(lb->idx_);
}
break;
case OperatorType::GE:
if (lb >= IndexStructure<T>(value)) {
if (*lb >= IndexStructure<T>(value)) {
bitset->set(lb->idx_);
}
break;
@ -117,9 +118,6 @@ StructuredIndexFlat<T>::Range(const T value, const OperatorType op) {
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexFlat<T>::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
if (lower_bound_value > upper_bound_value) {
std::swap(lower_bound_value, upper_bound_value);
@ -129,19 +127,19 @@ StructuredIndexFlat<T>::Range(T lower_bound_value, bool lb_inclusive, T upper_bo
auto ub = data_.end();
for (; lb <= ub; ++lb) {
if (lb_inclusive && ub_inclusive) {
if (lb >= IndexStructure<T>(lower_bound_value) && lb <= IndexStructure<T>(upper_bound_value)) {
if (*lb >= IndexStructure<T>(lower_bound_value) && *lb <= IndexStructure<T>(upper_bound_value)) {
bitset->set(lb->idx_);
}
} else if (lb_inclusive && !ub_inclusive) {
if (lb >= IndexStructure<T>(lower_bound_value) && lb < IndexStructure<T>(upper_bound_value)) {
if (*lb >= IndexStructure<T>(lower_bound_value) && *lb < IndexStructure<T>(upper_bound_value)) {
bitset->set(lb->idx_);
}
} else if (!lb_inclusive && ub_inclusive) {
if (lb > IndexStructure<T>(lower_bound_value) && lb <= IndexStructure<T>(upper_bound_value)) {
if (*lb > IndexStructure<T>(lower_bound_value) && *lb <= IndexStructure<T>(upper_bound_value)) {
bitset->set(lb->idx_);
}
} else {
if (lb > IndexStructure<T>(lower_bound_value) && lb < IndexStructure<T>(upper_bound_value)) {
if (*lb > IndexStructure<T>(lower_bound_value) && *lb < IndexStructure<T>(upper_bound_value)) {
bitset->set(lb->idx_);
}
}

View File

@ -37,9 +37,6 @@ class StructuredIndexFlat : public StructuredIndex<T> {
void
Build(const size_t n, const T* values) override;
void
build();
const faiss::ConcurrentBitsetPtr
In(const size_t n, const T* values) override;
@ -59,7 +56,7 @@ class StructuredIndexFlat : public StructuredIndex<T> {
int64_t
Size() override {
return (int64_t)data_.size();
return (int64_t)data_.size() * sizeof(IndexStructure<T>);
}
bool

View File

@ -59,7 +59,7 @@ class StructuredIndexSort : public StructuredIndex<T> {
int64_t
Size() override {
return (int64_t)data_.size();
return (int64_t)data_.size() * sizeof(IndexStructure<T>);
}
bool

View File

@ -15,6 +15,7 @@
#include <memory>
#include <string>
#include <vector>
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#ifdef MILVUS_GPU_VERSION
@ -24,14 +25,14 @@
namespace milvus {
namespace knowhere {
static const int64_t MIN_NBITS = 1;
static const int64_t MAX_NBITS = 16;
static const int64_t MIN_NLIST = 1;
static const int64_t MAX_NLIST = 65536;
static const int64_t MIN_NPROBE = 1;
static const int64_t MAX_NPROBE = MAX_NLIST;
static const int64_t DEFAULT_MIN_DIM = 1;
static const int64_t DEFAULT_MAX_DIM = 32768;
static const int64_t DEFAULT_MIN_ROWS = 1; // minimum size for build index
static const int64_t DEFAULT_MAX_ROWS = 50000000;
static const int64_t NGT_MIN_EDGE_SIZE = 1;
static const int64_t NGT_MAX_EDGE_SIZE = 200;
static const int64_t HNSW_MIN_EFCONSTRUCTION = 8;
@ -47,6 +48,12 @@ static const std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Me
return false; \
}
#define CheckFloatByRange(key, min, max) \
if (!oricfg.contains(key) || !oricfg[key].is_number_float() || oricfg[key].get<float>() > max || \
oricfg[key].get<float>() < min) { \
return false; \
}
#define CheckIntByValues(key, container) \
if (!oricfg.contains(key) || !oricfg[key].is_number_integer()) { \
return false; \
@ -84,33 +91,42 @@ ConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode m
int64_t
MatchNlist(int64_t size, int64_t nlist) {
const int64_t TYPICAL_COUNT = 1000000;
const int64_t PER_NLIST = 16384;
const int64_t MIN_POINTS_PER_CENTROID = 40;
if (nlist * TYPICAL_COUNT > size * PER_NLIST) {
if (nlist * MIN_POINTS_PER_CENTROID > size) {
// nlist is too large, adjust to a proper value
nlist = std::max(1L, size * PER_NLIST / TYPICAL_COUNT);
nlist = std::max(1L, size / MIN_POINTS_PER_CENTROID);
LOG_KNOWHERE_WARNING_ << "Row num " << size << " match nlist " << nlist;
}
return nlist;
}
int64_t
MatchNbits(int64_t size, int64_t nbits) {
if (size < (1 << nbits)) {
// nbits is too large, adjust to a proper value
if (size >= (1 << 8)) {
nbits = 8;
} else if (size >= (1 << 4)) {
nbits = 4;
} else if (size >= (1 << 2)) {
nbits = 2;
} else {
nbits = 1;
}
LOG_KNOWHERE_WARNING_ << "Row num " << size << " match nbits " << nbits;
}
return nbits;
}
bool
IVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
// int64_t nlist = oricfg[knowhere::IndexParams::nlist];
// CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
// auto tune params
auto nq = oricfg[knowhere::meta::ROWS].get<int64_t>();
auto rows = oricfg[knowhere::meta::ROWS].get<int64_t>();
auto nlist = oricfg[knowhere::IndexParams::nlist].get<int64_t>();
oricfg[knowhere::IndexParams::nlist] = MatchNlist(nq, nlist);
// Best Practice
// static int64_t MIN_POINTS_PER_CENTROID = 40;
// static int64_t MAX_POINTS_PER_CENTROID = 256;
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
oricfg[knowhere::IndexParams::nlist] = MatchNlist(rows, nlist);
return ConfAdapter::CheckTrain(oricfg, mode);
}
@ -138,49 +154,38 @@ IVFSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
bool
IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
const int64_t DEFAULT_NBITS = 8;
if (!IVFConfAdapter::CheckTrain(oricfg, mode)) {
return false;
}
oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS;
CheckIntByRange(knowhere::IndexParams::nbits, MIN_NBITS, MAX_NBITS);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
auto rows = oricfg[knowhere::meta::ROWS].get<int64_t>();
auto nbits = oricfg[knowhere::IndexParams::nbits].get<int64_t>();
oricfg[knowhere::IndexParams::nbits] = MatchNbits(rows, nbits);
// int64_t nlist = oricfg[knowhere::IndexParams::nlist];
// CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
// auto tune params
oricfg[knowhere::IndexParams::nlist] =
MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), oricfg[knowhere::IndexParams::nlist].get<int64_t>());
auto m = oricfg[knowhere::IndexParams::m].get<int64_t>();
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
// Best Practice
// static int64_t MIN_POINTS_PER_CENTROID = 40;
// static int64_t MAX_POINTS_PER_CENTROID = 256;
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
/*std::vector<int64_t> resset;
IVFPQConfAdapter::GetValidCPUM(dimension, resset);*/
IndexMode ivfpq_mode = mode;
return GetValidM(dimension, m, ivfpq_mode);
return CheckPQParams(dimension, m, nbits, ivfpq_mode);
}
bool
IVFPQConfAdapter::GetValidM(int64_t dimension, int64_t m, IndexMode& mode) {
IVFPQConfAdapter::CheckPQParams(int64_t dimension, int64_t m, int64_t nbits, IndexMode& mode) {
#ifdef MILVUS_GPU_VERSION
if (mode == knowhere::IndexMode::MODE_GPU && !IVFPQConfAdapter::GetValidGPUM(dimension, m)) {
if (mode == knowhere::IndexMode::MODE_GPU && !IVFPQConfAdapter::CheckGPUPQParams(dimension, m, nbits)) {
mode = knowhere::IndexMode::MODE_CPU;
}
#endif
if (mode == knowhere::IndexMode::MODE_CPU && !IVFPQConfAdapter::GetValidCPUM(dimension, m)) {
if (mode == knowhere::IndexMode::MODE_CPU && !IVFPQConfAdapter::CheckCPUPQParams(dimension, m)) {
return false;
}
return true;
}
bool
IVFPQConfAdapter::GetValidGPUM(int64_t dimension, int64_t m) {
IVFPQConfAdapter::CheckGPUPQParams(int64_t dimension, int64_t m, int64_t nbits) {
/*
* Faiss 1.6
* Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with
@ -193,22 +198,12 @@ IVFPQConfAdapter::GetValidGPUM(int64_t dimension, int64_t m) {
return (std::find(std::begin(support_subquantizer), std::end(support_subquantizer), m) !=
support_subquantizer.end()) &&
(std::find(std::begin(support_dim_per_subquantizer), std::end(support_dim_per_subquantizer), sub_dim) !=
support_dim_per_subquantizer.end());
/*resset.clear();
for (const auto& dimperquantizer : support_dim_per_subquantizer) {
if (!(dimension % dimperquantizer)) {
auto subquantzier_num = dimension / dimperquantizer;
auto finder = std::find(support_subquantizer.begin(), support_subquantizer.end(), subquantzier_num);
if (finder != support_subquantizer.end()) {
resset.push_back(subquantzier_num);
}
}
}*/
support_dim_per_subquantizer.end()) &&
(nbits == 8);
}
bool
IVFPQConfAdapter::GetValidCPUM(int64_t dimension, int64_t m) {
IVFPQConfAdapter::CheckCPUPQParams(int64_t dimension, int64_t m) {
return (dimension % m == 0);
}
@ -224,7 +219,6 @@ NSGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
const int64_t MAX_CANDIDATE_POOL_SIZE = 1000;
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::knng, MIN_KNNG, MAX_KNNG);
CheckIntByRange(knowhere::IndexParams::search_length, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH);
CheckIntByRange(knowhere::IndexParams::out_degree, MIN_OUT_DEGREE, MAX_OUT_DEGREE);
@ -251,7 +245,6 @@ NSGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMod
bool
HNSWConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
@ -267,7 +260,6 @@ HNSWConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMo
bool
RHNSWFlatConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
@ -283,13 +275,12 @@ RHNSWFlatConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const In
bool
RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
IVFPQConfAdapter::GetValidCPUM(dimension, oricfg[knowhere::IndexParams::PQM].get<int64_t>());
IVFPQConfAdapter::CheckCPUPQParams(dimension, oricfg[knowhere::IndexParams::PQM].get<int64_t>());
return ConfAdapter::CheckTrain(oricfg, mode);
}
@ -303,7 +294,6 @@ RHNSWPQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const Inde
bool
RHNSWSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
@ -334,18 +324,14 @@ BinIVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static const std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
knowhere::Metric::TANIMOTO};
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
int64_t nlist = oricfg[knowhere::IndexParams::nlist];
CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
// Best Practice
// static int64_t MIN_POINTS_PER_CENTROID = 40;
// static int64_t MAX_POINTS_PER_CENTROID = 256;
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
// auto tune params
auto rows = oricfg[knowhere::meta::ROWS].get<int64_t>();
auto nlist = oricfg[knowhere::IndexParams::nlist].get<int64_t>();
oricfg[knowhere::IndexParams::nlist] = MatchNlist(rows, nlist);
return true;
}
@ -370,35 +356,37 @@ ANNOYConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexM
bool
NGTPANNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
CheckIntByRange(knowhere::IndexParams::edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
CheckIntByRange(knowhere::IndexParams::forcedly_pruned_edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
CheckIntByRange(knowhere::IndexParams::selectively_pruned_edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
if (oricfg[knowhere::IndexParams::selectively_pruned_edge_size].get<int64_t>() >=
oricfg[knowhere::IndexParams::forcedly_pruned_edge_size].get<int64_t>()) {
return false;
}
return true;
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
NGTPANNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
CheckIntByRange(knowhere::IndexParams::max_search_edges, -1, NGT_MAX_EDGE_SIZE);
CheckFloatByRange(knowhere::IndexParams::epsilon, -1.0, 1.0);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
NGTONNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
CheckIntByRange(knowhere::IndexParams::edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
CheckIntByRange(knowhere::IndexParams::outgoing_edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
CheckIntByRange(knowhere::IndexParams::incoming_edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
return true;
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
NGTONNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
CheckIntByRange(knowhere::IndexParams::max_search_edges, -1, NGT_MAX_EDGE_SIZE);
CheckFloatByRange(knowhere::IndexParams::epsilon, -1.0, 1.0);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}

View File

@ -52,13 +52,13 @@ class IVFPQConfAdapter : public IVFConfAdapter {
CheckTrain(Config& oricfg, const IndexMode mode) override;
static bool
GetValidM(int64_t dimension, int64_t m, IndexMode& mode);
CheckPQParams(int64_t dimension, int64_t m, int64_t nbits, IndexMode& mode);
static bool
GetValidGPUM(int64_t dimension, int64_t m);
CheckGPUPQParams(int64_t dimension, int64_t m, int64_t nbits);
static bool
GetValidCPUM(int64_t dimension, int64_t m);
CheckCPUPQParams(int64_t dimension, int64_t m);
};
class NSGConfAdapter : public IVFConfAdapter {

View File

@ -48,11 +48,15 @@ IndexAnnoy::Serialize(const Config& config) {
res_set.Append("annoy_metric_type", metric_type, metric_type_length);
res_set.Append("annoy_dim", dim_data, sizeof(uint64_t));
res_set.Append("annoy_index_data", index_data, index_length);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
}
return res_set;
}
void
IndexAnnoy::Load(const BinarySet& index_binary) {
Assemble(const_cast<BinarySet&>(index_binary));
auto metric_type = index_binary.GetByName("annoy_metric_type");
metric_type_.resize(static_cast<size_t>(metric_type->size));
memcpy(metric_type_.data(), metric_type->data.get(), static_cast<size_t>(metric_type->size));
@ -105,7 +109,7 @@ IndexAnnoy::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}

View File

@ -54,7 +54,7 @@ class IndexAnnoy : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
int64_t
Count() override;

View File

@ -30,17 +30,23 @@ BinaryIDMAP::Serialize(const Config& config) {
}
std::lock_guard<std::mutex> lk(mutex_);
return SerializeImpl(index_type_);
// return SerializeImpl(index_type_);
auto ret = SerializeImpl(index_type_);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, ret);
}
return ret;
}
void
BinaryIDMAP::Load(const BinarySet& index_binary) {
Assemble(const_cast<BinarySet&>(index_binary));
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(index_binary, index_type_);
}
DatasetPtr
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
@ -147,7 +153,7 @@ BinaryIDMAP::QueryImpl(int64_t n,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
const faiss::BitsetView& bitset) {
// assign the metric type
auto bin_flat_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get())->index;
bin_flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());

View File

@ -48,7 +48,7 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
int64_t
Count() override;
@ -75,7 +75,7 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset);
const faiss::BitsetView& bitset);
protected:
std::mutex mutex_;

View File

@ -33,17 +33,27 @@ BinaryIVF::Serialize(const Config& config) {
}
std::lock_guard<std::mutex> lk(mutex_);
return SerializeImpl(index_type_);
auto ret = SerializeImpl(index_type_);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, ret);
}
return ret;
}
void
BinaryIVF::Load(const BinarySet& index_binary) {
Assemble(const_cast<BinarySet&>(index_binary));
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(index_binary, index_type_);
if (STATISTICS_LEVEL >= 3) {
auto ivf_index = static_cast<faiss::IndexBinaryIVF*>(index_.get());
ivf_index->nprobe_statistics.resize(ivf_index->nlist, 0);
}
}
DatasetPtr
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -104,6 +114,31 @@ BinaryIVF::UpdateIndexSize() {
index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size;
}
StatisticsPtr
BinaryIVF::GetStatistics() {
if (!STATISTICS_LEVEL) {
return stats;
}
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
auto lock = ivf_stats->Lock();
ivf_stats->update_ivf_access_stats(ivf_index->nprobe_statistics);
return ivf_stats;
}
void
BinaryIVF::ClearStatistics() {
if (!STATISTICS_LEVEL) {
return;
}
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
ivf_index->clear_nprobe_statistics();
ivf_index->index_ivf_stats.reset();
auto lock = ivf_stats->Lock();
ivf_stats->clear();
}
void
BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR(dataset_ptr)
@ -112,6 +147,7 @@ BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
faiss::IndexBinary* coarse_quantizer = new faiss::IndexBinaryFlat(dim, metric_type);
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, nlist, metric_type);
index->own_fields = true;
index->train(rows, static_cast<const uint8_t*>(p_data));
index->add_with_ids(rows, static_cast<const uint8_t*>(p_data), p_ids);
index_ = index;
@ -132,22 +168,37 @@ BinaryIVF::QueryImpl(int64_t n,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
const faiss::BitsetView& bitset) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
stdclock::time_point before = stdclock::now();
auto i_distances = reinterpret_cast<int32_t*>(distances);
index_->search(n, data, k, i_distances, labels, bitset);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
faiss::indexIVF_stats.quantization_time = 0;
faiss::indexIVF_stats.search_time = 0;
LOG_KNOWHERE_DEBUG_ << "IVF_NM search cost: " << search_cost
<< ", quantization cost: " << ivf_index->index_ivf_stats.quantization_time
<< ", data search cost: " << ivf_index->index_ivf_stats.search_time;
if (STATISTICS_LEVEL) {
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
auto lock = ivf_stats->Lock();
if (STATISTICS_LEVEL >= 1) {
ivf_stats->update_nq(n);
ivf_stats->count_nprobe(ivf_index->nprobe);
ivf_stats->update_total_query_time(ivf_index->index_ivf_stats.quantization_time +
ivf_index->index_ivf_stats.search_time);
ivf_index->index_ivf_stats.quantization_time = 0;
ivf_index->index_ivf_stats.search_time = 0;
}
if (STATISTICS_LEVEL >= 2) {
ivf_stats->update_filter_percentage(bitset);
}
}
// if hamming, it need transform int32 to float
if (ivf_index->metric_type == faiss::METRIC_Hamming) {

View File

@ -29,10 +29,12 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
public:
BinaryIVF() : FaissBaseBinaryIndex(nullptr) {
index_type_ = IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
}
explicit BinaryIVF(std::shared_ptr<faiss::IndexBinary> index) : FaissBaseBinaryIndex(std::move(index)) {
index_type_ = IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
}
BinarySet
@ -60,7 +62,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
int64_t
Count() override;
@ -71,6 +73,12 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
void
UpdateIndexSize() override;
StatisticsPtr
GetStatistics() override;
void
ClearStatistics() override;
protected:
virtual std::shared_ptr<faiss::IVFSearchParameters>
GenParams(const Config& config);
@ -82,7 +90,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset);
const faiss::BitsetView& bitset);
protected:
std::mutex mutex_;

View File

@ -13,6 +13,7 @@
#include <algorithm>
#include <cassert>
#include <chrono>
#include <iterator>
#include <string>
#include <utility>
@ -20,6 +21,7 @@
#include "faiss/BuilderSuspend.h"
#include "hnswlib/hnswalg.h"
#include "hnswlib/hnswlib.h"
#include "hnswlib/space_ip.h"
#include "hnswlib/space_l2.h"
#include "knowhere/common/Exception.h"
@ -51,6 +53,9 @@ IndexHNSW::Serialize(const Config& config) {
BinarySet res_set;
res_set.Append("HNSW", data, writer.rp);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
}
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
@ -60,6 +65,7 @@ IndexHNSW::Serialize(const Config& config) {
void
IndexHNSW::Load(const BinarySet& index_binary) {
try {
Assemble(const_cast<BinarySet&>(index_binary));
auto binary = index_binary.GetByName("HNSW");
MemoryIOReader reader;
@ -68,7 +74,15 @@ IndexHNSW::Load(const BinarySet& index_binary) {
hnswlib::SpaceInterface<float>* space = nullptr;
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space);
index_->stats_enable = (STATISTICS_LEVEL >= 3);
index_->loadIndex(reader);
auto hnsw_stats = std::static_pointer_cast<LibHNSWStatistics>(stats);
if (STATISTICS_LEVEL >= 3) {
auto lock = hnsw_stats->Lock();
hnsw_stats->update_level_distribution(index_->maxlevel_, index_->level_stats_);
}
// LOG_KNOWHERE_DEBUG_ << "IndexHNSW::Load finished, show statistics:";
// LOG_KNOWHERE_DEBUG_ << hnsw_stats->ToString();
normalize = index_->metric_type_ == 1; // 1 == InnerProduct
} catch (std::exception& e) {
@ -94,6 +108,7 @@ IndexHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) {
}
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space, rows, config[IndexParams::M].get<int64_t>(),
config[IndexParams::efConstruction].get<int64_t>());
index_->stats_enable = (STATISTICS_LEVEL >= 3);
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
@ -133,10 +148,17 @@ IndexHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
faiss::BuilderSuspend::check_wait();
index_->addPoint((reinterpret_cast<const float*>(p_data) + Dim() * i), p_ids[i]);
}
if (STATISTICS_LEVEL >= 3) {
auto hnsw_stats = std::static_pointer_cast<LibHNSWStatistics>(stats);
auto lock = hnsw_stats->Lock();
hnsw_stats->update_level_distribution(index_->maxlevel_, index_->level_stats_);
}
// LOG_KNOWHERE_DEBUG_ << "IndexHNSW::Train finished, show statistics:";
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
}
DatasetPtr
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -147,12 +169,23 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
size_t dist_size = sizeof(float) * k;
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
auto p_dist = static_cast<float*>(malloc(dist_size * rows));
std::vector<hnswlib::StatisticsInfo> query_stats;
auto hnsw_stats = std::dynamic_pointer_cast<LibHNSWStatistics>(stats);
if (STATISTICS_LEVEL >= 3) {
query_stats.resize(rows);
for (auto i = 0; i < rows; ++i) {
query_stats[i].target_level = hnsw_stats->target_level;
}
}
index_->setEf(config[IndexParams::ef]);
index_->setEf(config[IndexParams::ef].get<int64_t>());
using P = std::pair<float, int64_t>;
auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; };
std::chrono::high_resolution_clock::time_point query_start, query_end;
query_start = std::chrono::high_resolution_clock::now();
#pragma omp parallel for
for (unsigned int i = 0; i < rows; ++i) {
std::vector<P> ret;
@ -165,7 +198,12 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
// } else {
// ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
// }
ret = index_->searchKnn(single_query, k, compare, bitset);
if (STATISTICS_LEVEL >= 3) {
ret = index_->searchKnn(single_query, k, compare, bitset, query_stats[i]);
} else {
auto dummy_stat = hnswlib::StatisticsInfo();
ret = index_->searchKnn(single_query, k, compare, bitset, dummy_stat);
}
while (ret.size() < k) {
ret.emplace_back(std::make_pair(-1, -1));
@ -186,6 +224,33 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
memcpy(p_dist + i * k, dist.data(), dist_size);
memcpy(p_id + i * k, ids.data(), id_size);
}
query_end = std::chrono::high_resolution_clock::now();
if (STATISTICS_LEVEL) {
auto lock = hnsw_stats->Lock();
if (STATISTICS_LEVEL >= 1) {
hnsw_stats->update_nq(rows);
hnsw_stats->update_ef_sum(index_->ef_ * rows);
hnsw_stats->update_total_query_time(
std::chrono::duration_cast<std::chrono::milliseconds>(query_end - query_start).count());
}
if (STATISTICS_LEVEL >= 2) {
hnsw_stats->update_filter_percentage(bitset);
}
if (STATISTICS_LEVEL >= 3) {
for (auto i = 0; i < rows; ++i) {
for (auto j = 0; j < query_stats[i].accessed_points.size(); ++j) {
auto tgt = hnsw_stats->access_cnt_map.find(query_stats[i].accessed_points[j]);
if (tgt == hnsw_stats->access_cnt_map.end())
hnsw_stats->access_cnt_map[query_stats[i].accessed_points[j]] = 1;
else
tgt->second += 1;
}
}
}
}
// LOG_KNOWHERE_DEBUG_ << "IndexHNSW::Query finished, show statistics:";
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
@ -217,5 +282,14 @@ IndexHNSW::UpdateIndexSize() {
index_size_ = index_->cal_size();
}
void
IndexHNSW::ClearStatistics() {
if (!STATISTICS_LEVEL)
return;
auto hnsw_stats = std::static_pointer_cast<LibHNSWStatistics>(stats);
auto lock = hnsw_stats->Lock();
hnsw_stats->clear();
}
} // namespace knowhere
} // namespace milvus

View File

@ -26,6 +26,7 @@ class IndexHNSW : public VecIndex {
public:
IndexHNSW() {
index_type_ = IndexEnum::INDEX_HNSW;
stats = std::make_shared<milvus::knowhere::LibHNSWStatistics>(index_type_);
}
BinarySet
@ -46,7 +47,7 @@ class IndexHNSW : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
int64_t
Count() override;
@ -57,6 +58,9 @@ class IndexHNSW : public VecIndex {
void
UpdateIndexSize() override;
void
ClearStatistics() override;
private:
bool normalize = false;
std::mutex mutex_;

View File

@ -43,11 +43,16 @@ IDMAP::Serialize(const Config& config) {
}
std::lock_guard<std::mutex> lk(mutex_);
return SerializeImpl(index_type_);
auto ret = SerializeImpl(index_type_);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, ret);
}
return ret;
}
void
IDMAP::Load(const BinarySet& binary_set) {
Assemble(const_cast<BinarySet&>(binary_set));
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(binary_set, index_type_);
}
@ -95,7 +100,7 @@ IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
@ -229,7 +234,7 @@ IDMAP::QueryImpl(int64_t n,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
const faiss::BitsetView& bitset) {
// assign the metric type
auto flat_index = dynamic_cast<faiss::IndexIDMap*>(index_.get())->index;
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());

View File

@ -46,7 +46,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
#if 0
DatasetPtr
@ -80,7 +80,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
protected:
virtual void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&);
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView&);
protected:
std::mutex mutex_;

View File

@ -54,24 +54,37 @@ IVF::Serialize(const Config& config) {
}
std::lock_guard<std::mutex> lk(mutex_);
return SerializeImpl(index_type_);
auto ret = SerializeImpl(index_type_);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, ret);
}
return ret;
}
void
IVF::Load(const BinarySet& binary_set) {
Assemble(const_cast<BinarySet&>(binary_set));
std::lock_guard<std::mutex> lk(mutex_);
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
LoadImpl(binary_set, index_type_);
if (IndexMode() == IndexMode::MODE_CPU && STATISTICS_LEVEL >= 3) {
auto ivf_index = static_cast<faiss::IndexIVFFlat*>(index_.get());
ivf_index->nprobe_statistics.resize(ivf_index->nlist, 0);
}
}
void
IVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA_DIM(dataset_ptr)
int64_t nlist = config[IndexParams::nlist].get<int64_t>();
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
auto nlist = config[IndexParams::nlist].get<int64_t>();
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexIVFFlat(coarse_quantizer, dim, nlist, metric_type));
index_->train(rows, reinterpret_cast<const float*>(p_data));
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, nlist, metric_type);
index->own_fields = true;
index->train(rows, reinterpret_cast<const float*>(p_data));
index_ = index;
}
void
@ -97,7 +110,7 @@ IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -245,7 +258,7 @@ IVF::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
auto ivf_index = dynamic_cast<faiss::IndexIVFFlat*>(index_.get());
auto ivf_index = static_cast<faiss::IndexIVFFlat*>(index_.get());
auto nb = ivf_index->invlists->compute_ntotal();
auto nlist = ivf_index->nlist;
auto code_size = ivf_index->code_size;
@ -324,7 +337,7 @@ IVF::QueryImpl(int64_t n,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
const faiss::BitsetView& bitset) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->nprobe = std::min(params->nprobe, ivf_index->invlists->nlist);
@ -334,14 +347,30 @@ IVF::QueryImpl(int64_t n,
} else {
ivf_index->parallel_mode = 0;
}
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
ivf_index->search(n, data, k, distances, labels, bitset);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
faiss::indexIVF_stats.quantization_time = 0;
faiss::indexIVF_stats.search_time = 0;
if (STATISTICS_LEVEL) {
auto lock = ivf_stats->Lock();
if (STATISTICS_LEVEL >= 1) {
ivf_stats->update_nq(n);
ivf_stats->count_nprobe(ivf_index->nprobe);
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost
<< ", quantization cost: " << ivf_index->index_ivf_stats.quantization_time
<< ", data search cost: " << ivf_index->index_ivf_stats.search_time;
ivf_stats->update_total_query_time(ivf_index->index_ivf_stats.quantization_time +
ivf_index->index_ivf_stats.search_time);
ivf_index->index_ivf_stats.quantization_time = 0;
ivf_index->index_ivf_stats.search_time = 0;
}
if (STATISTICS_LEVEL >= 2) {
ivf_stats->update_filter_percentage(bitset);
}
}
// LOG_KNOWHERE_DEBUG_ << "IndexIVF::QueryImpl finished, show statistics:";
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
}
void
@ -355,5 +384,30 @@ IVF::SealImpl() {
#endif
}
StatisticsPtr
IVF::GetStatistics() {
if (IndexMode() != IndexMode::MODE_CPU || !STATISTICS_LEVEL) {
return stats;
}
auto ivf_stats = std::static_pointer_cast<IVFStatistics>(stats);
auto ivf_index = static_cast<faiss::IndexIVF*>(index_.get());
auto lock = ivf_stats->Lock();
ivf_stats->update_ivf_access_stats(ivf_index->nprobe_statistics);
return ivf_stats;
}
void
IVF::ClearStatistics() {
if (IndexMode() != IndexMode::MODE_CPU || !STATISTICS_LEVEL) {
return;
}
auto ivf_stats = std::static_pointer_cast<IVFStatistics>(stats);
auto ivf_index = static_cast<faiss::IndexIVF*>(index_.get());
ivf_index->clear_nprobe_statistics();
ivf_index->index_ivf_stats.reset();
auto lock = ivf_stats->Lock();
ivf_stats->clear();
}
} // namespace knowhere
} // namespace milvus

View File

@ -29,10 +29,12 @@ class IVF : public VecIndex, public FaissBaseIndex {
public:
IVF() : FaissBaseIndex(nullptr) {
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
}
explicit IVF(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
}
BinarySet
@ -51,7 +53,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
#if 0
DatasetPtr
@ -67,6 +69,12 @@ class IVF : public VecIndex, public FaissBaseIndex {
void
UpdateIndexSize() override;
StatisticsPtr
GetStatistics() override;
void
ClearStatistics() override;
#if 0
DatasetPtr
GetVectorById(const DatasetPtr& dataset, const Config& config) override;
@ -86,7 +94,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
GenParams(const Config&);
virtual void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&);
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView&);
void
SealImpl() override;

View File

@ -38,11 +38,12 @@ IVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexIVFPQ(
coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(), config[IndexParams::m].get<int64_t>(),
config[IndexParams::nbits].get<int64_t>(), metric_type));
index_->train(rows, reinterpret_cast<const float*>(p_data));
auto index = std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(),
config[IndexParams::m].get<int64_t>(),
config[IndexParams::nbits].get<int64_t>(), metric_type);
index->own_fields = true;
index->train(rows, reinterpret_cast<const float*>(p_data));
index_ = index;
}
VecIndexPtr
@ -51,7 +52,8 @@ IVFPQ::CopyCpuToGpu(const int64_t device_id, const Config& config) {
auto ivfpq_index = dynamic_cast<faiss::IndexIVFPQ*>(index_.get());
int64_t dim = ivfpq_index->d;
int64_t m = ivfpq_index->pq.M;
if (!IVFPQConfAdapter::GetValidGPUM(dim, m)) {
int64_t nbits = ivfpq_index->pq.nbits;
if (!IVFPQConfAdapter::CheckGPUPQParams(dim, m, nbits)) {
return nullptr;
}
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {

View File

@ -23,10 +23,12 @@ class IVFPQ : public IVF {
public:
IVFPQ() : IVF() {
index_type_ = IndexEnum::INDEX_FAISS_IVFPQ;
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
}
explicit IVFPQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
index_type_ = IndexEnum::INDEX_FAISS_IVFPQ;
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
}
void

View File

@ -37,18 +37,13 @@ void
IVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA_DIM(dataset_ptr)
// std::stringstream index_type;
// index_type << "IVF" << config[IndexParams::nlist] << ","
// << "SQ" << config[IndexParams::nbits];
// index_ = std::shared_ptr<faiss::Index>(
// faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>())));
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexIVFScalarQuantizer(
coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(), faiss::QuantizerType::QT_8bit, metric_type));
index_->train(rows, reinterpret_cast<const float*>(p_data));
auto index = std::make_shared<faiss::IndexIVFScalarQuantizer>(
coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(), faiss::QuantizerType::QT_8bit, metric_type);
index->own_fields = true;
index->train(rows, reinterpret_cast<const float*>(p_data));
index_ = index;
}
VecIndexPtr

View File

@ -23,10 +23,12 @@ class IVFSQ : public IVF {
public:
IVFSQ() : IVF() {
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8;
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
}
explicit IVFSQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8;
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
}
void

View File

@ -52,11 +52,15 @@ IndexNGT::Serialize(const Config& config) {
res_set.Append("ngt_grp_data", grp_data, grp_size);
res_set.Append("ngt_prf_data", prf_data, prf_size);
res_set.Append("ngt_tre_data", tre_data, tre_size);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
}
return res_set;
}
void
IndexNGT::Load(const BinarySet& index_binary) {
Assemble(const_cast<BinarySet&>(index_binary));
auto obj_data = index_binary.GetByName("ngt_obj_data");
std::string obj_str(reinterpret_cast<char*>(obj_data->data.get()), obj_data->size);
@ -118,13 +122,18 @@ IndexNGT::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
#endif
DatasetPtr
IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GET_TENSOR_DATA(dataset_ptr);
size_t k = config[meta::TOPK].get<int64_t>();
int k = config[meta::TOPK].get<int>();
auto epsilon = config[IndexParams::epsilon].get<float>();
auto edge_size = config[IndexParams::max_search_edges].get<int>();
if (edge_size == -1) { // pass -1
edge_size--;
}
size_t id_size = sizeof(int64_t) * k;
size_t dist_size = sizeof(float) * k;
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
@ -140,11 +149,11 @@ IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss
NGT::Object* object = index_->allocateObject(single_query, Dim());
NGT::SearchContainer sc(*object);
double epsilon = sp.beginOfEpsilon;
// double epsilon = sp.beginOfEpsilon;
NGT::ObjectDistances res;
sc.setResults(&res);
sc.setSize(sp.size);
sc.setSize(static_cast<size_t>(sp.size));
sc.setRadius(sp.radius);
if (sp.accuracy > 0.0) {
@ -152,7 +161,8 @@ IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss
} else {
sc.setEpsilon(epsilon);
}
sc.setEdgeSize(sp.edgeSize);
// sc.setEdgeSize(sp.edgeSize);
sc.setEdgeSize(edge_size);
try {
index_->search(sc, bitset);
@ -164,9 +174,13 @@ IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss
auto local_dist = p_dist + i * k;
int64_t res_num = res.size();
float dis_coefficient = 1.0;
if (index_->getObjectSpace().getDistanceType() == NGT::ObjectSpace::DistanceType::DistanceTypeIP) {
dis_coefficient = -1.0;
}
for (int64_t idx = 0; idx < res_num; ++idx) {
*(local_id + idx) = res[idx].id - 1;
*(local_dist + idx) = res[idx].distance;
*(local_dist + idx) = res[idx].distance * dis_coefficient;
}
while (res_num < static_cast<int64_t>(k)) {
*(local_id + res_num) = -1;
@ -197,5 +211,10 @@ IndexNGT::Dim() {
return index_->getDimension();
}
void
IndexNGT::UpdateIndexSize() {
KNOWHERE_THROW_MSG("IndexNGT has no implementation of UpdateIndexSize, please use IndexNGT(PANNG/ONNG) instead!");
}
} // namespace knowhere
} // namespace milvus

View File

@ -54,7 +54,7 @@ class IndexNGT : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
int64_t
Count() override;
@ -62,6 +62,9 @@ class IndexNGT : public VecIndex {
int64_t
Dim() override;
void
UpdateIndexSize() override;
protected:
std::shared_ptr<NGT::Index> index_ = nullptr;
};

View File

@ -38,10 +38,8 @@ IndexNGTONNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
if (metric_type == Metric::L2) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
} else if (metric_type == Metric::HAMMING) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
} else if (metric_type == Metric::JACCARD) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
} else if (metric_type == Metric::IP) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeIP;
} else {
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
}
@ -50,7 +48,7 @@ IndexNGTONNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
// reconstruct graph
NGT::GraphOptimizer graphOptimizer(true);
NGT::GraphOptimizer graphOptimizer(false);
auto number_of_outgoing_edges = config[IndexParams::outgoing_edge_size].get<size_t>();
auto number_of_incoming_edges = config[IndexParams::incoming_edge_size].get<size_t>();
@ -67,5 +65,13 @@ IndexNGTONNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
graphOptimizer.execute(*index_);
}
void
IndexNGTONNG::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
index_size_ = index_->memSize();
}
} // namespace knowhere
} // namespace milvus

View File

@ -24,6 +24,9 @@ class IndexNGTONNG : public IndexNGT {
void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
void
UpdateIndexSize() override;
};
} // namespace knowhere

View File

@ -27,16 +27,14 @@ IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
prop.dimension = dim;
auto edge_size = config[IndexParams::edge_size].get<int64_t>();
prop.edgeSizeLimitForCreation = edge_size;
prop.edgeSizeForCreation = edge_size;
MetricType metric_type = config[Metric::TYPE];
if (metric_type == Metric::L2) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
} else if (metric_type == Metric::HAMMING) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
} else if (metric_type == Metric::JACCARD) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
} else if (metric_type == Metric::IP) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeIP;
} else {
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
}
@ -48,6 +46,8 @@ IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
auto selectively_pruned_edge_size = config[IndexParams::selectively_pruned_edge_size].get<int64_t>();
if (!forcedly_pruned_edge_size && !selectively_pruned_edge_size) {
KNOWHERE_THROW_MSG(
"a lack of parameters forcedly_pruned_edge_size and selectively_pruned_edge_size 4 index NGTPANNG");
return;
}
@ -56,11 +56,23 @@ IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
KNOWHERE_THROW_MSG("Selectively pruned edge size should less than remaining edge size");
}
// std::map<size_t, size_t> stats;
// size_t max_len = 0;
// prune
auto& graph = dynamic_cast<NGT::GraphIndex&>(index_->getIndex());
for (size_t id = 1; id < graph.repository.size(); id++) {
try {
NGT::GraphNode& node = *graph.getNode(id);
// auto sz = node.size();
// if (max_len < sz)
// max_len = sz;
// auto fd = stats.find(sz);
// if (fd != stats.end()) {
// fd->second ++;
// } else {
// stats[sz] = 1;
// }
if (node.size() >= forcedly_pruned_edge_size) {
node.resize(forcedly_pruned_edge_size);
}
@ -73,7 +85,7 @@ IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
if (t1 >= selectively_pruned_edge_size) {
break;
}
if (rank == t1) {
if (rank == t1) { // can't reach here
continue;
}
NGT::GraphNode& node2 = *graph.getNode(node[t1].id);
@ -101,6 +113,25 @@ IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
continue;
}
}
/*
std::vector<size_t> cnt(max_len, 0);
for (auto &pr : stats) {
cnt[pr.first] = pr.second;
}
for (auto i = 0; i < cnt.size(); ++ i) {
if (cnt[i]) {
std::cout << "len = " << i << ", cnt = " << cnt[i] << std::endl;
}
}
*/
}
void
IndexNGTPANNG::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
index_size_ = index_->memSize();
}
} // namespace knowhere

View File

@ -24,6 +24,9 @@ class IndexNGTPANNG : public IndexNGT {
void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
void
UpdateIndexSize() override;
};
} // namespace knowhere

View File

@ -48,6 +48,9 @@ NSG::Serialize(const Config& config) {
BinarySet res_set;
res_set.Append("NSG", data, writer.rp);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
}
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
@ -57,6 +60,7 @@ NSG::Serialize(const Config& config) {
void
NSG::Load(const BinarySet& index_binary) {
try {
Assemble(const_cast<BinarySet&>(index_binary));
fiu_do_on("NSG.Load.throw_exception", throw std::exception());
std::lock_guard<std::mutex> lk(mutex_);
auto binary = index_binary.GetByName("NSG");
@ -73,7 +77,7 @@ NSG::Load(const BinarySet& index_binary) {
}
DatasetPtr
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}

View File

@ -59,7 +59,7 @@ class NSG : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
int64_t
Count() override;

View File

@ -13,6 +13,7 @@
#include <algorithm>
#include <cassert>
#include <chrono>
#include <iterator>
#include <utility>
#include <vector>
@ -49,6 +50,7 @@ IndexRHNSW::Serialize(const Config& config) {
void
IndexRHNSW::Load(const BinarySet& index_binary) {
try {
Assemble(const_cast<BinarySet&>(index_binary));
MemoryIOReader reader;
reader.name = this->index_type() + "_Index";
auto binary = index_binary.GetByName(reader.name);
@ -57,6 +59,15 @@ IndexRHNSW::Load(const BinarySet& index_binary) {
reader.data_ = binary->data.get();
auto idx = faiss::read_index(&reader);
auto hnsw_stats = std::static_pointer_cast<RHNSWStatistics>(stats);
if (STATISTICS_LEVEL >= 3) {
auto real_idx = static_cast<faiss::IndexRHNSW*>(idx);
auto lock = hnsw_stats->Lock();
hnsw_stats->update_level_distribution(real_idx->hnsw.max_level, real_idx->hnsw.level_stats);
real_idx->set_target_level(hnsw_stats->target_level);
// LOG_KNOWHERE_DEBUG_ << "IndexRHNSW::Load finished, show statistics:";
// LOG_KNOWHERE_DEBUG_ << hnsw_stats->ToString();
}
index_.reset(idx);
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
@ -76,10 +87,19 @@ IndexRHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA(dataset_ptr)
index_->add(rows, reinterpret_cast<const float*>(p_data));
auto hnsw_stats = std::static_pointer_cast<RHNSWStatistics>(stats);
if (STATISTICS_LEVEL >= 3) {
auto real_idx = static_cast<faiss::IndexRHNSW*>(index_.get());
auto lock = hnsw_stats->Lock();
hnsw_stats->update_level_distribution(real_idx->hnsw.max_level, real_idx->hnsw.level_stats);
real_idx->set_target_level(hnsw_stats->target_level);
}
// LOG_KNOWHERE_DEBUG_ << "IndexRHNSW::Load finished, show statistics:";
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
}
DatasetPtr
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -90,6 +110,7 @@ IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fai
int64_t dist_size = sizeof(float) * k;
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
auto p_dist = static_cast<float*>(malloc(dist_size * rows));
auto hnsw_stats = std::dynamic_pointer_cast<RHNSWStatistics>(stats);
for (auto i = 0; i < k * rows; ++i) {
p_id[i] = -1;
p_dist[i] = -1;
@ -97,8 +118,26 @@ IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fai
auto real_index = dynamic_cast<faiss::IndexRHNSW*>(index_.get());
real_index->hnsw.efSearch = (config[IndexParams::ef]);
real_index->hnsw.efSearch = (config[IndexParams::ef].get<int64_t>());
std::chrono::high_resolution_clock::time_point query_start, query_end;
query_start = std::chrono::high_resolution_clock::now();
real_index->search(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, bitset);
query_end = std::chrono::high_resolution_clock::now();
if (STATISTICS_LEVEL) {
auto lock = hnsw_stats->Lock();
if (STATISTICS_LEVEL >= 1) {
hnsw_stats->update_nq(rows);
hnsw_stats->update_ef_sum(real_index->hnsw.efSearch * rows);
hnsw_stats->update_total_query_time(
std::chrono::duration_cast<std::chrono::milliseconds>(query_end - query_start).count());
}
if (STATISTICS_LEVEL >= 2) {
hnsw_stats->update_filter_percentage(bitset);
}
}
// LOG_KNOWHERE_DEBUG_ << "IndexRHNSW::Load finished, show statistics:";
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
@ -122,6 +161,30 @@ IndexRHNSW::Dim() {
return index_->d;
}
StatisticsPtr
IndexRHNSW::GetStatistics() {
if (!STATISTICS_LEVEL) {
return stats;
}
auto hnsw_stats = std::static_pointer_cast<RHNSWStatistics>(stats);
auto real_index = static_cast<faiss::IndexRHNSW*>(index_.get());
auto lock = hnsw_stats->Lock();
real_index->get_sorted_access_counts(hnsw_stats->access_cnt, hnsw_stats->access_total);
return hnsw_stats;
}
void
IndexRHNSW::ClearStatistics() {
if (!STATISTICS_LEVEL) {
return;
}
auto hnsw_stats = std::static_pointer_cast<RHNSWStatistics>(stats);
auto real_index = static_cast<faiss::IndexRHNSW*>(index_.get());
real_index->clear_stats();
auto lock = hnsw_stats->Lock();
hnsw_stats->clear();
}
void
IndexRHNSW::UpdateIndexSize() {
KNOWHERE_THROW_MSG(

View File

@ -30,10 +30,12 @@ class IndexRHNSW : public VecIndex, public FaissBaseIndex {
public:
IndexRHNSW() : FaissBaseIndex(nullptr) {
index_type_ = IndexEnum::INVALID;
stats = std::make_shared<milvus::knowhere::RHNSWStatistics>(index_type_);
}
explicit IndexRHNSW(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
index_type_ = IndexEnum::INVALID;
stats = std::make_shared<milvus::knowhere::RHNSWStatistics>(index_type_);
}
BinarySet
@ -52,7 +54,7 @@ class IndexRHNSW : public VecIndex, public FaissBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
int64_t
Count() override;
@ -62,6 +64,12 @@ class IndexRHNSW : public VecIndex, public FaissBaseIndex {
void
UpdateIndexSize() override;
StatisticsPtr
GetStatistics() override;
void
ClearStatistics() override;
};
} // namespace knowhere
} // namespace milvus

View File

@ -52,6 +52,9 @@ IndexRHNSWFlat::Serialize(const Config& config) {
std::shared_ptr<uint8_t[]> data(writer.data_);
res_set.Append(writer.name, data, writer.rp);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
}
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
@ -61,6 +64,7 @@ IndexRHNSWFlat::Serialize(const Config& config) {
void
IndexRHNSWFlat::Load(const BinarySet& index_binary) {
try {
Assemble(const_cast<BinarySet&>(index_binary));
IndexRHNSW::Load(index_binary);
MemoryIOReader reader;
reader.name = this->index_type() + "_Data";

View File

@ -48,6 +48,9 @@ IndexRHNSWPQ::Serialize(const Config& config) {
std::shared_ptr<uint8_t[]> data(writer.data_);
res_set.Append(writer.name, data, writer.rp);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
}
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
@ -57,6 +60,7 @@ IndexRHNSWPQ::Serialize(const Config& config) {
void
IndexRHNSWPQ::Load(const BinarySet& index_binary) {
try {
Assemble(const_cast<BinarySet&>(index_binary));
IndexRHNSW::Load(index_binary);
MemoryIOReader reader;
reader.name = QUANTIZATION_DATA;

View File

@ -51,6 +51,9 @@ IndexRHNSWSQ::Serialize(const Config& config) {
std::shared_ptr<uint8_t[]> data(writer.data_);
res_set.Append(writer.name, data, writer.rp);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
}
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
@ -60,6 +63,7 @@ IndexRHNSWSQ::Serialize(const Config& config) {
void
IndexRHNSWSQ::Load(const BinarySet& index_binary) {
try {
Assemble(const_cast<BinarySet&>(index_binary));
IndexRHNSW::Load(index_binary);
MemoryIOReader reader;
reader.name = QUANTIZATION_DATA;

View File

@ -83,11 +83,15 @@ CPUSPTAGRNG::Serialize(const Config& config) {
binary_set.Append("config", x_cfg, length);
binary_set.Append("graph", graph, index_blobs[2].Length());
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, binary_set);
}
return binary_set;
}
void
CPUSPTAGRNG::Load(const BinarySet& binary_set) {
Assemble(const_cast<BinarySet&>(binary_set));
std::string index_config;
std::vector<SPTAG::ByteArray> index_blobs;
@ -176,7 +180,7 @@ CPUSPTAGRNG::SetParameters(const Config& config) {
}
DatasetPtr
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
SetParameters(config);
float* p_data = (float*)dataset_ptr->Get<const void*>(meta::TENSOR);

View File

@ -52,7 +52,7 @@ class CPUSPTAGRNG : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
int64_t
Count() override;

View File

@ -0,0 +1,208 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <algorithm>
#include <cstdio>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "IndexIVF.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/IndexType.h"
#include "knowhere/index/vector_index/Statistics.h"
#include "faiss/utils/ConcurrentBitset.h"
namespace milvus {
namespace knowhere {
int32_t STATISTICS_LEVEL = 0;
std::string
Statistics::ToString() {
std::ostringstream ret;
if (STATISTICS_LEVEL == 0) {
ret << "There is nothing because configuration STATISTICS_LEVEL = 0" << std::endl;
return ret.str();
}
if (STATISTICS_LEVEL >= 1) {
ret << "Total batches: " << batch_cnt << std::endl;
ret << "Total queries: " << nq_cnt << std::endl;
ret << "Qps: " << Qps() << std::endl;
ret << "The frequency distribution of the num of queries:" << std::endl;
size_t left = 1, right = 1;
for (size_t i = 0; i < NQ_Histogram_Slices - 1; i++) {
ret << "[" << left << ", " << right << "].count = " << nq_stat[i] << std::endl;
left = right + 1;
right <<= 1;
}
ret << "[" << left << ", +00).count = " << nq_stat.back() << std::endl;
}
if (STATISTICS_LEVEL >= 2) {
ret << "The frequency distribution of filter: " << std::endl;
for (auto i = 0; i < 20; ++i) {
ret << "[" << i * 5 << "%, " << i * 5 + 5 << "%).count = " << filter_stat[i] << std::endl;
}
}
return ret.str();
}
std::string
HNSWStatistics::ToString() {
std::ostringstream ret;
if (STATISTICS_LEVEL >= 1) {
ret << "Avg Ef: " << AvgSearchEf() << std::endl;
}
if (STATISTICS_LEVEL >= 3) {
std::vector<size_t> axis_x = {5, 10, 20, 40};
std::vector<double> access_cdf = AccessCDF(axis_x);
ret << "There are " << access_total << " times point-access at level " << target_level << std::endl;
ret << "The CDF at level " << target_level << ":" << std::endl;
for (auto i = 0; i < axis_x.size(); ++i) {
ret << "(" << axis_x[i] << "," << access_cdf[i] << ") ";
}
ret << std::endl;
ret << "Level distribution: " << std::endl;
size_t point_cnt = 0;
for (int i = distribution.size() - 1; i >= 0; i--) {
point_cnt += distribution[i];
ret << "Level " << i << " has " << point_cnt << " points" << std::endl;
}
}
return Statistics::ToString() + ret.str();
}
std::vector<size_t>
GenSplitIndex(size_t size, const std::vector<size_t>& axis_x) {
// Gen split index
std::vector<size_t> split_idx(axis_x.size());
for (size_t i = 0; i < axis_x.size(); i++) {
if (axis_x[i] >= 100) {
// for safe, not to let idx be larger than size
split_idx[i] = size;
} else {
split_idx[i] = (axis_x[i] * size + 50) / 100;
}
}
return split_idx;
}
std::vector<double>
CaculateCDF(size_t access_total, const std::vector<size_t>& access_cnt, const std::vector<size_t>& axis_x) {
auto split_idx = GenSplitIndex(access_cnt.size(), axis_x);
// count cdf
std::vector<double> access_cdf;
access_cdf.resize(split_idx.size(), 0.0);
size_t idx = 0;
size_t tmp_cnt = 0;
for (size_t i = 0; i < split_idx.size(); ++i) {
if (i != 0 && split_idx[i] < split_idx[i - 1]) {
// wrong split_idx
// Todo: log output
access_cdf[i] = 0;
} else {
while (idx < split_idx[i]) {
tmp_cnt += access_cnt[idx];
idx++;
}
access_cdf[i] = static_cast<double>(tmp_cnt) / static_cast<double>(access_total);
}
}
return access_cdf;
}
std::vector<double>
LibHNSWStatistics::AccessCDF(const std::vector<size_t>& axis_x) {
// copy from std::map to std::vector
std::vector<size_t> access_cnt;
access_cnt.reserve(access_cnt_map.size());
access_total = 0;
for (auto& elem : access_cnt_map) {
access_cnt.push_back(elem.second);
access_total += elem.second;
}
std::sort(access_cnt.begin(), access_cnt.end(), std::greater<>());
return CaculateCDF(access_total, access_cnt, axis_x);
}
std::vector<double>
RHNSWStatistics::AccessCDF(const std::vector<size_t>& axis_x) {
return CaculateCDF(access_total, access_cnt, axis_x);
}
std::string
IVFStatistics::ToString() {
std::ostringstream ret;
if (STATISTICS_LEVEL >= 1) {
ret << "nlist " << Nlist() << std::endl;
ret << "(nprobe, count): " << std::endl;
auto nprobe = SearchNprobe();
for (auto& it : nprobe) {
ret << "(" << it.first << ", " << it.second << ") ";
}
ret << std::endl;
}
if (STATISTICS_LEVEL >= 3) {
std::vector<size_t> axis_x = {5, 10, 20, 40};
ret << "Bucket CDF " << std::endl;
auto output = AccessCDF(axis_x);
for (int i = 0; i < output.size(); i++) {
ret << "Top " << axis_x[i] << "% access count " << output[i] << std::endl;
}
ret << std::endl;
}
return Statistics::ToString() + ret.str();
}
void
IVFStatistics::count_nprobe(const int64_t nprobe) {
// nprobe count
auto it = nprobe_count.find(nprobe);
if (it == nprobe_count.end()) {
nprobe_count[nprobe] = 1;
} else {
it->second++;
}
}
void
IVFStatistics::update_ivf_access_stats(const std::vector<size_t>& nprobe_statistics) {
nlist = nprobe_statistics.size();
access_total = 0;
access_cnt = nprobe_statistics;
std::sort(access_cnt.begin(), access_cnt.end(), std::greater<>());
// access total
for (auto& cnt : access_cnt) {
access_total += cnt;
}
}
std::vector<double>
IVFStatistics::AccessCDF(const std::vector<size_t>& axis_x) {
return CaculateCDF(access_total, access_cnt, axis_x);
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,377 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <algorithm>
#include <functional>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "faiss/utils/ConcurrentBitset.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/IndexType.h"
namespace milvus {
namespace knowhere {
extern int32_t STATISTICS_LEVEL;
inline uint64_t
upper_bound_of_pow2(uint64_t x) {
--x;
x |= (x >> 1);
x |= (x >> 2);
x |= (x >> 4);
x |= (x >> 8);
x |= (x >> 16);
x |= (x >> 32);
return x + 1;
}
inline int
len_of_pow2(uint64_t x) {
return __builtin_popcountl(x - 1);
}
/*
* class: Statistics
*/
class Statistics {
public:
static const size_t NQ_Histogram_Slices = 13;
static const size_t Filter_Histogram_Slices = 21;
explicit Statistics(std::string& idx_t)
: index_type(idx_t),
nq_cnt(0),
batch_cnt(0),
total_query_time(0.0),
nq_stat(NQ_Histogram_Slices, 0),
filter_stat(Filter_Histogram_Slices, 0),
update_lock() {
}
/*
* Get index type
* @retval: index type in string
*/
const std::string&
IndexType() {
return index_type;
}
/*
* To string (may be for log output)
* @retval: string output
*/
virtual std::string
ToString();
virtual ~Statistics() = default;
/*
* Get batch count of the queries (Level 1)
* @retval: query batch count
*/
size_t
BatchCount() {
return batch_cnt;
}
/*
* Get the statistics of the nq (Level 1)
* @retval: count nq 1, 2, 3~4, 5~8, 9~16,, 1024~2048, larger than 2048 (13 slices)
*/
const std::vector<size_t>&
NQHistogram() {
return nq_stat;
}
/*
* Get query response per-second (Level 1)
* @retval: Qps
*/
double
Qps() {
// ms -> s
return total_query_time ? (nq_cnt * 1000.0 / total_query_time) : 0.0;
}
/*
* Get the statistics of the filter for each batch (Level 2)
* @retval: count 0~5%, 5~10%, 10~15%, ...95~100%, 100% (21 slices)
*/
const std::vector<size_t>&
FilterHistograms() {
return filter_stat;
}
std::unique_lock<std::mutex>
Lock() {
return std::unique_lock<std::mutex>(update_lock);
}
public:
void
update_nq(const int64_t nq) {
// batch
batch_cnt++;
// nq_cnt
nq_cnt += static_cast<size_t>(nq);
// nq_stat
if (nq > 2048) {
nq_stat[12]++;
} else {
nq_stat[len_of_pow2(upper_bound_of_pow2(static_cast<size_t>(nq)))]++;
}
}
void
update_total_query_time(const double query_time) {
total_query_time += query_time;
}
void
update_filter_percentage(const faiss::BitsetView& bitset) {
double fps = !bitset.empty() ? static_cast<double>(bitset.count_1()) / bitset.size() : 0.0;
filter_stat[static_cast<int>(fps * 100) / 5] += 1;
}
virtual void
clear() {
total_query_time = 0.0;
nq_cnt = 0;
batch_cnt = 0;
nq_stat.resize(NQ_Histogram_Slices, 0);
filter_stat.resize(Filter_Histogram_Slices, 0);
}
public:
std::string& index_type;
size_t batch_cnt; // updated in query
size_t nq_cnt; // updated in query
double total_query_time; // updated in query (unit: ms)
std::vector<size_t> nq_stat; // updated in query
std::vector<size_t> filter_stat; // updated in query
std::mutex update_lock;
};
using StatisticsPtr = std::shared_ptr<Statistics>;
/*
* class: HNSWStatistics
*/
class HNSWStatistics : public Statistics {
public:
explicit HNSWStatistics(std::string& idx_t)
: Statistics(idx_t), distribution(), target_level(1), access_total(0), ef_sum(0) {
}
~HNSWStatistics() override = default;
/*
* To string (may be for log output)
* @retval: string output
*/
std::string
ToString() override;
/*
* Get nodes count in each level
* @retval: none
*/
const std::vector<size_t>&
LevelNodesNum() {
return distribution;
}
/*
* Get average search parameter ef (average for batches) (Level 1)
* @retval: avg Ef
*/
double
AvgSearchEf() {
return nq_cnt ? ef_sum / nq_cnt : 0;
}
/*
* Cumulative distribution function of nodes access (Level 3)
* @param: none (axis_x = {5,10,15,20,...100} by default)
* @retval: Access CDF
*/
virtual std::vector<double>
AccessCDF() {
std::vector<size_t> axis_x(20);
for (size_t i = 0; i < 20; ++i) {
axis_x[i] = (i + 1) * 5;
}
return AccessCDF(axis_x);
}
/*
* Cumulative distribution function of nodes access
* @param: axis_x[in] specified by users and should be in ascending order
* @retval: Access CDF
*/
virtual std::vector<double>
AccessCDF(const std::vector<size_t>& axis_x) = 0;
public:
void
update_ef_sum(const int64_t ef) {
ef_sum += ef;
}
void
update_level_distribution(const int max_level, const std::vector<int>& levels) {
distribution.resize(max_level + 1);
for (auto i = 0; i <= max_level; ++i) {
distribution[i] = levels[i];
if (distribution[i] >= 1000 && distribution[i] < 10000) {
target_level = i;
}
}
}
void
clear() override {
Statistics::clear();
access_total = 0;
ef_sum = 0;
}
public:
std::vector<size_t> distribution;
size_t target_level;
size_t access_total; // depend on subclass type
size_t ef_sum; // updated in query
};
/*
* class: LibHNSWStatistics
* for index: HNSW
*/
class LibHNSWStatistics : public HNSWStatistics {
public:
explicit LibHNSWStatistics(std::string& idx_t) : HNSWStatistics(idx_t), access_cnt_map() {
}
~LibHNSWStatistics() override = default;
std::vector<double>
AccessCDF(const std::vector<size_t>& axis_x) override;
public:
void
clear() override {
HNSWStatistics::clear();
access_cnt_map.clear();
}
public:
std::unordered_map<int64_t, size_t> access_cnt_map; // updated in query
};
/*
* class: RHNSWStatistics
* for index: RHNSW_FLAT, RHNSE_SQ, RHNSW_PQ
*/
class RHNSWStatistics : public HNSWStatistics {
public:
explicit RHNSWStatistics(std::string& idx_t) : HNSWStatistics(idx_t), access_cnt() {
}
~RHNSWStatistics() override = default;
std::vector<double>
AccessCDF(const std::vector<size_t>& axis_x) override;
public:
std::vector<size_t> access_cnt; // prepared in GetStatistics
};
/*
* class: IVFStatistics
* for index: IVF_FLAT, IVF_PQ, IVF_SQ8
*/
class IVFStatistics : public Statistics {
public:
explicit IVFStatistics(std::string& idx_t) : Statistics(idx_t), nprobe_count(), access_cnt(), nlist(0) {
}
~IVFStatistics() override = default;
/*
* To string (may be for log output)
* @retval: string output
*/
std::string
ToString() override;
/*
* Get the statistics of the search parameter nprboe (count of batches) (Level 1)
* @retval: nprobe
*/
int64_t
Nlist() {
return nlist;
}
/*
* Get the statistics of the search parameter nprboe (count of batches) (Level 1)
* @retval: <nprobe, count>
*/
std::unordered_map<int64_t, size_t>
SearchNprobe() {
auto lock = Lock();
auto rst = nprobe_count;
lock.unlock();
return rst;
}
/*
* Cumulative distribution function of bucket access (Level 3)
* @param: axis_x[in] specified by users and should be in ascending order
* @retval: Access CDF
*/
std::vector<double>
AccessCDF(const std::vector<size_t>& axis_x);
public:
void
count_nprobe(const int64_t nprobe);
void
update_ivf_access_stats(const std::vector<size_t>& nprobe_statistics);
void
clear() override {
Statistics::clear();
nprobe_count.clear();
access_total = 0;
}
public:
std::unordered_map<int64_t, size_t> nprobe_count; // updated in query
std::vector<size_t> access_cnt; // prepared in GetStatistics
size_t access_total; // updated in query
size_t nlist;
};
} // namespace knowhere
} // namespace milvus

View File

@ -11,6 +11,7 @@
#pragma once
#include <faiss/utils/BitsetView.h>
#include <faiss/utils/ConcurrentBitset.h>
#include <memory>
#include <utility>
@ -19,8 +20,10 @@
#include "knowhere/common/Dataset.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Typedef.h"
#include "knowhere/common/Utils.h"
#include "knowhere/index/Index.h"
#include "knowhere/index/IndexType.h"
#include "knowhere/index/vector_index/Statistics.h"
namespace milvus {
namespace knowhere {
@ -46,7 +49,7 @@ class VecIndex : public Index {
AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0;
virtual DatasetPtr
Query(const DatasetPtr& dataset, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) = 0;
Query(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView& bitset) = 0;
#if 0
virtual DatasetPtr
@ -67,6 +70,15 @@ class VecIndex : public Index {
virtual int64_t
Count() = 0;
virtual StatisticsPtr
GetStatistics() {
return stats;
}
virtual void
ClearStatistics() {
}
virtual IndexType
index_type() const {
return index_type_;
@ -84,39 +96,19 @@ class VecIndex : public Index {
}
#endif
faiss::ConcurrentBitsetPtr
GetBlacklist() {
return bitset_;
}
void
SetBlacklist(faiss::ConcurrentBitsetPtr bitset_ptr) {
bitset_ = std::move(bitset_ptr);
}
const std::vector<IDType>&
std::shared_ptr<std::vector<IDType>>
GetUids() const {
return uids_;
}
void
SetUids(std::vector<IDType>& uids) {
uids_.clear();
uids_.swap(uids);
}
size_t
BlacklistSize() {
if (bitset_) {
return bitset_->u8size() * sizeof(uint8_t);
} else {
return 0;
}
SetUids(std::shared_ptr<std::vector<IDType>> uids) {
uids_ = uids;
}
size_t
UidsSize() {
return uids_.size() * sizeof(IDType);
return uids_ ? uids_->size() * sizeof(IDType) : 0;
}
virtual int64_t
@ -138,17 +130,15 @@ class VecIndex : public Index {
int64_t
Size() override {
return BlacklistSize() + UidsSize() + IndexSize();
return UidsSize() + IndexSize();
}
protected:
IndexType index_type_ = "";
IndexMode index_mode_ = IndexMode::MODE_CPU;
std::vector<IDType> uids_;
std::shared_ptr<std::vector<IDType>> uids_ = nullptr;
int64_t index_size_ = -1;
private:
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
StatisticsPtr stats = nullptr;
};
using VecIndexPtr = std::shared_ptr<VecIndex>;

View File

@ -110,7 +110,7 @@ GPUIDMAP::QueryImpl(int64_t n,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
const faiss::BitsetView& bitset) {
ResScope rs(res_, gpu_id_);
// assign the metric type

View File

@ -55,8 +55,8 @@ class GPUIDMAP : public IDMAP, public GPUIndex {
LoadImpl(const BinarySet&, const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
override;
QueryImpl(
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
};
using GPUIDMAPPtr = std::shared_ptr<GPUIDMAP>;

View File

@ -143,7 +143,7 @@ GPUIVF::QueryImpl(int64_t n,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
const faiss::BitsetView& bitset) {
std::lock_guard<std::mutex> lk(mutex_);
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);

View File

@ -51,8 +51,8 @@ class GPUIVF : public IVF, public GPUIndex {
LoadImpl(const BinarySet&, const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
override;
QueryImpl(
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
};
using GPUIVFPtr = std::shared_ptr<GPUIVF>;

View File

@ -247,7 +247,7 @@ IVFSQHybrid::QueryImpl(int64_t n,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
const faiss::BitsetView& bitset) {
if (gpu_mode_ == 2) {
GPUIVF::QueryImpl(n, data, k, distances, labels, config, bitset);
// index_->search(n, (float*)data, k, distances, labels);

View File

@ -88,8 +88,8 @@ class IVFSQHybrid : public GPUIVFSQ {
LoadImpl(const BinarySet&, const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
override;
QueryImpl(
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
protected:
int64_t gpu_mode_ = 0; // 0: CPU, 1: Hybrid, 2: GPU

View File

@ -27,10 +27,7 @@ namespace cloner {
void
CopyIndexData(const VecIndexPtr& dst_index, const VecIndexPtr& src_index) {
/* do real copy */
auto uids = src_index->GetUids();
dst_index->SetUids(uids);
dst_index->SetBlacklist(src_index->GetBlacklist());
dst_index->SetUids(src_index->GetUids());
dst_index->SetIndexSize(src_index->IndexSize());
}

View File

@ -63,10 +63,5 @@ MemoryIOReader::operator()(void* ptr, size_t size, size_t nitems) {
return nitems;
}
void
enable_faiss_logging() {
faiss::LOG_DEBUG_ = &log_debug_;
}
} // namespace knowhere
} // namespace milvus

View File

@ -12,7 +12,6 @@
#pragma once
#include <faiss/impl/io.h>
#include <faiss/utils/utils.h>
namespace milvus {
namespace knowhere {
@ -47,8 +46,5 @@ struct MemoryIOReader : public faiss::IOReader {
}
};
void
enable_faiss_logging();
} // namespace knowhere
} // namespace milvus

View File

@ -54,6 +54,9 @@ constexpr const char* PQM = "PQM";
// NGT Params
constexpr const char* edge_size = "edge_size";
// NGT Search Params
constexpr const char* epsilon = "epsilon";
constexpr const char* max_search_edges = "max_search_edges";
// NGT_PANNG Params
constexpr const char* forcedly_pruned_edge_size = "forcedly_pruned_edge_size";
constexpr const char* selectively_pruned_edge_size = "selectively_pruned_edge_size";

View File

@ -864,7 +864,7 @@ NsgIndex::Search(const float* query,
float* dist,
int64_t* ids,
SearchParams& params,
faiss::ConcurrentBitsetPtr bitset) {
const faiss::BitsetView& bitset) {
std::vector<std::vector<Neighbor>> resset(nq);
TimeRecorder rc("NsgIndex::search", 1);
@ -886,7 +886,7 @@ NsgIndex::Search(const float* query,
if (pos >= k) {
break; // already top k
}
if (!bitset || !bitset->test(node.id)) {
if (!bitset || !bitset.test(node.id)) {
ids[i * k + pos] = ids_[node.id];
dist[i * k + pos] = is_ip ? -node.distance : node.distance;
++pos;

View File

@ -91,7 +91,7 @@ class NsgIndex {
float* dist,
int64_t* ids,
SearchParams& params,
faiss::ConcurrentBitsetPtr bitset = nullptr);
const faiss::BitsetView& bitset = nullptr);
int64_t
GetSize();

View File

@ -53,23 +53,32 @@ IVF_NM::Serialize(const Config& config) {
}
std::lock_guard<std::mutex> lk(mutex_);
return SerializeImpl(index_type_);
auto ret = SerializeImpl(index_type_);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, ret);
}
return ret;
}
void
IVF_NM::Load(const BinarySet& binary_set) {
Assemble(const_cast<BinarySet&>(binary_set));
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(binary_set, index_type_);
// Construct arranged data from original data
auto binary = binary_set.GetByName(RAW_DATA);
auto original_data = reinterpret_cast<const float*>(binary->data.get());
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
auto ivf_index = static_cast<faiss::IndexIVF*>(index_.get());
auto invlists = ivf_index->invlists;
auto d = ivf_index->d;
prefix_sum.resize(invlists->nlist);
size_t curr_index = 0;
if (STATISTICS_LEVEL >= 3) {
ivf_index->nprobe_statistics.resize(invlists->nlist, 0);
}
#ifndef MILVUS_GPU_VERSION
auto ails = dynamic_cast<faiss::ArrayInvertedLists*>(invlists);
size_t nb = binary->size / invlists->code_size;
@ -102,17 +111,22 @@ IVF_NM::Load(const BinarySet& binary_set) {
ro_codes = rol->pin_readonly_codes;
data_ = nullptr;
#endif
// LOG_KNOWHERE_DEBUG_ << "IndexIVF_FLAT::Load finished, show statistics:";
// auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
// LOG_KNOWHERE_DEBUG_ << ivf_stats->ToString();
}
void
IVF_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA_DIM(dataset_ptr)
int64_t nlist = config[IndexParams::nlist].get<int64_t>();
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
auto nlist = config[IndexParams::nlist].get<int64_t>();
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexIVFFlat(coarse_quantizer, dim, nlist, metric_type));
index_->train(rows, reinterpret_cast<const float*>(p_data));
auto coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, nlist, metric_type);
index->own_fields = true;
index->train(rows, reinterpret_cast<const float*>(p_data));
index_ = index;
}
void
@ -138,7 +152,7 @@ IVF_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -311,7 +325,7 @@ IVF_NM::QueryImpl(int64_t n,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
const faiss::BitsetView& bitset) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
@ -328,16 +342,31 @@ IVF_NM::QueryImpl(int64_t n,
#else
auto data = static_cast<const uint8_t*>(ro_codes->data);
#endif
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
ivf_index->search_without_codes(n, reinterpret_cast<const float*>(query), data, prefix_sum, is_sq8, k, distances,
labels, bitset);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
LOG_KNOWHERE_DEBUG_ << "IVF_NM search cost: " << search_cost
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
faiss::indexIVF_stats.quantization_time = 0;
faiss::indexIVF_stats.search_time = 0;
<< ", quantization cost: " << ivf_index->index_ivf_stats.quantization_time
<< ", data search cost: " << ivf_index->index_ivf_stats.search_time;
if (STATISTICS_LEVEL) {
auto lock = ivf_stats->Lock();
if (STATISTICS_LEVEL >= 1) {
ivf_stats->update_nq(n);
ivf_stats->count_nprobe(ivf_index->nprobe);
ivf_stats->update_total_query_time(ivf_index->index_ivf_stats.quantization_time +
ivf_index->index_ivf_stats.search_time);
ivf_index->index_ivf_stats.quantization_time = 0;
ivf_index->index_ivf_stats.search_time = 0;
}
if (STATISTICS_LEVEL >= 2) {
ivf_stats->update_filter_percentage(bitset);
}
}
// LOG_KNOWHERE_DEBUG_ << "IndexIVF_FLAT::QueryImpl finished, show statistics:";
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
}
void
@ -380,5 +409,30 @@ IVF_NM::UpdateIndexSize() {
index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size;
}
StatisticsPtr
IVF_NM::GetStatistics() {
if (!STATISTICS_LEVEL) {
return stats;
}
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
auto lock = ivf_stats->Lock();
ivf_stats->update_ivf_access_stats(ivf_index->nprobe_statistics);
return ivf_stats;
}
void
IVF_NM::ClearStatistics() {
if (!STATISTICS_LEVEL) {
return;
}
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->clear_nprobe_statistics();
ivf_index->index_ivf_stats.reset();
auto lock = ivf_stats->Lock();
ivf_stats->clear();
}
} // namespace knowhere
} // namespace milvus

View File

@ -29,10 +29,12 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
public:
IVF_NM() : OffsetBaseIndex(nullptr) {
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
}
explicit IVF_NM(std::shared_ptr<faiss::Index> index) : OffsetBaseIndex(std::move(index)) {
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
}
BinarySet
@ -51,7 +53,7 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
#if 0
DatasetPtr
@ -67,6 +69,12 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
void
UpdateIndexSize() override;
StatisticsPtr
GetStatistics() override;
void
ClearStatistics() override;
#if 0
DatasetPtr
GetVectorById(const DatasetPtr& dataset, const Config& config) override;
@ -86,8 +94,7 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
GenParams(const Config&);
virtual void
QueryImpl(
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset);
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset);
void
SealImpl() override;

View File

@ -47,6 +47,9 @@ NSG_NM::Serialize(const Config& config) {
BinarySet res_set;
res_set.Append("NSG_NM", data, writer.rp);
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
}
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
@ -56,6 +59,7 @@ NSG_NM::Serialize(const Config& config) {
void
NSG_NM::Load(const BinarySet& index_binary) {
try {
Assemble(const_cast<BinarySet&>(index_binary));
fiu_do_on("NSG_NM.Load.throw_exception", throw std::exception());
std::lock_guard<std::mutex> lk(mutex_);
auto binary = index_binary.GetByName("NSG_NM");
@ -74,7 +78,7 @@ NSG_NM::Load(const BinarySet& index_binary) {
}
DatasetPtr
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}

View File

@ -59,7 +59,7 @@ class NSG_NM : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
int64_t
Count() override;

View File

@ -124,7 +124,7 @@ GPUIVF_NM::QueryImpl(int64_t n,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
const faiss::BitsetView& bitset) {
std::lock_guard<std::mutex> lk(mutex_);
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);

View File

@ -51,8 +51,8 @@ class GPUIVF_NM : public IVF, public GPUIndex {
SerializeImpl(const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
override;
QueryImpl(
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
protected:
uint8_t* arranged_data;

View File

@ -17,6 +17,7 @@
#pragma once
#include "NGT/Index.h"
#include "defines.h"
using namespace std;
@ -185,8 +186,10 @@ class Clustering {
}
}
if ((numberOfClusters != 0) && (clusters.size() < numberOfClusters)) {
std::cerr << "initial cluster data are not enough. " << clusters.size() << ":" << numberOfClusters
<< std::endl;
// std::cerr << "initial cluster data are not enough. " << clusters.size() << ":" << numberOfClusters
// << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("initial cluster data are not enough. " + std::to_string(clusters.size()) + ":" + std::to_string(numberOfClusters));
exit(1);
}
}
@ -365,7 +368,9 @@ class Clustering {
for (auto soi = sortedObjects.rbegin(); soi != sortedObjects.rend();) {
Entry& entry = *soi;
if (entry.centroidID >= clusters.size()) {
std::cerr << "Something wrong. " << entry.centroidID << ":" << clusters.size() << std::endl;
// std::cerr << "Something wrong. " << entry.centroidID << ":" << clusters.size() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Something wrong. " + std::to_string(entry.centroidID) + ":" + std::to_string(clusters.size()));
soi++;
continue;
}
@ -547,7 +552,9 @@ class Clustering {
distance += distanceL2((*it).centroid, mean);
(*it).centroid = mean;
} else {
cerr << "Clustering: Fatal Error. No member!" << endl;
// cerr << "Clustering: Fatal Error. No member!" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Clustering: Fatal Error. No member!");
abort();
}
}
@ -579,7 +586,9 @@ class Clustering {
double diff = 0;
for (size_t i = 0; i < maximumIteration; i++) {
std::cerr << "iteration=" << i << std::endl;
// std::cerr << "iteration=" << i << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("iteration=" + std::to_string(i));
assign(vectors, clusters, clusterSize);
// centroid is recomputed.
// diff is distance between the current centroids and the previous centroids.
@ -610,7 +619,9 @@ class Clustering {
std::vector<Cluster> prevClusters = clusters;
diff = calculateCentroid(vectors, clusters);
timer.stop();
std::cerr << "iteration=" << i << " time=" << timer << " diff=" << diff << std::endl;
// std::cerr << "iteration=" << i << " time=" << timer << " diff=" << diff << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("iteration=" + std::to_string(i) + " time=" + std::to_string(timer.time)+ " diff=" + std::to_string(diff));
timer.start();
diffHistory.push_back(diff);
@ -664,15 +675,21 @@ class Clustering {
try {
os.getObject(idx, vectors[idx - 1]);
} catch (...) {
cerr << "Cannot get object " << idx << endl;
// cerr << "Cannot get object " << idx << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Cannot get object " + std::to_string(idx));
}
}
cerr << "# of data for clustering=" << vectors.size() << endl;
// cerr << "# of data for clustering=" << vectors.size() << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("# of data for clustering=" + std::to_string(vectors.size()));
double diff = DBL_MAX;
clusters.clear();
setupInitialClusters(vectors, numberOfClusters, clusters);
for (float epsilon = epsilonFrom; epsilon <= epsilonTo; epsilon += epsilonStep) {
cerr << "epsilon=" << epsilon << endl;
// cerr << "epsilon=" << epsilon << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("epsilon=" + std::to_string(epsilon));
diff = kmeansWithNGT(index, vectors, numberOfClusters, clusters, epsilon);
if (diff == 0.0) {
return diff;
@ -748,7 +765,9 @@ class Clustering {
}
}
if (vectors.size() != count) {
std::cerr << "Warning! vectors.size() != count" << std::endl;
// std::cerr << "Warning! vectors.size() != count" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Warning! vectors.size() != count");
}
return d / (double)vectors.size();
@ -787,7 +806,9 @@ class Clustering {
break;
}
default:
std::cerr << "proper initMode is not specified." << std::endl;
// std::cerr << "proper initMode is not specified." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("proper initMode is not specified.");
exit(1);
}
}
@ -805,7 +826,9 @@ class Clustering {
return kmeansWithNGT(vectors, numberOfClusters, clusters);
break;
default:
cerr << "kmeans::fatal error!. invalid clustering type. " << clusteringType << endl;
// cerr << "kmeans::fatal error!. invalid clustering type. " << clusteringType << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("kmeans::fatal error!. invalid clustering type. " + std::to_string(clusteringType));
abort();
break;
}
@ -817,16 +840,16 @@ class Clustering {
size_t clusterSize = std::numeric_limits<size_t>::max();
assign(vectors, clusters, clusterSize);
std::cout << "The number of vectors=" << vectors.size() << std::endl;
std::cout << "The number of centroids=" << clusters.size() << std::endl;
// std::cout << "The number of vectors=" << vectors.size() << std::endl;
// std::cout << "The number of centroids=" << clusters.size() << std::endl;
if (centroidIds.size() == 0) {
switch (mode) {
case 'e':
std::cout << "MSE=" << calculateMSE(vectors, clusters) << std::endl;
// std::cout << "MSE=" << calculateMSE(vectors, clusters) << std::endl;
break;
case '2':
default:
std::cout << "ML2=" << calculateML2(vectors, clusters) << std::endl;
// std::cout << "ML2=" << calculateML2(vectors, clusters) << std::endl;
break;
}
} else {
@ -835,8 +858,8 @@ class Clustering {
break;
case '2':
default:
std::cout << "ML2=" << calculateML2FromSpecifiedCentroids(vectors, clusters, centroidIds)
<< std::endl;
// std::cout << "ML2=" << calculateML2FromSpecifiedCentroids(vectors, clusters, centroidIds)
// << std::endl;
break;
}
}

View File

@ -0,0 +1,24 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/lib/NGT/Common.h"
#include "NGT/lib/NGT/ObjectSpace.h"
int64_t
NGT::SearchContainer::memSize() {
auto workres_size = workingResult.size() == 0 ? 0 : workingResult.size() * workingResult.top().memSize();
return sizeof(size_t) * 3 + sizeof(float) * 3 + result->memSize() + 1 + workres_size + Container::memSize();
}

View File

@ -120,7 +120,9 @@ namespace NGT {
}
auto status = insert(std::make_pair(key,value));
if (!status.second) {
std::cerr << "Args: Duplicated options. [" << opt << "]" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Args: Duplicated options. [" + opt + "]");
// std::cerr << "Args: Duplicated options. [" << opt << "]" << std::endl;
}
}
}
@ -305,7 +307,9 @@ namespace NGT {
logFD = open(logFilePath.c_str(), O_CREAT|O_WRONLY|O_APPEND, mode);
}
if (logFD < 0) {
std::cerr << "Logger: Cannot begin logging." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Logger: Cannot begin logging.");
// std::cerr << "Logger: Cannot begin logging." << std::endl;
logFD = -1;
return;
}
@ -323,6 +327,8 @@ namespace NGT {
savedFdNo = -1;
}
int64_t memSize() { return sizeof(*this); }
std::string logFilePath;
mode_t mode;
int logFD;
@ -479,7 +485,9 @@ namespace NGT {
uint64_t size = vectorSize;
size <<= 1;
if (size > 0xffff) {
std::cerr << "CompactVector is too big. " << size << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("CompactVector is too big. " + std::to_string(size));
// std::cerr << "CompactVector is too big. " << size << std::endl;
abort();
}
reserve(size);
@ -487,6 +495,8 @@ namespace NGT {
}
}
virtual int64_t memSize() { return vector->memSize() * vectorSize + sizeof(vectorSize) * 2; }
TYPE *vector;
uint16_t vectorSize;
uint16_t allocatedSize;
@ -540,6 +550,8 @@ namespace NGT {
}
}
virtual int64_t memSize() { return size(); }
char *vector;
};
@ -563,6 +575,7 @@ namespace NGT {
inline void reset(size_t i) {
getEntry(i) &= ~getBitString(i);
}
inline int64_t memSize() { return size * sizeof(uint64_t); }
std::vector<uint64_t> bitvec;
uint64_t size;
};
@ -602,7 +615,9 @@ namespace NGT {
char *e = 0;
float val = strtof(it->second.c_str(), &e);
if (*e != 0) {
std::cerr << "Warning: Illegal property. " << key << ":" << it->second << " (" << e << ")" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Warning: Illegal property. " + key + ":" + it->second + " (" + e + ")");
// std::cerr << "Warning: Illegal property. " << key << ":" << it->second << " (" << e << ")" << std::endl;
return defvalue;
}
return val;
@ -620,7 +635,9 @@ namespace NGT {
char *e = 0;
float val = strtol(it->second.c_str(), &e, 10);
if (*e != 0) {
std::cerr << "Warning: Illegal property. " << key << ":" << it->second << " (" << e << ")" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Warning: Illegal property. " + key + ":" + it->second + " (" + e + ")");
// std::cerr << "Warning: Illegal property. " << key << ":" << it->second << " (" << e << ")" << std::endl;
}
return val;
}
@ -668,7 +685,9 @@ namespace NGT {
NGT::Common::tokenize(line, tokens, "\t");
if (tokens.size() != 2)
{
std::cerr << "Property file is illegal. " << line << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Property file is illegal. " + line);
// std::cerr << "Property file is illegal. " << line << std::endl;
continue;
}
set(tokens[0], tokens[1]);
@ -681,7 +700,9 @@ namespace NGT {
std::vector<std::string> tokens;
NGT::Common::tokenize(line, tokens, "\t");
if (tokens.size() != 2) {
std::cerr << "Property file is illegal. " << line << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Property file is illegal. " + line);
// std::cerr << "Property file is illegal. " << line << std::endl;
continue;
}
set(tokens[0], tokens[1]);
@ -719,7 +740,9 @@ namespace NGT {
unsigned int tmp;
is >> tmp;
if (tmp > 255) {
std::cerr << "Error! Invalid. " << tmp << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Error! Invalid. " + std::to_string(tmp));
// std::cerr << "Error! Invalid. " << tmp << std::endl;
}
v = (TYPE)tmp;
} else {
@ -819,7 +842,9 @@ namespace NGT {
unsigned int size;
is >> size;
if (s != size) {
std::cerr << "readAsText: something wrong. " << size << ":" << s << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("readAsText: something wrong. " + std::to_string(size) + ":" + std::to_string(s));
// std::cerr << "readAsText: something wrong. " << size << ":" << s << std::endl;
return;
}
for (unsigned int i = 0; i < s; i++) {
@ -1010,7 +1035,9 @@ namespace NGT {
size <<= 1;
} while (size <= idx);
if (size > 0xffffffff) {
std::cerr << "Vector is too big. " << size << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Vector is too big. " + std::to_string(size));
// std::cerr << "Vector is too big. " << size << std::endl;
abort();
}
reserve(size, allocator);
@ -1092,7 +1119,9 @@ namespace NGT {
Vector<size_t>::iterator rmi
= std::lower_bound(removedList->begin(allocator), removedList->end(allocator), id, std::greater<size_t>());
if ((rmi != removedList->end(allocator)) && ((*rmi) == id)) {
std::cerr << "removedListPush: already existed! continue... ID=" << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("removedListPush: already existed! continue... ID=" + std::to_string(id));
// std::cerr << "removedListPush: already existed! continue... ID=" << id << std::endl;
return;
}
removedList->insert(rmi, id, allocator);
@ -1271,7 +1300,9 @@ namespace NGT {
size_t idx;
NGT::Serializer::readAsText(is, idx);
if (i != idx) {
std::cerr << "PersistentRepository: Error. index of a specified import file is invalid. " << idx << ":" << i << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("PersistentRepository: Error. index of a specified import file is invalid. " + std::to_string(idx) + ":" + std::to_string(i));
// std::cerr << "PersistentRepository: Error. index of a specified import file is invalid. " << idx << ":" << i << std::endl;
}
char type;
NGT::Serializer::readAsText(is, type);
@ -1599,7 +1630,9 @@ namespace NGT {
size_t idx;
NGT::Serializer::readAsText(is, idx);
if (i != idx) {
std::cerr << "Repository: Error. index of a specified import file is invalid. " << idx << ":" << i << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Repository: Error. index of a specified import file is invalid. " + std::to_string(idx) + ":" + std::to_string(i));
// std::cerr << "Repository: Error. index of a specified import file is invalid. " << idx << ":" << i << std::endl;
}
char type;
NGT::Serializer::readAsText(is, type);
@ -1655,6 +1688,7 @@ namespace NGT {
#ifdef ADVANCED_USE_REMOVED_LIST
size_t count() { return std::vector<TYPE*>::size() == 0 ? 0 : std::vector<TYPE*>::size() - removedList.size() - 1; }
virtual int64_t memSize() { return std::vector<TYPE*>::size() == 0 ? 0 : (*this)[1]->memSize() * std::vector<TYPE*>::size() + removedList.size() * sizeof(size_t); }
protected:
std::priority_queue<size_t, std::vector<size_t>, std::greater<size_t> > removedList;
#endif
@ -1724,6 +1758,7 @@ namespace NGT {
is >> o.distance;
return is;
}
int64_t memSize() const { return sizeof(id) + sizeof(distance); }
uint32_t id;
float distance;
};
@ -1737,6 +1772,7 @@ namespace NGT {
public:
Container(Object &o, ObjectID i):object(o), id(i) {}
Container(Container &c):object(c.object), id(c.id) {}
virtual int64_t memSize() { return sizeof(ObjectID); }
Object &object;
ObjectID id;
};
@ -1791,6 +1827,11 @@ namespace NGT {
ResultPriorityQueue &getWorkingResult() { return workingResult; }
virtual int64_t memSize();
// virtual int64_t memSize() {
// auto workres_size = workingResult.size() == 0 ? 0 : workingResult.size() * workingResult.top().memSize();
// return sizeof(size_t) * 3 + sizeof(float) * 3 + result->memSize() + 1 + workres_size + Container::memSize();
// }
size_t size;
Distance radius;
@ -1828,6 +1869,7 @@ namespace NGT {
}
void *getQuery() { return query; }
const std::type_info &getQueryType() { return *queryType; }
virtual int64_t memSize() { return std::strlen((char*)getQuery()) + sizeof(getQueryType()) + SearchContainer::memSize(); }
private:
void deleteQuery() {
if (query == 0) {

View File

@ -599,7 +599,7 @@ NeighborhoodGraph::setupSeeds(NGT::SearchContainer &sc, ObjectDistances &seeds,
}
// for milvus
void NeighborhoodGraph::search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::ConcurrentBitsetPtr & bitset)
void NeighborhoodGraph::search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView& bitset)
{
if (sc.explorationCoefficient == 0.0)
{
@ -710,7 +710,7 @@ NeighborhoodGraph::setupSeeds(NGT::SearchContainer &sc, ObjectDistances &seeds,
distanceChecked.insert(neighbor.id);
// judge if id in blacklist
if (bitset != nullptr && bitset->test((faiss::ConcurrentBitset::id_type_t)neighbor.id - 1)) {
if (!bitset.empty() && bitset.test((faiss::ConcurrentBitset::id_type_t)neighbor.id - 1)) {
continue;
}

View File

@ -24,6 +24,7 @@
#include "NGT/ObjectSpaceRepository.h"
#include "faiss/utils/ConcurrentBitset.h"
#include "faiss/utils/BitsetView.h"
#include "NGT/HashBasedBooleanSet.h"
@ -187,6 +188,14 @@ namespace NGT {
}
}
virtual int64_t memSize() {
int64_t ret = prevsize->size() * sizeof(unsigned short);
for (size_t i = 1; i < this->size(); ++ i) {
ret += (*this)[i]->memSize();
}
return ret;
}
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
Vector<unsigned short> *prevsize;
@ -211,6 +220,7 @@ namespace NGT {
usedSize++;
}
size_t size() { return usedSize; }
virtual int64_t memSize() { return reservedSize * (sizeof(uint64_t) + (*this)[0].second->memSize()); }
size_t reservedSize;
size_t usedSize;
};
@ -219,6 +229,13 @@ namespace NGT {
public:
SearchGraphRepository() {}
bool isEmpty(size_t idx) { return (*this)[idx].empty(); }
virtual int64_t memSize() {
int64_t ret = 0;
for (size_t i = 1; i < this->size(); ++ i) {
ret += (*this)[i].memSize();
}
return ret;
}
void deserialize(std::ifstream &is, ObjectRepository &objectRepository) {
if (!is.is_open()) {
@ -496,6 +513,8 @@ namespace NGT {
return os;
}
int64_t memSize() { return sizeof(*this); }
int16_t truncationThreshold;
int16_t edgeSizeForCreation;
int16_t edgeSizeForSearch;
@ -679,7 +698,7 @@ namespace NGT {
void search(NGT::SearchContainer &sc, ObjectDistances &seeds);
// for milvus
void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::ConcurrentBitsetPtr & bitset);
void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView&bitset);
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
template <typename COMPARATOR, typename CHECK_LIST> void searchReadOnlyGraph(NGT::SearchContainer &sc, ObjectDistances &seeds);
@ -933,6 +952,7 @@ namespace NGT {
public:
virtual int64_t memSize() { return repository.memSize() + searchRepository.memSize() + property.memSize() + objectSpace->memSize(); }
GraphRepository repository;
ObjectSpace *objectSpace;

View File

@ -120,8 +120,10 @@ namespace NGT {
}
}
if (nQueries > ids.size()) {
std::cerr << "# of Queries is not enough." << std::endl;
return DBL_MAX;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("# of Queries is not enough.");
// std::cerr << "# of Queries is not enough." << std::endl;
return DBL_MAX;
}
NGT::Timer timer;
@ -233,14 +235,18 @@ namespace NGT {
{
if (!logDisabled)
{
std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: adjusting outgoing and incoming edges...");
// std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
}
NGT::Timer timer;
timer.start();
std::vector<NGT::ObjectDistances> graph;
try
{
std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Optimizer::execute: Extract the graph data.");
// std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
// extract only edges from the index to reduce the memory usage.
NGT::GraphReconstructor::extractGraph(graph, graphIndex);
NeighborhoodGraph::Property & prop = graphIndex.getGraphProperty();
@ -250,7 +256,9 @@ namespace NGT {
}
NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges);
timer.stop();
std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Optimizer::execute: Graph reconstruction time=" + std::to_string(timer.time) + " (sec) ");
// std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
}
catch (NGT::Exception & err)
@ -263,7 +271,9 @@ namespace NGT {
{
if (!logDisabled)
{
std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: redusing shortcut edges...");
// std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
}
try
{
@ -271,7 +281,9 @@ namespace NGT {
timer.start();
NGT::GraphReconstructor::adjustPathsEffectively(graphIndex);
timer.stop();
std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Optimizer::execute: Path adjustment time=" + std::to_string(timer.time) + " (sec) ");
// std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl;
}
catch (NGT::Exception & err)
{
@ -286,7 +298,9 @@ namespace NGT {
{
if (!logDisabled)
{
std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing search parameters...");
// std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
}
NGT::GraphIndex & outGraph = static_cast<NGT::GraphIndex &>(outIndex.getIndex());
NGT::Optimizer optimizer(outIndex);
@ -321,7 +335,9 @@ namespace NGT {
{
if (!logDisabled)
{
std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing prefetch parameters...");
// std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
}
try
{
@ -343,7 +359,9 @@ namespace NGT {
{
if (!logDisabled)
{
std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: generating the accuracy table...");
// std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
}
try
{
@ -387,14 +405,18 @@ namespace NGT {
NGT::GraphIndex graphIndex(outIndexPath, false);
if (numOfOutgoingEdges > 0 || numOfIncomingEdges > 0) {
if (!logDisabled) {
std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: adjusting outgoing and incoming edges...");
// std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
}
redirector.begin();
NGT::Timer timer;
timer.start();
std::vector<NGT::ObjectDistances> graph;
try {
std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Optimizer::execute: Extract the graph data.");
// std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
// extract only edges from the index to reduce the memory usage.
NGT::GraphReconstructor::extractGraph(graph, graphIndex);
NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty();
@ -403,7 +425,9 @@ namespace NGT {
}
NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges);
timer.stop();
std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Optimizer::execute: Graph reconstruction time=" + std::to_string(timer.time) + " (sec) ");
// std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
graphIndex.saveGraph(outIndexPath);
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
graphIndex.saveProperty(outIndexPath);
@ -415,14 +439,18 @@ namespace NGT {
if (shortcutReduction) {
if (!logDisabled) {
std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: redusing shortcut edges...");
// std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
}
try {
NGT::Timer timer;
timer.start();
NGT::GraphReconstructor::adjustPathsEffectively(graphIndex);
timer.stop();
std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("optimizer::execute: path adjustment time=" + std::to_string(timer.time) + " (sec) ");
// std::cerr << "optimizer::execute: path adjustment time=" << timer.time << " (sec) " << std::endl;
graphIndex.saveGraph(outIndexPath);
} catch (NGT::Exception &err) {
redirector.end();
@ -441,7 +469,9 @@ namespace NGT {
if (searchParameterOptimization) {
if (!logDisabled) {
std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing search parameters...");
// std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
}
NGT::Index outIndex(outIndexPath);
NGT::GraphIndex &outGraph = static_cast<NGT::GraphIndex&>(outIndex.getIndex());
@ -472,7 +502,9 @@ namespace NGT {
NGT::GraphIndex &outGraph = static_cast<NGT::GraphIndex&>(outIndex.getIndex());
if (prefetchParameterOptimization) {
if (!logDisabled) {
std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing prefetch parameters...");
// std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
}
try {
auto prefetch = adjustPrefetchParameters(outIndex);
@ -491,7 +523,9 @@ namespace NGT {
}
if (accuracyTableGeneration) {
if (!logDisabled) {
std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: generating the accuracy table...");
// std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
}
try {
auto table = NGT::Optimizer::generateAccuracyTable(outIndex, numOfResults, numOfQueries);

View File

@ -19,6 +19,7 @@
#include <unordered_map>
#include <unordered_set>
#include <list>
#include "defines.h"
#ifdef _OPENMP
#include <omp.h>
@ -34,7 +35,9 @@ class GraphReconstructor {
graph.reserve(graphIndex.repository.size());
for (size_t id = 1; id < graphIndex.repository.size(); id++) {
if (id % 1000000 == 0) {
std::cerr << "GraphReconstructor::extractGraph: Processed " << id << " objects." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor::extractGraph: Processed " + std::to_string(id) + " objects.");
// std::cerr << "GraphReconstructor::extractGraph: Processed " << id << " objects." << std::endl;
}
try {
NGT::GraphNode &node = *graphIndex.getNode(id);
@ -49,7 +52,9 @@ class GraphReconstructor {
graph.push_back(node);
#endif
if (graph.back().size() != graph.back().capacity()) {
std::cerr << "GraphReconstructor::extractGraph: Warning! The graph size must be the same as the capacity. " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor::extractGraph: Warning! The graph size must be the same as the capacity. " + std::to_string(id));
// std::cerr << "GraphReconstructor::extractGraph: Warning! The graph size must be the same as the capacity. " << id << std::endl;
}
} catch(NGT::Exception &err) {
graph.push_back(NGT::ObjectDistances());
@ -66,7 +71,9 @@ class GraphReconstructor {
adjustPaths(NGT::Index &outIndex)
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "construct index is not implemented." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("construct index is not implemented.");
// std::cerr << "construct index is not implemented." << std::endl;
exit(1);
#else
NGT::GraphIndex &outGraph = dynamic_cast<NGT::GraphIndex&>(outIndex.getIndex());
@ -93,8 +100,12 @@ class GraphReconstructor {
}
edge = true;
if (rank >= 1 && node[rank - 1].distance > node[rank].distance) {
std::cerr << "distance order is wrong!" << std::endl;
std::cerr << id << ":" << rank << ":" << node[rank - 1].id << ":" << node[rank].id << std::endl;
// std::cerr << "distance order is wrong!" << std::endl;
// std::cerr << id << ":" << rank << ":" << node[rank - 1].id << ":" << node[rank].id << std::endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("distance order is wrong!");
(*NGT_LOG_DEBUG_)(std::to_string(id) + ":" + std::to_string(rank) + ":" + std::to_string(node[rank - 1].id) + ":" + std::to_string(node[rank].id));
}
}
NGT::GraphNode &tn = *outGraph.getNode(id);
volatile bool found = false;
@ -141,7 +152,9 @@ class GraphReconstructor {
removeCount++;
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
it++;
continue;
}
@ -210,7 +223,9 @@ class GraphReconstructor {
node.clear();
#endif
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
tmpGraph.push_back(NGT::GraphNode(outGraph.repository.allocator));
#else
@ -224,7 +239,9 @@ class GraphReconstructor {
NGTThrowException(msg);
}
timer.stop();
std::cerr << "GraphReconstructor::adjustPaths: graph preparing time=" << timer << std::endl;
// std::cerr << "GraphReconstructor::adjustPaths: graph preparing time=" << timer << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor::adjustPaths: graph preparing time=" + std::to_string(timer.time));
timer.reset();
timer.start();
@ -287,12 +304,16 @@ class GraphReconstructor {
removeCandidates[id - 1].push_back(candidates[i].second);
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
timer.stop();
std::cerr << "GraphReconstructor::adjustPaths: extracting removed edge candidates time=" << timer << std::endl;
// std::cerr << "GraphReconstructor::adjustPaths: extracting removed edge candidates time=" << timer << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor::adjustPaths: extracting removed edge candidates time=" + std::to_string(timer.time));
timer.reset();
timer.start();
@ -311,7 +332,9 @@ class GraphReconstructor {
NGT::GraphNode &srcNode = tmpGraph[idx];
if (rank >= srcNode.size()) {
if (!removeCandidates[idx].empty()) {
std::cerr << "Something wrong! ID=" << id << " # of remaining candidates=" << removeCandidates[idx].size() << std::endl;
// std::cerr << "Something wrong! ID=" << id << " # of remaining candidates=" << removeCandidates[idx].size() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Something wrong! ID=" + std::to_string(id) + " # of remaining candidates=" + std::to_string(removeCandidates[idx].size()));
abort();
}
#if !defined(NGT_SHARED_MEMORY_ALLOCATOR)
@ -365,7 +388,9 @@ class GraphReconstructor {
insert(outSrcNode, srcNode[rank].id, srcNode[rank].distance);
#endif
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
it++;
continue;
}
@ -389,10 +414,14 @@ class GraphReconstructor {
void convertToANNG(std::vector<NGT::ObjectDistances> &graph)
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "convertToANNG is not implemented for shared memory." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("convertToANNG is not implemented for shared memory.");
// std::cerr << "convertToANNG is not implemented for shared memory." << std::endl;
return;
#else
std::cerr << "convertToANNG begin" << std::endl;
// std::cerr << "convertToANNG begin" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("convertToANNG begin");
for (size_t idx = 0; idx < graph.size(); idx++) {
NGT::GraphNode &node = graph[idx];
for (auto ni = node.begin(); ni != node.end(); ++ni) {
@ -417,7 +446,9 @@ class GraphReconstructor {
NGT::GraphNode tmp = node;
node.swap(tmp);
}
std::cerr << "convertToANNG end" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("convertToANNG end");
// std::cerr << "convertToANNG end" << std::endl;
#endif
}
@ -425,7 +456,9 @@ class GraphReconstructor {
void reconstructGraph(std::vector<NGT::ObjectDistances> &graph, NGT::GraphIndex &outGraph, size_t originalEdgeSize, size_t reverseEdgeSize)
{
if (reverseEdgeSize > 10000) {
std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
// std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("something wrong. Edge size=" + std::to_string(reverseEdgeSize));
exit(1);
}
@ -445,7 +478,9 @@ class GraphReconstructor {
} else {
NGT::ObjectDistances n = graph[id - 1];
if (n.size() < originalEdgeSize) {
std::cerr << "GraphReconstructor: Warning. The edges are too few. " << n.size() << ":" << originalEdgeSize << " for " << id << std::endl;
// std::cerr << "GraphReconstructor: Warning. The edges are too few. " << n.size() << ":" << originalEdgeSize << " for " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. The edges are too few. " + std::to_string(n.size()) + ":" + std::to_string(originalEdgeSize) + " for " + std::to_string(id));
continue;
}
n.resize(originalEdgeSize);
@ -456,7 +491,9 @@ class GraphReconstructor {
#endif
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
@ -485,13 +522,17 @@ class GraphReconstructor {
} catch(...) {}
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
reverseEdgeTimer.stop();
if (insufficientNodeCount != 0) {
std::cerr << "# of the nodes edges of which are in short = " << insufficientNodeCount << std::endl;
// std::cerr << "# of the nodes edges of which are in short = " << insufficientNodeCount << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("# of the nodes edges of which are in short = " + std::to_string(insufficientNodeCount));
}
normalizeEdgeTimer.start();
@ -499,7 +540,9 @@ class GraphReconstructor {
try {
NGT::GraphNode &n = *outGraph.getNode(id);
if (id % 100000 == 0) {
std::cerr << "Processed " << id << " nodes" << std::endl;
// std::cerr << "Processed " << id << " nodes" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id) + " nodes");
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::sort(n.begin(outGraph.repository.allocator), n.end(outGraph.repository.allocator));
@ -528,13 +571,18 @@ class GraphReconstructor {
n.swap(tmp);
#endif
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
normalizeEdgeTimer.stop();
std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
<< ":" << normalizeEdgeTimer.time << std::endl;
// std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
// << ":" << normalizeEdgeTimer.time << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Reconstruction time=" + std::to_string(originalEdgeTimer.time) + ":" + std::to_string(reverseEdgeTimer.time)
+ ":" + std::to_string(normalizeEdgeTimer.time));
NGT::Property prop;
outGraph.getProperty().get(prop);
@ -550,20 +598,26 @@ class GraphReconstructor {
char mode = 'a')
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "reconstructGraphWithConstraint is not implemented." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("reconstructGraphWithConstraint is not implemented.");
// std::cerr << "reconstructGraphWithConstraint is not implemented." << std::endl;
abort();
#else
NGT::Timer originalEdgeTimer, reverseEdgeTimer, normalizeEdgeTimer;
if (reverseEdgeSize > 10000) {
std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
// std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("something wrong. Edge size=" + std::to_string(reverseEdgeSize));
exit(1);
}
for (size_t id = 1; id < outGraph.repository.size(); id++) {
if (id % 1000000 == 0) {
std::cerr << "Processed " << id << std::endl;
// std::cerr << "Processed " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id));
}
try {
NGT::GraphNode &node = *outGraph.getNode(id);
@ -574,7 +628,9 @@ class GraphReconstructor {
NGT::GraphNode empty;
node.swap(empty);
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
@ -585,13 +641,17 @@ class GraphReconstructor {
try {
NGT::GraphNode &node = graph[id - 1];
if (id % 100000 == 0) {
std::cerr << "Processed (summing up) " << id << std::endl;
// std::cerr << "Processed (summing up) " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed (summing up) " + std::to_string(id));
}
for (size_t rank = 0; rank < node.size(); rank++) {
reverse[node[rank].id].push_back(ObjectDistance(id, node[rank].distance));
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
@ -628,7 +688,9 @@ class GraphReconstructor {
}
}
reverseEdgeTimer.stop();
std::cerr << "The number of nodes with zero outdegree by reverse edges=" << zeroCount << std::endl;
// std::cerr << "The number of nodes with zero outdegree by reverse edges=" << zeroCount << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("The number of nodes with zero outdegree by reverse edges=" + std::to_string(zeroCount));
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
normalizeEdgeTimer.start();
@ -636,7 +698,9 @@ class GraphReconstructor {
try {
NGT::GraphNode &n = *outGraph.getNode(id);
if (id % 100000 == 0) {
std::cerr << "Processed " << id << std::endl;
// std::cerr << "Processed " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id));
}
std::sort(n.begin(), n.end());
NGT::ObjectID prev = 0;
@ -651,7 +715,9 @@ class GraphReconstructor {
NGT::GraphNode tmp = n;
n.swap(tmp);
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
@ -661,7 +727,9 @@ class GraphReconstructor {
originalEdgeTimer.start();
for (size_t id = 1; id < outGraph.repository.size(); id++) {
if (id % 1000000 == 0) {
std::cerr << "Processed " << id << std::endl;
// std::cerr << "Processed " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id));
}
NGT::GraphNode &node = graph[id - 1];
try {
@ -683,15 +751,20 @@ class GraphReconstructor {
outGraph.addEdge(id, nodeID, distance, false);
}
} catch(NGT::Exception &err) {
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
originalEdgeTimer.stop();
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
<< ":" << normalizeEdgeTimer.time << std::endl;
// std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
// << ":" << normalizeEdgeTimer.time << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Reconstruction time=" + std::to_string(originalEdgeTimer.time) + ":" + std::to_string(reverseEdgeTimer.time)
+ ":" + std::to_string(normalizeEdgeTimer.time));
#endif
}
@ -703,7 +776,9 @@ class GraphReconstructor {
void reconstructANNGFromANNG(std::vector<NGT::ObjectDistances> &graph, NGT::Index &index, size_t edgeSize)
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "reconstructANNGFromANNG is not implemented." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("reconstructANNGFromANNG is not implemented.");
// std::cerr << "reconstructANNGFromANNG is not implemented." << std::endl;
abort();
#else
@ -712,7 +787,9 @@ class GraphReconstructor {
// remove all edges in the index.
for (size_t id = 1; id < outGraph.repository.size(); id++) {
if (id % 1000000 == 0) {
std::cerr << "Processed " << id << " nodes." << std::endl;
// std::cerr << "Processed " << id << " nodes." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id) + " nodes.");
}
try {
NGT::GraphNode &node = *outGraph.getNode(id);
@ -810,7 +887,9 @@ class GraphReconstructor {
for (size_t idx = 0; idx < batchSize; idx++) {
size_t id = bid + idx;
if (id % 100000 == 0) {
std::cerr << "# of processed objects=" << id << std::endl;
// std::cerr << "# of processed objects=" << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("# of processed objects=" + std::to_string(id));
}
if (objectRepository.isEmpty(id)) {
continue;
@ -876,7 +955,9 @@ class GraphReconstructor {
for (size_t idx = 0; idx < batchSize; idx++) {
size_t id = bid + idx;
if (id % 10000 == 0) {
std::cerr << "# of processed objects=" << id << std::endl;
// std::cerr << "# of processed objects=" << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("# of processed objects=" + std::to_string(id));
}
for (auto i = results[idx].begin(); i != results[idx].end(); ++i) {
if ((*i).id != id) {

View File

@ -16,6 +16,7 @@
#pragma once
#include "defines.h"
#include <iostream>
#include <cstring>
#include <stdint.h>
@ -53,7 +54,9 @@ class HashBasedBooleanSet{
_mask = _tableSize - 1;
const uint32_t checkValue = _hash1(tableSize);
if(checkValue != 0){
std::cerr << "[WARN] table size is not 2^N : " << tableSize << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("[WARN] table size is not 2^N : " + std::to_string(tableSize));
// std::cerr << "[WARN] table size is not 2^N : " << tableSize << std::endl;
}
_table = new uint32_t[tableSize];

View File

@ -266,7 +266,9 @@ void NGT::Index::loadRawDataAndCreateIndex(NGT::Index * index_, const float * ro
timer.start();
index_->createIndex(threadSize);
timer.stop();
cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Index creation time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
// cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
}
void
@ -280,17 +282,23 @@ NGT::Index::loadAndCreateIndex(Index &index, const string &database, const strin
return;
}
timer.stop();
cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Data loading time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
// cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
if (index.getObjectRepositorySize() == 0) {
NGTThrowException("Index::create: Data file is empty.");
}
cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(index.getObjectRepositorySize() - 1));
// cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
timer.reset();
timer.start();
index.createIndex(threadSize);
timer.stop();
index.saveIndex(database);
cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
// cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Index creation time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
}
// For milvus
@ -307,13 +315,21 @@ void NGT::Index::append(NGT::Index * index_, const float * data, size_t dataSize
NGTThrowException("Index::append: No data.");
}
timer.stop();
cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
cerr << "# of objects=" << index_->getObjectRepositorySize() - 1 << endl;
// cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
// cerr << "# of objects=" << index_->getObjectRepositorySize() - 1 << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)(
"Data loading time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0)
+ " (msec)");
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(index_->getObjectRepositorySize() - 1));
}
timer.reset();
timer.start();
index_->createIndex(threadSize);
timer.stop();
cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
// cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Index creation time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
return;
}
@ -326,14 +342,20 @@ NGT::Index::append(const string &database, const string &dataFile, size_t thread
index.append(dataFile, dataSize);
}
timer.stop();
cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("Data loading time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(index.getObjectRepositorySize() - 1));
}
// cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
// cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
timer.reset();
timer.start();
index.createIndex(threadSize);
timer.stop();
index.saveIndex(database);
cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
// cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Index creation time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
return;
}
@ -348,14 +370,20 @@ NGT::Index::append(const string &database, const float *data, size_t dataSize, s
NGTThrowException("Index::append: No data.");
}
timer.stop();
cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
// cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
// cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("Data loading time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(index.getObjectRepositorySize() - 1));
}
timer.reset();
timer.start();
index.createIndex(threadSize);
timer.stop();
index.saveIndex(database);
cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
// cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Index creation time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
return;
}
@ -368,13 +396,19 @@ NGT::Index::remove(const string &database, vector<ObjectID> &objects, bool force
try {
index.remove(*i, force);
} catch (Exception &err) {
cerr << "Warning: Cannot remove the node. ID=" << *i << " : " << err.what() << endl;
// cerr << "Warning: Cannot remove the node. ID=" << *i << " : " << err.what() << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Warning: Cannot remove the node. ID=" + std::to_string(*i) + " : " + err.what());
continue;
}
}
timer.stop();
cerr << "Data removing time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
// cerr << "Data removing time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
// cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("Data removing time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(index.getObjectRepositorySize() - 1));
}
index.saveIndex(database);
return;
}
@ -411,8 +445,12 @@ NGT::Index::importIndex(const string &database, const string &file) {
}
idx->importIndex(file);
timer.stop();
cerr << "Data importing time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
cerr << "# of objects=" << idx->getObjectRepositorySize() - 1 << endl;
// cerr << "Data importing time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
// cerr << "# of objects=" << idx->getObjectRepositorySize() - 1 << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("Data importing time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(idx->getObjectRepositorySize() - 1));
}
idx->saveIndex(database);
delete idx;
}
@ -424,8 +462,12 @@ NGT::Index::exportIndex(const string &database, const string &file) {
timer.start();
idx.exportIndex(file);
timer.stop();
cerr << "Data exporting time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
cerr << "# of objects=" << idx.getObjectRepositorySize() - 1 << endl;
// cerr << "Data exporting time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
// cerr << "# of objects=" << idx.getObjectRepositorySize() - 1 << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("Data exporting time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(idx.getObjectRepositorySize() - 1));
}
}
std::vector<float>
@ -535,7 +577,9 @@ CreateIndexThread::run() {
} catch(NGT::ThreadTerminationException &err) {
break;
} catch(NGT::Exception &err) {
cerr << "CreateIndex::search:Error! popFront " << err.what() << endl;
// cerr << "CreateIndex::search:Error! popFront " << err.what() << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("CreateIndex::search:Error! popFront " + std::string(err.what()));
break;
}
ObjectDistances *rs = new ObjectDistances;
@ -547,7 +591,9 @@ CreateIndexThread::run() {
graphIndex.searchForNNGInsertion(obj, *rs);
}
} catch(NGT::Exception &err) {
cerr << "CreateIndex::search:Fatal error! ID=" << job.id << " " << err.what() << endl;
// cerr << "CreateIndex::search:Fatal error! ID=" << job.id << " " << err.what() << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("CreateIndex::search:Fatal error! ID=" + std::to_string(job.id) + " " + err.what());
abort();
}
job.results = rs;
@ -684,7 +730,9 @@ GraphAndTreeIndex::createTreeIndex()
ObjectRepository &fr = GraphIndex::objectSpace->getRepository();
for (size_t id = 0; id < fr.size(); id++){
if (id % 100000 == 0) {
cerr << " Processed id=" << id << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(" Processed id=" + std::to_string(id));
// cerr << " Processed id=" << id << endl;
}
if (fr.isEmpty(id)) {
continue;
@ -698,8 +746,12 @@ GraphAndTreeIndex::createTreeIndex()
try {
DVPTree::insert(tiobj);
} catch (Exception &err) {
cerr << "GraphAndTreeIndex::createTreeIndex: Warning. ID=" << id << ":";
cerr << err.what() << " continue.." << endl;
// cerr << "GraphAndTreeIndex::createTreeIndex: Warning. ID=" << id << ":";
// cerr << err.what() << " continue.." << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("GraphAndTreeIndex::createTreeIndex: Warning. ID=" + std::to_string(id) + ":");
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue..");
}
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
GraphIndex::objectSpace->deleteObject(f);
@ -840,10 +892,16 @@ insertMultipleSearchResults(GraphIndex &neighborhoodGraph,
}
if (static_cast<int>(gr.id) > neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForCreation &&
static_cast<int>(gr.results->size()) < neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForCreation) {
cerr << "createIndex: Warning. The specified number of edges could not be acquired, because the pruned parameter [-S] might be set." << endl;
cerr << " The node id=" << gr.id << endl;
cerr << " The number of edges for the node=" << gr.results->size() << endl;
cerr << " The pruned parameter (edgeSizeForSearch [-S])=" << neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForSearch << endl;
// cerr << "createIndex: Warning. The specified number of edges could not be acquired, because the pruned parameter [-S] might be set." << endl;
// cerr << " The node id=" << gr.id << endl;
// cerr << " The number of edges for the node=" << gr.results->size() << endl;
// cerr << " The pruned parameter (edgeSizeForSearch [-S])=" << neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForSearch << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("createIndex: Warning. The specified number of edges could not be acquired, because the pruned parameter [-S] might be set.");
(*NGT_LOG_DEBUG_)(" The node id=" + std::to_string(gr.id));
(*NGT_LOG_DEBUG_)(" The number of edges for the node=" + std::to_string(gr.results->size()));
(*NGT_LOG_DEBUG_)(" The pruned parameter (edgeSizeForSearch [-S])=" + std::to_string(neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForSearch));
}
}
neighborhoodGraph.insertNode(gr.id, *gr.results);
}
@ -886,7 +944,9 @@ GraphIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository)
// wait for the completion of the search
threads.waitForFinish();
if (output.size() != cnt) {
cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
// cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong.");
cnt = output.size();
}
// insertion
@ -903,7 +963,9 @@ GraphIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository)
count += cnt;
if (timerCount <= count) {
timer.stop();
cerr << "Processed " << timerCount << " time= " << timer << endl;
// cerr << "Processed " << timerCount << " time= " << timer << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(timerCount )+ " time= " + std::to_string(timer.time));
timerCount += timerInterval;
timer.start();
}
@ -956,13 +1018,17 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
try {
node = outGraph.getNode(id);
} catch(NGT::Exception &err) {
std::cerr << "ngt info: Error. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "ngt info: Error. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ngt info: Error. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
valid = false;
continue;
}
numberOfNodes++;
if (numberOfNodes % 1000000 == 0) {
std::cerr << "Processed " << numberOfNodes << std::endl;
// std::cerr << "Processed " << numberOfNodes << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(numberOfNodes));
}
size_t esize = node->size() > edgeSize ? edgeSize : node->size();
if (esize == 0) {
@ -988,7 +1054,9 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
NGT::ObjectDistance &n = (*node)[i];
#endif
if (n.id == 0) {
std::cerr << "ngt info: Warning. id is zero." << std::endl;
// std::cerr << "ngt info: Warning. id is zero." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ngt info: Warning. id is zero.");
valid = false;
}
indegreeCount[n.id]++;
@ -1029,10 +1097,15 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
} catch(NGT::Exception &err) {
count++;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "Directed edge! " << id << "->" << node.at(i, graph.allocator).id << " no object. "
<< node.at(i, graph.allocator).id << std::endl;
// std::cerr << "Directed edge! " << id << "->" << node.at(i, graph.allocator).id << " no object. "
// << node.at(i, graph.allocator).id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Directed edge! " + std::to_string(id) + "->" + std::to_string(node.at(i, graph.allocator).id) + " no object. "
+ std::to_string(node.at(i, graph.allocator).id));
#else
std::cerr << "Directed edge! " << id << "->" << node[i].id << " no object. " << node[i].id << std::endl;
// std::cerr << "Directed edge! " << id << "->" << node[i].id << " no object. " << node[i].id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Directed edge! " + std::to_string(id) + "->" + std::to_string(node[i].id) + " no object. " + std::to_string(node[i].id));
#endif
continue;
}
@ -1050,16 +1123,23 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
}
if (!found) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "Directed edge! " << id << "->" << node.at(i, graph.allocator).id << " no edge. "
<< node.at(i, graph.allocator).id << "->" << id << std::endl;
// std::cerr << "Directed edge! " << id << "->" << node.at(i, graph.allocator).id << " no edge. "
// << node.at(i, graph.allocator).id << "->" << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Directed edge! " + std::to_string(id) + "->" + std::to_string(node.at(i, graph.allocator).id) + " no edge. "
+ std::to_string(node.at(i, graph.allocator).id) + "->" + std::to_string(id));
#else
std::cerr << "Directed edge! " << id << "->" << node[i].id << " no edge. " << node[i].id << "->" << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Directed edge! " + std::to_string(id) + "->" + std::to_string(node[i].id) + " no edge. " + std::to_string(node[i].id) + "->" + std::to_string(id));
// std::cerr << "Directed edge! " << id << "->" << node[i].id << " no edge. " << node[i].id << "->" << id << std::endl;
#endif
count++;
}
}
}
std::cerr << "The number of directed edges=" << count << std::endl;
// std::cerr << "The number of directed edges=" << count << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("The number of directed edges=" + std::to_string(count));
}
// calculate outdegree distance 10
@ -1075,7 +1155,9 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
try {
n = outGraph.getNode(id);
} catch(NGT::Exception &err) {
std::cerr << "ngt info: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "ngt info: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ngt info: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
NGT::GraphNode &node = *n;
@ -1131,7 +1213,9 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
try {
node = outGraph.getNode(id);
} catch(NGT::Exception &err) {
std::cerr << "ngt info: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
// std::cerr << "ngt info: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ngt info: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
size_t esize = node->size();
@ -1148,7 +1232,9 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
}
if (indegreeCount[id] == 0) {
numberOfNodesWithoutIndegree++;
std::cerr << "Error! The node without incoming edges. " << id << std::endl;
// std::cerr << "Error! The node without incoming edges. " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Error! The node without incoming edges. " + std::to_string(id));
valid = false;
}
if (indegreeCount[id] > static_cast<int>(maxNumberOfIndegree)) {
@ -1229,6 +1315,47 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
c5 /= (double)numberOfNodes * 0.05;
c1 /= (double)numberOfNodes * 0.01;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("The size of the object repository (not the number of the objects):\t" + std::to_string(repo.size() - 1));
(*NGT_LOG_DEBUG_)("The number of the removed objects:\t" + std::to_string(removedObjectCount) + "/" + std::to_string(repo.size() - 1));
(*NGT_LOG_DEBUG_)("The number of the nodes:\t" + std::to_string(numberOfNodes));
(*NGT_LOG_DEBUG_)("The number of the edges:\t" + std::to_string(numberOfOutdegree));
(*NGT_LOG_DEBUG_)("The mean of the edge lengths:\t" + std::to_string(distance / (double)numberOfOutdegree));
(*NGT_LOG_DEBUG_)("The mean of the number of the edges per node:\t" + std::to_string((double)numberOfOutdegree / (double)numberOfNodes));
(*NGT_LOG_DEBUG_)("The number of the nodes without edges:\t" + std::to_string(numberOfNodesWithoutEdges));
(*NGT_LOG_DEBUG_)("The maximum of the outdegrees:\t" + std::to_string(maxNumberOfOutdegree));
if (minNumberOfOutdegree == SIZE_MAX) {
(*NGT_LOG_DEBUG_)("The minimum of the outdegrees:\t-NA-");
} else {
(*NGT_LOG_DEBUG_)("The minimum of the outdegrees:\t" + std::to_string(minNumberOfOutdegree));
}
(*NGT_LOG_DEBUG_)("The number of the nodes where indegree is 0:\t" + std::to_string(numberOfNodesWithoutIndegree));
(*NGT_LOG_DEBUG_)("The maximum of the indegrees:\t" + std::to_string(maxNumberOfIndegree));
if (minNumberOfIndegree == INT64_MAX) {
(*NGT_LOG_DEBUG_)("The minimum of the indegrees:\t-NA-");
} else {
(*NGT_LOG_DEBUG_)("The minimum of the indegrees:\t" + std::to_string(minNumberOfIndegree));
}
(*NGT_LOG_DEBUG_)("#-nodes,#-edges,#-no-indegree,avg-edges,avg-dist,max-out,min-out,v-out,max-in,min-in,v-in,med-out,"
"med-in,mode-out,mode-in,c95,c5,o-distance(10),o-skip,i-distance(10),i-skip:"
+ std::to_string(numberOfNodes) + ":" + std::to_string(numberOfOutdegree) + ":" + std::to_string(numberOfNodesWithoutIndegree) + ":"
+ std::to_string((double)numberOfOutdegree / (double)numberOfNodes) + ":"
+ std::to_string(distance / (double)numberOfOutdegree) + ":"
+ std::to_string(maxNumberOfOutdegree) + ":" + std::to_string(minNumberOfOutdegree) + ":" + std::to_string(sumOfSquareOfOutdegree / (double)numberOfOutdegree) + ":"
+ std::to_string(maxNumberOfIndegree) + ":" + std::to_string(minNumberOfIndegree) + ":" + std::to_string(sumOfSquareOfIndegree / (double)numberOfOutdegree) + ":"
+ std::to_string(medianOutdegree) + ":" + std::to_string(medianIndegree) + ":" + std::to_string(modeOutdegree) + ":" + std::to_string(modeIndegree)
+ ":" + std::to_string(c95) + ":" + std::to_string(c5) + ":" + std::to_string(c99) + ":" + std::to_string(c1) + ":" + std::to_string(distance10) + ":" + std::to_string(d10SkipCount) + ":"
+ std::to_string(indegreeDistance10) + ":" + std::to_string(ind10SkipCount));
if (mode == 'h') {
(*NGT_LOG_DEBUG_)("#\tout\tin");
for (size_t i = 0; i < outdegreeHistogram.size() || i < indegreeHistogram.size(); i++) {
size_t out = outdegreeHistogram.size() <= i ? 0 : outdegreeHistogram[i];
size_t in = indegreeHistogram.size() <= i ? 0 : indegreeHistogram[i];
(*NGT_LOG_DEBUG_)(std::to_string(i) + "\t" + std::to_string(out) + "\t" + std::to_string(in));
}
}
}
/*
std::cerr << "The size of the object repository (not the number of the objects):\t" << repo.size() - 1 << std::endl;
std::cerr << "The number of the removed objects:\t" << removedObjectCount << "/" << repo.size() - 1 << std::endl;
std::cerr << "The number of the nodes:\t" << numberOfNodes << std::endl;
@ -1250,13 +1377,13 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
std::cerr << "The minimum of the indegrees:\t" << minNumberOfIndegree << std::endl;
}
std::cerr << "#-nodes,#-edges,#-no-indegree,avg-edges,avg-dist,max-out,min-out,v-out,max-in,min-in,v-in,med-out,"
"med-in,mode-out,mode-in,c95,c5,o-distance(10),o-skip,i-distance(10),i-skip:"
<< numberOfNodes << ":" << numberOfOutdegree << ":" << numberOfNodesWithoutIndegree << ":"
"med-in,mode-out,mode-in,c95,c5,o-distance(10),o-skip,i-distance(10),i-skip:"
<< numberOfNodes << ":" << numberOfOutdegree << ":" << numberOfNodesWithoutIndegree << ":"
<< std::setprecision(10) << (double)numberOfOutdegree / (double)numberOfNodes << ":"
<< distance / (double)numberOfOutdegree << ":"
<< maxNumberOfOutdegree << ":" << minNumberOfOutdegree << ":" << sumOfSquareOfOutdegree / (double)numberOfOutdegree<< ":"
<< maxNumberOfIndegree << ":" << minNumberOfIndegree << ":" << sumOfSquareOfIndegree / (double)numberOfOutdegree << ":"
<< medianOutdegree << ":" << medianIndegree << ":" << modeOutdegree << ":" << modeIndegree
<< medianOutdegree << ":" << medianIndegree << ":" << modeOutdegree << ":" << modeIndegree
<< ":" << c95 << ":" << c5 << ":" << c99 << ":" << c1 << ":" << distance10 << ":" << d10SkipCount << ":"
<< indegreeDistance10 << ":" << ind10SkipCount << std::endl;
if (mode == 'h') {
@ -1267,6 +1394,7 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
std::cerr << i << "\t" << out << "\t" << in << std::endl;
}
}
*/
return valid;
}
@ -1309,7 +1437,9 @@ GraphAndTreeIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository)
threads.waitForFinish();
if (output.size() != cnt) {
cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
// cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong.");
cnt = output.size();
}
@ -1331,12 +1461,16 @@ GraphAndTreeIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository)
GraphIndex::objectSpace->deleteObject(f);
#endif
} catch (Exception &err) {
cerr << "NGT::createIndex: Fatal error. ID=" << job.id << ":";
// cerr << "NGT::createIndex: Fatal error. ID=" << job.id << ":";
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("NGT::createIndex: Fatal error. ID=" + std::to_string(job.id) + ":");
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
GraphIndex::objectSpace->deleteObject(f);
#endif
if (NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeKNNG) {
cerr << err.what() << " continue.." << endl;
// cerr << err.what() << " continue.." << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue..");
} else {
throw err;
}
@ -1354,10 +1488,14 @@ GraphAndTreeIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository)
count += cnt;
if (timerCount <= count) {
timer.stop();
cerr << "Processed " << timerCount << " objects. time= " << timer << endl;
timerCount += timerInterval;
timer.start();
timer.stop();
// cerr << "Processed " << timerCount << " objects. time= " << timer << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)(
"Processed " + std::to_string(timerCount) + " objects. time= " + std::to_string(timer.time));
}
timerCount += timerInterval;
timer.start();
}
buildTimeController.adjustEdgeSize(count);
if (pathAdjustCount > 0 && pathAdjustCount <= count) {
@ -1384,7 +1522,9 @@ GraphAndTreeIndex::createIndex(const vector<pair<NGT::Object*, size_t> > &object
size_t count = 0;
timer.start();
if (threadPoolSize <= 0) {
cerr << "Not implemented!!" << endl;
// cerr << "Not implemented!!" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Not implemented!!");
abort();
} else {
CreateIndexThreadPool threads(threadPoolSize);
@ -1422,7 +1562,9 @@ GraphAndTreeIndex::createIndex(const vector<pair<NGT::Object*, size_t> > &object
}
threads.waitForFinish();
if (output.size() != cnt) {
cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
// cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong.");
cnt = output.size();
}
{
@ -1488,19 +1630,25 @@ GraphAndTreeIndex::createIndex(const vector<pair<NGT::Object*, size_t> > &object
GraphIndex::objectSpace->deleteObject(f);
#endif
} catch (Exception &err) {
cerr << "NGT::createIndex: Fatal error. ID=" << job.id << ":" << err.what();
// cerr << "NGT::createIndex: Fatal error. ID=" << job.id << ":" << err.what();
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("NGT::createIndex: Fatal error. ID=" + std::to_string(job.id) + ":" + err.what());
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
GraphIndex::objectSpace->deleteObject(f);
#endif
if (NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeKNNG) {
cerr << err.what() << " continue.." << endl;
// cerr << err.what() << " continue.." << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue..");
} else {
throw err;
}
}
}
if (((*job.results).size() == 0) && (job.id != 1)) {
cerr << "insert warning!! No searched nodes!. If the first time, no problem. " << job.id << endl;
// cerr << "insert warning!! No searched nodes!. If the first time, no problem. " << job.id << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("insert warning!! No searched nodes!. If the first time, no problem. " + std::to_string(job.id));
}
GraphIndex::insertNode(job.id, *job.results);
}
@ -1513,13 +1661,17 @@ GraphAndTreeIndex::createIndex(const vector<pair<NGT::Object*, size_t> > &object
count += cnt;
if (timerCount <= count) {
timer.stop();
cerr << "Processed " << timerCount << " time= " << timer << endl;
// cerr << "Processed " << timerCount << " time= " << timer << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(timerCount) + " time= " + std::to_string(timer.time));
timerCount += timerInterval;
timer.start();
}
}
} catch(Exception &err) {
cerr << "thread terminate!" << endl;
// cerr << "thread terminate!" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("thread terminate!");
threads.terminate();
throw err;
}
@ -1560,18 +1712,26 @@ bool
GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
bool valid = GraphIndex::verify(status, info);
if (!valid) {
cerr << "The graph or object is invalid!" << endl;
// cerr << "The graph or object is invalid!" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("The graph or object is invalid!");
}
bool treeValid = DVPTree::verify(GraphIndex::objectSpace->getRepository().size(), status);
if (!treeValid) {
cerr << "The tree is invalid" << endl;
// cerr << "The tree is invalid" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("The tree is invalid");
}
valid = valid && treeValid;
// status: tree|graph|object
cerr << "Started checking consistency..." << endl;
// cerr << "Started checking consistency..." << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Started checking consistency...");
for (size_t id = 1; id < status.size(); id++) {
if (id % 100000 == 0) {
cerr << "The number of processed objects=" << id << endl;
// cerr << "The number of processed objects=" << id << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("The number of processed objects=" + std::to_string(id));
}
if (status[id] != 0x00 && status[id] != 0x07) {
if (status[id] == 0x03) {
@ -1596,7 +1756,9 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
GraphIndex::objectSpace->deleteObject(po);
#endif
cerr << "Fatal Error!: Cannot search! " << err.what() << endl;
// cerr << "Fatal Error!: Cannot search! " << err.what() << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Fatal Error!: Cannot search! " + std::string(err.what()));
objects.clear();
}
size_t n = 0;
@ -1609,7 +1771,9 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
}
if (!registeredIdenticalObject) {
if (info) {
cerr << "info: not found the registered same objects. id=" << id << " size=" << objects.size() << endl;
// cerr << "info: not found the registered same objects. id=" << id << " size=" << objects.size() << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("info: not found the registered same objects. id=" + std::to_string(id) + " size=" + std::to_string(objects.size()));
}
sc.id = 0;
sc.radius = FLT_MAX;
@ -1622,7 +1786,10 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
try {
GraphIndex::search(sc, seeds);
} catch(Exception &err) {
cerr << "Fatal Error!: Cannot search! " << err.what() << endl;
// cerr << "Fatal Error!: Cannot search! " << err.what() << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("Fatal Error!: Cannot search! " + std::string(err.what()));
}
objects.clear();
}
registeredIdenticalObject = false;
@ -1631,7 +1798,9 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
if (objects[n].id != id && status[objects[n].id] == 0x07) {
registeredIdenticalObject = true;
if (info) {
cerr << "info: found by using mode accurate search. " << objects[n].id << endl;
// cerr << "info: found by using mode accurate search. " << objects[n].id << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("info: found by using mode accurate search. " + std::to_string(objects[n].id));
}
break;
}
@ -1639,7 +1808,9 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
}
if (!registeredIdenticalObject && mode != 's') {
if (info) {
cerr << "info: not found by using more accurate search." << endl;
// cerr << "info: not found by using more accurate search." << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("info: not found by using more accurate search.");
}
sc.id = 0;
sc.radius = 0.0;
@ -1655,7 +1826,9 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
if (objects[n].id != id && status[objects[n].id] == 0x07) {
registeredIdenticalObject = true;
if (info) {
cerr << "info: found by using linear search. " << objects[n].id << endl;
// cerr << "info: found by using linear search. " << objects[n].id << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("info: found by using linear search. " + std::to_string(objects[n].id));
}
break;
}
@ -1666,8 +1839,12 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
#endif
if (registeredIdenticalObject) {
if (info) {
cerr << "Info ID=" << id << ":" << static_cast<int>(status[id]) << endl;
cerr << " found the valid same objects. " << objects[n].id << endl;
// cerr << "Info ID=" << id << ":" << static_cast<int>(status[id]) << endl;
// cerr << " found the valid same objects. " << objects[n].id << endl;
if (NGT_LOG_DEBUG_){
(*NGT_LOG_DEBUG_)("Info ID=" + std::to_string(id) + ":" + std::to_string(static_cast<int>(status[id])));
(*NGT_LOG_DEBUG_)(" found the valid same objects. " + std::to_string(objects[n].id));
}
}
GraphNode &fromNode = *GraphIndex::getNode(id);
bool fromFound = false;
@ -1694,40 +1871,67 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
if (!fromFound || !toFound) {
if (info) {
if (!fromFound && !toFound) {
cerr << "Warning no undirected edge between " << id << "(" << fromNode.size() << ") and "
<< objects[n].id << "(" << toNode.size() << ")." << endl;
// cerr << "Warning no undirected edge between " << id << "(" << fromNode.size() << ") and "
// << objects[n].id << "(" << toNode.size() << ")." << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Warning no undirected edge between " + std::to_string(id) + "(" + std::to_string(fromNode.size()) + ") and "
+ std::to_string(objects[n].id) + "(" + std::to_string(toNode.size()) + ").");
} else if (!fromFound) {
cerr << "Warning no directed edge from " << id << "(" << fromNode.size() << ") to "
<< objects[n].id << "(" << toNode.size() << ")." << endl;
// cerr << "Warning no directed edge from " << id << "(" << fromNode.size() << ") to "
// << objects[n].id << "(" << toNode.size() << ")." << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Warning no directed edge from " + std::to_string(id) + "(" + std::to_string(fromNode.size()) + ") to "
+ std::to_string(objects[n].id) + "(" + std::to_string(toNode.size()) + ").");
} else if (!toFound) {
cerr << "Warning no reverse directed edge from " << id << "(" << fromNode.size() << ") to "
<< objects[n].id << "(" << toNode.size() << ")." << endl;
// cerr << "Warning no reverse directed edge from " << id << "(" << fromNode.size() << ") to "
// << objects[n].id << "(" << toNode.size() << ")." << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Warning no reverse directed edge from " + std::to_string(id) + "(" + std::to_string(fromNode.size()) + ") to "
+ std::to_string(objects[n].id) + "(" + std::to_string(toNode.size()) + ").");
}
}
if (!findPathAmongIdenticalObjects(*this, id, objects[n].id)) {
cerr << "Warning no path from " << id << " to " << objects[n].id << endl;
// cerr << "Warning no path from " << id << " to " << objects[n].id << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Warning no path from " + std::to_string(id) + " to " + std::to_string(objects[n].id));
}
if (!findPathAmongIdenticalObjects(*this, objects[n].id, id)) {
cerr << "Warning no reverse path from " << id << " to " << objects[n].id << endl;
// cerr << "Warning no reverse path from " << id << " to " << objects[n].id << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Warning no reverse path from " + std::to_string(id) + " to " + std::to_string(objects[n].id));
}
}
} else {
if (mode == 's') {
cerr << "Warning: not found the valid same object, but not try to use linear search." << endl;
cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
// cerr << "Warning: not found the valid same object, but not try to use linear search." << endl;
// cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("Warning: not found the valid same object, but not try to use linear search.");
(*NGT_LOG_DEBUG_)("Error! ID=" + std::to_string(id) + ":" + std::to_string(static_cast<int>(status[id])));
}
} else {
cerr << "Warning: not found the valid same object even by using linear search." << endl;
cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
// cerr << "Warning: not found the valid same object even by using linear search." << endl;
// cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("Warning: not found the valid same object even by using linear search.");
(*NGT_LOG_DEBUG_)("Error! ID=" + std::to_string(id) + ":" + std::to_string(static_cast<int>(status[id])));
}
valid = false;
}
}
} else if (status[id] == 0x01) {
if (info) {
cerr << "Warning! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
cerr << " not inserted into the indexes" << endl;
// cerr << "Warning! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
// cerr << " not inserted into the indexes" << endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("Warning! ID=" + std::to_string(id) + ":" + std::to_string(static_cast<int>(status[id])));
(*NGT_LOG_DEBUG_)(" not inserted into the indexes");
}
}
} else {
cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
// cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Error! ID=" + std::to_string(id) + ":" + std::to_string(static_cast<int>(status[id])));
valid = false;
}
}

View File

@ -140,6 +140,9 @@ public:
case DistanceType::DistanceTypeL2:
p.set("DistanceType", "L2");
break;
case DistanceType::DistanceTypeIP:
p.set("DistanceType", "IP");
break;
case DistanceType::DistanceTypeHamming:
p.set("DistanceType", "Hamming");
break;
@ -255,6 +258,10 @@ public:
{
distanceType = DistanceType::DistanceTypeL2;
}
else if (it->second == "IP")
{
distanceType = DistanceType::DistanceTypeIP;
}
else if (it->second == "Hamming")
{
distanceType = DistanceType::DistanceTypeHamming;
@ -379,6 +386,7 @@ public:
void set(NGT::Property & prop);
void get(NGT::Property & prop);
int64_t memSize() { return sizeof(*this); }
int dimension;
int threadPoolSize;
ObjectSpace::ObjectType objectType;
@ -403,6 +411,7 @@ public:
public:
InsertionResult() : id(0), identical(false), distance(0.0) {}
InsertionResult(size_t i, bool tf, Distance d) : id(i), identical(tf), distance(d) {}
int64_t memSize() { return sizeof(*this); }
size_t id;
bool identical;
Distance distance; // the distance between the centroid and the inserted object.
@ -415,6 +424,7 @@ public:
AccuracyTable(std::vector<std::pair<float, double>> & t) { set(t); }
AccuracyTable(std::string str) { set(str); }
void set(std::vector<std::pair<float, double>> & t) { table = t; }
int64_t memSize() { return sizeof(*this) + table.capacity() * sizeof(table[0]); }
void set(std::string str)
{
std::vector<std::string> tokens;
@ -645,7 +655,7 @@ public:
virtual void linearSearch(NGT::SearchContainer & sc) { getIndex().linearSearch(sc); }
virtual void linearSearch(NGT::SearchQuery & sc) { getIndex().linearSearch(sc); }
// for milvus
virtual void search(NGT::SearchContainer & sc, const faiss::ConcurrentBitsetPtr & bitset) { getIndex().search(sc, bitset); }
virtual void search(NGT::SearchContainer & sc, const faiss::BitsetView&bitset) { getIndex().search(sc, bitset); }
virtual void search(NGT::SearchContainer & sc) { getIndex().search(sc); }
virtual void search(NGT::SearchQuery & sc) { getIndex().search(sc); }
virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds) { getIndex().search(sc, seeds); }
@ -711,6 +721,8 @@ public:
static std::string getVersion();
std::string getPath() { return path; }
virtual int64_t memSize() { return redirector.memSize() + sizeof(path) + (index ? getIndex().memSize() : 0); }
protected:
Object * allocateObject(void * vec, const std::type_info & objectType)
{
@ -1046,7 +1058,7 @@ public:
}
// for milvus
virtual void search(NGT::SearchContainer & sc, const faiss::ConcurrentBitsetPtr & bitset)
virtual void search(NGT::SearchContainer & sc, const faiss::BitsetView&bitset)
{
sc.distanceComputationCount = 0;
sc.visitCount = 0;
@ -1574,7 +1586,7 @@ protected:
}
// for milvus
virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::ConcurrentBitsetPtr & bitset)
virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView&bitset)
{
if (sc.size == 0)
{
@ -1628,6 +1640,8 @@ protected:
throw e;
}
}
virtual int64_t memSize() { return property.memSize() + sizeof(readOnly) + accuracyTable.memSize() + Index::memSize() + NeighborhoodGraph::memSize(); }
Index::Property property;
bool readOnly;
@ -1641,6 +1655,9 @@ protected:
class GraphAndTreeIndex : public GraphIndex, public DVPTree
{
public:
virtual int64_t memSize() { return GraphIndex::memSize() + DVPTree::memSize(); }
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
GraphAndTreeIndex(const std::string & allocator, bool rdOnly = false) : GraphIndex(allocator, false) { initialize(allocator, 0); }
GraphAndTreeIndex(const std::string & allocator, NGT::Property & prop);
@ -2130,7 +2147,7 @@ public:
// for milvus
void
getSeedsFromTree(NGT::SearchContainer& sc, ObjectDistances& seeds, const faiss::ConcurrentBitsetPtr& bitset) {
getSeedsFromTree(NGT::SearchContainer& sc, ObjectDistances& seeds, const faiss::BitsetView& bitset) {
DVPTree::SearchContainer tso(sc.object);
tso.mode = DVPTree::SearchContainer::SearchLeaf;
tso.radius = 0.0;
@ -2187,7 +2204,7 @@ public:
}
// for milvus
void search(NGT::SearchContainer & sc, const faiss::ConcurrentBitsetPtr & bitset)
void search(NGT::SearchContainer & sc, const faiss::BitsetView&bitset)
{
sc.distanceComputationCount = 0;
sc.visitCount = 0;

View File

@ -16,6 +16,7 @@
#pragma once
#include "defines.h"
#include "MmapManagerDefs.h"
#include "MmapManagerException.h"
@ -132,7 +133,9 @@ namespace MemoryManager{
const off_t old_file_size = mmapCntlHead->base_size * mmapCntlHead->unit_num;
if(new_unit_num >= MMAP_MAX_UNIT_NUM){
std::cerr << "over max unit num" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("over max unit num");
// std::cerr << "over max unit num" << std::endl;
return false;
}
@ -151,10 +154,14 @@ namespace MemoryManager{
throw MmapManagerException("truncate error" + err_str);
}
if(close(fd) == -1) std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
if(close(fd) == -1 && NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(filePath + "[WARN] : filedescript cannot close");
// std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException("mmap error" + err_str);
}
if(close(fd) == -1) std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
if(close(fd) == -1 && NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(filePath + "[WARN] : filedescript cannot close");
// std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
mmapDataAddr[mmapCntlHead->unit_num] = new_area;
@ -179,14 +186,18 @@ namespace MemoryManager{
if(lseek(fd, (off_t)size-1, SEEK_SET) < 0){
std::stringstream ss;
ss << "[ERR] Cannot seek the file. " << targetFile << " " << getErrorStr(errno);
if(close(fd) == -1) std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
if(close(fd) == -1 && NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(targetFile + "[WARN] : filedescript cannot close");
// std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException(ss.str());
}
errno = 0;
if(write(fd, &c, sizeof(char)) == -1){
std::stringstream ss;
ss << "[ERR] Cannot write the file. Check the disk space. " << targetFile << " " << getErrorStr(errno);
if(close(fd) == -1) std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
if(close(fd) == -1 && NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(targetFile + "[WARN] : filedescript cannot close");
// std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException(ss.str());
}

View File

@ -19,6 +19,7 @@
#include "NGT/Index.h"
#include "NGT/ArrayFile.h"
#include "NGT/Clustering.h"
#include "NGT/defines.h"

View File

@ -61,6 +61,7 @@ namespace NGT {
void deserialize(std::stringstream & is) { NGT::Serializer::read(is, id); }
void serializeAsText(std::ofstream &os) { NGT::Serializer::writeAsText(os, id); }
void deserializeAsText(std::ifstream &is) { NGT::Serializer::readAsText(is, id); }
virtual int64_t memSize() { return sizeof(id); }
protected:
NodeID id;
};
@ -69,6 +70,7 @@ namespace NGT {
public:
Object():object(0) {}
bool operator<(const Object &o) const { return distance < o.distance; }
virtual int64_t memSize() { return sizeof(*this) + object->memSize(); } // size of object cannot be decided accurately
static const double Pivot;
ObjectID id;
PersistentObject *object;
@ -126,6 +128,8 @@ namespace NGT {
parent.deserializeAsText(is);
}
virtual int64_t memSize() { return id.memSize() * 2 + pivot->memSize(); }
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void setPivot(PersistentObject &f, ObjectSpace &os, SharedMemoryAllocator &allocator) {
if (pivot == 0) {
@ -435,6 +439,8 @@ namespace NGT {
}
}
virtual int64_t memSize() { return sizeof(childrenSize) + children->memSize() + childrenSize * sizeof(Distance) + Node::memSize(); }
void show() {
std::cout << "Show internal node " << childrenSize << ":";
for (size_t i = 0; i < childrenSize; i++) {
@ -747,6 +753,7 @@ namespace NGT {
bool verify(size_t nobjs, std::vector<uint8_t> &status);
#endif
virtual int64_t memSize() { return sizeof(objectSize) + objectIDs->memSize() * objectSize + Node::memSize(); }
#ifdef NGT_NODE_USE_VECTOR
size_t getObjectSize() { return objectIDs.size(); }

View File

@ -16,6 +16,7 @@
#pragma once
#include <sstream>
#include "defines.h"
namespace NGT {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
@ -104,22 +105,30 @@ namespace NGT {
}
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<double> &obj) {
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(double): Fatal error! Something wrong!" << std::endl;
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(double): Fatal error! Something wrong!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject(double): Fatal error! Something wrong!");
abort();
}
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<float> &obj) {
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(float): Fatal error! Something wrong!" << std::endl;
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(float): Fatal error! Something wrong!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject(float): Fatal error! Something wrong!");
abort();
}
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<uint8_t> &obj) {
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(uint8_t): Fatal error! Something wrong!" << std::endl;
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(uint8_t): Fatal error! Something wrong!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject(uint8_t): Fatal error! Something wrong!");
abort();
}
virtual PersistentObject *allocateNormalizedPersistentObject(const float *obj, size_t size) {
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject: Fatal error! Something wrong!" << std::endl;
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject: Fatal error! Something wrong!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject: Fatal error! Something wrong!");
abort();
}
@ -141,8 +150,11 @@ namespace NGT {
while (getline(is, line)) {
lineNo++;
if (dataSize > 0 && (dataSize <= size() - prevDataSize)) {
std::cerr << "The size of data reached the specified size. The remaining data in the file are not inserted. "
<< dataSize << std::endl;
// std::cerr << "The size of data reached the specified size. The remaining data in the file are not inserted. "
// << dataSize << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("The size of data reached the specified size. The remaining data in the file are not inserted. "
+ std::to_string(dataSize));
break;
}
std::vector<double> object;
@ -152,12 +164,16 @@ namespace NGT {
try {
obj = allocateNormalizedPersistentObject(object);
} catch (Exception &err) {
std::cerr << err.what() << " continue..." << std::endl;
// std::cerr << err.what() << " continue..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue...");
obj = allocatePersistentObject(object);
}
push_back(obj);
} catch (Exception &err) {
std::cerr << "ObjectSpace::readText: Warning! Invalid line. [" << line << "] Skip the line " << lineNo << " and continue." << std::endl;
// std::cerr << "ObjectSpace::readText: Warning! Invalid line. [" << line << "] Skip the line " << lineNo << " and continue." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::readText: Warning! Invalid line. [" + line + "] Skip the line " + std::to_string(lineNo) + " and continue.");
}
}
}
@ -186,13 +202,17 @@ namespace NGT {
try {
obj = allocateNormalizedPersistentObject(object);
} catch (Exception &err) {
std::cerr << err.what() << " continue..." << std::endl;
// std::cerr << err.what() << " continue..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue...");
obj = allocatePersistentObject(object);
}
push_back(obj);
} catch (Exception &err) {
std::cerr << "ObjectSpace::readText: Warning! Invalid data. Skip the data no. " << idx << " and continue." << std::endl;
// std::cerr << "ObjectSpace::readText: Warning! Invalid data. Skip the data no. " << idx << " and continue." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::readText: Warning! Invalid data. Skip the data no. " + std::to_string(idx) + " and continue.");
}
}
}
@ -231,7 +251,9 @@ namespace NGT {
char *e;
object[idx] = strtod(tokens[idx].c_str(), &e);
if (*e != 0) {
std::cerr << "ObjectSpace::readText: Warning! Not numerical value. [" << e << "]" << std::endl;
// std::cerr << "ObjectSpace::readText: Warning! Not numerical value. [" << e << "]" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::readText: Warning! Not numerical value. [" + std::string(e) + "]");
break;
}
}
@ -245,8 +267,11 @@ namespace NGT {
osize = osize < vsize ? vsize : osize;
} else {
if (dimension != size) {
std::cerr << "ObjectSpace::allocateObject: Fatal error! dimension is invalid. The indexed objects="
<< dimension << " The specified object=" << size << std::endl;
// std::cerr << "ObjectSpace::allocateObject: Fatal error! dimension is invalid. The indexed objects="
// << dimension << " The specified object=" << size << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::allocateObject: Fatal error! dimension is invalid. The indexed objects="
+ std::to_string(dimension) + " The specified object=" + std::to_string(size));
assert(dimension == size);
}
}
@ -263,7 +288,9 @@ namespace NGT {
obj[i] = static_cast<float>(o[i]);
}
} else {
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
abort();
}
return po;
@ -283,7 +310,9 @@ namespace NGT {
} else if (type == typeid(float)) {
cpsize *= sizeof(float);
} else {
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
abort();
}
PersistentObject *po = new (objectAllocator) PersistentObject(objectAllocator, paddedByteSize);
@ -315,7 +344,9 @@ namespace NGT {
obj[i] = static_cast<float>(o[i]);
}
} else {
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
abort();
}
return po;
@ -361,7 +392,9 @@ namespace NGT {
d.push_back(obj[i]);
}
} else {
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
abort();
}
}

View File

@ -16,6 +16,7 @@
#pragma once
#include <cstring>
#include "PrimitiveComparator.h"
class ObjectSpace;
@ -94,6 +95,15 @@ namespace NGT {
}
}
int64_t memSize() const {
// auto obj = (std::vector<ObjectDistance>)(*this);
if (this->size() == 0)
return 0;
else {
return (*this)[0].memSize() * this->size();
}
// return this->size() == 0 ? 0 : (*this)[0].memSize() * (this->size());
}
ObjectDistances &operator=(PersistentObjectDistances &objs);
};
@ -169,6 +179,7 @@ namespace NGT {
SharedMemoryAllocator &allocator;
#endif
virtual ~Comparator(){}
int64_t memSize() { return sizeof(size_t); }
};
enum DistanceType {
DistanceTypeNone = -1,
@ -180,7 +191,8 @@ namespace NGT {
DistanceTypeNormalizedAngle = 5,
DistanceTypeNormalizedCosine = 6,
DistanceTypeJaccard = 7,
DistanceTypeSparseJaccard = 8
DistanceTypeSparseJaccard = 8,
DistanceTypeIP = 9
};
enum ObjectType {
@ -248,6 +260,7 @@ namespace NGT {
virtual ObjectRepository &getRepository() = 0;
virtual void setDistanceType(DistanceType t) = 0;
virtual DistanceType getDistanceType() = 0;
virtual void *getObject(size_t idx) = 0;
virtual void getObject(size_t idx, std::vector<float> &v) = 0;
@ -255,6 +268,7 @@ namespace NGT {
size_t getDimension() { return dimension; }
size_t getPaddedDimension() { return ((dimension - 1) / 16 + 1) * 16; }
virtual int64_t memSize() { return sizeof(dimension) + sizeof(distanceType) + sizeof(prefetchOffset) * 2 + sizeof(normalization) + comparator->memSize(); };
template <typename T>
void normalize(T *data, size_t dim) {
@ -384,6 +398,8 @@ namespace NGT {
void *getPointer(size_t idx = 0) const { return vector + idx; }
static Object *allocate(ObjectSpace &objectspace) { return new Object(&objectspace); }
virtual int64_t memSize() { return std::strlen((char*)vector); }
private:
void clear() {
if (vector != 0) {

View File

@ -73,6 +73,27 @@ namespace NGT {
#endif
};
class ComparatorIP : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorIP(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareIP((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareIP((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareIP((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorIP(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareInnerProduct((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorHammingDistance : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
@ -281,6 +302,8 @@ namespace NGT {
objecta.copy(objectb, getByteSizeOfObject());
}
DistanceType getDistanceType() { return distanceType; }
void setDistanceType(DistanceType t) {
if (comparator != 0) {
delete comparator;
@ -326,6 +349,9 @@ namespace NGT {
case DistanceTypeL2:
comparator = new ObjectSpaceRepository::ComparatorL2(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeIP:
comparator = new ObjectSpaceRepository::ComparatorIP(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeHamming:
comparator = new ObjectSpaceRepository::ComparatorHammingDistance(ObjectSpace::getPaddedDimension());
break;
@ -352,7 +378,9 @@ namespace NGT {
break;
#endif
default:
std::cerr << "Distance type is not specified" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Distance type is not specified");
// std::cerr << "Distance type is not specified" << std::endl;
assert(distanceType != DistanceTypeNone);
abort();
}
@ -587,7 +615,9 @@ namespace NGT {
} else if (t == typeid(uint32_t)) {
NGT::Serializer::writeAsText(os, (uint32_t*)ref, dimension);
} else {
std::cerr << "ObjectT::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectT::serializeAsText: not supported data type. [" + std::to_string(t.name()) + "]");
// std::cerr << "ObjectT::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
assert(0);
}
}
@ -610,7 +640,9 @@ namespace NGT {
} else if (t == typeid(uint32_t)) {
NGT::Serializer::readAsText(is, (uint32_t*)ref, dimension);
} else {
std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Object::deserializeAsText: not supported data type. [" + std::to_string(t.name()) + "]");
// std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
assert(0);
}
}

View File

@ -195,7 +195,8 @@ class PrimitiveComparator {
diff0 = (COMPARE_TYPE)(*a++ - *b++);
d += diff0 * diff0;
}
return sqrt((double)d);
// return sqrt((double)d);
return d;
}
inline static double
@ -264,7 +265,8 @@ class PrimitiveComparator {
_mm_store_ps(f, sum128);
double s = f[0] + f[1] + f[2] + f[3];
return sqrt(s);
// return sqrt(s);
return s;
}
inline static double
@ -290,7 +292,8 @@ class PrimitiveComparator {
int d = (int)*a++ - (int)*b++;
s += d * d;
}
return sqrt(s);
// return sqrt(s);
return s;
}
#endif
#if defined(NGT_NO_AVX)
@ -498,6 +501,17 @@ class PrimitiveComparator {
return sum;
}
template <typename OBJECT_TYPE>
inline static double
compareInnerProduct(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
sum += (double)a[loc] * (double)b[loc];
// sum += a[loc] * b[loc];
}
return -sum;
}
template <typename OBJECT_TYPE>
inline static double
compareCosine(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
@ -551,6 +565,42 @@ class PrimitiveComparator {
return s;
}
inline static double
compareInnerProduct(const float* a, const float* b, size_t size) {
const float* last = a + size;
#if defined(NGT_AVX512)
__m512 sum512 = _mm512_setzero_ps();
while (a < last) {
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(_mm512_loadu_ps(a), _mm512_loadu_ps(b)));
a += 16;
b += 16;
}
__m256 sum256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum512, 0), _mm512_extractf32x8_ps(sum512, 1));
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#elif defined(NGT_AVX2)
__m256 sum256 = _mm256_setzero_ps();
while (a < last) {
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b)));
a += 8;
b += 8;
}
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#else
__m128 sum128 = _mm_setzero_ps();
while (a < last) {
sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(a), _mm_loadu_ps(b)));
a += 4;
b += 4;
}
#endif
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, sum128);
double s = f[0] + f[1] + f[2] + f[3];
return -s;
}
inline static double
compareDotProduct(const unsigned char* a, const unsigned char* b, size_t size) {
double sum = 0.0;
@ -560,6 +610,15 @@ class PrimitiveComparator {
return sum;
}
inline static double
compareInnerProduct(const unsigned char* a, const unsigned char* b, size_t size) {
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
sum += (double)a[loc] * (double)b[loc];
}
return -sum;
}
inline static double
compareCosine(const float* a, const float* b, size_t size) {
const float* last = a + size;

View File

@ -20,6 +20,7 @@
#include "NGT/Node.h"
#include "NGT/defines.h"
#include "faiss/utils/ConcurrentBitset.h"
#include "faiss/utils/BitsetView.h"
#include <sstream>
#include <string>
@ -262,7 +263,7 @@ namespace NGT {
// for milvus
void
getObjectIDsFromLeaf(Node::ID nid, ObjectDistances& rl, const faiss::ConcurrentBitsetPtr& bitset) {
getObjectIDsFromLeaf(Node::ID nid, ObjectDistances& rl, const faiss::BitsetView& bitset) {
LeafNode& ln = *(LeafNode*)getNode(nid);
rl.clear();
ObjectDistance r;
@ -274,7 +275,7 @@ namespace NGT {
r.id = ln.getObjectIDs()[i].id;
r.distance = ln.getObjectIDs()[i].distance;
#endif
if (bitset != nullptr && bitset->test(r.id - 1)) {
if (!bitset.empty() && bitset.test(r.id - 1)) {
continue;
}
rl.push_back(r);
@ -487,6 +488,8 @@ namespace NGT {
}
}
virtual int64_t memSize() { return sizeof(size_t) * 2 + sizeof(splitMode) + name.size() + leafNodes.memSize() + internalNodes.memSize(); }
public:
size_t internalChildrenSize;
size_t leafObjectsSize;

Some files were not shown because too many files have changed in this diff Show More