Refactor cmake and build script and add timed benchmark

Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
pull/4973/head^2
FluorineDog 2020-10-23 18:01:24 +08:00 committed by yefu.chen
parent 9d2ebe7632
commit e84b0180c9
208 changed files with 29340 additions and 1748 deletions

34
.clang-format Normal file
View File

@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
---
# Below is copied from milvus project
BasedOnStyle: Google
DerivePointerAlignment: false
ColumnLimit: 120
IndentWidth: 4
AccessModifierOffset: -3
AlwaysBreakAfterReturnType: All
AllowShortBlocksOnASingleLine: false
AllowShortFunctionsOnASingleLine: false
AllowShortIfStatementsOnASingleLine: false
AlignTrailingComments: true
# Appended Options
SortIncludes: false
Standard: Latest
AlignAfterOpenBracket: Align
BinPackParameters: false

19
.gitignore vendored
View File

@ -1,17 +1,10 @@
# CLion generated files
core/cmake-build-debug/
core/cmake-build-debug/*
core/cmake-build-release/
core/cmake-build-release/*
core/cmake_build/
core/cmake_build/*
core/build/
core/build/*
core/.idea/
.idea/
.idea/*
pulsar/client-cpp/cmake-build-debug/
pulsar/client-cpp/cmake-build-debug/*
**/cmake-build-debug/*
**/cmake_build/*
**/cmake-build-release/*
internal/core/output/*
internal/core/build/*
**/.idea/*
pulsar/client-cpp/build/
pulsar/client-cpp/build/*

View File

@ -11,12 +11,12 @@ done
SCRIPTS_DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )"
MILVUS_CORE_DIR="${SCRIPTS_DIR}/../../internal/core"
CORE_INSTALL_PREFIX="${MILVUS_CORE_DIR}/milvus"
CORE_INSTALL_PREFIX="${MILVUS_CORE_DIR}/output"
UNITTEST_DIRS=("${CORE_INSTALL_PREFIX}/unittest")
# Currently core will install target lib to "core/lib"
if [ -d "${MILVUS_CORE_DIR}/lib" ]; then
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${MILVUS_CORE_DIR}/lib
# Currently core will install target lib to "core/output/lib"
if [ -d "${CORE_INSTALL_PREFIX}/lib" ]; then
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${CORE_INSTALL_PREFIX}/lib
fi
# run unittest

View File

@ -35,10 +35,11 @@ ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/lib"
# Install Go
ENV GOPATH /go
ENV GOROOT /usr/local/go
RUN mkdir -p /usr/local/go && wget -qO- "https://golang.org/dl/go1.15.2.linux-amd64.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \
mkdir -p "$GOPATH/src" "$GOPATH/bin" && chmod -R 777 "$GOPATH"
ENV GO111MODULE on
ENV PATH $GOPATH/bin:$GOROOT/bin:$PATH
RUN mkdir -p /usr/local/go && wget -qO- "https://golang.org/dl/go1.15.2.linux-amd64.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \
mkdir -p "$GOPATH/src" "$GOPATH/bin" && chmod -R 777 "$GOPATH" && \
go get github.com/golang/protobuf/protoc-gen-go@v1.3.2
# Set permissions on /etc/passwd and /home to allow arbitrary users to write
COPY --chown=0:0 docker/build_env/entrypoint.sh /

View File

@ -176,8 +176,6 @@ config_summary()
add_subdirectory( thirdparty )
add_subdirectory( src )
# Unittest lib
if ( BUILD_UNIT_TEST STREQUAL "ON" )
if ( BUILD_COVERAGE STREQUAL "ON" )
@ -189,7 +187,7 @@ if ( BUILD_UNIT_TEST STREQUAL "ON" )
endif ()
append_flags( CMAKE_CXX_FLAGS FLAGS "-DELPP_DISABLE_LOGS")
add_subdirectory( ${CMAKE_CURRENT_SOURCE_DIR}/unittest )
add_subdirectory(unittest)
endif ()
@ -206,9 +204,9 @@ set( GPU_ENABLE "false" )
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/dog_segment/
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/include
DESTINATION include
FILES_MATCHING PATTERN "*_c.h"
)
install(FILES ${CMAKE_BINARY_DIR}/src/dog_segment/libmilvus_dog_segment.so
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib)
DESTINATION lib)

View File

@ -8,7 +8,7 @@ fi
BUILD_OUTPUT_DIR="cmake_build"
BUILD_TYPE="Release"
BUILD_UNITTEST="OFF"
INSTALL_PREFIX=$(pwd)/milvus
INSTALL_PREFIX=$(pwd)/output
MAKE_CLEAN="OFF"
BUILD_COVERAGE="OFF"
DB_PATH="/tmp/milvus"
@ -20,7 +20,7 @@ WITH_PROMETHEUS="ON"
CUDA_ARCH="DEFAULT"
CUSTOM_THIRDPARTY_PATH=""
while getopts "p:d:t:s:f:ulrcghzme" arg; do
while getopts "p:d:t:s:f:o:ulrcghzme" arg; do
case $arg in
f)
CUSTOM_THIRDPARTY_PATH=$OPTARG
@ -28,6 +28,9 @@ while getopts "p:d:t:s:f:ulrcghzme" arg; do
p)
INSTALL_PREFIX=$OPTARG
;;
o)
BUILD_OUTPUT_DIR=$OPTARG
;;
d)
DB_PATH=$OPTARG
;;

View File

@ -64,16 +64,12 @@ define_option(MILVUS_VERBOSE_THIRDPARTY_BUILD
define_option(MILVUS_WITH_EASYLOGGINGPP "Build with Easylogging++ library" ON)
define_option(MILVUS_WITH_GRPC "Build with GRPC" OFF)
define_option(MILVUS_WITH_ZLIB "Build with zlib compression" ON)
define_option(MILVUS_WITH_OPENTRACING "Build with Opentracing" ON)
define_option(MILVUS_WITH_YAMLCPP "Build with yaml-cpp library" ON)
define_option(MILVUS_WITH_PULSAR "Build with pulsar-client" ON)
#----------------------------------------------------------------------
set_option_category("Test and benchmark")

View File

@ -1,18 +0,0 @@
#ifdef __cplusplus
extern "C" {
#endif
typedef void* CCollection;
CCollection
NewCollection(const char* collection_name, const char* schema_conf);
void
DeleteCollection(CCollection collection);
void
UpdateIndexes(CCollection c_collection, const char *index_string);
#ifdef __cplusplus
}
#endif

View File

@ -1,17 +0,0 @@
#ifdef __cplusplus
extern "C" {
#endif
#include "collection_c.h"
typedef void* CPartition;
CPartition
NewPartition(CCollection collection, const char* partition_name);
void
DeletePartition(CPartition partition);
#ifdef __cplusplus
}
#endif

View File

@ -1,89 +0,0 @@
#ifdef __cplusplus
extern "C" {
#endif
#include <stdbool.h>
#include "partition_c.h"
typedef void* CSegmentBase;
typedef struct CQueryInfo {
long int num_queries;
int topK;
const char* field_name;
} CQueryInfo;
CSegmentBase
NewSegment(CPartition partition, unsigned long segment_id);
void
DeleteSegment(CSegmentBase segment);
//////////////////////////////////////////////////////////////////
int
Insert(CSegmentBase c_segment,
long int reserved_offset,
signed long int size,
const long* primary_keys,
const unsigned long* timestamps,
void* raw_data,
int sizeof_per_row,
signed long int count);
long int
PreInsert(CSegmentBase c_segment, long int size);
int
Delete(CSegmentBase c_segment,
long int reserved_offset,
long size,
const long* primary_keys,
const unsigned long* timestamps);
long int
PreDelete(CSegmentBase c_segment, long int size);
//int
//Search(CSegmentBase c_segment,
// const char* query_json,
// unsigned long timestamp,
// float* query_raw_data,
// int num_of_query_raw_data,
// long int* result_ids,
// float* result_distances);
int
Search(CSegmentBase c_segment,
CQueryInfo c_query_info,
unsigned long timestamp,
float* query_raw_data,
int num_of_query_raw_data,
long int* result_ids,
float* result_distances);
//////////////////////////////////////////////////////////////////
int
Close(CSegmentBase c_segment);
int
BuildIndex(CCollection c_collection, CSegmentBase c_segment);
bool
IsOpened(CSegmentBase c_segment);
long int
GetMemoryUsageInBytes(CSegmentBase c_segment);
//////////////////////////////////////////////////////////////////
long int
GetRowCount(CSegmentBase c_segment);
long int
GetDeletedCount(CSegmentBase c_segment);
#ifdef __cplusplus
}
#endif

View File

@ -13,7 +13,7 @@
#include <cstring>
#include <limits>
#include <unordered_map>
#include<iostream>
#include <iostream>
#include "config/ConfigMgr.h"
#include "config/ServerConfig.h"
@ -70,22 +70,19 @@ ConfigMgr::ConfigMgr() {
config_list_ = {
/* general */
{"timezone",
CreateStringConfig("timezone", false, &config.timezone.value, "UTC+8", nullptr, nullptr)},
{"timezone", CreateStringConfig("timezone", false, &config.timezone.value, "UTC+8", nullptr, nullptr)},
/* network */
{"network.address", CreateStringConfig("network.address", false, &config.network.address.value,
"0.0.0.0", nullptr, nullptr)},
{"network.port", CreateIntegerConfig("network.port", false, 0, 65535, &config.network.port.value,
19530, nullptr, nullptr)},
{"network.address",
CreateStringConfig("network.address", false, &config.network.address.value, "0.0.0.0", nullptr, nullptr)},
{"network.port",
CreateIntegerConfig("network.port", false, 0, 65535, &config.network.port.value, 19530, nullptr, nullptr)},
/* pulsar */
{"pulsar.address", CreateStringConfig("pulsar.address", false, &config.pulsar.address.value,
"localhost", nullptr, nullptr)},
{"pulsar.port", CreateIntegerConfig("pulsar.port", false, 0, 65535, &config.pulsar.port.value,
6650, nullptr, nullptr)},
{"pulsar.address",
CreateStringConfig("pulsar.address", false, &config.pulsar.address.value, "localhost", nullptr, nullptr)},
{"pulsar.port",
CreateIntegerConfig("pulsar.port", false, 0, 65535, &config.pulsar.port.value, 6650, nullptr, nullptr)},
/* log */
{"logs.level", CreateStringConfig("logs.level", false, &config.logs.level.value, "debug", nullptr, nullptr)},
@ -147,9 +144,9 @@ ConfigMgr::Load(const std::string& path) {
void
ConfigMgr::Set(const std::string& name, const std::string& value, bool update) {
std::cout<<"InSet Config "<< name <<std::endl;
if (config_list_.find(name) == config_list_.end()){
std::cout<<"Config "<< name << " not found!"<<std::endl;
std::cout << "InSet Config " << name << std::endl;
if (config_list_.find(name) == config_list_.end()) {
std::cout << "Config " << name << " not found!" << std::endl;
return;
}
try {

View File

@ -142,7 +142,11 @@ BaseConfig::Init() {
inited_ = true;
}
BoolConfig::BoolConfig(const char* name, const char* alias, bool modifiable, bool* config, bool default_value,
BoolConfig::BoolConfig(const char* name,
const char* alias,
bool modifiable,
bool* config,
bool default_value,
std::function<bool(bool val, std::string& err)> is_valid_fn,
std::function<bool(bool val, bool prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
@ -199,7 +203,11 @@ BoolConfig::Get() {
}
StringConfig::StringConfig(
const char* name, const char* alias, bool modifiable, std::string* config, const char* default_value,
const char* name,
const char* alias,
bool modifiable,
std::string* config,
const char* default_value,
std::function<bool(const std::string& val, std::string& err)> is_valid_fn,
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
@ -251,8 +259,13 @@ StringConfig::Get() {
return *config_;
}
EnumConfig::EnumConfig(const char* name, const char* alias, bool modifiable, configEnum* enumd, int64_t* config,
int64_t default_value, std::function<bool(int64_t val, std::string& err)> is_valid_fn,
EnumConfig::EnumConfig(const char* name,
const char* alias,
bool modifiable,
configEnum* enumd,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
@ -324,8 +337,13 @@ EnumConfig::Get() {
return "unknown";
}
IntegerConfig::IntegerConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound,
int64_t upper_bound, int64_t* config, int64_t default_value,
IntegerConfig::IntegerConfig(const char* name,
const char* alias,
bool modifiable,
int64_t lower_bound,
int64_t upper_bound,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
@ -393,8 +411,13 @@ IntegerConfig::Get() {
return std::to_string(*config_);
}
FloatingConfig::FloatingConfig(const char* name, const char* alias, bool modifiable, double lower_bound,
double upper_bound, double* config, double default_value,
FloatingConfig::FloatingConfig(const char* name,
const char* alias,
bool modifiable,
double lower_bound,
double upper_bound,
double* config,
double default_value,
std::function<bool(double val, std::string& err)> is_valid_fn,
std::function<bool(double val, double prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
@ -457,8 +480,13 @@ FloatingConfig::Get() {
return std::to_string(*config_);
}
SizeConfig::SizeConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound,
int64_t* config, int64_t default_value,
SizeConfig::SizeConfig(const char* name,
const char* alias,
bool modifiable,
int64_t lower_bound,
int64_t upper_bound,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),

View File

@ -67,7 +67,11 @@ using BaseConfigPtr = std::shared_ptr<BaseConfig>;
class BoolConfig : public BaseConfig {
public:
BoolConfig(const char* name, const char* alias, bool modifiable, bool* config, bool default_value,
BoolConfig(const char* name,
const char* alias,
bool modifiable,
bool* config,
bool default_value,
std::function<bool(bool val, std::string& err)> is_valid_fn,
std::function<bool(bool val, bool prev, std::string& err)> update_fn);
@ -90,7 +94,11 @@ class BoolConfig : public BaseConfig {
class StringConfig : public BaseConfig {
public:
StringConfig(const char* name, const char* alias, bool modifiable, std::string* config, const char* default_value,
StringConfig(const char* name,
const char* alias,
bool modifiable,
std::string* config,
const char* default_value,
std::function<bool(const std::string& val, std::string& err)> is_valid_fn,
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn);
@ -113,8 +121,13 @@ class StringConfig : public BaseConfig {
class EnumConfig : public BaseConfig {
public:
EnumConfig(const char* name, const char* alias, bool modifiable, configEnum* enumd, int64_t* config,
int64_t default_value, std::function<bool(int64_t val, std::string& err)> is_valid_fn,
EnumConfig(const char* name,
const char* alias,
bool modifiable,
configEnum* enumd,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
private:
@ -137,8 +150,13 @@ class EnumConfig : public BaseConfig {
class IntegerConfig : public BaseConfig {
public:
IntegerConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound,
int64_t* config, int64_t default_value,
IntegerConfig(const char* name,
const char* alias,
bool modifiable,
int64_t lower_bound,
int64_t upper_bound,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
@ -163,8 +181,14 @@ class IntegerConfig : public BaseConfig {
class FloatingConfig : public BaseConfig {
public:
FloatingConfig(const char* name, const char* alias, bool modifiable, double lower_bound, double upper_bound,
double* config, double default_value, std::function<bool(double val, std::string& err)> is_valid_fn,
FloatingConfig(const char* name,
const char* alias,
bool modifiable,
double lower_bound,
double upper_bound,
double* config,
double default_value,
std::function<bool(double val, std::string& err)> is_valid_fn,
std::function<bool(double val, double prev, std::string& err)> update_fn);
private:
@ -188,8 +212,14 @@ class FloatingConfig : public BaseConfig {
class SizeConfig : public BaseConfig {
public:
SizeConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound,
int64_t* config, int64_t default_value, std::function<bool(int64_t val, std::string& err)> is_valid_fn,
SizeConfig(const char* name,
const char* alias,
bool modifiable,
int64_t lower_bound,
int64_t upper_bound,
int64_t* config,
int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
private:

View File

@ -71,11 +71,10 @@ struct ServerConfig {
Integer port{0};
} network;
struct Pulsar{
struct Pulsar {
String address{"localhost"};
Integer port{6650};
}pulsar;
} pulsar;
struct Engine {
Integer build_index_threshold{4096};
@ -89,7 +88,6 @@ struct ServerConfig {
String json_config_path{"unknown"};
} tracing;
struct Logs {
String level{"unknown"};
struct Trace {

View File

@ -11,13 +11,13 @@ class AckResponder {
std::lock_guard lck(mutex_);
fetch_and_flip(seg_end);
auto old_begin = fetch_and_flip(seg_begin);
if(old_begin) {
if (old_begin) {
minimal = *acks_.begin();
}
}
int64_t
GetAck() const{
GetAck() const {
return minimal;
}
@ -38,4 +38,4 @@ class AckResponder {
std::set<int64_t> acks_ = {0};
std::atomic<int64_t> minimal = 0;
};
}
} // namespace milvus::dog_segment

View File

@ -11,7 +11,7 @@ set(DOG_SEGMENT_FILES
partition_c.cpp
segment_c.cpp
EasyAssert.cpp
${PB_SRC_FILES}
${PB_SRC_FILES}
)
add_library(milvus_dog_segment SHARED
${DOG_SEGMENT_FILES}
@ -20,5 +20,9 @@ add_library(milvus_dog_segment SHARED
#add_dependencies( segment sqlite mysqlpp )
target_link_libraries(milvus_dog_segment tbb utils pthread knowhere log libprotobuf dl backtrace
)
target_link_libraries(milvus_dog_segment
tbb utils pthread knowhere log libprotobuf
dl backtrace
milvus_query
)

View File

@ -6,17 +6,14 @@
namespace milvus::dog_segment {
Collection::Collection(std::string &collection_name, std::string &schema):
collection_name_(collection_name), schema_json_(schema) {
Collection::Collection(std::string& collection_name, std::string& schema)
: collection_name_(collection_name), schema_json_(schema) {
parse();
index_ = nullptr;
}
void
Collection::AddIndex(const grpc::IndexParam& index_param) {
auto& index_name = index_param.index_name();
auto& field_name = index_param.field_name();
@ -32,7 +29,7 @@ Collection::AddIndex(const grpc::IndexParam& index_param) {
bool found_index_conf = false;
auto extra_params = index_param.extra_params();
for (auto& extra_param: extra_params) {
for (auto& extra_param : extra_params) {
if (extra_param.key() == "index_type") {
index_type = extra_param.value().data();
found_index_type = true;
@ -67,21 +64,18 @@ Collection::AddIndex(const grpc::IndexParam& index_param) {
if (!found_index_conf) {
int dim = 0;
for (auto& field: schema_->get_fields()) {
for (auto& field : schema_->get_fields()) {
if (field.get_data_type() == DataType::VECTOR_FLOAT) {
dim = field.get_dim();
dim = field.get_dim();
}
}
Assert(dim != 0);
index_conf = milvus::knowhere::Config{
{knowhere::meta::DIM, dim},
{knowhere::IndexParams::nlist, 100},
{knowhere::IndexParams::nprobe, 4},
{knowhere::IndexParams::m, 4},
{knowhere::IndexParams::nbits, 8},
{knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{knowhere::meta::DEVICEID, 0},
{knowhere::meta::DIM, dim}, {knowhere::IndexParams::nlist, 100},
{knowhere::IndexParams::nprobe, 4}, {knowhere::IndexParams::m, 4},
{knowhere::IndexParams::nbits, 8}, {knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{knowhere::meta::DEVICEID, 0},
};
std::cout << "WARN: Not specify index config, use default index config" << std::endl;
}
@ -89,11 +83,9 @@ Collection::AddIndex(const grpc::IndexParam& index_param) {
index_->AddEntry(index_name, field_name, index_type, index_mode, index_conf);
}
void
Collection::CreateIndex(std::string &index_config) {
if(index_config.empty()) {
Collection::CreateIndex(std::string& index_config) {
if (index_config.empty()) {
index_ = nullptr;
std::cout << "null index config when create index" << std::endl;
return;
@ -108,18 +100,16 @@ Collection::CreateIndex(std::string &index_config) {
index_ = std::make_shared<IndexMeta>(schema_);
for (const auto &index: collection.indexes()){
std::cout << "add index, index name =" << index.index_name()
<< ", field_name = " << index.field_name()
for (const auto& index : collection.indexes()) {
std::cout << "add index, index name =" << index.index_name() << ", field_name = " << index.field_name()
<< std::endl;
AddIndex(index);
}
}
void
Collection::parse() {
if(schema_json_.empty()) {
if (schema_json_.empty()) {
std::cout << "WARN: Use default schema" << std::endl;
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
@ -131,22 +121,20 @@ Collection::parse() {
masterpb::Collection collection;
auto suc = google::protobuf::TextFormat::ParseFromString(schema_json_, &collection);
if (!suc) {
std::cerr << "unmarshal schema string failed" << std::endl;
}
auto schema = std::make_shared<Schema>();
for (const milvus::grpc::FieldMeta & child: collection.schema().field_metas()){
std::cout<<"add Field, name :" << child.field_name() << ", datatype :" << child.type() << ", dim :" << int(child.dim()) << std::endl;
schema->AddField(std::string_view(child.field_name()), DataType {child.type()}, int(child.dim()));
for (const milvus::grpc::FieldMeta& child : collection.schema().field_metas()) {
std::cout << "add Field, name :" << child.field_name() << ", datatype :" << child.type()
<< ", dim :" << int(child.dim()) << std::endl;
schema->AddField(std::string_view(child.field_name()), DataType{child.type()}, int(child.dim()));
}
/*
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
*/
schema_ = schema;
}
}
} // namespace milvus::dog_segment

View File

@ -7,29 +7,35 @@
namespace milvus::dog_segment {
class Collection {
public:
explicit Collection(std::string &collection_name, std::string &schema);
public:
explicit Collection(std::string& collection_name, std::string& schema);
void AddIndex(const grpc::IndexParam &index_param);
void
AddIndex(const grpc::IndexParam& index_param);
void CreateIndex(std::string &index_config);
void
CreateIndex(std::string& index_config);
void parse();
void
parse();
public:
SchemaPtr& get_schema() {
return schema_;
public:
SchemaPtr&
get_schema() {
return schema_;
}
IndexMetaPtr& get_index() {
return index_;
IndexMetaPtr&
get_index() {
return index_;
}
std::string& get_collection_name() {
return collection_name_;
std::string&
get_collection_name() {
return collection_name_;
}
private:
private:
IndexMetaPtr index_;
std::string collection_name_;
std::string schema_json_;
@ -38,4 +44,4 @@ private:
using CollectionPtr = std::unique_ptr<Collection>;
}
} // namespace milvus::dog_segment

View File

@ -2,7 +2,4 @@
#include <iostream>
#include "dog_segment/ConcurrentVector.h"
namespace milvus::dog_segment {
}
namespace milvus::dog_segment {}

View File

@ -90,7 +90,8 @@ class VectorBase {
virtual void
grow_to_at_least(int64_t element_count) = 0;
virtual void set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) = 0;
virtual void
set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) = 0;
};
template <typename Type, bool is_scalar = false, ssize_t ElementsPerChunk = DefaultElementPerChunk>
@ -101,10 +102,12 @@ class ConcurrentVector : public VectorBase {
ConcurrentVector(ConcurrentVector&&) = delete;
ConcurrentVector(const ConcurrentVector&) = delete;
ConcurrentVector& operator=(ConcurrentVector&&) = delete;
ConcurrentVector& operator=(const ConcurrentVector&) = delete;
public:
ConcurrentVector&
operator=(ConcurrentVector&&) = delete;
ConcurrentVector&
operator=(const ConcurrentVector&) = delete;
public:
explicit ConcurrentVector(ssize_t dim = 1) : Dim(is_scalar ? 1 : dim), SizePerChunk(Dim * ElementsPerChunk) {
Assert(is_scalar ? dim == 1 : dim != 1);
}
@ -185,8 +188,8 @@ class ConcurrentVector : public VectorBase {
private:
void
fill_chunk(ssize_t chunk_id, ssize_t chunk_offset, ssize_t element_count, const Type* source,
ssize_t source_offset) {
fill_chunk(
ssize_t chunk_id, ssize_t chunk_offset, ssize_t element_count, const Type* source, ssize_t source_offset) {
if (element_count <= 0) {
return;
}
@ -199,6 +202,7 @@ class ConcurrentVector : public VectorBase {
const ssize_t Dim;
const ssize_t SizePerChunk;
private:
ThreadSafeVector<Chunk> chunks_;
};

View File

@ -13,22 +13,25 @@ struct DeletedRecord {
int64_t del_barrier = 0;
faiss::ConcurrentBitsetPtr bitmap_ptr;
std::shared_ptr<TmpBitmap> clone(int64_t capacity);
std::shared_ptr<TmpBitmap>
clone(int64_t capacity);
};
DeletedRecord() : lru_(std::make_shared<TmpBitmap>()) {
lru_->bitmap_ptr = std::make_shared<faiss::ConcurrentBitset>(0);
}
auto get_lru_entry() {
auto
get_lru_entry() {
std::shared_lock lck(shared_mutex_);
return lru_;
}
void insert_lru_entry(std::shared_ptr<TmpBitmap> new_entry, bool force = false) {
void
insert_lru_entry(std::shared_ptr<TmpBitmap> new_entry, bool force = false) {
std::lock_guard lck(shared_mutex_);
if (new_entry->del_barrier <= lru_->del_barrier) {
if (!force || new_entry->bitmap_ptr->capacity() <= lru_->bitmap_ptr->capacity()) {
if (!force || new_entry->bitmap_ptr->count() <= lru_->bitmap_ptr->count()) {
// DO NOTHING
return;
}
@ -36,18 +39,19 @@ struct DeletedRecord {
lru_ = std::move(new_entry);
}
public:
public:
std::atomic<int64_t> reserved = 0;
AckResponder ack_responder_;
ConcurrentVector<Timestamp, true> timestamps_;
ConcurrentVector<idx_t, true> uids_;
private:
private:
std::shared_ptr<TmpBitmap> lru_;
std::shared_mutex shared_mutex_;
};
auto DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr<TmpBitmap> {
auto
DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr<TmpBitmap> {
auto res = std::make_shared<TmpBitmap>();
res->del_barrier = this->del_barrier;
res->bitmap_ptr = std::make_shared<faiss::ConcurrentBitset>(capacity);
@ -56,4 +60,4 @@ auto DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr<TmpBit
return res;
}
}
} // namespace milvus::dog_segment

View File

@ -5,15 +5,15 @@
#define BOOST_STACKTRACE_USE_BACKTRACE
#include <boost/stacktrace.hpp>
namespace milvus::impl {
void EasyAssertInfo(bool value, std::string_view expr_str, std::string_view filename, int lineno,
std::string_view extra_info) {
void
EasyAssertInfo(
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info) {
if (!value) {
std::string info;
info += "Assert \"" + std::string(expr_str) + "\"";
info += " at " + std::string(filename) + ":" + std::to_string(lineno) + "\n";
if(!extra_info.empty()) {
if (!extra_info.empty()) {
info += " => " + std::string(extra_info);
}
auto fuck = boost::stacktrace::stacktrace();
@ -23,4 +23,4 @@ void EasyAssertInfo(bool value, std::string_view expr_str, std::string_view file
throw std::runtime_error(info);
}
}
}
} // namespace milvus::impl

View File

@ -6,8 +6,9 @@
/* Paste this on the file you want to debug. */
namespace milvus::impl {
void EasyAssertInfo(bool value, std::string_view expr_str, std::string_view filename, int lineno,
std::string_view extra_info);
void
EasyAssertInfo(
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info);
}
#define AssertInfo(expr, info) impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info))

View File

@ -4,15 +4,9 @@
namespace milvus::dog_segment {
Status
IndexMeta::AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode,
IndexConfig config) {
Entry entry{
index_name,
field_name,
type,
mode,
std::move(config)
};
IndexMeta::AddEntry(
const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode, IndexConfig config) {
Entry entry{index_name, field_name, type, mode, std::move(config)};
VerifyEntry(entry);
if (entries_.count(index_name)) {
@ -30,22 +24,23 @@ Status
IndexMeta::DropEntry(const std::string& index_name) {
Assert(entries_.count(index_name));
auto entry = std::move(entries_[index_name]);
if(lookups_[entry.field_name] == index_name) {
if (lookups_[entry.field_name] == index_name) {
lookups_.erase(entry.field_name);
}
return Status::OK();
}
void IndexMeta::VerifyEntry(const Entry &entry) {
void
IndexMeta::VerifyEntry(const Entry& entry) {
auto is_mode_valid = std::set{IndexMode::MODE_CPU, IndexMode::MODE_GPU}.count(entry.mode);
if(!is_mode_valid) {
if (!is_mode_valid) {
throw std::invalid_argument("invalid mode");
}
auto& schema = *schema_;
auto& field_meta = schema[entry.field_name];
// TODO checking
if(field_meta.is_vector()) {
if (field_meta.is_vector()) {
Assert(entry.type == knowhere::IndexEnum::INDEX_FAISS_IVFPQ);
} else {
Assert(false);

View File

@ -29,7 +29,10 @@ class IndexMeta {
};
Status
AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode,
AddEntry(const std::string& index_name,
const std::string& field_name,
IndexType type,
IndexMode mode,
IndexConfig config);
Status
@ -40,12 +43,14 @@ class IndexMeta {
return entries_;
}
const Entry& lookup_by_field(const std::string& field_name) {
const Entry&
lookup_by_field(const std::string& field_name) {
AssertInfo(lookups_.count(field_name), field_name);
auto index_name = lookups_.at(field_name);
AssertInfo(entries_.count(index_name), index_name);
return entries_.at(index_name);
}
private:
void
VerifyEntry(const Entry& entry);

View File

@ -2,7 +2,8 @@
namespace milvus::dog_segment {
Partition::Partition(std::string& partition_name, SchemaPtr& schema, IndexMetaPtr& index):
partition_name_(partition_name), schema_(schema), index_(index) {}
Partition::Partition(std::string& partition_name, SchemaPtr& schema, IndexMetaPtr& index)
: partition_name_(partition_name), schema_(schema), index_(index) {
}
} // namespace milvus::dog_segment

View File

@ -5,23 +5,26 @@
namespace milvus::dog_segment {
class Partition {
public:
public:
explicit Partition(std::string& partition_name, SchemaPtr& schema, IndexMetaPtr& index);
public:
SchemaPtr& get_schema() {
return schema_;
public:
SchemaPtr&
get_schema() {
return schema_;
}
IndexMetaPtr& get_index() {
return index_;
IndexMetaPtr&
get_index() {
return index_;
}
std::string& get_partition_name() {
return partition_name_;
std::string&
get_partition_name() {
return partition_name_;
}
private:
private:
std::string partition_name_;
SchemaPtr schema_;
IndexMetaPtr index_;
@ -29,4 +32,4 @@ private:
using PartitionPtr = std::unique_ptr<Partition>;
}
} // namespace milvus::dog_segment

View File

@ -32,12 +32,18 @@ class SegmentBase {
virtual ~SegmentBase() = default;
// SegmentBase(std::shared_ptr<FieldsInfo> collection);
virtual int64_t PreInsert(int64_t size) = 0;
virtual int64_t
PreInsert(int64_t size) = 0;
virtual Status
Insert(int64_t reserved_offset, int64_t size, const int64_t* primary_keys, const Timestamp* timestamps, const DogDataChunk& values) = 0;
Insert(int64_t reserved_offset,
int64_t size,
const int64_t* primary_keys,
const Timestamp* timestamps,
const DogDataChunk& values) = 0;
virtual int64_t PreDelete(int64_t size) = 0;
virtual int64_t
PreDelete(int64_t size) = 0;
// TODO: add id into delete log, possibly bitmap
virtual Status

View File

@ -152,20 +152,23 @@ class Schema {
return total_sizeof_;
}
const std::vector<int>& get_sizeof_infos() {
const std::vector<int>&
get_sizeof_infos() {
return sizeof_infos_;
}
std::optional<int> get_offset(const std::string& field_name) {
if(!offsets_.count(field_name)) {
std::optional<int>
get_offset(const std::string& field_name) {
if (!offsets_.count(field_name)) {
return std::nullopt;
} else {
return offsets_[field_name];
}
}
const std::vector<FieldMeta>& get_fields() {
return fields_;
const std::vector<FieldMeta>&
get_fields() {
return fields_;
}
const FieldMeta&
@ -175,6 +178,7 @@ class Schema {
auto offset = offset_iter->second;
return (*this)[offset];
}
private:
// this is where data holds
std::vector<FieldMeta> fields_;

View File

@ -21,8 +21,8 @@ CreateSegment(SchemaPtr schema) {
return segment;
}
SegmentNaive::Record::Record(const Schema &schema) : uids_(1), timestamps_(1) {
for (auto &field : schema) {
SegmentNaive::Record::Record(const Schema& schema) : uids_(1), timestamps_(1) {
for (auto& field : schema) {
if (field.is_vector()) {
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<float>>(field.get_dim()));
@ -45,17 +45,17 @@ SegmentNaive::PreDelete(int64_t size) {
return reserved_begin;
}
auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp,
int64_t insert_barrier, bool force) -> std::shared_ptr<DeletedRecord::TmpBitmap> {
auto
SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier, bool force)
-> std::shared_ptr<DeletedRecord::TmpBitmap> {
auto old = deleted_record_.get_lru_entry();
if (!force || old->bitmap_ptr->capacity() == insert_barrier) {
if (!force || old->bitmap_ptr->count() == insert_barrier) {
if (old->del_barrier == del_barrier) {
return old;
}
}
auto current = old->clone(insert_barrier);
current->del_barrier = del_barrier;
@ -67,7 +67,7 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times
// map uid to corrensponding offsets, select the max one, which should be the target
// the max one should be closest to query_timestamp, so the delete log should refer to it
int64_t the_offset = -1;
auto[iter_b, iter_e] = uid2offset_.equal_range(uid);
auto [iter_b, iter_e] = uid2offset_.equal_range(uid);
for (auto iter = iter_b; iter != iter_e; ++iter) {
auto offset = iter->second;
if (record_.timestamps_[offset] < query_timestamp) {
@ -90,7 +90,7 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times
// map uid to corrensponding offsets, select the max one, which should be the target
// the max one should be closest to query_timestamp, so the delete log should refer to it
int64_t the_offset = -1;
auto[iter_b, iter_e] = uid2offset_.equal_range(uid);
auto [iter_b, iter_e] = uid2offset_.equal_range(uid);
for (auto iter = iter_b; iter != iter_e; ++iter) {
auto offset = iter->second;
if (offset >= insert_barrier) {
@ -116,16 +116,19 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times
}
Status
SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t *uids_raw, const Timestamp *timestamps_raw,
const DogDataChunk &entities_raw) {
SegmentNaive::Insert(int64_t reserved_begin,
int64_t size,
const int64_t* uids_raw,
const Timestamp* timestamps_raw,
const DogDataChunk& entities_raw) {
Assert(entities_raw.count == size);
if (entities_raw.sizeof_per_row != schema_->get_total_sizeof()) {
std::string msg = "entity length = " + std::to_string(entities_raw.sizeof_per_row) +
", schema length = " + std::to_string(schema_->get_total_sizeof());
std::string msg = "entity length = " + std::to_string(entities_raw.sizeof_per_row) +
", schema length = " + std::to_string(schema_->get_total_sizeof());
throw std::runtime_error(msg);
}
auto raw_data = reinterpret_cast<const char *>(entities_raw.raw_data);
auto raw_data = reinterpret_cast<const char*>(entities_raw.raw_data);
// std::vector<char> entities(raw_data, raw_data + size * len_per_row);
auto len_per_row = entities_raw.sizeof_per_row;
@ -150,7 +153,7 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t *uids_r
std::vector<Timestamp> timestamps(size);
// #pragma omp parallel for
for (int index = 0; index < size; ++index) {
auto[t, uid, order_index] = ordering[index];
auto [t, uid, order_index] = ordering[index];
timestamps[index] = t;
uids[index] = uid;
for (int fid = 0; fid < schema_->size(); ++fid) {
@ -209,8 +212,7 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t *uids_r
}
Status
SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t *uids_raw,
const Timestamp *timestamps_raw) {
SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t* uids_raw, const Timestamp* timestamps_raw) {
std::vector<std::tuple<Timestamp, idx_t>> ordering;
ordering.resize(size);
// #pragma omp parallel for
@ -222,7 +224,7 @@ SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t *uids_r
std::vector<Timestamp> timestamps(size);
// #pragma omp parallel for
for (int index = 0; index < size; ++index) {
auto[t, uid] = ordering[index];
auto [t, uid] = ordering[index];
timestamps[index] = t;
uids[index] = uid;
}
@ -238,9 +240,10 @@ SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t *uids_r
// return Status::OK();
}
template<typename RecordType>
int64_t get_barrier(const RecordType &record, Timestamp timestamp) {
auto &vec = record.timestamps_;
template <typename RecordType>
int64_t
get_barrier(const RecordType& record, Timestamp timestamp) {
auto& vec = record.timestamps_;
int64_t beg = 0;
int64_t end = record.ack_responder_.GetAck();
while (beg < end) {
@ -255,15 +258,15 @@ int64_t get_barrier(const RecordType &record, Timestamp timestamp) {
}
Status
SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) {
SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
auto ins_barrier = get_barrier(record_, timestamp);
auto del_barrier = get_barrier(deleted_record_, timestamp);
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier, true);
Assert(bitmap_holder);
Assert(bitmap_holder->bitmap_ptr->capacity() == ins_barrier);
Assert(bitmap_holder->bitmap_ptr->count() == ins_barrier);
auto field_offset = schema_->get_offset(query_info->field_name);
auto &field = schema_->operator[](query_info->field_name);
auto& field = schema_->operator[](query_info->field_name);
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
auto dim = field.get_dim();
@ -280,7 +283,7 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe
conf[milvus::knowhere::meta::TOPK] = query_info->topK;
{
auto count = 0;
for (int i = 0; i < bitmap->capacity(); ++i) {
for (int i = 0; i < bitmap->count(); ++i) {
if (bitmap->test(i)) {
++count;
}
@ -291,10 +294,10 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe
auto indexing = std::static_pointer_cast<knowhere::VecIndex>(indexings_[index_entry.index_name]);
indexing->SetBlacklist(bitmap);
auto ds = knowhere::GenDataset(query_info->num_queries, dim, query_info->query_raw_data.data());
auto final = indexing->Query(ds, conf);
auto final = indexing->Query(ds, conf, bitmap);
auto ids = final->Get<idx_t *>(knowhere::meta::IDS);
auto distances = final->Get<float *>(knowhere::meta::DISTANCE);
auto ids = final->Get<idx_t*>(knowhere::meta::IDS);
auto distances = final->Get<float*>(knowhere::meta::DISTANCE);
auto total_num = num_queries * topK;
result.result_ids_.resize(total_num);
@ -307,7 +310,7 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe
std::copy_n(ids, total_num, result.result_ids_.data());
std::copy_n(distances, total_num, result.result_distances_.data());
for (auto &id: result.result_ids_) {
for (auto& id : result.result_ids_) {
id = record_.uids_[id];
}
@ -315,8 +318,13 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe
}
void
merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const float *new_distances, const int64_t *new_uids) {
for(int64_t qn = 0; qn < queries; ++qn) {
merge_into(int64_t queries,
int64_t topk,
float* distances,
int64_t* uids,
const float* new_distances,
const int64_t* new_uids) {
for (int64_t qn = 0; qn < queries; ++qn) {
auto base = qn * topk;
auto src2_dis = distances + base;
auto src2_uids = uids + base;
@ -330,8 +338,8 @@ merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const
auto it1 = 0;
auto it2 = 0;
for(auto buf = 0; buf < topk; ++buf){
if(src1_dis[it1] <= src2_dis[it2]) {
for (auto buf = 0; buf < topk; ++buf) {
if (src1_dis[it1] <= src2_dis[it2]) {
buf_dis[buf] = src1_dis[it1];
buf_uids[buf] = src1_uids[it1];
++it1;
@ -347,13 +355,13 @@ merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const
}
Status
SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &results) {
SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) {
auto ins_barrier = get_barrier(record_, timestamp);
auto del_barrier = get_barrier(deleted_record_, timestamp);
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier);
Assert(bitmap_holder);
auto &field = schema_->operator[](query_info->field_name);
auto& field = schema_->operator[](query_info->field_name);
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
auto dim = field.get_dim();
auto bitmap = bitmap_holder->bitmap_ptr;
@ -375,15 +383,15 @@ SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestam
std::vector<int64_t> buf_uids(total_count, -1);
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
faiss::float_maxheap_array_t buf = {
(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()};
faiss::float_maxheap_array_t buf = {(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()};
auto src_data = vec_ptr->get_chunk(chunk_id).data();
auto nsize = chunk_id != max_chunk - 1? DefaultElementPerChunk: ins_barrier - chunk_id * DefaultElementPerChunk;
auto nsize =
chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk;
auto offset = chunk_id * DefaultElementPerChunk;
faiss::knn_L2sqr(query_info->query_raw_data.data(), src_data, dim, num_queries, nsize, &buf, bitmap, offset);
if(chunk_id == 0) {
if (chunk_id == 0) {
final_uids = buf_uids;
final_dis = buf_dis;
} else {
@ -391,8 +399,7 @@ SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestam
}
}
for(auto& id: final_uids) {
for (auto& id : final_uids) {
id = record_.uids_[id];
}
@ -402,20 +409,18 @@ SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestam
results.num_queries_ = num_queries;
results.row_num_ = total_count;
// throw std::runtime_error("unimplemented");
// throw std::runtime_error("unimplemented");
return Status::OK();
}
Status
SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) {
SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
auto ins_barrier = get_barrier(record_, timestamp);
auto del_barrier = get_barrier(deleted_record_, timestamp);
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier);
Assert(bitmap_holder);
auto &field = schema_->operator[](query_info->field_name);
auto& field = schema_->operator[](query_info->field_name);
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
auto dim = field.get_dim();
auto bitmap = bitmap_holder->bitmap_ptr;
@ -428,7 +433,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
auto vec_ptr = std::static_pointer_cast<ConcurrentVector<float>>(record_.entity_vec_.at(the_offset_opt.value()));
std::vector<std::priority_queue<std::pair<float, int>>> records(num_queries);
auto get_L2_distance = [dim](const float *a, const float *b) {
auto get_L2_distance = [dim](const float* a, const float* b) {
float L2_distance = 0;
for (auto i = 0; i < dim; ++i) {
auto d = a[i] - b[i];
@ -438,14 +443,14 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
};
for (int64_t i = 0; i < ins_barrier; ++i) {
if (i < bitmap->capacity() && bitmap->test(i)) {
if (i < bitmap->count() && bitmap->test(i)) {
continue;
}
auto element = vec_ptr->get_element(i);
for (auto query_id = 0; query_id < num_queries; ++query_id) {
auto query_blob = query_info->query_raw_data.data() + query_id * dim;
auto dis = get_L2_distance(query_blob, element);
auto &record = records[query_id];
auto& record = records[query_id];
if (record.size() < topK) {
record.emplace(dis, i);
} else if (record.top().first > dis) {
@ -455,7 +460,6 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
}
}
result.num_queries_ = num_queries;
result.topK_ = topK;
auto row_num = topK * num_queries;
@ -468,7 +472,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
// reverse
for (int i = 0; i < topK; ++i) {
auto dst_id = topK - 1 - i + q_id * topK;
auto[dis, offset] = records[q_id].top();
auto [dis, offset] = records[q_id].top();
records[q_id].pop();
result.result_ids_[dst_id] = record_.uids_[offset];
result.result_distances_[dst_id] = dis;
@ -479,7 +483,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
}
Status
SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) {
SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
// TODO: enable delete
// TODO: enable index
// TODO: remove mock
@ -493,7 +497,7 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult
std::default_random_engine e(42);
std::uniform_real_distribution<> dis(0.0, 1.0);
query_info->query_raw_data.resize(query_info->num_queries * dim);
for (auto &x: query_info->query_raw_data) {
for (auto& x : query_info->query_raw_data) {
x = dis(e);
}
}
@ -517,8 +521,9 @@ SegmentNaive::Close() {
return Status::OK();
}
template<typename Type>
knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry) {
template <typename Type>
knowhere::IndexPtr
SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
auto offset_opt = schema_->get_offset(entry.field_name);
Assert(offset_opt.has_value());
auto offset = offset_opt.value();
@ -528,7 +533,7 @@ knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry
auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode);
auto chunk_size = record_.uids_.chunk_size();
auto &uids = record_.uids_;
auto& uids = record_.uids_;
auto entities = record_.get_vec_entity<float>(offset);
std::vector<knowhere::DatasetPtr> datasets;
@ -538,10 +543,10 @@ knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry
: DefaultElementPerChunk;
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
}
for (auto &ds: datasets) {
for (auto& ds : datasets) {
indexing->Train(ds, entry.config);
}
for (auto &ds: datasets) {
for (auto& ds : datasets) {
indexing->AddWithoutIds(ds, entry.config);
}
return indexing;
@ -555,7 +560,7 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) {
int dim = 0;
std::string index_field_name;
for (auto& field: schema_->get_fields()) {
for (auto& field : schema_->get_fields()) {
if (field.get_data_type() == DataType::VECTOR_FLOAT) {
dim = field.get_dim();
index_field_name = field.get_name();
@ -569,28 +574,24 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) {
// TODO: this is merge of query conf and insert conf
// TODO: should be splitted into multiple configs
auto conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, dim},
{milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::nprobe, 4},
{milvus::knowhere::IndexParams::m, 4},
{milvus::knowhere::IndexParams::nbits, 8},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{milvus::knowhere::meta::DEVICEID, 0},
{milvus::knowhere::meta::DIM, dim}, {milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::nprobe, 4}, {milvus::knowhere::IndexParams::m, 4},
{milvus::knowhere::IndexParams::nbits, 8}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{milvus::knowhere::meta::DEVICEID, 0},
};
index_meta->AddEntry("fakeindex", index_field_name, knowhere::IndexEnum::INDEX_FAISS_IVFPQ,
knowhere::IndexMode::MODE_CPU, conf);
remote_index_meta = index_meta;
}
if(record_.ack_responder_.GetAck() < 1024 * 4) {
if (record_.ack_responder_.GetAck() < 1024 * 4) {
return Status(SERVER_BUILD_INDEX_ERROR, "too few elements");
}
index_meta_ = remote_index_meta;
for (auto&[index_name, entry]: index_meta_->get_entries()) {
for (auto& [index_name, entry] : index_meta_->get_entries()) {
Assert(entry.index_name == index_name);
const auto &field = (*schema_)[entry.field_name];
const auto& field = (*schema_)[entry.field_name];
if (field.is_vector()) {
Assert(field.get_data_type() == engine::DataType::VECTOR_FLOAT);
@ -608,9 +609,9 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) {
int64_t
SegmentNaive::GetMemoryUsageInBytes() {
int64_t total_bytes = 0;
if(index_ready_) {
if (index_ready_) {
auto& index_entries = index_meta_->get_entries();
for(auto [index_name, entry]: index_entries) {
for (auto [index_name, entry] : index_entries) {
Assert(schema_->operator[](entry.field_name).is_vector());
auto vec_ptr = std::static_pointer_cast<knowhere::VecIndex>(indexings_[index_name]);
total_bytes += vec_ptr->IndexSize();

View File

@ -21,12 +21,12 @@ struct ColumnBasedDataChunk {
std::vector<std::vector<float>> entity_vecs;
static ColumnBasedDataChunk
from(const DogDataChunk &source, const Schema &schema) {
from(const DogDataChunk& source, const Schema& schema) {
ColumnBasedDataChunk dest;
auto count = source.count;
auto raw_data = reinterpret_cast<const char *>(source.raw_data);
auto raw_data = reinterpret_cast<const char*>(source.raw_data);
auto align = source.sizeof_per_row;
for (auto &field : schema) {
for (auto& field : schema) {
auto len = field.get_sizeof();
Assert(len % sizeof(float) == 0);
std::vector<float> new_col(len * count / sizeof(float));
@ -42,28 +42,33 @@ struct ColumnBasedDataChunk {
};
class SegmentNaive : public SegmentBase {
public:
public:
virtual ~SegmentNaive() = default;
// SegmentBase(std::shared_ptr<FieldsInfo> collection);
int64_t PreInsert(int64_t size) override;
int64_t
PreInsert(int64_t size) override;
// TODO: originally, id should be put into data_chunk
// TODO: Is it ok to put them the other side?
Status
Insert(int64_t reserverd_offset, int64_t size, const int64_t *primary_keys, const Timestamp *timestamps,
const DogDataChunk &values) override;
Insert(int64_t reserverd_offset,
int64_t size,
const int64_t* primary_keys,
const Timestamp* timestamps,
const DogDataChunk& values) override;
int64_t PreDelete(int64_t size) override;
int64_t
PreDelete(int64_t size) override;
// TODO: add id into delete log, possibly bitmap
Status
Delete(int64_t reserverd_offset, int64_t size, const int64_t *primary_keys, const Timestamp *timestamps) override;
Delete(int64_t reserverd_offset, int64_t size, const int64_t* primary_keys, const Timestamp* timestamps) override;
// query contains metadata of
Status
Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult &results) override;
Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
// stop receive insert requests
// will move data to immutable vector or something
@ -87,7 +92,7 @@ public:
}
Status
LoadRawData(std::string_view field_name, const char *blob, int64_t blob_size) override {
LoadRawData(std::string_view field_name, const char* blob, int64_t blob_size) override {
// TODO: NO-OP
return Status::OK();
}
@ -95,7 +100,7 @@ public:
int64_t
GetMemoryUsageInBytes() override;
public:
public:
ssize_t
get_row_count() const override {
return record_.ack_responder_.GetAck();
@ -111,23 +116,22 @@ public:
return 0;
}
public:
public:
friend std::unique_ptr<SegmentBase>
CreateSegment(SchemaPtr schema);
explicit SegmentNaive(SchemaPtr schema)
: schema_(schema), record_(*schema) {
explicit SegmentNaive(SchemaPtr schema) : schema_(schema), record_(*schema) {
}
private:
// struct MutableRecord {
// ConcurrentVector<uint64_t> uids_;
// tbb::concurrent_vector<Timestamp> timestamps_;
// std::vector<tbb::concurrent_vector<float>> entity_vecs_;
//
// MutableRecord(int entity_size) : entity_vecs_(entity_size) {
// }
// };
private:
// struct MutableRecord {
// ConcurrentVector<uint64_t> uids_;
// tbb::concurrent_vector<Timestamp> timestamps_;
// std::vector<tbb::concurrent_vector<float>> entity_vecs_;
//
// MutableRecord(int entity_size) : entity_vecs_(entity_size) {
// }
// };
struct Record {
std::atomic<int64_t> reserved = 0;
@ -136,31 +140,32 @@ private:
ConcurrentVector<idx_t, true> uids_;
std::vector<std::shared_ptr<VectorBase>> entity_vec_;
Record(const Schema &schema);
Record(const Schema& schema);
template<typename Type>
auto get_vec_entity(int offset) {
template <typename Type>
auto
get_vec_entity(int offset) {
return std::static_pointer_cast<ConcurrentVector<Type>>(entity_vec_[offset]);
}
};
std::shared_ptr<DeletedRecord::TmpBitmap>
get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier, bool force = false);
Status
QueryImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results);
QueryImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results);
Status
QuerySlowImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results);
QuerySlowImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results);
Status
QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results);
QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results);
template<typename Type>
knowhere::IndexPtr BuildVecIndexImpl(const IndexMeta::Entry &entry);
template <typename Type>
knowhere::IndexPtr
BuildVecIndexImpl(const IndexMeta::Entry& entry);
private:
private:
SchemaPtr schema_;
std::atomic<SegmentState> state_ = SegmentState::Open;
Record record_;
@ -168,7 +173,7 @@ private:
std::atomic<bool> index_ready_ = false;
IndexMetaPtr index_meta_;
std::unordered_map<std::string, knowhere::IndexPtr> indexings_; // index_name => indexing
std::unordered_map<std::string, knowhere::IndexPtr> indexings_; // index_name => indexing
tbb::concurrent_unordered_multimap<idx_t, int64_t> uid2offset_;
};
} // namespace milvus::dog_segment

View File

@ -3,28 +3,28 @@
CCollection
NewCollection(const char* collection_name, const char* schema_conf) {
auto name = std::string(collection_name);
auto conf = std::string(schema_conf);
auto name = std::string(collection_name);
auto conf = std::string(schema_conf);
auto collection = std::make_unique<milvus::dog_segment::Collection>(name, conf);
auto collection = std::make_unique<milvus::dog_segment::Collection>(name, conf);
// TODO: delete print
std::cout << "create collection " << collection_name << std::endl;
return (void*)collection.release();
// TODO: delete print
std::cout << "create collection " << collection_name << std::endl;
return (void*)collection.release();
}
void
DeleteCollection(CCollection collection) {
auto col = (milvus::dog_segment::Collection*)collection;
auto col = (milvus::dog_segment::Collection*)collection;
// TODO: delete print
std::cout << "delete collection " << col->get_collection_name() << std::endl;
delete col;
// TODO: delete print
std::cout << "delete collection " << col->get_collection_name() << std::endl;
delete col;
}
void
UpdateIndexes(CCollection c_collection, const char *index_string) {
auto c = (milvus::dog_segment::Collection*)c_collection;
std::string s(index_string);
c->CreateIndex(s);
UpdateIndexes(CCollection c_collection, const char* index_string) {
auto c = (milvus::dog_segment::Collection*)c_collection;
std::string s(index_string);
c->CreateIndex(s);
}

View File

@ -11,7 +11,7 @@ void
DeleteCollection(CCollection collection);
void
UpdateIndexes(CCollection c_collection, const char *index_string);
UpdateIndexes(CCollection c_collection, const char* index_string);
#ifdef __cplusplus
}

View File

@ -4,26 +4,26 @@
CPartition
NewPartition(CCollection collection, const char* partition_name) {
auto c = (milvus::dog_segment::Collection*)collection;
auto c = (milvus::dog_segment::Collection*)collection;
auto name = std::string(partition_name);
auto name = std::string(partition_name);
auto schema = c->get_schema();
auto schema = c->get_schema();
auto index = c->get_index();
auto index = c->get_index();
auto partition = std::make_unique<milvus::dog_segment::Partition>(name, schema, index);
auto partition = std::make_unique<milvus::dog_segment::Partition>(name, schema, index);
// TODO: delete print
std::cout << "create partition " << name << std::endl;
return (void*)partition.release();
// TODO: delete print
std::cout << "create partition " << name << std::endl;
return (void*)partition.release();
}
void
DeletePartition(CPartition partition) {
auto p = (milvus::dog_segment::Partition*)partition;
auto p = (milvus::dog_segment::Partition*)partition;
// TODO: delete print
std::cout << "delete partition " << p->get_partition_name() <<std::endl;
delete p;
// TODO: delete print
std::cout << "delete partition " << p->get_partition_name() << std::endl;
delete p;
}

View File

@ -8,89 +8,83 @@
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include <knowhere/index/vector_index/VecIndexFactory.h>
CSegmentBase
NewSegment(CPartition partition, unsigned long segment_id) {
auto p = (milvus::dog_segment::Partition*)partition;
auto p = (milvus::dog_segment::Partition*)partition;
auto segment = milvus::dog_segment::CreateSegment(p->get_schema());
auto segment = milvus::dog_segment::CreateSegment(p->get_schema());
// TODO: delete print
std::cout << "create segment " << segment_id << std::endl;
return (void*)segment.release();
// TODO: delete print
std::cout << "create segment " << segment_id << std::endl;
return (void*)segment.release();
}
void
DeleteSegment(CSegmentBase segment) {
auto s = (milvus::dog_segment::SegmentBase*)segment;
auto s = (milvus::dog_segment::SegmentBase*)segment;
// TODO: delete print
std::cout << "delete segment " << std::endl;
delete s;
// TODO: delete print
std::cout << "delete segment " << std::endl;
delete s;
}
//////////////////////////////////////////////////////////////////
int
Insert(CSegmentBase c_segment,
long int reserved_offset,
signed long int size,
const long* primary_keys,
const unsigned long* timestamps,
void* raw_data,
int sizeof_per_row,
signed long int count) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
milvus::dog_segment::DogDataChunk dataChunk{};
long int reserved_offset,
signed long int size,
const long* primary_keys,
const unsigned long* timestamps,
void* raw_data,
int sizeof_per_row,
signed long int count) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
milvus::dog_segment::DogDataChunk dataChunk{};
dataChunk.raw_data = raw_data;
dataChunk.sizeof_per_row = sizeof_per_row;
dataChunk.count = count;
dataChunk.raw_data = raw_data;
dataChunk.sizeof_per_row = sizeof_per_row;
dataChunk.count = count;
auto res = segment->Insert(reserved_offset, size, primary_keys, timestamps, dataChunk);
auto res = segment->Insert(reserved_offset, size, primary_keys, timestamps, dataChunk);
// TODO: delete print
// std::cout << "do segment insert, sizeof_per_row = " << sizeof_per_row << std::endl;
return res.code();
// TODO: delete print
// std::cout << "do segment insert, sizeof_per_row = " << sizeof_per_row << std::endl;
return res.code();
}
long int
PreInsert(CSegmentBase c_segment, long int size) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
// TODO: delete print
// std::cout << "PreInsert segment " << std::endl;
return segment->PreInsert(size);
// TODO: delete print
// std::cout << "PreInsert segment " << std::endl;
return segment->PreInsert(size);
}
int
Delete(CSegmentBase c_segment,
long int reserved_offset,
long size,
const long* primary_keys,
const unsigned long* timestamps) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
long int reserved_offset,
long size,
const long* primary_keys,
const unsigned long* timestamps) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto res = segment->Delete(reserved_offset, size, primary_keys, timestamps);
return res.code();
auto res = segment->Delete(reserved_offset, size, primary_keys, timestamps);
return res.code();
}
long int
PreDelete(CSegmentBase c_segment, long int size) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
// TODO: delete print
// std::cout << "PreDelete segment " << std::endl;
return segment->PreDelete(size);
// TODO: delete print
// std::cout << "PreDelete segment " << std::endl;
return segment->PreDelete(size);
}
//int
//Search(CSegmentBase c_segment,
// int
// Search(CSegmentBase c_segment,
// const char* query_json,
// unsigned long timestamp,
// float* query_raw_data,
@ -125,41 +119,42 @@ PreDelete(CSegmentBase c_segment, long int size) {
int
Search(CSegmentBase c_segment,
CQueryInfo c_query_info,
unsigned long timestamp,
float* query_raw_data,
int num_of_query_raw_data,
long int* result_ids,
float* result_distances) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
milvus::dog_segment::QueryResult query_result;
CQueryInfo c_query_info,
unsigned long timestamp,
float* query_raw_data,
int num_of_query_raw_data,
long int* result_ids,
float* result_distances) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
milvus::dog_segment::QueryResult query_result;
// construct QueryPtr
auto query_ptr = std::make_shared<milvus::query::Query>();
query_ptr->num_queries = c_query_info.num_queries;
query_ptr->topK = c_query_info.topK;
query_ptr->field_name = c_query_info.field_name;
// construct QueryPtr
auto query_ptr = std::make_shared<milvus::query::Query>();
query_ptr->query_raw_data.resize(num_of_query_raw_data);
memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float));
query_ptr->num_queries = c_query_info.num_queries;
query_ptr->topK = c_query_info.topK;
query_ptr->field_name = c_query_info.field_name;
auto res = segment->Query(query_ptr, timestamp, query_result);
query_ptr->query_raw_data.resize(num_of_query_raw_data);
memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float));
// result_ids and result_distances have been allocated memory in goLang,
// so we don't need to malloc here.
memcpy(result_ids, query_result.result_ids_.data(), query_result.row_num_ * sizeof(long int));
memcpy(result_distances, query_result.result_distances_.data(), query_result.row_num_ * sizeof(float));
auto res = segment->Query(query_ptr, timestamp, query_result);
return res.code();
// result_ids and result_distances have been allocated memory in goLang,
// so we don't need to malloc here.
memcpy(result_ids, query_result.result_ids_.data(), query_result.row_num_ * sizeof(long int));
memcpy(result_distances, query_result.result_distances_.data(), query_result.row_num_ * sizeof(float));
return res.code();
}
//////////////////////////////////////////////////////////////////
int
Close(CSegmentBase c_segment) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto status = segment->Close();
return status.code();
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto status = segment->Close();
return status.code();
}
int
@ -171,34 +166,32 @@ BuildIndex(CCollection c_collection, CSegmentBase c_segment) {
return status.code();
}
bool
IsOpened(CSegmentBase c_segment) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto status = segment->get_state();
return status == milvus::dog_segment::SegmentBase::SegmentState::Open;
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto status = segment->get_state();
return status == milvus::dog_segment::SegmentBase::SegmentState::Open;
}
long int
GetMemoryUsageInBytes(CSegmentBase c_segment) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto mem_size = segment->GetMemoryUsageInBytes();
return mem_size;
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto mem_size = segment->GetMemoryUsageInBytes();
return mem_size;
}
//////////////////////////////////////////////////////////////////
long int
GetRowCount(CSegmentBase c_segment) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto row_count = segment->get_row_count();
return row_count;
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto row_count = segment->get_row_count();
return row_count;
}
long int
GetDeletedCount(CSegmentBase c_segment) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto deleted_count = segment->get_deleted_count();
return deleted_count;
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto deleted_count = segment->get_deleted_count();
return deleted_count;
}

View File

@ -23,29 +23,29 @@ DeleteSegment(CSegmentBase segment);
int
Insert(CSegmentBase c_segment,
long int reserved_offset,
signed long int size,
const long* primary_keys,
const unsigned long* timestamps,
void* raw_data,
int sizeof_per_row,
signed long int count);
long int reserved_offset,
signed long int size,
const long* primary_keys,
const unsigned long* timestamps,
void* raw_data,
int sizeof_per_row,
signed long int count);
long int
PreInsert(CSegmentBase c_segment, long int size);
int
Delete(CSegmentBase c_segment,
long int reserved_offset,
long size,
const long* primary_keys,
const unsigned long* timestamps);
long int reserved_offset,
long size,
const long* primary_keys,
const unsigned long* timestamps);
long int
PreDelete(CSegmentBase c_segment, long int size);
//int
//Search(CSegmentBase c_segment,
// int
// Search(CSegmentBase c_segment,
// const char* query_json,
// unsigned long timestamp,
// float* query_raw_data,
@ -55,7 +55,7 @@ PreDelete(CSegmentBase c_segment, long int size);
int
Search(CSegmentBase c_segment,
CQueryInfo c_query_info,
CQueryInfo c_query_info,
unsigned long timestamp,
float* query_raw_data,
int num_of_query_raw_data,

View File

@ -52,7 +52,18 @@ include(BuildUtilsCore)
using_ccache_if_defined( KNOWHERE_USE_CCACHE )
message(STATUS "Building Knowhere CPU version")
if (MILVUS_GPU_VERSION)
message(STATUS "Building Knowhere GPU version")
add_compile_definitions("MILVUS_GPU_VERSION")
enable_language(CUDA)
find_package(CUDA 10 REQUIRED)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fPIC -std=c++11 -D_FORCE_INLINES --expt-extended-lambda")
if ( CCACHE_FOUND )
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_FOUND}")
endif()
else ()
message(STATUS "Building Knowhere CPU version")
endif ()
if (MILVUS_SUPPORT_SPTAG)
message(STATUS "Building Knowhere with SPTAG supported")
@ -63,8 +74,14 @@ include(ThirdPartyPackagesCore)
if (CMAKE_BUILD_TYPE STREQUAL "Release")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -DELPP_THREAD_SAFE -fopenmp")
if (MILVUS_GPU_VERSION)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3")
endif ()
else ()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp")
if (MILVUS_GPU_VERSION)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O0 -g")
endif ()
endif ()
add_subdirectory(knowhere)
@ -75,10 +92,9 @@ endif ()
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
#if (KNOWHERE_BUILD_TESTS)
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS")
# add_subdirectory(unittest)
#endif ()
if (KNOWHERE_BUILD_TESTS)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS")
add_subdirectory(unittest)
endif ()
config_summary()

View File

@ -13,14 +13,17 @@
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#endif
#include <faiss/Clustering.h>
#include <faiss/utils/distances.h>
#include "config/ServerConfig.h"
#include "faiss/FaissHook.h"
// #include "scheduler/Utils.h"
#include "scheduler/Utils.h"
#include "utils/ConfigUtils.h"
#include "utils/Error.h"
#include "utils/Log.h"
// #include <fiu/fiu-local.h>
#include <fiu/fiu-local.h>
#include <map>
#include <set>
#include <string>
@ -60,9 +63,38 @@ KnowhereResource::Initialize() {
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
}
// engine config
int64_t omp_thread = config.engine.omp_thread_num();
if (omp_thread > 0) {
omp_set_num_threads(omp_thread);
LOG_SERVER_DEBUG_ << "Specify openmp thread number: " << omp_thread;
} else {
int64_t sys_thread_cnt = 8;
if (milvus::server::GetSystemAvailableThreads(sys_thread_cnt)) {
omp_thread = static_cast<int32_t>(ceil(sys_thread_cnt * 0.5));
omp_set_num_threads(omp_thread);
}
}
// init faiss global variable
int64_t use_blas_threshold = config.engine.use_blas_threshold();
faiss::distance_compute_blas_threshold = use_blas_threshold;
int64_t clustering_type = config.engine.clustering_type();
switch (clustering_type) {
case ClusteringType::K_MEANS:
default:
faiss::clustering_type = faiss::ClusteringType::K_MEANS;
break;
case ClusteringType::K_MEANS_PLUS_PLUS:
faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS;
break;
}
#ifdef MILVUS_GPU_VERSION
bool enable_gpu = config.gpu.enable();
// fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false);
fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false);
if (!enable_gpu) {
return Status::OK();
}

View File

@ -64,7 +64,7 @@ define_option_string(KNOWHERE_DEPENDENCY_SOURCE
"BUNDLED"
"SYSTEM")
define_option(KNOWHERE_USE_CCACHE "Use ccache when compiling (if available)" OFF)
define_option(KNOWHERE_USE_CCACHE "Use ccache when compiling (if available)" ON)
define_option(KNOWHERE_VERBOSE_THIRDPARTY_BUILD
"Show output from ExternalProjects rather than just logging to files" ON)
@ -82,7 +82,7 @@ define_option(KNOWHERE_WITH_OPENBLAS "Build with OpenBLAS library" ON)
define_option(KNOWHERE_WITH_FAISS "Build with FAISS library" ON)
define_option(KNOWHERE_WITH_FAISS_GPU_VERSION "Build with FAISS GPU version" OFF)
define_option(KNOWHERE_WITH_FAISS_GPU_VERSION "Build with FAISS GPU version" ON)
define_option(FAISS_WITH_MKL "Build FAISS with MKL" OFF)

View File

@ -32,8 +32,7 @@ macro(build_dependency DEPENDENCY_NAME)
if ("${DEPENDENCY_NAME}" STREQUAL "Arrow")
build_arrow()
elseif ("${DEPENDENCY_NAME}" STREQUAL "GTest")
# build_gtest()
# find_package(GTest REQUIRED)
find_package(GTest REQUIRED)
elseif ("${DEPENDENCY_NAME}" STREQUAL "OpenBLAS")
build_openblas()
elseif ("${DEPENDENCY_NAME}" STREQUAL "FAISS")
@ -216,12 +215,12 @@ else ()
)
endif ()
if (DEFINED ENV{KNOWHERE_GTEST_URL})
set(GTEST_SOURCE_URL "$ENV{KNOWHERE_GTEST_URL}")
else ()
set(GTEST_SOURCE_URL
"https://github.com/google/googletest/archive/release-${GTEST_VERSION}.tar.gz")
endif ()
# if (DEFINED ENV{KNOWHERE_GTEST_URL})
# set(GTEST_SOURCE_URL "$ENV{KNOWHERE_GTEST_URL}")
# else ()
# set(GTEST_SOURCE_URL
# "https://github.com/google/googletest/archive/release-${GTEST_VERSION}.tar.gz")
# endif ()
if (DEFINED ENV{KNOWHERE_OPENBLAS_URL})
set(OPENBLAS_SOURCE_URL "$ENV{KNOWHERE_OPENBLAS_URL}")
@ -387,77 +386,77 @@ endif()
# ----------------------------------------------------------------------
# Google gtest
#macro(build_gtest)
# message(STATUS "Building gtest-${GTEST_VERSION} from source")
# set(GTEST_VENDORED TRUE)
# set(GTEST_CMAKE_CXX_FLAGS "${EP_CXX_FLAGS}")
#
# if (APPLE)
# set(GTEST_CMAKE_CXX_FLAGS
# ${GTEST_CMAKE_CXX_FLAGS}
# -DGTEST_USE_OWN_TR1_TUPLE=1
# -Wno-unused-value
# -Wno-ignored-attributes)
# endif ()
#
# set(GTEST_PREFIX "${INDEX_BINARY_DIR}/googletest_ep-prefix/src/googletest_ep")
# set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include")
# set(GTEST_STATIC_LIB
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}")
# set(GTEST_MAIN_STATIC_LIB
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}")
#
# set(GTEST_CMAKE_ARGS
# ${EP_COMMON_CMAKE_ARGS}
# "-DCMAKE_INSTALL_PREFIX=${GTEST_PREFIX}"
# "-DCMAKE_INSTALL_LIBDIR=lib"
# -DCMAKE_CXX_FLAGS=${GTEST_CMAKE_CXX_FLAGS}
# -DCMAKE_BUILD_TYPE=Release)
#
# set(GMOCK_INCLUDE_DIR "${GTEST_PREFIX}/include")
# set(GMOCK_STATIC_LIB
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX}"
# )
#
# ExternalProject_Add(googletest_ep
# URL
# ${GTEST_SOURCE_URL}
# BUILD_COMMAND
# ${MAKE}
# ${MAKE_BUILD_ARGS}
# BUILD_BYPRODUCTS
# ${GTEST_STATIC_LIB}
# ${GTEST_MAIN_STATIC_LIB}
# ${GMOCK_STATIC_LIB}
# CMAKE_ARGS
# ${GTEST_CMAKE_ARGS}
# ${EP_LOG_OPTIONS})
#
# # The include directory must exist before it is referenced by a target.
# file(MAKE_DIRECTORY "${GTEST_INCLUDE_DIR}")
#
# add_library(gtest STATIC IMPORTED)
# set_target_properties(gtest
# PROPERTIES IMPORTED_LOCATION "${GTEST_STATIC_LIB}"
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
#
# add_library(gtest_main STATIC IMPORTED)
# set_target_properties(gtest_main
# PROPERTIES IMPORTED_LOCATION "${GTEST_MAIN_STATIC_LIB}"
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
#
# add_library(gmock STATIC IMPORTED)
# set_target_properties(gmock
# PROPERTIES IMPORTED_LOCATION "${GMOCK_STATIC_LIB}"
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
#
# add_dependencies(gtest googletest_ep)
# add_dependencies(gtest_main googletest_ep)
# add_dependencies(gmock googletest_ep)
#
#endmacro()
# macro(build_gtest)
# message(STATUS "Building gtest-${GTEST_VERSION} from source")
# set(GTEST_VENDORED TRUE)
# set(GTEST_CMAKE_CXX_FLAGS "${EP_CXX_FLAGS}")
#
# if (APPLE)
# set(GTEST_CMAKE_CXX_FLAGS
# ${GTEST_CMAKE_CXX_FLAGS}
# -DGTEST_USE_OWN_TR1_TUPLE=1
# -Wno-unused-value
# -Wno-ignored-attributes)
# endif ()
#
# set(GTEST_PREFIX "${INDEX_BINARY_DIR}/googletest_ep-prefix/src/googletest_ep")
# set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include")
# set(GTEST_STATIC_LIB
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}")
# set(GTEST_MAIN_STATIC_LIB
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}")
#
# set(GTEST_CMAKE_ARGS
# ${EP_COMMON_CMAKE_ARGS}
# "-DCMAKE_INSTALL_PREFIX=${GTEST_PREFIX}"
# "-DCMAKE_INSTALL_LIBDIR=lib"
# -DCMAKE_CXX_FLAGS=${GTEST_CMAKE_CXX_FLAGS}
# -DCMAKE_BUILD_TYPE=Release)
#
# set(GMOCK_INCLUDE_DIR "${GTEST_PREFIX}/include")
# set(GMOCK_STATIC_LIB
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX}"
# )
#
# ExternalProject_Add(googletest_ep
# URL
# ${GTEST_SOURCE_URL}
# BUILD_COMMAND
# ${MAKE}
# ${MAKE_BUILD_ARGS}
# BUILD_BYPRODUCTS
# ${GTEST_STATIC_LIB}
# ${GTEST_MAIN_STATIC_LIB}
# ${GMOCK_STATIC_LIB}
# CMAKE_ARGS
# ${GTEST_CMAKE_ARGS}
# ${EP_LOG_OPTIONS})
#
# # The include directory must exist before it is referenced by a target.
# file(MAKE_DIRECTORY "${GTEST_INCLUDE_DIR}")
#
# add_library(gtest STATIC IMPORTED)
# set_target_properties(gtest
# PROPERTIES IMPORTED_LOCATION "${GTEST_STATIC_LIB}"
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
#
# add_library(gtest_main STATIC IMPORTED)
# set_target_properties(gtest_main
# PROPERTIES IMPORTED_LOCATION "${GTEST_MAIN_STATIC_LIB}"
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
#
# add_library(gmock STATIC IMPORTED)
# set_target_properties(gmock
# PROPERTIES IMPORTED_LOCATION "${GMOCK_STATIC_LIB}"
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
#
# add_dependencies(gtest googletest_ep)
# add_dependencies(gtest_main googletest_ep)
# add_dependencies(gmock googletest_ep)
#
# endmacro()
# if (KNOWHERE_BUILD_TESTS AND NOT TARGET googletest_ep)
## if (KNOWHERE_BUILD_TESTS AND NOT TARGET googletest_ep)
#if ( NOT TARGET gtest AND KNOWHERE_BUILD_TESTS )
# resolve_dependency(GTest)
#
@ -654,3 +653,5 @@ if (KNOWHERE_WITH_FAISS AND NOT TARGET faiss_ep)
include_directories(SYSTEM "${FAISS_INCLUDE_DIR}")
link_directories(SYSTEM ${FAISS_PREFIX}/lib/)
endif ()
add_subdirectory(thirdparty/NGT)

View File

@ -13,6 +13,7 @@
include_directories(${INDEX_SOURCE_DIR}/knowhere)
include_directories(${INDEX_SOURCE_DIR}/thirdparty)
include_directories(${INDEX_SOURCE_DIR}/thirdparty/NGT/lib)
if (MILVUS_SUPPORT_SPTAG)
include_directories(${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService)
@ -68,6 +69,9 @@ set(vector_index_srcs
knowhere/index/vector_index/IndexRHNSWFlat.cpp
knowhere/index/vector_index/IndexRHNSWSQ.cpp
knowhere/index/vector_index/IndexRHNSWPQ.cpp
knowhere/index/vector_index/IndexNGT.cpp
knowhere/index/vector_index/IndexNGTPANNG.cpp
knowhere/index/vector_index/IndexNGTONNG.cpp
)
set(vector_offset_index_srcs
@ -90,6 +94,8 @@ set(depend_libs
gomp
gfortran
pthread
fiu
ngt
)
if (MILVUS_SUPPORT_SPTAG)
@ -100,6 +106,32 @@ if (MILVUS_SUPPORT_SPTAG)
endif ()
if (MILVUS_GPU_VERSION)
include_directories(${CUDA_INCLUDE_DIRS})
link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64")
set(cuda_lib
cudart
cublas
)
set(depend_libs ${depend_libs}
${cuda_lib}
)
set(vector_index_srcs ${vector_index_srcs}
knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp
knowhere/index/vector_index/gpu/IndexGPUIVF.cpp
knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp
knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp
knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp
knowhere/index/vector_index/helpers/Cloner.cpp
knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp
)
set(vector_offset_index_srcs ${vector_offset_index_srcs}
knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp
)
endif ()
if (NOT TARGET knowhere)
add_library(
knowhere STATIC
@ -130,11 +162,3 @@ if (MILVUS_SUPPORT_SPTAG)
endif ()
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
# **************************** Get&Print Include Directories ****************************
get_property( dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES )
foreach ( dir ${dirs} )
message( STATUS "Knowhere Current Include DIRS: " ${dir} )
endforeach ()

View File

@ -37,6 +37,8 @@ const char* INDEX_RHNSWFlat = "RHNSW_FLAT";
const char* INDEX_RHNSWPQ = "RHNSW_PQ";
const char* INDEX_RHNSWSQ = "RHNSW_SQ";
const char* INDEX_ANNOY = "ANNOY";
const char* INDEX_NGTPANNG = "NGT_PANNG";
const char* INDEX_NGTONNG = "NGT_ONNG";
} // namespace IndexEnum
} // namespace knowhere

View File

@ -64,6 +64,8 @@ extern const char* INDEX_RHNSWFlat;
extern const char* INDEX_RHNSWPQ;
extern const char* INDEX_RHNSWSQ;
extern const char* INDEX_ANNOY;
extern const char* INDEX_NGTPANNG;
extern const char* INDEX_NGTONNG;
} // namespace IndexEnum
enum class IndexMode { MODE_CPU = 0, MODE_GPU = 1 };

View File

@ -25,13 +25,20 @@ namespace milvus {
namespace knowhere {
static const int64_t MIN_NLIST = 1;
static const int64_t MAX_NLIST = 1LL << 20;
static const int64_t MAX_NLIST = 65536;
static const int64_t MIN_NPROBE = 1;
static const int64_t MAX_NPROBE = MAX_NLIST;
static const int64_t DEFAULT_MIN_DIM = 1;
static const int64_t DEFAULT_MAX_DIM = 32768;
static const int64_t DEFAULT_MIN_ROWS = 1; // minimum size for build index
static const int64_t DEFAULT_MAX_ROWS = 50000000;
static const int64_t NGT_MIN_EDGE_SIZE = 1;
static const int64_t NGT_MAX_EDGE_SIZE = 200;
static const int64_t HNSW_MIN_EFCONSTRUCTION = 8;
static const int64_t HNSW_MAX_EFCONSTRUCTION = 512;
static const int64_t HNSW_MIN_M = 4;
static const int64_t HNSW_MAX_M = 64;
static const int64_t HNSW_MAX_EF = 32768;
static const std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::IP};
#define CheckIntByRange(key, min, max) \
@ -146,24 +153,34 @@ IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
// auto tune params
oricfg[knowhere::IndexParams::nlist] =
MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), oricfg[knowhere::IndexParams::nlist].get<int64_t>());
auto m = oricfg[knowhere::IndexParams::m].get<int64_t>();
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
// Best Practice
// static int64_t MIN_POINTS_PER_CENTROID = 40;
// static int64_t MAX_POINTS_PER_CENTROID = 256;
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
std::vector<int64_t> resset;
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
IVFPQConfAdapter::GetValidMList(dimension, resset);
CheckIntByValues(knowhere::IndexParams::m, resset);
/*std::vector<int64_t> resset;
IVFPQConfAdapter::GetValidCPUM(dimension, resset);*/
IndexMode ivfpq_mode = mode;
return GetValidM(dimension, m, ivfpq_mode);
}
bool
IVFPQConfAdapter::GetValidM(int64_t dimension, int64_t m, IndexMode& mode) {
#ifdef MILVUS_GPU_VERSION
if (mode == knowhere::IndexMode::MODE_GPU && !IVFPQConfAdapter::GetValidGPUM(dimension, m)) {
mode = knowhere::IndexMode::MODE_CPU;
}
#endif
if (mode == knowhere::IndexMode::MODE_CPU && !IVFPQConfAdapter::GetValidCPUM(dimension, m)) {
return false;
}
return true;
}
void
IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset) {
resset.clear();
bool
IVFPQConfAdapter::GetValidGPUM(int64_t dimension, int64_t m) {
/*
* Faiss 1.6
* Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with
@ -172,7 +189,14 @@ IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset)
static const std::vector<int64_t> support_dim_per_subquantizer{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1};
static const std::vector<int64_t> support_subquantizer{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1};
for (const auto& dimperquantizer : support_dim_per_subquantizer) {
int64_t sub_dim = dimension / m;
return (std::find(std::begin(support_subquantizer), std::end(support_subquantizer), m) !=
support_subquantizer.end()) &&
(std::find(std::begin(support_dim_per_subquantizer), std::end(support_dim_per_subquantizer), sub_dim) !=
support_dim_per_subquantizer.end());
/*resset.clear();
for (const auto& dimperquantizer : support_dim_per_subquantizer) {
if (!(dimension % dimperquantizer)) {
auto subquantzier_num = dimension / dimperquantizer;
auto finder = std::find(support_subquantizer.begin(), support_subquantizer.end(), subquantzier_num);
@ -180,7 +204,12 @@ IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset)
resset.push_back(subquantzier_num);
}
}
}
}*/
}
bool
IVFPQConfAdapter::GetValidCPUM(int64_t dimension, int64_t m) {
return (dimension % m == 0);
}
bool
@ -222,97 +251,68 @@ NSGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMod
bool
HNSWConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_EFCONSTRUCTION = 8;
static int64_t MAX_EFCONSTRUCTION = 512;
static int64_t MIN_M = 4;
static int64_t MAX_M = 64;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
HNSWConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
RHNSWFlatConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_EFCONSTRUCTION = 8;
static int64_t MAX_EFCONSTRUCTION = 512;
static int64_t MIN_M = 4;
static int64_t MAX_M = 64;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
RHNSWFlatConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_EFCONSTRUCTION = 8;
static int64_t MAX_EFCONSTRUCTION = 512;
static int64_t MIN_M = 4;
static int64_t MAX_M = 64;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
std::vector<int64_t> resset;
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
IVFPQConfAdapter::GetValidMList(dimension, resset);
CheckIntByValues(knowhere::IndexParams::PQM, resset);
IVFPQConfAdapter::GetValidCPUM(dimension, oricfg[knowhere::IndexParams::PQM].get<int64_t>());
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
RHNSWPQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
RHNSWSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_EFCONSTRUCTION = 8;
static int64_t MAX_EFCONSTRUCTION = 512;
static int64_t MIN_M = 4;
static int64_t MAX_M = 64;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
RHNSWSQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
@ -368,5 +368,39 @@ ANNOYConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexM
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
NGTPANNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
CheckIntByRange(knowhere::IndexParams::edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
return true;
}
bool
NGTPANNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
NGTONNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
CheckIntByRange(knowhere::IndexParams::edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
return true;
}
bool
NGTONNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
} // namespace knowhere
} // namespace milvus

View File

@ -51,8 +51,14 @@ class IVFPQConfAdapter : public IVFConfAdapter {
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
static void
GetValidMList(int64_t dimension, std::vector<int64_t>& resset);
static bool
GetValidM(int64_t dimension, int64_t m, IndexMode& mode);
static bool
GetValidGPUM(int64_t dimension, int64_t m);
static bool
GetValidCPUM(int64_t dimension, int64_t m);
};
class NSGConfAdapter : public IVFConfAdapter {
@ -120,5 +126,24 @@ class RHNSWSQConfAdapter : public ConfAdapter {
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class NGTPANNGConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class NGTONNGConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
} // namespace knowhere
} // namespace milvus

View File

@ -42,7 +42,7 @@ AdapterMgr::RegisterAdapter() {
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq8_adapter);
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8H, ivfsq8h_adapter);
REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IDMAP, idmap_bin_adapter);
REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivf_bin_adapter);
REGISTER_CONF_ADAPTER(BinIVFConfAdapter, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivf_bin_adapter);
REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexEnum::INDEX_NSG, nsg_adapter);
#ifdef MILVUS_SUPPORT_SPTAG
REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_SPTAG_KDT_RNT, sptag_kdt_adapter);
@ -53,6 +53,8 @@ AdapterMgr::RegisterAdapter() {
REGISTER_CONF_ADAPTER(RHNSWFlatConfAdapter, IndexEnum::INDEX_RHNSWFlat, rhnswflat_adapter);
REGISTER_CONF_ADAPTER(RHNSWPQConfAdapter, IndexEnum::INDEX_RHNSWPQ, rhnswpq_adapter);
REGISTER_CONF_ADAPTER(RHNSWSQConfAdapter, IndexEnum::INDEX_RHNSWSQ, rhnswsq_adapter);
REGISTER_CONF_ADAPTER(NGTPANNGConfAdapter, IndexEnum::INDEX_NGTPANNG, ngtpanng_adapter);
REGISTER_CONF_ADAPTER(NGTONNGConfAdapter, IndexEnum::INDEX_NGTONNG, ngtonng_adapter);
}
} // namespace knowhere

View File

@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <faiss/index_io.h>
#include <fiu/fiu-local.h>
#include "knowhere/common/Exception.h"
#include "knowhere/index/IndexType.h"
@ -22,6 +23,7 @@ namespace knowhere {
BinarySet
FaissBaseIndex::SerializeImpl(const IndexType& type) {
try {
fiu_do_on("FaissBaseIndex.SerializeImpl.throw_exception", throw std::exception());
faiss::Index* index = index_.get();
MemoryIOWriter writer;

View File

@ -105,7 +105,7 @@ IndexAnnoy::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -116,7 +116,6 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto all_num = rows * k;
auto p_id = static_cast<int64_t*>(malloc(all_num * sizeof(int64_t)));
auto p_dist = static_cast<float*>(malloc(all_num * sizeof(float)));
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
#pragma omp parallel for
for (unsigned int i = 0; i < rows; ++i) {
@ -125,7 +124,7 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
std::vector<float> distances;
distances.reserve(k);
index_->get_nns_by_vector(static_cast<const float*>(p_data) + i * dim, k, search_k, &result, &distances,
blacklist);
bitset);
int64_t result_num = result.size();
auto local_p_id = p_id + k * i;

View File

@ -54,7 +54,7 @@ class IndexAnnoy : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;

View File

@ -40,7 +40,7 @@ BinaryIDMAP::Load(const BinarySet& index_binary) {
}
DatasetPtr
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
@ -53,7 +53,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config);
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config, bitset);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
@ -141,14 +141,19 @@ BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config)
}
void
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
BinaryIDMAP::QueryImpl(int64_t n,
const uint8_t* data,
int64_t k,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
// assign the metric type
auto bin_flat_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get())->index;
bin_flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto i_distances = reinterpret_cast<int32_t*>(distances);
bin_flat_index->search(n, data, k, i_distances, labels, bitset_);
bin_flat_index->search(n, data, k, i_distances, labels, bitset);
// if hamming, it need transform int32 to float
if (bin_flat_index->metric_type == faiss::METRIC_Hamming) {

View File

@ -48,7 +48,7 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;
@ -69,7 +69,13 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
protected:
virtual void
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
QueryImpl(int64_t n,
const uint8_t* data,
int64_t k,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset);
protected:
std::mutex mutex_;

View File

@ -43,7 +43,7 @@ BinaryIVF::Load(const BinarySet& index_binary) {
}
DatasetPtr
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -59,7 +59,7 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config);
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config, bitset);
auto ret_ds = std::make_shared<Dataset>();
@ -126,15 +126,20 @@ BinaryIVF::GenParams(const Config& config) {
}
void
BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
BinaryIVF::QueryImpl(int64_t n,
const uint8_t* data,
int64_t k,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
stdclock::time_point before = stdclock::now();
auto i_distances = reinterpret_cast<int32_t*>(distances);
index_->search(n, data, k, i_distances, labels, bitset_);
index_->search(n, data, k, i_distances, labels, bitset);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();

View File

@ -60,7 +60,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;
@ -76,7 +76,13 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
GenParams(const Config& config);
virtual void
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
QueryImpl(int64_t n,
const uint8_t* data,
int64_t k,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset);
protected:
std::mutex mutex_;

View File

@ -136,7 +136,7 @@ IndexHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -153,7 +153,6 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
using P = std::pair<float, int64_t>;
auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; };
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
#pragma omp parallel for
for (unsigned int i = 0; i < rows; ++i) {
std::vector<P> ret;
@ -166,7 +165,7 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
// } else {
// ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
// }
ret = index_->searchKnn(single_query, k, compare, blacklist);
ret = index_->searchKnn(single_query, k, compare, bitset);
while (ret.size() < k) {
ret.emplace_back(std::make_pair(-1, -1));

View File

@ -46,7 +46,7 @@ class IndexHNSW : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;

View File

@ -95,7 +95,7 @@ IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
@ -108,7 +108,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config);
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
@ -223,11 +223,17 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
#endif
void
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
IDMAP::QueryImpl(int64_t n,
const float* data,
int64_t k,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
// assign the metric type
auto flat_index = dynamic_cast<faiss::IndexIDMap*>(index_.get())->index;
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
index_->search(n, data, k, distances, labels, bitset_);
index_->search(n, data, k, distances, labels, bitset);
}
} // namespace knowhere

View File

@ -46,7 +46,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
#if 0
DatasetPtr
@ -80,7 +80,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
protected:
virtual void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&);
protected:
std::mutex mutex_;

View File

@ -23,6 +23,8 @@
#include <faiss/gpu/GpuCloner.h>
#endif
#include <fiu/fiu-local.h>
#include <algorithm>
#include <chrono>
#include <memory>
#include <string>
@ -95,7 +97,7 @@ IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -103,6 +105,8 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA(dataset_ptr)
try {
fiu_do_on("IVF.Search.throw_std_exception", throw std::exception());
fiu_do_on("IVF.Search.throw_faiss_exception", throw faiss::FaissException(""));
auto k = config[meta::TOPK].get<int64_t>();
auto elems = rows * k;
@ -111,7 +115,7 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config);
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);
// std::stringstream ss_res_id, ss_res_dist;
// for (int i = 0; i < 10; ++i) {
@ -292,7 +296,7 @@ IVF::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config
res.resize(K * b_size);
const float* xq = data + batch_size * dim * i;
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config);
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr);
for (int j = 0; j < b_size; ++j) {
auto& node = graph[batch_size * i + j];
@ -314,17 +318,23 @@ IVF::GenParams(const Config& config) {
}
void
IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
IVF::QueryImpl(int64_t n,
const float* data,
int64_t k,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
ivf_index->nprobe = std::min(params->nprobe, ivf_index->invlists->nlist);
stdclock::time_point before = stdclock::now();
if (params->nprobe > 1 && n <= 4) {
ivf_index->parallel_mode = 1;
} else {
ivf_index->parallel_mode = 0;
}
ivf_index->search(n, data, k, distances, labels, bitset_);
ivf_index->search(n, data, k, distances, labels, bitset);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost

View File

@ -51,7 +51,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
#if 0
DatasetPtr
@ -86,7 +86,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
GenParams(const Config&);
virtual void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&);
void
SealImpl() override;

View File

@ -24,6 +24,7 @@
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/ConfAdapter.h"
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h"
#endif
@ -47,6 +48,12 @@ IVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
VecIndexPtr
IVFPQ::CopyCpuToGpu(const int64_t device_id, const Config& config) {
#ifdef MILVUS_GPU_VERSION
auto ivfpq_index = dynamic_cast<faiss::IndexIVFPQ*>(index_.get());
int64_t dim = ivfpq_index->d;
int64_t m = ivfpq_index->pq.M;
if (!IVFPQConfAdapter::GetValidGPUM(dim, m)) {
return nullptr;
}
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());

View File

@ -0,0 +1,201 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexNGT.h"
#include <omp.h>
#include <sstream>
#include <string>
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
namespace milvus {
namespace knowhere {
BinarySet
IndexNGT::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
std::stringstream obj, grp, prf, tre;
index_->saveIndex(obj, grp, prf, tre);
auto obj_str = obj.str();
auto grp_str = grp.str();
auto prf_str = prf.str();
auto tre_str = tre.str();
uint64_t obj_size = obj_str.size();
uint64_t grp_size = grp_str.size();
uint64_t prf_size = prf_str.size();
uint64_t tre_size = tre_str.size();
std::shared_ptr<uint8_t[]> obj_data(new uint8_t[obj_size]);
memcpy(obj_data.get(), obj_str.data(), obj_size);
std::shared_ptr<uint8_t[]> grp_data(new uint8_t[grp_size]);
memcpy(grp_data.get(), grp_str.data(), grp_size);
std::shared_ptr<uint8_t[]> prf_data(new uint8_t[prf_size]);
memcpy(prf_data.get(), prf_str.data(), prf_size);
std::shared_ptr<uint8_t[]> tre_data(new uint8_t[tre_size]);
memcpy(tre_data.get(), tre_str.data(), tre_size);
BinarySet res_set;
res_set.Append("ngt_obj_data", obj_data, obj_size);
res_set.Append("ngt_grp_data", grp_data, grp_size);
res_set.Append("ngt_prf_data", prf_data, prf_size);
res_set.Append("ngt_tre_data", tre_data, tre_size);
return res_set;
}
void
IndexNGT::Load(const BinarySet& index_binary) {
auto obj_data = index_binary.GetByName("ngt_obj_data");
std::string obj_str(reinterpret_cast<char*>(obj_data->data.get()), obj_data->size);
auto grp_data = index_binary.GetByName("ngt_grp_data");
std::string grp_str(reinterpret_cast<char*>(grp_data->data.get()), grp_data->size);
auto prf_data = index_binary.GetByName("ngt_prf_data");
std::string prf_str(reinterpret_cast<char*>(prf_data->data.get()), prf_data->size);
auto tre_data = index_binary.GetByName("ngt_tre_data");
std::string tre_str(reinterpret_cast<char*>(tre_data->data.get()), tre_data->size);
std::stringstream obj(obj_str);
std::stringstream grp(grp_str);
std::stringstream prf(prf_str);
std::stringstream tre(tre_str);
index_ = std::shared_ptr<NGT::Index>(NGT::Index::loadIndex(obj, grp, prf, tre));
}
void
IndexNGT::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
KNOWHERE_THROW_MSG("IndexNGT has no implementation of BuildAll, please use IndexNGT(PANNG/ONNG) instead!");
}
#if 0
void
IndexNGT::Train(const DatasetPtr& dataset_ptr, const Config& config) {
KNOWHERE_THROW_MSG("IndexNGT has no implementation of Train, please use IndexNGT(PANNG/ONNG) instead!");
GET_TENSOR_DATA_DIM(dataset_ptr);
NGT::Property prop;
prop.setDefaultForCreateIndex();
prop.dimension = dim;
MetricType metric_type = config[Metric::TYPE];
if (metric_type == Metric::L2)
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
else if (metric_type == Metric::HAMMING)
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
else if (metric_type == Metric::JACCARD)
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
else
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
index_ =
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
}
void
IndexNGT::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GET_TENSOR_DATA(dataset_ptr);
index_->append(reinterpret_cast<const float*>(p_data), rows);
}
#endif
DatasetPtr
IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GET_TENSOR_DATA(dataset_ptr);
size_t k = config[meta::TOPK].get<int64_t>();
size_t id_size = sizeof(int64_t) * k;
size_t dist_size = sizeof(float) * k;
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
auto p_dist = static_cast<float*>(malloc(dist_size * rows));
NGT::Command::SearchParameter sp;
sp.size = k;
#pragma omp parallel for
for (unsigned int i = 0; i < rows; ++i) {
const float* single_query = reinterpret_cast<float*>(const_cast<void*>(p_data)) + i * Dim();
NGT::Object* object = index_->allocateObject(single_query, Dim());
NGT::SearchContainer sc(*object);
double epsilon = sp.beginOfEpsilon;
NGT::ObjectDistances res;
sc.setResults(&res);
sc.setSize(sp.size);
sc.setRadius(sp.radius);
if (sp.accuracy > 0.0) {
sc.setExpectedAccuracy(sp.accuracy);
} else {
sc.setEpsilon(epsilon);
}
sc.setEdgeSize(sp.edgeSize);
try {
index_->search(sc, bitset);
} catch (NGT::Exception& err) {
KNOWHERE_THROW_MSG("Query failed");
}
auto local_id = p_id + i * k;
auto local_dist = p_dist + i * k;
int64_t res_num = res.size();
for (int64_t idx = 0; idx < res_num; ++idx) {
*(local_id + idx) = res[idx].id - 1;
*(local_dist + idx) = res[idx].distance;
}
while (res_num < static_cast<int64_t>(k)) {
*(local_id + res_num) = -1;
*(local_dist + res_num) = 1.0 / 0.0;
}
index_->deleteObject(object);
}
auto res_ds = std::make_shared<Dataset>();
res_ds->Set(meta::IDS, p_id);
res_ds->Set(meta::DISTANCE, p_dist);
return res_ds;
}
int64_t
IndexNGT::Count() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->getNumberOfVectors();
}
int64_t
IndexNGT::Dim() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->getDimension();
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,70 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <NGT/lib/NGT/Command.h>
#include <NGT/lib/NGT/Common.h>
#include <NGT/lib/NGT/Index.h>
#include <knowhere/common/Exception.h>
#include <knowhere/index/IndexType.h>
#include <knowhere/index/vector_index/VecIndex.h>
#include <memory>
namespace milvus {
namespace knowhere {
class IndexNGT : public VecIndex {
public:
IndexNGT() {
index_type_ = IndexEnum::INVALID;
}
BinarySet
Serialize(const Config& config) override;
void
Load(const BinarySet& index_binary) override;
void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
void
Train(const DatasetPtr& dataset_ptr, const Config& config) override {
KNOWHERE_THROW_MSG("NGT not support add item dynamically, please invoke BuildAll interface.");
}
void
Add(const DatasetPtr& dataset_ptr, const Config& config) override {
KNOWHERE_THROW_MSG("NGT not support add item dynamically, please invoke BuildAll interface.");
}
void
AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) override {
KNOWHERE_THROW_MSG("Incremental index is not supported");
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;
int64_t
Dim() override;
protected:
std::shared_ptr<NGT::Index> index_ = nullptr;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,71 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexNGTONNG.h"
#include "NGT/lib/NGT/GraphOptimizer.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include <cstddef>
#include <memory>
namespace milvus {
namespace knowhere {
void
IndexNGTONNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA_DIM(dataset_ptr);
NGT::Property prop;
prop.setDefaultForCreateIndex();
prop.dimension = dim;
auto edge_size = config[IndexParams::edge_size].get<int64_t>();
prop.edgeSizeForCreation = edge_size;
prop.insertionRadiusCoefficient = 1.0;
MetricType metric_type = config[Metric::TYPE];
if (metric_type == Metric::L2) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
} else if (metric_type == Metric::HAMMING) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
} else if (metric_type == Metric::JACCARD) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
} else {
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
}
index_ =
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
// reconstruct graph
NGT::GraphOptimizer graphOptimizer(true);
auto number_of_outgoing_edges = config[IndexParams::outgoing_edge_size].get<size_t>();
auto number_of_incoming_edges = config[IndexParams::incoming_edge_size].get<size_t>();
graphOptimizer.shortcutReduction = true;
graphOptimizer.searchParameterOptimization = false;
graphOptimizer.prefetchParameterOptimization = false;
graphOptimizer.accuracyTableGeneration = false;
graphOptimizer.margin = 0.2;
graphOptimizer.gtEpsilon = 0.1;
graphOptimizer.set(number_of_outgoing_edges, number_of_incoming_edges, 1000, 20);
graphOptimizer.execute(*index_);
}
} // namespace knowhere
} // namespace milvus

View File

@ -7,27 +7,24 @@
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include "knowhere/common/Config.h"
#include "knowhere/index/vector_index/IndexNGT.h"
namespace milvus {
namespace knowhere {
struct Quantizer {
virtual ~Quantizer() = default;
class IndexNGTONNG : public IndexNGT {
public:
IndexNGTONNG() {
index_type_ = IndexEnum::INDEX_NGTONNG;
}
int64_t size = -1;
void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
};
using QuantizerPtr = std::shared_ptr<Quantizer>;
// struct QuantizerCfg : Cfg {
// int64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data
// };
// using QuantizerConfig = std::shared_ptr<QuantizerCfg>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,107 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexNGTPANNG.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include <memory>
namespace milvus {
namespace knowhere {
void
IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA_DIM(dataset_ptr);
NGT::Property prop;
prop.setDefaultForCreateIndex();
prop.dimension = dim;
auto edge_size = config[IndexParams::edge_size].get<int64_t>();
prop.edgeSizeLimitForCreation = edge_size;
MetricType metric_type = config[Metric::TYPE];
if (metric_type == Metric::L2) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
} else if (metric_type == Metric::HAMMING) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
} else if (metric_type == Metric::JACCARD) {
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
} else {
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
}
index_ =
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
auto forcedly_pruned_edge_size = config[IndexParams::forcedly_pruned_edge_size].get<int64_t>();
auto selectively_pruned_edge_size = config[IndexParams::selectively_pruned_edge_size].get<int64_t>();
if (!forcedly_pruned_edge_size && !selectively_pruned_edge_size) {
return;
}
if (forcedly_pruned_edge_size && selectively_pruned_edge_size &&
selectively_pruned_edge_size >= forcedly_pruned_edge_size) {
KNOWHERE_THROW_MSG("Selectively pruned edge size should less than remaining edge size");
}
// prune
auto& graph = dynamic_cast<NGT::GraphIndex&>(index_->getIndex());
for (size_t id = 1; id < graph.repository.size(); id++) {
try {
NGT::GraphNode& node = *graph.getNode(id);
if (node.size() >= forcedly_pruned_edge_size) {
node.resize(forcedly_pruned_edge_size);
}
if (node.size() >= selectively_pruned_edge_size) {
size_t rank = 0;
for (auto i = node.begin(); i != node.end(); ++rank) {
if (rank >= selectively_pruned_edge_size) {
bool found = false;
for (size_t t1 = 0; t1 < node.size() && found == false; ++t1) {
if (t1 >= selectively_pruned_edge_size) {
break;
}
if (rank == t1) {
continue;
}
NGT::GraphNode& node2 = *graph.getNode(node[t1].id);
for (size_t t2 = 0; t2 < node2.size(); ++t2) {
if (t2 >= selectively_pruned_edge_size) {
break;
}
if (node2[t2].id == (*i).id) {
found = true;
break;
}
} // for
} // for
if (found) {
// remove
i = node.erase(i);
continue;
}
}
i++;
} // for
}
} catch (NGT::Exception& err) {
std::cerr << "Graph::search: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
continue;
}
}
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,30 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include "knowhere/index/vector_index/IndexNGT.h"
namespace milvus {
namespace knowhere {
class IndexNGTPANNG : public IndexNGT {
public:
IndexNGTPANNG() {
index_type_ = IndexEnum::INDEX_NGTPANNG;
}
void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
};
} // namespace knowhere
} // namespace milvus

View File

@ -9,6 +9,7 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <fiu/fiu-local.h>
#include <string>
#include "knowhere/common/Exception.h"
@ -37,6 +38,7 @@ NSG::Serialize(const Config& config) {
}
try {
fiu_do_on("NSG.Serialize.throw_exception", throw std::exception());
std::lock_guard<std::mutex> lk(mutex_);
impl::NsgIndex* index = index_.get();
@ -55,6 +57,7 @@ NSG::Serialize(const Config& config) {
void
NSG::Load(const BinarySet& index_binary) {
try {
fiu_do_on("NSG.Load.throw_exception", throw std::exception());
std::lock_guard<std::mutex> lk(mutex_);
auto binary = index_binary.GetByName("NSG");
@ -70,7 +73,7 @@ NSG::Load(const BinarySet& index_binary) {
}
DatasetPtr
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -84,15 +87,13 @@ NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
impl::SearchParams s_params;
s_params.search_length = config[IndexParams::search_length];
s_params.k = config[meta::TOPK];
{
std::lock_guard<std::mutex> lk(mutex_);
index_->Search((float*)p_data, nullptr, rows, dim, config[meta::TOPK].get<int64_t>(), p_dist, p_id,
s_params, blacklist);
s_params, bitset);
}
auto ret_ds = std::make_shared<Dataset>();

View File

@ -59,7 +59,7 @@ class NSG : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
int64_t
Count() override;

View File

@ -79,7 +79,7 @@ IndexRHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -96,10 +96,9 @@ IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
}
auto real_index = dynamic_cast<faiss::IndexRHNSW*>(index_.get());
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
real_index->hnsw.efSearch = (config[IndexParams::ef]);
real_index->search(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, blacklist);
real_index->search(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, bitset);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);

View File

@ -52,7 +52,7 @@ class IndexRHNSW : public VecIndex, public FaissBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;

View File

@ -176,7 +176,7 @@ CPUSPTAGRNG::SetParameters(const Config& config) {
}
DatasetPtr
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
SetParameters(config);
float* p_data = (float*)dataset_ptr->Get<const void*>(meta::TENSOR);

View File

@ -52,7 +52,7 @@ class CPUSPTAGRNG : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;

View File

@ -46,7 +46,7 @@ class VecIndex : public Index {
AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0;
virtual DatasetPtr
Query(const DatasetPtr& dataset, const Config& config) = 0;
Query(const DatasetPtr& dataset, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) = 0;
#if 0
virtual DatasetPtr
@ -144,9 +144,11 @@ class VecIndex : public Index {
protected:
IndexType index_type_ = "";
IndexMode index_mode_ = IndexMode::MODE_CPU;
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
std::vector<IDType> uids_;
int64_t index_size_ = -1;
private:
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
};
using VecIndexPtr = std::shared_ptr<VecIndex>;

View File

@ -21,6 +21,8 @@
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/IndexNGTONNG.h"
#include "knowhere/index/vector_index/IndexNGTPANNG.h"
#include "knowhere/index/vector_index/IndexRHNSWFlat.h"
#include "knowhere/index/vector_index/IndexRHNSWPQ.h"
#include "knowhere/index/vector_index/IndexRHNSWSQ.h"
@ -99,6 +101,10 @@ VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) {
return std::make_shared<knowhere::IndexRHNSWPQ>();
} else if (type == IndexEnum::INDEX_RHNSWSQ) {
return std::make_shared<knowhere::IndexRHNSWSQ>();
} else if (type == IndexEnum::INDEX_NGTPANNG) {
return std::make_shared<knowhere::IndexNGTPANNG>();
} else if (type == IndexEnum::INDEX_NGTONNG) {
return std::make_shared<knowhere::IndexNGTONNG>();
} else {
return nullptr;
}

View File

@ -16,6 +16,7 @@
#ifdef MILVUS_GPU_VERSION
#include <faiss/gpu/GpuCloner.h>
#endif
#include <fiu/fiu-local.h>
#include <string>
#include "knowhere/common/Exception.h"
@ -43,6 +44,7 @@ GPUIDMAP::CopyGpuToCpu(const Config& config) {
BinarySet
GPUIDMAP::SerializeImpl(const IndexType& type) {
try {
fiu_do_on("GPUIDMP.SerializeImpl.throw_exception", throw std::exception());
MemoryIOWriter writer;
{
faiss::Index* index = index_.get();
@ -102,13 +104,19 @@ GPUIDMAP::GetRawIds() {
}
void
GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
GPUIDMAP::QueryImpl(int64_t n,
const float* data,
int64_t k,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
ResScope rs(res_, gpu_id_);
// assign the metric type
auto flat_index = dynamic_cast<faiss::IndexIDMap*>(index_.get())->index;
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
index_->search(n, data, k, distances, labels, bitset_);
index_->search(n, data, k, distances, labels, bitset);
}
void
@ -132,7 +140,7 @@ GPUIDMAP::GenGraph(const float* data, const int64_t k, GraphType& graph, const C
res.resize(K * b_size);
const float* xq = data + batch_size * dim * i;
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config);
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr);
for (int j = 0; j < b_size; ++j) {
auto& node = graph[batch_size * i + j];

View File

@ -55,7 +55,8 @@ class GPUIDMAP : public IDMAP, public GPUIndex {
LoadImpl(const BinarySet&, const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
override;
};
using GPUIDMAPPtr = std::shared_ptr<GPUIDMAP>;

View File

@ -9,12 +9,14 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <algorithm>
#include <memory>
#include <faiss/gpu/GpuCloner.h>
#include <faiss/gpu/GpuIndexIVF.h>
#include <faiss/gpu/GpuIndexIVFFlat.h>
#include <faiss/index_io.h>
#include <fiu/fiu-local.h>
#include <string>
#include "knowhere/common/Exception.h"
@ -91,6 +93,7 @@ GPUIVF::SerializeImpl(const IndexType& type) {
}
try {
fiu_do_on("GPUIVF.SerializeImpl.throw_exception", throw std::exception());
MemoryIOWriter writer;
{
faiss::Index* index = index_.get();
@ -134,12 +137,19 @@ GPUIVF::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
}
void
GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
GPUIVF::QueryImpl(int64_t n,
const float* data,
int64_t k,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
std::lock_guard<std::mutex> lk(mutex_);
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
fiu_do_on("GPUIVF.search_impl.invald_index", device_index = nullptr);
if (device_index) {
device_index->nprobe = config[IndexParams::nprobe];
device_index->nprobe = std::min(static_cast<int>(config[IndexParams::nprobe]), device_index->nlist);
ResScope rs(res_, gpu_id_);
// if query size > 2048 we search by blocks to avoid malloc issue
@ -148,7 +158,7 @@ GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int
for (int64_t i = 0; i < n; i += block_size) {
int64_t search_size = (n - i > block_size) ? block_size : (n - i);
device_index->search(search_size, reinterpret_cast<const float*>(data) + i * dim, k, distances + i * k,
labels + i * k, bitset_);
labels + i * k, bitset);
}
} else {
KNOWHERE_THROW_MSG("Not a GpuIndexIVF type.");

View File

@ -51,7 +51,8 @@ class GPUIVF : public IVF, public GPUIndex {
LoadImpl(const BinarySet&, const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
override;
};
using GPUIVFPtr = std::shared_ptr<GPUIVF>;

View File

@ -14,6 +14,7 @@
#include <faiss/gpu/GpuCloner.h>
#include <faiss/gpu/GpuIndexIVF.h>
#include <faiss/index_factory.h>
#include <fiu/fiu-local.h>
#include <string>
#include <utility>
@ -93,7 +94,7 @@ IVFSQHybrid::CopyCpuToGpu(const int64_t device_id, const Config& config) {
}
}
std::pair<VecIndexPtr, QuantizerPtr>
std::pair<VecIndexPtr, FaissIVFQuantizerPtr>
IVFSQHybrid::CopyCpuToGpuWithQuantizer(const int64_t device_id, const Config& config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
@ -122,7 +123,7 @@ IVFSQHybrid::CopyCpuToGpuWithQuantizer(const int64_t device_id, const Config& co
}
VecIndexPtr
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& quantizer_ptr, const Config& config) {
IVFSQHybrid::LoadData(const FaissIVFQuantizerPtr& quantizer_ptr, const Config& config) {
int64_t gpu_id = config[knowhere::meta::DEVICEID];
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
@ -150,7 +151,7 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& quantizer_ptr, const Config&
}
}
QuantizerPtr
FaissIVFQuantizerPtr
IVFSQHybrid::LoadQuantizer(const Config& config) {
auto gpu_id = config[knowhere::meta::DEVICEID].get<int64_t>();
@ -173,8 +174,6 @@ IVFSQHybrid::LoadQuantizer(const Config& config) {
q->size = q_ptr->d * q_ptr->getNumVecs() * sizeof(float);
q->quantizer = q_ptr;
q->gpu_id = gpu_id;
res_ = res;
gpu_mode_ = 1;
return q;
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu: " + std::to_string(gpu_id) + "resource");
@ -182,20 +181,17 @@ IVFSQHybrid::LoadQuantizer(const Config& config) {
}
void
IVFSQHybrid::SetQuantizer(const QuantizerPtr& quantizer_ptr) {
auto ivf_quantizer = std::dynamic_pointer_cast<FaissIVFQuantizer>(quantizer_ptr);
if (ivf_quantizer == nullptr) {
KNOWHERE_THROW_MSG("Quantizer type error");
IVFSQHybrid::SetQuantizer(const FaissIVFQuantizerPtr& quantizer_ptr) {
faiss::IndexIVF* ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
if (ivf_index == nullptr) {
KNOWHERE_THROW_MSG("Index type error");
}
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
// Once SetQuantizer() is called, make sure UnsetQuantizer() is also called before destructuring.
// Otherwise, ivf_index->quantizer will be double free.
auto is_gpu_flat_index = dynamic_cast<faiss::gpu::GpuIndexFlat*>(ivf_index->quantizer);
if (is_gpu_flat_index == nullptr) {
// delete ivf_index->quantizer;
ivf_index->quantizer = ivf_quantizer->quantizer;
}
quantizer_gpu_id_ = ivf_quantizer->gpu_id;
quantizer_ = quantizer_ptr;
ivf_index->quantizer = quantizer_->quantizer;
gpu_mode_ = 1;
}
@ -206,8 +202,10 @@ IVFSQHybrid::UnsetQuantizer() {
KNOWHERE_THROW_MSG("Index type error");
}
ivf_index->quantizer = nullptr;
quantizer_gpu_id_ = -1;
// set back to cpu mode
ivf_index->restore_quantizer();
quantizer_ = nullptr;
gpu_mode_ = 0;
}
BinarySet
@ -216,6 +214,7 @@ IVFSQHybrid::SerializeImpl(const IndexType& type) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
fiu_do_on("IVFSQHybrid.SerializeImpl.zero_gpu_mode", gpu_mode_ = 0);
if (gpu_mode_ == 0) {
MemoryIOWriter writer;
faiss::write_index(index_.get(), &writer);
@ -242,20 +241,26 @@ IVFSQHybrid::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
}
void
IVFSQHybrid::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
IVFSQHybrid::QueryImpl(int64_t n,
const float* data,
int64_t k,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
if (gpu_mode_ == 2) {
GPUIVF::QueryImpl(n, data, k, distances, labels, config);
GPUIVF::QueryImpl(n, data, k, distances, labels, config, bitset);
// index_->search(n, (float*)data, k, distances, labels);
} else if (gpu_mode_ == 1) { // hybrid
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(quantizer_gpu_id_)) {
ResScope rs(res, quantizer_gpu_id_, true);
IVF::QueryImpl(n, data, k, distances, labels, config);
auto gpu_id = quantizer_->gpu_id;
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
ResScope rs(res, gpu_id, true);
IVF::QueryImpl(n, data, k, distances, labels, config, bitset);
} else {
KNOWHERE_THROW_MSG("Hybrid Search Error, can't get gpu: " + std::to_string(quantizer_gpu_id_) + "resource");
KNOWHERE_THROW_MSG("Hybrid Search Error, can't get gpu: " + std::to_string(gpu_id) + "resource");
}
} else if (gpu_mode_ == 0) {
IVF::QueryImpl(n, data, k, distances, labels, config);
IVF::QueryImpl(n, data, k, distances, labels, config, bitset);
}
}
@ -278,7 +283,6 @@ FaissIVFQuantizer::~FaissIVFQuantizer() {
delete quantizer;
quantizer = nullptr;
}
// else do nothing
}
#endif

View File

@ -18,18 +18,18 @@
#include <utility>
#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h"
#include "knowhere/index/vector_index/gpu/Quantizer.h"
namespace milvus {
namespace knowhere {
#ifdef MILVUS_GPU_VERSION
struct FaissIVFQuantizer : public Quantizer {
struct FaissIVFQuantizer {
faiss::gpu::GpuIndexFlat* quantizer = nullptr;
int64_t gpu_id;
int64_t size = -1;
~FaissIVFQuantizer() override;
~FaissIVFQuantizer();
};
using FaissIVFQuantizerPtr = std::shared_ptr<FaissIVFQuantizer>;
@ -62,17 +62,17 @@ class IVFSQHybrid : public GPUIVFSQ {
VecIndexPtr
CopyCpuToGpu(const int64_t, const Config&) override;
std::pair<VecIndexPtr, QuantizerPtr>
std::pair<VecIndexPtr, FaissIVFQuantizerPtr>
CopyCpuToGpuWithQuantizer(const int64_t, const Config&);
VecIndexPtr
LoadData(const knowhere::QuantizerPtr&, const Config&);
LoadData(const FaissIVFQuantizerPtr&, const Config&);
QuantizerPtr
FaissIVFQuantizerPtr
LoadQuantizer(const Config& conf);
void
SetQuantizer(const QuantizerPtr& q);
SetQuantizer(const FaissIVFQuantizerPtr& q);
void
UnsetQuantizer();
@ -88,11 +88,12 @@ class IVFSQHybrid : public GPUIVFSQ {
LoadImpl(const BinarySet&, const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
override;
protected:
int64_t gpu_mode_ = 0; // 0,1,2
int64_t quantizer_gpu_id_ = -1;
int64_t gpu_mode_ = 0; // 0: CPU, 1: Hybrid, 2: GPU
FaissIVFQuantizerPtr quantizer_ = nullptr;
};
using IVFSQHybridPtr = std::shared_ptr<IVFSQHybrid>;

View File

@ -65,8 +65,9 @@ CopyCpuToGpu(const VecIndexPtr& index, const int64_t device_id, const Config& co
} else {
KNOWHERE_THROW_MSG("this index type not support transfer to gpu");
}
CopyIndexData(result, index);
if (result != nullptr) {
CopyIndexData(result, index);
}
return result;
}

View File

@ -12,6 +12,7 @@
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#include "knowhere/common/Log.h"
#include <fiu/fiu-local.h>
#include <utility>
namespace milvus {
@ -82,6 +83,7 @@ FaissGpuResourceMgr::InitResource() {
ResPtr
FaissGpuResourceMgr::GetRes(const int64_t device_id, const int64_t alloc_size) {
fiu_return_on("FaissGpuResourceMgr.GetRes.ret_null", nullptr);
InitResource();
auto finder = idle_map_.find(device_id);

View File

@ -51,6 +51,15 @@ constexpr const char* search_k = "search_k";
// PQ Params
constexpr const char* PQM = "PQM";
// NGT Params
constexpr const char* edge_size = "edge_size";
// NGT_PANNG Params
constexpr const char* forcedly_pruned_edge_size = "forcedly_pruned_edge_size";
constexpr const char* selectively_pruned_edge_size = "selectively_pruned_edge_size";
// NGT_ONNG Params
constexpr const char* outgoing_edge_size = "outgoing_edge_size";
constexpr const char* incoming_edge_size = "incoming_edge_size";
} // namespace IndexParams
namespace Metric {

View File

@ -124,7 +124,10 @@ NsgIndex::InitNavigationPoint(float* data) {
// Specify Link
void
NsgIndex::GetNeighbors(const float* query, float* data, std::vector<Neighbor>& resset, std::vector<Neighbor>& fullset,
NsgIndex::GetNeighbors(const float* query,
float* data,
std::vector<Neighbor>& resset,
std::vector<Neighbor>& fullset,
boost::dynamic_bitset<>& has_calculated_dist) {
auto& graph = knng;
size_t buffer_size = search_length;
@ -331,8 +334,8 @@ NsgIndex::GetNeighbors(const float* query, float* data, std::vector<Neighbor>& r
}
void
NsgIndex::GetNeighbors(const float* query, float* data, std::vector<Neighbor>& resset, Graph& graph,
SearchParams* params) {
NsgIndex::GetNeighbors(
const float* query, float* data, std::vector<Neighbor>& resset, Graph& graph, SearchParams* params) {
size_t buffer_size = params ? params->search_length : search_length;
if (buffer_size > ntotal) {
@ -482,7 +485,10 @@ NsgIndex::Link(float* data) {
}
void
NsgIndex::SyncPrune(float* data, size_t n, std::vector<Neighbor>& pool, boost::dynamic_bitset<>& has_calculated,
NsgIndex::SyncPrune(float* data,
size_t n,
std::vector<Neighbor>& pool,
boost::dynamic_bitset<>& has_calculated,
float* cut_graph_dist) {
// avoid lose nearest neighbor in knng
for (size_t i = 0; i < knng[n].size(); ++i) {
@ -597,8 +603,8 @@ NsgIndex::InterInsert(float* data, unsigned n, std::vector<std::mutex>& mutex_ve
}
void
NsgIndex::SelectEdge(float* data, unsigned& cursor, std::vector<Neighbor>& sort_pool, std::vector<Neighbor>& result,
bool limit) {
NsgIndex::SelectEdge(
float* data, unsigned& cursor, std::vector<Neighbor>& sort_pool, std::vector<Neighbor>& result, bool limit) {
auto& pool = sort_pool;
/*
@ -850,8 +856,15 @@ NsgIndex::FindUnconnectedNode(float* data, boost::dynamic_bitset<>& has_linked,
// }
void
NsgIndex::Search(const float* query, float* data, const unsigned& nq, const unsigned& dim, const unsigned& k,
float* dist, int64_t* ids, SearchParams& params, faiss::ConcurrentBitsetPtr bitset) {
NsgIndex::Search(const float* query,
float* data,
const unsigned& nq,
const unsigned& dim,
const unsigned& k,
float* dist,
int64_t* ids,
SearchParams& params,
faiss::ConcurrentBitsetPtr bitset) {
std::vector<std::vector<Neighbor>> resset(nq);
TimeRecorder rc("NsgIndex::search", 1);

View File

@ -83,8 +83,15 @@ class NsgIndex {
Build_with_ids(size_t nb, float* data, const int64_t* ids, const BuildParams& parameters);
void
Search(const float* query, float* data, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist,
int64_t* ids, SearchParams& params, faiss::ConcurrentBitsetPtr bitset = nullptr);
Search(const float* query,
float* data,
const unsigned& nq,
const unsigned& dim,
const unsigned& k,
float* dist,
int64_t* ids,
SearchParams& params,
faiss::ConcurrentBitsetPtr bitset = nullptr);
int64_t
GetSize();
@ -108,7 +115,10 @@ class NsgIndex {
// link specify
void
GetNeighbors(const float* query, float* data, std::vector<Neighbor>& resset, std::vector<Neighbor>& fullset,
GetNeighbors(const float* query,
float* data,
std::vector<Neighbor>& resset,
std::vector<Neighbor>& fullset,
boost::dynamic_bitset<>& has_calculated_dist);
// FindUnconnectedNode
@ -117,8 +127,8 @@ class NsgIndex {
// navigation-point
void
GetNeighbors(const float* query, float* data, std::vector<Neighbor>& resset, Graph& graph,
SearchParams* param = nullptr);
GetNeighbors(
const float* query, float* data, std::vector<Neighbor>& resset, Graph& graph, SearchParams* param = nullptr);
// only for search
// void
@ -128,11 +138,17 @@ class NsgIndex {
Link(float* data);
void
SyncPrune(float* data, size_t q, std::vector<Neighbor>& pool, boost::dynamic_bitset<>& has_calculated,
SyncPrune(float* data,
size_t q,
std::vector<Neighbor>& pool,
boost::dynamic_bitset<>& has_calculated,
float* cut_graph_dist);
void
SelectEdge(float* data, unsigned& cursor, std::vector<Neighbor>& sort_pool, std::vector<Neighbor>& result,
SelectEdge(float* data,
unsigned& cursor,
std::vector<Neighbor>& sort_pool,
std::vector<Neighbor>& result,
bool limit = false);
void

View File

@ -23,6 +23,7 @@
#include <faiss/gpu/GpuCloner.h>
#endif
#include <fiu/fiu-local.h>
#include <chrono>
#include <memory>
#include <string>
@ -66,13 +67,13 @@ IVF_NM::Load(const BinarySet& binary_set) {
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
auto invlists = ivf_index->invlists;
auto d = ivf_index->d;
size_t nb = binary->size / invlists->code_size;
auto arranged_data = new float[d * nb];
prefix_sum.resize(invlists->nlist);
size_t curr_index = 0;
#ifndef MILVUS_GPU_VERSION
auto ails = dynamic_cast<faiss::ArrayInvertedLists*>(invlists);
size_t nb = binary->size / invlists->code_size;
auto arranged_data = new float[d * nb];
for (size_t i = 0; i < invlists->nlist; i++) {
auto list_size = ails->ids[i].size();
for (size_t j = 0; j < list_size; j++) {
@ -81,8 +82,10 @@ IVF_NM::Load(const BinarySet& binary_set) {
prefix_sum[i] = curr_index;
curr_index += list_size;
}
data_ = std::shared_ptr<uint8_t[]>(reinterpret_cast<uint8_t*>(arranged_data));
#else
auto rol = dynamic_cast<faiss::ReadOnlyArrayInvertedLists*>(invlists);
auto arranged_data = reinterpret_cast<float*>(rol->pin_readonly_codes->data);
auto lengths = rol->readonly_length;
auto rol_ids = reinterpret_cast<const int64_t*>(rol->pin_readonly_ids->data);
for (size_t i = 0; i < invlists->nlist; i++) {
@ -94,8 +97,11 @@ IVF_NM::Load(const BinarySet& binary_set) {
prefix_sum[i] = curr_index;
curr_index += list_size;
}
/* hold codes shared pointer */
ro_codes = rol->pin_readonly_codes;
data_ = nullptr;
#endif
data_ = std::shared_ptr<uint8_t[]>(reinterpret_cast<uint8_t*>(arranged_data));
}
void
@ -132,7 +138,7 @@ IVF_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -140,6 +146,8 @@ IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA(dataset_ptr)
try {
fiu_do_on("IVF_NM.Search.throw_std_exception", throw std::exception());
fiu_do_on("IVF_NM.Search.throw_faiss_exception", throw faiss::FaissException(""));
auto k = config[meta::TOPK].get<int64_t>();
auto elems = rows * k;
@ -148,7 +156,7 @@ IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config);
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
@ -236,8 +244,8 @@ IVF_NM::CopyCpuToGpu(const int64_t device_id, const Config& config) {
#ifdef MILVUS_GPU_VERSION
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
auto gpu_index =
faiss::gpu::index_cpu_to_gpu_without_codes(res->faiss_res.get(), device_id, index_.get(), data_.get());
auto gpu_index = faiss::gpu::index_cpu_to_gpu_without_codes(res->faiss_res.get(), device_id, index_.get(),
static_cast<const uint8_t*>(ro_codes->data));
std::shared_ptr<faiss::Index> device_index;
device_index.reset(gpu_index);
@ -275,7 +283,7 @@ IVF_NM::GenGraph(const float* data, const int64_t k, GraphType& graph, const Con
res.resize(K * b_size);
const float* xq = data + batch_size * dim * i;
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config);
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr);
for (int j = 0; j < b_size; ++j) {
auto& node = graph[batch_size * i + j];
@ -297,7 +305,13 @@ IVF_NM::GenParams(const Config& config) {
}
void
IVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
IVF_NM::QueryImpl(int64_t n,
const float* query,
int64_t k,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
@ -308,8 +322,15 @@ IVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int
ivf_index->parallel_mode = 0;
}
bool is_sq8 = (index_type_ == IndexEnum::INDEX_FAISS_IVFSQ8) ? true : false;
ivf_index->search_without_codes(n, reinterpret_cast<const float*>(data), data_.get(), prefix_sum, is_sq8, k,
distances, labels, bitset_);
#ifndef MILVUS_GPU_VERSION
auto data = static_cast<const uint8_t*>(data_.get());
#else
auto data = static_cast<const uint8_t*>(ro_codes->data);
#endif
ivf_index->search_without_codes(n, reinterpret_cast<const float*>(query), data, prefix_sum, is_sq8, k, distances,
labels, bitset);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
LOG_KNOWHERE_DEBUG_ << "IVF_NM search cost: " << search_cost

View File

@ -51,7 +51,7 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
#if 0
DatasetPtr
@ -86,15 +86,21 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
GenParams(const Config&);
virtual void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
QueryImpl(
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset);
void
SealImpl() override;
protected:
std::mutex mutex_;
std::shared_ptr<uint8_t[]> data_ = nullptr;
std::vector<size_t> prefix_sum;
// data_: if CPU, malloc memory while loading data
// ro_codes: if GPU, hold a ptr of read only codes so that
// destruction won't be done twice
std::shared_ptr<uint8_t[]> data_ = nullptr;
faiss::PageLockMemoryPtr ro_codes = nullptr;
};
using IVFNMPtr = std::shared_ptr<IVF_NM>;

View File

@ -9,6 +9,7 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <fiu/fiu-local.h>
#include <string>
#include "knowhere/common/Exception.h"
@ -36,6 +37,7 @@ NSG_NM::Serialize(const Config& config) {
}
try {
fiu_do_on("NSG_NM.Serialize.throw_exception", throw std::exception());
std::lock_guard<std::mutex> lk(mutex_);
impl::NsgIndex* index = index_.get();
@ -54,6 +56,7 @@ NSG_NM::Serialize(const Config& config) {
void
NSG_NM::Load(const BinarySet& index_binary) {
try {
fiu_do_on("NSG_NM.Load.throw_exception", throw std::exception());
std::lock_guard<std::mutex> lk(mutex_);
auto binary = index_binary.GetByName("NSG_NM");
@ -71,7 +74,7 @@ NSG_NM::Load(const BinarySet& index_binary) {
}
DatasetPtr
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -86,8 +89,6 @@ NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
impl::SearchParams s_params;
s_params.search_length = config[IndexParams::search_length];
s_params.k = config[meta::TOPK];
@ -95,7 +96,7 @@ NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
std::lock_guard<std::mutex> lk(mutex_);
// index_->ori_data_ = (float*) data_.get();
index_->Search(reinterpret_cast<const float*>(p_data), reinterpret_cast<float*>(data_.get()), rows, dim,
topK, p_dist, p_id, s_params, blacklist);
topK, p_dist, p_id, s_params, bitset);
}
auto ret_ds = std::make_shared<Dataset>();

View File

@ -59,7 +59,7 @@ class NSG_NM : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;

View File

@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <faiss/index_io.h>
#include <fiu/fiu-local.h>
#include "knowhere/common/Exception.h"
#include "knowhere/index/IndexType.h"
@ -22,6 +23,7 @@ namespace knowhere {
BinarySet
OffsetBaseIndex::SerializeImpl(const IndexType& type) {
try {
fiu_do_on("OffsetBaseIndex.SerializeImpl.throw_exception", throw std::exception());
faiss::Index* index = index_.get();
MemoryIOWriter writer;

View File

@ -15,6 +15,7 @@
#include <faiss/gpu/GpuIndexIVF.h>
#include <faiss/gpu/GpuIndexIVFFlat.h>
#include <faiss/index_io.h>
#include <fiu/fiu-local.h>
#include <string>
#include "knowhere/common/Exception.h"
@ -97,6 +98,7 @@ GPUIVF_NM::SerializeImpl(const IndexType& type) {
}
try {
fiu_do_on("GPUIVF_NM.SerializeImpl.throw_exception", throw std::exception());
MemoryIOWriter writer;
{
faiss::Index* index = index_.get();
@ -116,10 +118,17 @@ GPUIVF_NM::SerializeImpl(const IndexType& type) {
}
void
GPUIVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
GPUIVF_NM::QueryImpl(int64_t n,
const float* data,
int64_t k,
float* distances,
int64_t* labels,
const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
std::lock_guard<std::mutex> lk(mutex_);
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
fiu_do_on("GPUIVF_NM.search_impl.invald_index", device_index = nullptr);
if (device_index) {
device_index->nprobe = config[IndexParams::nprobe];
ResScope rs(res_, gpu_id_);
@ -129,7 +138,7 @@ GPUIVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances,
int64_t dim = device_index->d;
for (int64_t i = 0; i < n; i += block_size) {
int64_t search_size = (n - i > block_size) ? block_size : (n - i);
device_index->search(search_size, data + i * dim, k, distances + i * k, labels + i * k, bitset_);
device_index->search(search_size, data + i * dim, k, distances + i * k, labels + i * k, bitset);
}
} else {
KNOWHERE_THROW_MSG("Not a GpuIndexIVF type.");

View File

@ -51,7 +51,8 @@ class GPUIVF_NM : public IVF, public GPUIndex {
SerializeImpl(const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
override;
protected:
uint8_t* arranged_data;

View File

@ -0,0 +1,79 @@
if(APPLE)
cmake_minimum_required(VERSION 3.0)
else()
cmake_minimum_required(VERSION 2.8)
endif()
project(ngt)
file(STRINGS "VERSION" ngt_VERSION)
message(STATUS "VERSION: ${ngt_VERSION}")
string(REGEX MATCH "^[0-9]+" ngt_VERSION_MAJOR ${ngt_VERSION})
set(ngt_VERSION ${ngt_VERSION})
set(ngt_SOVERSION ${ngt_VERSION_MAJOR})
if (NOT CMAKE_BUILD_TYPE)
set (CMAKE_BUILD_TYPE "Release")
endif (NOT CMAKE_BUILD_TYPE)
string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_LOWER)
message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
message(STATUS "CMAKE_BUILD_TYPE_LOWER: ${CMAKE_BUILD_TYPE_LOWER}")
if(${UNIX})
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
if(CMAKE_VERSION VERSION_LESS 3.1)
set(BASE_OPTIONS "-Wall -std=gnu++0x -lrt")
if(${NGT_AVX_DISABLED})
message(STATUS "AVX will not be used to compute distances.")
endif()
if(${NGT_OPENMP_DISABLED})
message(STATUS "OpenMP is disabled.")
else()
set(BASE_OPTIONS "${BASE_OPTIONS} -fopenmp")
endif()
set(CMAKE_CXX_FLAGS_DEBUG "-g ${BASE_OPTIONS}")
if(${NGT_MARCH_NATIVE_DISABLED})
message(STATUS "Compile option -march=native is disabled.")
set(CMAKE_CXX_FLAGS_RELEASE "-O2 ${BASE_OPTIONS}")
else()
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=native ${BASE_OPTIONS}")
endif()
else()
if (CMAKE_BUILD_TYPE_LOWER STREQUAL "release")
set(CMAKE_CXX_FLAGS_RELEASE "")
if(${NGT_MARCH_NATIVE_DISABLED})
message(STATUS "Compile option -march=native is disabled.")
add_compile_options(-O2 -DNDEBUG)
else()
add_compile_options(-Ofast -march=native -DNDEBUG)
endif()
endif()
add_compile_options(-Wall)
if(${NGT_AVX_DISABLED})
message(STATUS "AVX will not be used to compute distances.")
endif()
if(${NGT_OPENMP_DISABLED})
message(STATUS "OpenMP is disabled.")
else()
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "8.1.0")
message(FATAL_ERROR "Insufficient AppleClang version")
endif()
cmake_minimum_required(VERSION 3.16)
endif()
find_package(OpenMP REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()
set(CMAKE_CXX_STANDARD 11) # for std::unordered_set, std::unique_ptr
set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_package(Threads REQUIRED)
endif()
add_subdirectory("${PROJECT_SOURCE_DIR}/lib")
endif( ${UNIX} )

View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1 @@
1.12.0

View File

@ -0,0 +1,3 @@
if( ${UNIX} )
add_subdirectory(${PROJECT_SOURCE_DIR}/lib/NGT)
endif()

View File

@ -0,0 +1,89 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "ArrayFile.h"
#include <iostream>
#include <assert.h>
class ItemID {
public:
void serialize(std::ostream &os, NGT::ObjectSpace *ospace = 0) {
os.write((char*)&value, sizeof(value));
}
void deserialize(std::istream &is, NGT::ObjectSpace *ospace = 0) {
is.read((char*)&value, sizeof(value));
}
static size_t getSerializedDataSize() {
return sizeof(uint64_t);
}
uint64_t value;
};
void
sampleForUsage() {
{
ArrayFile<ItemID> itemIDFile;
itemIDFile.create("test.data", ItemID::getSerializedDataSize());
itemIDFile.open("test.data");
ItemID itemID;
size_t id;
id = 1;
itemID.value = 4910002490100;
itemIDFile.put(id, itemID);
itemID.value = 0;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490100);
id = 2;
itemID.value = 4910002490101;
itemIDFile.put(id, itemID);
itemID.value = 0;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490101);
itemID.value = 4910002490102;
id = itemIDFile.insert(itemID);
itemID.value = 0;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490102);
itemIDFile.close();
}
{
ArrayFile<ItemID> itemIDFile;
itemIDFile.create("test.data", ItemID::getSerializedDataSize());
itemIDFile.open("test.data");
ItemID itemID;
size_t id;
id = 10;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490100);
id = 20;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490101);
}
}

View File

@ -0,0 +1,220 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <fstream>
#include <string>
#include <cstddef>
#include <stdint.h>
#include <iostream>
#include <stdexcept>
#include <cerrno>
#include <cstring>
namespace NGT {
class ObjectSpace;
};
template <class TYPE>
class ArrayFile {
private:
struct FileHeadStruct {
size_t recordSize;
uint64_t extraData; // reserve
};
struct RecordStruct {
bool deleteFlag;
uint64_t extraData; // reserve
};
bool _isOpen;
std::fstream _stream;
FileHeadStruct _fileHead;
bool _readFileHead();
pthread_mutex_t _mutex;
public:
ArrayFile();
~ArrayFile();
bool create(const std::string &file, size_t recordSize);
bool open(const std::string &file);
void close();
size_t insert(TYPE &data, NGT::ObjectSpace *objectSpace = 0);
void put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0);
bool get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0);
void remove(const size_t id);
bool isOpen() const;
size_t size();
size_t getRecordSize() { return _fileHead.recordSize; }
};
// constructor
template <class TYPE>
ArrayFile<TYPE>::ArrayFile()
: _isOpen(false), _mutex((pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER){
if(pthread_mutex_init(&_mutex, NULL) < 0) throw std::runtime_error("pthread init error.");
}
// destructor
template <class TYPE>
ArrayFile<TYPE>::~ArrayFile() {
pthread_mutex_destroy(&_mutex);
close();
}
template <class TYPE>
bool ArrayFile<TYPE>::create(const std::string &file, size_t recordSize) {
std::fstream tmpstream;
tmpstream.open(file.c_str());
if(tmpstream){
return false;
}
tmpstream.open(file.c_str(), std::ios::out);
tmpstream.seekp(0, std::ios::beg);
FileHeadStruct fileHead = {recordSize, 0};
tmpstream.write((char *)(&fileHead), sizeof(FileHeadStruct));
tmpstream.close();
return true;
}
template <class TYPE>
bool ArrayFile<TYPE>::open(const std::string &file) {
_stream.open(file.c_str(), std::ios::in | std::ios::out);
if(!_stream){
_isOpen = false;
return false;
}
_isOpen = true;
bool ret = _readFileHead();
return ret;
}
template <class TYPE>
void ArrayFile<TYPE>::close(){
_stream.close();
_isOpen = false;
}
template <class TYPE>
size_t ArrayFile<TYPE>::insert(TYPE &data, NGT::ObjectSpace *objectSpace) {
_stream.seekp(sizeof(RecordStruct), std::ios::end);
int64_t write_pos = _stream.tellg();
for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); }
_stream.seekp(write_pos, std::ios::beg);
data.serialize(_stream, objectSpace);
int64_t offset_pos = _stream.tellg();
offset_pos -= sizeof(FileHeadStruct);
size_t id = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize);
if(offset_pos % (sizeof(RecordStruct) + _fileHead.recordSize) == 0){
id -= 1;
}
return id;
}
template <class TYPE>
void ArrayFile<TYPE>::put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) {
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
offset_pos += sizeof(RecordStruct);
_stream.seekp(offset_pos, std::ios::beg);
for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); }
_stream.seekp(offset_pos, std::ios::beg);
data.serialize(_stream, objectSpace);
}
template <class TYPE>
bool ArrayFile<TYPE>::get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) {
pthread_mutex_lock(&_mutex);
if( size() <= id ){
pthread_mutex_unlock(&_mutex);
return false;
}
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
offset_pos += sizeof(RecordStruct);
_stream.seekg(offset_pos, std::ios::beg);
if (!_stream.fail()) {
data.deserialize(_stream, objectSpace);
}
if (_stream.fail()) {
const int trialCount = 10;
for (int tc = 0; tc < trialCount; tc++) {
_stream.clear();
_stream.seekg(offset_pos, std::ios::beg);
if (_stream.fail()) {
continue;
}
data.deserialize(_stream, objectSpace);
if (_stream.fail()) {
continue;
} else {
break;
}
}
if (_stream.fail()) {
throw std::runtime_error("ArrayFile::get: Error!");
}
}
pthread_mutex_unlock(&_mutex);
return true;
}
template <class TYPE>
void ArrayFile<TYPE>::remove(const size_t id) {
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
_stream.seekp(offset_pos, std::ios::beg);
RecordStruct recordHead = {1, 0};
_stream.write((char *)(&recordHead), sizeof(RecordStruct));
}
template <class TYPE>
bool ArrayFile<TYPE>::isOpen() const
{
return _isOpen;
}
template <class TYPE>
size_t ArrayFile<TYPE>::size()
{
_stream.seekp(0, std::ios::end);
int64_t offset_pos = _stream.tellg();
offset_pos -= sizeof(FileHeadStruct);
size_t num = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize);
return num;
}
template <class TYPE>
bool ArrayFile<TYPE>::_readFileHead() {
_stream.seekp(0, std::ios::beg);
_stream.read((char *)(&_fileHead), sizeof(FileHeadStruct));
if(_stream.bad()){
return false;
}
return true;
}

View File

@ -0,0 +1,40 @@
if( ${UNIX} )
option(NGT_SHARED_MEMORY_ALLOCATOR "enable shared memory" OFF)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/defines.h.in ${CMAKE_CURRENT_BINARY_DIR}/defines.h)
include_directories("${CMAKE_CURRENT_BINARY_DIR}" "${PROJECT_SOURCE_DIR}/lib" "${PROJECT_BINARY_DIR}/lib/")
include_directories("${PROJECT_SOURCE_DIR}/../")
file(GLOB NGT_SOURCES *.cpp)
file(GLOB HEADER_FILES *.h *.hpp)
file(GLOB NGTQ_HEADER_FILES NGTQ/*.h NGTQ/*.hpp)
add_library(ngtstatic STATIC ${NGT_SOURCES})
set_target_properties(ngtstatic PROPERTIES OUTPUT_NAME ngt)
set_target_properties(ngtstatic PROPERTIES COMPILE_FLAGS "-fPIC")
target_link_libraries(ngtstatic)
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
target_link_libraries(ngtstatic OpenMP::OpenMP_CXX)
endif()
add_library(ngt SHARED ${NGT_SOURCES})
set_target_properties(ngt PROPERTIES VERSION ${ngt_VERSION})
set_target_properties(ngt PROPERTIES SOVERSION ${ngt_SOVERSION})
add_dependencies(ngt ngtstatic)
if(${APPLE})
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
target_link_libraries(ngt OpenMP::OpenMP_CXX)
else()
target_link_libraries(ngt gomp)
endif()
else(${APPLE})
target_link_libraries(ngt gomp rt)
endif(${APPLE})
install(TARGETS
ngt
ngtstatic
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib)
endif()

View File

@ -0,0 +1,988 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include <string>
#include <iostream>
#include <sstream>
#include "NGT/Index.h"
#include "NGT/GraphOptimizer.h"
#include "Capi.h"
static bool operate_error_string_(const std::stringstream &ss, NGTError error){
if(error != NULL){
try{
std::string *error_str = static_cast<std::string*>(error);
*error_str = ss.str();
}catch(std::exception &err){
std::cerr << ss.str() << " > " << err.what() << std::endl;
return false;
}
}else{
std::cerr << ss.str() << std::endl;
}
return true;
}
NGTIndex ngt_open_index(const char *index_path, NGTError error) {
try{
std::string index_path_str(index_path);
NGT::Index *index = new NGT::Index(index_path_str);
index->disableLog();
return static_cast<NGTIndex>(index);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
NGTIndex ngt_create_graph_and_tree(const char *database, NGTProperty prop, NGTError error) {
NGT::Index *index = NULL;
try{
std::string database_str(database);
NGT::Property prop_i = *(static_cast<NGT::Property*>(prop));
NGT::Index::createGraphAndTree(database_str, prop_i, true);
index = new NGT::Index(database_str);
index->disableLog();
return static_cast<NGTIndex>(index);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
delete index;
return NULL;
}
}
NGTIndex ngt_create_graph_and_tree_in_memory(NGTProperty prop, NGTError error) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << __FUNCTION__ << " is unavailable for shared-memory-type NGT.";
operate_error_string_(ss, error);
return NULL;
#else
try{
NGT::Index *index = new NGT::GraphAndTreeIndex(*(static_cast<NGT::Property*>(prop)));
index->disableLog();
return static_cast<NGTIndex>(index);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
#endif
}
NGTProperty ngt_create_property(NGTError error) {
try{
return static_cast<NGTProperty>(new NGT::Property());
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
bool ngt_save_index(const NGTIndex index, const char *database, NGTError error) {
try{
std::string database_str(database);
(static_cast<NGT::Index*>(index))->saveIndex(database_str);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_get_property(NGTIndex index, NGTProperty prop, NGTError error) {
if(index == NULL || prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " prop = " << prop;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::Index*>(index))->getProperty(*(static_cast<NGT::Property*>(prop)));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
int32_t ngt_get_property_dimension(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).dimension;
}
bool ngt_set_property_dimension(NGTProperty prop, int32_t value, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).dimension = value;
return true;
}
bool ngt_set_property_edge_size_for_creation(NGTProperty prop, int16_t value, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).edgeSizeForCreation = value;
return true;
}
bool ngt_set_property_edge_size_for_search(NGTProperty prop, int16_t value, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).edgeSizeForSearch = value;
return true;
}
int32_t ngt_get_property_object_type(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).objectType;
}
bool ngt_is_property_object_type_float(int32_t object_type) {
return (object_type == NGT::ObjectSpace::ObjectType::Float);
}
bool ngt_is_property_object_type_integer(int32_t object_type) {
return (object_type == NGT::ObjectSpace::ObjectType::Uint8);
}
bool ngt_set_property_object_type_float(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).objectType = NGT::ObjectSpace::ObjectType::Float;
return true;
}
bool ngt_set_property_object_type_integer(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).objectType = NGT::ObjectSpace::ObjectType::Uint8;
return true;
}
bool ngt_set_property_distance_type_l1(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL1;
return true;
}
bool ngt_set_property_distance_type_l2(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
return true;
}
bool ngt_set_property_distance_type_angle(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeAngle;
return true;
}
bool ngt_set_property_distance_type_hamming(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
return true;
}
bool ngt_set_property_distance_type_jaccard(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
return true;
}
bool ngt_set_property_distance_type_cosine(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeCosine;
return true;
}
bool ngt_set_property_distance_type_normalized_angle(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedAngle;
return true;
}
bool ngt_set_property_distance_type_normalized_cosine(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedCosine;
return true;
}
NGTObjectDistances ngt_create_empty_results(NGTError error) {
try{
return static_cast<NGTObjectDistances>(new NGT::ObjectDistances());
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
static bool ngt_search_index_(NGT::Index* pindex, NGT::Object *ngtquery, size_t size, float epsilon, float radius, NGTObjectDistances results, int edge_size = INT_MIN) {
// set search prameters.
NGT::SearchContainer sc(*ngtquery); // search parametera container.
sc.setResults(static_cast<NGT::ObjectDistances*>(results)); // set the result set.
sc.setSize(size); // the number of resultant objects.
sc.setRadius(radius); // search radius.
sc.setEpsilon(epsilon); // set exploration coefficient.
if (edge_size != INT_MIN) {
sc.setEdgeSize(edge_size);// set # of edges for each node
}
pindex->search(sc);
// delete the query object.
pindex->deleteObject(ngtquery);
return true;
}
bool ngt_search_index(NGTIndex index, double *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) {
if(index == NULL || query == NULL || results == NULL || query_dim <= 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
NGT::Object *ngtquery = NULL;
if(radius < 0.0){
radius = FLT_MAX;
}
try{
std::vector<double> vquery(&query[0], &query[query_dim]);
ngtquery = pindex->allocateObject(vquery);
ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
if(ngtquery != NULL){
pindex->deleteObject(ngtquery);
}
return false;
}
return true;
}
bool ngt_search_index_as_float(NGTIndex index, float *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) {
if(index == NULL || query == NULL || results == NULL || query_dim <= 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
NGT::Object *ngtquery = NULL;
if(radius < 0.0){
radius = FLT_MAX;
}
try{
std::vector<float> vquery(&query[0], &query[query_dim]);
ngtquery = pindex->allocateObject(vquery);
ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
if(ngtquery != NULL){
pindex->deleteObject(ngtquery);
}
return false;
}
return true;
}
bool ngt_search_index_with_query(NGTIndex index, NGTQuery query, NGTObjectDistances results, NGTError error) {
if(index == NULL || query.query == NULL || results == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query.query << " results = " << results;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
int32_t dim = pindex->getObjectSpace().getDimension();
NGT::Object *ngtquery = NULL;
if(query.radius < 0.0){
query.radius = FLT_MAX;
}
try{
std::vector<float> vquery(&query.query[0], &query.query[dim]);
ngtquery = pindex->allocateObject(vquery);
ngt_search_index_(pindex, ngtquery, query.size, query.epsilon, query.radius, results, query.edge_size);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
if(ngtquery != NULL){
pindex->deleteObject(ngtquery);
}
return false;
}
return true;
}
// * deprecated *
int32_t ngt_get_size(NGTObjectDistances results, NGTError error) {
if(results == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results;
operate_error_string_(ss, error);
return -1;
}
return (static_cast<NGT::ObjectDistances*>(results))->size();
}
uint32_t ngt_get_result_size(NGTObjectDistances results, NGTError error) {
if(results == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results;
operate_error_string_(ss, error);
return 0;
}
return (static_cast<NGT::ObjectDistances*>(results))->size();
}
NGTObjectDistance ngt_get_result(const NGTObjectDistances results, const uint32_t i, NGTError error) {
try{
NGT::ObjectDistances objects = *(static_cast<NGT::ObjectDistances*>(results));
NGTObjectDistance ret_val = {objects[i].id, objects[i].distance};
return ret_val;
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
NGTObjectDistance err_val = {0};
return err_val;
}
}
ObjectID ngt_insert_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<double> vobj(&obj[0], &obj[obj_dim]);
return pindex->insert(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
ObjectID ngt_append_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<double> vobj(&obj[0], &obj[obj_dim]);
return pindex->append(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
ObjectID ngt_insert_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<float> vobj(&obj[0], &obj[obj_dim]);
return pindex->insert(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
ObjectID ngt_append_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<float> vobj(&obj[0], &obj[obj_dim]);
return pindex->append(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
bool ngt_batch_append_index(NGTIndex index, float *obj, uint32_t data_count, NGTError error) {
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
pindex->append(obj, data_count);
return true;
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
}
bool ngt_batch_insert_index(NGTIndex index, float *obj, uint32_t data_count, uint32_t *ids, NGTError error) {
NGT::Index* pindex = static_cast<NGT::Index*>(index);
int32_t dim = pindex->getObjectSpace().getDimension();
bool status = true;
float *objptr = obj;
for (size_t idx = 0; idx < data_count; idx++, objptr += dim) {
try{
std::vector<double> vobj(objptr, objptr + dim);
ids[idx] = pindex->insert(vobj);
}catch(std::exception &err) {
status = false;
ids[idx] = 0;
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
}
}
return status;
}
bool ngt_create_index(NGTIndex index, uint32_t pool_size, NGTError error) {
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::Index*>(index))->createIndex(pool_size);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_remove_index(NGTIndex index, ObjectID id, NGTError error) {
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::Index*>(index))->remove(id);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
NGTObjectSpace ngt_get_object_space(NGTIndex index, NGTError error) {
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
operate_error_string_(ss, error);
return NULL;
}
try{
return static_cast<NGTObjectSpace>(&(static_cast<NGT::Index*>(index))->getObjectSpace());
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
float* ngt_get_object_as_float(NGTObjectSpace object_space, ObjectID id, NGTError error) {
if(object_space == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space;
operate_error_string_(ss, error);
return NULL;
}
try{
return static_cast<float*>((static_cast<NGT::ObjectSpace*>(object_space))->getObject(id));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
uint8_t* ngt_get_object_as_integer(NGTObjectSpace object_space, ObjectID id, NGTError error) {
if(object_space == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space;
operate_error_string_(ss, error);
return NULL;
}
try{
return static_cast<uint8_t*>((static_cast<NGT::ObjectSpace*>(object_space))->getObject(id));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
void ngt_destroy_results(NGTObjectDistances results) {
if(results == NULL) return;
delete(static_cast<NGT::ObjectDistances*>(results));
}
void ngt_destroy_property(NGTProperty prop) {
if(prop == NULL) return;
delete(static_cast<NGT::Property*>(prop));
}
void ngt_close_index(NGTIndex index) {
if(index == NULL) return;
(static_cast<NGT::Index*>(index))->close();
delete(static_cast<NGT::Index*>(index));
}
int16_t ngt_get_property_edge_size_for_creation(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).edgeSizeForCreation;
}
int16_t ngt_get_property_edge_size_for_search(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).edgeSizeForSearch;
}
int32_t ngt_get_property_distance_type(NGTProperty prop, NGTError error){
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).distanceType;
}
NGTError ngt_create_error_object()
{
try{
std::string *error_str = new std::string();
return static_cast<NGTError>(error_str);
}catch(std::exception &err){
std::cerr << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
return NULL;
}
}
const char *ngt_get_error_string(const NGTError error)
{
std::string *error_str = static_cast<std::string*>(error);
return error_str->c_str();
}
void ngt_clear_error_string(NGTError error)
{
std::string *error_str = static_cast<std::string*>(error);
*error_str = "";
}
void ngt_destroy_error_object(NGTError error)
{
std::string *error_str = static_cast<std::string*>(error);
delete error_str;
}
NGTOptimizer ngt_create_optimizer(bool logDisabled, NGTError error)
{
try{
return static_cast<NGTOptimizer>(new NGT::GraphOptimizer(logDisabled));
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
bool ngt_optimizer_adjust_search_coefficients(NGTOptimizer optimizer, const char *index, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->adjustSearchCoefficients(std::string(index));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_execute(NGTOptimizer optimizer, const char *inIndex, const char *outIndex, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->execute(std::string(inIndex), std::string(outIndex));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
// obsolute because of a lack of a parameter
bool ngt_optimizer_set(NGTOptimizer optimizer, int outgoing, int incoming, int nofqs,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->set(outgoing, incoming, nofqs, baseAccuracyFrom, baseAccuracyTo,
rateAccuracyFrom, rateAccuracyTo, gte, m);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_set_minimum(NGTOptimizer optimizer, int outgoing, int incoming,
int nofqs, int nofrs, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->set(outgoing, incoming, nofqs, nofrs);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_set_extension(NGTOptimizer optimizer,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->setExtension(baseAccuracyFrom, baseAccuracyTo,
rateAccuracyFrom, rateAccuracyTo, gte, m);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_set_processing_modes(NGTOptimizer optimizer, bool searchParameter,
bool prefetchParameter, bool accuracyTable, NGTError error)
{
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
(static_cast<NGT::GraphOptimizer*>(optimizer))->setProcessingModes(searchParameter, prefetchParameter,
accuracyTable);
return true;
}
void ngt_destroy_optimizer(NGTOptimizer optimizer)
{
if(optimizer == NULL) return;
delete(static_cast<NGT::GraphOptimizer*>(optimizer));
}
bool ngt_refine_anng(NGTIndex index, float epsilon, float accuracy, int noOfEdges, int exploreEdgeSize, size_t batchSize, NGTError error)
{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
try {
NGT::GraphReconstructor::refineANNG(*pindex, true, epsilon, accuracy, noOfEdges, exploreEdgeSize, batchSize);
} catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_get_edges(NGTIndex index, ObjectID id, NGTObjectDistances edges, NGTError error)
{
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
NGT::GraphIndex &graph = static_cast<NGT::GraphIndex&>(pindex->getIndex());
try {
NGT::ObjectDistances &objects = *static_cast<NGT::ObjectDistances*>(edges);
objects = *graph.getNode(id);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
uint32_t ngt_get_object_repository_size(NGTIndex index, NGTError error)
{
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index;
operate_error_string_(ss, error);
return false;
}
NGT::Index& pindex = *static_cast<NGT::Index*>(index);
return pindex.getObjectRepositorySize();
}
NGTAnngEdgeOptimizationParameter ngt_get_anng_edge_optimization_parameter()
{
NGT::GraphOptimizer::ANNGEdgeOptimizationParameter gp;
NGTAnngEdgeOptimizationParameter parameter;
parameter.no_of_queries = gp.noOfQueries;
parameter.no_of_results = gp.noOfResults;
parameter.no_of_threads = gp.noOfThreads;
parameter.target_accuracy = gp.targetAccuracy;
parameter.target_no_of_objects = gp.targetNoOfObjects;
parameter.no_of_sample_objects = gp.noOfSampleObjects;
parameter.max_of_no_of_edges = gp.maxNoOfEdges;
parameter.log = false;
return parameter;
}
bool ngt_optimize_number_of_edges(const char *indexPath, NGTAnngEdgeOptimizationParameter parameter, NGTError error)
{
NGT::GraphOptimizer::ANNGEdgeOptimizationParameter p;
p.noOfQueries = parameter.no_of_queries;
p.noOfResults = parameter.no_of_results;
p.noOfThreads = parameter.no_of_threads;
p.targetAccuracy = parameter.target_accuracy;
p.targetNoOfObjects = parameter.target_no_of_objects;
p.noOfSampleObjects = parameter.no_of_sample_objects;
p.maxNoOfEdges = parameter.max_of_no_of_edges;
try {
NGT::GraphOptimizer graphOptimizer(!parameter.log); // false=log
std::string path(indexPath);
auto edge = graphOptimizer.optimizeNumberOfEdgesForANNG(path, p);
if (parameter.log) {
std::cerr << "the optimized number of edges is" << edge.first << "(" << edge.second << ")" << std::endl;
}
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}

Some files were not shown because too many files have changed in this diff Show More