mirror of https://github.com/milvus-io/milvus.git
Fix missing file
Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>pull/4973/head^2
parent
0110ba6bd2
commit
0d75840ed6
|
@ -1,17 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/storage"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) == 1 {
|
||||
fmt.Println("usage: binlog file1 file2 ...")
|
||||
}
|
||||
if err := storage.PrintBinlogFiles(os.Args[1:]); err != nil {
|
||||
fmt.Printf("error: %s\n", err.Error())
|
||||
}
|
||||
}
|
|
@ -17,4 +17,4 @@ target_sources( cache PRIVATE ${CACHE_FILES}
|
|||
Cache.inl
|
||||
)
|
||||
target_include_directories( cache PUBLIC ${MILVUS_ENGINE_SRC}/cache )
|
||||
|
||||
target_link_libraries( cache PRIVATE fiu)
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
#pragma once
|
||||
|
||||
#include "Cache.h"
|
||||
// #include "s/Metrics.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
#include <memory>
|
||||
|
|
|
@ -47,7 +47,6 @@ CacheMgr<ItemObj>::GetItem(const std::string& key) {
|
|||
LOG_SERVER_ERROR_ << "Cache doesn't exist";
|
||||
return nullptr;
|
||||
}
|
||||
// server::Metrics::GetInstance().CacheAccessTotalIncrement();
|
||||
return cache_->get(key);
|
||||
}
|
||||
|
||||
|
@ -59,7 +58,6 @@ CacheMgr<ItemObj>::InsertItem(const std::string& key, const ItemObj& data) {
|
|||
return;
|
||||
}
|
||||
cache_->insert(key, data);
|
||||
// server::Metrics::GetInstance().CacheAccessTotalIncrement();
|
||||
}
|
||||
|
||||
template <typename ItemObj>
|
||||
|
@ -70,7 +68,6 @@ CacheMgr<ItemObj>::EraseItem(const std::string& key) {
|
|||
return;
|
||||
}
|
||||
cache_->erase(key);
|
||||
// server::Metrics::GetInstance().CacheAccessTotalIncrement();
|
||||
}
|
||||
|
||||
template <typename ItemObj>
|
||||
|
|
|
@ -13,36 +13,42 @@
|
|||
|
||||
#include <utility>
|
||||
|
||||
// #include <fiu/fiu-local.h>
|
||||
#include <fiu/fiu-local.h>
|
||||
|
||||
#include "config/ServerConfig.h"
|
||||
#include "utils/Log.h"
|
||||
#include "value/config/ServerConfig.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace cache {
|
||||
|
||||
CpuCacheMgr::CpuCacheMgr() {
|
||||
// cache_ = std::make_shared<Cache<DataObjPtr>>(config.cache.cache_size(), 1UL << 32, "[CACHE CPU]");
|
||||
|
||||
// if (config.cache.cpu_cache_threshold() > 0.0) {
|
||||
// cache_->set_freemem_percent(config.cache.cpu_cache_threshold());
|
||||
// }
|
||||
ConfigMgr::GetInstance().Attach("cache.cache_size", this);
|
||||
}
|
||||
|
||||
CpuCacheMgr::~CpuCacheMgr() {
|
||||
ConfigMgr::GetInstance().Detach("cache.cache_size", this);
|
||||
}
|
||||
|
||||
CpuCacheMgr&
|
||||
CpuCacheMgr::GetInstance() {
|
||||
static CpuCacheMgr s_mgr;
|
||||
return s_mgr;
|
||||
}
|
||||
|
||||
CpuCacheMgr::CpuCacheMgr() {
|
||||
cache_ = std::make_shared<Cache<DataObjPtr>>(config.cache.cache_size(), 1UL << 32, "[CACHE CPU]");
|
||||
|
||||
if (config.cache.cpu_cache_threshold() > 0.0) {
|
||||
cache_->set_freemem_percent(config.cache.cpu_cache_threshold());
|
||||
}
|
||||
ConfigMgr::GetInstance().Attach("cache.cache_size", this);
|
||||
}
|
||||
|
||||
CpuCacheMgr::~CpuCacheMgr() {
|
||||
ConfigMgr::GetInstance().Detach("cache.cache_size", this);
|
||||
}
|
||||
|
||||
DataObjPtr
|
||||
CpuCacheMgr::GetItem(const std::string& key) {
|
||||
auto ret = CacheMgr<DataObjPtr>::GetItem(key);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
CpuCacheMgr::ConfigUpdate(const std::string& name) {
|
||||
// SetCapacity(config.cache.cache_size());
|
||||
SetCapacity(config.cache.cache_size());
|
||||
}
|
||||
|
||||
} // namespace cache
|
||||
|
|
|
@ -16,20 +16,24 @@
|
|||
|
||||
#include "cache/CacheMgr.h"
|
||||
#include "cache/DataObj.h"
|
||||
#include "config/ConfigMgr.h"
|
||||
#include "value/config/ConfigMgr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace cache {
|
||||
|
||||
class CpuCacheMgr : public CacheMgr<DataObjPtr>, public ConfigObserver {
|
||||
public:
|
||||
static CpuCacheMgr&
|
||||
GetInstance();
|
||||
|
||||
private:
|
||||
CpuCacheMgr();
|
||||
|
||||
~CpuCacheMgr();
|
||||
|
||||
public:
|
||||
static CpuCacheMgr&
|
||||
GetInstance();
|
||||
DataObjPtr
|
||||
GetItem(const std::string& key) override;
|
||||
|
||||
public:
|
||||
void
|
||||
|
|
|
@ -20,8 +20,6 @@ class DataObj {
|
|||
public:
|
||||
virtual int64_t
|
||||
Size() = 0;
|
||||
|
||||
public:
|
||||
virtual ~DataObj() = default;
|
||||
};
|
||||
|
||||
|
|
|
@ -10,10 +10,10 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "cache/GpuCacheMgr.h"
|
||||
#include "config/ServerConfig.h"
|
||||
#include "utils/Log.h"
|
||||
#include "value/config/ServerConfig.h"
|
||||
|
||||
// #include <fiu/fiu-local.h>
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#include "cache/CacheMgr.h"
|
||||
#include "cache/DataObj.h"
|
||||
#include "config/ConfigMgr.h"
|
||||
#include "value/config/ConfigMgr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace cache {
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
#-------------------------------------------------------------------------------
|
||||
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# library
|
||||
set( CONFIG_SRCS ConfigMgr.h
|
||||
ConfigMgr.cpp
|
||||
ConfigType.h
|
||||
ConfigType.cpp
|
||||
ServerConfig.h
|
||||
ServerConfig.cpp
|
||||
)
|
||||
|
||||
set( CONFIG_LIBS yaml-cpp
|
||||
)
|
||||
|
||||
create_library(
|
||||
TARGET config
|
||||
SRCS ${CONFIG_SRCS}
|
||||
LIBS ${CONFIG_LIBS}
|
||||
)
|
||||
|
|
@ -1,219 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <yaml-cpp/yaml.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
#include "config/ConfigMgr.h"
|
||||
#include "config/ServerConfig.h"
|
||||
|
||||
namespace {
|
||||
const int64_t MB = (1024ll * 1024);
|
||||
const int64_t GB = (1024ll * 1024 * 1024);
|
||||
|
||||
void
|
||||
Flatten(const YAML::Node& node, std::unordered_map<std::string, std::string>& target, const std::string& prefix) {
|
||||
for (auto& it : node) {
|
||||
auto key = prefix.empty() ? it.first.as<std::string>() : prefix + "." + it.first.as<std::string>();
|
||||
switch (it.second.Type()) {
|
||||
case YAML::NodeType::Null: {
|
||||
target[key] = "";
|
||||
break;
|
||||
}
|
||||
case YAML::NodeType::Scalar: {
|
||||
target[key] = it.second.as<std::string>();
|
||||
break;
|
||||
}
|
||||
case YAML::NodeType::Sequence: {
|
||||
std::string value;
|
||||
for (auto& sub : it.second) value += sub.as<std::string>() + ",";
|
||||
target[key] = value;
|
||||
break;
|
||||
}
|
||||
case YAML::NodeType::Map: {
|
||||
Flatten(it.second, target, key);
|
||||
break;
|
||||
}
|
||||
case YAML::NodeType::Undefined: {
|
||||
throw "Unexpected";
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ThrowIfNotSuccess(const milvus::ConfigStatus& cs) {
|
||||
if (cs.set_return != milvus::SetReturn::SUCCESS) {
|
||||
throw cs;
|
||||
}
|
||||
}
|
||||
|
||||
}; // namespace
|
||||
|
||||
namespace milvus {
|
||||
|
||||
ConfigMgr ConfigMgr::instance;
|
||||
|
||||
ConfigMgr::ConfigMgr() {
|
||||
config_list_ = {
|
||||
/* general */
|
||||
{"timezone", CreateStringConfig("timezone", false, &config.timezone.value, "UTC+8", nullptr, nullptr)},
|
||||
|
||||
/* network */
|
||||
{"network.address",
|
||||
CreateStringConfig("network.address", false, &config.network.address.value, "0.0.0.0", nullptr, nullptr)},
|
||||
{"network.port",
|
||||
CreateIntegerConfig("network.port", false, 0, 65535, &config.network.port.value, 19530, nullptr, nullptr)},
|
||||
|
||||
/* pulsar */
|
||||
{"pulsar.address",
|
||||
CreateStringConfig("pulsar.address", false, &config.pulsar.address.value, "localhost", nullptr, nullptr)},
|
||||
{"pulsar.port",
|
||||
CreateIntegerConfig("pulsar.port", false, 0, 65535, &config.pulsar.port.value, 6650, nullptr, nullptr)},
|
||||
|
||||
/* log */
|
||||
{"logs.level", CreateStringConfig("logs.level", false, &config.logs.level.value, "debug", nullptr, nullptr)},
|
||||
{"logs.trace.enable",
|
||||
CreateBoolConfig("logs.trace.enable", false, &config.logs.trace.enable.value, true, nullptr, nullptr)},
|
||||
{"logs.path",
|
||||
CreateStringConfig("logs.path", false, &config.logs.path.value, "/var/lib/milvus/logs", nullptr, nullptr)},
|
||||
{"logs.max_log_file_size", CreateSizeConfig("logs.max_log_file_size", false, 512 * MB, 4096 * MB,
|
||||
&config.logs.max_log_file_size.value, 1024 * MB, nullptr, nullptr)},
|
||||
{"logs.log_rotate_num", CreateIntegerConfig("logs.log_rotate_num", false, 0, 1024,
|
||||
&config.logs.log_rotate_num.value, 0, nullptr, nullptr)},
|
||||
|
||||
/* tracing */
|
||||
{"tracing.json_config_path", CreateStringConfig("tracing.json_config_path", false,
|
||||
&config.tracing.json_config_path.value, "", nullptr, nullptr)},
|
||||
|
||||
/* invisible */
|
||||
/* engine */
|
||||
{"engine.build_index_threshold",
|
||||
CreateIntegerConfig("engine.build_index_threshold", false, 0, std::numeric_limits<int64_t>::max(),
|
||||
&config.engine.build_index_threshold.value, 4096, nullptr, nullptr)},
|
||||
{"engine.search_combine_nq",
|
||||
CreateIntegerConfig("engine.search_combine_nq", true, 0, std::numeric_limits<int64_t>::max(),
|
||||
&config.engine.search_combine_nq.value, 64, nullptr, nullptr)},
|
||||
{"engine.use_blas_threshold",
|
||||
CreateIntegerConfig("engine.use_blas_threshold", true, 0, std::numeric_limits<int64_t>::max(),
|
||||
&config.engine.use_blas_threshold.value, 1100, nullptr, nullptr)},
|
||||
{"engine.omp_thread_num",
|
||||
CreateIntegerConfig("engine.omp_thread_num", true, 0, std::numeric_limits<int64_t>::max(),
|
||||
&config.engine.omp_thread_num.value, 0, nullptr, nullptr)},
|
||||
{"engine.simd_type", CreateEnumConfig("engine.simd_type", false, &SimdMap, &config.engine.simd_type.value,
|
||||
SimdType::AUTO, nullptr, nullptr)},
|
||||
};
|
||||
}
|
||||
|
||||
void
|
||||
ConfigMgr::Init() {
|
||||
std::lock_guard<std::mutex> lock(GetConfigMutex());
|
||||
for (auto& kv : config_list_) {
|
||||
kv.second->Init();
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ConfigMgr::Load(const std::string& path) {
|
||||
/* load from milvus.yaml */
|
||||
auto yaml = YAML::LoadFile(path);
|
||||
/* make it flattened */
|
||||
std::unordered_map<std::string, std::string> flattened;
|
||||
// auto proxy_yaml = yaml["porxy"];
|
||||
auto other_yaml = YAML::Node{};
|
||||
other_yaml["pulsar"] = yaml["pulsar"];
|
||||
Flatten(yaml["proxy"], flattened, "");
|
||||
Flatten(other_yaml, flattened, "");
|
||||
// Flatten(yaml["proxy"], flattened, "");
|
||||
/* update config */
|
||||
for (auto& it : flattened) Set(it.first, it.second, false);
|
||||
}
|
||||
|
||||
void
|
||||
ConfigMgr::Set(const std::string& name, const std::string& value, bool update) {
|
||||
std::cout << "InSet Config " << name << std::endl;
|
||||
if (config_list_.find(name) == config_list_.end()) {
|
||||
std::cout << "Config " << name << " not found!" << std::endl;
|
||||
return;
|
||||
}
|
||||
try {
|
||||
auto& config = config_list_.at(name);
|
||||
std::unique_lock<std::mutex> lock(GetConfigMutex());
|
||||
/* update=false when loading from config file */
|
||||
if (not update) {
|
||||
ThrowIfNotSuccess(config->Set(value, update));
|
||||
} else if (config->modifiable_) {
|
||||
/* set manually */
|
||||
ThrowIfNotSuccess(config->Set(value, update));
|
||||
lock.unlock();
|
||||
Notify(name);
|
||||
} else {
|
||||
throw ConfigStatus(SetReturn::IMMUTABLE, "Config " + name + " is not modifiable");
|
||||
}
|
||||
} catch (ConfigStatus& cs) {
|
||||
throw cs;
|
||||
} catch (...) {
|
||||
throw "Config " + name + " not found.";
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
ConfigMgr::Get(const std::string& name) const {
|
||||
try {
|
||||
auto& config = config_list_.at(name);
|
||||
std::lock_guard<std::mutex> lock(GetConfigMutex());
|
||||
return config->Get();
|
||||
} catch (...) {
|
||||
throw "Config " + name + " not found.";
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
ConfigMgr::Dump() const {
|
||||
std::stringstream ss;
|
||||
for (auto& kv : config_list_) {
|
||||
auto& config = kv.second;
|
||||
ss << config->name_ << ": " << config->Get() << std::endl;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void
|
||||
ConfigMgr::Attach(const std::string& name, ConfigObserver* observer) {
|
||||
std::lock_guard<std::mutex> lock(observer_mutex_);
|
||||
observers_[name].push_back(observer);
|
||||
}
|
||||
|
||||
void
|
||||
ConfigMgr::Detach(const std::string& name, ConfigObserver* observer) {
|
||||
std::lock_guard<std::mutex> lock(observer_mutex_);
|
||||
if (observers_.find(name) == observers_.end())
|
||||
return;
|
||||
auto& ob_list = observers_[name];
|
||||
ob_list.remove(observer);
|
||||
}
|
||||
|
||||
void
|
||||
ConfigMgr::Notify(const std::string& name) {
|
||||
std::lock_guard<std::mutex> lock(observer_mutex_);
|
||||
if (observers_.find(name) == observers_.end())
|
||||
return;
|
||||
auto& ob_list = observers_[name];
|
||||
for (auto& ob : ob_list) {
|
||||
ob->ConfigUpdate(name);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus
|
|
@ -1,556 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "config/ConfigType.h"
|
||||
#include "config/ServerConfig.h"
|
||||
|
||||
#include <strings.h>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace {
|
||||
std::unordered_map<std::string, int64_t> BYTE_UNITS = {
|
||||
{"b", 1},
|
||||
{"k", 1024},
|
||||
{"m", 1024 * 1024},
|
||||
{"g", 1024 * 1024 * 1024},
|
||||
};
|
||||
|
||||
bool
|
||||
is_integer(const std::string& s) {
|
||||
if (not s.empty() && (std::isdigit(s[0]) || s[0] == '-')) {
|
||||
auto ss = s.substr(1);
|
||||
return std::find_if(ss.begin(), ss.end(), [](unsigned char c) { return !std::isdigit(c); }) == ss.end();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool
|
||||
is_number(const std::string& s) {
|
||||
return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { return !std::isdigit(c); }) == s.end();
|
||||
}
|
||||
|
||||
bool
|
||||
is_alpha(const std::string& s) {
|
||||
return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { return !std::isalpha(c); }) == s.end();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool
|
||||
boundary_check(T val, T lower_bound, T upper_bound) {
|
||||
return lower_bound <= val && val <= upper_bound;
|
||||
}
|
||||
|
||||
bool
|
||||
parse_bool(const std::string& str, std::string& err) {
|
||||
if (!strcasecmp(str.c_str(), "true"))
|
||||
return true;
|
||||
else if (!strcasecmp(str.c_str(), "false"))
|
||||
return false;
|
||||
else
|
||||
err = "The specified value must be true or false";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string
|
||||
str_tolower(std::string s) {
|
||||
std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
|
||||
return s;
|
||||
}
|
||||
|
||||
int64_t
|
||||
parse_bytes(const std::string& str, std::string& err) {
|
||||
try {
|
||||
if (str.find_first_of('-') != std::string::npos) {
|
||||
std::stringstream ss;
|
||||
ss << "The specified value for memory (" << str << ") should be a positive integer.";
|
||||
err = ss.str();
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string s = str;
|
||||
if (is_number(s))
|
||||
return std::stoll(s);
|
||||
if (s.length() == 0)
|
||||
return 0;
|
||||
|
||||
auto last_two = s.substr(s.length() - 2, 2);
|
||||
auto last_one = s.substr(s.length() - 1);
|
||||
if (is_alpha(last_two) && is_alpha(last_one))
|
||||
if (last_one == "b" or last_one == "B")
|
||||
s = s.substr(0, s.length() - 1);
|
||||
auto& units = BYTE_UNITS;
|
||||
auto suffix = str_tolower(s.substr(s.length() - 1));
|
||||
|
||||
std::string digits_part;
|
||||
if (is_number(suffix)) {
|
||||
digits_part = s;
|
||||
suffix = 'b';
|
||||
} else {
|
||||
digits_part = s.substr(0, s.length() - 1);
|
||||
}
|
||||
|
||||
if (is_number(digits_part) && (units.find(suffix) != units.end() || is_number(suffix))) {
|
||||
auto digits = std::stoll(digits_part);
|
||||
return digits * units[suffix];
|
||||
} else {
|
||||
std::stringstream ss;
|
||||
ss << "The specified value for memory (" << str << ") should specify the units."
|
||||
<< "The postfix should be one of the `b` `k` `m` `g` characters";
|
||||
err = ss.str();
|
||||
}
|
||||
} catch (...) {
|
||||
err = "Unknown error happened on parse bytes.";
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Use (void) to silent unused warnings.
|
||||
#define assertm(exp, msg) assert(((void)msg, exp))
|
||||
|
||||
namespace milvus {
|
||||
|
||||
std::vector<std::string>
|
||||
OptionValue(const configEnum& ce) {
|
||||
std::vector<std::string> ret;
|
||||
for (auto& e : ce) {
|
||||
ret.emplace_back(e.first);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
BaseConfig::BaseConfig(const char* name, const char* alias, bool modifiable)
|
||||
: name_(name), alias_(alias), modifiable_(modifiable) {
|
||||
}
|
||||
|
||||
void
|
||||
BaseConfig::Init() {
|
||||
assertm(not inited_, "already initialized");
|
||||
inited_ = true;
|
||||
}
|
||||
|
||||
BoolConfig::BoolConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
bool* config,
|
||||
bool default_value,
|
||||
std::function<bool(bool val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(bool val, bool prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
config_(config),
|
||||
default_value_(default_value),
|
||||
is_valid_fn_(std::move(is_valid_fn)),
|
||||
update_fn_(std::move(update_fn)) {
|
||||
}
|
||||
|
||||
void
|
||||
BoolConfig::Init() {
|
||||
BaseConfig::Init();
|
||||
assert(config_ != nullptr);
|
||||
*config_ = default_value_;
|
||||
}
|
||||
|
||||
ConfigStatus
|
||||
BoolConfig::Set(const std::string& val, bool update) {
|
||||
assertm(inited_, "uninitialized");
|
||||
try {
|
||||
if (update and not modifiable_) {
|
||||
std::stringstream ss;
|
||||
ss << "Config " << name_ << " is immutable.";
|
||||
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
|
||||
}
|
||||
|
||||
std::string err;
|
||||
bool value = parse_bool(val, err);
|
||||
if (not err.empty())
|
||||
return ConfigStatus(SetReturn::INVALID, err);
|
||||
|
||||
if (is_valid_fn_ && not is_valid_fn_(value, err))
|
||||
return ConfigStatus(SetReturn::INVALID, err);
|
||||
|
||||
bool prev = *config_;
|
||||
*config_ = value;
|
||||
if (update && update_fn_ && not update_fn_(value, prev, err)) {
|
||||
*config_ = prev;
|
||||
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
|
||||
}
|
||||
|
||||
return ConfigStatus(SetReturn::SUCCESS, "");
|
||||
} catch (std::exception& e) {
|
||||
return ConfigStatus(SetReturn::EXCEPTION, e.what());
|
||||
} catch (...) {
|
||||
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
BoolConfig::Get() {
|
||||
assertm(inited_, "uninitialized");
|
||||
return *config_ ? "true" : "false";
|
||||
}
|
||||
|
||||
StringConfig::StringConfig(
|
||||
const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
std::string* config,
|
||||
const char* default_value,
|
||||
std::function<bool(const std::string& val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
config_(config),
|
||||
default_value_(default_value),
|
||||
is_valid_fn_(std::move(is_valid_fn)),
|
||||
update_fn_(std::move(update_fn)) {
|
||||
}
|
||||
|
||||
void
|
||||
StringConfig::Init() {
|
||||
BaseConfig::Init();
|
||||
assert(config_ != nullptr);
|
||||
*config_ = default_value_;
|
||||
}
|
||||
|
||||
ConfigStatus
|
||||
StringConfig::Set(const std::string& val, bool update) {
|
||||
assertm(inited_, "uninitialized");
|
||||
try {
|
||||
if (update and not modifiable_) {
|
||||
std::stringstream ss;
|
||||
ss << "Config " << name_ << " is immutable.";
|
||||
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
|
||||
}
|
||||
|
||||
std::string err;
|
||||
if (is_valid_fn_ && not is_valid_fn_(val, err))
|
||||
return ConfigStatus(SetReturn::INVALID, err);
|
||||
|
||||
std::string prev = *config_;
|
||||
*config_ = val;
|
||||
if (update && update_fn_ && not update_fn_(val, prev, err)) {
|
||||
*config_ = prev;
|
||||
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
|
||||
}
|
||||
|
||||
return ConfigStatus(SetReturn::SUCCESS, "");
|
||||
} catch (std::exception& e) {
|
||||
return ConfigStatus(SetReturn::EXCEPTION, e.what());
|
||||
} catch (...) {
|
||||
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
StringConfig::Get() {
|
||||
assertm(inited_, "uninitialized");
|
||||
return *config_;
|
||||
}
|
||||
|
||||
EnumConfig::EnumConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
configEnum* enumd,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
config_(config),
|
||||
enum_value_(enumd),
|
||||
default_value_(default_value),
|
||||
is_valid_fn_(std::move(is_valid_fn)),
|
||||
update_fn_(std::move(update_fn)) {
|
||||
}
|
||||
|
||||
void
|
||||
EnumConfig::Init() {
|
||||
BaseConfig::Init();
|
||||
assert(enum_value_ != nullptr);
|
||||
assertm(not enum_value_->empty(), "enum value empty");
|
||||
assert(config_ != nullptr);
|
||||
*config_ = default_value_;
|
||||
}
|
||||
|
||||
ConfigStatus
|
||||
EnumConfig::Set(const std::string& val, bool update) {
|
||||
assertm(inited_, "uninitialized");
|
||||
try {
|
||||
if (update and not modifiable_) {
|
||||
std::stringstream ss;
|
||||
ss << "Config " << name_ << " is immutable.";
|
||||
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
|
||||
}
|
||||
|
||||
if (enum_value_->find(val) == enum_value_->end()) {
|
||||
auto option_values = OptionValue(*enum_value_);
|
||||
std::stringstream ss;
|
||||
ss << "Config " << name_ << "(" << val << ") must be one of following: ";
|
||||
for (size_t i = 0; i < option_values.size() - 1; ++i) {
|
||||
ss << option_values[i] << ", ";
|
||||
}
|
||||
ss << option_values.back() << ".";
|
||||
return ConfigStatus(SetReturn::ENUM_VALUE_NOTFOUND, ss.str());
|
||||
}
|
||||
|
||||
int64_t value = enum_value_->at(val);
|
||||
std::string err;
|
||||
if (is_valid_fn_ && not is_valid_fn_(value, err)) {
|
||||
return ConfigStatus(SetReturn::INVALID, err);
|
||||
}
|
||||
|
||||
int64_t prev = *config_;
|
||||
*config_ = value;
|
||||
if (update && update_fn_ && not update_fn_(value, prev, err)) {
|
||||
*config_ = prev;
|
||||
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
|
||||
}
|
||||
|
||||
return ConfigStatus(SetReturn::SUCCESS, "");
|
||||
} catch (std::exception& e) {
|
||||
return ConfigStatus(SetReturn::EXCEPTION, e.what());
|
||||
} catch (...) {
|
||||
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
EnumConfig::Get() {
|
||||
assertm(inited_, "uninitialized");
|
||||
for (auto& it : *enum_value_) {
|
||||
if (*config_ == it.second) {
|
||||
return it.first;
|
||||
}
|
||||
}
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
IntegerConfig::IntegerConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
int64_t lower_bound,
|
||||
int64_t upper_bound,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
config_(config),
|
||||
lower_bound_(lower_bound),
|
||||
upper_bound_(upper_bound),
|
||||
default_value_(default_value),
|
||||
is_valid_fn_(std::move(is_valid_fn)),
|
||||
update_fn_(std::move(update_fn)) {
|
||||
}
|
||||
|
||||
void
|
||||
IntegerConfig::Init() {
|
||||
BaseConfig::Init();
|
||||
assert(config_ != nullptr);
|
||||
*config_ = default_value_;
|
||||
}
|
||||
|
||||
ConfigStatus
|
||||
IntegerConfig::Set(const std::string& val, bool update) {
|
||||
assertm(inited_, "uninitialized");
|
||||
try {
|
||||
if (update and not modifiable_) {
|
||||
std::stringstream ss;
|
||||
ss << "Config " << name_ << " is immutable.";
|
||||
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
|
||||
}
|
||||
|
||||
if (not is_integer(val)) {
|
||||
std::stringstream ss;
|
||||
ss << "Config " << name_ << "(" << val << ") must be a integer.";
|
||||
return ConfigStatus(SetReturn::INVALID, ss.str());
|
||||
}
|
||||
|
||||
int64_t value = std::stoll(val);
|
||||
if (not boundary_check<int64_t>(value, lower_bound_, upper_bound_)) {
|
||||
std::stringstream ss;
|
||||
ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << ", " << upper_bound_
|
||||
<< "].";
|
||||
return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str());
|
||||
}
|
||||
|
||||
std::string err;
|
||||
if (is_valid_fn_ && not is_valid_fn_(value, err))
|
||||
return ConfigStatus(SetReturn::INVALID, err);
|
||||
|
||||
int64_t prev = *config_;
|
||||
*config_ = value;
|
||||
if (update && update_fn_ && not update_fn_(value, prev, err)) {
|
||||
*config_ = prev;
|
||||
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
|
||||
}
|
||||
|
||||
return ConfigStatus(SetReturn::SUCCESS, "");
|
||||
} catch (std::exception& e) {
|
||||
return ConfigStatus(SetReturn::EXCEPTION, e.what());
|
||||
} catch (...) {
|
||||
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
IntegerConfig::Get() {
|
||||
assertm(inited_, "uninitialized");
|
||||
return std::to_string(*config_);
|
||||
}
|
||||
|
||||
FloatingConfig::FloatingConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
double lower_bound,
|
||||
double upper_bound,
|
||||
double* config,
|
||||
double default_value,
|
||||
std::function<bool(double val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(double val, double prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
config_(config),
|
||||
lower_bound_(lower_bound),
|
||||
upper_bound_(upper_bound),
|
||||
default_value_(default_value),
|
||||
is_valid_fn_(std::move(is_valid_fn)),
|
||||
update_fn_(std::move(update_fn)) {
|
||||
}
|
||||
|
||||
void
|
||||
FloatingConfig::Init() {
|
||||
BaseConfig::Init();
|
||||
assert(config_ != nullptr);
|
||||
*config_ = default_value_;
|
||||
}
|
||||
|
||||
ConfigStatus
|
||||
FloatingConfig::Set(const std::string& val, bool update) {
|
||||
assertm(inited_, "uninitialized");
|
||||
try {
|
||||
if (update and not modifiable_) {
|
||||
std::stringstream ss;
|
||||
ss << "Config " << name_ << " is immutable.";
|
||||
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
|
||||
}
|
||||
|
||||
double value = std::stod(val);
|
||||
if (not boundary_check<double>(value, lower_bound_, upper_bound_)) {
|
||||
std::stringstream ss;
|
||||
ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << ", " << upper_bound_
|
||||
<< "].";
|
||||
return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str());
|
||||
}
|
||||
|
||||
std::string err;
|
||||
if (is_valid_fn_ && not is_valid_fn_(value, err))
|
||||
return ConfigStatus(SetReturn::INVALID, err);
|
||||
|
||||
double prev = *config_;
|
||||
*config_ = value;
|
||||
if (update && update_fn_ && not update_fn_(value, prev, err)) {
|
||||
*config_ = prev;
|
||||
|
||||
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
|
||||
}
|
||||
|
||||
return ConfigStatus(SetReturn::SUCCESS, "");
|
||||
} catch (std::exception& e) {
|
||||
return ConfigStatus(SetReturn::EXCEPTION, e.what());
|
||||
} catch (...) {
|
||||
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
FloatingConfig::Get() {
|
||||
assertm(inited_, "uninitialized");
|
||||
return std::to_string(*config_);
|
||||
}
|
||||
|
||||
SizeConfig::SizeConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
int64_t lower_bound,
|
||||
int64_t upper_bound,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
config_(config),
|
||||
lower_bound_(lower_bound),
|
||||
upper_bound_(upper_bound),
|
||||
default_value_(default_value),
|
||||
is_valid_fn_(std::move(is_valid_fn)),
|
||||
update_fn_(std::move(update_fn)) {
|
||||
}
|
||||
|
||||
void
|
||||
SizeConfig::Init() {
|
||||
BaseConfig::Init();
|
||||
assert(config_ != nullptr);
|
||||
*config_ = default_value_;
|
||||
}
|
||||
|
||||
ConfigStatus
|
||||
SizeConfig::Set(const std::string& val, bool update) {
|
||||
assertm(inited_, "uninitialized");
|
||||
try {
|
||||
if (update and not modifiable_) {
|
||||
std::stringstream ss;
|
||||
ss << "Config " << name_ << " is immutable.";
|
||||
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
|
||||
}
|
||||
|
||||
std::string err;
|
||||
int64_t value = parse_bytes(val, err);
|
||||
if (not err.empty()) {
|
||||
return ConfigStatus(SetReturn::INVALID, err);
|
||||
}
|
||||
|
||||
if (not boundary_check<int64_t>(value, lower_bound_, upper_bound_)) {
|
||||
std::stringstream ss;
|
||||
ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << " Byte, " << upper_bound_
|
||||
<< " Byte].";
|
||||
return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str());
|
||||
}
|
||||
|
||||
if (is_valid_fn_ && not is_valid_fn_(value, err)) {
|
||||
return ConfigStatus(SetReturn::INVALID, err);
|
||||
}
|
||||
|
||||
int64_t prev = *config_;
|
||||
*config_ = value;
|
||||
if (update && update_fn_ && not update_fn_(value, prev, err)) {
|
||||
*config_ = prev;
|
||||
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
|
||||
}
|
||||
|
||||
return ConfigStatus(SetReturn::SUCCESS, "");
|
||||
} catch (std::exception& e) {
|
||||
return ConfigStatus(SetReturn::EXCEPTION, e.what());
|
||||
} catch (...) {
|
||||
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
SizeConfig::Get() {
|
||||
assertm(inited_, "uninitialized");
|
||||
return std::to_string(*config_);
|
||||
}
|
||||
|
||||
} // namespace milvus
|
|
@ -1,265 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace milvus {
|
||||
|
||||
using configEnum = const std::unordered_map<std::string, int64_t>;
|
||||
std::vector<std::string>
|
||||
OptionValue(const configEnum& ce);
|
||||
|
||||
enum SetReturn {
|
||||
SUCCESS = 1,
|
||||
IMMUTABLE,
|
||||
ENUM_VALUE_NOTFOUND,
|
||||
INVALID,
|
||||
OUT_OF_RANGE,
|
||||
UPDATE_FAILURE,
|
||||
EXCEPTION,
|
||||
UNEXPECTED,
|
||||
};
|
||||
|
||||
struct ConfigStatus {
|
||||
ConfigStatus(SetReturn sr, std::string msg) : set_return(sr), message(std::move(msg)) {
|
||||
}
|
||||
SetReturn set_return;
|
||||
std::string message;
|
||||
};
|
||||
|
||||
class BaseConfig {
|
||||
public:
|
||||
BaseConfig(const char* name, const char* alias, bool modifiable);
|
||||
virtual ~BaseConfig() = default;
|
||||
|
||||
public:
|
||||
bool inited_ = false;
|
||||
const char* name_;
|
||||
const char* alias_;
|
||||
const bool modifiable_;
|
||||
|
||||
public:
|
||||
virtual void
|
||||
Init();
|
||||
|
||||
virtual ConfigStatus
|
||||
Set(const std::string& value, bool update) = 0;
|
||||
|
||||
virtual std::string
|
||||
Get() = 0;
|
||||
};
|
||||
using BaseConfigPtr = std::shared_ptr<BaseConfig>;
|
||||
|
||||
class BoolConfig : public BaseConfig {
|
||||
public:
|
||||
BoolConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
bool* config,
|
||||
bool default_value,
|
||||
std::function<bool(bool val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(bool val, bool prev, std::string& err)> update_fn);
|
||||
|
||||
private:
|
||||
bool* config_;
|
||||
const bool default_value_;
|
||||
std::function<bool(bool val, std::string& err)> is_valid_fn_;
|
||||
std::function<bool(bool val, bool prev, std::string& err)> update_fn_;
|
||||
|
||||
public:
|
||||
void
|
||||
Init() override;
|
||||
|
||||
ConfigStatus
|
||||
Set(const std::string& value, bool update) override;
|
||||
|
||||
std::string
|
||||
Get() override;
|
||||
};
|
||||
|
||||
class StringConfig : public BaseConfig {
|
||||
public:
|
||||
StringConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
std::string* config,
|
||||
const char* default_value,
|
||||
std::function<bool(const std::string& val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn);
|
||||
|
||||
private:
|
||||
std::string* config_;
|
||||
const char* default_value_;
|
||||
std::function<bool(const std::string& val, std::string& err)> is_valid_fn_;
|
||||
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn_;
|
||||
|
||||
public:
|
||||
void
|
||||
Init() override;
|
||||
|
||||
ConfigStatus
|
||||
Set(const std::string& value, bool update) override;
|
||||
|
||||
std::string
|
||||
Get() override;
|
||||
};
|
||||
|
||||
class EnumConfig : public BaseConfig {
|
||||
public:
|
||||
EnumConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
configEnum* enumd,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
|
||||
|
||||
private:
|
||||
int64_t* config_;
|
||||
configEnum* enum_value_;
|
||||
const int64_t default_value_;
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn_;
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn_;
|
||||
|
||||
public:
|
||||
void
|
||||
Init() override;
|
||||
|
||||
ConfigStatus
|
||||
Set(const std::string& value, bool update) override;
|
||||
|
||||
std::string
|
||||
Get() override;
|
||||
};
|
||||
|
||||
class IntegerConfig : public BaseConfig {
|
||||
public:
|
||||
IntegerConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
int64_t lower_bound,
|
||||
int64_t upper_bound,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
|
||||
|
||||
private:
|
||||
int64_t* config_;
|
||||
int64_t lower_bound_;
|
||||
int64_t upper_bound_;
|
||||
const int64_t default_value_;
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn_;
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn_;
|
||||
|
||||
public:
|
||||
void
|
||||
Init() override;
|
||||
|
||||
ConfigStatus
|
||||
Set(const std::string& value, bool update) override;
|
||||
|
||||
std::string
|
||||
Get() override;
|
||||
};
|
||||
|
||||
class FloatingConfig : public BaseConfig {
|
||||
public:
|
||||
FloatingConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
double lower_bound,
|
||||
double upper_bound,
|
||||
double* config,
|
||||
double default_value,
|
||||
std::function<bool(double val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(double val, double prev, std::string& err)> update_fn);
|
||||
|
||||
private:
|
||||
double* config_;
|
||||
double lower_bound_;
|
||||
double upper_bound_;
|
||||
const double default_value_;
|
||||
std::function<bool(double val, std::string& err)> is_valid_fn_;
|
||||
std::function<bool(double val, double prev, std::string& err)> update_fn_;
|
||||
|
||||
public:
|
||||
void
|
||||
Init() override;
|
||||
|
||||
ConfigStatus
|
||||
Set(const std::string& value, bool update) override;
|
||||
|
||||
std::string
|
||||
Get() override;
|
||||
};
|
||||
|
||||
class SizeConfig : public BaseConfig {
|
||||
public:
|
||||
SizeConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
int64_t lower_bound,
|
||||
int64_t upper_bound,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
|
||||
|
||||
private:
|
||||
int64_t* config_;
|
||||
int64_t lower_bound_;
|
||||
int64_t upper_bound_;
|
||||
const int64_t default_value_;
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn_;
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn_;
|
||||
|
||||
public:
|
||||
void
|
||||
Init() override;
|
||||
|
||||
ConfigStatus
|
||||
Set(const std::string& value, bool update) override;
|
||||
|
||||
std::string
|
||||
Get() override;
|
||||
};
|
||||
|
||||
#define CreateBoolConfig(name, modifiable, config_addr, default, is_valid, update) \
|
||||
std::make_shared<BoolConfig>(name, nullptr, modifiable, config_addr, (default), is_valid, update)
|
||||
|
||||
#define CreateStringConfig(name, modifiable, config_addr, default, is_valid, update) \
|
||||
std::make_shared<StringConfig>(name, nullptr, modifiable, config_addr, (default), is_valid, update)
|
||||
|
||||
#define CreateEnumConfig(name, modifiable, enumd, config_addr, default, is_valid, update) \
|
||||
std::make_shared<EnumConfig>(name, nullptr, modifiable, enumd, config_addr, (default), is_valid, update)
|
||||
|
||||
#define CreateIntegerConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \
|
||||
std::make_shared<IntegerConfig>(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \
|
||||
is_valid, update)
|
||||
|
||||
#define CreateFloatingConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \
|
||||
std::make_shared<FloatingConfig>(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \
|
||||
is_valid, update)
|
||||
|
||||
#define CreateSizeConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \
|
||||
std::make_shared<SizeConfig>(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \
|
||||
is_valid, update)
|
||||
|
||||
} // namespace milvus
|
|
@ -1,493 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
|
||||
#include "config/ServerConfig.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
#define _MODIFIABLE (true)
|
||||
#define _IMMUTABLE (false)
|
||||
|
||||
template <typename T>
|
||||
class Utils {
|
||||
public:
|
||||
bool
|
||||
validate_fn(const T& value, std::string& err) {
|
||||
validate_value = value;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
update_fn(const T& value, const T& prev, std::string& err) {
|
||||
new_value = value;
|
||||
prev_value = prev;
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
T validate_value;
|
||||
T new_value;
|
||||
T prev_value;
|
||||
};
|
||||
|
||||
/* ValidBoolConfigTest */
|
||||
class ValidBoolConfigTest : public testing::Test, public Utils<bool> {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(ValidBoolConfigTest, init_load_update_get_test) {
|
||||
auto validate = std::bind(&ValidBoolConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto update = std::bind(&ValidBoolConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
|
||||
std::placeholders::_3);
|
||||
|
||||
bool bool_value = true;
|
||||
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, false, validate, update);
|
||||
ASSERT_EQ(bool_value, true);
|
||||
ASSERT_EQ(bool_config->modifiable_, true);
|
||||
|
||||
bool_config->Init();
|
||||
ASSERT_EQ(bool_value, false);
|
||||
ASSERT_EQ(bool_config->Get(), "false");
|
||||
|
||||
{
|
||||
// now `bool_value` is `false`, calling Set(update=false) to set it to `true`, but not notify update_fn()
|
||||
validate_value = false;
|
||||
new_value = false;
|
||||
prev_value = true;
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = bool_config->Set("true", false);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
EXPECT_EQ(bool_value, true);
|
||||
EXPECT_EQ(bool_config->Get(), "true");
|
||||
|
||||
// expect change
|
||||
EXPECT_EQ(validate_value, true);
|
||||
// expect not change
|
||||
EXPECT_EQ(new_value, false);
|
||||
EXPECT_EQ(prev_value, true);
|
||||
}
|
||||
|
||||
{
|
||||
// now `bool_value` is `true`, calling Set(update=true) to set it to `false`, will notify update_fn()
|
||||
validate_value = true;
|
||||
new_value = true;
|
||||
prev_value = false;
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = bool_config->Set("false", true);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
EXPECT_EQ(bool_value, false);
|
||||
EXPECT_EQ(bool_config->Get(), "false");
|
||||
|
||||
// expect change
|
||||
EXPECT_EQ(validate_value, false);
|
||||
EXPECT_EQ(new_value, false);
|
||||
EXPECT_EQ(prev_value, true);
|
||||
}
|
||||
}
|
||||
|
||||
/* ValidStringConfigTest */
|
||||
class ValidStringConfigTest : public testing::Test, public Utils<std::string> {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(ValidStringConfigTest, init_load_update_get_test) {
|
||||
auto validate = std::bind(&ValidStringConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto update = std::bind(&ValidStringConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
|
||||
std::placeholders::_3);
|
||||
|
||||
std::string string_value;
|
||||
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", validate, update);
|
||||
ASSERT_EQ(string_value, "");
|
||||
ASSERT_EQ(string_config->modifiable_, true);
|
||||
|
||||
string_config->Init();
|
||||
ASSERT_EQ(string_value, "Magic");
|
||||
ASSERT_EQ(string_config->Get(), "Magic");
|
||||
|
||||
{
|
||||
// now `string_value` is `Magic`, calling Set(update=false) to set it to `cigaM`, but not notify update_fn()
|
||||
validate_value = "";
|
||||
new_value = "";
|
||||
prev_value = "";
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = string_config->Set("cigaM", false);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
EXPECT_EQ(string_value, "cigaM");
|
||||
EXPECT_EQ(string_config->Get(), "cigaM");
|
||||
|
||||
// expect change
|
||||
EXPECT_EQ(validate_value, "cigaM");
|
||||
// expect not change
|
||||
EXPECT_EQ(new_value, "");
|
||||
EXPECT_EQ(prev_value, "");
|
||||
}
|
||||
|
||||
{
|
||||
// now `string_value` is `cigaM`, calling Set(update=true) to set it to `Check`, will notify update_fn()
|
||||
validate_value = "";
|
||||
new_value = "";
|
||||
prev_value = "";
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = string_config->Set("Check", true);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
EXPECT_EQ(string_value, "Check");
|
||||
EXPECT_EQ(string_config->Get(), "Check");
|
||||
|
||||
// expect change
|
||||
EXPECT_EQ(validate_value, "Check");
|
||||
EXPECT_EQ(new_value, "Check");
|
||||
EXPECT_EQ(prev_value, "cigaM");
|
||||
}
|
||||
}
|
||||
|
||||
/* ValidIntegerConfigTest */
|
||||
class ValidIntegerConfigTest : public testing::Test, public Utils<int64_t> {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(ValidIntegerConfigTest, init_load_update_get_test) {
|
||||
auto validate = std::bind(&ValidIntegerConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto update = std::bind(&ValidIntegerConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
|
||||
std::placeholders::_3);
|
||||
|
||||
int64_t integer_value = 0;
|
||||
auto integer_config = CreateIntegerConfig("i", _MODIFIABLE, -100, 100, &integer_value, 42, validate, update);
|
||||
ASSERT_EQ(integer_value, 0);
|
||||
ASSERT_EQ(integer_config->modifiable_, true);
|
||||
|
||||
integer_config->Init();
|
||||
ASSERT_EQ(integer_value, 42);
|
||||
ASSERT_EQ(integer_config->Get(), "42");
|
||||
|
||||
{
|
||||
// now `integer_value` is `42`, calling Set(update=false) to set it to `24`, but not notify update_fn()
|
||||
validate_value = 0;
|
||||
new_value = 0;
|
||||
prev_value = 0;
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = integer_config->Set("24", false);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
EXPECT_EQ(integer_value, 24);
|
||||
EXPECT_EQ(integer_config->Get(), "24");
|
||||
|
||||
// expect change
|
||||
EXPECT_EQ(validate_value, 24);
|
||||
// expect not change
|
||||
EXPECT_EQ(new_value, 0);
|
||||
EXPECT_EQ(prev_value, 0);
|
||||
}
|
||||
|
||||
{
|
||||
// now `integer_value` is `24`, calling Set(update=true) to set it to `36`, will notify update_fn()
|
||||
validate_value = 0;
|
||||
new_value = 0;
|
||||
prev_value = 0;
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = integer_config->Set("36", true);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
EXPECT_EQ(integer_value, 36);
|
||||
EXPECT_EQ(integer_config->Get(), "36");
|
||||
|
||||
// expect change
|
||||
EXPECT_EQ(validate_value, 36);
|
||||
EXPECT_EQ(new_value, 36);
|
||||
EXPECT_EQ(prev_value, 24);
|
||||
}
|
||||
}
|
||||
|
||||
/* ValidFloatingConfigTest */
|
||||
class ValidFloatingConfigTest : public testing::Test, public Utils<double> {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(ValidFloatingConfigTest, init_load_update_get_test) {
|
||||
auto validate =
|
||||
std::bind(&ValidFloatingConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto update = std::bind(&ValidFloatingConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
|
||||
std::placeholders::_3);
|
||||
|
||||
double floating_value = 0.0;
|
||||
auto floating_config = CreateFloatingConfig("f", _MODIFIABLE, -10.0, 10.0, &floating_value, 3.14, validate, update);
|
||||
ASSERT_FLOAT_EQ(floating_value, 0.0);
|
||||
ASSERT_EQ(floating_config->modifiable_, true);
|
||||
|
||||
floating_config->Init();
|
||||
ASSERT_FLOAT_EQ(floating_value, 3.14);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 3.14);
|
||||
|
||||
{
|
||||
// now `floating_value` is `3.14`, calling Set(update=false) to set it to `6.22`, but not notify update_fn()
|
||||
validate_value = 0.0;
|
||||
new_value = 0.0;
|
||||
prev_value = 0.0;
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = floating_config->Set("6.22", false);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
ASSERT_FLOAT_EQ(floating_value, 6.22);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 6.22);
|
||||
|
||||
// expect change
|
||||
ASSERT_FLOAT_EQ(validate_value, 6.22);
|
||||
// expect not change
|
||||
ASSERT_FLOAT_EQ(new_value, 0.0);
|
||||
ASSERT_FLOAT_EQ(prev_value, 0.0);
|
||||
}
|
||||
|
||||
{
|
||||
// now `integer_value` is `6.22`, calling Set(update=true) to set it to `-3.14`, will notify update_fn()
|
||||
validate_value = 0.0;
|
||||
new_value = 0.0;
|
||||
prev_value = 0.0;
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = floating_config->Set("-3.14", true);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
ASSERT_FLOAT_EQ(floating_value, -3.14);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), -3.14);
|
||||
|
||||
// expect change
|
||||
ASSERT_FLOAT_EQ(validate_value, -3.14);
|
||||
ASSERT_FLOAT_EQ(new_value, -3.14);
|
||||
ASSERT_FLOAT_EQ(prev_value, 6.22);
|
||||
}
|
||||
}
|
||||
|
||||
/* ValidEnumConfigTest */
|
||||
class ValidEnumConfigTest : public testing::Test, public Utils<int64_t> {
|
||||
protected:
|
||||
};
|
||||
|
||||
// template <>
|
||||
// int64_t Utils<int64_t>::validate_value = 0;
|
||||
// template <>
|
||||
// int64_t Utils<int64_t>::new_value = 0;
|
||||
// template <>
|
||||
// int64_t Utils<int64_t>::prev_value = 0;
|
||||
|
||||
TEST_F(ValidEnumConfigTest, init_load_update_get_test) {
|
||||
auto validate = std::bind(&ValidEnumConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto update = std::bind(&ValidEnumConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
|
||||
std::placeholders::_3);
|
||||
|
||||
configEnum testEnum{
|
||||
{"a", 1},
|
||||
{"b", 2},
|
||||
{"c", 3},
|
||||
};
|
||||
int64_t enum_value = 0;
|
||||
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, validate, update);
|
||||
ASSERT_EQ(enum_value, 0);
|
||||
ASSERT_EQ(enum_config->modifiable_, true);
|
||||
|
||||
enum_config->Init();
|
||||
ASSERT_EQ(enum_value, 1);
|
||||
ASSERT_EQ(enum_config->Get(), "a");
|
||||
|
||||
{
|
||||
// now `enum_value` is `a`, calling Set(update=false) to set it to `b`, but not notify update_fn()
|
||||
validate_value = 0;
|
||||
new_value = 0;
|
||||
prev_value = 0;
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = enum_config->Set("b", false);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
ASSERT_EQ(enum_value, 2);
|
||||
ASSERT_EQ(enum_config->Get(), "b");
|
||||
|
||||
// expect change
|
||||
ASSERT_EQ(validate_value, 2);
|
||||
// expect not change
|
||||
ASSERT_EQ(new_value, 0);
|
||||
ASSERT_EQ(prev_value, 0);
|
||||
}
|
||||
|
||||
{
|
||||
// now `enum_value` is `b`, calling Set(update=true) to set it to `c`, will notify update_fn()
|
||||
validate_value = 0;
|
||||
new_value = 0;
|
||||
prev_value = 0;
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = enum_config->Set("c", true);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
ASSERT_EQ(enum_value, 3);
|
||||
ASSERT_EQ(enum_config->Get(), "c");
|
||||
|
||||
// expect change
|
||||
ASSERT_EQ(validate_value, 3);
|
||||
ASSERT_EQ(new_value, 3);
|
||||
ASSERT_EQ(prev_value, 2);
|
||||
}
|
||||
}
|
||||
|
||||
/* ValidSizeConfigTest */
|
||||
class ValidSizeConfigTest : public testing::Test, public Utils<int64_t> {
|
||||
protected:
|
||||
};
|
||||
|
||||
// template <>
|
||||
// int64_t Utils<int64_t>::validate_value = 0;
|
||||
// template <>
|
||||
// int64_t Utils<int64_t>::new_value = 0;
|
||||
// template <>
|
||||
// int64_t Utils<int64_t>::prev_value = 0;
|
||||
|
||||
TEST_F(ValidSizeConfigTest, init_load_update_get_test) {
|
||||
auto validate = std::bind(&ValidSizeConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto update = std::bind(&ValidSizeConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
|
||||
std::placeholders::_3);
|
||||
|
||||
int64_t size_value = 0;
|
||||
auto size_config = CreateSizeConfig("i", _MODIFIABLE, 0, 1024 * 1024, &size_value, 1024, validate, update);
|
||||
ASSERT_EQ(size_value, 0);
|
||||
ASSERT_EQ(size_config->modifiable_, true);
|
||||
|
||||
size_config->Init();
|
||||
ASSERT_EQ(size_value, 1024);
|
||||
ASSERT_EQ(size_config->Get(), "1024");
|
||||
|
||||
{
|
||||
// now `size_value` is `1024`, calling Set(update=false) to set it to `4096`, but not notify update_fn()
|
||||
validate_value = 0;
|
||||
new_value = 0;
|
||||
prev_value = 0;
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = size_config->Set("4096", false);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
EXPECT_EQ(size_value, 4096);
|
||||
EXPECT_EQ(size_config->Get(), "4096");
|
||||
|
||||
// expect change
|
||||
EXPECT_EQ(validate_value, 4096);
|
||||
// expect not change
|
||||
EXPECT_EQ(new_value, 0);
|
||||
EXPECT_EQ(prev_value, 0);
|
||||
}
|
||||
|
||||
{
|
||||
// now `size_value` is `4096`, calling Set(update=true) to set it to `256kb`, will notify update_fn()
|
||||
validate_value = 0;
|
||||
new_value = 0;
|
||||
prev_value = 0;
|
||||
|
||||
ConfigStatus status(SetReturn::SUCCESS, "");
|
||||
status = size_config->Set("256kb", true);
|
||||
|
||||
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
|
||||
EXPECT_EQ(size_value, 256 * 1024);
|
||||
EXPECT_EQ(size_config->Get(), "262144");
|
||||
|
||||
// expect change
|
||||
EXPECT_EQ(validate_value, 262144);
|
||||
EXPECT_EQ(new_value, 262144);
|
||||
EXPECT_EQ(prev_value, 4096);
|
||||
}
|
||||
}
|
||||
|
||||
class ValidTest : public testing::Test {
|
||||
protected:
|
||||
configEnum family{
|
||||
{"ipv4", 1},
|
||||
{"ipv6", 2},
|
||||
};
|
||||
|
||||
struct Server {
|
||||
bool running = true;
|
||||
std::string hostname;
|
||||
int64_t family = 0;
|
||||
int64_t port = 0;
|
||||
double uptime = 0;
|
||||
};
|
||||
|
||||
Server server;
|
||||
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
config_list = {
|
||||
CreateBoolConfig("running", true, &server.running, true, nullptr, nullptr),
|
||||
CreateStringConfig("hostname", true, &server.hostname, "Magic", nullptr, nullptr),
|
||||
CreateEnumConfig("socket_family", false, &family, &server.family, 2, nullptr, nullptr),
|
||||
CreateIntegerConfig("port", true, 1024, 65535, &server.port, 19530, nullptr, nullptr),
|
||||
CreateFloatingConfig("uptime", true, 0, 9999.0, &server.uptime, 0, nullptr, nullptr),
|
||||
};
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override {
|
||||
}
|
||||
|
||||
protected:
|
||||
void
|
||||
Init() {
|
||||
for (auto& config : config_list) {
|
||||
config->Init();
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
Load() {
|
||||
std::unordered_map<std::string, std::string> config_file{
|
||||
{"running", "false"},
|
||||
};
|
||||
|
||||
for (auto& c : config_file) Set(c.first, c.second, false);
|
||||
}
|
||||
|
||||
void
|
||||
Set(const std::string& name, const std::string& value, bool update = true) {
|
||||
for (auto& config : config_list) {
|
||||
if (std::strcmp(name.c_str(), config->name_) == 0) {
|
||||
config->Set(value, update);
|
||||
return;
|
||||
}
|
||||
}
|
||||
throw "Config " + name + " not found.";
|
||||
}
|
||||
|
||||
std::string
|
||||
Get(const std::string& name) {
|
||||
for (auto& config : config_list) {
|
||||
if (std::strcmp(name.c_str(), config->name_) == 0) {
|
||||
return config->Get();
|
||||
}
|
||||
}
|
||||
throw "Config " + name + " not found.";
|
||||
}
|
||||
|
||||
std::vector<BaseConfigPtr> config_list;
|
||||
};
|
||||
|
||||
} // namespace milvus
|
|
@ -1,861 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "config/ServerConfig.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
#define _MODIFIABLE (true)
|
||||
#define _IMMUTABLE (false)
|
||||
|
||||
template <typename T>
|
||||
class Utils {
|
||||
public:
|
||||
static bool
|
||||
valid_check_failure(const T& value, std::string& err) {
|
||||
err = "Value is invalid.";
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool
|
||||
update_failure(const T& value, const T& prev, std::string& err) {
|
||||
err = "Update is failure";
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool
|
||||
valid_check_raise_string(const T& value, std::string& err) {
|
||||
throw "string exception";
|
||||
}
|
||||
|
||||
static bool
|
||||
valid_check_raise_exception(const T& value, std::string& err) {
|
||||
throw std::bad_alloc();
|
||||
}
|
||||
};
|
||||
|
||||
/* BoolConfigTest */
|
||||
class BoolConfigTest : public testing::Test, public Utils<bool> {};
|
||||
|
||||
TEST_F(BoolConfigTest, nullptr_init_test) {
|
||||
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, nullptr, true, nullptr, nullptr);
|
||||
ASSERT_DEATH(bool_config->Init(), "nullptr");
|
||||
}
|
||||
|
||||
TEST_F(BoolConfigTest, init_twice_test) {
|
||||
bool bool_value;
|
||||
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr);
|
||||
ASSERT_DEATH(
|
||||
{
|
||||
bool_config->Init();
|
||||
bool_config->Init();
|
||||
},
|
||||
"initialized");
|
||||
}
|
||||
|
||||
TEST_F(BoolConfigTest, non_init_test) {
|
||||
bool bool_value;
|
||||
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr);
|
||||
ASSERT_DEATH(bool_config->Set("false", true), "uninitialized");
|
||||
ASSERT_DEATH(bool_config->Get(), "uninitialized");
|
||||
}
|
||||
|
||||
TEST_F(BoolConfigTest, immutable_update_test) {
|
||||
bool bool_value = false;
|
||||
auto bool_config = CreateBoolConfig("b", _IMMUTABLE, &bool_value, true, nullptr, nullptr);
|
||||
bool_config->Init();
|
||||
ASSERT_EQ(bool_value, true);
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = bool_config->Set("false", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
|
||||
ASSERT_EQ(bool_value, true);
|
||||
}
|
||||
|
||||
TEST_F(BoolConfigTest, set_invalid_value_test) {
|
||||
bool bool_value;
|
||||
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr);
|
||||
bool_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = bool_config->Set(" false", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(bool_config->Get(), "true");
|
||||
|
||||
status = bool_config->Set("false ", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(bool_config->Get(), "true");
|
||||
|
||||
status = bool_config->Set("afalse", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(bool_config->Get(), "true");
|
||||
|
||||
status = bool_config->Set("falsee", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(bool_config->Get(), "true");
|
||||
|
||||
status = bool_config->Set("abcdefg", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(bool_config->Get(), "true");
|
||||
|
||||
status = bool_config->Set("123456", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(bool_config->Get(), "true");
|
||||
|
||||
status = bool_config->Set("", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(bool_config->Get(), "true");
|
||||
}
|
||||
|
||||
TEST_F(BoolConfigTest, valid_check_fail_test) {
|
||||
bool bool_value;
|
||||
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_failure, nullptr);
|
||||
bool_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = bool_config->Set("123456", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(bool_config->Get(), "true");
|
||||
}
|
||||
|
||||
TEST_F(BoolConfigTest, update_fail_test) {
|
||||
bool bool_value;
|
||||
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, update_failure);
|
||||
bool_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = bool_config->Set("false", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
|
||||
ASSERT_EQ(bool_config->Get(), "true");
|
||||
}
|
||||
|
||||
TEST_F(BoolConfigTest, string_exception_test) {
|
||||
bool bool_value;
|
||||
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_raise_string, nullptr);
|
||||
bool_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = bool_config->Set("false", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
|
||||
ASSERT_EQ(bool_config->Get(), "true");
|
||||
}
|
||||
|
||||
TEST_F(BoolConfigTest, standard_exception_test) {
|
||||
bool bool_value;
|
||||
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_raise_exception, nullptr);
|
||||
bool_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = bool_config->Set("false", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
|
||||
ASSERT_EQ(bool_config->Get(), "true");
|
||||
}
|
||||
|
||||
/* StringConfigTest */
|
||||
class StringConfigTest : public testing::Test, public Utils<std::string> {};
|
||||
|
||||
TEST_F(StringConfigTest, nullptr_init_test) {
|
||||
auto string_config = CreateStringConfig("s", true, nullptr, "Magic", nullptr, nullptr);
|
||||
ASSERT_DEATH(string_config->Init(), "nullptr");
|
||||
}
|
||||
|
||||
TEST_F(StringConfigTest, init_twice_test) {
|
||||
std::string string_value;
|
||||
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, nullptr);
|
||||
ASSERT_DEATH(
|
||||
{
|
||||
string_config->Init();
|
||||
string_config->Init();
|
||||
},
|
||||
"initialized");
|
||||
}
|
||||
|
||||
TEST_F(StringConfigTest, non_init_test) {
|
||||
std::string string_value;
|
||||
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, nullptr);
|
||||
ASSERT_DEATH(string_config->Set("value", true), "uninitialized");
|
||||
ASSERT_DEATH(string_config->Get(), "uninitialized");
|
||||
}
|
||||
|
||||
TEST_F(StringConfigTest, immutable_update_test) {
|
||||
std::string string_value;
|
||||
auto string_config = CreateStringConfig("s", _IMMUTABLE, &string_value, "Magic", nullptr, nullptr);
|
||||
string_config->Init();
|
||||
ASSERT_EQ(string_value, "Magic");
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = string_config->Set("cigaM", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
|
||||
ASSERT_EQ(string_value, "Magic");
|
||||
}
|
||||
|
||||
TEST_F(StringConfigTest, valid_check_fail_test) {
|
||||
std::string string_value;
|
||||
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_failure, nullptr);
|
||||
string_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = string_config->Set("123456", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(string_config->Get(), "Magic");
|
||||
}
|
||||
|
||||
TEST_F(StringConfigTest, update_fail_test) {
|
||||
std::string string_value;
|
||||
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, update_failure);
|
||||
string_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = string_config->Set("Mi", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
|
||||
ASSERT_EQ(string_config->Get(), "Magic");
|
||||
}
|
||||
|
||||
TEST_F(StringConfigTest, string_exception_test) {
|
||||
std::string string_value;
|
||||
auto string_config =
|
||||
CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_raise_string, nullptr);
|
||||
string_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = string_config->Set("any", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
|
||||
ASSERT_EQ(string_config->Get(), "Magic");
|
||||
}
|
||||
|
||||
TEST_F(StringConfigTest, standard_exception_test) {
|
||||
std::string string_value;
|
||||
auto string_config =
|
||||
CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_raise_exception, nullptr);
|
||||
string_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = string_config->Set("any", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
|
||||
ASSERT_EQ(string_config->Get(), "Magic");
|
||||
}
|
||||
|
||||
/* IntegerConfigTest */
|
||||
class IntegerConfigTest : public testing::Test, public Utils<int64_t> {};
|
||||
|
||||
TEST_F(IntegerConfigTest, nullptr_init_test) {
|
||||
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, nullptr, 19530, nullptr, nullptr);
|
||||
ASSERT_DEATH(integer_config->Init(), "nullptr");
|
||||
}
|
||||
|
||||
TEST_F(IntegerConfigTest, init_twice_test) {
|
||||
int64_t integer_value;
|
||||
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
|
||||
ASSERT_DEATH(
|
||||
{
|
||||
integer_config->Init();
|
||||
integer_config->Init();
|
||||
},
|
||||
"initialized");
|
||||
}
|
||||
|
||||
TEST_F(IntegerConfigTest, non_init_test) {
|
||||
int64_t integer_value;
|
||||
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
|
||||
ASSERT_DEATH(integer_config->Set("42", true), "uninitialized");
|
||||
ASSERT_DEATH(integer_config->Get(), "uninitialized");
|
||||
}
|
||||
|
||||
TEST_F(IntegerConfigTest, immutable_update_test) {
|
||||
int64_t integer_value;
|
||||
auto integer_config = CreateIntegerConfig("i", _IMMUTABLE, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
|
||||
integer_config->Init();
|
||||
ASSERT_EQ(integer_value, 19530);
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("2048", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
|
||||
ASSERT_EQ(integer_value, 19530);
|
||||
}
|
||||
|
||||
TEST_F(IntegerConfigTest, set_invalid_value_test) {
|
||||
}
|
||||
|
||||
TEST_F(IntegerConfigTest, valid_check_fail_test) {
|
||||
int64_t integer_value;
|
||||
auto integer_config =
|
||||
CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_failure, nullptr);
|
||||
integer_config->Init();
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("2048", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(integer_config->Get(), "19530");
|
||||
}
|
||||
|
||||
TEST_F(IntegerConfigTest, update_fail_test) {
|
||||
int64_t integer_value;
|
||||
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, update_failure);
|
||||
integer_config->Init();
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("2048", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
|
||||
ASSERT_EQ(integer_config->Get(), "19530");
|
||||
}
|
||||
|
||||
TEST_F(IntegerConfigTest, string_exception_test) {
|
||||
int64_t integer_value;
|
||||
auto integer_config =
|
||||
CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_raise_string, nullptr);
|
||||
integer_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("2048", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
|
||||
ASSERT_EQ(integer_config->Get(), "19530");
|
||||
}
|
||||
|
||||
TEST_F(IntegerConfigTest, standard_exception_test) {
|
||||
int64_t integer_value;
|
||||
auto integer_config =
|
||||
CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_raise_exception, nullptr);
|
||||
integer_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("2048", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
|
||||
ASSERT_EQ(integer_config->Get(), "19530");
|
||||
}
|
||||
|
||||
TEST_F(IntegerConfigTest, out_of_range_test) {
|
||||
int64_t integer_value;
|
||||
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
|
||||
integer_config->Init();
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("1023", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
|
||||
ASSERT_EQ(integer_config->Get(), "19530");
|
||||
}
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("65536", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
|
||||
ASSERT_EQ(integer_config->Get(), "19530");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(IntegerConfigTest, invalid_bound_test) {
|
||||
int64_t integer_value;
|
||||
auto integer_config = CreateIntegerConfig("i", true, 100, 0, &integer_value, 50, nullptr, nullptr);
|
||||
integer_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("30", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
|
||||
ASSERT_EQ(integer_config->Get(), "50");
|
||||
}
|
||||
|
||||
TEST_F(IntegerConfigTest, invalid_format_test) {
|
||||
int64_t integer_value;
|
||||
auto integer_config = CreateIntegerConfig("i", true, 0, 100, &integer_value, 50, nullptr, nullptr);
|
||||
integer_config->Init();
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("3-0", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(integer_config->Get(), "50");
|
||||
}
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("30-", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(integer_config->Get(), "50");
|
||||
}
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("+30", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(integer_config->Get(), "50");
|
||||
}
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("a30", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(integer_config->Get(), "50");
|
||||
}
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("30a", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(integer_config->Get(), "50");
|
||||
}
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = integer_config->Set("3a0", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(integer_config->Get(), "50");
|
||||
}
|
||||
}
|
||||
|
||||
/* FloatingConfigTest */
|
||||
class FloatingConfigTest : public testing::Test, public Utils<double> {};
|
||||
|
||||
TEST_F(FloatingConfigTest, nullptr_init_test) {
|
||||
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, nullptr, 4.5, nullptr, nullptr);
|
||||
ASSERT_DEATH(floating_config->Init(), "nullptr");
|
||||
}
|
||||
|
||||
TEST_F(FloatingConfigTest, init_twice_test) {
|
||||
double floating_value;
|
||||
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr);
|
||||
ASSERT_DEATH(
|
||||
{
|
||||
floating_config->Init();
|
||||
floating_config->Init();
|
||||
},
|
||||
"initialized");
|
||||
}
|
||||
|
||||
TEST_F(FloatingConfigTest, non_init_test) {
|
||||
double floating_value;
|
||||
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr);
|
||||
ASSERT_DEATH(floating_config->Set("3.14", true), "uninitialized");
|
||||
ASSERT_DEATH(floating_config->Get(), "uninitialized");
|
||||
}
|
||||
|
||||
TEST_F(FloatingConfigTest, immutable_update_test) {
|
||||
double floating_value;
|
||||
auto floating_config = CreateFloatingConfig("f", _IMMUTABLE, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr);
|
||||
floating_config->Init();
|
||||
ASSERT_FLOAT_EQ(floating_value, 4.5);
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = floating_config->Set("1.23", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
|
||||
}
|
||||
|
||||
TEST_F(FloatingConfigTest, set_invalid_value_test) {
|
||||
}
|
||||
|
||||
TEST_F(FloatingConfigTest, valid_check_fail_test) {
|
||||
double floating_value;
|
||||
auto floating_config =
|
||||
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_failure, nullptr);
|
||||
floating_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = floating_config->Set("1.23", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
|
||||
}
|
||||
|
||||
TEST_F(FloatingConfigTest, update_fail_test) {
|
||||
double floating_value;
|
||||
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, update_failure);
|
||||
floating_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = floating_config->Set("1.23", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
|
||||
}
|
||||
|
||||
TEST_F(FloatingConfigTest, string_exception_test) {
|
||||
double floating_value;
|
||||
auto floating_config =
|
||||
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_string, nullptr);
|
||||
floating_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = floating_config->Set("1.23", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
|
||||
}
|
||||
|
||||
TEST_F(FloatingConfigTest, standard_exception_test) {
|
||||
double floating_value;
|
||||
auto floating_config =
|
||||
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_exception, nullptr);
|
||||
floating_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = floating_config->Set("1.23", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
|
||||
}
|
||||
|
||||
TEST_F(FloatingConfigTest, out_of_range_test) {
|
||||
double floating_value;
|
||||
auto floating_config =
|
||||
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_exception, nullptr);
|
||||
floating_config->Init();
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = floating_config->Set("0.99", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
|
||||
}
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = floating_config->Set("10.00", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FloatingConfigTest, invalid_bound_test) {
|
||||
double floating_value;
|
||||
auto floating_config =
|
||||
CreateFloatingConfig("f", true, 9.9, 1.0, &floating_value, 4.5, valid_check_raise_exception, nullptr);
|
||||
floating_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = floating_config->Set("6.0", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
|
||||
}
|
||||
|
||||
TEST_F(FloatingConfigTest, DISABLED_invalid_format_test) {
|
||||
double floating_value;
|
||||
auto floating_config = CreateFloatingConfig("f", true, 1.0, 100.0, &floating_value, 4.5, nullptr, nullptr);
|
||||
floating_config->Init();
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = floating_config->Set("6.0.1", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
|
||||
}
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = floating_config->Set("6a0", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
|
||||
}
|
||||
}
|
||||
|
||||
/* EnumConfigTest */
|
||||
class EnumConfigTest : public testing::Test, public Utils<int64_t> {};
|
||||
|
||||
TEST_F(EnumConfigTest, nullptr_init_test) {
|
||||
configEnum testEnum{
|
||||
{"e", 1},
|
||||
};
|
||||
int64_t testEnumValue;
|
||||
auto enum_config_1 = CreateEnumConfig("e", _MODIFIABLE, &testEnum, nullptr, 2, nullptr, nullptr);
|
||||
ASSERT_DEATH(enum_config_1->Init(), "nullptr");
|
||||
|
||||
auto enum_config_2 = CreateEnumConfig("e", _MODIFIABLE, nullptr, &testEnumValue, 2, nullptr, nullptr);
|
||||
ASSERT_DEATH(enum_config_2->Init(), "nullptr");
|
||||
|
||||
auto enum_config_3 = CreateEnumConfig("e", _MODIFIABLE, nullptr, nullptr, 2, nullptr, nullptr);
|
||||
ASSERT_DEATH(enum_config_3->Init(), "nullptr");
|
||||
}
|
||||
|
||||
TEST_F(EnumConfigTest, init_twice_test) {
|
||||
configEnum testEnum{
|
||||
{"e", 1},
|
||||
};
|
||||
int64_t enum_value;
|
||||
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr);
|
||||
ASSERT_DEATH(
|
||||
{
|
||||
enum_config->Init();
|
||||
enum_config->Init();
|
||||
},
|
||||
"initialized");
|
||||
}
|
||||
|
||||
TEST_F(EnumConfigTest, non_init_test) {
|
||||
configEnum testEnum{
|
||||
{"e", 1},
|
||||
};
|
||||
int64_t enum_value;
|
||||
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr);
|
||||
ASSERT_DEATH(enum_config->Set("e", true), "uninitialized");
|
||||
ASSERT_DEATH(enum_config->Get(), "uninitialized");
|
||||
}
|
||||
|
||||
TEST_F(EnumConfigTest, immutable_update_test) {
|
||||
configEnum testEnum{
|
||||
{"a", 1},
|
||||
{"b", 2},
|
||||
{"c", 3},
|
||||
};
|
||||
int64_t enum_value = 0;
|
||||
auto enum_config = CreateEnumConfig("e", _IMMUTABLE, &testEnum, &enum_value, 1, nullptr, nullptr);
|
||||
enum_config->Init();
|
||||
ASSERT_EQ(enum_value, 1);
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = enum_config->Set("b", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
|
||||
ASSERT_EQ(enum_value, 1);
|
||||
}
|
||||
|
||||
TEST_F(EnumConfigTest, set_invalid_value_check) {
|
||||
configEnum testEnum{
|
||||
{"a", 1},
|
||||
};
|
||||
int64_t enum_value = 0;
|
||||
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, nullptr, nullptr);
|
||||
enum_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = enum_config->Set("b", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::ENUM_VALUE_NOTFOUND);
|
||||
ASSERT_EQ(enum_config->Get(), "a");
|
||||
}
|
||||
|
||||
TEST_F(EnumConfigTest, empty_enum_test) {
|
||||
configEnum testEnum{};
|
||||
int64_t enum_value;
|
||||
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr);
|
||||
ASSERT_DEATH(enum_config->Init(), "empty");
|
||||
}
|
||||
|
||||
TEST_F(EnumConfigTest, valid_check_fail_test) {
|
||||
configEnum testEnum{
|
||||
{"a", 1},
|
||||
{"b", 2},
|
||||
{"c", 3},
|
||||
};
|
||||
int64_t enum_value;
|
||||
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_failure, nullptr);
|
||||
enum_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = enum_config->Set("b", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(enum_config->Get(), "a");
|
||||
}
|
||||
|
||||
TEST_F(EnumConfigTest, update_fail_test) {
|
||||
configEnum testEnum{
|
||||
{"a", 1},
|
||||
{"b", 2},
|
||||
{"c", 3},
|
||||
};
|
||||
int64_t enum_value;
|
||||
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, nullptr, update_failure);
|
||||
enum_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = enum_config->Set("b", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
|
||||
ASSERT_EQ(enum_config->Get(), "a");
|
||||
}
|
||||
|
||||
TEST_F(EnumConfigTest, string_exception_test) {
|
||||
configEnum testEnum{
|
||||
{"a", 1},
|
||||
{"b", 2},
|
||||
{"c", 3},
|
||||
};
|
||||
int64_t enum_value;
|
||||
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_raise_string, nullptr);
|
||||
enum_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = enum_config->Set("b", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
|
||||
ASSERT_EQ(enum_config->Get(), "a");
|
||||
}
|
||||
|
||||
TEST_F(EnumConfigTest, standard_exception_test) {
|
||||
configEnum testEnum{
|
||||
{"a", 1},
|
||||
{"b", 2},
|
||||
{"c", 3},
|
||||
};
|
||||
int64_t enum_value;
|
||||
auto enum_config =
|
||||
CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_raise_exception, nullptr);
|
||||
enum_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = enum_config->Set("b", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
|
||||
ASSERT_EQ(enum_config->Get(), "a");
|
||||
}
|
||||
|
||||
/* SizeConfigTest */
|
||||
class SizeConfigTest : public testing::Test, public Utils<int64_t> {};
|
||||
|
||||
TEST_F(SizeConfigTest, nullptr_init_test) {
|
||||
auto size_config = CreateSizeConfig("i", true, 1024, 4096, nullptr, 2048, nullptr, nullptr);
|
||||
ASSERT_DEATH(size_config->Init(), "nullptr");
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, init_twice_test) {
|
||||
int64_t size_value;
|
||||
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
|
||||
ASSERT_DEATH(
|
||||
{
|
||||
size_config->Init();
|
||||
size_config->Init();
|
||||
},
|
||||
"initialized");
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, non_init_test) {
|
||||
int64_t size_value;
|
||||
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
|
||||
ASSERT_DEATH(size_config->Set("3000", true), "uninitialized");
|
||||
ASSERT_DEATH(size_config->Get(), "uninitialized");
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, immutable_update_test) {
|
||||
int64_t size_value = 0;
|
||||
auto size_config = CreateSizeConfig("i", _IMMUTABLE, 1024, 4096, &size_value, 2048, nullptr, nullptr);
|
||||
size_config->Init();
|
||||
ASSERT_EQ(size_value, 2048);
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("3000", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
|
||||
ASSERT_EQ(size_value, 2048);
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, set_invalid_value_test) {
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, valid_check_fail_test) {
|
||||
int64_t size_value;
|
||||
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_failure, nullptr);
|
||||
size_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("3000", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(size_config->Get(), "2048");
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, update_fail_test) {
|
||||
int64_t size_value;
|
||||
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, update_failure);
|
||||
size_config->Init();
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("3000", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
|
||||
ASSERT_EQ(size_config->Get(), "2048");
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, string_exception_test) {
|
||||
int64_t size_value;
|
||||
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_raise_string, nullptr);
|
||||
size_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("3000", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
|
||||
ASSERT_EQ(size_config->Get(), "2048");
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, standard_exception_test) {
|
||||
int64_t size_value;
|
||||
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_raise_exception, nullptr);
|
||||
size_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("3000", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
|
||||
ASSERT_EQ(size_config->Get(), "2048");
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, out_of_range_test) {
|
||||
int64_t size_value;
|
||||
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
|
||||
size_config->Init();
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("1023", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
|
||||
ASSERT_EQ(size_config->Get(), "2048");
|
||||
}
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("4097", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
|
||||
ASSERT_EQ(size_config->Get(), "2048");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, negative_integer_test) {
|
||||
int64_t size_value;
|
||||
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
|
||||
size_config->Init();
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("-3KB", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(size_config->Get(), "2048");
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, invalid_bound_test) {
|
||||
int64_t size_value;
|
||||
auto size_config = CreateSizeConfig("i", true, 100, 0, &size_value, 50, nullptr, nullptr);
|
||||
size_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("30", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
|
||||
ASSERT_EQ(size_config->Get(), "50");
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, invalid_unit_test) {
|
||||
int64_t size_value;
|
||||
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
|
||||
size_config->Init();
|
||||
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("1 TB", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(size_config->Get(), "2048");
|
||||
}
|
||||
|
||||
TEST_F(SizeConfigTest, invalid_format_test) {
|
||||
int64_t size_value;
|
||||
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
|
||||
size_config->Init();
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("a10GB", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(size_config->Get(), "2048");
|
||||
}
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("200*0", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(size_config->Get(), "2048");
|
||||
}
|
||||
|
||||
{
|
||||
ConfigStatus status(SUCCESS, "");
|
||||
status = size_config->Set("10AB", true);
|
||||
ASSERT_EQ(status.set_return, SetReturn::INVALID);
|
||||
ASSERT_EQ(size_config->Get(), "2048");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus
|
|
@ -1,110 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "config/ConfigType.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
extern std::mutex&
|
||||
GetConfigMutex();
|
||||
|
||||
template <typename T>
|
||||
class ConfigValue {
|
||||
public:
|
||||
explicit ConfigValue(T init_value) : value(std::move(init_value)) {
|
||||
}
|
||||
|
||||
const T&
|
||||
operator()() {
|
||||
std::lock_guard<std::mutex> lock(GetConfigMutex());
|
||||
return value;
|
||||
}
|
||||
|
||||
public:
|
||||
T value;
|
||||
};
|
||||
|
||||
enum ClusterRole {
|
||||
RW = 1,
|
||||
RO,
|
||||
};
|
||||
|
||||
enum SimdType {
|
||||
AUTO = 1,
|
||||
SSE,
|
||||
AVX2,
|
||||
AVX512,
|
||||
};
|
||||
|
||||
const configEnum SimdMap{
|
||||
{"auto", SimdType::AUTO},
|
||||
{"sse", SimdType::SSE},
|
||||
{"avx2", SimdType::AVX2},
|
||||
{"avx512", SimdType::AVX512},
|
||||
};
|
||||
|
||||
struct ServerConfig {
|
||||
using String = ConfigValue<std::string>;
|
||||
using Bool = ConfigValue<bool>;
|
||||
using Integer = ConfigValue<int64_t>;
|
||||
using Floating = ConfigValue<double>;
|
||||
|
||||
String timezone{"unknown"};
|
||||
|
||||
struct Network {
|
||||
String address{"unknown"};
|
||||
Integer port{0};
|
||||
} network;
|
||||
|
||||
struct Pulsar {
|
||||
String address{"localhost"};
|
||||
Integer port{6650};
|
||||
} pulsar;
|
||||
|
||||
struct Engine {
|
||||
Integer build_index_threshold{4096};
|
||||
Integer search_combine_nq{0};
|
||||
Integer use_blas_threshold{0};
|
||||
Integer omp_thread_num{0};
|
||||
Integer simd_type{0};
|
||||
} engine;
|
||||
|
||||
struct Tracing {
|
||||
String json_config_path{"unknown"};
|
||||
} tracing;
|
||||
|
||||
struct Logs {
|
||||
String level{"unknown"};
|
||||
struct Trace {
|
||||
Bool enable{false};
|
||||
} trace;
|
||||
String path{"unknown"};
|
||||
Integer max_log_file_size{0};
|
||||
Integer log_rotate_num{0};
|
||||
} logs;
|
||||
};
|
||||
|
||||
extern ServerConfig config;
|
||||
extern std::mutex _config_mutex;
|
||||
|
||||
std::vector<std::string>
|
||||
ParsePreloadCollection(const std::string&);
|
||||
|
||||
std::vector<int64_t>
|
||||
ParseGPUDevices(const std::string&);
|
||||
} // namespace milvus
|
|
@ -51,6 +51,7 @@ include(DefineOptionsCore)
|
|||
include(BuildUtilsCore)
|
||||
|
||||
using_ccache_if_defined( KNOWHERE_USE_CCACHE )
|
||||
set_directory_properties(PROPERTIES RULE_LAUNCH_COMPILE "")
|
||||
|
||||
if (MILVUS_GPU_VERSION)
|
||||
message(STATUS "Building Knowhere GPU version")
|
||||
|
|
|
@ -16,12 +16,19 @@
|
|||
#include <faiss/Clustering.h>
|
||||
#include <faiss/utils/distances.h>
|
||||
|
||||
#include "config/ServerConfig.h"
|
||||
#include "NGT/lib/NGT/defines.h"
|
||||
#include "faiss/FaissHook.h"
|
||||
#include "faiss/common.h"
|
||||
#include "faiss/utils/utils.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/IndexHNSW.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
#include "scheduler/Utils.h"
|
||||
#include "utils/ConfigUtils.h"
|
||||
#include "utils/Error.h"
|
||||
#include "utils/Log.h"
|
||||
#include "value/config/ServerConfig.h"
|
||||
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <map>
|
||||
|
@ -63,20 +70,6 @@ KnowhereResource::Initialize() {
|
|||
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
|
||||
}
|
||||
|
||||
// engine config
|
||||
int64_t omp_thread = config.engine.omp_thread_num();
|
||||
|
||||
if (omp_thread > 0) {
|
||||
omp_set_num_threads(omp_thread);
|
||||
LOG_SERVER_DEBUG_ << "Specify openmp thread number: " << omp_thread;
|
||||
} else {
|
||||
int64_t sys_thread_cnt = 8;
|
||||
if (milvus::server::GetSystemAvailableThreads(sys_thread_cnt)) {
|
||||
omp_thread = static_cast<int32_t>(ceil(sys_thread_cnt * 0.5));
|
||||
omp_set_num_threads(omp_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// init faiss global variable
|
||||
int64_t use_blas_threshold = config.engine.use_blas_threshold();
|
||||
faiss::distance_compute_blas_threshold = use_blas_threshold;
|
||||
|
@ -95,40 +88,49 @@ KnowhereResource::Initialize() {
|
|||
#ifdef MILVUS_GPU_VERSION
|
||||
bool enable_gpu = config.gpu.enable();
|
||||
fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false);
|
||||
if (!enable_gpu) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (enable_gpu) {
|
||||
struct GpuResourceSetting {
|
||||
int64_t pinned_memory = 256 * M_BYTE;
|
||||
int64_t temp_memory = 256 * M_BYTE;
|
||||
int64_t resource_num = 2;
|
||||
};
|
||||
using GpuResourcesArray = std::map<int64_t, GpuResourceSetting>;
|
||||
GpuResourcesArray gpu_resources;
|
||||
|
||||
struct GpuResourceSetting {
|
||||
int64_t pinned_memory = 256 * M_BYTE;
|
||||
int64_t temp_memory = 256 * M_BYTE;
|
||||
int64_t resource_num = 2;
|
||||
};
|
||||
using GpuResourcesArray = std::map<int64_t, GpuResourceSetting>;
|
||||
GpuResourcesArray gpu_resources;
|
||||
// get build index gpu resource
|
||||
std::vector<int64_t> build_index_gpus = ParseGPUDevices(config.gpu.build_index_devices());
|
||||
|
||||
// get build index gpu resource
|
||||
std::vector<int64_t> build_index_gpus = ParseGPUDevices(config.gpu.build_index_devices());
|
||||
for (auto gpu_id : build_index_gpus) {
|
||||
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
|
||||
}
|
||||
|
||||
for (auto gpu_id : build_index_gpus) {
|
||||
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
|
||||
}
|
||||
// get search gpu resource
|
||||
std::vector<int64_t> search_gpus = ParseGPUDevices(config.gpu.search_devices());
|
||||
|
||||
// get search gpu resource
|
||||
std::vector<int64_t> search_gpus = ParseGPUDevices(config.gpu.search_devices());
|
||||
for (auto& gpu_id : search_gpus) {
|
||||
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
|
||||
}
|
||||
|
||||
for (auto& gpu_id : search_gpus) {
|
||||
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
|
||||
}
|
||||
|
||||
// init gpu resources
|
||||
for (auto& gpu_resource : gpu_resources) {
|
||||
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(gpu_resource.first, gpu_resource.second.pinned_memory,
|
||||
gpu_resource.second.temp_memory,
|
||||
gpu_resource.second.resource_num);
|
||||
// init gpu resources
|
||||
for (auto& gpu_resource : gpu_resources) {
|
||||
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(
|
||||
gpu_resource.first, gpu_resource.second.pinned_memory, gpu_resource.second.temp_memory,
|
||||
gpu_resource.second.resource_num);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
faiss::LOG_ERROR_ = &knowhere::log_error_;
|
||||
faiss::LOG_WARNING_ = &knowhere::log_warning_;
|
||||
// faiss::LOG_DEBUG_ = &knowhere::log_debug_;
|
||||
NGT_LOG_ERROR_ = &knowhere::log_error_;
|
||||
NGT_LOG_WARNING_ = &knowhere::log_warning_;
|
||||
// NGT_LOG_DEBUG_ = &knowhere::log_debug_;
|
||||
|
||||
auto stat_level = config.engine.statistics_level();
|
||||
milvus::knowhere::STATISTICS_LEVEL = stat_level;
|
||||
faiss::STATISTICS_LEVEL = stat_level;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -39,11 +39,21 @@ endif ()
|
|||
|
||||
set(external_srcs
|
||||
knowhere/common/Exception.cpp
|
||||
knowhere/common/Log.cpp
|
||||
knowhere/common/Timer.cpp
|
||||
knowhere/common/Utils.cpp
|
||||
)
|
||||
|
||||
set (LOG_SRC
|
||||
knowhere/common/Log.cpp
|
||||
${MILVUS_THIRDPARTY_SRC}/easyloggingpp/easylogging++.cc
|
||||
)
|
||||
add_library(index_log STATIC ${LOG_SRC})
|
||||
set_target_properties(index_log PROPERTIES RULE_LAUNCH_COMPILE "")
|
||||
set_target_properties(index_log PROPERTIES RULE_LAUNCH_LINK "")
|
||||
include_directories(${MILVUS_THIRDPARTY_SRC})
|
||||
|
||||
set(vector_index_srcs
|
||||
knowhere/index/IndexType.cpp
|
||||
knowhere/index/vector_index/adapter/VectorAdapter.cpp
|
||||
knowhere/index/vector_index/helpers/FaissIO.cpp
|
||||
knowhere/index/vector_index/helpers/IndexParameter.cpp
|
||||
|
@ -61,8 +71,6 @@ set(vector_index_srcs
|
|||
knowhere/index/vector_index/IndexIVF.cpp
|
||||
knowhere/index/vector_index/IndexIVFPQ.cpp
|
||||
knowhere/index/vector_index/IndexIVFSQ.cpp
|
||||
knowhere/index/IndexType.cpp
|
||||
knowhere/index/vector_index/VecIndexFactory.cpp
|
||||
knowhere/index/vector_index/IndexAnnoy.cpp
|
||||
knowhere/index/vector_index/IndexRHNSW.cpp
|
||||
knowhere/index/vector_index/IndexHNSW.cpp
|
||||
|
@ -72,6 +80,8 @@ set(vector_index_srcs
|
|||
knowhere/index/vector_index/IndexNGT.cpp
|
||||
knowhere/index/vector_index/IndexNGTPANNG.cpp
|
||||
knowhere/index/vector_index/IndexNGTONNG.cpp
|
||||
knowhere/index/vector_index/Statistics.cpp
|
||||
knowhere/index/vector_index/VecIndexFactory.cpp
|
||||
)
|
||||
|
||||
set(vector_offset_index_srcs
|
||||
|
@ -96,6 +106,7 @@ set(depend_libs
|
|||
pthread
|
||||
fiu
|
||||
ngt
|
||||
index_log
|
||||
)
|
||||
|
||||
if (MILVUS_SUPPORT_SPTAG)
|
||||
|
@ -143,7 +154,6 @@ endif ()
|
|||
|
||||
target_link_libraries(
|
||||
knowhere
|
||||
milvus_utils
|
||||
${depend_libs}
|
||||
)
|
||||
|
||||
|
|
|
@ -79,6 +79,11 @@ class BinarySet {
|
|||
binary_map_.clear();
|
||||
}
|
||||
|
||||
bool
|
||||
Contains(const std::string& key) const {
|
||||
return binary_map_.find(key) != binary_map_.end();
|
||||
}
|
||||
|
||||
public:
|
||||
std::map<std::string, BinaryPtr> binary_map_;
|
||||
};
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
using Config = milvus::Json;
|
||||
using Config = milvus::json;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/common/Utils.h"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
const char* INDEX_FILE_SLICE_SIZE_IN_MEGABYTE = "SLICE_SIZE";
|
||||
const char* INDEX_FILE_SLICE_META = "SLICE_META";
|
||||
|
||||
static const char* META = "meta";
|
||||
static const char* NAME = "name";
|
||||
static const char* SLICE_NUM = "slice_num";
|
||||
static const char* TOTAL_LEN = "total_len";
|
||||
|
||||
void
|
||||
Slice(const std::string& prefix,
|
||||
const BinaryPtr& data_src,
|
||||
const int64_t& slice_len,
|
||||
BinarySet& binarySet,
|
||||
milvus::json& ret) {
|
||||
if (!data_src) {
|
||||
return;
|
||||
}
|
||||
|
||||
int slice_num = 0;
|
||||
for (int64_t i = 0; i < data_src->size; ++slice_num) {
|
||||
int64_t ri = std::min(i + slice_len, data_src->size);
|
||||
auto size = static_cast<size_t>(ri - i);
|
||||
auto slice_i = reinterpret_cast<uint8_t*>(malloc(size));
|
||||
memcpy(slice_i, data_src->data.get() + i, size);
|
||||
std::shared_ptr<uint8_t[]> slice_i_sp(slice_i, std::default_delete<uint8_t[]>());
|
||||
binarySet.Append(prefix + "_" + std::to_string(slice_num), slice_i_sp, ri - i);
|
||||
i = ri;
|
||||
}
|
||||
ret[NAME] = prefix;
|
||||
ret[SLICE_NUM] = slice_num;
|
||||
ret[TOTAL_LEN] = data_src->size;
|
||||
}
|
||||
|
||||
void
|
||||
Assemble(BinarySet& binarySet) {
|
||||
auto slice_meta = binarySet.Erase(INDEX_FILE_SLICE_META);
|
||||
if (slice_meta == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
milvus::json meta_data =
|
||||
milvus::json::parse(std::string(reinterpret_cast<char*>(slice_meta->data.get()), slice_meta->size));
|
||||
|
||||
for (auto& item : meta_data[META]) {
|
||||
std::string prefix = item[NAME];
|
||||
int slice_num = item[SLICE_NUM];
|
||||
auto total_len = static_cast<size_t>(item[TOTAL_LEN]);
|
||||
auto p_data = reinterpret_cast<uint8_t*>(malloc(total_len));
|
||||
int64_t pos = 0;
|
||||
for (auto i = 0; i < slice_num; ++i) {
|
||||
auto slice_i_sp = binarySet.Erase(prefix + "_" + std::to_string(i));
|
||||
memcpy(p_data + pos, slice_i_sp->data.get(), static_cast<size_t>(slice_i_sp->size));
|
||||
pos += slice_i_sp->size;
|
||||
}
|
||||
std::shared_ptr<uint8_t[]> integral_data(p_data, std::default_delete<uint8_t[]>());
|
||||
binarySet.Append(prefix, integral_data, total_len);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
Disassemble(const int64_t& slice_size_in_byte, BinarySet& binarySet) {
|
||||
milvus::json meta_info;
|
||||
std::vector<std::string> slice_key_list;
|
||||
for (auto& kv : binarySet.binary_map_) {
|
||||
if (kv.second->size > slice_size_in_byte) {
|
||||
slice_key_list.push_back(kv.first);
|
||||
}
|
||||
}
|
||||
for (auto& key : slice_key_list) {
|
||||
milvus::json slice_i;
|
||||
Slice(key, binarySet.Erase(key), slice_size_in_byte, binarySet, slice_i);
|
||||
meta_info[META].emplace_back(slice_i);
|
||||
}
|
||||
if (!slice_key_list.empty()) {
|
||||
auto meta_str = meta_info.dump();
|
||||
std::shared_ptr<uint8_t[]> meta_data(new uint8_t[meta_str.length() + 1], std::default_delete<uint8_t[]>());
|
||||
memcpy(meta_data.get(), meta_str.data(), meta_str.length());
|
||||
meta_data.get()[meta_str.length()] = 0;
|
||||
binarySet.Append(INDEX_FILE_SLICE_META, meta_data, meta_str.length() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "BinarySet.h"
|
||||
#include "Config.h"
|
||||
#include "Exception.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
extern const char* INDEX_FILE_SLICE_SIZE_IN_MEGABYTE;
|
||||
extern const char* INDEX_FILE_SLICE_META;
|
||||
|
||||
void
|
||||
Assemble(BinarySet& binarySet);
|
||||
|
||||
void
|
||||
Disassemble(const int64_t& slice_size_in_byte, BinarySet& binarySet);
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -38,6 +38,8 @@ enum class OldIndexType {
|
|||
RHNSW_FLAT,
|
||||
RHNSW_PQ,
|
||||
RHNSW_SQ,
|
||||
NGTPANNG,
|
||||
NGTONNG,
|
||||
FAISS_BIN_IDMAP = 100,
|
||||
FAISS_BIN_IVFLAT_CPU = 101,
|
||||
};
|
||||
|
|
|
@ -20,13 +20,11 @@
|
|||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
enum class OperatorType { LT = 0, LE = 1, GT = 3, GE = 4 };
|
||||
enum OperatorType { LT = 0, LE = 1, GT = 3, GE = 4 };
|
||||
|
||||
static std::map<std::string, OperatorType> s_map_operator_type = {
|
||||
{"LT", OperatorType::LT},
|
||||
{"LE", OperatorType::LE},
|
||||
{"GT", OperatorType::GT},
|
||||
{"GE", OperatorType::GE},
|
||||
{"LT", OperatorType::LT}, {"LTE", OperatorType::LE}, {"GT", OperatorType::GT}, {"GTE", OperatorType::GE},
|
||||
{"lt", OperatorType::LT}, {"lte", OperatorType::LE}, {"gt", OperatorType::GT}, {"gte", OperatorType::GE},
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/knowhere/common/Log.h"
|
||||
#include "knowhere/index/structured_index/StructuredIndexFlat.h"
|
||||
|
||||
namespace milvus {
|
||||
|
@ -31,6 +31,16 @@ template <typename T>
|
|||
StructuredIndexFlat<T>::~StructuredIndexFlat() {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
BinarySet
|
||||
StructuredIndexFlat<T>::Serialize(const milvus::knowhere::Config& config) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
StructuredIndexFlat<T>::Load(const milvus::knowhere::BinarySet& index_binary) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
StructuredIndexFlat<T>::Build(const size_t n, const T* values) {
|
||||
|
@ -45,14 +55,11 @@ StructuredIndexFlat<T>::Build(const size_t n, const T* values) {
|
|||
template <typename T>
|
||||
const faiss::ConcurrentBitsetPtr
|
||||
StructuredIndexFlat<T>::In(const size_t n, const T* values) {
|
||||
if (!is_built_) {
|
||||
build();
|
||||
}
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
for (const auto& index : data_) {
|
||||
if (index->a_ == *(values + i)) {
|
||||
bitset->set(index->idx_);
|
||||
if (index.a_ == *(values + i)) {
|
||||
bitset->set(index.idx_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -62,14 +69,11 @@ StructuredIndexFlat<T>::In(const size_t n, const T* values) {
|
|||
template <typename T>
|
||||
const faiss::ConcurrentBitsetPtr
|
||||
StructuredIndexFlat<T>::NotIn(const size_t n, const T* values) {
|
||||
if (!is_built_) {
|
||||
build();
|
||||
}
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size(), 0xff);
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
for (const auto& index : data_) {
|
||||
if (index->a_ == *(values + i)) {
|
||||
bitset->clear(index->idx_);
|
||||
if (index.a_ == *(values + i)) {
|
||||
bitset->clear(index.idx_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -79,31 +83,28 @@ StructuredIndexFlat<T>::NotIn(const size_t n, const T* values) {
|
|||
template <typename T>
|
||||
const faiss::ConcurrentBitsetPtr
|
||||
StructuredIndexFlat<T>::Range(const T value, const OperatorType op) {
|
||||
if (!is_built_) {
|
||||
build();
|
||||
}
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
|
||||
auto lb = data_.begin();
|
||||
auto ub = data_.end();
|
||||
for (; lb <= ub; lb++) {
|
||||
switch (op) {
|
||||
case OperatorType::LT:
|
||||
if (lb < IndexStructure<T>(value)) {
|
||||
if (*lb < IndexStructure<T>(value)) {
|
||||
bitset->set(lb->idx_);
|
||||
}
|
||||
break;
|
||||
case OperatorType::LE:
|
||||
if (lb <= IndexStructure<T>(value)) {
|
||||
if (*lb <= IndexStructure<T>(value)) {
|
||||
bitset->set(lb->idx_);
|
||||
}
|
||||
break;
|
||||
case OperatorType::GT:
|
||||
if (lb > IndexStructure<T>(value)) {
|
||||
if (*lb > IndexStructure<T>(value)) {
|
||||
bitset->set(lb->idx_);
|
||||
}
|
||||
break;
|
||||
case OperatorType::GE:
|
||||
if (lb >= IndexStructure<T>(value)) {
|
||||
if (*lb >= IndexStructure<T>(value)) {
|
||||
bitset->set(lb->idx_);
|
||||
}
|
||||
break;
|
||||
|
@ -117,9 +118,6 @@ StructuredIndexFlat<T>::Range(const T value, const OperatorType op) {
|
|||
template <typename T>
|
||||
const faiss::ConcurrentBitsetPtr
|
||||
StructuredIndexFlat<T>::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) {
|
||||
if (!is_built_) {
|
||||
build();
|
||||
}
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
|
||||
if (lower_bound_value > upper_bound_value) {
|
||||
std::swap(lower_bound_value, upper_bound_value);
|
||||
|
@ -129,19 +127,19 @@ StructuredIndexFlat<T>::Range(T lower_bound_value, bool lb_inclusive, T upper_bo
|
|||
auto ub = data_.end();
|
||||
for (; lb <= ub; ++lb) {
|
||||
if (lb_inclusive && ub_inclusive) {
|
||||
if (lb >= IndexStructure<T>(lower_bound_value) && lb <= IndexStructure<T>(upper_bound_value)) {
|
||||
if (*lb >= IndexStructure<T>(lower_bound_value) && *lb <= IndexStructure<T>(upper_bound_value)) {
|
||||
bitset->set(lb->idx_);
|
||||
}
|
||||
} else if (lb_inclusive && !ub_inclusive) {
|
||||
if (lb >= IndexStructure<T>(lower_bound_value) && lb < IndexStructure<T>(upper_bound_value)) {
|
||||
if (*lb >= IndexStructure<T>(lower_bound_value) && *lb < IndexStructure<T>(upper_bound_value)) {
|
||||
bitset->set(lb->idx_);
|
||||
}
|
||||
} else if (!lb_inclusive && ub_inclusive) {
|
||||
if (lb > IndexStructure<T>(lower_bound_value) && lb <= IndexStructure<T>(upper_bound_value)) {
|
||||
if (*lb > IndexStructure<T>(lower_bound_value) && *lb <= IndexStructure<T>(upper_bound_value)) {
|
||||
bitset->set(lb->idx_);
|
||||
}
|
||||
} else {
|
||||
if (lb > IndexStructure<T>(lower_bound_value) && lb < IndexStructure<T>(upper_bound_value)) {
|
||||
if (*lb > IndexStructure<T>(lower_bound_value) && *lb < IndexStructure<T>(upper_bound_value)) {
|
||||
bitset->set(lb->idx_);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,9 +37,6 @@ class StructuredIndexFlat : public StructuredIndex<T> {
|
|||
void
|
||||
Build(const size_t n, const T* values) override;
|
||||
|
||||
void
|
||||
build();
|
||||
|
||||
const faiss::ConcurrentBitsetPtr
|
||||
In(const size_t n, const T* values) override;
|
||||
|
||||
|
@ -59,7 +56,7 @@ class StructuredIndexFlat : public StructuredIndex<T> {
|
|||
|
||||
int64_t
|
||||
Size() override {
|
||||
return (int64_t)data_.size();
|
||||
return (int64_t)data_.size() * sizeof(IndexStructure<T>);
|
||||
}
|
||||
|
||||
bool
|
||||
|
|
|
@ -59,7 +59,7 @@ class StructuredIndexSort : public StructuredIndex<T> {
|
|||
|
||||
int64_t
|
||||
Size() override {
|
||||
return (int64_t)data_.size();
|
||||
return (int64_t)data_.size() * sizeof(IndexStructure<T>);
|
||||
}
|
||||
|
||||
bool
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
@ -24,14 +25,14 @@
|
|||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
static const int64_t MIN_NBITS = 1;
|
||||
static const int64_t MAX_NBITS = 16;
|
||||
static const int64_t MIN_NLIST = 1;
|
||||
static const int64_t MAX_NLIST = 65536;
|
||||
static const int64_t MIN_NPROBE = 1;
|
||||
static const int64_t MAX_NPROBE = MAX_NLIST;
|
||||
static const int64_t DEFAULT_MIN_DIM = 1;
|
||||
static const int64_t DEFAULT_MAX_DIM = 32768;
|
||||
static const int64_t DEFAULT_MIN_ROWS = 1; // minimum size for build index
|
||||
static const int64_t DEFAULT_MAX_ROWS = 50000000;
|
||||
static const int64_t NGT_MIN_EDGE_SIZE = 1;
|
||||
static const int64_t NGT_MAX_EDGE_SIZE = 200;
|
||||
static const int64_t HNSW_MIN_EFCONSTRUCTION = 8;
|
||||
|
@ -47,6 +48,12 @@ static const std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Me
|
|||
return false; \
|
||||
}
|
||||
|
||||
#define CheckFloatByRange(key, min, max) \
|
||||
if (!oricfg.contains(key) || !oricfg[key].is_number_float() || oricfg[key].get<float>() > max || \
|
||||
oricfg[key].get<float>() < min) { \
|
||||
return false; \
|
||||
}
|
||||
|
||||
#define CheckIntByValues(key, container) \
|
||||
if (!oricfg.contains(key) || !oricfg[key].is_number_integer()) { \
|
||||
return false; \
|
||||
|
@ -84,33 +91,42 @@ ConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode m
|
|||
|
||||
int64_t
|
||||
MatchNlist(int64_t size, int64_t nlist) {
|
||||
const int64_t TYPICAL_COUNT = 1000000;
|
||||
const int64_t PER_NLIST = 16384;
|
||||
const int64_t MIN_POINTS_PER_CENTROID = 40;
|
||||
|
||||
if (nlist * TYPICAL_COUNT > size * PER_NLIST) {
|
||||
if (nlist * MIN_POINTS_PER_CENTROID > size) {
|
||||
// nlist is too large, adjust to a proper value
|
||||
nlist = std::max(1L, size * PER_NLIST / TYPICAL_COUNT);
|
||||
nlist = std::max(1L, size / MIN_POINTS_PER_CENTROID);
|
||||
LOG_KNOWHERE_WARNING_ << "Row num " << size << " match nlist " << nlist;
|
||||
}
|
||||
return nlist;
|
||||
}
|
||||
|
||||
int64_t
|
||||
MatchNbits(int64_t size, int64_t nbits) {
|
||||
if (size < (1 << nbits)) {
|
||||
// nbits is too large, adjust to a proper value
|
||||
if (size >= (1 << 8)) {
|
||||
nbits = 8;
|
||||
} else if (size >= (1 << 4)) {
|
||||
nbits = 4;
|
||||
} else if (size >= (1 << 2)) {
|
||||
nbits = 2;
|
||||
} else {
|
||||
nbits = 1;
|
||||
}
|
||||
LOG_KNOWHERE_WARNING_ << "Row num " << size << " match nbits " << nbits;
|
||||
}
|
||||
return nbits;
|
||||
}
|
||||
|
||||
bool
|
||||
IVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
|
||||
// int64_t nlist = oricfg[knowhere::IndexParams::nlist];
|
||||
// CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
|
||||
|
||||
// auto tune params
|
||||
auto nq = oricfg[knowhere::meta::ROWS].get<int64_t>();
|
||||
auto rows = oricfg[knowhere::meta::ROWS].get<int64_t>();
|
||||
auto nlist = oricfg[knowhere::IndexParams::nlist].get<int64_t>();
|
||||
oricfg[knowhere::IndexParams::nlist] = MatchNlist(nq, nlist);
|
||||
|
||||
// Best Practice
|
||||
// static int64_t MIN_POINTS_PER_CENTROID = 40;
|
||||
// static int64_t MAX_POINTS_PER_CENTROID = 256;
|
||||
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
|
||||
oricfg[knowhere::IndexParams::nlist] = MatchNlist(rows, nlist);
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
@ -138,49 +154,38 @@ IVFSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
|||
|
||||
bool
|
||||
IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
const int64_t DEFAULT_NBITS = 8;
|
||||
if (!IVFConfAdapter::CheckTrain(oricfg, mode)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS;
|
||||
CheckIntByRange(knowhere::IndexParams::nbits, MIN_NBITS, MAX_NBITS);
|
||||
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
|
||||
auto rows = oricfg[knowhere::meta::ROWS].get<int64_t>();
|
||||
auto nbits = oricfg[knowhere::IndexParams::nbits].get<int64_t>();
|
||||
oricfg[knowhere::IndexParams::nbits] = MatchNbits(rows, nbits);
|
||||
|
||||
// int64_t nlist = oricfg[knowhere::IndexParams::nlist];
|
||||
// CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
|
||||
|
||||
// auto tune params
|
||||
oricfg[knowhere::IndexParams::nlist] =
|
||||
MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), oricfg[knowhere::IndexParams::nlist].get<int64_t>());
|
||||
auto m = oricfg[knowhere::IndexParams::m].get<int64_t>();
|
||||
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
|
||||
// Best Practice
|
||||
// static int64_t MIN_POINTS_PER_CENTROID = 40;
|
||||
// static int64_t MAX_POINTS_PER_CENTROID = 256;
|
||||
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
|
||||
|
||||
/*std::vector<int64_t> resset;
|
||||
IVFPQConfAdapter::GetValidCPUM(dimension, resset);*/
|
||||
IndexMode ivfpq_mode = mode;
|
||||
return GetValidM(dimension, m, ivfpq_mode);
|
||||
return CheckPQParams(dimension, m, nbits, ivfpq_mode);
|
||||
}
|
||||
|
||||
bool
|
||||
IVFPQConfAdapter::GetValidM(int64_t dimension, int64_t m, IndexMode& mode) {
|
||||
IVFPQConfAdapter::CheckPQParams(int64_t dimension, int64_t m, int64_t nbits, IndexMode& mode) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (mode == knowhere::IndexMode::MODE_GPU && !IVFPQConfAdapter::GetValidGPUM(dimension, m)) {
|
||||
if (mode == knowhere::IndexMode::MODE_GPU && !IVFPQConfAdapter::CheckGPUPQParams(dimension, m, nbits)) {
|
||||
mode = knowhere::IndexMode::MODE_CPU;
|
||||
}
|
||||
#endif
|
||||
if (mode == knowhere::IndexMode::MODE_CPU && !IVFPQConfAdapter::GetValidCPUM(dimension, m)) {
|
||||
if (mode == knowhere::IndexMode::MODE_CPU && !IVFPQConfAdapter::CheckCPUPQParams(dimension, m)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
IVFPQConfAdapter::GetValidGPUM(int64_t dimension, int64_t m) {
|
||||
IVFPQConfAdapter::CheckGPUPQParams(int64_t dimension, int64_t m, int64_t nbits) {
|
||||
/*
|
||||
* Faiss 1.6
|
||||
* Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with
|
||||
|
@ -193,22 +198,12 @@ IVFPQConfAdapter::GetValidGPUM(int64_t dimension, int64_t m) {
|
|||
return (std::find(std::begin(support_subquantizer), std::end(support_subquantizer), m) !=
|
||||
support_subquantizer.end()) &&
|
||||
(std::find(std::begin(support_dim_per_subquantizer), std::end(support_dim_per_subquantizer), sub_dim) !=
|
||||
support_dim_per_subquantizer.end());
|
||||
|
||||
/*resset.clear();
|
||||
for (const auto& dimperquantizer : support_dim_per_subquantizer) {
|
||||
if (!(dimension % dimperquantizer)) {
|
||||
auto subquantzier_num = dimension / dimperquantizer;
|
||||
auto finder = std::find(support_subquantizer.begin(), support_subquantizer.end(), subquantzier_num);
|
||||
if (finder != support_subquantizer.end()) {
|
||||
resset.push_back(subquantzier_num);
|
||||
}
|
||||
}
|
||||
}*/
|
||||
support_dim_per_subquantizer.end()) &&
|
||||
(nbits == 8);
|
||||
}
|
||||
|
||||
bool
|
||||
IVFPQConfAdapter::GetValidCPUM(int64_t dimension, int64_t m) {
|
||||
IVFPQConfAdapter::CheckCPUPQParams(int64_t dimension, int64_t m) {
|
||||
return (dimension % m == 0);
|
||||
}
|
||||
|
||||
|
@ -224,7 +219,6 @@ NSGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
|||
const int64_t MAX_CANDIDATE_POOL_SIZE = 1000;
|
||||
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::knng, MIN_KNNG, MAX_KNNG);
|
||||
CheckIntByRange(knowhere::IndexParams::search_length, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH);
|
||||
CheckIntByRange(knowhere::IndexParams::out_degree, MIN_OUT_DEGREE, MAX_OUT_DEGREE);
|
||||
|
@ -251,7 +245,6 @@ NSGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMod
|
|||
|
||||
bool
|
||||
HNSWConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
|
||||
|
||||
|
@ -267,7 +260,6 @@ HNSWConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMo
|
|||
|
||||
bool
|
||||
RHNSWFlatConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
|
||||
|
||||
|
@ -283,13 +275,12 @@ RHNSWFlatConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const In
|
|||
|
||||
bool
|
||||
RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
|
||||
|
||||
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
|
||||
|
||||
IVFPQConfAdapter::GetValidCPUM(dimension, oricfg[knowhere::IndexParams::PQM].get<int64_t>());
|
||||
IVFPQConfAdapter::CheckCPUPQParams(dimension, oricfg[knowhere::IndexParams::PQM].get<int64_t>());
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
@ -303,7 +294,6 @@ RHNSWPQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const Inde
|
|||
|
||||
bool
|
||||
RHNSWSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
|
||||
|
||||
|
@ -334,18 +324,14 @@ BinIVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
|||
static const std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
|
||||
knowhere::Metric::TANIMOTO};
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
|
||||
int64_t nlist = oricfg[knowhere::IndexParams::nlist];
|
||||
CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
|
||||
|
||||
// Best Practice
|
||||
// static int64_t MIN_POINTS_PER_CENTROID = 40;
|
||||
// static int64_t MAX_POINTS_PER_CENTROID = 256;
|
||||
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
|
||||
// auto tune params
|
||||
auto rows = oricfg[knowhere::meta::ROWS].get<int64_t>();
|
||||
auto nlist = oricfg[knowhere::IndexParams::nlist].get<int64_t>();
|
||||
oricfg[knowhere::IndexParams::nlist] = MatchNlist(rows, nlist);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -370,35 +356,37 @@ ANNOYConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexM
|
|||
|
||||
bool
|
||||
NGTPANNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
CheckIntByRange(knowhere::IndexParams::edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
|
||||
CheckIntByRange(knowhere::IndexParams::forcedly_pruned_edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
|
||||
CheckIntByRange(knowhere::IndexParams::selectively_pruned_edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
|
||||
if (oricfg[knowhere::IndexParams::selectively_pruned_edge_size].get<int64_t>() >=
|
||||
oricfg[knowhere::IndexParams::forcedly_pruned_edge_size].get<int64_t>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
NGTPANNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
CheckIntByRange(knowhere::IndexParams::max_search_edges, -1, NGT_MAX_EDGE_SIZE);
|
||||
CheckFloatByRange(knowhere::IndexParams::epsilon, -1.0, 1.0);
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
NGTONNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
CheckIntByRange(knowhere::IndexParams::edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
|
||||
CheckIntByRange(knowhere::IndexParams::outgoing_edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
|
||||
CheckIntByRange(knowhere::IndexParams::incoming_edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
|
||||
|
||||
return true;
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
NGTONNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
CheckIntByRange(knowhere::IndexParams::max_search_edges, -1, NGT_MAX_EDGE_SIZE);
|
||||
CheckFloatByRange(knowhere::IndexParams::epsilon, -1.0, 1.0);
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
|
|
|
@ -52,13 +52,13 @@ class IVFPQConfAdapter : public IVFConfAdapter {
|
|||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
static bool
|
||||
GetValidM(int64_t dimension, int64_t m, IndexMode& mode);
|
||||
CheckPQParams(int64_t dimension, int64_t m, int64_t nbits, IndexMode& mode);
|
||||
|
||||
static bool
|
||||
GetValidGPUM(int64_t dimension, int64_t m);
|
||||
CheckGPUPQParams(int64_t dimension, int64_t m, int64_t nbits);
|
||||
|
||||
static bool
|
||||
GetValidCPUM(int64_t dimension, int64_t m);
|
||||
CheckCPUPQParams(int64_t dimension, int64_t m);
|
||||
};
|
||||
|
||||
class NSGConfAdapter : public IVFConfAdapter {
|
||||
|
|
|
@ -48,11 +48,15 @@ IndexAnnoy::Serialize(const Config& config) {
|
|||
res_set.Append("annoy_metric_type", metric_type, metric_type_length);
|
||||
res_set.Append("annoy_dim", dim_data, sizeof(uint64_t));
|
||||
res_set.Append("annoy_index_data", index_data, index_length);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
|
||||
}
|
||||
return res_set;
|
||||
}
|
||||
|
||||
void
|
||||
IndexAnnoy::Load(const BinarySet& index_binary) {
|
||||
Assemble(const_cast<BinarySet&>(index_binary));
|
||||
auto metric_type = index_binary.GetByName("annoy_metric_type");
|
||||
metric_type_.resize(static_cast<size_t>(metric_type->size));
|
||||
memcpy(metric_type_.data(), metric_type->data.get(), static_cast<size_t>(metric_type->size));
|
||||
|
@ -105,7 +109,7 @@ IndexAnnoy::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ class IndexAnnoy : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -30,17 +30,23 @@ BinaryIDMAP::Serialize(const Config& config) {
|
|||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl(index_type_);
|
||||
// return SerializeImpl(index_type_);
|
||||
auto ret = SerializeImpl(index_type_);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::Load(const BinarySet& index_binary) {
|
||||
Assemble(const_cast<BinarySet&>(index_binary));
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(index_binary, index_type_);
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
@ -147,7 +153,7 @@ BinaryIDMAP::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
const faiss::BitsetView& bitset) {
|
||||
// assign the metric type
|
||||
auto bin_flat_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get())->index;
|
||||
bin_flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
|
|
|
@ -48,7 +48,7 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
@ -75,7 +75,7 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset);
|
||||
const faiss::BitsetView& bitset);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
|
|
|
@ -33,17 +33,27 @@ BinaryIVF::Serialize(const Config& config) {
|
|||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl(index_type_);
|
||||
auto ret = SerializeImpl(index_type_);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::Load(const BinarySet& index_binary) {
|
||||
Assemble(const_cast<BinarySet&>(index_binary));
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(index_binary, index_type_);
|
||||
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
auto ivf_index = static_cast<faiss::IndexBinaryIVF*>(index_.get());
|
||||
ivf_index->nprobe_statistics.resize(ivf_index->nlist, 0);
|
||||
}
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -104,6 +114,31 @@ BinaryIVF::UpdateIndexSize() {
|
|||
index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size;
|
||||
}
|
||||
|
||||
StatisticsPtr
|
||||
BinaryIVF::GetStatistics() {
|
||||
if (!STATISTICS_LEVEL) {
|
||||
return stats;
|
||||
}
|
||||
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
|
||||
auto lock = ivf_stats->Lock();
|
||||
ivf_stats->update_ivf_access_stats(ivf_index->nprobe_statistics);
|
||||
return ivf_stats;
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::ClearStatistics() {
|
||||
if (!STATISTICS_LEVEL) {
|
||||
return;
|
||||
}
|
||||
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
|
||||
ivf_index->clear_nprobe_statistics();
|
||||
ivf_index->index_ivf_stats.reset();
|
||||
auto lock = ivf_stats->Lock();
|
||||
ivf_stats->clear();
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GET_TENSOR(dataset_ptr)
|
||||
|
@ -112,6 +147,7 @@ BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
faiss::IndexBinary* coarse_quantizer = new faiss::IndexBinaryFlat(dim, metric_type);
|
||||
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, nlist, metric_type);
|
||||
index->own_fields = true;
|
||||
index->train(rows, static_cast<const uint8_t*>(p_data));
|
||||
index->add_with_ids(rows, static_cast<const uint8_t*>(p_data), p_ids);
|
||||
index_ = index;
|
||||
|
@ -132,22 +168,37 @@ BinaryIVF::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
const faiss::BitsetView& bitset) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
|
||||
stdclock::time_point before = stdclock::now();
|
||||
auto i_distances = reinterpret_cast<int32_t*>(distances);
|
||||
|
||||
index_->search(n, data, k, i_distances, labels, bitset);
|
||||
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost
|
||||
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
|
||||
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
|
||||
faiss::indexIVF_stats.quantization_time = 0;
|
||||
faiss::indexIVF_stats.search_time = 0;
|
||||
LOG_KNOWHERE_DEBUG_ << "IVF_NM search cost: " << search_cost
|
||||
<< ", quantization cost: " << ivf_index->index_ivf_stats.quantization_time
|
||||
<< ", data search cost: " << ivf_index->index_ivf_stats.search_time;
|
||||
|
||||
if (STATISTICS_LEVEL) {
|
||||
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
|
||||
auto lock = ivf_stats->Lock();
|
||||
if (STATISTICS_LEVEL >= 1) {
|
||||
ivf_stats->update_nq(n);
|
||||
ivf_stats->count_nprobe(ivf_index->nprobe);
|
||||
ivf_stats->update_total_query_time(ivf_index->index_ivf_stats.quantization_time +
|
||||
ivf_index->index_ivf_stats.search_time);
|
||||
ivf_index->index_ivf_stats.quantization_time = 0;
|
||||
ivf_index->index_ivf_stats.search_time = 0;
|
||||
}
|
||||
if (STATISTICS_LEVEL >= 2) {
|
||||
ivf_stats->update_filter_percentage(bitset);
|
||||
}
|
||||
}
|
||||
|
||||
// if hamming, it need transform int32 to float
|
||||
if (ivf_index->metric_type == faiss::METRIC_Hamming) {
|
||||
|
|
|
@ -29,10 +29,12 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
|
|||
public:
|
||||
BinaryIVF() : FaissBaseBinaryIndex(nullptr) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
|
||||
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
|
||||
}
|
||||
|
||||
explicit BinaryIVF(std::shared_ptr<faiss::IndexBinary> index) : FaissBaseBinaryIndex(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
|
||||
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
|
||||
}
|
||||
|
||||
BinarySet
|
||||
|
@ -60,7 +62,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
@ -71,6 +73,12 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
|
|||
void
|
||||
UpdateIndexSize() override;
|
||||
|
||||
StatisticsPtr
|
||||
GetStatistics() override;
|
||||
|
||||
void
|
||||
ClearStatistics() override;
|
||||
|
||||
protected:
|
||||
virtual std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GenParams(const Config& config);
|
||||
|
@ -82,7 +90,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset);
|
||||
const faiss::BitsetView& bitset);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
@ -20,6 +21,7 @@
|
|||
|
||||
#include "faiss/BuilderSuspend.h"
|
||||
#include "hnswlib/hnswalg.h"
|
||||
#include "hnswlib/hnswlib.h"
|
||||
#include "hnswlib/space_ip.h"
|
||||
#include "hnswlib/space_l2.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
@ -51,6 +53,9 @@ IndexHNSW::Serialize(const Config& config) {
|
|||
|
||||
BinarySet res_set;
|
||||
res_set.Append("HNSW", data, writer.rp);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
|
||||
}
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
|
@ -60,6 +65,7 @@ IndexHNSW::Serialize(const Config& config) {
|
|||
void
|
||||
IndexHNSW::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
Assemble(const_cast<BinarySet&>(index_binary));
|
||||
auto binary = index_binary.GetByName("HNSW");
|
||||
|
||||
MemoryIOReader reader;
|
||||
|
@ -68,7 +74,15 @@ IndexHNSW::Load(const BinarySet& index_binary) {
|
|||
|
||||
hnswlib::SpaceInterface<float>* space = nullptr;
|
||||
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space);
|
||||
index_->stats_enable = (STATISTICS_LEVEL >= 3);
|
||||
index_->loadIndex(reader);
|
||||
auto hnsw_stats = std::static_pointer_cast<LibHNSWStatistics>(stats);
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
auto lock = hnsw_stats->Lock();
|
||||
hnsw_stats->update_level_distribution(index_->maxlevel_, index_->level_stats_);
|
||||
}
|
||||
// LOG_KNOWHERE_DEBUG_ << "IndexHNSW::Load finished, show statistics:";
|
||||
// LOG_KNOWHERE_DEBUG_ << hnsw_stats->ToString();
|
||||
|
||||
normalize = index_->metric_type_ == 1; // 1 == InnerProduct
|
||||
} catch (std::exception& e) {
|
||||
|
@ -94,6 +108,7 @@ IndexHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space, rows, config[IndexParams::M].get<int64_t>(),
|
||||
config[IndexParams::efConstruction].get<int64_t>());
|
||||
index_->stats_enable = (STATISTICS_LEVEL >= 3);
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
|
@ -133,10 +148,17 @@ IndexHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
faiss::BuilderSuspend::check_wait();
|
||||
index_->addPoint((reinterpret_cast<const float*>(p_data) + Dim() * i), p_ids[i]);
|
||||
}
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
auto hnsw_stats = std::static_pointer_cast<LibHNSWStatistics>(stats);
|
||||
auto lock = hnsw_stats->Lock();
|
||||
hnsw_stats->update_level_distribution(index_->maxlevel_, index_->level_stats_);
|
||||
}
|
||||
// LOG_KNOWHERE_DEBUG_ << "IndexHNSW::Train finished, show statistics:";
|
||||
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -147,12 +169,23 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
|
|||
size_t dist_size = sizeof(float) * k;
|
||||
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
|
||||
auto p_dist = static_cast<float*>(malloc(dist_size * rows));
|
||||
std::vector<hnswlib::StatisticsInfo> query_stats;
|
||||
auto hnsw_stats = std::dynamic_pointer_cast<LibHNSWStatistics>(stats);
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
query_stats.resize(rows);
|
||||
for (auto i = 0; i < rows; ++i) {
|
||||
query_stats[i].target_level = hnsw_stats->target_level;
|
||||
}
|
||||
}
|
||||
|
||||
index_->setEf(config[IndexParams::ef]);
|
||||
index_->setEf(config[IndexParams::ef].get<int64_t>());
|
||||
|
||||
using P = std::pair<float, int64_t>;
|
||||
auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; };
|
||||
|
||||
std::chrono::high_resolution_clock::time_point query_start, query_end;
|
||||
query_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < rows; ++i) {
|
||||
std::vector<P> ret;
|
||||
|
@ -165,7 +198,12 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
|
|||
// } else {
|
||||
// ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
|
||||
// }
|
||||
ret = index_->searchKnn(single_query, k, compare, bitset);
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
ret = index_->searchKnn(single_query, k, compare, bitset, query_stats[i]);
|
||||
} else {
|
||||
auto dummy_stat = hnswlib::StatisticsInfo();
|
||||
ret = index_->searchKnn(single_query, k, compare, bitset, dummy_stat);
|
||||
}
|
||||
|
||||
while (ret.size() < k) {
|
||||
ret.emplace_back(std::make_pair(-1, -1));
|
||||
|
@ -186,6 +224,33 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
|
|||
memcpy(p_dist + i * k, dist.data(), dist_size);
|
||||
memcpy(p_id + i * k, ids.data(), id_size);
|
||||
}
|
||||
query_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if (STATISTICS_LEVEL) {
|
||||
auto lock = hnsw_stats->Lock();
|
||||
if (STATISTICS_LEVEL >= 1) {
|
||||
hnsw_stats->update_nq(rows);
|
||||
hnsw_stats->update_ef_sum(index_->ef_ * rows);
|
||||
hnsw_stats->update_total_query_time(
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(query_end - query_start).count());
|
||||
}
|
||||
if (STATISTICS_LEVEL >= 2) {
|
||||
hnsw_stats->update_filter_percentage(bitset);
|
||||
}
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
for (auto i = 0; i < rows; ++i) {
|
||||
for (auto j = 0; j < query_stats[i].accessed_points.size(); ++j) {
|
||||
auto tgt = hnsw_stats->access_cnt_map.find(query_stats[i].accessed_points[j]);
|
||||
if (tgt == hnsw_stats->access_cnt_map.end())
|
||||
hnsw_stats->access_cnt_map[query_stats[i].accessed_points[j]] = 1;
|
||||
else
|
||||
tgt->second += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// LOG_KNOWHERE_DEBUG_ << "IndexHNSW::Query finished, show statistics:";
|
||||
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
|
@ -217,5 +282,14 @@ IndexHNSW::UpdateIndexSize() {
|
|||
index_size_ = index_->cal_size();
|
||||
}
|
||||
|
||||
void
|
||||
IndexHNSW::ClearStatistics() {
|
||||
if (!STATISTICS_LEVEL)
|
||||
return;
|
||||
auto hnsw_stats = std::static_pointer_cast<LibHNSWStatistics>(stats);
|
||||
auto lock = hnsw_stats->Lock();
|
||||
hnsw_stats->clear();
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -26,6 +26,7 @@ class IndexHNSW : public VecIndex {
|
|||
public:
|
||||
IndexHNSW() {
|
||||
index_type_ = IndexEnum::INDEX_HNSW;
|
||||
stats = std::make_shared<milvus::knowhere::LibHNSWStatistics>(index_type_);
|
||||
}
|
||||
|
||||
BinarySet
|
||||
|
@ -46,7 +47,7 @@ class IndexHNSW : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
@ -57,6 +58,9 @@ class IndexHNSW : public VecIndex {
|
|||
void
|
||||
UpdateIndexSize() override;
|
||||
|
||||
void
|
||||
ClearStatistics() override;
|
||||
|
||||
private:
|
||||
bool normalize = false;
|
||||
std::mutex mutex_;
|
||||
|
|
|
@ -43,11 +43,16 @@ IDMAP::Serialize(const Config& config) {
|
|||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl(index_type_);
|
||||
auto ret = SerializeImpl(index_type_);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
IDMAP::Load(const BinarySet& binary_set) {
|
||||
Assemble(const_cast<BinarySet&>(binary_set));
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(binary_set, index_type_);
|
||||
}
|
||||
|
@ -95,7 +100,7 @@ IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
@ -229,7 +234,7 @@ IDMAP::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
const faiss::BitsetView& bitset) {
|
||||
// assign the metric type
|
||||
auto flat_index = dynamic_cast<faiss::IndexIDMap*>(index_.get())->index;
|
||||
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
|
|
|
@ -46,7 +46,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
|
@ -80,7 +80,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
|
|||
|
||||
protected:
|
||||
virtual void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&);
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView&);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
|
|
|
@ -54,24 +54,37 @@ IVF::Serialize(const Config& config) {
|
|||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl(index_type_);
|
||||
auto ret = SerializeImpl(index_type_);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
IVF::Load(const BinarySet& binary_set) {
|
||||
Assemble(const_cast<BinarySet&>(binary_set));
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
|
||||
LoadImpl(binary_set, index_type_);
|
||||
|
||||
if (IndexMode() == IndexMode::MODE_CPU && STATISTICS_LEVEL >= 3) {
|
||||
auto ivf_index = static_cast<faiss::IndexIVFFlat*>(index_.get());
|
||||
ivf_index->nprobe_statistics.resize(ivf_index->nlist, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
|
||||
int64_t nlist = config[IndexParams::nlist].get<int64_t>();
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
|
||||
auto nlist = config[IndexParams::nlist].get<int64_t>();
|
||||
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexIVFFlat(coarse_quantizer, dim, nlist, metric_type));
|
||||
index_->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, nlist, metric_type);
|
||||
index->own_fields = true;
|
||||
index->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
index_ = index;
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -97,7 +110,7 @@ IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -245,7 +258,7 @@ IVF::UpdateIndexSize() {
|
|||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVFFlat*>(index_.get());
|
||||
auto ivf_index = static_cast<faiss::IndexIVFFlat*>(index_.get());
|
||||
auto nb = ivf_index->invlists->compute_ntotal();
|
||||
auto nlist = ivf_index->nlist;
|
||||
auto code_size = ivf_index->code_size;
|
||||
|
@ -324,7 +337,7 @@ IVF::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
const faiss::BitsetView& bitset) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->nprobe = std::min(params->nprobe, ivf_index->invlists->nlist);
|
||||
|
@ -334,14 +347,30 @@ IVF::QueryImpl(int64_t n,
|
|||
} else {
|
||||
ivf_index->parallel_mode = 0;
|
||||
}
|
||||
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
|
||||
ivf_index->search(n, data, k, distances, labels, bitset);
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost
|
||||
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
|
||||
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
|
||||
faiss::indexIVF_stats.quantization_time = 0;
|
||||
faiss::indexIVF_stats.search_time = 0;
|
||||
if (STATISTICS_LEVEL) {
|
||||
auto lock = ivf_stats->Lock();
|
||||
if (STATISTICS_LEVEL >= 1) {
|
||||
ivf_stats->update_nq(n);
|
||||
ivf_stats->count_nprobe(ivf_index->nprobe);
|
||||
|
||||
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost
|
||||
<< ", quantization cost: " << ivf_index->index_ivf_stats.quantization_time
|
||||
<< ", data search cost: " << ivf_index->index_ivf_stats.search_time;
|
||||
ivf_stats->update_total_query_time(ivf_index->index_ivf_stats.quantization_time +
|
||||
ivf_index->index_ivf_stats.search_time);
|
||||
ivf_index->index_ivf_stats.quantization_time = 0;
|
||||
ivf_index->index_ivf_stats.search_time = 0;
|
||||
}
|
||||
if (STATISTICS_LEVEL >= 2) {
|
||||
ivf_stats->update_filter_percentage(bitset);
|
||||
}
|
||||
}
|
||||
// LOG_KNOWHERE_DEBUG_ << "IndexIVF::QueryImpl finished, show statistics:";
|
||||
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -355,5 +384,30 @@ IVF::SealImpl() {
|
|||
#endif
|
||||
}
|
||||
|
||||
StatisticsPtr
|
||||
IVF::GetStatistics() {
|
||||
if (IndexMode() != IndexMode::MODE_CPU || !STATISTICS_LEVEL) {
|
||||
return stats;
|
||||
}
|
||||
auto ivf_stats = std::static_pointer_cast<IVFStatistics>(stats);
|
||||
auto ivf_index = static_cast<faiss::IndexIVF*>(index_.get());
|
||||
auto lock = ivf_stats->Lock();
|
||||
ivf_stats->update_ivf_access_stats(ivf_index->nprobe_statistics);
|
||||
return ivf_stats;
|
||||
}
|
||||
|
||||
void
|
||||
IVF::ClearStatistics() {
|
||||
if (IndexMode() != IndexMode::MODE_CPU || !STATISTICS_LEVEL) {
|
||||
return;
|
||||
}
|
||||
auto ivf_stats = std::static_pointer_cast<IVFStatistics>(stats);
|
||||
auto ivf_index = static_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->clear_nprobe_statistics();
|
||||
ivf_index->index_ivf_stats.reset();
|
||||
auto lock = ivf_stats->Lock();
|
||||
ivf_stats->clear();
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -29,10 +29,12 @@ class IVF : public VecIndex, public FaissBaseIndex {
|
|||
public:
|
||||
IVF() : FaissBaseIndex(nullptr) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
|
||||
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
|
||||
}
|
||||
|
||||
explicit IVF(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
|
||||
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
|
||||
}
|
||||
|
||||
BinarySet
|
||||
|
@ -51,7 +53,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
|
@ -67,6 +69,12 @@ class IVF : public VecIndex, public FaissBaseIndex {
|
|||
void
|
||||
UpdateIndexSize() override;
|
||||
|
||||
StatisticsPtr
|
||||
GetStatistics() override;
|
||||
|
||||
void
|
||||
ClearStatistics() override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
GetVectorById(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
@ -86,7 +94,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
|
|||
GenParams(const Config&);
|
||||
|
||||
virtual void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&);
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView&);
|
||||
|
||||
void
|
||||
SealImpl() override;
|
||||
|
|
|
@ -38,11 +38,12 @@ IVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
|
||||
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexIVFPQ(
|
||||
coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(), config[IndexParams::m].get<int64_t>(),
|
||||
config[IndexParams::nbits].get<int64_t>(), metric_type));
|
||||
|
||||
index_->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
auto index = std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(),
|
||||
config[IndexParams::m].get<int64_t>(),
|
||||
config[IndexParams::nbits].get<int64_t>(), metric_type);
|
||||
index->own_fields = true;
|
||||
index->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
index_ = index;
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
|
@ -51,7 +52,8 @@ IVFPQ::CopyCpuToGpu(const int64_t device_id, const Config& config) {
|
|||
auto ivfpq_index = dynamic_cast<faiss::IndexIVFPQ*>(index_.get());
|
||||
int64_t dim = ivfpq_index->d;
|
||||
int64_t m = ivfpq_index->pq.M;
|
||||
if (!IVFPQConfAdapter::GetValidGPUM(dim, m)) {
|
||||
int64_t nbits = ivfpq_index->pq.nbits;
|
||||
if (!IVFPQConfAdapter::CheckGPUPQParams(dim, m, nbits)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
|
|
|
@ -23,10 +23,12 @@ class IVFPQ : public IVF {
|
|||
public:
|
||||
IVFPQ() : IVF() {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFPQ;
|
||||
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
|
||||
}
|
||||
|
||||
explicit IVFPQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFPQ;
|
||||
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -37,18 +37,13 @@ void
|
|||
IVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
|
||||
// std::stringstream index_type;
|
||||
// index_type << "IVF" << config[IndexParams::nlist] << ","
|
||||
// << "SQ" << config[IndexParams::nbits];
|
||||
// index_ = std::shared_ptr<faiss::Index>(
|
||||
// faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>())));
|
||||
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
|
||||
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexIVFScalarQuantizer(
|
||||
coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(), faiss::QuantizerType::QT_8bit, metric_type));
|
||||
|
||||
index_->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
auto index = std::make_shared<faiss::IndexIVFScalarQuantizer>(
|
||||
coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(), faiss::QuantizerType::QT_8bit, metric_type);
|
||||
index->own_fields = true;
|
||||
index->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
index_ = index;
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
|
|
|
@ -23,10 +23,12 @@ class IVFSQ : public IVF {
|
|||
public:
|
||||
IVFSQ() : IVF() {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8;
|
||||
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
|
||||
}
|
||||
|
||||
explicit IVFSQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8;
|
||||
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -52,11 +52,15 @@ IndexNGT::Serialize(const Config& config) {
|
|||
res_set.Append("ngt_grp_data", grp_data, grp_size);
|
||||
res_set.Append("ngt_prf_data", prf_data, prf_size);
|
||||
res_set.Append("ngt_tre_data", tre_data, tre_size);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
|
||||
}
|
||||
return res_set;
|
||||
}
|
||||
|
||||
void
|
||||
IndexNGT::Load(const BinarySet& index_binary) {
|
||||
Assemble(const_cast<BinarySet&>(index_binary));
|
||||
auto obj_data = index_binary.GetByName("ngt_obj_data");
|
||||
std::string obj_str(reinterpret_cast<char*>(obj_data->data.get()), obj_data->size);
|
||||
|
||||
|
@ -118,13 +122,18 @@ IndexNGT::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
#endif
|
||||
|
||||
DatasetPtr
|
||||
IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset_ptr);
|
||||
|
||||
size_t k = config[meta::TOPK].get<int64_t>();
|
||||
int k = config[meta::TOPK].get<int>();
|
||||
auto epsilon = config[IndexParams::epsilon].get<float>();
|
||||
auto edge_size = config[IndexParams::max_search_edges].get<int>();
|
||||
if (edge_size == -1) { // pass -1
|
||||
edge_size--;
|
||||
}
|
||||
size_t id_size = sizeof(int64_t) * k;
|
||||
size_t dist_size = sizeof(float) * k;
|
||||
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
|
||||
|
@ -140,11 +149,11 @@ IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss
|
|||
NGT::Object* object = index_->allocateObject(single_query, Dim());
|
||||
NGT::SearchContainer sc(*object);
|
||||
|
||||
double epsilon = sp.beginOfEpsilon;
|
||||
// double epsilon = sp.beginOfEpsilon;
|
||||
|
||||
NGT::ObjectDistances res;
|
||||
sc.setResults(&res);
|
||||
sc.setSize(sp.size);
|
||||
sc.setSize(static_cast<size_t>(sp.size));
|
||||
sc.setRadius(sp.radius);
|
||||
|
||||
if (sp.accuracy > 0.0) {
|
||||
|
@ -152,7 +161,8 @@ IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss
|
|||
} else {
|
||||
sc.setEpsilon(epsilon);
|
||||
}
|
||||
sc.setEdgeSize(sp.edgeSize);
|
||||
// sc.setEdgeSize(sp.edgeSize);
|
||||
sc.setEdgeSize(edge_size);
|
||||
|
||||
try {
|
||||
index_->search(sc, bitset);
|
||||
|
@ -164,9 +174,13 @@ IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss
|
|||
auto local_dist = p_dist + i * k;
|
||||
|
||||
int64_t res_num = res.size();
|
||||
float dis_coefficient = 1.0;
|
||||
if (index_->getObjectSpace().getDistanceType() == NGT::ObjectSpace::DistanceType::DistanceTypeIP) {
|
||||
dis_coefficient = -1.0;
|
||||
}
|
||||
for (int64_t idx = 0; idx < res_num; ++idx) {
|
||||
*(local_id + idx) = res[idx].id - 1;
|
||||
*(local_dist + idx) = res[idx].distance;
|
||||
*(local_dist + idx) = res[idx].distance * dis_coefficient;
|
||||
}
|
||||
while (res_num < static_cast<int64_t>(k)) {
|
||||
*(local_id + res_num) = -1;
|
||||
|
@ -197,5 +211,10 @@ IndexNGT::Dim() {
|
|||
return index_->getDimension();
|
||||
}
|
||||
|
||||
void
|
||||
IndexNGT::UpdateIndexSize() {
|
||||
KNOWHERE_THROW_MSG("IndexNGT has no implementation of UpdateIndexSize, please use IndexNGT(PANNG/ONNG) instead!");
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -54,7 +54,7 @@ class IndexNGT : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
@ -62,6 +62,9 @@ class IndexNGT : public VecIndex {
|
|||
int64_t
|
||||
Dim() override;
|
||||
|
||||
void
|
||||
UpdateIndexSize() override;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<NGT::Index> index_ = nullptr;
|
||||
};
|
||||
|
|
|
@ -38,10 +38,8 @@ IndexNGTONNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
|
||||
if (metric_type == Metric::L2) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
|
||||
} else if (metric_type == Metric::HAMMING) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
|
||||
} else if (metric_type == Metric::JACCARD) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
|
||||
} else if (metric_type == Metric::IP) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeIP;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
|
||||
}
|
||||
|
@ -50,7 +48,7 @@ IndexNGTONNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
|
||||
|
||||
// reconstruct graph
|
||||
NGT::GraphOptimizer graphOptimizer(true);
|
||||
NGT::GraphOptimizer graphOptimizer(false);
|
||||
|
||||
auto number_of_outgoing_edges = config[IndexParams::outgoing_edge_size].get<size_t>();
|
||||
auto number_of_incoming_edges = config[IndexParams::incoming_edge_size].get<size_t>();
|
||||
|
@ -67,5 +65,13 @@ IndexNGTONNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
graphOptimizer.execute(*index_);
|
||||
}
|
||||
|
||||
void
|
||||
IndexNGTONNG::UpdateIndexSize() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
index_size_ = index_->memSize();
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -24,6 +24,9 @@ class IndexNGTONNG : public IndexNGT {
|
|||
|
||||
void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
UpdateIndexSize() override;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -27,16 +27,14 @@ IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
prop.dimension = dim;
|
||||
|
||||
auto edge_size = config[IndexParams::edge_size].get<int64_t>();
|
||||
prop.edgeSizeLimitForCreation = edge_size;
|
||||
prop.edgeSizeForCreation = edge_size;
|
||||
|
||||
MetricType metric_type = config[Metric::TYPE];
|
||||
|
||||
if (metric_type == Metric::L2) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
|
||||
} else if (metric_type == Metric::HAMMING) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
|
||||
} else if (metric_type == Metric::JACCARD) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
|
||||
} else if (metric_type == Metric::IP) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeIP;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
|
||||
}
|
||||
|
@ -48,6 +46,8 @@ IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
auto selectively_pruned_edge_size = config[IndexParams::selectively_pruned_edge_size].get<int64_t>();
|
||||
|
||||
if (!forcedly_pruned_edge_size && !selectively_pruned_edge_size) {
|
||||
KNOWHERE_THROW_MSG(
|
||||
"a lack of parameters forcedly_pruned_edge_size and selectively_pruned_edge_size 4 index NGTPANNG");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -56,11 +56,23 @@ IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
KNOWHERE_THROW_MSG("Selectively pruned edge size should less than remaining edge size");
|
||||
}
|
||||
|
||||
// std::map<size_t, size_t> stats;
|
||||
// size_t max_len = 0;
|
||||
|
||||
// prune
|
||||
auto& graph = dynamic_cast<NGT::GraphIndex&>(index_->getIndex());
|
||||
for (size_t id = 1; id < graph.repository.size(); id++) {
|
||||
try {
|
||||
NGT::GraphNode& node = *graph.getNode(id);
|
||||
// auto sz = node.size();
|
||||
// if (max_len < sz)
|
||||
// max_len = sz;
|
||||
// auto fd = stats.find(sz);
|
||||
// if (fd != stats.end()) {
|
||||
// fd->second ++;
|
||||
// } else {
|
||||
// stats[sz] = 1;
|
||||
// }
|
||||
if (node.size() >= forcedly_pruned_edge_size) {
|
||||
node.resize(forcedly_pruned_edge_size);
|
||||
}
|
||||
|
@ -73,7 +85,7 @@ IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
if (t1 >= selectively_pruned_edge_size) {
|
||||
break;
|
||||
}
|
||||
if (rank == t1) {
|
||||
if (rank == t1) { // can't reach here
|
||||
continue;
|
||||
}
|
||||
NGT::GraphNode& node2 = *graph.getNode(node[t1].id);
|
||||
|
@ -101,6 +113,25 @@ IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
continue;
|
||||
}
|
||||
}
|
||||
/*
|
||||
std::vector<size_t> cnt(max_len, 0);
|
||||
for (auto &pr : stats) {
|
||||
cnt[pr.first] = pr.second;
|
||||
}
|
||||
for (auto i = 0; i < cnt.size(); ++ i) {
|
||||
if (cnt[i]) {
|
||||
std::cout << "len = " << i << ", cnt = " << cnt[i] << std::endl;
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
void
|
||||
IndexNGTPANNG::UpdateIndexSize() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
index_size_ = index_->memSize();
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -24,6 +24,9 @@ class IndexNGTPANNG : public IndexNGT {
|
|||
|
||||
void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
UpdateIndexSize() override;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -48,6 +48,9 @@ NSG::Serialize(const Config& config) {
|
|||
|
||||
BinarySet res_set;
|
||||
res_set.Append("NSG", data, writer.rp);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
|
||||
}
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
|
@ -57,6 +60,7 @@ NSG::Serialize(const Config& config) {
|
|||
void
|
||||
NSG::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
Assemble(const_cast<BinarySet&>(index_binary));
|
||||
fiu_do_on("NSG.Load.throw_exception", throw std::exception());
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
auto binary = index_binary.GetByName("NSG");
|
||||
|
@ -73,7 +77,7 @@ NSG::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ class NSG : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
@ -49,6 +50,7 @@ IndexRHNSW::Serialize(const Config& config) {
|
|||
void
|
||||
IndexRHNSW::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
Assemble(const_cast<BinarySet&>(index_binary));
|
||||
MemoryIOReader reader;
|
||||
reader.name = this->index_type() + "_Index";
|
||||
auto binary = index_binary.GetByName(reader.name);
|
||||
|
@ -57,6 +59,15 @@ IndexRHNSW::Load(const BinarySet& index_binary) {
|
|||
reader.data_ = binary->data.get();
|
||||
|
||||
auto idx = faiss::read_index(&reader);
|
||||
auto hnsw_stats = std::static_pointer_cast<RHNSWStatistics>(stats);
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
auto real_idx = static_cast<faiss::IndexRHNSW*>(idx);
|
||||
auto lock = hnsw_stats->Lock();
|
||||
hnsw_stats->update_level_distribution(real_idx->hnsw.max_level, real_idx->hnsw.level_stats);
|
||||
real_idx->set_target_level(hnsw_stats->target_level);
|
||||
// LOG_KNOWHERE_DEBUG_ << "IndexRHNSW::Load finished, show statistics:";
|
||||
// LOG_KNOWHERE_DEBUG_ << hnsw_stats->ToString();
|
||||
}
|
||||
index_.reset(idx);
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
|
@ -76,10 +87,19 @@ IndexRHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
GET_TENSOR_DATA(dataset_ptr)
|
||||
|
||||
index_->add(rows, reinterpret_cast<const float*>(p_data));
|
||||
auto hnsw_stats = std::static_pointer_cast<RHNSWStatistics>(stats);
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
auto real_idx = static_cast<faiss::IndexRHNSW*>(index_.get());
|
||||
auto lock = hnsw_stats->Lock();
|
||||
hnsw_stats->update_level_distribution(real_idx->hnsw.max_level, real_idx->hnsw.level_stats);
|
||||
real_idx->set_target_level(hnsw_stats->target_level);
|
||||
}
|
||||
// LOG_KNOWHERE_DEBUG_ << "IndexRHNSW::Load finished, show statistics:";
|
||||
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -90,6 +110,7 @@ IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fai
|
|||
int64_t dist_size = sizeof(float) * k;
|
||||
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
|
||||
auto p_dist = static_cast<float*>(malloc(dist_size * rows));
|
||||
auto hnsw_stats = std::dynamic_pointer_cast<RHNSWStatistics>(stats);
|
||||
for (auto i = 0; i < k * rows; ++i) {
|
||||
p_id[i] = -1;
|
||||
p_dist[i] = -1;
|
||||
|
@ -97,8 +118,26 @@ IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fai
|
|||
|
||||
auto real_index = dynamic_cast<faiss::IndexRHNSW*>(index_.get());
|
||||
|
||||
real_index->hnsw.efSearch = (config[IndexParams::ef]);
|
||||
real_index->hnsw.efSearch = (config[IndexParams::ef].get<int64_t>());
|
||||
|
||||
std::chrono::high_resolution_clock::time_point query_start, query_end;
|
||||
query_start = std::chrono::high_resolution_clock::now();
|
||||
real_index->search(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, bitset);
|
||||
query_end = std::chrono::high_resolution_clock::now();
|
||||
if (STATISTICS_LEVEL) {
|
||||
auto lock = hnsw_stats->Lock();
|
||||
if (STATISTICS_LEVEL >= 1) {
|
||||
hnsw_stats->update_nq(rows);
|
||||
hnsw_stats->update_ef_sum(real_index->hnsw.efSearch * rows);
|
||||
hnsw_stats->update_total_query_time(
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(query_end - query_start).count());
|
||||
}
|
||||
if (STATISTICS_LEVEL >= 2) {
|
||||
hnsw_stats->update_filter_percentage(bitset);
|
||||
}
|
||||
}
|
||||
// LOG_KNOWHERE_DEBUG_ << "IndexRHNSW::Load finished, show statistics:";
|
||||
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
|
@ -122,6 +161,30 @@ IndexRHNSW::Dim() {
|
|||
return index_->d;
|
||||
}
|
||||
|
||||
StatisticsPtr
|
||||
IndexRHNSW::GetStatistics() {
|
||||
if (!STATISTICS_LEVEL) {
|
||||
return stats;
|
||||
}
|
||||
auto hnsw_stats = std::static_pointer_cast<RHNSWStatistics>(stats);
|
||||
auto real_index = static_cast<faiss::IndexRHNSW*>(index_.get());
|
||||
auto lock = hnsw_stats->Lock();
|
||||
real_index->get_sorted_access_counts(hnsw_stats->access_cnt, hnsw_stats->access_total);
|
||||
return hnsw_stats;
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSW::ClearStatistics() {
|
||||
if (!STATISTICS_LEVEL) {
|
||||
return;
|
||||
}
|
||||
auto hnsw_stats = std::static_pointer_cast<RHNSWStatistics>(stats);
|
||||
auto real_index = static_cast<faiss::IndexRHNSW*>(index_.get());
|
||||
real_index->clear_stats();
|
||||
auto lock = hnsw_stats->Lock();
|
||||
hnsw_stats->clear();
|
||||
}
|
||||
|
||||
void
|
||||
IndexRHNSW::UpdateIndexSize() {
|
||||
KNOWHERE_THROW_MSG(
|
||||
|
|
|
@ -30,10 +30,12 @@ class IndexRHNSW : public VecIndex, public FaissBaseIndex {
|
|||
public:
|
||||
IndexRHNSW() : FaissBaseIndex(nullptr) {
|
||||
index_type_ = IndexEnum::INVALID;
|
||||
stats = std::make_shared<milvus::knowhere::RHNSWStatistics>(index_type_);
|
||||
}
|
||||
|
||||
explicit IndexRHNSW(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
|
||||
index_type_ = IndexEnum::INVALID;
|
||||
stats = std::make_shared<milvus::knowhere::RHNSWStatistics>(index_type_);
|
||||
}
|
||||
|
||||
BinarySet
|
||||
|
@ -52,7 +54,7 @@ class IndexRHNSW : public VecIndex, public FaissBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
@ -62,6 +64,12 @@ class IndexRHNSW : public VecIndex, public FaissBaseIndex {
|
|||
|
||||
void
|
||||
UpdateIndexSize() override;
|
||||
|
||||
StatisticsPtr
|
||||
GetStatistics() override;
|
||||
|
||||
void
|
||||
ClearStatistics() override;
|
||||
};
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -52,6 +52,9 @@ IndexRHNSWFlat::Serialize(const Config& config) {
|
|||
std::shared_ptr<uint8_t[]> data(writer.data_);
|
||||
|
||||
res_set.Append(writer.name, data, writer.rp);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
|
||||
}
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
|
@ -61,6 +64,7 @@ IndexRHNSWFlat::Serialize(const Config& config) {
|
|||
void
|
||||
IndexRHNSWFlat::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
Assemble(const_cast<BinarySet&>(index_binary));
|
||||
IndexRHNSW::Load(index_binary);
|
||||
MemoryIOReader reader;
|
||||
reader.name = this->index_type() + "_Data";
|
||||
|
|
|
@ -48,6 +48,9 @@ IndexRHNSWPQ::Serialize(const Config& config) {
|
|||
std::shared_ptr<uint8_t[]> data(writer.data_);
|
||||
|
||||
res_set.Append(writer.name, data, writer.rp);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
|
||||
}
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
|
@ -57,6 +60,7 @@ IndexRHNSWPQ::Serialize(const Config& config) {
|
|||
void
|
||||
IndexRHNSWPQ::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
Assemble(const_cast<BinarySet&>(index_binary));
|
||||
IndexRHNSW::Load(index_binary);
|
||||
MemoryIOReader reader;
|
||||
reader.name = QUANTIZATION_DATA;
|
||||
|
|
|
@ -51,6 +51,9 @@ IndexRHNSWSQ::Serialize(const Config& config) {
|
|||
std::shared_ptr<uint8_t[]> data(writer.data_);
|
||||
|
||||
res_set.Append(writer.name, data, writer.rp);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
|
||||
}
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
|
@ -60,6 +63,7 @@ IndexRHNSWSQ::Serialize(const Config& config) {
|
|||
void
|
||||
IndexRHNSWSQ::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
Assemble(const_cast<BinarySet&>(index_binary));
|
||||
IndexRHNSW::Load(index_binary);
|
||||
MemoryIOReader reader;
|
||||
reader.name = QUANTIZATION_DATA;
|
||||
|
|
|
@ -83,11 +83,15 @@ CPUSPTAGRNG::Serialize(const Config& config) {
|
|||
binary_set.Append("config", x_cfg, length);
|
||||
binary_set.Append("graph", graph, index_blobs[2].Length());
|
||||
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, binary_set);
|
||||
}
|
||||
return binary_set;
|
||||
}
|
||||
|
||||
void
|
||||
CPUSPTAGRNG::Load(const BinarySet& binary_set) {
|
||||
Assemble(const_cast<BinarySet&>(binary_set));
|
||||
std::string index_config;
|
||||
std::vector<SPTAG::ByteArray> index_blobs;
|
||||
|
||||
|
@ -176,7 +180,7 @@ CPUSPTAGRNG::SetParameters(const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
SetParameters(config);
|
||||
|
||||
float* p_data = (float*)dataset_ptr->Get<const void*>(meta::TENSOR);
|
||||
|
|
|
@ -52,7 +52,7 @@ class CPUSPTAGRNG : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -0,0 +1,208 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "IndexIVF.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/Statistics.h"
|
||||
#include "faiss/utils/ConcurrentBitset.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
int32_t STATISTICS_LEVEL = 0;
|
||||
|
||||
std::string
|
||||
Statistics::ToString() {
|
||||
std::ostringstream ret;
|
||||
|
||||
if (STATISTICS_LEVEL == 0) {
|
||||
ret << "There is nothing because configuration STATISTICS_LEVEL = 0" << std::endl;
|
||||
return ret.str();
|
||||
}
|
||||
if (STATISTICS_LEVEL >= 1) {
|
||||
ret << "Total batches: " << batch_cnt << std::endl;
|
||||
ret << "Total queries: " << nq_cnt << std::endl;
|
||||
ret << "Qps: " << Qps() << std::endl;
|
||||
|
||||
ret << "The frequency distribution of the num of queries:" << std::endl;
|
||||
size_t left = 1, right = 1;
|
||||
for (size_t i = 0; i < NQ_Histogram_Slices - 1; i++) {
|
||||
ret << "[" << left << ", " << right << "].count = " << nq_stat[i] << std::endl;
|
||||
left = right + 1;
|
||||
right <<= 1;
|
||||
}
|
||||
ret << "[" << left << ", +00).count = " << nq_stat.back() << std::endl;
|
||||
}
|
||||
if (STATISTICS_LEVEL >= 2) {
|
||||
ret << "The frequency distribution of filter: " << std::endl;
|
||||
for (auto i = 0; i < 20; ++i) {
|
||||
ret << "[" << i * 5 << "%, " << i * 5 + 5 << "%).count = " << filter_stat[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return ret.str();
|
||||
}
|
||||
|
||||
std::string
|
||||
HNSWStatistics::ToString() {
|
||||
std::ostringstream ret;
|
||||
|
||||
if (STATISTICS_LEVEL >= 1) {
|
||||
ret << "Avg Ef: " << AvgSearchEf() << std::endl;
|
||||
}
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
std::vector<size_t> axis_x = {5, 10, 20, 40};
|
||||
std::vector<double> access_cdf = AccessCDF(axis_x);
|
||||
ret << "There are " << access_total << " times point-access at level " << target_level << std::endl;
|
||||
ret << "The CDF at level " << target_level << ":" << std::endl;
|
||||
for (auto i = 0; i < axis_x.size(); ++i) {
|
||||
ret << "(" << axis_x[i] << "," << access_cdf[i] << ") ";
|
||||
}
|
||||
ret << std::endl;
|
||||
ret << "Level distribution: " << std::endl;
|
||||
size_t point_cnt = 0;
|
||||
for (int i = distribution.size() - 1; i >= 0; i--) {
|
||||
point_cnt += distribution[i];
|
||||
ret << "Level " << i << " has " << point_cnt << " points" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return Statistics::ToString() + ret.str();
|
||||
}
|
||||
|
||||
std::vector<size_t>
|
||||
GenSplitIndex(size_t size, const std::vector<size_t>& axis_x) {
|
||||
// Gen split index
|
||||
std::vector<size_t> split_idx(axis_x.size());
|
||||
for (size_t i = 0; i < axis_x.size(); i++) {
|
||||
if (axis_x[i] >= 100) {
|
||||
// for safe, not to let idx be larger than size
|
||||
split_idx[i] = size;
|
||||
} else {
|
||||
split_idx[i] = (axis_x[i] * size + 50) / 100;
|
||||
}
|
||||
}
|
||||
return split_idx;
|
||||
}
|
||||
|
||||
std::vector<double>
|
||||
CaculateCDF(size_t access_total, const std::vector<size_t>& access_cnt, const std::vector<size_t>& axis_x) {
|
||||
auto split_idx = GenSplitIndex(access_cnt.size(), axis_x);
|
||||
|
||||
// count cdf
|
||||
std::vector<double> access_cdf;
|
||||
access_cdf.resize(split_idx.size(), 0.0);
|
||||
|
||||
size_t idx = 0;
|
||||
size_t tmp_cnt = 0;
|
||||
for (size_t i = 0; i < split_idx.size(); ++i) {
|
||||
if (i != 0 && split_idx[i] < split_idx[i - 1]) {
|
||||
// wrong split_idx
|
||||
// Todo: log output
|
||||
access_cdf[i] = 0;
|
||||
} else {
|
||||
while (idx < split_idx[i]) {
|
||||
tmp_cnt += access_cnt[idx];
|
||||
idx++;
|
||||
}
|
||||
access_cdf[i] = static_cast<double>(tmp_cnt) / static_cast<double>(access_total);
|
||||
}
|
||||
}
|
||||
|
||||
return access_cdf;
|
||||
}
|
||||
|
||||
std::vector<double>
|
||||
LibHNSWStatistics::AccessCDF(const std::vector<size_t>& axis_x) {
|
||||
// copy from std::map to std::vector
|
||||
std::vector<size_t> access_cnt;
|
||||
access_cnt.reserve(access_cnt_map.size());
|
||||
access_total = 0;
|
||||
for (auto& elem : access_cnt_map) {
|
||||
access_cnt.push_back(elem.second);
|
||||
access_total += elem.second;
|
||||
}
|
||||
std::sort(access_cnt.begin(), access_cnt.end(), std::greater<>());
|
||||
|
||||
return CaculateCDF(access_total, access_cnt, axis_x);
|
||||
}
|
||||
|
||||
std::vector<double>
|
||||
RHNSWStatistics::AccessCDF(const std::vector<size_t>& axis_x) {
|
||||
return CaculateCDF(access_total, access_cnt, axis_x);
|
||||
}
|
||||
|
||||
std::string
|
||||
IVFStatistics::ToString() {
|
||||
std::ostringstream ret;
|
||||
|
||||
if (STATISTICS_LEVEL >= 1) {
|
||||
ret << "nlist " << Nlist() << std::endl;
|
||||
ret << "(nprobe, count): " << std::endl;
|
||||
auto nprobe = SearchNprobe();
|
||||
for (auto& it : nprobe) {
|
||||
ret << "(" << it.first << ", " << it.second << ") ";
|
||||
}
|
||||
ret << std::endl;
|
||||
}
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
std::vector<size_t> axis_x = {5, 10, 20, 40};
|
||||
ret << "Bucket CDF " << std::endl;
|
||||
auto output = AccessCDF(axis_x);
|
||||
for (int i = 0; i < output.size(); i++) {
|
||||
ret << "Top " << axis_x[i] << "% access count " << output[i] << std::endl;
|
||||
}
|
||||
ret << std::endl;
|
||||
}
|
||||
return Statistics::ToString() + ret.str();
|
||||
}
|
||||
|
||||
void
|
||||
IVFStatistics::count_nprobe(const int64_t nprobe) {
|
||||
// nprobe count
|
||||
auto it = nprobe_count.find(nprobe);
|
||||
if (it == nprobe_count.end()) {
|
||||
nprobe_count[nprobe] = 1;
|
||||
} else {
|
||||
it->second++;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
IVFStatistics::update_ivf_access_stats(const std::vector<size_t>& nprobe_statistics) {
|
||||
nlist = nprobe_statistics.size();
|
||||
access_total = 0;
|
||||
access_cnt = nprobe_statistics;
|
||||
|
||||
std::sort(access_cnt.begin(), access_cnt.end(), std::greater<>());
|
||||
// access total
|
||||
for (auto& cnt : access_cnt) {
|
||||
access_total += cnt;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<double>
|
||||
IVFStatistics::AccessCDF(const std::vector<size_t>& axis_x) {
|
||||
return CaculateCDF(access_total, access_cnt, axis_x);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,377 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "faiss/utils/ConcurrentBitset.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/IndexType.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
extern int32_t STATISTICS_LEVEL;
|
||||
|
||||
inline uint64_t
|
||||
upper_bound_of_pow2(uint64_t x) {
|
||||
--x;
|
||||
x |= (x >> 1);
|
||||
x |= (x >> 2);
|
||||
x |= (x >> 4);
|
||||
x |= (x >> 8);
|
||||
x |= (x >> 16);
|
||||
x |= (x >> 32);
|
||||
return x + 1;
|
||||
}
|
||||
|
||||
inline int
|
||||
len_of_pow2(uint64_t x) {
|
||||
return __builtin_popcountl(x - 1);
|
||||
}
|
||||
|
||||
/*
|
||||
* class: Statistics
|
||||
*/
|
||||
class Statistics {
|
||||
public:
|
||||
static const size_t NQ_Histogram_Slices = 13;
|
||||
static const size_t Filter_Histogram_Slices = 21;
|
||||
|
||||
explicit Statistics(std::string& idx_t)
|
||||
: index_type(idx_t),
|
||||
nq_cnt(0),
|
||||
batch_cnt(0),
|
||||
total_query_time(0.0),
|
||||
nq_stat(NQ_Histogram_Slices, 0),
|
||||
filter_stat(Filter_Histogram_Slices, 0),
|
||||
update_lock() {
|
||||
}
|
||||
|
||||
/*
|
||||
* Get index type
|
||||
* @retval: index type in string
|
||||
*/
|
||||
const std::string&
|
||||
IndexType() {
|
||||
return index_type;
|
||||
}
|
||||
|
||||
/*
|
||||
* To string (may be for log output)
|
||||
* @retval: string output
|
||||
*/
|
||||
virtual std::string
|
||||
ToString();
|
||||
|
||||
virtual ~Statistics() = default;
|
||||
|
||||
/*
|
||||
* Get batch count of the queries (Level 1)
|
||||
* @retval: query batch count
|
||||
*/
|
||||
size_t
|
||||
BatchCount() {
|
||||
return batch_cnt;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the statistics of the nq (Level 1)
|
||||
* @retval: count nq 1, 2, 3~4, 5~8, 9~16,…, 1024~2048, larger than 2048 (13 slices)
|
||||
*/
|
||||
const std::vector<size_t>&
|
||||
NQHistogram() {
|
||||
return nq_stat;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get query response per-second (Level 1)
|
||||
* @retval: Qps
|
||||
*/
|
||||
double
|
||||
Qps() {
|
||||
// ms -> s
|
||||
return total_query_time ? (nq_cnt * 1000.0 / total_query_time) : 0.0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the statistics of the filter for each batch (Level 2)
|
||||
* @retval: count 0~5%, 5~10%, 10~15%, ...95~100%, 100% (21 slices)
|
||||
*/
|
||||
const std::vector<size_t>&
|
||||
FilterHistograms() {
|
||||
return filter_stat;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex>
|
||||
Lock() {
|
||||
return std::unique_lock<std::mutex>(update_lock);
|
||||
}
|
||||
|
||||
public:
|
||||
void
|
||||
update_nq(const int64_t nq) {
|
||||
// batch
|
||||
batch_cnt++;
|
||||
|
||||
// nq_cnt
|
||||
nq_cnt += static_cast<size_t>(nq);
|
||||
|
||||
// nq_stat
|
||||
if (nq > 2048) {
|
||||
nq_stat[12]++;
|
||||
} else {
|
||||
nq_stat[len_of_pow2(upper_bound_of_pow2(static_cast<size_t>(nq)))]++;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
update_total_query_time(const double query_time) {
|
||||
total_query_time += query_time;
|
||||
}
|
||||
|
||||
void
|
||||
update_filter_percentage(const faiss::BitsetView& bitset) {
|
||||
double fps = !bitset.empty() ? static_cast<double>(bitset.count_1()) / bitset.size() : 0.0;
|
||||
filter_stat[static_cast<int>(fps * 100) / 5] += 1;
|
||||
}
|
||||
|
||||
virtual void
|
||||
clear() {
|
||||
total_query_time = 0.0;
|
||||
nq_cnt = 0;
|
||||
batch_cnt = 0;
|
||||
nq_stat.resize(NQ_Histogram_Slices, 0);
|
||||
filter_stat.resize(Filter_Histogram_Slices, 0);
|
||||
}
|
||||
|
||||
public:
|
||||
std::string& index_type;
|
||||
size_t batch_cnt; // updated in query
|
||||
size_t nq_cnt; // updated in query
|
||||
double total_query_time; // updated in query (unit: ms)
|
||||
std::vector<size_t> nq_stat; // updated in query
|
||||
std::vector<size_t> filter_stat; // updated in query
|
||||
std::mutex update_lock;
|
||||
};
|
||||
using StatisticsPtr = std::shared_ptr<Statistics>;
|
||||
|
||||
/*
|
||||
* class: HNSWStatistics
|
||||
*/
|
||||
class HNSWStatistics : public Statistics {
|
||||
public:
|
||||
explicit HNSWStatistics(std::string& idx_t)
|
||||
: Statistics(idx_t), distribution(), target_level(1), access_total(0), ef_sum(0) {
|
||||
}
|
||||
|
||||
~HNSWStatistics() override = default;
|
||||
|
||||
/*
|
||||
* To string (may be for log output)
|
||||
* @retval: string output
|
||||
*/
|
||||
std::string
|
||||
ToString() override;
|
||||
|
||||
/*
|
||||
* Get nodes count in each level
|
||||
* @retval: none
|
||||
*/
|
||||
const std::vector<size_t>&
|
||||
LevelNodesNum() {
|
||||
return distribution;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get average search parameter ‘ef’ (average for batches) (Level 1)
|
||||
* @retval: avg Ef
|
||||
*/
|
||||
double
|
||||
AvgSearchEf() {
|
||||
return nq_cnt ? ef_sum / nq_cnt : 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Cumulative distribution function of nodes access (Level 3)
|
||||
* @param: none (axis_x = {5,10,15,20,...100} by default)
|
||||
* @retval: Access CDF
|
||||
*/
|
||||
virtual std::vector<double>
|
||||
AccessCDF() {
|
||||
std::vector<size_t> axis_x(20);
|
||||
for (size_t i = 0; i < 20; ++i) {
|
||||
axis_x[i] = (i + 1) * 5;
|
||||
}
|
||||
|
||||
return AccessCDF(axis_x);
|
||||
}
|
||||
|
||||
/*
|
||||
* Cumulative distribution function of nodes access
|
||||
* @param: axis_x[in] specified by users and should be in ascending order
|
||||
* @retval: Access CDF
|
||||
*/
|
||||
virtual std::vector<double>
|
||||
AccessCDF(const std::vector<size_t>& axis_x) = 0;
|
||||
|
||||
public:
|
||||
void
|
||||
update_ef_sum(const int64_t ef) {
|
||||
ef_sum += ef;
|
||||
}
|
||||
|
||||
void
|
||||
update_level_distribution(const int max_level, const std::vector<int>& levels) {
|
||||
distribution.resize(max_level + 1);
|
||||
for (auto i = 0; i <= max_level; ++i) {
|
||||
distribution[i] = levels[i];
|
||||
if (distribution[i] >= 1000 && distribution[i] < 10000) {
|
||||
target_level = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
clear() override {
|
||||
Statistics::clear();
|
||||
access_total = 0;
|
||||
ef_sum = 0;
|
||||
}
|
||||
|
||||
public:
|
||||
std::vector<size_t> distribution;
|
||||
size_t target_level;
|
||||
size_t access_total; // depend on subclass type
|
||||
size_t ef_sum; // updated in query
|
||||
};
|
||||
|
||||
/*
|
||||
* class: LibHNSWStatistics
|
||||
* for index: HNSW
|
||||
*/
|
||||
class LibHNSWStatistics : public HNSWStatistics {
|
||||
public:
|
||||
explicit LibHNSWStatistics(std::string& idx_t) : HNSWStatistics(idx_t), access_cnt_map() {
|
||||
}
|
||||
|
||||
~LibHNSWStatistics() override = default;
|
||||
|
||||
std::vector<double>
|
||||
AccessCDF(const std::vector<size_t>& axis_x) override;
|
||||
|
||||
public:
|
||||
void
|
||||
clear() override {
|
||||
HNSWStatistics::clear();
|
||||
access_cnt_map.clear();
|
||||
}
|
||||
|
||||
public:
|
||||
std::unordered_map<int64_t, size_t> access_cnt_map; // updated in query
|
||||
};
|
||||
|
||||
/*
|
||||
* class: RHNSWStatistics
|
||||
* for index: RHNSW_FLAT, RHNSE_SQ, RHNSW_PQ
|
||||
*/
|
||||
class RHNSWStatistics : public HNSWStatistics {
|
||||
public:
|
||||
explicit RHNSWStatistics(std::string& idx_t) : HNSWStatistics(idx_t), access_cnt() {
|
||||
}
|
||||
|
||||
~RHNSWStatistics() override = default;
|
||||
|
||||
std::vector<double>
|
||||
AccessCDF(const std::vector<size_t>& axis_x) override;
|
||||
|
||||
public:
|
||||
std::vector<size_t> access_cnt; // prepared in GetStatistics
|
||||
};
|
||||
|
||||
/*
|
||||
* class: IVFStatistics
|
||||
* for index: IVF_FLAT, IVF_PQ, IVF_SQ8
|
||||
*/
|
||||
class IVFStatistics : public Statistics {
|
||||
public:
|
||||
explicit IVFStatistics(std::string& idx_t) : Statistics(idx_t), nprobe_count(), access_cnt(), nlist(0) {
|
||||
}
|
||||
|
||||
~IVFStatistics() override = default;
|
||||
|
||||
/*
|
||||
* To string (may be for log output)
|
||||
* @retval: string output
|
||||
*/
|
||||
std::string
|
||||
ToString() override;
|
||||
|
||||
/*
|
||||
* Get the statistics of the search parameter nprboe (count of batches) (Level 1)
|
||||
* @retval: nprobe
|
||||
*/
|
||||
int64_t
|
||||
Nlist() {
|
||||
return nlist;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the statistics of the search parameter nprboe (count of batches) (Level 1)
|
||||
* @retval: <nprobe, count>
|
||||
*/
|
||||
std::unordered_map<int64_t, size_t>
|
||||
SearchNprobe() {
|
||||
auto lock = Lock();
|
||||
auto rst = nprobe_count;
|
||||
lock.unlock();
|
||||
return rst;
|
||||
}
|
||||
|
||||
/*
|
||||
* Cumulative distribution function of bucket access (Level 3)
|
||||
* @param: axis_x[in] specified by users and should be in ascending order
|
||||
* @retval: Access CDF
|
||||
*/
|
||||
std::vector<double>
|
||||
AccessCDF(const std::vector<size_t>& axis_x);
|
||||
|
||||
public:
|
||||
void
|
||||
count_nprobe(const int64_t nprobe);
|
||||
|
||||
void
|
||||
update_ivf_access_stats(const std::vector<size_t>& nprobe_statistics);
|
||||
|
||||
void
|
||||
clear() override {
|
||||
Statistics::clear();
|
||||
nprobe_count.clear();
|
||||
access_total = 0;
|
||||
}
|
||||
|
||||
public:
|
||||
std::unordered_map<int64_t, size_t> nprobe_count; // updated in query
|
||||
std::vector<size_t> access_cnt; // prepared in GetStatistics
|
||||
size_t access_total; // updated in query
|
||||
size_t nlist;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <faiss/utils/BitsetView.h>
|
||||
#include <faiss/utils/ConcurrentBitset.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
@ -19,8 +20,10 @@
|
|||
#include "knowhere/common/Dataset.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Typedef.h"
|
||||
#include "knowhere/common/Utils.h"
|
||||
#include "knowhere/index/Index.h"
|
||||
#include "knowhere/index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/Statistics.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
@ -46,7 +49,7 @@ class VecIndex : public Index {
|
|||
AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
virtual DatasetPtr
|
||||
Query(const DatasetPtr& dataset, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) = 0;
|
||||
Query(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView& bitset) = 0;
|
||||
|
||||
#if 0
|
||||
virtual DatasetPtr
|
||||
|
@ -67,6 +70,15 @@ class VecIndex : public Index {
|
|||
virtual int64_t
|
||||
Count() = 0;
|
||||
|
||||
virtual StatisticsPtr
|
||||
GetStatistics() {
|
||||
return stats;
|
||||
}
|
||||
|
||||
virtual void
|
||||
ClearStatistics() {
|
||||
}
|
||||
|
||||
virtual IndexType
|
||||
index_type() const {
|
||||
return index_type_;
|
||||
|
@ -84,39 +96,19 @@ class VecIndex : public Index {
|
|||
}
|
||||
#endif
|
||||
|
||||
faiss::ConcurrentBitsetPtr
|
||||
GetBlacklist() {
|
||||
return bitset_;
|
||||
}
|
||||
|
||||
void
|
||||
SetBlacklist(faiss::ConcurrentBitsetPtr bitset_ptr) {
|
||||
bitset_ = std::move(bitset_ptr);
|
||||
}
|
||||
|
||||
const std::vector<IDType>&
|
||||
std::shared_ptr<std::vector<IDType>>
|
||||
GetUids() const {
|
||||
return uids_;
|
||||
}
|
||||
|
||||
void
|
||||
SetUids(std::vector<IDType>& uids) {
|
||||
uids_.clear();
|
||||
uids_.swap(uids);
|
||||
}
|
||||
|
||||
size_t
|
||||
BlacklistSize() {
|
||||
if (bitset_) {
|
||||
return bitset_->u8size() * sizeof(uint8_t);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
SetUids(std::shared_ptr<std::vector<IDType>> uids) {
|
||||
uids_ = uids;
|
||||
}
|
||||
|
||||
size_t
|
||||
UidsSize() {
|
||||
return uids_.size() * sizeof(IDType);
|
||||
return uids_ ? uids_->size() * sizeof(IDType) : 0;
|
||||
}
|
||||
|
||||
virtual int64_t
|
||||
|
@ -138,17 +130,15 @@ class VecIndex : public Index {
|
|||
|
||||
int64_t
|
||||
Size() override {
|
||||
return BlacklistSize() + UidsSize() + IndexSize();
|
||||
return UidsSize() + IndexSize();
|
||||
}
|
||||
|
||||
protected:
|
||||
IndexType index_type_ = "";
|
||||
IndexMode index_mode_ = IndexMode::MODE_CPU;
|
||||
std::vector<IDType> uids_;
|
||||
std::shared_ptr<std::vector<IDType>> uids_ = nullptr;
|
||||
int64_t index_size_ = -1;
|
||||
|
||||
private:
|
||||
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
|
||||
StatisticsPtr stats = nullptr;
|
||||
};
|
||||
|
||||
using VecIndexPtr = std::shared_ptr<VecIndex>;
|
||||
|
|
|
@ -110,7 +110,7 @@ GPUIDMAP::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
const faiss::BitsetView& bitset) {
|
||||
ResScope rs(res_, gpu_id_);
|
||||
|
||||
// assign the metric type
|
||||
|
|
|
@ -55,8 +55,8 @@ class GPUIDMAP : public IDMAP, public GPUIndex {
|
|||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
|
||||
override;
|
||||
QueryImpl(
|
||||
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
|
||||
};
|
||||
|
||||
using GPUIDMAPPtr = std::shared_ptr<GPUIDMAP>;
|
||||
|
|
|
@ -143,7 +143,7 @@ GPUIVF::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
const faiss::BitsetView& bitset) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
|
||||
|
|
|
@ -51,8 +51,8 @@ class GPUIVF : public IVF, public GPUIndex {
|
|||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
|
||||
override;
|
||||
QueryImpl(
|
||||
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
|
||||
};
|
||||
|
||||
using GPUIVFPtr = std::shared_ptr<GPUIVF>;
|
||||
|
|
|
@ -247,7 +247,7 @@ IVFSQHybrid::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
const faiss::BitsetView& bitset) {
|
||||
if (gpu_mode_ == 2) {
|
||||
GPUIVF::QueryImpl(n, data, k, distances, labels, config, bitset);
|
||||
// index_->search(n, (float*)data, k, distances, labels);
|
||||
|
|
|
@ -88,8 +88,8 @@ class IVFSQHybrid : public GPUIVFSQ {
|
|||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
|
||||
override;
|
||||
QueryImpl(
|
||||
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
|
||||
|
||||
protected:
|
||||
int64_t gpu_mode_ = 0; // 0: CPU, 1: Hybrid, 2: GPU
|
||||
|
|
|
@ -27,10 +27,7 @@ namespace cloner {
|
|||
|
||||
void
|
||||
CopyIndexData(const VecIndexPtr& dst_index, const VecIndexPtr& src_index) {
|
||||
/* do real copy */
|
||||
auto uids = src_index->GetUids();
|
||||
dst_index->SetUids(uids);
|
||||
dst_index->SetBlacklist(src_index->GetBlacklist());
|
||||
dst_index->SetUids(src_index->GetUids());
|
||||
dst_index->SetIndexSize(src_index->IndexSize());
|
||||
}
|
||||
|
||||
|
|
|
@ -63,10 +63,5 @@ MemoryIOReader::operator()(void* ptr, size_t size, size_t nitems) {
|
|||
return nitems;
|
||||
}
|
||||
|
||||
void
|
||||
enable_faiss_logging() {
|
||||
faiss::LOG_DEBUG_ = &log_debug_;
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <faiss/impl/io.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
@ -47,8 +46,5 @@ struct MemoryIOReader : public faiss::IOReader {
|
|||
}
|
||||
};
|
||||
|
||||
void
|
||||
enable_faiss_logging();
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -54,6 +54,9 @@ constexpr const char* PQM = "PQM";
|
|||
|
||||
// NGT Params
|
||||
constexpr const char* edge_size = "edge_size";
|
||||
// NGT Search Params
|
||||
constexpr const char* epsilon = "epsilon";
|
||||
constexpr const char* max_search_edges = "max_search_edges";
|
||||
// NGT_PANNG Params
|
||||
constexpr const char* forcedly_pruned_edge_size = "forcedly_pruned_edge_size";
|
||||
constexpr const char* selectively_pruned_edge_size = "selectively_pruned_edge_size";
|
||||
|
|
|
@ -864,7 +864,7 @@ NsgIndex::Search(const float* query,
|
|||
float* dist,
|
||||
int64_t* ids,
|
||||
SearchParams& params,
|
||||
faiss::ConcurrentBitsetPtr bitset) {
|
||||
const faiss::BitsetView& bitset) {
|
||||
std::vector<std::vector<Neighbor>> resset(nq);
|
||||
|
||||
TimeRecorder rc("NsgIndex::search", 1);
|
||||
|
@ -886,7 +886,7 @@ NsgIndex::Search(const float* query,
|
|||
if (pos >= k) {
|
||||
break; // already top k
|
||||
}
|
||||
if (!bitset || !bitset->test(node.id)) {
|
||||
if (!bitset || !bitset.test(node.id)) {
|
||||
ids[i * k + pos] = ids_[node.id];
|
||||
dist[i * k + pos] = is_ip ? -node.distance : node.distance;
|
||||
++pos;
|
||||
|
|
|
@ -91,7 +91,7 @@ class NsgIndex {
|
|||
float* dist,
|
||||
int64_t* ids,
|
||||
SearchParams& params,
|
||||
faiss::ConcurrentBitsetPtr bitset = nullptr);
|
||||
const faiss::BitsetView& bitset = nullptr);
|
||||
|
||||
int64_t
|
||||
GetSize();
|
||||
|
|
|
@ -53,23 +53,32 @@ IVF_NM::Serialize(const Config& config) {
|
|||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl(index_type_);
|
||||
auto ret = SerializeImpl(index_type_);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
IVF_NM::Load(const BinarySet& binary_set) {
|
||||
Assemble(const_cast<BinarySet&>(binary_set));
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(binary_set, index_type_);
|
||||
|
||||
// Construct arranged data from original data
|
||||
auto binary = binary_set.GetByName(RAW_DATA);
|
||||
auto original_data = reinterpret_cast<const float*>(binary->data.get());
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
auto ivf_index = static_cast<faiss::IndexIVF*>(index_.get());
|
||||
auto invlists = ivf_index->invlists;
|
||||
auto d = ivf_index->d;
|
||||
prefix_sum.resize(invlists->nlist);
|
||||
size_t curr_index = 0;
|
||||
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
ivf_index->nprobe_statistics.resize(invlists->nlist, 0);
|
||||
}
|
||||
|
||||
#ifndef MILVUS_GPU_VERSION
|
||||
auto ails = dynamic_cast<faiss::ArrayInvertedLists*>(invlists);
|
||||
size_t nb = binary->size / invlists->code_size;
|
||||
|
@ -102,17 +111,22 @@ IVF_NM::Load(const BinarySet& binary_set) {
|
|||
ro_codes = rol->pin_readonly_codes;
|
||||
data_ = nullptr;
|
||||
#endif
|
||||
// LOG_KNOWHERE_DEBUG_ << "IndexIVF_FLAT::Load finished, show statistics:";
|
||||
// auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
|
||||
// LOG_KNOWHERE_DEBUG_ << ivf_stats->ToString();
|
||||
}
|
||||
|
||||
void
|
||||
IVF_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
|
||||
int64_t nlist = config[IndexParams::nlist].get<int64_t>();
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
|
||||
auto nlist = config[IndexParams::nlist].get<int64_t>();
|
||||
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexIVFFlat(coarse_quantizer, dim, nlist, metric_type));
|
||||
index_->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
auto coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
|
||||
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, nlist, metric_type);
|
||||
index->own_fields = true;
|
||||
index->train(rows, reinterpret_cast<const float*>(p_data));
|
||||
index_ = index;
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -138,7 +152,7 @@ IVF_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -311,7 +325,7 @@ IVF_NM::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
const faiss::BitsetView& bitset) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
|
@ -328,16 +342,31 @@ IVF_NM::QueryImpl(int64_t n,
|
|||
#else
|
||||
auto data = static_cast<const uint8_t*>(ro_codes->data);
|
||||
#endif
|
||||
|
||||
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
|
||||
ivf_index->search_without_codes(n, reinterpret_cast<const float*>(query), data, prefix_sum, is_sq8, k, distances,
|
||||
labels, bitset);
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
LOG_KNOWHERE_DEBUG_ << "IVF_NM search cost: " << search_cost
|
||||
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
|
||||
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
|
||||
faiss::indexIVF_stats.quantization_time = 0;
|
||||
faiss::indexIVF_stats.search_time = 0;
|
||||
<< ", quantization cost: " << ivf_index->index_ivf_stats.quantization_time
|
||||
<< ", data search cost: " << ivf_index->index_ivf_stats.search_time;
|
||||
|
||||
if (STATISTICS_LEVEL) {
|
||||
auto lock = ivf_stats->Lock();
|
||||
if (STATISTICS_LEVEL >= 1) {
|
||||
ivf_stats->update_nq(n);
|
||||
ivf_stats->count_nprobe(ivf_index->nprobe);
|
||||
ivf_stats->update_total_query_time(ivf_index->index_ivf_stats.quantization_time +
|
||||
ivf_index->index_ivf_stats.search_time);
|
||||
ivf_index->index_ivf_stats.quantization_time = 0;
|
||||
ivf_index->index_ivf_stats.search_time = 0;
|
||||
}
|
||||
if (STATISTICS_LEVEL >= 2) {
|
||||
ivf_stats->update_filter_percentage(bitset);
|
||||
}
|
||||
}
|
||||
// LOG_KNOWHERE_DEBUG_ << "IndexIVF_FLAT::QueryImpl finished, show statistics:";
|
||||
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -380,5 +409,30 @@ IVF_NM::UpdateIndexSize() {
|
|||
index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size;
|
||||
}
|
||||
|
||||
StatisticsPtr
|
||||
IVF_NM::GetStatistics() {
|
||||
if (!STATISTICS_LEVEL) {
|
||||
return stats;
|
||||
}
|
||||
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
auto lock = ivf_stats->Lock();
|
||||
ivf_stats->update_ivf_access_stats(ivf_index->nprobe_statistics);
|
||||
return ivf_stats;
|
||||
}
|
||||
|
||||
void
|
||||
IVF_NM::ClearStatistics() {
|
||||
if (!STATISTICS_LEVEL) {
|
||||
return;
|
||||
}
|
||||
auto ivf_stats = std::dynamic_pointer_cast<IVFStatistics>(stats);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->clear_nprobe_statistics();
|
||||
ivf_index->index_ivf_stats.reset();
|
||||
auto lock = ivf_stats->Lock();
|
||||
ivf_stats->clear();
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -29,10 +29,12 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
|
|||
public:
|
||||
IVF_NM() : OffsetBaseIndex(nullptr) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
|
||||
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
|
||||
}
|
||||
|
||||
explicit IVF_NM(std::shared_ptr<faiss::Index> index) : OffsetBaseIndex(std::move(index)) {
|
||||
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
|
||||
stats = std::make_shared<milvus::knowhere::IVFStatistics>(index_type_);
|
||||
}
|
||||
|
||||
BinarySet
|
||||
|
@ -51,7 +53,7 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
|
@ -67,6 +69,12 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
|
|||
void
|
||||
UpdateIndexSize() override;
|
||||
|
||||
StatisticsPtr
|
||||
GetStatistics() override;
|
||||
|
||||
void
|
||||
ClearStatistics() override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
GetVectorById(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
@ -86,8 +94,7 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
|
|||
GenParams(const Config&);
|
||||
|
||||
virtual void
|
||||
QueryImpl(
|
||||
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset);
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset);
|
||||
|
||||
void
|
||||
SealImpl() override;
|
||||
|
|
|
@ -47,6 +47,9 @@ NSG_NM::Serialize(const Config& config) {
|
|||
|
||||
BinarySet res_set;
|
||||
res_set.Append("NSG_NM", data, writer.rp);
|
||||
if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
|
||||
Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>() * 1024 * 1024, res_set);
|
||||
}
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
|
@ -56,6 +59,7 @@ NSG_NM::Serialize(const Config& config) {
|
|||
void
|
||||
NSG_NM::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
Assemble(const_cast<BinarySet&>(index_binary));
|
||||
fiu_do_on("NSG_NM.Load.throw_exception", throw std::exception());
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
auto binary = index_binary.GetByName("NSG_NM");
|
||||
|
@ -74,7 +78,7 @@ NSG_NM::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView& bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ class NSG_NM : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -124,7 +124,7 @@ GPUIVF_NM::QueryImpl(int64_t n,
|
|||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
const faiss::BitsetView& bitset) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
|
||||
|
|
|
@ -51,8 +51,8 @@ class GPUIVF_NM : public IVF, public GPUIndex {
|
|||
SerializeImpl(const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
|
||||
override;
|
||||
QueryImpl(
|
||||
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::BitsetView& bitset) override;
|
||||
|
||||
protected:
|
||||
uint8_t* arranged_data;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "NGT/Index.h"
|
||||
#include "defines.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -185,8 +186,10 @@ class Clustering {
|
|||
}
|
||||
}
|
||||
if ((numberOfClusters != 0) && (clusters.size() < numberOfClusters)) {
|
||||
std::cerr << "initial cluster data are not enough. " << clusters.size() << ":" << numberOfClusters
|
||||
<< std::endl;
|
||||
// std::cerr << "initial cluster data are not enough. " << clusters.size() << ":" << numberOfClusters
|
||||
// << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("initial cluster data are not enough. " + std::to_string(clusters.size()) + ":" + std::to_string(numberOfClusters));
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
@ -365,7 +368,9 @@ class Clustering {
|
|||
for (auto soi = sortedObjects.rbegin(); soi != sortedObjects.rend();) {
|
||||
Entry& entry = *soi;
|
||||
if (entry.centroidID >= clusters.size()) {
|
||||
std::cerr << "Something wrong. " << entry.centroidID << ":" << clusters.size() << std::endl;
|
||||
// std::cerr << "Something wrong. " << entry.centroidID << ":" << clusters.size() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Something wrong. " + std::to_string(entry.centroidID) + ":" + std::to_string(clusters.size()));
|
||||
soi++;
|
||||
continue;
|
||||
}
|
||||
|
@ -547,7 +552,9 @@ class Clustering {
|
|||
distance += distanceL2((*it).centroid, mean);
|
||||
(*it).centroid = mean;
|
||||
} else {
|
||||
cerr << "Clustering: Fatal Error. No member!" << endl;
|
||||
// cerr << "Clustering: Fatal Error. No member!" << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Clustering: Fatal Error. No member!");
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
@ -579,7 +586,9 @@ class Clustering {
|
|||
|
||||
double diff = 0;
|
||||
for (size_t i = 0; i < maximumIteration; i++) {
|
||||
std::cerr << "iteration=" << i << std::endl;
|
||||
// std::cerr << "iteration=" << i << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("iteration=" + std::to_string(i));
|
||||
assign(vectors, clusters, clusterSize);
|
||||
// centroid is recomputed.
|
||||
// diff is distance between the current centroids and the previous centroids.
|
||||
|
@ -610,7 +619,9 @@ class Clustering {
|
|||
std::vector<Cluster> prevClusters = clusters;
|
||||
diff = calculateCentroid(vectors, clusters);
|
||||
timer.stop();
|
||||
std::cerr << "iteration=" << i << " time=" << timer << " diff=" << diff << std::endl;
|
||||
// std::cerr << "iteration=" << i << " time=" << timer << " diff=" << diff << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("iteration=" + std::to_string(i) + " time=" + std::to_string(timer.time)+ " diff=" + std::to_string(diff));
|
||||
timer.start();
|
||||
diffHistory.push_back(diff);
|
||||
|
||||
|
@ -664,15 +675,21 @@ class Clustering {
|
|||
try {
|
||||
os.getObject(idx, vectors[idx - 1]);
|
||||
} catch (...) {
|
||||
cerr << "Cannot get object " << idx << endl;
|
||||
// cerr << "Cannot get object " << idx << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Cannot get object " + std::to_string(idx));
|
||||
}
|
||||
}
|
||||
cerr << "# of data for clustering=" << vectors.size() << endl;
|
||||
// cerr << "# of data for clustering=" << vectors.size() << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("# of data for clustering=" + std::to_string(vectors.size()));
|
||||
double diff = DBL_MAX;
|
||||
clusters.clear();
|
||||
setupInitialClusters(vectors, numberOfClusters, clusters);
|
||||
for (float epsilon = epsilonFrom; epsilon <= epsilonTo; epsilon += epsilonStep) {
|
||||
cerr << "epsilon=" << epsilon << endl;
|
||||
// cerr << "epsilon=" << epsilon << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("epsilon=" + std::to_string(epsilon));
|
||||
diff = kmeansWithNGT(index, vectors, numberOfClusters, clusters, epsilon);
|
||||
if (diff == 0.0) {
|
||||
return diff;
|
||||
|
@ -748,7 +765,9 @@ class Clustering {
|
|||
}
|
||||
}
|
||||
if (vectors.size() != count) {
|
||||
std::cerr << "Warning! vectors.size() != count" << std::endl;
|
||||
// std::cerr << "Warning! vectors.size() != count" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Warning! vectors.size() != count");
|
||||
}
|
||||
|
||||
return d / (double)vectors.size();
|
||||
|
@ -787,7 +806,9 @@ class Clustering {
|
|||
break;
|
||||
}
|
||||
default:
|
||||
std::cerr << "proper initMode is not specified." << std::endl;
|
||||
// std::cerr << "proper initMode is not specified." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("proper initMode is not specified.");
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
@ -805,7 +826,9 @@ class Clustering {
|
|||
return kmeansWithNGT(vectors, numberOfClusters, clusters);
|
||||
break;
|
||||
default:
|
||||
cerr << "kmeans::fatal error!. invalid clustering type. " << clusteringType << endl;
|
||||
// cerr << "kmeans::fatal error!. invalid clustering type. " << clusteringType << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("kmeans::fatal error!. invalid clustering type. " + std::to_string(clusteringType));
|
||||
abort();
|
||||
break;
|
||||
}
|
||||
|
@ -817,16 +840,16 @@ class Clustering {
|
|||
size_t clusterSize = std::numeric_limits<size_t>::max();
|
||||
assign(vectors, clusters, clusterSize);
|
||||
|
||||
std::cout << "The number of vectors=" << vectors.size() << std::endl;
|
||||
std::cout << "The number of centroids=" << clusters.size() << std::endl;
|
||||
// std::cout << "The number of vectors=" << vectors.size() << std::endl;
|
||||
// std::cout << "The number of centroids=" << clusters.size() << std::endl;
|
||||
if (centroidIds.size() == 0) {
|
||||
switch (mode) {
|
||||
case 'e':
|
||||
std::cout << "MSE=" << calculateMSE(vectors, clusters) << std::endl;
|
||||
// std::cout << "MSE=" << calculateMSE(vectors, clusters) << std::endl;
|
||||
break;
|
||||
case '2':
|
||||
default:
|
||||
std::cout << "ML2=" << calculateML2(vectors, clusters) << std::endl;
|
||||
// std::cout << "ML2=" << calculateML2(vectors, clusters) << std::endl;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
|
@ -835,8 +858,8 @@ class Clustering {
|
|||
break;
|
||||
case '2':
|
||||
default:
|
||||
std::cout << "ML2=" << calculateML2FromSpecifiedCentroids(vectors, clusters, centroidIds)
|
||||
<< std::endl;
|
||||
// std::cout << "ML2=" << calculateML2FromSpecifiedCentroids(vectors, clusters, centroidIds)
|
||||
// << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include "NGT/lib/NGT/Common.h"
|
||||
#include "NGT/lib/NGT/ObjectSpace.h"
|
||||
|
||||
int64_t
|
||||
NGT::SearchContainer::memSize() {
|
||||
auto workres_size = workingResult.size() == 0 ? 0 : workingResult.size() * workingResult.top().memSize();
|
||||
return sizeof(size_t) * 3 + sizeof(float) * 3 + result->memSize() + 1 + workres_size + Container::memSize();
|
||||
}
|
|
@ -120,7 +120,9 @@ namespace NGT {
|
|||
}
|
||||
auto status = insert(std::make_pair(key,value));
|
||||
if (!status.second) {
|
||||
std::cerr << "Args: Duplicated options. [" << opt << "]" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Args: Duplicated options. [" + opt + "]");
|
||||
// std::cerr << "Args: Duplicated options. [" << opt << "]" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -305,7 +307,9 @@ namespace NGT {
|
|||
logFD = open(logFilePath.c_str(), O_CREAT|O_WRONLY|O_APPEND, mode);
|
||||
}
|
||||
if (logFD < 0) {
|
||||
std::cerr << "Logger: Cannot begin logging." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Logger: Cannot begin logging.");
|
||||
// std::cerr << "Logger: Cannot begin logging." << std::endl;
|
||||
logFD = -1;
|
||||
return;
|
||||
}
|
||||
|
@ -323,6 +327,8 @@ namespace NGT {
|
|||
savedFdNo = -1;
|
||||
}
|
||||
|
||||
int64_t memSize() { return sizeof(*this); }
|
||||
|
||||
std::string logFilePath;
|
||||
mode_t mode;
|
||||
int logFD;
|
||||
|
@ -479,7 +485,9 @@ namespace NGT {
|
|||
uint64_t size = vectorSize;
|
||||
size <<= 1;
|
||||
if (size > 0xffff) {
|
||||
std::cerr << "CompactVector is too big. " << size << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("CompactVector is too big. " + std::to_string(size));
|
||||
// std::cerr << "CompactVector is too big. " << size << std::endl;
|
||||
abort();
|
||||
}
|
||||
reserve(size);
|
||||
|
@ -487,6 +495,8 @@ namespace NGT {
|
|||
}
|
||||
}
|
||||
|
||||
virtual int64_t memSize() { return vector->memSize() * vectorSize + sizeof(vectorSize) * 2; }
|
||||
|
||||
TYPE *vector;
|
||||
uint16_t vectorSize;
|
||||
uint16_t allocatedSize;
|
||||
|
@ -540,6 +550,8 @@ namespace NGT {
|
|||
}
|
||||
}
|
||||
|
||||
virtual int64_t memSize() { return size(); }
|
||||
|
||||
char *vector;
|
||||
};
|
||||
|
||||
|
@ -563,6 +575,7 @@ namespace NGT {
|
|||
inline void reset(size_t i) {
|
||||
getEntry(i) &= ~getBitString(i);
|
||||
}
|
||||
inline int64_t memSize() { return size * sizeof(uint64_t); }
|
||||
std::vector<uint64_t> bitvec;
|
||||
uint64_t size;
|
||||
};
|
||||
|
@ -602,7 +615,9 @@ namespace NGT {
|
|||
char *e = 0;
|
||||
float val = strtof(it->second.c_str(), &e);
|
||||
if (*e != 0) {
|
||||
std::cerr << "Warning: Illegal property. " << key << ":" << it->second << " (" << e << ")" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Warning: Illegal property. " + key + ":" + it->second + " (" + e + ")");
|
||||
// std::cerr << "Warning: Illegal property. " << key << ":" << it->second << " (" << e << ")" << std::endl;
|
||||
return defvalue;
|
||||
}
|
||||
return val;
|
||||
|
@ -620,7 +635,9 @@ namespace NGT {
|
|||
char *e = 0;
|
||||
float val = strtol(it->second.c_str(), &e, 10);
|
||||
if (*e != 0) {
|
||||
std::cerr << "Warning: Illegal property. " << key << ":" << it->second << " (" << e << ")" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Warning: Illegal property. " + key + ":" + it->second + " (" + e + ")");
|
||||
// std::cerr << "Warning: Illegal property. " << key << ":" << it->second << " (" << e << ")" << std::endl;
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
@ -668,7 +685,9 @@ namespace NGT {
|
|||
NGT::Common::tokenize(line, tokens, "\t");
|
||||
if (tokens.size() != 2)
|
||||
{
|
||||
std::cerr << "Property file is illegal. " << line << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Property file is illegal. " + line);
|
||||
// std::cerr << "Property file is illegal. " << line << std::endl;
|
||||
continue;
|
||||
}
|
||||
set(tokens[0], tokens[1]);
|
||||
|
@ -681,7 +700,9 @@ namespace NGT {
|
|||
std::vector<std::string> tokens;
|
||||
NGT::Common::tokenize(line, tokens, "\t");
|
||||
if (tokens.size() != 2) {
|
||||
std::cerr << "Property file is illegal. " << line << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Property file is illegal. " + line);
|
||||
// std::cerr << "Property file is illegal. " << line << std::endl;
|
||||
continue;
|
||||
}
|
||||
set(tokens[0], tokens[1]);
|
||||
|
@ -719,7 +740,9 @@ namespace NGT {
|
|||
unsigned int tmp;
|
||||
is >> tmp;
|
||||
if (tmp > 255) {
|
||||
std::cerr << "Error! Invalid. " << tmp << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Error! Invalid. " + std::to_string(tmp));
|
||||
// std::cerr << "Error! Invalid. " << tmp << std::endl;
|
||||
}
|
||||
v = (TYPE)tmp;
|
||||
} else {
|
||||
|
@ -819,7 +842,9 @@ namespace NGT {
|
|||
unsigned int size;
|
||||
is >> size;
|
||||
if (s != size) {
|
||||
std::cerr << "readAsText: something wrong. " << size << ":" << s << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("readAsText: something wrong. " + std::to_string(size) + ":" + std::to_string(s));
|
||||
// std::cerr << "readAsText: something wrong. " << size << ":" << s << std::endl;
|
||||
return;
|
||||
}
|
||||
for (unsigned int i = 0; i < s; i++) {
|
||||
|
@ -1010,7 +1035,9 @@ namespace NGT {
|
|||
size <<= 1;
|
||||
} while (size <= idx);
|
||||
if (size > 0xffffffff) {
|
||||
std::cerr << "Vector is too big. " << size << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Vector is too big. " + std::to_string(size));
|
||||
// std::cerr << "Vector is too big. " << size << std::endl;
|
||||
abort();
|
||||
}
|
||||
reserve(size, allocator);
|
||||
|
@ -1092,7 +1119,9 @@ namespace NGT {
|
|||
Vector<size_t>::iterator rmi
|
||||
= std::lower_bound(removedList->begin(allocator), removedList->end(allocator), id, std::greater<size_t>());
|
||||
if ((rmi != removedList->end(allocator)) && ((*rmi) == id)) {
|
||||
std::cerr << "removedListPush: already existed! continue... ID=" << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("removedListPush: already existed! continue... ID=" + std::to_string(id));
|
||||
// std::cerr << "removedListPush: already existed! continue... ID=" << id << std::endl;
|
||||
return;
|
||||
}
|
||||
removedList->insert(rmi, id, allocator);
|
||||
|
@ -1271,7 +1300,9 @@ namespace NGT {
|
|||
size_t idx;
|
||||
NGT::Serializer::readAsText(is, idx);
|
||||
if (i != idx) {
|
||||
std::cerr << "PersistentRepository: Error. index of a specified import file is invalid. " << idx << ":" << i << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("PersistentRepository: Error. index of a specified import file is invalid. " + std::to_string(idx) + ":" + std::to_string(i));
|
||||
// std::cerr << "PersistentRepository: Error. index of a specified import file is invalid. " << idx << ":" << i << std::endl;
|
||||
}
|
||||
char type;
|
||||
NGT::Serializer::readAsText(is, type);
|
||||
|
@ -1599,7 +1630,9 @@ namespace NGT {
|
|||
size_t idx;
|
||||
NGT::Serializer::readAsText(is, idx);
|
||||
if (i != idx) {
|
||||
std::cerr << "Repository: Error. index of a specified import file is invalid. " << idx << ":" << i << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Repository: Error. index of a specified import file is invalid. " + std::to_string(idx) + ":" + std::to_string(i));
|
||||
// std::cerr << "Repository: Error. index of a specified import file is invalid. " << idx << ":" << i << std::endl;
|
||||
}
|
||||
char type;
|
||||
NGT::Serializer::readAsText(is, type);
|
||||
|
@ -1655,6 +1688,7 @@ namespace NGT {
|
|||
|
||||
#ifdef ADVANCED_USE_REMOVED_LIST
|
||||
size_t count() { return std::vector<TYPE*>::size() == 0 ? 0 : std::vector<TYPE*>::size() - removedList.size() - 1; }
|
||||
virtual int64_t memSize() { return std::vector<TYPE*>::size() == 0 ? 0 : (*this)[1]->memSize() * std::vector<TYPE*>::size() + removedList.size() * sizeof(size_t); }
|
||||
protected:
|
||||
std::priority_queue<size_t, std::vector<size_t>, std::greater<size_t> > removedList;
|
||||
#endif
|
||||
|
@ -1724,6 +1758,7 @@ namespace NGT {
|
|||
is >> o.distance;
|
||||
return is;
|
||||
}
|
||||
int64_t memSize() const { return sizeof(id) + sizeof(distance); }
|
||||
uint32_t id;
|
||||
float distance;
|
||||
};
|
||||
|
@ -1737,6 +1772,7 @@ namespace NGT {
|
|||
public:
|
||||
Container(Object &o, ObjectID i):object(o), id(i) {}
|
||||
Container(Container &c):object(c.object), id(c.id) {}
|
||||
virtual int64_t memSize() { return sizeof(ObjectID); }
|
||||
Object &object;
|
||||
ObjectID id;
|
||||
};
|
||||
|
@ -1791,6 +1827,11 @@ namespace NGT {
|
|||
|
||||
ResultPriorityQueue &getWorkingResult() { return workingResult; }
|
||||
|
||||
virtual int64_t memSize();
|
||||
// virtual int64_t memSize() {
|
||||
// auto workres_size = workingResult.size() == 0 ? 0 : workingResult.size() * workingResult.top().memSize();
|
||||
// return sizeof(size_t) * 3 + sizeof(float) * 3 + result->memSize() + 1 + workres_size + Container::memSize();
|
||||
// }
|
||||
|
||||
size_t size;
|
||||
Distance radius;
|
||||
|
@ -1828,6 +1869,7 @@ namespace NGT {
|
|||
}
|
||||
void *getQuery() { return query; }
|
||||
const std::type_info &getQueryType() { return *queryType; }
|
||||
virtual int64_t memSize() { return std::strlen((char*)getQuery()) + sizeof(getQueryType()) + SearchContainer::memSize(); }
|
||||
private:
|
||||
void deleteQuery() {
|
||||
if (query == 0) {
|
||||
|
|
|
@ -599,7 +599,7 @@ NeighborhoodGraph::setupSeeds(NGT::SearchContainer &sc, ObjectDistances &seeds,
|
|||
}
|
||||
|
||||
// for milvus
|
||||
void NeighborhoodGraph::search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::ConcurrentBitsetPtr & bitset)
|
||||
void NeighborhoodGraph::search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView& bitset)
|
||||
{
|
||||
if (sc.explorationCoefficient == 0.0)
|
||||
{
|
||||
|
@ -710,7 +710,7 @@ NeighborhoodGraph::setupSeeds(NGT::SearchContainer &sc, ObjectDistances &seeds,
|
|||
distanceChecked.insert(neighbor.id);
|
||||
|
||||
// judge if id in blacklist
|
||||
if (bitset != nullptr && bitset->test((faiss::ConcurrentBitset::id_type_t)neighbor.id - 1)) {
|
||||
if (!bitset.empty() && bitset.test((faiss::ConcurrentBitset::id_type_t)neighbor.id - 1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "NGT/ObjectSpaceRepository.h"
|
||||
|
||||
#include "faiss/utils/ConcurrentBitset.h"
|
||||
#include "faiss/utils/BitsetView.h"
|
||||
|
||||
#include "NGT/HashBasedBooleanSet.h"
|
||||
|
||||
|
@ -187,6 +188,14 @@ namespace NGT {
|
|||
}
|
||||
}
|
||||
|
||||
virtual int64_t memSize() {
|
||||
int64_t ret = prevsize->size() * sizeof(unsigned short);
|
||||
for (size_t i = 1; i < this->size(); ++ i) {
|
||||
ret += (*this)[i]->memSize();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
Vector<unsigned short> *prevsize;
|
||||
|
@ -211,6 +220,7 @@ namespace NGT {
|
|||
usedSize++;
|
||||
}
|
||||
size_t size() { return usedSize; }
|
||||
virtual int64_t memSize() { return reservedSize * (sizeof(uint64_t) + (*this)[0].second->memSize()); }
|
||||
size_t reservedSize;
|
||||
size_t usedSize;
|
||||
};
|
||||
|
@ -219,6 +229,13 @@ namespace NGT {
|
|||
public:
|
||||
SearchGraphRepository() {}
|
||||
bool isEmpty(size_t idx) { return (*this)[idx].empty(); }
|
||||
virtual int64_t memSize() {
|
||||
int64_t ret = 0;
|
||||
for (size_t i = 1; i < this->size(); ++ i) {
|
||||
ret += (*this)[i].memSize();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void deserialize(std::ifstream &is, ObjectRepository &objectRepository) {
|
||||
if (!is.is_open()) {
|
||||
|
@ -496,6 +513,8 @@ namespace NGT {
|
|||
return os;
|
||||
}
|
||||
|
||||
int64_t memSize() { return sizeof(*this); }
|
||||
|
||||
int16_t truncationThreshold;
|
||||
int16_t edgeSizeForCreation;
|
||||
int16_t edgeSizeForSearch;
|
||||
|
@ -679,7 +698,7 @@ namespace NGT {
|
|||
|
||||
void search(NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
// for milvus
|
||||
void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::ConcurrentBitsetPtr & bitset);
|
||||
void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView&bitset);
|
||||
|
||||
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
|
||||
template <typename COMPARATOR, typename CHECK_LIST> void searchReadOnlyGraph(NGT::SearchContainer &sc, ObjectDistances &seeds);
|
||||
|
@ -933,6 +952,7 @@ namespace NGT {
|
|||
|
||||
public:
|
||||
|
||||
virtual int64_t memSize() { return repository.memSize() + searchRepository.memSize() + property.memSize() + objectSpace->memSize(); }
|
||||
GraphRepository repository;
|
||||
ObjectSpace *objectSpace;
|
||||
|
||||
|
|
|
@ -120,8 +120,10 @@ namespace NGT {
|
|||
}
|
||||
}
|
||||
if (nQueries > ids.size()) {
|
||||
std::cerr << "# of Queries is not enough." << std::endl;
|
||||
return DBL_MAX;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("# of Queries is not enough.");
|
||||
// std::cerr << "# of Queries is not enough." << std::endl;
|
||||
return DBL_MAX;
|
||||
}
|
||||
|
||||
NGT::Timer timer;
|
||||
|
@ -233,14 +235,18 @@ namespace NGT {
|
|||
{
|
||||
if (!logDisabled)
|
||||
{
|
||||
std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphOptimizer: adjusting outgoing and incoming edges...");
|
||||
// std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
|
||||
}
|
||||
NGT::Timer timer;
|
||||
timer.start();
|
||||
std::vector<NGT::ObjectDistances> graph;
|
||||
try
|
||||
{
|
||||
std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Optimizer::execute: Extract the graph data.");
|
||||
// std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
|
||||
// extract only edges from the index to reduce the memory usage.
|
||||
NGT::GraphReconstructor::extractGraph(graph, graphIndex);
|
||||
NeighborhoodGraph::Property & prop = graphIndex.getGraphProperty();
|
||||
|
@ -250,7 +256,9 @@ namespace NGT {
|
|||
}
|
||||
NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges);
|
||||
timer.stop();
|
||||
std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Optimizer::execute: Graph reconstruction time=" + std::to_string(timer.time) + " (sec) ");
|
||||
// std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
|
||||
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
|
||||
}
|
||||
catch (NGT::Exception & err)
|
||||
|
@ -263,7 +271,9 @@ namespace NGT {
|
|||
{
|
||||
if (!logDisabled)
|
||||
{
|
||||
std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphOptimizer: redusing shortcut edges...");
|
||||
// std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
|
||||
}
|
||||
try
|
||||
{
|
||||
|
@ -271,7 +281,9 @@ namespace NGT {
|
|||
timer.start();
|
||||
NGT::GraphReconstructor::adjustPathsEffectively(graphIndex);
|
||||
timer.stop();
|
||||
std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Optimizer::execute: Path adjustment time=" + std::to_string(timer.time) + " (sec) ");
|
||||
// std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl;
|
||||
}
|
||||
catch (NGT::Exception & err)
|
||||
{
|
||||
|
@ -286,7 +298,9 @@ namespace NGT {
|
|||
{
|
||||
if (!logDisabled)
|
||||
{
|
||||
std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing search parameters...");
|
||||
// std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
|
||||
}
|
||||
NGT::GraphIndex & outGraph = static_cast<NGT::GraphIndex &>(outIndex.getIndex());
|
||||
NGT::Optimizer optimizer(outIndex);
|
||||
|
@ -321,7 +335,9 @@ namespace NGT {
|
|||
{
|
||||
if (!logDisabled)
|
||||
{
|
||||
std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing prefetch parameters...");
|
||||
// std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
|
||||
}
|
||||
try
|
||||
{
|
||||
|
@ -343,7 +359,9 @@ namespace NGT {
|
|||
{
|
||||
if (!logDisabled)
|
||||
{
|
||||
std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphOptimizer: generating the accuracy table...");
|
||||
// std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
|
||||
}
|
||||
try
|
||||
{
|
||||
|
@ -387,14 +405,18 @@ namespace NGT {
|
|||
NGT::GraphIndex graphIndex(outIndexPath, false);
|
||||
if (numOfOutgoingEdges > 0 || numOfIncomingEdges > 0) {
|
||||
if (!logDisabled) {
|
||||
std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphOptimizer: adjusting outgoing and incoming edges...");
|
||||
// std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
|
||||
}
|
||||
redirector.begin();
|
||||
NGT::Timer timer;
|
||||
timer.start();
|
||||
std::vector<NGT::ObjectDistances> graph;
|
||||
try {
|
||||
std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Optimizer::execute: Extract the graph data.");
|
||||
// std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
|
||||
// extract only edges from the index to reduce the memory usage.
|
||||
NGT::GraphReconstructor::extractGraph(graph, graphIndex);
|
||||
NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty();
|
||||
|
@ -403,7 +425,9 @@ namespace NGT {
|
|||
}
|
||||
NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges);
|
||||
timer.stop();
|
||||
std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Optimizer::execute: Graph reconstruction time=" + std::to_string(timer.time) + " (sec) ");
|
||||
// std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
|
||||
graphIndex.saveGraph(outIndexPath);
|
||||
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
|
||||
graphIndex.saveProperty(outIndexPath);
|
||||
|
@ -415,14 +439,18 @@ namespace NGT {
|
|||
|
||||
if (shortcutReduction) {
|
||||
if (!logDisabled) {
|
||||
std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphOptimizer: redusing shortcut edges...");
|
||||
// std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
|
||||
}
|
||||
try {
|
||||
NGT::Timer timer;
|
||||
timer.start();
|
||||
NGT::GraphReconstructor::adjustPathsEffectively(graphIndex);
|
||||
timer.stop();
|
||||
std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("optimizer::execute: path adjustment time=" + std::to_string(timer.time) + " (sec) ");
|
||||
// std::cerr << "optimizer::execute: path adjustment time=" << timer.time << " (sec) " << std::endl;
|
||||
graphIndex.saveGraph(outIndexPath);
|
||||
} catch (NGT::Exception &err) {
|
||||
redirector.end();
|
||||
|
@ -441,7 +469,9 @@ namespace NGT {
|
|||
|
||||
if (searchParameterOptimization) {
|
||||
if (!logDisabled) {
|
||||
std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing search parameters...");
|
||||
// std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
|
||||
}
|
||||
NGT::Index outIndex(outIndexPath);
|
||||
NGT::GraphIndex &outGraph = static_cast<NGT::GraphIndex&>(outIndex.getIndex());
|
||||
|
@ -472,7 +502,9 @@ namespace NGT {
|
|||
NGT::GraphIndex &outGraph = static_cast<NGT::GraphIndex&>(outIndex.getIndex());
|
||||
if (prefetchParameterOptimization) {
|
||||
if (!logDisabled) {
|
||||
std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing prefetch parameters...");
|
||||
// std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
|
||||
}
|
||||
try {
|
||||
auto prefetch = adjustPrefetchParameters(outIndex);
|
||||
|
@ -491,7 +523,9 @@ namespace NGT {
|
|||
}
|
||||
if (accuracyTableGeneration) {
|
||||
if (!logDisabled) {
|
||||
std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphOptimizer: generating the accuracy table...");
|
||||
// std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
|
||||
}
|
||||
try {
|
||||
auto table = NGT::Optimizer::generateAccuracyTable(outIndex, numOfResults, numOfQueries);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <list>
|
||||
#include "defines.h"
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
|
@ -34,7 +35,9 @@ class GraphReconstructor {
|
|||
graph.reserve(graphIndex.repository.size());
|
||||
for (size_t id = 1; id < graphIndex.repository.size(); id++) {
|
||||
if (id % 1000000 == 0) {
|
||||
std::cerr << "GraphReconstructor::extractGraph: Processed " << id << " objects." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor::extractGraph: Processed " + std::to_string(id) + " objects.");
|
||||
// std::cerr << "GraphReconstructor::extractGraph: Processed " << id << " objects." << std::endl;
|
||||
}
|
||||
try {
|
||||
NGT::GraphNode &node = *graphIndex.getNode(id);
|
||||
|
@ -49,7 +52,9 @@ class GraphReconstructor {
|
|||
graph.push_back(node);
|
||||
#endif
|
||||
if (graph.back().size() != graph.back().capacity()) {
|
||||
std::cerr << "GraphReconstructor::extractGraph: Warning! The graph size must be the same as the capacity. " << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor::extractGraph: Warning! The graph size must be the same as the capacity. " + std::to_string(id));
|
||||
// std::cerr << "GraphReconstructor::extractGraph: Warning! The graph size must be the same as the capacity. " << id << std::endl;
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
graph.push_back(NGT::ObjectDistances());
|
||||
|
@ -66,7 +71,9 @@ class GraphReconstructor {
|
|||
adjustPaths(NGT::Index &outIndex)
|
||||
{
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "construct index is not implemented." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("construct index is not implemented.");
|
||||
// std::cerr << "construct index is not implemented." << std::endl;
|
||||
exit(1);
|
||||
#else
|
||||
NGT::GraphIndex &outGraph = dynamic_cast<NGT::GraphIndex&>(outIndex.getIndex());
|
||||
|
@ -93,8 +100,12 @@ class GraphReconstructor {
|
|||
}
|
||||
edge = true;
|
||||
if (rank >= 1 && node[rank - 1].distance > node[rank].distance) {
|
||||
std::cerr << "distance order is wrong!" << std::endl;
|
||||
std::cerr << id << ":" << rank << ":" << node[rank - 1].id << ":" << node[rank].id << std::endl;
|
||||
// std::cerr << "distance order is wrong!" << std::endl;
|
||||
// std::cerr << id << ":" << rank << ":" << node[rank - 1].id << ":" << node[rank].id << std::endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("distance order is wrong!");
|
||||
(*NGT_LOG_DEBUG_)(std::to_string(id) + ":" + std::to_string(rank) + ":" + std::to_string(node[rank - 1].id) + ":" + std::to_string(node[rank].id));
|
||||
}
|
||||
}
|
||||
NGT::GraphNode &tn = *outGraph.getNode(id);
|
||||
volatile bool found = false;
|
||||
|
@ -141,7 +152,9 @@ class GraphReconstructor {
|
|||
removeCount++;
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
it++;
|
||||
continue;
|
||||
}
|
||||
|
@ -210,7 +223,9 @@ class GraphReconstructor {
|
|||
node.clear();
|
||||
#endif
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
tmpGraph.push_back(NGT::GraphNode(outGraph.repository.allocator));
|
||||
#else
|
||||
|
@ -224,7 +239,9 @@ class GraphReconstructor {
|
|||
NGTThrowException(msg);
|
||||
}
|
||||
timer.stop();
|
||||
std::cerr << "GraphReconstructor::adjustPaths: graph preparing time=" << timer << std::endl;
|
||||
// std::cerr << "GraphReconstructor::adjustPaths: graph preparing time=" << timer << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor::adjustPaths: graph preparing time=" + std::to_string(timer.time));
|
||||
timer.reset();
|
||||
timer.start();
|
||||
|
||||
|
@ -287,12 +304,16 @@ class GraphReconstructor {
|
|||
removeCandidates[id - 1].push_back(candidates[i].second);
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
timer.stop();
|
||||
std::cerr << "GraphReconstructor::adjustPaths: extracting removed edge candidates time=" << timer << std::endl;
|
||||
// std::cerr << "GraphReconstructor::adjustPaths: extracting removed edge candidates time=" << timer << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor::adjustPaths: extracting removed edge candidates time=" + std::to_string(timer.time));
|
||||
timer.reset();
|
||||
timer.start();
|
||||
|
||||
|
@ -311,7 +332,9 @@ class GraphReconstructor {
|
|||
NGT::GraphNode &srcNode = tmpGraph[idx];
|
||||
if (rank >= srcNode.size()) {
|
||||
if (!removeCandidates[idx].empty()) {
|
||||
std::cerr << "Something wrong! ID=" << id << " # of remaining candidates=" << removeCandidates[idx].size() << std::endl;
|
||||
// std::cerr << "Something wrong! ID=" << id << " # of remaining candidates=" << removeCandidates[idx].size() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Something wrong! ID=" + std::to_string(id) + " # of remaining candidates=" + std::to_string(removeCandidates[idx].size()));
|
||||
abort();
|
||||
}
|
||||
#if !defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
|
@ -365,7 +388,9 @@ class GraphReconstructor {
|
|||
insert(outSrcNode, srcNode[rank].id, srcNode[rank].distance);
|
||||
#endif
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
it++;
|
||||
continue;
|
||||
}
|
||||
|
@ -389,10 +414,14 @@ class GraphReconstructor {
|
|||
void convertToANNG(std::vector<NGT::ObjectDistances> &graph)
|
||||
{
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "convertToANNG is not implemented for shared memory." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("convertToANNG is not implemented for shared memory.");
|
||||
// std::cerr << "convertToANNG is not implemented for shared memory." << std::endl;
|
||||
return;
|
||||
#else
|
||||
std::cerr << "convertToANNG begin" << std::endl;
|
||||
// std::cerr << "convertToANNG begin" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("convertToANNG begin");
|
||||
for (size_t idx = 0; idx < graph.size(); idx++) {
|
||||
NGT::GraphNode &node = graph[idx];
|
||||
for (auto ni = node.begin(); ni != node.end(); ++ni) {
|
||||
|
@ -417,7 +446,9 @@ class GraphReconstructor {
|
|||
NGT::GraphNode tmp = node;
|
||||
node.swap(tmp);
|
||||
}
|
||||
std::cerr << "convertToANNG end" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("convertToANNG end");
|
||||
// std::cerr << "convertToANNG end" << std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -425,7 +456,9 @@ class GraphReconstructor {
|
|||
void reconstructGraph(std::vector<NGT::ObjectDistances> &graph, NGT::GraphIndex &outGraph, size_t originalEdgeSize, size_t reverseEdgeSize)
|
||||
{
|
||||
if (reverseEdgeSize > 10000) {
|
||||
std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
|
||||
// std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("something wrong. Edge size=" + std::to_string(reverseEdgeSize));
|
||||
exit(1);
|
||||
}
|
||||
|
||||
|
@ -445,7 +478,9 @@ class GraphReconstructor {
|
|||
} else {
|
||||
NGT::ObjectDistances n = graph[id - 1];
|
||||
if (n.size() < originalEdgeSize) {
|
||||
std::cerr << "GraphReconstructor: Warning. The edges are too few. " << n.size() << ":" << originalEdgeSize << " for " << id << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. The edges are too few. " << n.size() << ":" << originalEdgeSize << " for " << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. The edges are too few. " + std::to_string(n.size()) + ":" + std::to_string(originalEdgeSize) + " for " + std::to_string(id));
|
||||
continue;
|
||||
}
|
||||
n.resize(originalEdgeSize);
|
||||
|
@ -456,7 +491,9 @@ class GraphReconstructor {
|
|||
#endif
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -485,13 +522,17 @@ class GraphReconstructor {
|
|||
} catch(...) {}
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
reverseEdgeTimer.stop();
|
||||
if (insufficientNodeCount != 0) {
|
||||
std::cerr << "# of the nodes edges of which are in short = " << insufficientNodeCount << std::endl;
|
||||
// std::cerr << "# of the nodes edges of which are in short = " << insufficientNodeCount << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("# of the nodes edges of which are in short = " + std::to_string(insufficientNodeCount));
|
||||
}
|
||||
|
||||
normalizeEdgeTimer.start();
|
||||
|
@ -499,7 +540,9 @@ class GraphReconstructor {
|
|||
try {
|
||||
NGT::GraphNode &n = *outGraph.getNode(id);
|
||||
if (id % 100000 == 0) {
|
||||
std::cerr << "Processed " << id << " nodes" << std::endl;
|
||||
// std::cerr << "Processed " << id << " nodes" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id) + " nodes");
|
||||
}
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::sort(n.begin(outGraph.repository.allocator), n.end(outGraph.repository.allocator));
|
||||
|
@ -528,13 +571,18 @@ class GraphReconstructor {
|
|||
n.swap(tmp);
|
||||
#endif
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
normalizeEdgeTimer.stop();
|
||||
std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
|
||||
<< ":" << normalizeEdgeTimer.time << std::endl;
|
||||
// std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
|
||||
// << ":" << normalizeEdgeTimer.time << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Reconstruction time=" + std::to_string(originalEdgeTimer.time) + ":" + std::to_string(reverseEdgeTimer.time)
|
||||
+ ":" + std::to_string(normalizeEdgeTimer.time));
|
||||
|
||||
NGT::Property prop;
|
||||
outGraph.getProperty().get(prop);
|
||||
|
@ -550,20 +598,26 @@ class GraphReconstructor {
|
|||
char mode = 'a')
|
||||
{
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "reconstructGraphWithConstraint is not implemented." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("reconstructGraphWithConstraint is not implemented.");
|
||||
// std::cerr << "reconstructGraphWithConstraint is not implemented." << std::endl;
|
||||
abort();
|
||||
#else
|
||||
|
||||
NGT::Timer originalEdgeTimer, reverseEdgeTimer, normalizeEdgeTimer;
|
||||
|
||||
if (reverseEdgeSize > 10000) {
|
||||
std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
|
||||
// std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("something wrong. Edge size=" + std::to_string(reverseEdgeSize));
|
||||
exit(1);
|
||||
}
|
||||
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
if (id % 1000000 == 0) {
|
||||
std::cerr << "Processed " << id << std::endl;
|
||||
// std::cerr << "Processed " << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id));
|
||||
}
|
||||
try {
|
||||
NGT::GraphNode &node = *outGraph.getNode(id);
|
||||
|
@ -574,7 +628,9 @@ class GraphReconstructor {
|
|||
NGT::GraphNode empty;
|
||||
node.swap(empty);
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -585,13 +641,17 @@ class GraphReconstructor {
|
|||
try {
|
||||
NGT::GraphNode &node = graph[id - 1];
|
||||
if (id % 100000 == 0) {
|
||||
std::cerr << "Processed (summing up) " << id << std::endl;
|
||||
// std::cerr << "Processed (summing up) " << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Processed (summing up) " + std::to_string(id));
|
||||
}
|
||||
for (size_t rank = 0; rank < node.size(); rank++) {
|
||||
reverse[node[rank].id].push_back(ObjectDistance(id, node[rank].distance));
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -628,7 +688,9 @@ class GraphReconstructor {
|
|||
}
|
||||
}
|
||||
reverseEdgeTimer.stop();
|
||||
std::cerr << "The number of nodes with zero outdegree by reverse edges=" << zeroCount << std::endl;
|
||||
// std::cerr << "The number of nodes with zero outdegree by reverse edges=" << zeroCount << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("The number of nodes with zero outdegree by reverse edges=" + std::to_string(zeroCount));
|
||||
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
|
||||
|
||||
normalizeEdgeTimer.start();
|
||||
|
@ -636,7 +698,9 @@ class GraphReconstructor {
|
|||
try {
|
||||
NGT::GraphNode &n = *outGraph.getNode(id);
|
||||
if (id % 100000 == 0) {
|
||||
std::cerr << "Processed " << id << std::endl;
|
||||
// std::cerr << "Processed " << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id));
|
||||
}
|
||||
std::sort(n.begin(), n.end());
|
||||
NGT::ObjectID prev = 0;
|
||||
|
@ -651,7 +715,9 @@ class GraphReconstructor {
|
|||
NGT::GraphNode tmp = n;
|
||||
n.swap(tmp);
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -661,7 +727,9 @@ class GraphReconstructor {
|
|||
originalEdgeTimer.start();
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
if (id % 1000000 == 0) {
|
||||
std::cerr << "Processed " << id << std::endl;
|
||||
// std::cerr << "Processed " << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id));
|
||||
}
|
||||
NGT::GraphNode &node = graph[id - 1];
|
||||
try {
|
||||
|
@ -683,15 +751,20 @@ class GraphReconstructor {
|
|||
outGraph.addEdge(id, nodeID, distance, false);
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
originalEdgeTimer.stop();
|
||||
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
|
||||
|
||||
std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
|
||||
<< ":" << normalizeEdgeTimer.time << std::endl;
|
||||
// std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
|
||||
// << ":" << normalizeEdgeTimer.time << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Reconstruction time=" + std::to_string(originalEdgeTimer.time) + ":" + std::to_string(reverseEdgeTimer.time)
|
||||
+ ":" + std::to_string(normalizeEdgeTimer.time));
|
||||
|
||||
#endif
|
||||
}
|
||||
|
@ -703,7 +776,9 @@ class GraphReconstructor {
|
|||
void reconstructANNGFromANNG(std::vector<NGT::ObjectDistances> &graph, NGT::Index &index, size_t edgeSize)
|
||||
{
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "reconstructANNGFromANNG is not implemented." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("reconstructANNGFromANNG is not implemented.");
|
||||
// std::cerr << "reconstructANNGFromANNG is not implemented." << std::endl;
|
||||
abort();
|
||||
#else
|
||||
|
||||
|
@ -712,7 +787,9 @@ class GraphReconstructor {
|
|||
// remove all edges in the index.
|
||||
for (size_t id = 1; id < outGraph.repository.size(); id++) {
|
||||
if (id % 1000000 == 0) {
|
||||
std::cerr << "Processed " << id << " nodes." << std::endl;
|
||||
// std::cerr << "Processed " << id << " nodes." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id) + " nodes.");
|
||||
}
|
||||
try {
|
||||
NGT::GraphNode &node = *outGraph.getNode(id);
|
||||
|
@ -810,7 +887,9 @@ class GraphReconstructor {
|
|||
for (size_t idx = 0; idx < batchSize; idx++) {
|
||||
size_t id = bid + idx;
|
||||
if (id % 100000 == 0) {
|
||||
std::cerr << "# of processed objects=" << id << std::endl;
|
||||
// std::cerr << "# of processed objects=" << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("# of processed objects=" + std::to_string(id));
|
||||
}
|
||||
if (objectRepository.isEmpty(id)) {
|
||||
continue;
|
||||
|
@ -876,7 +955,9 @@ class GraphReconstructor {
|
|||
for (size_t idx = 0; idx < batchSize; idx++) {
|
||||
size_t id = bid + idx;
|
||||
if (id % 10000 == 0) {
|
||||
std::cerr << "# of processed objects=" << id << std::endl;
|
||||
// std::cerr << "# of processed objects=" << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("# of processed objects=" + std::to_string(id));
|
||||
}
|
||||
for (auto i = results[idx].begin(); i != results[idx].end(); ++i) {
|
||||
if ((*i).id != id) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "defines.h"
|
||||
#include <iostream>
|
||||
#include <cstring>
|
||||
#include <stdint.h>
|
||||
|
@ -53,7 +54,9 @@ class HashBasedBooleanSet{
|
|||
_mask = _tableSize - 1;
|
||||
const uint32_t checkValue = _hash1(tableSize);
|
||||
if(checkValue != 0){
|
||||
std::cerr << "[WARN] table size is not 2^N : " << tableSize << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("[WARN] table size is not 2^N : " + std::to_string(tableSize));
|
||||
// std::cerr << "[WARN] table size is not 2^N : " << tableSize << std::endl;
|
||||
}
|
||||
|
||||
_table = new uint32_t[tableSize];
|
||||
|
|
|
@ -266,7 +266,9 @@ void NGT::Index::loadRawDataAndCreateIndex(NGT::Index * index_, const float * ro
|
|||
timer.start();
|
||||
index_->createIndex(threadSize);
|
||||
timer.stop();
|
||||
cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Index creation time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
|
||||
// cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -280,17 +282,23 @@ NGT::Index::loadAndCreateIndex(Index &index, const string &database, const strin
|
|||
return;
|
||||
}
|
||||
timer.stop();
|
||||
cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Data loading time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
|
||||
// cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
if (index.getObjectRepositorySize() == 0) {
|
||||
NGTThrowException("Index::create: Data file is empty.");
|
||||
}
|
||||
cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(index.getObjectRepositorySize() - 1));
|
||||
// cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
|
||||
timer.reset();
|
||||
timer.start();
|
||||
index.createIndex(threadSize);
|
||||
timer.stop();
|
||||
index.saveIndex(database);
|
||||
cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
// cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Index creation time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
|
||||
}
|
||||
|
||||
// For milvus
|
||||
|
@ -307,13 +315,21 @@ void NGT::Index::append(NGT::Index * index_, const float * data, size_t dataSize
|
|||
NGTThrowException("Index::append: No data.");
|
||||
}
|
||||
timer.stop();
|
||||
cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
cerr << "# of objects=" << index_->getObjectRepositorySize() - 1 << endl;
|
||||
// cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
// cerr << "# of objects=" << index_->getObjectRepositorySize() - 1 << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)(
|
||||
"Data loading time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0)
|
||||
+ " (msec)");
|
||||
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(index_->getObjectRepositorySize() - 1));
|
||||
}
|
||||
timer.reset();
|
||||
timer.start();
|
||||
index_->createIndex(threadSize);
|
||||
timer.stop();
|
||||
cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
// cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Index creation time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -326,14 +342,20 @@ NGT::Index::append(const string &database, const string &dataFile, size_t thread
|
|||
index.append(dataFile, dataSize);
|
||||
}
|
||||
timer.stop();
|
||||
cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("Data loading time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
|
||||
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(index.getObjectRepositorySize() - 1));
|
||||
}
|
||||
// cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
// cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
|
||||
timer.reset();
|
||||
timer.start();
|
||||
index.createIndex(threadSize);
|
||||
timer.stop();
|
||||
index.saveIndex(database);
|
||||
cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
// cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Index creation time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -348,14 +370,20 @@ NGT::Index::append(const string &database, const float *data, size_t dataSize, s
|
|||
NGTThrowException("Index::append: No data.");
|
||||
}
|
||||
timer.stop();
|
||||
cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
|
||||
// cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
// cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("Data loading time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
|
||||
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(index.getObjectRepositorySize() - 1));
|
||||
}
|
||||
timer.reset();
|
||||
timer.start();
|
||||
index.createIndex(threadSize);
|
||||
timer.stop();
|
||||
index.saveIndex(database);
|
||||
cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
// cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Index creation time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -368,13 +396,19 @@ NGT::Index::remove(const string &database, vector<ObjectID> &objects, bool force
|
|||
try {
|
||||
index.remove(*i, force);
|
||||
} catch (Exception &err) {
|
||||
cerr << "Warning: Cannot remove the node. ID=" << *i << " : " << err.what() << endl;
|
||||
// cerr << "Warning: Cannot remove the node. ID=" << *i << " : " << err.what() << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Warning: Cannot remove the node. ID=" + std::to_string(*i) + " : " + err.what());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
timer.stop();
|
||||
cerr << "Data removing time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
|
||||
// cerr << "Data removing time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
// cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("Data removing time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
|
||||
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(index.getObjectRepositorySize() - 1));
|
||||
}
|
||||
index.saveIndex(database);
|
||||
return;
|
||||
}
|
||||
|
@ -411,8 +445,12 @@ NGT::Index::importIndex(const string &database, const string &file) {
|
|||
}
|
||||
idx->importIndex(file);
|
||||
timer.stop();
|
||||
cerr << "Data importing time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
cerr << "# of objects=" << idx->getObjectRepositorySize() - 1 << endl;
|
||||
// cerr << "Data importing time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
// cerr << "# of objects=" << idx->getObjectRepositorySize() - 1 << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("Data importing time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
|
||||
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(idx->getObjectRepositorySize() - 1));
|
||||
}
|
||||
idx->saveIndex(database);
|
||||
delete idx;
|
||||
}
|
||||
|
@ -424,8 +462,12 @@ NGT::Index::exportIndex(const string &database, const string &file) {
|
|||
timer.start();
|
||||
idx.exportIndex(file);
|
||||
timer.stop();
|
||||
cerr << "Data exporting time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
cerr << "# of objects=" << idx.getObjectRepositorySize() - 1 << endl;
|
||||
// cerr << "Data exporting time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl;
|
||||
// cerr << "# of objects=" << idx.getObjectRepositorySize() - 1 << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("Data exporting time=" + std::to_string(timer.time) + " (sec) " + std::to_string(timer.time * 1000.0) + " (msec)");
|
||||
(*NGT_LOG_DEBUG_)("# of objects=" + std::to_string(idx.getObjectRepositorySize() - 1));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float>
|
||||
|
@ -535,7 +577,9 @@ CreateIndexThread::run() {
|
|||
} catch(NGT::ThreadTerminationException &err) {
|
||||
break;
|
||||
} catch(NGT::Exception &err) {
|
||||
cerr << "CreateIndex::search:Error! popFront " << err.what() << endl;
|
||||
// cerr << "CreateIndex::search:Error! popFront " << err.what() << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("CreateIndex::search:Error! popFront " + std::string(err.what()));
|
||||
break;
|
||||
}
|
||||
ObjectDistances *rs = new ObjectDistances;
|
||||
|
@ -547,7 +591,9 @@ CreateIndexThread::run() {
|
|||
graphIndex.searchForNNGInsertion(obj, *rs);
|
||||
}
|
||||
} catch(NGT::Exception &err) {
|
||||
cerr << "CreateIndex::search:Fatal error! ID=" << job.id << " " << err.what() << endl;
|
||||
// cerr << "CreateIndex::search:Fatal error! ID=" << job.id << " " << err.what() << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("CreateIndex::search:Fatal error! ID=" + std::to_string(job.id) + " " + err.what());
|
||||
abort();
|
||||
}
|
||||
job.results = rs;
|
||||
|
@ -684,7 +730,9 @@ GraphAndTreeIndex::createTreeIndex()
|
|||
ObjectRepository &fr = GraphIndex::objectSpace->getRepository();
|
||||
for (size_t id = 0; id < fr.size(); id++){
|
||||
if (id % 100000 == 0) {
|
||||
cerr << " Processed id=" << id << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)(" Processed id=" + std::to_string(id));
|
||||
// cerr << " Processed id=" << id << endl;
|
||||
}
|
||||
if (fr.isEmpty(id)) {
|
||||
continue;
|
||||
|
@ -698,8 +746,12 @@ GraphAndTreeIndex::createTreeIndex()
|
|||
try {
|
||||
DVPTree::insert(tiobj);
|
||||
} catch (Exception &err) {
|
||||
cerr << "GraphAndTreeIndex::createTreeIndex: Warning. ID=" << id << ":";
|
||||
cerr << err.what() << " continue.." << endl;
|
||||
// cerr << "GraphAndTreeIndex::createTreeIndex: Warning. ID=" << id << ":";
|
||||
// cerr << err.what() << " continue.." << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("GraphAndTreeIndex::createTreeIndex: Warning. ID=" + std::to_string(id) + ":");
|
||||
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue..");
|
||||
}
|
||||
}
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
GraphIndex::objectSpace->deleteObject(f);
|
||||
|
@ -840,10 +892,16 @@ insertMultipleSearchResults(GraphIndex &neighborhoodGraph,
|
|||
}
|
||||
if (static_cast<int>(gr.id) > neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForCreation &&
|
||||
static_cast<int>(gr.results->size()) < neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForCreation) {
|
||||
cerr << "createIndex: Warning. The specified number of edges could not be acquired, because the pruned parameter [-S] might be set." << endl;
|
||||
cerr << " The node id=" << gr.id << endl;
|
||||
cerr << " The number of edges for the node=" << gr.results->size() << endl;
|
||||
cerr << " The pruned parameter (edgeSizeForSearch [-S])=" << neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForSearch << endl;
|
||||
// cerr << "createIndex: Warning. The specified number of edges could not be acquired, because the pruned parameter [-S] might be set." << endl;
|
||||
// cerr << " The node id=" << gr.id << endl;
|
||||
// cerr << " The number of edges for the node=" << gr.results->size() << endl;
|
||||
// cerr << " The pruned parameter (edgeSizeForSearch [-S])=" << neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForSearch << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("createIndex: Warning. The specified number of edges could not be acquired, because the pruned parameter [-S] might be set.");
|
||||
(*NGT_LOG_DEBUG_)(" The node id=" + std::to_string(gr.id));
|
||||
(*NGT_LOG_DEBUG_)(" The number of edges for the node=" + std::to_string(gr.results->size()));
|
||||
(*NGT_LOG_DEBUG_)(" The pruned parameter (edgeSizeForSearch [-S])=" + std::to_string(neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForSearch));
|
||||
}
|
||||
}
|
||||
neighborhoodGraph.insertNode(gr.id, *gr.results);
|
||||
}
|
||||
|
@ -886,7 +944,9 @@ GraphIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository)
|
|||
// wait for the completion of the search
|
||||
threads.waitForFinish();
|
||||
if (output.size() != cnt) {
|
||||
cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
|
||||
// cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong.");
|
||||
cnt = output.size();
|
||||
}
|
||||
// insertion
|
||||
|
@ -903,7 +963,9 @@ GraphIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository)
|
|||
count += cnt;
|
||||
if (timerCount <= count) {
|
||||
timer.stop();
|
||||
cerr << "Processed " << timerCount << " time= " << timer << endl;
|
||||
// cerr << "Processed " << timerCount << " time= " << timer << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(timerCount )+ " time= " + std::to_string(timer.time));
|
||||
timerCount += timerInterval;
|
||||
timer.start();
|
||||
}
|
||||
|
@ -956,13 +1018,17 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
|
|||
try {
|
||||
node = outGraph.getNode(id);
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "ngt info: Error. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "ngt info: Error. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ngt info: Error. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
valid = false;
|
||||
continue;
|
||||
}
|
||||
numberOfNodes++;
|
||||
if (numberOfNodes % 1000000 == 0) {
|
||||
std::cerr << "Processed " << numberOfNodes << std::endl;
|
||||
// std::cerr << "Processed " << numberOfNodes << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(numberOfNodes));
|
||||
}
|
||||
size_t esize = node->size() > edgeSize ? edgeSize : node->size();
|
||||
if (esize == 0) {
|
||||
|
@ -988,7 +1054,9 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
|
|||
NGT::ObjectDistance &n = (*node)[i];
|
||||
#endif
|
||||
if (n.id == 0) {
|
||||
std::cerr << "ngt info: Warning. id is zero." << std::endl;
|
||||
// std::cerr << "ngt info: Warning. id is zero." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ngt info: Warning. id is zero.");
|
||||
valid = false;
|
||||
}
|
||||
indegreeCount[n.id]++;
|
||||
|
@ -1029,10 +1097,15 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
|
|||
} catch(NGT::Exception &err) {
|
||||
count++;
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "Directed edge! " << id << "->" << node.at(i, graph.allocator).id << " no object. "
|
||||
<< node.at(i, graph.allocator).id << std::endl;
|
||||
// std::cerr << "Directed edge! " << id << "->" << node.at(i, graph.allocator).id << " no object. "
|
||||
// << node.at(i, graph.allocator).id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Directed edge! " + std::to_string(id) + "->" + std::to_string(node.at(i, graph.allocator).id) + " no object. "
|
||||
+ std::to_string(node.at(i, graph.allocator).id));
|
||||
#else
|
||||
std::cerr << "Directed edge! " << id << "->" << node[i].id << " no object. " << node[i].id << std::endl;
|
||||
// std::cerr << "Directed edge! " << id << "->" << node[i].id << " no object. " << node[i].id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Directed edge! " + std::to_string(id) + "->" + std::to_string(node[i].id) + " no object. " + std::to_string(node[i].id));
|
||||
#endif
|
||||
continue;
|
||||
}
|
||||
|
@ -1050,16 +1123,23 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
|
|||
}
|
||||
if (!found) {
|
||||
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
|
||||
std::cerr << "Directed edge! " << id << "->" << node.at(i, graph.allocator).id << " no edge. "
|
||||
<< node.at(i, graph.allocator).id << "->" << id << std::endl;
|
||||
// std::cerr << "Directed edge! " << id << "->" << node.at(i, graph.allocator).id << " no edge. "
|
||||
// << node.at(i, graph.allocator).id << "->" << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Directed edge! " + std::to_string(id) + "->" + std::to_string(node.at(i, graph.allocator).id) + " no edge. "
|
||||
+ std::to_string(node.at(i, graph.allocator).id) + "->" + std::to_string(id));
|
||||
#else
|
||||
std::cerr << "Directed edge! " << id << "->" << node[i].id << " no edge. " << node[i].id << "->" << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Directed edge! " + std::to_string(id) + "->" + std::to_string(node[i].id) + " no edge. " + std::to_string(node[i].id) + "->" + std::to_string(id));
|
||||
// std::cerr << "Directed edge! " << id << "->" << node[i].id << " no edge. " << node[i].id << "->" << id << std::endl;
|
||||
#endif
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cerr << "The number of directed edges=" << count << std::endl;
|
||||
// std::cerr << "The number of directed edges=" << count << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("The number of directed edges=" + std::to_string(count));
|
||||
}
|
||||
|
||||
// calculate outdegree distance 10
|
||||
|
@ -1075,7 +1155,9 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
|
|||
try {
|
||||
n = outGraph.getNode(id);
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "ngt info: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "ngt info: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ngt info: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
continue;
|
||||
}
|
||||
NGT::GraphNode &node = *n;
|
||||
|
@ -1131,7 +1213,9 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
|
|||
try {
|
||||
node = outGraph.getNode(id);
|
||||
} catch(NGT::Exception &err) {
|
||||
std::cerr << "ngt info: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
// std::cerr << "ngt info: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ngt info: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
|
||||
continue;
|
||||
}
|
||||
size_t esize = node->size();
|
||||
|
@ -1148,7 +1232,9 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
|
|||
}
|
||||
if (indegreeCount[id] == 0) {
|
||||
numberOfNodesWithoutIndegree++;
|
||||
std::cerr << "Error! The node without incoming edges. " << id << std::endl;
|
||||
// std::cerr << "Error! The node without incoming edges. " << id << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Error! The node without incoming edges. " + std::to_string(id));
|
||||
valid = false;
|
||||
}
|
||||
if (indegreeCount[id] > static_cast<int>(maxNumberOfIndegree)) {
|
||||
|
@ -1229,6 +1315,47 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
|
|||
c5 /= (double)numberOfNodes * 0.05;
|
||||
c1 /= (double)numberOfNodes * 0.01;
|
||||
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("The size of the object repository (not the number of the objects):\t" + std::to_string(repo.size() - 1));
|
||||
(*NGT_LOG_DEBUG_)("The number of the removed objects:\t" + std::to_string(removedObjectCount) + "/" + std::to_string(repo.size() - 1));
|
||||
(*NGT_LOG_DEBUG_)("The number of the nodes:\t" + std::to_string(numberOfNodes));
|
||||
(*NGT_LOG_DEBUG_)("The number of the edges:\t" + std::to_string(numberOfOutdegree));
|
||||
(*NGT_LOG_DEBUG_)("The mean of the edge lengths:\t" + std::to_string(distance / (double)numberOfOutdegree));
|
||||
(*NGT_LOG_DEBUG_)("The mean of the number of the edges per node:\t" + std::to_string((double)numberOfOutdegree / (double)numberOfNodes));
|
||||
(*NGT_LOG_DEBUG_)("The number of the nodes without edges:\t" + std::to_string(numberOfNodesWithoutEdges));
|
||||
(*NGT_LOG_DEBUG_)("The maximum of the outdegrees:\t" + std::to_string(maxNumberOfOutdegree));
|
||||
if (minNumberOfOutdegree == SIZE_MAX) {
|
||||
(*NGT_LOG_DEBUG_)("The minimum of the outdegrees:\t-NA-");
|
||||
} else {
|
||||
(*NGT_LOG_DEBUG_)("The minimum of the outdegrees:\t" + std::to_string(minNumberOfOutdegree));
|
||||
}
|
||||
(*NGT_LOG_DEBUG_)("The number of the nodes where indegree is 0:\t" + std::to_string(numberOfNodesWithoutIndegree));
|
||||
(*NGT_LOG_DEBUG_)("The maximum of the indegrees:\t" + std::to_string(maxNumberOfIndegree));
|
||||
if (minNumberOfIndegree == INT64_MAX) {
|
||||
(*NGT_LOG_DEBUG_)("The minimum of the indegrees:\t-NA-");
|
||||
} else {
|
||||
(*NGT_LOG_DEBUG_)("The minimum of the indegrees:\t" + std::to_string(minNumberOfIndegree));
|
||||
}
|
||||
(*NGT_LOG_DEBUG_)("#-nodes,#-edges,#-no-indegree,avg-edges,avg-dist,max-out,min-out,v-out,max-in,min-in,v-in,med-out,"
|
||||
"med-in,mode-out,mode-in,c95,c5,o-distance(10),o-skip,i-distance(10),i-skip:"
|
||||
+ std::to_string(numberOfNodes) + ":" + std::to_string(numberOfOutdegree) + ":" + std::to_string(numberOfNodesWithoutIndegree) + ":"
|
||||
+ std::to_string((double)numberOfOutdegree / (double)numberOfNodes) + ":"
|
||||
+ std::to_string(distance / (double)numberOfOutdegree) + ":"
|
||||
+ std::to_string(maxNumberOfOutdegree) + ":" + std::to_string(minNumberOfOutdegree) + ":" + std::to_string(sumOfSquareOfOutdegree / (double)numberOfOutdegree) + ":"
|
||||
+ std::to_string(maxNumberOfIndegree) + ":" + std::to_string(minNumberOfIndegree) + ":" + std::to_string(sumOfSquareOfIndegree / (double)numberOfOutdegree) + ":"
|
||||
+ std::to_string(medianOutdegree) + ":" + std::to_string(medianIndegree) + ":" + std::to_string(modeOutdegree) + ":" + std::to_string(modeIndegree)
|
||||
+ ":" + std::to_string(c95) + ":" + std::to_string(c5) + ":" + std::to_string(c99) + ":" + std::to_string(c1) + ":" + std::to_string(distance10) + ":" + std::to_string(d10SkipCount) + ":"
|
||||
+ std::to_string(indegreeDistance10) + ":" + std::to_string(ind10SkipCount));
|
||||
if (mode == 'h') {
|
||||
(*NGT_LOG_DEBUG_)("#\tout\tin");
|
||||
for (size_t i = 0; i < outdegreeHistogram.size() || i < indegreeHistogram.size(); i++) {
|
||||
size_t out = outdegreeHistogram.size() <= i ? 0 : outdegreeHistogram[i];
|
||||
size_t in = indegreeHistogram.size() <= i ? 0 : indegreeHistogram[i];
|
||||
(*NGT_LOG_DEBUG_)(std::to_string(i) + "\t" + std::to_string(out) + "\t" + std::to_string(in));
|
||||
}
|
||||
}
|
||||
}
|
||||
/*
|
||||
std::cerr << "The size of the object repository (not the number of the objects):\t" << repo.size() - 1 << std::endl;
|
||||
std::cerr << "The number of the removed objects:\t" << removedObjectCount << "/" << repo.size() - 1 << std::endl;
|
||||
std::cerr << "The number of the nodes:\t" << numberOfNodes << std::endl;
|
||||
|
@ -1250,13 +1377,13 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
|
|||
std::cerr << "The minimum of the indegrees:\t" << minNumberOfIndegree << std::endl;
|
||||
}
|
||||
std::cerr << "#-nodes,#-edges,#-no-indegree,avg-edges,avg-dist,max-out,min-out,v-out,max-in,min-in,v-in,med-out,"
|
||||
"med-in,mode-out,mode-in,c95,c5,o-distance(10),o-skip,i-distance(10),i-skip:"
|
||||
<< numberOfNodes << ":" << numberOfOutdegree << ":" << numberOfNodesWithoutIndegree << ":"
|
||||
"med-in,mode-out,mode-in,c95,c5,o-distance(10),o-skip,i-distance(10),i-skip:"
|
||||
<< numberOfNodes << ":" << numberOfOutdegree << ":" << numberOfNodesWithoutIndegree << ":"
|
||||
<< std::setprecision(10) << (double)numberOfOutdegree / (double)numberOfNodes << ":"
|
||||
<< distance / (double)numberOfOutdegree << ":"
|
||||
<< maxNumberOfOutdegree << ":" << minNumberOfOutdegree << ":" << sumOfSquareOfOutdegree / (double)numberOfOutdegree<< ":"
|
||||
<< maxNumberOfIndegree << ":" << minNumberOfIndegree << ":" << sumOfSquareOfIndegree / (double)numberOfOutdegree << ":"
|
||||
<< medianOutdegree << ":" << medianIndegree << ":" << modeOutdegree << ":" << modeIndegree
|
||||
<< medianOutdegree << ":" << medianIndegree << ":" << modeOutdegree << ":" << modeIndegree
|
||||
<< ":" << c95 << ":" << c5 << ":" << c99 << ":" << c1 << ":" << distance10 << ":" << d10SkipCount << ":"
|
||||
<< indegreeDistance10 << ":" << ind10SkipCount << std::endl;
|
||||
if (mode == 'h') {
|
||||
|
@ -1267,6 +1394,7 @@ NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, siz
|
|||
std::cerr << i << "\t" << out << "\t" << in << std::endl;
|
||||
}
|
||||
}
|
||||
*/
|
||||
return valid;
|
||||
}
|
||||
|
||||
|
@ -1309,7 +1437,9 @@ GraphAndTreeIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository)
|
|||
threads.waitForFinish();
|
||||
|
||||
if (output.size() != cnt) {
|
||||
cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
|
||||
// cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong.");
|
||||
cnt = output.size();
|
||||
}
|
||||
|
||||
|
@ -1331,12 +1461,16 @@ GraphAndTreeIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository)
|
|||
GraphIndex::objectSpace->deleteObject(f);
|
||||
#endif
|
||||
} catch (Exception &err) {
|
||||
cerr << "NGT::createIndex: Fatal error. ID=" << job.id << ":";
|
||||
// cerr << "NGT::createIndex: Fatal error. ID=" << job.id << ":";
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("NGT::createIndex: Fatal error. ID=" + std::to_string(job.id) + ":");
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
GraphIndex::objectSpace->deleteObject(f);
|
||||
#endif
|
||||
if (NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeKNNG) {
|
||||
cerr << err.what() << " continue.." << endl;
|
||||
// cerr << err.what() << " continue.." << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue..");
|
||||
} else {
|
||||
throw err;
|
||||
}
|
||||
|
@ -1354,10 +1488,14 @@ GraphAndTreeIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository)
|
|||
|
||||
count += cnt;
|
||||
if (timerCount <= count) {
|
||||
timer.stop();
|
||||
cerr << "Processed " << timerCount << " objects. time= " << timer << endl;
|
||||
timerCount += timerInterval;
|
||||
timer.start();
|
||||
timer.stop();
|
||||
// cerr << "Processed " << timerCount << " objects. time= " << timer << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)(
|
||||
"Processed " + std::to_string(timerCount) + " objects. time= " + std::to_string(timer.time));
|
||||
}
|
||||
timerCount += timerInterval;
|
||||
timer.start();
|
||||
}
|
||||
buildTimeController.adjustEdgeSize(count);
|
||||
if (pathAdjustCount > 0 && pathAdjustCount <= count) {
|
||||
|
@ -1384,7 +1522,9 @@ GraphAndTreeIndex::createIndex(const vector<pair<NGT::Object*, size_t> > &object
|
|||
size_t count = 0;
|
||||
timer.start();
|
||||
if (threadPoolSize <= 0) {
|
||||
cerr << "Not implemented!!" << endl;
|
||||
// cerr << "Not implemented!!" << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Not implemented!!");
|
||||
abort();
|
||||
} else {
|
||||
CreateIndexThreadPool threads(threadPoolSize);
|
||||
|
@ -1422,7 +1562,9 @@ GraphAndTreeIndex::createIndex(const vector<pair<NGT::Object*, size_t> > &object
|
|||
}
|
||||
threads.waitForFinish();
|
||||
if (output.size() != cnt) {
|
||||
cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
|
||||
// cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong.");
|
||||
cnt = output.size();
|
||||
}
|
||||
{
|
||||
|
@ -1488,19 +1630,25 @@ GraphAndTreeIndex::createIndex(const vector<pair<NGT::Object*, size_t> > &object
|
|||
GraphIndex::objectSpace->deleteObject(f);
|
||||
#endif
|
||||
} catch (Exception &err) {
|
||||
cerr << "NGT::createIndex: Fatal error. ID=" << job.id << ":" << err.what();
|
||||
// cerr << "NGT::createIndex: Fatal error. ID=" << job.id << ":" << err.what();
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("NGT::createIndex: Fatal error. ID=" + std::to_string(job.id) + ":" + err.what());
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
GraphIndex::objectSpace->deleteObject(f);
|
||||
#endif
|
||||
if (NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeKNNG) {
|
||||
cerr << err.what() << " continue.." << endl;
|
||||
// cerr << err.what() << " continue.." << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue..");
|
||||
} else {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (((*job.results).size() == 0) && (job.id != 1)) {
|
||||
cerr << "insert warning!! No searched nodes!. If the first time, no problem. " << job.id << endl;
|
||||
// cerr << "insert warning!! No searched nodes!. If the first time, no problem. " << job.id << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("insert warning!! No searched nodes!. If the first time, no problem. " + std::to_string(job.id));
|
||||
}
|
||||
GraphIndex::insertNode(job.id, *job.results);
|
||||
}
|
||||
|
@ -1513,13 +1661,17 @@ GraphAndTreeIndex::createIndex(const vector<pair<NGT::Object*, size_t> > &object
|
|||
count += cnt;
|
||||
if (timerCount <= count) {
|
||||
timer.stop();
|
||||
cerr << "Processed " << timerCount << " time= " << timer << endl;
|
||||
// cerr << "Processed " << timerCount << " time= " << timer << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(timerCount) + " time= " + std::to_string(timer.time));
|
||||
timerCount += timerInterval;
|
||||
timer.start();
|
||||
}
|
||||
}
|
||||
} catch(Exception &err) {
|
||||
cerr << "thread terminate!" << endl;
|
||||
// cerr << "thread terminate!" << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("thread terminate!");
|
||||
threads.terminate();
|
||||
throw err;
|
||||
}
|
||||
|
@ -1560,18 +1712,26 @@ bool
|
|||
GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
|
||||
bool valid = GraphIndex::verify(status, info);
|
||||
if (!valid) {
|
||||
cerr << "The graph or object is invalid!" << endl;
|
||||
// cerr << "The graph or object is invalid!" << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("The graph or object is invalid!");
|
||||
}
|
||||
bool treeValid = DVPTree::verify(GraphIndex::objectSpace->getRepository().size(), status);
|
||||
if (!treeValid) {
|
||||
cerr << "The tree is invalid" << endl;
|
||||
// cerr << "The tree is invalid" << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("The tree is invalid");
|
||||
}
|
||||
valid = valid && treeValid;
|
||||
// status: tree|graph|object
|
||||
cerr << "Started checking consistency..." << endl;
|
||||
// cerr << "Started checking consistency..." << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Started checking consistency...");
|
||||
for (size_t id = 1; id < status.size(); id++) {
|
||||
if (id % 100000 == 0) {
|
||||
cerr << "The number of processed objects=" << id << endl;
|
||||
// cerr << "The number of processed objects=" << id << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("The number of processed objects=" + std::to_string(id));
|
||||
}
|
||||
if (status[id] != 0x00 && status[id] != 0x07) {
|
||||
if (status[id] == 0x03) {
|
||||
|
@ -1596,7 +1756,9 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
|
|||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
GraphIndex::objectSpace->deleteObject(po);
|
||||
#endif
|
||||
cerr << "Fatal Error!: Cannot search! " << err.what() << endl;
|
||||
// cerr << "Fatal Error!: Cannot search! " << err.what() << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Fatal Error!: Cannot search! " + std::string(err.what()));
|
||||
objects.clear();
|
||||
}
|
||||
size_t n = 0;
|
||||
|
@ -1609,7 +1771,9 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
|
|||
}
|
||||
if (!registeredIdenticalObject) {
|
||||
if (info) {
|
||||
cerr << "info: not found the registered same objects. id=" << id << " size=" << objects.size() << endl;
|
||||
// cerr << "info: not found the registered same objects. id=" << id << " size=" << objects.size() << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("info: not found the registered same objects. id=" + std::to_string(id) + " size=" + std::to_string(objects.size()));
|
||||
}
|
||||
sc.id = 0;
|
||||
sc.radius = FLT_MAX;
|
||||
|
@ -1622,7 +1786,10 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
|
|||
try {
|
||||
GraphIndex::search(sc, seeds);
|
||||
} catch(Exception &err) {
|
||||
cerr << "Fatal Error!: Cannot search! " << err.what() << endl;
|
||||
// cerr << "Fatal Error!: Cannot search! " << err.what() << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("Fatal Error!: Cannot search! " + std::string(err.what()));
|
||||
}
|
||||
objects.clear();
|
||||
}
|
||||
registeredIdenticalObject = false;
|
||||
|
@ -1631,7 +1798,9 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
|
|||
if (objects[n].id != id && status[objects[n].id] == 0x07) {
|
||||
registeredIdenticalObject = true;
|
||||
if (info) {
|
||||
cerr << "info: found by using mode accurate search. " << objects[n].id << endl;
|
||||
// cerr << "info: found by using mode accurate search. " << objects[n].id << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("info: found by using mode accurate search. " + std::to_string(objects[n].id));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -1639,7 +1808,9 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
|
|||
}
|
||||
if (!registeredIdenticalObject && mode != 's') {
|
||||
if (info) {
|
||||
cerr << "info: not found by using more accurate search." << endl;
|
||||
// cerr << "info: not found by using more accurate search." << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("info: not found by using more accurate search.");
|
||||
}
|
||||
sc.id = 0;
|
||||
sc.radius = 0.0;
|
||||
|
@ -1655,7 +1826,9 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
|
|||
if (objects[n].id != id && status[objects[n].id] == 0x07) {
|
||||
registeredIdenticalObject = true;
|
||||
if (info) {
|
||||
cerr << "info: found by using linear search. " << objects[n].id << endl;
|
||||
// cerr << "info: found by using linear search. " << objects[n].id << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("info: found by using linear search. " + std::to_string(objects[n].id));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -1666,8 +1839,12 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
|
|||
#endif
|
||||
if (registeredIdenticalObject) {
|
||||
if (info) {
|
||||
cerr << "Info ID=" << id << ":" << static_cast<int>(status[id]) << endl;
|
||||
cerr << " found the valid same objects. " << objects[n].id << endl;
|
||||
// cerr << "Info ID=" << id << ":" << static_cast<int>(status[id]) << endl;
|
||||
// cerr << " found the valid same objects. " << objects[n].id << endl;
|
||||
if (NGT_LOG_DEBUG_){
|
||||
(*NGT_LOG_DEBUG_)("Info ID=" + std::to_string(id) + ":" + std::to_string(static_cast<int>(status[id])));
|
||||
(*NGT_LOG_DEBUG_)(" found the valid same objects. " + std::to_string(objects[n].id));
|
||||
}
|
||||
}
|
||||
GraphNode &fromNode = *GraphIndex::getNode(id);
|
||||
bool fromFound = false;
|
||||
|
@ -1694,40 +1871,67 @@ GraphAndTreeIndex::verify(vector<uint8_t> &status, bool info, char mode) {
|
|||
if (!fromFound || !toFound) {
|
||||
if (info) {
|
||||
if (!fromFound && !toFound) {
|
||||
cerr << "Warning no undirected edge between " << id << "(" << fromNode.size() << ") and "
|
||||
<< objects[n].id << "(" << toNode.size() << ")." << endl;
|
||||
// cerr << "Warning no undirected edge between " << id << "(" << fromNode.size() << ") and "
|
||||
// << objects[n].id << "(" << toNode.size() << ")." << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Warning no undirected edge between " + std::to_string(id) + "(" + std::to_string(fromNode.size()) + ") and "
|
||||
+ std::to_string(objects[n].id) + "(" + std::to_string(toNode.size()) + ").");
|
||||
} else if (!fromFound) {
|
||||
cerr << "Warning no directed edge from " << id << "(" << fromNode.size() << ") to "
|
||||
<< objects[n].id << "(" << toNode.size() << ")." << endl;
|
||||
// cerr << "Warning no directed edge from " << id << "(" << fromNode.size() << ") to "
|
||||
// << objects[n].id << "(" << toNode.size() << ")." << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Warning no directed edge from " + std::to_string(id) + "(" + std::to_string(fromNode.size()) + ") to "
|
||||
+ std::to_string(objects[n].id) + "(" + std::to_string(toNode.size()) + ").");
|
||||
} else if (!toFound) {
|
||||
cerr << "Warning no reverse directed edge from " << id << "(" << fromNode.size() << ") to "
|
||||
<< objects[n].id << "(" << toNode.size() << ")." << endl;
|
||||
// cerr << "Warning no reverse directed edge from " << id << "(" << fromNode.size() << ") to "
|
||||
// << objects[n].id << "(" << toNode.size() << ")." << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Warning no reverse directed edge from " + std::to_string(id) + "(" + std::to_string(fromNode.size()) + ") to "
|
||||
+ std::to_string(objects[n].id) + "(" + std::to_string(toNode.size()) + ").");
|
||||
}
|
||||
}
|
||||
if (!findPathAmongIdenticalObjects(*this, id, objects[n].id)) {
|
||||
cerr << "Warning no path from " << id << " to " << objects[n].id << endl;
|
||||
// cerr << "Warning no path from " << id << " to " << objects[n].id << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Warning no path from " + std::to_string(id) + " to " + std::to_string(objects[n].id));
|
||||
}
|
||||
if (!findPathAmongIdenticalObjects(*this, objects[n].id, id)) {
|
||||
cerr << "Warning no reverse path from " << id << " to " << objects[n].id << endl;
|
||||
// cerr << "Warning no reverse path from " << id << " to " << objects[n].id << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Warning no reverse path from " + std::to_string(id) + " to " + std::to_string(objects[n].id));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (mode == 's') {
|
||||
cerr << "Warning: not found the valid same object, but not try to use linear search." << endl;
|
||||
cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
|
||||
// cerr << "Warning: not found the valid same object, but not try to use linear search." << endl;
|
||||
// cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("Warning: not found the valid same object, but not try to use linear search.");
|
||||
(*NGT_LOG_DEBUG_)("Error! ID=" + std::to_string(id) + ":" + std::to_string(static_cast<int>(status[id])));
|
||||
}
|
||||
} else {
|
||||
cerr << "Warning: not found the valid same object even by using linear search." << endl;
|
||||
cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
|
||||
// cerr << "Warning: not found the valid same object even by using linear search." << endl;
|
||||
// cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("Warning: not found the valid same object even by using linear search.");
|
||||
(*NGT_LOG_DEBUG_)("Error! ID=" + std::to_string(id) + ":" + std::to_string(static_cast<int>(status[id])));
|
||||
}
|
||||
valid = false;
|
||||
}
|
||||
}
|
||||
} else if (status[id] == 0x01) {
|
||||
if (info) {
|
||||
cerr << "Warning! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
|
||||
cerr << " not inserted into the indexes" << endl;
|
||||
// cerr << "Warning! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
|
||||
// cerr << " not inserted into the indexes" << endl;
|
||||
if (NGT_LOG_DEBUG_) {
|
||||
(*NGT_LOG_DEBUG_)("Warning! ID=" + std::to_string(id) + ":" + std::to_string(static_cast<int>(status[id])));
|
||||
(*NGT_LOG_DEBUG_)(" not inserted into the indexes");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
|
||||
// cerr << "Error! ID=" << id << ":" << static_cast<int>(status[id]) << endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Error! ID=" + std::to_string(id) + ":" + std::to_string(static_cast<int>(status[id])));
|
||||
valid = false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -140,6 +140,9 @@ public:
|
|||
case DistanceType::DistanceTypeL2:
|
||||
p.set("DistanceType", "L2");
|
||||
break;
|
||||
case DistanceType::DistanceTypeIP:
|
||||
p.set("DistanceType", "IP");
|
||||
break;
|
||||
case DistanceType::DistanceTypeHamming:
|
||||
p.set("DistanceType", "Hamming");
|
||||
break;
|
||||
|
@ -255,6 +258,10 @@ public:
|
|||
{
|
||||
distanceType = DistanceType::DistanceTypeL2;
|
||||
}
|
||||
else if (it->second == "IP")
|
||||
{
|
||||
distanceType = DistanceType::DistanceTypeIP;
|
||||
}
|
||||
else if (it->second == "Hamming")
|
||||
{
|
||||
distanceType = DistanceType::DistanceTypeHamming;
|
||||
|
@ -379,6 +386,7 @@ public:
|
|||
|
||||
void set(NGT::Property & prop);
|
||||
void get(NGT::Property & prop);
|
||||
int64_t memSize() { return sizeof(*this); }
|
||||
int dimension;
|
||||
int threadPoolSize;
|
||||
ObjectSpace::ObjectType objectType;
|
||||
|
@ -403,6 +411,7 @@ public:
|
|||
public:
|
||||
InsertionResult() : id(0), identical(false), distance(0.0) {}
|
||||
InsertionResult(size_t i, bool tf, Distance d) : id(i), identical(tf), distance(d) {}
|
||||
int64_t memSize() { return sizeof(*this); }
|
||||
size_t id;
|
||||
bool identical;
|
||||
Distance distance; // the distance between the centroid and the inserted object.
|
||||
|
@ -415,6 +424,7 @@ public:
|
|||
AccuracyTable(std::vector<std::pair<float, double>> & t) { set(t); }
|
||||
AccuracyTable(std::string str) { set(str); }
|
||||
void set(std::vector<std::pair<float, double>> & t) { table = t; }
|
||||
int64_t memSize() { return sizeof(*this) + table.capacity() * sizeof(table[0]); }
|
||||
void set(std::string str)
|
||||
{
|
||||
std::vector<std::string> tokens;
|
||||
|
@ -645,7 +655,7 @@ public:
|
|||
virtual void linearSearch(NGT::SearchContainer & sc) { getIndex().linearSearch(sc); }
|
||||
virtual void linearSearch(NGT::SearchQuery & sc) { getIndex().linearSearch(sc); }
|
||||
// for milvus
|
||||
virtual void search(NGT::SearchContainer & sc, const faiss::ConcurrentBitsetPtr & bitset) { getIndex().search(sc, bitset); }
|
||||
virtual void search(NGT::SearchContainer & sc, const faiss::BitsetView&bitset) { getIndex().search(sc, bitset); }
|
||||
virtual void search(NGT::SearchContainer & sc) { getIndex().search(sc); }
|
||||
virtual void search(NGT::SearchQuery & sc) { getIndex().search(sc); }
|
||||
virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds) { getIndex().search(sc, seeds); }
|
||||
|
@ -711,6 +721,8 @@ public:
|
|||
static std::string getVersion();
|
||||
std::string getPath() { return path; }
|
||||
|
||||
virtual int64_t memSize() { return redirector.memSize() + sizeof(path) + (index ? getIndex().memSize() : 0); }
|
||||
|
||||
protected:
|
||||
Object * allocateObject(void * vec, const std::type_info & objectType)
|
||||
{
|
||||
|
@ -1046,7 +1058,7 @@ public:
|
|||
}
|
||||
|
||||
// for milvus
|
||||
virtual void search(NGT::SearchContainer & sc, const faiss::ConcurrentBitsetPtr & bitset)
|
||||
virtual void search(NGT::SearchContainer & sc, const faiss::BitsetView&bitset)
|
||||
{
|
||||
sc.distanceComputationCount = 0;
|
||||
sc.visitCount = 0;
|
||||
|
@ -1574,7 +1586,7 @@ protected:
|
|||
}
|
||||
|
||||
// for milvus
|
||||
virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::ConcurrentBitsetPtr & bitset)
|
||||
virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView&bitset)
|
||||
{
|
||||
if (sc.size == 0)
|
||||
{
|
||||
|
@ -1628,6 +1640,8 @@ protected:
|
|||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
virtual int64_t memSize() { return property.memSize() + sizeof(readOnly) + accuracyTable.memSize() + Index::memSize() + NeighborhoodGraph::memSize(); }
|
||||
Index::Property property;
|
||||
|
||||
bool readOnly;
|
||||
|
@ -1641,6 +1655,9 @@ protected:
|
|||
class GraphAndTreeIndex : public GraphIndex, public DVPTree
|
||||
{
|
||||
public:
|
||||
|
||||
virtual int64_t memSize() { return GraphIndex::memSize() + DVPTree::memSize(); }
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
GraphAndTreeIndex(const std::string & allocator, bool rdOnly = false) : GraphIndex(allocator, false) { initialize(allocator, 0); }
|
||||
GraphAndTreeIndex(const std::string & allocator, NGT::Property & prop);
|
||||
|
@ -2130,7 +2147,7 @@ public:
|
|||
|
||||
// for milvus
|
||||
void
|
||||
getSeedsFromTree(NGT::SearchContainer& sc, ObjectDistances& seeds, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
getSeedsFromTree(NGT::SearchContainer& sc, ObjectDistances& seeds, const faiss::BitsetView& bitset) {
|
||||
DVPTree::SearchContainer tso(sc.object);
|
||||
tso.mode = DVPTree::SearchContainer::SearchLeaf;
|
||||
tso.radius = 0.0;
|
||||
|
@ -2187,7 +2204,7 @@ public:
|
|||
}
|
||||
|
||||
// for milvus
|
||||
void search(NGT::SearchContainer & sc, const faiss::ConcurrentBitsetPtr & bitset)
|
||||
void search(NGT::SearchContainer & sc, const faiss::BitsetView&bitset)
|
||||
{
|
||||
sc.distanceComputationCount = 0;
|
||||
sc.visitCount = 0;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "defines.h"
|
||||
#include "MmapManagerDefs.h"
|
||||
#include "MmapManagerException.h"
|
||||
|
||||
|
@ -132,7 +133,9 @@ namespace MemoryManager{
|
|||
const off_t old_file_size = mmapCntlHead->base_size * mmapCntlHead->unit_num;
|
||||
|
||||
if(new_unit_num >= MMAP_MAX_UNIT_NUM){
|
||||
std::cerr << "over max unit num" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("over max unit num");
|
||||
// std::cerr << "over max unit num" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -151,10 +154,14 @@ namespace MemoryManager{
|
|||
throw MmapManagerException("truncate error" + err_str);
|
||||
}
|
||||
|
||||
if(close(fd) == -1) std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
|
||||
if(close(fd) == -1 && NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)(filePath + "[WARN] : filedescript cannot close");
|
||||
// std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
|
||||
throw MmapManagerException("mmap error" + err_str);
|
||||
}
|
||||
if(close(fd) == -1) std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
|
||||
if(close(fd) == -1 && NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)(filePath + "[WARN] : filedescript cannot close");
|
||||
// std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
|
||||
|
||||
mmapDataAddr[mmapCntlHead->unit_num] = new_area;
|
||||
|
||||
|
@ -179,14 +186,18 @@ namespace MemoryManager{
|
|||
if(lseek(fd, (off_t)size-1, SEEK_SET) < 0){
|
||||
std::stringstream ss;
|
||||
ss << "[ERR] Cannot seek the file. " << targetFile << " " << getErrorStr(errno);
|
||||
if(close(fd) == -1) std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
if(close(fd) == -1 && NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)(targetFile + "[WARN] : filedescript cannot close");
|
||||
// std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
throw MmapManagerException(ss.str());
|
||||
}
|
||||
errno = 0;
|
||||
if(write(fd, &c, sizeof(char)) == -1){
|
||||
std::stringstream ss;
|
||||
ss << "[ERR] Cannot write the file. Check the disk space. " << targetFile << " " << getErrorStr(errno);
|
||||
if(close(fd) == -1) std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
if(close(fd) == -1 && NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)(targetFile + "[WARN] : filedescript cannot close");
|
||||
// std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
|
||||
throw MmapManagerException(ss.str());
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "NGT/Index.h"
|
||||
#include "NGT/ArrayFile.h"
|
||||
#include "NGT/Clustering.h"
|
||||
#include "NGT/defines.h"
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -61,6 +61,7 @@ namespace NGT {
|
|||
void deserialize(std::stringstream & is) { NGT::Serializer::read(is, id); }
|
||||
void serializeAsText(std::ofstream &os) { NGT::Serializer::writeAsText(os, id); }
|
||||
void deserializeAsText(std::ifstream &is) { NGT::Serializer::readAsText(is, id); }
|
||||
virtual int64_t memSize() { return sizeof(id); }
|
||||
protected:
|
||||
NodeID id;
|
||||
};
|
||||
|
@ -69,6 +70,7 @@ namespace NGT {
|
|||
public:
|
||||
Object():object(0) {}
|
||||
bool operator<(const Object &o) const { return distance < o.distance; }
|
||||
virtual int64_t memSize() { return sizeof(*this) + object->memSize(); } // size of object cannot be decided accurately
|
||||
static const double Pivot;
|
||||
ObjectID id;
|
||||
PersistentObject *object;
|
||||
|
@ -126,6 +128,8 @@ namespace NGT {
|
|||
parent.deserializeAsText(is);
|
||||
}
|
||||
|
||||
virtual int64_t memSize() { return id.memSize() * 2 + pivot->memSize(); }
|
||||
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
void setPivot(PersistentObject &f, ObjectSpace &os, SharedMemoryAllocator &allocator) {
|
||||
if (pivot == 0) {
|
||||
|
@ -435,6 +439,8 @@ namespace NGT {
|
|||
}
|
||||
}
|
||||
|
||||
virtual int64_t memSize() { return sizeof(childrenSize) + children->memSize() + childrenSize * sizeof(Distance) + Node::memSize(); }
|
||||
|
||||
void show() {
|
||||
std::cout << "Show internal node " << childrenSize << ":";
|
||||
for (size_t i = 0; i < childrenSize; i++) {
|
||||
|
@ -747,6 +753,7 @@ namespace NGT {
|
|||
bool verify(size_t nobjs, std::vector<uint8_t> &status);
|
||||
#endif
|
||||
|
||||
virtual int64_t memSize() { return sizeof(objectSize) + objectIDs->memSize() * objectSize + Node::memSize(); }
|
||||
|
||||
#ifdef NGT_NODE_USE_VECTOR
|
||||
size_t getObjectSize() { return objectIDs.size(); }
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#pragma once
|
||||
#include <sstream>
|
||||
#include "defines.h"
|
||||
|
||||
namespace NGT {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
|
@ -104,22 +105,30 @@ namespace NGT {
|
|||
}
|
||||
|
||||
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<double> &obj) {
|
||||
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(double): Fatal error! Something wrong!" << std::endl;
|
||||
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(double): Fatal error! Something wrong!" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject(double): Fatal error! Something wrong!");
|
||||
abort();
|
||||
}
|
||||
|
||||
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<float> &obj) {
|
||||
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(float): Fatal error! Something wrong!" << std::endl;
|
||||
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(float): Fatal error! Something wrong!" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject(float): Fatal error! Something wrong!");
|
||||
abort();
|
||||
}
|
||||
|
||||
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<uint8_t> &obj) {
|
||||
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(uint8_t): Fatal error! Something wrong!" << std::endl;
|
||||
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(uint8_t): Fatal error! Something wrong!" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject(uint8_t): Fatal error! Something wrong!");
|
||||
abort();
|
||||
}
|
||||
|
||||
virtual PersistentObject *allocateNormalizedPersistentObject(const float *obj, size_t size) {
|
||||
std::cerr << "ObjectRepository::allocateNormalizedPersistentObject: Fatal error! Something wrong!" << std::endl;
|
||||
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject: Fatal error! Something wrong!" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject: Fatal error! Something wrong!");
|
||||
abort();
|
||||
}
|
||||
|
||||
|
@ -141,8 +150,11 @@ namespace NGT {
|
|||
while (getline(is, line)) {
|
||||
lineNo++;
|
||||
if (dataSize > 0 && (dataSize <= size() - prevDataSize)) {
|
||||
std::cerr << "The size of data reached the specified size. The remaining data in the file are not inserted. "
|
||||
<< dataSize << std::endl;
|
||||
// std::cerr << "The size of data reached the specified size. The remaining data in the file are not inserted. "
|
||||
// << dataSize << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("The size of data reached the specified size. The remaining data in the file are not inserted. "
|
||||
+ std::to_string(dataSize));
|
||||
break;
|
||||
}
|
||||
std::vector<double> object;
|
||||
|
@ -152,12 +164,16 @@ namespace NGT {
|
|||
try {
|
||||
obj = allocateNormalizedPersistentObject(object);
|
||||
} catch (Exception &err) {
|
||||
std::cerr << err.what() << " continue..." << std::endl;
|
||||
// std::cerr << err.what() << " continue..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue...");
|
||||
obj = allocatePersistentObject(object);
|
||||
}
|
||||
push_back(obj);
|
||||
} catch (Exception &err) {
|
||||
std::cerr << "ObjectSpace::readText: Warning! Invalid line. [" << line << "] Skip the line " << lineNo << " and continue." << std::endl;
|
||||
// std::cerr << "ObjectSpace::readText: Warning! Invalid line. [" << line << "] Skip the line " << lineNo << " and continue." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectSpace::readText: Warning! Invalid line. [" + line + "] Skip the line " + std::to_string(lineNo) + " and continue.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -186,13 +202,17 @@ namespace NGT {
|
|||
try {
|
||||
obj = allocateNormalizedPersistentObject(object);
|
||||
} catch (Exception &err) {
|
||||
std::cerr << err.what() << " continue..." << std::endl;
|
||||
// std::cerr << err.what() << " continue..." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue...");
|
||||
obj = allocatePersistentObject(object);
|
||||
}
|
||||
push_back(obj);
|
||||
|
||||
} catch (Exception &err) {
|
||||
std::cerr << "ObjectSpace::readText: Warning! Invalid data. Skip the data no. " << idx << " and continue." << std::endl;
|
||||
// std::cerr << "ObjectSpace::readText: Warning! Invalid data. Skip the data no. " << idx << " and continue." << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectSpace::readText: Warning! Invalid data. Skip the data no. " + std::to_string(idx) + " and continue.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -231,7 +251,9 @@ namespace NGT {
|
|||
char *e;
|
||||
object[idx] = strtod(tokens[idx].c_str(), &e);
|
||||
if (*e != 0) {
|
||||
std::cerr << "ObjectSpace::readText: Warning! Not numerical value. [" << e << "]" << std::endl;
|
||||
// std::cerr << "ObjectSpace::readText: Warning! Not numerical value. [" << e << "]" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectSpace::readText: Warning! Not numerical value. [" + std::string(e) + "]");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -245,8 +267,11 @@ namespace NGT {
|
|||
osize = osize < vsize ? vsize : osize;
|
||||
} else {
|
||||
if (dimension != size) {
|
||||
std::cerr << "ObjectSpace::allocateObject: Fatal error! dimension is invalid. The indexed objects="
|
||||
<< dimension << " The specified object=" << size << std::endl;
|
||||
// std::cerr << "ObjectSpace::allocateObject: Fatal error! dimension is invalid. The indexed objects="
|
||||
// << dimension << " The specified object=" << size << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectSpace::allocateObject: Fatal error! dimension is invalid. The indexed objects="
|
||||
+ std::to_string(dimension) + " The specified object=" + std::to_string(size));
|
||||
assert(dimension == size);
|
||||
}
|
||||
}
|
||||
|
@ -263,7 +288,9 @@ namespace NGT {
|
|||
obj[i] = static_cast<float>(o[i]);
|
||||
}
|
||||
} else {
|
||||
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
|
||||
abort();
|
||||
}
|
||||
return po;
|
||||
|
@ -283,7 +310,9 @@ namespace NGT {
|
|||
} else if (type == typeid(float)) {
|
||||
cpsize *= sizeof(float);
|
||||
} else {
|
||||
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
|
||||
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
PersistentObject *po = new (objectAllocator) PersistentObject(objectAllocator, paddedByteSize);
|
||||
|
@ -315,7 +344,9 @@ namespace NGT {
|
|||
obj[i] = static_cast<float>(o[i]);
|
||||
}
|
||||
} else {
|
||||
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
|
||||
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
return po;
|
||||
|
@ -361,7 +392,9 @@ namespace NGT {
|
|||
d.push_back(obj[i]);
|
||||
}
|
||||
} else {
|
||||
std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include "PrimitiveComparator.h"
|
||||
|
||||
class ObjectSpace;
|
||||
|
@ -94,6 +95,15 @@ namespace NGT {
|
|||
}
|
||||
}
|
||||
|
||||
int64_t memSize() const {
|
||||
// auto obj = (std::vector<ObjectDistance>)(*this);
|
||||
if (this->size() == 0)
|
||||
return 0;
|
||||
else {
|
||||
return (*this)[0].memSize() * this->size();
|
||||
}
|
||||
// return this->size() == 0 ? 0 : (*this)[0].memSize() * (this->size());
|
||||
}
|
||||
ObjectDistances &operator=(PersistentObjectDistances &objs);
|
||||
};
|
||||
|
||||
|
@ -169,6 +179,7 @@ namespace NGT {
|
|||
SharedMemoryAllocator &allocator;
|
||||
#endif
|
||||
virtual ~Comparator(){}
|
||||
int64_t memSize() { return sizeof(size_t); }
|
||||
};
|
||||
enum DistanceType {
|
||||
DistanceTypeNone = -1,
|
||||
|
@ -180,7 +191,8 @@ namespace NGT {
|
|||
DistanceTypeNormalizedAngle = 5,
|
||||
DistanceTypeNormalizedCosine = 6,
|
||||
DistanceTypeJaccard = 7,
|
||||
DistanceTypeSparseJaccard = 8
|
||||
DistanceTypeSparseJaccard = 8,
|
||||
DistanceTypeIP = 9
|
||||
};
|
||||
|
||||
enum ObjectType {
|
||||
|
@ -248,6 +260,7 @@ namespace NGT {
|
|||
virtual ObjectRepository &getRepository() = 0;
|
||||
|
||||
virtual void setDistanceType(DistanceType t) = 0;
|
||||
virtual DistanceType getDistanceType() = 0;
|
||||
|
||||
virtual void *getObject(size_t idx) = 0;
|
||||
virtual void getObject(size_t idx, std::vector<float> &v) = 0;
|
||||
|
@ -255,6 +268,7 @@ namespace NGT {
|
|||
|
||||
size_t getDimension() { return dimension; }
|
||||
size_t getPaddedDimension() { return ((dimension - 1) / 16 + 1) * 16; }
|
||||
virtual int64_t memSize() { return sizeof(dimension) + sizeof(distanceType) + sizeof(prefetchOffset) * 2 + sizeof(normalization) + comparator->memSize(); };
|
||||
|
||||
template <typename T>
|
||||
void normalize(T *data, size_t dim) {
|
||||
|
@ -384,6 +398,8 @@ namespace NGT {
|
|||
void *getPointer(size_t idx = 0) const { return vector + idx; }
|
||||
|
||||
static Object *allocate(ObjectSpace &objectspace) { return new Object(&objectspace); }
|
||||
|
||||
virtual int64_t memSize() { return std::strlen((char*)vector); }
|
||||
private:
|
||||
void clear() {
|
||||
if (vector != 0) {
|
||||
|
|
|
@ -73,6 +73,27 @@ namespace NGT {
|
|||
#endif
|
||||
};
|
||||
|
||||
class ComparatorIP : public Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
ComparatorIP(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareIP((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
double operator()(Object &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareIP((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
|
||||
return PrimitiveComparator::compareIP((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
|
||||
}
|
||||
#else
|
||||
ComparatorIP(size_t d) : Comparator(d) {}
|
||||
double operator()(Object &objecta, Object &objectb) {
|
||||
return PrimitiveComparator::compareInnerProduct((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class ComparatorHammingDistance : public Comparator {
|
||||
public:
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
|
@ -281,6 +302,8 @@ namespace NGT {
|
|||
objecta.copy(objectb, getByteSizeOfObject());
|
||||
}
|
||||
|
||||
DistanceType getDistanceType() { return distanceType; }
|
||||
|
||||
void setDistanceType(DistanceType t) {
|
||||
if (comparator != 0) {
|
||||
delete comparator;
|
||||
|
@ -326,6 +349,9 @@ namespace NGT {
|
|||
case DistanceTypeL2:
|
||||
comparator = new ObjectSpaceRepository::ComparatorL2(ObjectSpace::getPaddedDimension());
|
||||
break;
|
||||
case DistanceTypeIP:
|
||||
comparator = new ObjectSpaceRepository::ComparatorIP(ObjectSpace::getPaddedDimension());
|
||||
break;
|
||||
case DistanceTypeHamming:
|
||||
comparator = new ObjectSpaceRepository::ComparatorHammingDistance(ObjectSpace::getPaddedDimension());
|
||||
break;
|
||||
|
@ -352,7 +378,9 @@ namespace NGT {
|
|||
break;
|
||||
#endif
|
||||
default:
|
||||
std::cerr << "Distance type is not specified" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Distance type is not specified");
|
||||
// std::cerr << "Distance type is not specified" << std::endl;
|
||||
assert(distanceType != DistanceTypeNone);
|
||||
abort();
|
||||
}
|
||||
|
@ -587,7 +615,9 @@ namespace NGT {
|
|||
} else if (t == typeid(uint32_t)) {
|
||||
NGT::Serializer::writeAsText(os, (uint32_t*)ref, dimension);
|
||||
} else {
|
||||
std::cerr << "ObjectT::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("ObjectT::serializeAsText: not supported data type. [" + std::to_string(t.name()) + "]");
|
||||
// std::cerr << "ObjectT::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
|
@ -610,7 +640,9 @@ namespace NGT {
|
|||
} else if (t == typeid(uint32_t)) {
|
||||
NGT::Serializer::readAsText(is, (uint32_t*)ref, dimension);
|
||||
} else {
|
||||
std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
|
||||
if (NGT_LOG_DEBUG_)
|
||||
(*NGT_LOG_DEBUG_)("Object::deserializeAsText: not supported data type. [" + std::to_string(t.name()) + "]");
|
||||
// std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -195,7 +195,8 @@ class PrimitiveComparator {
|
|||
diff0 = (COMPARE_TYPE)(*a++ - *b++);
|
||||
d += diff0 * diff0;
|
||||
}
|
||||
return sqrt((double)d);
|
||||
// return sqrt((double)d);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline static double
|
||||
|
@ -264,7 +265,8 @@ class PrimitiveComparator {
|
|||
_mm_store_ps(f, sum128);
|
||||
|
||||
double s = f[0] + f[1] + f[2] + f[3];
|
||||
return sqrt(s);
|
||||
// return sqrt(s);
|
||||
return s;
|
||||
}
|
||||
|
||||
inline static double
|
||||
|
@ -290,7 +292,8 @@ class PrimitiveComparator {
|
|||
int d = (int)*a++ - (int)*b++;
|
||||
s += d * d;
|
||||
}
|
||||
return sqrt(s);
|
||||
// return sqrt(s);
|
||||
return s;
|
||||
}
|
||||
#endif
|
||||
#if defined(NGT_NO_AVX)
|
||||
|
@ -498,6 +501,17 @@ class PrimitiveComparator {
|
|||
return sum;
|
||||
}
|
||||
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareInnerProduct(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
double sum = 0.0;
|
||||
for (size_t loc = 0; loc < size; loc++) {
|
||||
sum += (double)a[loc] * (double)b[loc];
|
||||
// sum += a[loc] * b[loc];
|
||||
}
|
||||
return -sum;
|
||||
}
|
||||
|
||||
template <typename OBJECT_TYPE>
|
||||
inline static double
|
||||
compareCosine(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
|
||||
|
@ -551,6 +565,42 @@ class PrimitiveComparator {
|
|||
return s;
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareInnerProduct(const float* a, const float* b, size_t size) {
|
||||
const float* last = a + size;
|
||||
#if defined(NGT_AVX512)
|
||||
__m512 sum512 = _mm512_setzero_ps();
|
||||
while (a < last) {
|
||||
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(_mm512_loadu_ps(a), _mm512_loadu_ps(b)));
|
||||
a += 16;
|
||||
b += 16;
|
||||
}
|
||||
|
||||
__m256 sum256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum512, 0), _mm512_extractf32x8_ps(sum512, 1));
|
||||
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
|
||||
#elif defined(NGT_AVX2)
|
||||
__m256 sum256 = _mm256_setzero_ps();
|
||||
while (a < last) {
|
||||
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b)));
|
||||
a += 8;
|
||||
b += 8;
|
||||
}
|
||||
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
|
||||
#else
|
||||
__m128 sum128 = _mm_setzero_ps();
|
||||
while (a < last) {
|
||||
sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(a), _mm_loadu_ps(b)));
|
||||
a += 4;
|
||||
b += 4;
|
||||
}
|
||||
#endif
|
||||
__attribute__((aligned(32))) float f[4];
|
||||
_mm_store_ps(f, sum128);
|
||||
|
||||
double s = f[0] + f[1] + f[2] + f[3];
|
||||
return -s;
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareDotProduct(const unsigned char* a, const unsigned char* b, size_t size) {
|
||||
double sum = 0.0;
|
||||
|
@ -560,6 +610,15 @@ class PrimitiveComparator {
|
|||
return sum;
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareInnerProduct(const unsigned char* a, const unsigned char* b, size_t size) {
|
||||
double sum = 0.0;
|
||||
for (size_t loc = 0; loc < size; loc++) {
|
||||
sum += (double)a[loc] * (double)b[loc];
|
||||
}
|
||||
return -sum;
|
||||
}
|
||||
|
||||
inline static double
|
||||
compareCosine(const float* a, const float* b, size_t size) {
|
||||
const float* last = a + size;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "NGT/Node.h"
|
||||
#include "NGT/defines.h"
|
||||
#include "faiss/utils/ConcurrentBitset.h"
|
||||
#include "faiss/utils/BitsetView.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
@ -262,7 +263,7 @@ namespace NGT {
|
|||
|
||||
// for milvus
|
||||
void
|
||||
getObjectIDsFromLeaf(Node::ID nid, ObjectDistances& rl, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
getObjectIDsFromLeaf(Node::ID nid, ObjectDistances& rl, const faiss::BitsetView& bitset) {
|
||||
LeafNode& ln = *(LeafNode*)getNode(nid);
|
||||
rl.clear();
|
||||
ObjectDistance r;
|
||||
|
@ -274,7 +275,7 @@ namespace NGT {
|
|||
r.id = ln.getObjectIDs()[i].id;
|
||||
r.distance = ln.getObjectIDs()[i].distance;
|
||||
#endif
|
||||
if (bitset != nullptr && bitset->test(r.id - 1)) {
|
||||
if (!bitset.empty() && bitset.test(r.id - 1)) {
|
||||
continue;
|
||||
}
|
||||
rl.push_back(r);
|
||||
|
@ -487,6 +488,8 @@ namespace NGT {
|
|||
}
|
||||
}
|
||||
|
||||
virtual int64_t memSize() { return sizeof(size_t) * 2 + sizeof(splitMode) + name.size() + leafNodes.memSize() + internalNodes.memSize(); }
|
||||
|
||||
public:
|
||||
size_t internalChildrenSize;
|
||||
size_t leafObjectsSize;
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue