mirror of https://github.com/milvus-io/milvus.git
Refactor cmake and build script and add timed benchmark
Signed-off-by: FluorineDog <guilin.gou@zilliz.com>pull/4973/head^2
parent
9d2ebe7632
commit
e84b0180c9
|
@ -0,0 +1,34 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
---
|
||||
# Below is copied from milvus project
|
||||
BasedOnStyle: Google
|
||||
DerivePointerAlignment: false
|
||||
ColumnLimit: 120
|
||||
IndentWidth: 4
|
||||
AccessModifierOffset: -3
|
||||
AlwaysBreakAfterReturnType: All
|
||||
AllowShortBlocksOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: false
|
||||
AllowShortIfStatementsOnASingleLine: false
|
||||
AlignTrailingComments: true
|
||||
|
||||
# Appended Options
|
||||
SortIncludes: false
|
||||
Standard: Latest
|
||||
AlignAfterOpenBracket: Align
|
||||
BinPackParameters: false
|
|
@ -1,17 +1,10 @@
|
|||
# CLion generated files
|
||||
core/cmake-build-debug/
|
||||
core/cmake-build-debug/*
|
||||
core/cmake-build-release/
|
||||
core/cmake-build-release/*
|
||||
core/cmake_build/
|
||||
core/cmake_build/*
|
||||
core/build/
|
||||
core/build/*
|
||||
core/.idea/
|
||||
.idea/
|
||||
.idea/*
|
||||
pulsar/client-cpp/cmake-build-debug/
|
||||
pulsar/client-cpp/cmake-build-debug/*
|
||||
**/cmake-build-debug/*
|
||||
**/cmake_build/*
|
||||
**/cmake-build-release/*
|
||||
internal/core/output/*
|
||||
internal/core/build/*
|
||||
**/.idea/*
|
||||
pulsar/client-cpp/build/
|
||||
pulsar/client-cpp/build/*
|
||||
|
||||
|
|
|
@ -11,12 +11,12 @@ done
|
|||
SCRIPTS_DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )"
|
||||
|
||||
MILVUS_CORE_DIR="${SCRIPTS_DIR}/../../internal/core"
|
||||
CORE_INSTALL_PREFIX="${MILVUS_CORE_DIR}/milvus"
|
||||
CORE_INSTALL_PREFIX="${MILVUS_CORE_DIR}/output"
|
||||
UNITTEST_DIRS=("${CORE_INSTALL_PREFIX}/unittest")
|
||||
|
||||
# Currently core will install target lib to "core/lib"
|
||||
if [ -d "${MILVUS_CORE_DIR}/lib" ]; then
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${MILVUS_CORE_DIR}/lib
|
||||
# Currently core will install target lib to "core/output/lib"
|
||||
if [ -d "${CORE_INSTALL_PREFIX}/lib" ]; then
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${CORE_INSTALL_PREFIX}/lib
|
||||
fi
|
||||
|
||||
# run unittest
|
||||
|
|
|
@ -35,10 +35,11 @@ ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/lib"
|
|||
# Install Go
|
||||
ENV GOPATH /go
|
||||
ENV GOROOT /usr/local/go
|
||||
RUN mkdir -p /usr/local/go && wget -qO- "https://golang.org/dl/go1.15.2.linux-amd64.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \
|
||||
mkdir -p "$GOPATH/src" "$GOPATH/bin" && chmod -R 777 "$GOPATH"
|
||||
ENV GO111MODULE on
|
||||
ENV PATH $GOPATH/bin:$GOROOT/bin:$PATH
|
||||
RUN mkdir -p /usr/local/go && wget -qO- "https://golang.org/dl/go1.15.2.linux-amd64.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \
|
||||
mkdir -p "$GOPATH/src" "$GOPATH/bin" && chmod -R 777 "$GOPATH" && \
|
||||
go get github.com/golang/protobuf/protoc-gen-go@v1.3.2
|
||||
|
||||
# Set permissions on /etc/passwd and /home to allow arbitrary users to write
|
||||
COPY --chown=0:0 docker/build_env/entrypoint.sh /
|
||||
|
|
|
@ -176,8 +176,6 @@ config_summary()
|
|||
add_subdirectory( thirdparty )
|
||||
add_subdirectory( src )
|
||||
|
||||
|
||||
|
||||
# Unittest lib
|
||||
if ( BUILD_UNIT_TEST STREQUAL "ON" )
|
||||
if ( BUILD_COVERAGE STREQUAL "ON" )
|
||||
|
@ -189,7 +187,7 @@ if ( BUILD_UNIT_TEST STREQUAL "ON" )
|
|||
endif ()
|
||||
append_flags( CMAKE_CXX_FLAGS FLAGS "-DELPP_DISABLE_LOGS")
|
||||
|
||||
add_subdirectory( ${CMAKE_CURRENT_SOURCE_DIR}/unittest )
|
||||
add_subdirectory(unittest)
|
||||
endif ()
|
||||
|
||||
|
||||
|
@ -206,9 +204,9 @@ set( GPU_ENABLE "false" )
|
|||
|
||||
install(
|
||||
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/dog_segment/
|
||||
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
DESTINATION include
|
||||
FILES_MATCHING PATTERN "*_c.h"
|
||||
)
|
||||
|
||||
install(FILES ${CMAKE_BINARY_DIR}/src/dog_segment/libmilvus_dog_segment.so
|
||||
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib)
|
||||
DESTINATION lib)
|
||||
|
|
|
@ -8,7 +8,7 @@ fi
|
|||
BUILD_OUTPUT_DIR="cmake_build"
|
||||
BUILD_TYPE="Release"
|
||||
BUILD_UNITTEST="OFF"
|
||||
INSTALL_PREFIX=$(pwd)/milvus
|
||||
INSTALL_PREFIX=$(pwd)/output
|
||||
MAKE_CLEAN="OFF"
|
||||
BUILD_COVERAGE="OFF"
|
||||
DB_PATH="/tmp/milvus"
|
||||
|
@ -20,7 +20,7 @@ WITH_PROMETHEUS="ON"
|
|||
CUDA_ARCH="DEFAULT"
|
||||
CUSTOM_THIRDPARTY_PATH=""
|
||||
|
||||
while getopts "p:d:t:s:f:ulrcghzme" arg; do
|
||||
while getopts "p:d:t:s:f:o:ulrcghzme" arg; do
|
||||
case $arg in
|
||||
f)
|
||||
CUSTOM_THIRDPARTY_PATH=$OPTARG
|
||||
|
@ -28,6 +28,9 @@ while getopts "p:d:t:s:f:ulrcghzme" arg; do
|
|||
p)
|
||||
INSTALL_PREFIX=$OPTARG
|
||||
;;
|
||||
o)
|
||||
BUILD_OUTPUT_DIR=$OPTARG
|
||||
;;
|
||||
d)
|
||||
DB_PATH=$OPTARG
|
||||
;;
|
||||
|
|
|
@ -64,16 +64,12 @@ define_option(MILVUS_VERBOSE_THIRDPARTY_BUILD
|
|||
|
||||
define_option(MILVUS_WITH_EASYLOGGINGPP "Build with Easylogging++ library" ON)
|
||||
|
||||
define_option(MILVUS_WITH_GRPC "Build with GRPC" OFF)
|
||||
|
||||
define_option(MILVUS_WITH_ZLIB "Build with zlib compression" ON)
|
||||
|
||||
define_option(MILVUS_WITH_OPENTRACING "Build with Opentracing" ON)
|
||||
|
||||
define_option(MILVUS_WITH_YAMLCPP "Build with yaml-cpp library" ON)
|
||||
|
||||
define_option(MILVUS_WITH_PULSAR "Build with pulsar-client" ON)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
set_option_category("Test and benchmark")
|
||||
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef void* CCollection;
|
||||
|
||||
CCollection
|
||||
NewCollection(const char* collection_name, const char* schema_conf);
|
||||
|
||||
void
|
||||
DeleteCollection(CCollection collection);
|
||||
|
||||
void
|
||||
UpdateIndexes(CCollection c_collection, const char *index_string);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -1,17 +0,0 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include "collection_c.h"
|
||||
|
||||
typedef void* CPartition;
|
||||
|
||||
CPartition
|
||||
NewPartition(CCollection collection, const char* partition_name);
|
||||
|
||||
void
|
||||
DeletePartition(CPartition partition);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -1,89 +0,0 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <stdbool.h>
|
||||
#include "partition_c.h"
|
||||
|
||||
typedef void* CSegmentBase;
|
||||
|
||||
typedef struct CQueryInfo {
|
||||
long int num_queries;
|
||||
int topK;
|
||||
const char* field_name;
|
||||
} CQueryInfo;
|
||||
|
||||
CSegmentBase
|
||||
NewSegment(CPartition partition, unsigned long segment_id);
|
||||
|
||||
void
|
||||
DeleteSegment(CSegmentBase segment);
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
int
|
||||
Insert(CSegmentBase c_segment,
|
||||
long int reserved_offset,
|
||||
signed long int size,
|
||||
const long* primary_keys,
|
||||
const unsigned long* timestamps,
|
||||
void* raw_data,
|
||||
int sizeof_per_row,
|
||||
signed long int count);
|
||||
|
||||
long int
|
||||
PreInsert(CSegmentBase c_segment, long int size);
|
||||
|
||||
int
|
||||
Delete(CSegmentBase c_segment,
|
||||
long int reserved_offset,
|
||||
long size,
|
||||
const long* primary_keys,
|
||||
const unsigned long* timestamps);
|
||||
|
||||
long int
|
||||
PreDelete(CSegmentBase c_segment, long int size);
|
||||
|
||||
//int
|
||||
//Search(CSegmentBase c_segment,
|
||||
// const char* query_json,
|
||||
// unsigned long timestamp,
|
||||
// float* query_raw_data,
|
||||
// int num_of_query_raw_data,
|
||||
// long int* result_ids,
|
||||
// float* result_distances);
|
||||
|
||||
int
|
||||
Search(CSegmentBase c_segment,
|
||||
CQueryInfo c_query_info,
|
||||
unsigned long timestamp,
|
||||
float* query_raw_data,
|
||||
int num_of_query_raw_data,
|
||||
long int* result_ids,
|
||||
float* result_distances);
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
int
|
||||
Close(CSegmentBase c_segment);
|
||||
|
||||
int
|
||||
BuildIndex(CCollection c_collection, CSegmentBase c_segment);
|
||||
|
||||
bool
|
||||
IsOpened(CSegmentBase c_segment);
|
||||
|
||||
long int
|
||||
GetMemoryUsageInBytes(CSegmentBase c_segment);
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
long int
|
||||
GetRowCount(CSegmentBase c_segment);
|
||||
|
||||
long int
|
||||
GetDeletedCount(CSegmentBase c_segment);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -13,7 +13,7 @@
|
|||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
#include<iostream>
|
||||
#include <iostream>
|
||||
#include "config/ConfigMgr.h"
|
||||
#include "config/ServerConfig.h"
|
||||
|
||||
|
@ -70,22 +70,19 @@ ConfigMgr::ConfigMgr() {
|
|||
config_list_ = {
|
||||
|
||||
/* general */
|
||||
{"timezone",
|
||||
CreateStringConfig("timezone", false, &config.timezone.value, "UTC+8", nullptr, nullptr)},
|
||||
{"timezone", CreateStringConfig("timezone", false, &config.timezone.value, "UTC+8", nullptr, nullptr)},
|
||||
|
||||
/* network */
|
||||
{"network.address", CreateStringConfig("network.address", false, &config.network.address.value,
|
||||
"0.0.0.0", nullptr, nullptr)},
|
||||
{"network.port", CreateIntegerConfig("network.port", false, 0, 65535, &config.network.port.value,
|
||||
19530, nullptr, nullptr)},
|
||||
{"network.address",
|
||||
CreateStringConfig("network.address", false, &config.network.address.value, "0.0.0.0", nullptr, nullptr)},
|
||||
{"network.port",
|
||||
CreateIntegerConfig("network.port", false, 0, 65535, &config.network.port.value, 19530, nullptr, nullptr)},
|
||||
|
||||
|
||||
/* pulsar */
|
||||
{"pulsar.address", CreateStringConfig("pulsar.address", false, &config.pulsar.address.value,
|
||||
"localhost", nullptr, nullptr)},
|
||||
{"pulsar.port", CreateIntegerConfig("pulsar.port", false, 0, 65535, &config.pulsar.port.value,
|
||||
6650, nullptr, nullptr)},
|
||||
|
||||
{"pulsar.address",
|
||||
CreateStringConfig("pulsar.address", false, &config.pulsar.address.value, "localhost", nullptr, nullptr)},
|
||||
{"pulsar.port",
|
||||
CreateIntegerConfig("pulsar.port", false, 0, 65535, &config.pulsar.port.value, 6650, nullptr, nullptr)},
|
||||
|
||||
/* log */
|
||||
{"logs.level", CreateStringConfig("logs.level", false, &config.logs.level.value, "debug", nullptr, nullptr)},
|
||||
|
@ -147,9 +144,9 @@ ConfigMgr::Load(const std::string& path) {
|
|||
|
||||
void
|
||||
ConfigMgr::Set(const std::string& name, const std::string& value, bool update) {
|
||||
std::cout<<"InSet Config "<< name <<std::endl;
|
||||
if (config_list_.find(name) == config_list_.end()){
|
||||
std::cout<<"Config "<< name << " not found!"<<std::endl;
|
||||
std::cout << "InSet Config " << name << std::endl;
|
||||
if (config_list_.find(name) == config_list_.end()) {
|
||||
std::cout << "Config " << name << " not found!" << std::endl;
|
||||
return;
|
||||
}
|
||||
try {
|
||||
|
|
|
@ -142,7 +142,11 @@ BaseConfig::Init() {
|
|||
inited_ = true;
|
||||
}
|
||||
|
||||
BoolConfig::BoolConfig(const char* name, const char* alias, bool modifiable, bool* config, bool default_value,
|
||||
BoolConfig::BoolConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
bool* config,
|
||||
bool default_value,
|
||||
std::function<bool(bool val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(bool val, bool prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
|
@ -199,7 +203,11 @@ BoolConfig::Get() {
|
|||
}
|
||||
|
||||
StringConfig::StringConfig(
|
||||
const char* name, const char* alias, bool modifiable, std::string* config, const char* default_value,
|
||||
const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
std::string* config,
|
||||
const char* default_value,
|
||||
std::function<bool(const std::string& val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
|
@ -251,8 +259,13 @@ StringConfig::Get() {
|
|||
return *config_;
|
||||
}
|
||||
|
||||
EnumConfig::EnumConfig(const char* name, const char* alias, bool modifiable, configEnum* enumd, int64_t* config,
|
||||
int64_t default_value, std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
EnumConfig::EnumConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
configEnum* enumd,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
config_(config),
|
||||
|
@ -324,8 +337,13 @@ EnumConfig::Get() {
|
|||
return "unknown";
|
||||
}
|
||||
|
||||
IntegerConfig::IntegerConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound,
|
||||
int64_t upper_bound, int64_t* config, int64_t default_value,
|
||||
IntegerConfig::IntegerConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
int64_t lower_bound,
|
||||
int64_t upper_bound,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
|
@ -393,8 +411,13 @@ IntegerConfig::Get() {
|
|||
return std::to_string(*config_);
|
||||
}
|
||||
|
||||
FloatingConfig::FloatingConfig(const char* name, const char* alias, bool modifiable, double lower_bound,
|
||||
double upper_bound, double* config, double default_value,
|
||||
FloatingConfig::FloatingConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
double lower_bound,
|
||||
double upper_bound,
|
||||
double* config,
|
||||
double default_value,
|
||||
std::function<bool(double val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(double val, double prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
|
@ -457,8 +480,13 @@ FloatingConfig::Get() {
|
|||
return std::to_string(*config_);
|
||||
}
|
||||
|
||||
SizeConfig::SizeConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound,
|
||||
int64_t* config, int64_t default_value,
|
||||
SizeConfig::SizeConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
int64_t lower_bound,
|
||||
int64_t upper_bound,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
|
||||
: BaseConfig(name, alias, modifiable),
|
||||
|
|
|
@ -67,7 +67,11 @@ using BaseConfigPtr = std::shared_ptr<BaseConfig>;
|
|||
|
||||
class BoolConfig : public BaseConfig {
|
||||
public:
|
||||
BoolConfig(const char* name, const char* alias, bool modifiable, bool* config, bool default_value,
|
||||
BoolConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
bool* config,
|
||||
bool default_value,
|
||||
std::function<bool(bool val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(bool val, bool prev, std::string& err)> update_fn);
|
||||
|
||||
|
@ -90,7 +94,11 @@ class BoolConfig : public BaseConfig {
|
|||
|
||||
class StringConfig : public BaseConfig {
|
||||
public:
|
||||
StringConfig(const char* name, const char* alias, bool modifiable, std::string* config, const char* default_value,
|
||||
StringConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
std::string* config,
|
||||
const char* default_value,
|
||||
std::function<bool(const std::string& val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn);
|
||||
|
||||
|
@ -113,8 +121,13 @@ class StringConfig : public BaseConfig {
|
|||
|
||||
class EnumConfig : public BaseConfig {
|
||||
public:
|
||||
EnumConfig(const char* name, const char* alias, bool modifiable, configEnum* enumd, int64_t* config,
|
||||
int64_t default_value, std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
EnumConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
configEnum* enumd,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
|
||||
|
||||
private:
|
||||
|
@ -137,8 +150,13 @@ class EnumConfig : public BaseConfig {
|
|||
|
||||
class IntegerConfig : public BaseConfig {
|
||||
public:
|
||||
IntegerConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound,
|
||||
int64_t* config, int64_t default_value,
|
||||
IntegerConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
int64_t lower_bound,
|
||||
int64_t upper_bound,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
|
||||
|
||||
|
@ -163,8 +181,14 @@ class IntegerConfig : public BaseConfig {
|
|||
|
||||
class FloatingConfig : public BaseConfig {
|
||||
public:
|
||||
FloatingConfig(const char* name, const char* alias, bool modifiable, double lower_bound, double upper_bound,
|
||||
double* config, double default_value, std::function<bool(double val, std::string& err)> is_valid_fn,
|
||||
FloatingConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
double lower_bound,
|
||||
double upper_bound,
|
||||
double* config,
|
||||
double default_value,
|
||||
std::function<bool(double val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(double val, double prev, std::string& err)> update_fn);
|
||||
|
||||
private:
|
||||
|
@ -188,8 +212,14 @@ class FloatingConfig : public BaseConfig {
|
|||
|
||||
class SizeConfig : public BaseConfig {
|
||||
public:
|
||||
SizeConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound,
|
||||
int64_t* config, int64_t default_value, std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
SizeConfig(const char* name,
|
||||
const char* alias,
|
||||
bool modifiable,
|
||||
int64_t lower_bound,
|
||||
int64_t upper_bound,
|
||||
int64_t* config,
|
||||
int64_t default_value,
|
||||
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
|
||||
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
|
||||
|
||||
private:
|
||||
|
|
|
@ -71,11 +71,10 @@ struct ServerConfig {
|
|||
Integer port{0};
|
||||
} network;
|
||||
|
||||
struct Pulsar{
|
||||
struct Pulsar {
|
||||
String address{"localhost"};
|
||||
Integer port{6650};
|
||||
}pulsar;
|
||||
|
||||
} pulsar;
|
||||
|
||||
struct Engine {
|
||||
Integer build_index_threshold{4096};
|
||||
|
@ -89,7 +88,6 @@ struct ServerConfig {
|
|||
String json_config_path{"unknown"};
|
||||
} tracing;
|
||||
|
||||
|
||||
struct Logs {
|
||||
String level{"unknown"};
|
||||
struct Trace {
|
||||
|
|
|
@ -11,13 +11,13 @@ class AckResponder {
|
|||
std::lock_guard lck(mutex_);
|
||||
fetch_and_flip(seg_end);
|
||||
auto old_begin = fetch_and_flip(seg_begin);
|
||||
if(old_begin) {
|
||||
if (old_begin) {
|
||||
minimal = *acks_.begin();
|
||||
}
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetAck() const{
|
||||
GetAck() const {
|
||||
return minimal;
|
||||
}
|
||||
|
||||
|
@ -38,4 +38,4 @@ class AckResponder {
|
|||
std::set<int64_t> acks_ = {0};
|
||||
std::atomic<int64_t> minimal = 0;
|
||||
};
|
||||
}
|
||||
} // namespace milvus::dog_segment
|
||||
|
|
|
@ -11,7 +11,7 @@ set(DOG_SEGMENT_FILES
|
|||
partition_c.cpp
|
||||
segment_c.cpp
|
||||
EasyAssert.cpp
|
||||
${PB_SRC_FILES}
|
||||
${PB_SRC_FILES}
|
||||
)
|
||||
add_library(milvus_dog_segment SHARED
|
||||
${DOG_SEGMENT_FILES}
|
||||
|
@ -20,5 +20,9 @@ add_library(milvus_dog_segment SHARED
|
|||
|
||||
#add_dependencies( segment sqlite mysqlpp )
|
||||
|
||||
target_link_libraries(milvus_dog_segment tbb utils pthread knowhere log libprotobuf dl backtrace
|
||||
)
|
||||
target_link_libraries(milvus_dog_segment
|
||||
tbb utils pthread knowhere log libprotobuf
|
||||
dl backtrace
|
||||
milvus_query
|
||||
)
|
||||
|
||||
|
|
|
@ -6,17 +6,14 @@
|
|||
|
||||
namespace milvus::dog_segment {
|
||||
|
||||
|
||||
Collection::Collection(std::string &collection_name, std::string &schema):
|
||||
collection_name_(collection_name), schema_json_(schema) {
|
||||
Collection::Collection(std::string& collection_name, std::string& schema)
|
||||
: collection_name_(collection_name), schema_json_(schema) {
|
||||
parse();
|
||||
index_ = nullptr;
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
Collection::AddIndex(const grpc::IndexParam& index_param) {
|
||||
|
||||
auto& index_name = index_param.index_name();
|
||||
auto& field_name = index_param.field_name();
|
||||
|
||||
|
@ -32,7 +29,7 @@ Collection::AddIndex(const grpc::IndexParam& index_param) {
|
|||
bool found_index_conf = false;
|
||||
|
||||
auto extra_params = index_param.extra_params();
|
||||
for (auto& extra_param: extra_params) {
|
||||
for (auto& extra_param : extra_params) {
|
||||
if (extra_param.key() == "index_type") {
|
||||
index_type = extra_param.value().data();
|
||||
found_index_type = true;
|
||||
|
@ -67,21 +64,18 @@ Collection::AddIndex(const grpc::IndexParam& index_param) {
|
|||
if (!found_index_conf) {
|
||||
int dim = 0;
|
||||
|
||||
for (auto& field: schema_->get_fields()) {
|
||||
for (auto& field : schema_->get_fields()) {
|
||||
if (field.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
dim = field.get_dim();
|
||||
dim = field.get_dim();
|
||||
}
|
||||
}
|
||||
Assert(dim != 0);
|
||||
|
||||
index_conf = milvus::knowhere::Config{
|
||||
{knowhere::meta::DIM, dim},
|
||||
{knowhere::IndexParams::nlist, 100},
|
||||
{knowhere::IndexParams::nprobe, 4},
|
||||
{knowhere::IndexParams::m, 4},
|
||||
{knowhere::IndexParams::nbits, 8},
|
||||
{knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{knowhere::meta::DEVICEID, 0},
|
||||
{knowhere::meta::DIM, dim}, {knowhere::IndexParams::nlist, 100},
|
||||
{knowhere::IndexParams::nprobe, 4}, {knowhere::IndexParams::m, 4},
|
||||
{knowhere::IndexParams::nbits, 8}, {knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{knowhere::meta::DEVICEID, 0},
|
||||
};
|
||||
std::cout << "WARN: Not specify index config, use default index config" << std::endl;
|
||||
}
|
||||
|
@ -89,11 +83,9 @@ Collection::AddIndex(const grpc::IndexParam& index_param) {
|
|||
index_->AddEntry(index_name, field_name, index_type, index_mode, index_conf);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
Collection::CreateIndex(std::string &index_config) {
|
||||
|
||||
if(index_config.empty()) {
|
||||
Collection::CreateIndex(std::string& index_config) {
|
||||
if (index_config.empty()) {
|
||||
index_ = nullptr;
|
||||
std::cout << "null index config when create index" << std::endl;
|
||||
return;
|
||||
|
@ -108,18 +100,16 @@ Collection::CreateIndex(std::string &index_config) {
|
|||
|
||||
index_ = std::make_shared<IndexMeta>(schema_);
|
||||
|
||||
for (const auto &index: collection.indexes()){
|
||||
std::cout << "add index, index name =" << index.index_name()
|
||||
<< ", field_name = " << index.field_name()
|
||||
for (const auto& index : collection.indexes()) {
|
||||
std::cout << "add index, index name =" << index.index_name() << ", field_name = " << index.field_name()
|
||||
<< std::endl;
|
||||
AddIndex(index);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
Collection::parse() {
|
||||
if(schema_json_.empty()) {
|
||||
if (schema_json_.empty()) {
|
||||
std::cout << "WARN: Use default schema" << std::endl;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
|
@ -131,22 +121,20 @@ Collection::parse() {
|
|||
masterpb::Collection collection;
|
||||
auto suc = google::protobuf::TextFormat::ParseFromString(schema_json_, &collection);
|
||||
|
||||
|
||||
if (!suc) {
|
||||
std::cerr << "unmarshal schema string failed" << std::endl;
|
||||
}
|
||||
auto schema = std::make_shared<Schema>();
|
||||
for (const milvus::grpc::FieldMeta & child: collection.schema().field_metas()){
|
||||
std::cout<<"add Field, name :" << child.field_name() << ", datatype :" << child.type() << ", dim :" << int(child.dim()) << std::endl;
|
||||
schema->AddField(std::string_view(child.field_name()), DataType {child.type()}, int(child.dim()));
|
||||
for (const milvus::grpc::FieldMeta& child : collection.schema().field_metas()) {
|
||||
std::cout << "add Field, name :" << child.field_name() << ", datatype :" << child.type()
|
||||
<< ", dim :" << int(child.dim()) << std::endl;
|
||||
schema->AddField(std::string_view(child.field_name()), DataType{child.type()}, int(child.dim()));
|
||||
}
|
||||
/*
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
*/
|
||||
schema_ = schema;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace milvus::dog_segment
|
||||
|
|
|
@ -7,29 +7,35 @@
|
|||
namespace milvus::dog_segment {
|
||||
|
||||
class Collection {
|
||||
public:
|
||||
explicit Collection(std::string &collection_name, std::string &schema);
|
||||
public:
|
||||
explicit Collection(std::string& collection_name, std::string& schema);
|
||||
|
||||
void AddIndex(const grpc::IndexParam &index_param);
|
||||
void
|
||||
AddIndex(const grpc::IndexParam& index_param);
|
||||
|
||||
void CreateIndex(std::string &index_config);
|
||||
void
|
||||
CreateIndex(std::string& index_config);
|
||||
|
||||
void parse();
|
||||
void
|
||||
parse();
|
||||
|
||||
public:
|
||||
SchemaPtr& get_schema() {
|
||||
return schema_;
|
||||
public:
|
||||
SchemaPtr&
|
||||
get_schema() {
|
||||
return schema_;
|
||||
}
|
||||
|
||||
IndexMetaPtr& get_index() {
|
||||
return index_;
|
||||
IndexMetaPtr&
|
||||
get_index() {
|
||||
return index_;
|
||||
}
|
||||
|
||||
std::string& get_collection_name() {
|
||||
return collection_name_;
|
||||
std::string&
|
||||
get_collection_name() {
|
||||
return collection_name_;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
IndexMetaPtr index_;
|
||||
std::string collection_name_;
|
||||
std::string schema_json_;
|
||||
|
@ -38,4 +44,4 @@ private:
|
|||
|
||||
using CollectionPtr = std::unique_ptr<Collection>;
|
||||
|
||||
}
|
||||
} // namespace milvus::dog_segment
|
|
@ -2,7 +2,4 @@
|
|||
#include <iostream>
|
||||
#include "dog_segment/ConcurrentVector.h"
|
||||
|
||||
namespace milvus::dog_segment {
|
||||
|
||||
}
|
||||
|
||||
namespace milvus::dog_segment {}
|
||||
|
|
|
@ -90,7 +90,8 @@ class VectorBase {
|
|||
virtual void
|
||||
grow_to_at_least(int64_t element_count) = 0;
|
||||
|
||||
virtual void set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) = 0;
|
||||
virtual void
|
||||
set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) = 0;
|
||||
};
|
||||
|
||||
template <typename Type, bool is_scalar = false, ssize_t ElementsPerChunk = DefaultElementPerChunk>
|
||||
|
@ -101,10 +102,12 @@ class ConcurrentVector : public VectorBase {
|
|||
ConcurrentVector(ConcurrentVector&&) = delete;
|
||||
ConcurrentVector(const ConcurrentVector&) = delete;
|
||||
|
||||
ConcurrentVector& operator=(ConcurrentVector&&) = delete;
|
||||
ConcurrentVector& operator=(const ConcurrentVector&) = delete;
|
||||
public:
|
||||
ConcurrentVector&
|
||||
operator=(ConcurrentVector&&) = delete;
|
||||
ConcurrentVector&
|
||||
operator=(const ConcurrentVector&) = delete;
|
||||
|
||||
public:
|
||||
explicit ConcurrentVector(ssize_t dim = 1) : Dim(is_scalar ? 1 : dim), SizePerChunk(Dim * ElementsPerChunk) {
|
||||
Assert(is_scalar ? dim == 1 : dim != 1);
|
||||
}
|
||||
|
@ -185,8 +188,8 @@ class ConcurrentVector : public VectorBase {
|
|||
|
||||
private:
|
||||
void
|
||||
fill_chunk(ssize_t chunk_id, ssize_t chunk_offset, ssize_t element_count, const Type* source,
|
||||
ssize_t source_offset) {
|
||||
fill_chunk(
|
||||
ssize_t chunk_id, ssize_t chunk_offset, ssize_t element_count, const Type* source, ssize_t source_offset) {
|
||||
if (element_count <= 0) {
|
||||
return;
|
||||
}
|
||||
|
@ -199,6 +202,7 @@ class ConcurrentVector : public VectorBase {
|
|||
|
||||
const ssize_t Dim;
|
||||
const ssize_t SizePerChunk;
|
||||
|
||||
private:
|
||||
ThreadSafeVector<Chunk> chunks_;
|
||||
};
|
||||
|
|
|
@ -13,22 +13,25 @@ struct DeletedRecord {
|
|||
int64_t del_barrier = 0;
|
||||
faiss::ConcurrentBitsetPtr bitmap_ptr;
|
||||
|
||||
std::shared_ptr<TmpBitmap> clone(int64_t capacity);
|
||||
std::shared_ptr<TmpBitmap>
|
||||
clone(int64_t capacity);
|
||||
};
|
||||
|
||||
DeletedRecord() : lru_(std::make_shared<TmpBitmap>()) {
|
||||
lru_->bitmap_ptr = std::make_shared<faiss::ConcurrentBitset>(0);
|
||||
}
|
||||
|
||||
auto get_lru_entry() {
|
||||
auto
|
||||
get_lru_entry() {
|
||||
std::shared_lock lck(shared_mutex_);
|
||||
return lru_;
|
||||
}
|
||||
|
||||
void insert_lru_entry(std::shared_ptr<TmpBitmap> new_entry, bool force = false) {
|
||||
void
|
||||
insert_lru_entry(std::shared_ptr<TmpBitmap> new_entry, bool force = false) {
|
||||
std::lock_guard lck(shared_mutex_);
|
||||
if (new_entry->del_barrier <= lru_->del_barrier) {
|
||||
if (!force || new_entry->bitmap_ptr->capacity() <= lru_->bitmap_ptr->capacity()) {
|
||||
if (!force || new_entry->bitmap_ptr->count() <= lru_->bitmap_ptr->count()) {
|
||||
// DO NOTHING
|
||||
return;
|
||||
}
|
||||
|
@ -36,18 +39,19 @@ struct DeletedRecord {
|
|||
lru_ = std::move(new_entry);
|
||||
}
|
||||
|
||||
public:
|
||||
public:
|
||||
std::atomic<int64_t> reserved = 0;
|
||||
AckResponder ack_responder_;
|
||||
ConcurrentVector<Timestamp, true> timestamps_;
|
||||
ConcurrentVector<idx_t, true> uids_;
|
||||
private:
|
||||
|
||||
private:
|
||||
std::shared_ptr<TmpBitmap> lru_;
|
||||
std::shared_mutex shared_mutex_;
|
||||
|
||||
};
|
||||
|
||||
auto DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr<TmpBitmap> {
|
||||
auto
|
||||
DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr<TmpBitmap> {
|
||||
auto res = std::make_shared<TmpBitmap>();
|
||||
res->del_barrier = this->del_barrier;
|
||||
res->bitmap_ptr = std::make_shared<faiss::ConcurrentBitset>(capacity);
|
||||
|
@ -56,4 +60,4 @@ auto DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr<TmpBit
|
|||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace milvus::dog_segment
|
||||
|
|
|
@ -5,15 +5,15 @@
|
|||
#define BOOST_STACKTRACE_USE_BACKTRACE
|
||||
#include <boost/stacktrace.hpp>
|
||||
|
||||
|
||||
namespace milvus::impl {
|
||||
void EasyAssertInfo(bool value, std::string_view expr_str, std::string_view filename, int lineno,
|
||||
std::string_view extra_info) {
|
||||
void
|
||||
EasyAssertInfo(
|
||||
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info) {
|
||||
if (!value) {
|
||||
std::string info;
|
||||
info += "Assert \"" + std::string(expr_str) + "\"";
|
||||
info += " at " + std::string(filename) + ":" + std::to_string(lineno) + "\n";
|
||||
if(!extra_info.empty()) {
|
||||
if (!extra_info.empty()) {
|
||||
info += " => " + std::string(extra_info);
|
||||
}
|
||||
auto fuck = boost::stacktrace::stacktrace();
|
||||
|
@ -23,4 +23,4 @@ void EasyAssertInfo(bool value, std::string_view expr_str, std::string_view file
|
|||
throw std::runtime_error(info);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace milvus::impl
|
|
@ -6,8 +6,9 @@
|
|||
/* Paste this on the file you want to debug. */
|
||||
|
||||
namespace milvus::impl {
|
||||
void EasyAssertInfo(bool value, std::string_view expr_str, std::string_view filename, int lineno,
|
||||
std::string_view extra_info);
|
||||
void
|
||||
EasyAssertInfo(
|
||||
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info);
|
||||
}
|
||||
|
||||
#define AssertInfo(expr, info) impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info))
|
||||
|
|
|
@ -4,15 +4,9 @@
|
|||
namespace milvus::dog_segment {
|
||||
|
||||
Status
|
||||
IndexMeta::AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode,
|
||||
IndexConfig config) {
|
||||
Entry entry{
|
||||
index_name,
|
||||
field_name,
|
||||
type,
|
||||
mode,
|
||||
std::move(config)
|
||||
};
|
||||
IndexMeta::AddEntry(
|
||||
const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode, IndexConfig config) {
|
||||
Entry entry{index_name, field_name, type, mode, std::move(config)};
|
||||
VerifyEntry(entry);
|
||||
|
||||
if (entries_.count(index_name)) {
|
||||
|
@ -30,22 +24,23 @@ Status
|
|||
IndexMeta::DropEntry(const std::string& index_name) {
|
||||
Assert(entries_.count(index_name));
|
||||
auto entry = std::move(entries_[index_name]);
|
||||
if(lookups_[entry.field_name] == index_name) {
|
||||
if (lookups_[entry.field_name] == index_name) {
|
||||
lookups_.erase(entry.field_name);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void IndexMeta::VerifyEntry(const Entry &entry) {
|
||||
void
|
||||
IndexMeta::VerifyEntry(const Entry& entry) {
|
||||
auto is_mode_valid = std::set{IndexMode::MODE_CPU, IndexMode::MODE_GPU}.count(entry.mode);
|
||||
if(!is_mode_valid) {
|
||||
if (!is_mode_valid) {
|
||||
throw std::invalid_argument("invalid mode");
|
||||
}
|
||||
|
||||
auto& schema = *schema_;
|
||||
auto& field_meta = schema[entry.field_name];
|
||||
// TODO checking
|
||||
if(field_meta.is_vector()) {
|
||||
if (field_meta.is_vector()) {
|
||||
Assert(entry.type == knowhere::IndexEnum::INDEX_FAISS_IVFPQ);
|
||||
} else {
|
||||
Assert(false);
|
||||
|
|
|
@ -29,7 +29,10 @@ class IndexMeta {
|
|||
};
|
||||
|
||||
Status
|
||||
AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode,
|
||||
AddEntry(const std::string& index_name,
|
||||
const std::string& field_name,
|
||||
IndexType type,
|
||||
IndexMode mode,
|
||||
IndexConfig config);
|
||||
|
||||
Status
|
||||
|
@ -40,12 +43,14 @@ class IndexMeta {
|
|||
return entries_;
|
||||
}
|
||||
|
||||
const Entry& lookup_by_field(const std::string& field_name) {
|
||||
const Entry&
|
||||
lookup_by_field(const std::string& field_name) {
|
||||
AssertInfo(lookups_.count(field_name), field_name);
|
||||
auto index_name = lookups_.at(field_name);
|
||||
AssertInfo(entries_.count(index_name), index_name);
|
||||
return entries_.at(index_name);
|
||||
}
|
||||
|
||||
private:
|
||||
void
|
||||
VerifyEntry(const Entry& entry);
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
|
||||
namespace milvus::dog_segment {
|
||||
|
||||
Partition::Partition(std::string& partition_name, SchemaPtr& schema, IndexMetaPtr& index):
|
||||
partition_name_(partition_name), schema_(schema), index_(index) {}
|
||||
|
||||
Partition::Partition(std::string& partition_name, SchemaPtr& schema, IndexMetaPtr& index)
|
||||
: partition_name_(partition_name), schema_(schema), index_(index) {
|
||||
}
|
||||
|
||||
} // namespace milvus::dog_segment
|
||||
|
|
|
@ -5,23 +5,26 @@
|
|||
namespace milvus::dog_segment {
|
||||
|
||||
class Partition {
|
||||
public:
|
||||
public:
|
||||
explicit Partition(std::string& partition_name, SchemaPtr& schema, IndexMetaPtr& index);
|
||||
|
||||
public:
|
||||
SchemaPtr& get_schema() {
|
||||
return schema_;
|
||||
public:
|
||||
SchemaPtr&
|
||||
get_schema() {
|
||||
return schema_;
|
||||
}
|
||||
|
||||
IndexMetaPtr& get_index() {
|
||||
return index_;
|
||||
IndexMetaPtr&
|
||||
get_index() {
|
||||
return index_;
|
||||
}
|
||||
|
||||
std::string& get_partition_name() {
|
||||
return partition_name_;
|
||||
std::string&
|
||||
get_partition_name() {
|
||||
return partition_name_;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
std::string partition_name_;
|
||||
SchemaPtr schema_;
|
||||
IndexMetaPtr index_;
|
||||
|
@ -29,4 +32,4 @@ private:
|
|||
|
||||
using PartitionPtr = std::unique_ptr<Partition>;
|
||||
|
||||
}
|
||||
} // namespace milvus::dog_segment
|
|
@ -32,12 +32,18 @@ class SegmentBase {
|
|||
virtual ~SegmentBase() = default;
|
||||
// SegmentBase(std::shared_ptr<FieldsInfo> collection);
|
||||
|
||||
virtual int64_t PreInsert(int64_t size) = 0;
|
||||
virtual int64_t
|
||||
PreInsert(int64_t size) = 0;
|
||||
|
||||
virtual Status
|
||||
Insert(int64_t reserved_offset, int64_t size, const int64_t* primary_keys, const Timestamp* timestamps, const DogDataChunk& values) = 0;
|
||||
Insert(int64_t reserved_offset,
|
||||
int64_t size,
|
||||
const int64_t* primary_keys,
|
||||
const Timestamp* timestamps,
|
||||
const DogDataChunk& values) = 0;
|
||||
|
||||
virtual int64_t PreDelete(int64_t size) = 0;
|
||||
virtual int64_t
|
||||
PreDelete(int64_t size) = 0;
|
||||
// TODO: add id into delete log, possibly bitmap
|
||||
|
||||
virtual Status
|
||||
|
|
|
@ -152,20 +152,23 @@ class Schema {
|
|||
return total_sizeof_;
|
||||
}
|
||||
|
||||
const std::vector<int>& get_sizeof_infos() {
|
||||
const std::vector<int>&
|
||||
get_sizeof_infos() {
|
||||
return sizeof_infos_;
|
||||
}
|
||||
|
||||
std::optional<int> get_offset(const std::string& field_name) {
|
||||
if(!offsets_.count(field_name)) {
|
||||
std::optional<int>
|
||||
get_offset(const std::string& field_name) {
|
||||
if (!offsets_.count(field_name)) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return offsets_[field_name];
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<FieldMeta>& get_fields() {
|
||||
return fields_;
|
||||
const std::vector<FieldMeta>&
|
||||
get_fields() {
|
||||
return fields_;
|
||||
}
|
||||
|
||||
const FieldMeta&
|
||||
|
@ -175,6 +178,7 @@ class Schema {
|
|||
auto offset = offset_iter->second;
|
||||
return (*this)[offset];
|
||||
}
|
||||
|
||||
private:
|
||||
// this is where data holds
|
||||
std::vector<FieldMeta> fields_;
|
||||
|
|
|
@ -21,8 +21,8 @@ CreateSegment(SchemaPtr schema) {
|
|||
return segment;
|
||||
}
|
||||
|
||||
SegmentNaive::Record::Record(const Schema &schema) : uids_(1), timestamps_(1) {
|
||||
for (auto &field : schema) {
|
||||
SegmentNaive::Record::Record(const Schema& schema) : uids_(1), timestamps_(1) {
|
||||
for (auto& field : schema) {
|
||||
if (field.is_vector()) {
|
||||
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
||||
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<float>>(field.get_dim()));
|
||||
|
@ -45,17 +45,17 @@ SegmentNaive::PreDelete(int64_t size) {
|
|||
return reserved_begin;
|
||||
}
|
||||
|
||||
auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp,
|
||||
int64_t insert_barrier, bool force) -> std::shared_ptr<DeletedRecord::TmpBitmap> {
|
||||
auto
|
||||
SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier, bool force)
|
||||
-> std::shared_ptr<DeletedRecord::TmpBitmap> {
|
||||
auto old = deleted_record_.get_lru_entry();
|
||||
|
||||
if (!force || old->bitmap_ptr->capacity() == insert_barrier) {
|
||||
if (!force || old->bitmap_ptr->count() == insert_barrier) {
|
||||
if (old->del_barrier == del_barrier) {
|
||||
return old;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
auto current = old->clone(insert_barrier);
|
||||
current->del_barrier = del_barrier;
|
||||
|
||||
|
@ -67,7 +67,7 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times
|
|||
// map uid to corrensponding offsets, select the max one, which should be the target
|
||||
// the max one should be closest to query_timestamp, so the delete log should refer to it
|
||||
int64_t the_offset = -1;
|
||||
auto[iter_b, iter_e] = uid2offset_.equal_range(uid);
|
||||
auto [iter_b, iter_e] = uid2offset_.equal_range(uid);
|
||||
for (auto iter = iter_b; iter != iter_e; ++iter) {
|
||||
auto offset = iter->second;
|
||||
if (record_.timestamps_[offset] < query_timestamp) {
|
||||
|
@ -90,7 +90,7 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times
|
|||
// map uid to corrensponding offsets, select the max one, which should be the target
|
||||
// the max one should be closest to query_timestamp, so the delete log should refer to it
|
||||
int64_t the_offset = -1;
|
||||
auto[iter_b, iter_e] = uid2offset_.equal_range(uid);
|
||||
auto [iter_b, iter_e] = uid2offset_.equal_range(uid);
|
||||
for (auto iter = iter_b; iter != iter_e; ++iter) {
|
||||
auto offset = iter->second;
|
||||
if (offset >= insert_barrier) {
|
||||
|
@ -116,16 +116,19 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times
|
|||
}
|
||||
|
||||
Status
|
||||
SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t *uids_raw, const Timestamp *timestamps_raw,
|
||||
const DogDataChunk &entities_raw) {
|
||||
SegmentNaive::Insert(int64_t reserved_begin,
|
||||
int64_t size,
|
||||
const int64_t* uids_raw,
|
||||
const Timestamp* timestamps_raw,
|
||||
const DogDataChunk& entities_raw) {
|
||||
Assert(entities_raw.count == size);
|
||||
if (entities_raw.sizeof_per_row != schema_->get_total_sizeof()) {
|
||||
std::string msg = "entity length = " + std::to_string(entities_raw.sizeof_per_row) +
|
||||
", schema length = " + std::to_string(schema_->get_total_sizeof());
|
||||
std::string msg = "entity length = " + std::to_string(entities_raw.sizeof_per_row) +
|
||||
", schema length = " + std::to_string(schema_->get_total_sizeof());
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
|
||||
auto raw_data = reinterpret_cast<const char *>(entities_raw.raw_data);
|
||||
|
||||
auto raw_data = reinterpret_cast<const char*>(entities_raw.raw_data);
|
||||
// std::vector<char> entities(raw_data, raw_data + size * len_per_row);
|
||||
|
||||
auto len_per_row = entities_raw.sizeof_per_row;
|
||||
|
@ -150,7 +153,7 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t *uids_r
|
|||
std::vector<Timestamp> timestamps(size);
|
||||
// #pragma omp parallel for
|
||||
for (int index = 0; index < size; ++index) {
|
||||
auto[t, uid, order_index] = ordering[index];
|
||||
auto [t, uid, order_index] = ordering[index];
|
||||
timestamps[index] = t;
|
||||
uids[index] = uid;
|
||||
for (int fid = 0; fid < schema_->size(); ++fid) {
|
||||
|
@ -209,8 +212,7 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t *uids_r
|
|||
}
|
||||
|
||||
Status
|
||||
SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t *uids_raw,
|
||||
const Timestamp *timestamps_raw) {
|
||||
SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t* uids_raw, const Timestamp* timestamps_raw) {
|
||||
std::vector<std::tuple<Timestamp, idx_t>> ordering;
|
||||
ordering.resize(size);
|
||||
// #pragma omp parallel for
|
||||
|
@ -222,7 +224,7 @@ SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t *uids_r
|
|||
std::vector<Timestamp> timestamps(size);
|
||||
// #pragma omp parallel for
|
||||
for (int index = 0; index < size; ++index) {
|
||||
auto[t, uid] = ordering[index];
|
||||
auto [t, uid] = ordering[index];
|
||||
timestamps[index] = t;
|
||||
uids[index] = uid;
|
||||
}
|
||||
|
@ -238,9 +240,10 @@ SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t *uids_r
|
|||
// return Status::OK();
|
||||
}
|
||||
|
||||
template<typename RecordType>
|
||||
int64_t get_barrier(const RecordType &record, Timestamp timestamp) {
|
||||
auto &vec = record.timestamps_;
|
||||
template <typename RecordType>
|
||||
int64_t
|
||||
get_barrier(const RecordType& record, Timestamp timestamp) {
|
||||
auto& vec = record.timestamps_;
|
||||
int64_t beg = 0;
|
||||
int64_t end = record.ack_responder_.GetAck();
|
||||
while (beg < end) {
|
||||
|
@ -255,15 +258,15 @@ int64_t get_barrier(const RecordType &record, Timestamp timestamp) {
|
|||
}
|
||||
|
||||
Status
|
||||
SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) {
|
||||
SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
|
||||
auto ins_barrier = get_barrier(record_, timestamp);
|
||||
auto del_barrier = get_barrier(deleted_record_, timestamp);
|
||||
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier, true);
|
||||
Assert(bitmap_holder);
|
||||
Assert(bitmap_holder->bitmap_ptr->capacity() == ins_barrier);
|
||||
Assert(bitmap_holder->bitmap_ptr->count() == ins_barrier);
|
||||
|
||||
auto field_offset = schema_->get_offset(query_info->field_name);
|
||||
auto &field = schema_->operator[](query_info->field_name);
|
||||
auto& field = schema_->operator[](query_info->field_name);
|
||||
|
||||
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
||||
auto dim = field.get_dim();
|
||||
|
@ -280,7 +283,7 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe
|
|||
conf[milvus::knowhere::meta::TOPK] = query_info->topK;
|
||||
{
|
||||
auto count = 0;
|
||||
for (int i = 0; i < bitmap->capacity(); ++i) {
|
||||
for (int i = 0; i < bitmap->count(); ++i) {
|
||||
if (bitmap->test(i)) {
|
||||
++count;
|
||||
}
|
||||
|
@ -291,10 +294,10 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe
|
|||
auto indexing = std::static_pointer_cast<knowhere::VecIndex>(indexings_[index_entry.index_name]);
|
||||
indexing->SetBlacklist(bitmap);
|
||||
auto ds = knowhere::GenDataset(query_info->num_queries, dim, query_info->query_raw_data.data());
|
||||
auto final = indexing->Query(ds, conf);
|
||||
auto final = indexing->Query(ds, conf, bitmap);
|
||||
|
||||
auto ids = final->Get<idx_t *>(knowhere::meta::IDS);
|
||||
auto distances = final->Get<float *>(knowhere::meta::DISTANCE);
|
||||
auto ids = final->Get<idx_t*>(knowhere::meta::IDS);
|
||||
auto distances = final->Get<float*>(knowhere::meta::DISTANCE);
|
||||
|
||||
auto total_num = num_queries * topK;
|
||||
result.result_ids_.resize(total_num);
|
||||
|
@ -307,7 +310,7 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe
|
|||
std::copy_n(ids, total_num, result.result_ids_.data());
|
||||
std::copy_n(distances, total_num, result.result_distances_.data());
|
||||
|
||||
for (auto &id: result.result_ids_) {
|
||||
for (auto& id : result.result_ids_) {
|
||||
id = record_.uids_[id];
|
||||
}
|
||||
|
||||
|
@ -315,8 +318,13 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe
|
|||
}
|
||||
|
||||
void
|
||||
merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const float *new_distances, const int64_t *new_uids) {
|
||||
for(int64_t qn = 0; qn < queries; ++qn) {
|
||||
merge_into(int64_t queries,
|
||||
int64_t topk,
|
||||
float* distances,
|
||||
int64_t* uids,
|
||||
const float* new_distances,
|
||||
const int64_t* new_uids) {
|
||||
for (int64_t qn = 0; qn < queries; ++qn) {
|
||||
auto base = qn * topk;
|
||||
auto src2_dis = distances + base;
|
||||
auto src2_uids = uids + base;
|
||||
|
@ -330,8 +338,8 @@ merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const
|
|||
auto it1 = 0;
|
||||
auto it2 = 0;
|
||||
|
||||
for(auto buf = 0; buf < topk; ++buf){
|
||||
if(src1_dis[it1] <= src2_dis[it2]) {
|
||||
for (auto buf = 0; buf < topk; ++buf) {
|
||||
if (src1_dis[it1] <= src2_dis[it2]) {
|
||||
buf_dis[buf] = src1_dis[it1];
|
||||
buf_uids[buf] = src1_uids[it1];
|
||||
++it1;
|
||||
|
@ -347,13 +355,13 @@ merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const
|
|||
}
|
||||
|
||||
Status
|
||||
SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &results) {
|
||||
SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) {
|
||||
auto ins_barrier = get_barrier(record_, timestamp);
|
||||
auto del_barrier = get_barrier(deleted_record_, timestamp);
|
||||
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier);
|
||||
Assert(bitmap_holder);
|
||||
|
||||
auto &field = schema_->operator[](query_info->field_name);
|
||||
auto& field = schema_->operator[](query_info->field_name);
|
||||
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
||||
auto dim = field.get_dim();
|
||||
auto bitmap = bitmap_holder->bitmap_ptr;
|
||||
|
@ -375,15 +383,15 @@ SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestam
|
|||
std::vector<int64_t> buf_uids(total_count, -1);
|
||||
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
|
||||
|
||||
faiss::float_maxheap_array_t buf = {
|
||||
(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()};
|
||||
faiss::float_maxheap_array_t buf = {(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()};
|
||||
|
||||
auto src_data = vec_ptr->get_chunk(chunk_id).data();
|
||||
auto nsize = chunk_id != max_chunk - 1? DefaultElementPerChunk: ins_barrier - chunk_id * DefaultElementPerChunk;
|
||||
auto nsize =
|
||||
chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk;
|
||||
auto offset = chunk_id * DefaultElementPerChunk;
|
||||
|
||||
faiss::knn_L2sqr(query_info->query_raw_data.data(), src_data, dim, num_queries, nsize, &buf, bitmap, offset);
|
||||
if(chunk_id == 0) {
|
||||
if (chunk_id == 0) {
|
||||
final_uids = buf_uids;
|
||||
final_dis = buf_dis;
|
||||
} else {
|
||||
|
@ -391,8 +399,7 @@ SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestam
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
for(auto& id: final_uids) {
|
||||
for (auto& id : final_uids) {
|
||||
id = record_.uids_[id];
|
||||
}
|
||||
|
||||
|
@ -402,20 +409,18 @@ SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestam
|
|||
results.num_queries_ = num_queries;
|
||||
results.row_num_ = total_count;
|
||||
|
||||
// throw std::runtime_error("unimplemented");
|
||||
// throw std::runtime_error("unimplemented");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
Status
|
||||
SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) {
|
||||
|
||||
SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
|
||||
auto ins_barrier = get_barrier(record_, timestamp);
|
||||
auto del_barrier = get_barrier(deleted_record_, timestamp);
|
||||
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier);
|
||||
Assert(bitmap_holder);
|
||||
|
||||
auto &field = schema_->operator[](query_info->field_name);
|
||||
auto& field = schema_->operator[](query_info->field_name);
|
||||
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
||||
auto dim = field.get_dim();
|
||||
auto bitmap = bitmap_holder->bitmap_ptr;
|
||||
|
@ -428,7 +433,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
|
|||
auto vec_ptr = std::static_pointer_cast<ConcurrentVector<float>>(record_.entity_vec_.at(the_offset_opt.value()));
|
||||
std::vector<std::priority_queue<std::pair<float, int>>> records(num_queries);
|
||||
|
||||
auto get_L2_distance = [dim](const float *a, const float *b) {
|
||||
auto get_L2_distance = [dim](const float* a, const float* b) {
|
||||
float L2_distance = 0;
|
||||
for (auto i = 0; i < dim; ++i) {
|
||||
auto d = a[i] - b[i];
|
||||
|
@ -438,14 +443,14 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
|
|||
};
|
||||
|
||||
for (int64_t i = 0; i < ins_barrier; ++i) {
|
||||
if (i < bitmap->capacity() && bitmap->test(i)) {
|
||||
if (i < bitmap->count() && bitmap->test(i)) {
|
||||
continue;
|
||||
}
|
||||
auto element = vec_ptr->get_element(i);
|
||||
for (auto query_id = 0; query_id < num_queries; ++query_id) {
|
||||
auto query_blob = query_info->query_raw_data.data() + query_id * dim;
|
||||
auto dis = get_L2_distance(query_blob, element);
|
||||
auto &record = records[query_id];
|
||||
auto& record = records[query_id];
|
||||
if (record.size() < topK) {
|
||||
record.emplace(dis, i);
|
||||
} else if (record.top().first > dis) {
|
||||
|
@ -455,7 +460,6 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
result.num_queries_ = num_queries;
|
||||
result.topK_ = topK;
|
||||
auto row_num = topK * num_queries;
|
||||
|
@ -468,7 +472,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
|
|||
// reverse
|
||||
for (int i = 0; i < topK; ++i) {
|
||||
auto dst_id = topK - 1 - i + q_id * topK;
|
||||
auto[dis, offset] = records[q_id].top();
|
||||
auto [dis, offset] = records[q_id].top();
|
||||
records[q_id].pop();
|
||||
result.result_ids_[dst_id] = record_.uids_[offset];
|
||||
result.result_distances_[dst_id] = dis;
|
||||
|
@ -479,7 +483,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
|
|||
}
|
||||
|
||||
Status
|
||||
SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) {
|
||||
SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
|
||||
// TODO: enable delete
|
||||
// TODO: enable index
|
||||
// TODO: remove mock
|
||||
|
@ -493,7 +497,7 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult
|
|||
std::default_random_engine e(42);
|
||||
std::uniform_real_distribution<> dis(0.0, 1.0);
|
||||
query_info->query_raw_data.resize(query_info->num_queries * dim);
|
||||
for (auto &x: query_info->query_raw_data) {
|
||||
for (auto& x : query_info->query_raw_data) {
|
||||
x = dis(e);
|
||||
}
|
||||
}
|
||||
|
@ -517,8 +521,9 @@ SegmentNaive::Close() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template<typename Type>
|
||||
knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry) {
|
||||
template <typename Type>
|
||||
knowhere::IndexPtr
|
||||
SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
|
||||
auto offset_opt = schema_->get_offset(entry.field_name);
|
||||
Assert(offset_opt.has_value());
|
||||
auto offset = offset_opt.value();
|
||||
|
@ -528,7 +533,7 @@ knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry
|
|||
auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode);
|
||||
auto chunk_size = record_.uids_.chunk_size();
|
||||
|
||||
auto &uids = record_.uids_;
|
||||
auto& uids = record_.uids_;
|
||||
auto entities = record_.get_vec_entity<float>(offset);
|
||||
|
||||
std::vector<knowhere::DatasetPtr> datasets;
|
||||
|
@ -538,10 +543,10 @@ knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry
|
|||
: DefaultElementPerChunk;
|
||||
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
|
||||
}
|
||||
for (auto &ds: datasets) {
|
||||
for (auto& ds : datasets) {
|
||||
indexing->Train(ds, entry.config);
|
||||
}
|
||||
for (auto &ds: datasets) {
|
||||
for (auto& ds : datasets) {
|
||||
indexing->AddWithoutIds(ds, entry.config);
|
||||
}
|
||||
return indexing;
|
||||
|
@ -555,7 +560,7 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) {
|
|||
int dim = 0;
|
||||
std::string index_field_name;
|
||||
|
||||
for (auto& field: schema_->get_fields()) {
|
||||
for (auto& field : schema_->get_fields()) {
|
||||
if (field.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
dim = field.get_dim();
|
||||
index_field_name = field.get_name();
|
||||
|
@ -569,28 +574,24 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) {
|
|||
// TODO: this is merge of query conf and insert conf
|
||||
// TODO: should be splitted into multiple configs
|
||||
auto conf = milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, dim},
|
||||
{milvus::knowhere::IndexParams::nlist, 100},
|
||||
{milvus::knowhere::IndexParams::nprobe, 4},
|
||||
{milvus::knowhere::IndexParams::m, 4},
|
||||
{milvus::knowhere::IndexParams::nbits, 8},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{milvus::knowhere::meta::DEVICEID, 0},
|
||||
{milvus::knowhere::meta::DIM, dim}, {milvus::knowhere::IndexParams::nlist, 100},
|
||||
{milvus::knowhere::IndexParams::nprobe, 4}, {milvus::knowhere::IndexParams::m, 4},
|
||||
{milvus::knowhere::IndexParams::nbits, 8}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{milvus::knowhere::meta::DEVICEID, 0},
|
||||
};
|
||||
index_meta->AddEntry("fakeindex", index_field_name, knowhere::IndexEnum::INDEX_FAISS_IVFPQ,
|
||||
knowhere::IndexMode::MODE_CPU, conf);
|
||||
remote_index_meta = index_meta;
|
||||
}
|
||||
|
||||
|
||||
if(record_.ack_responder_.GetAck() < 1024 * 4) {
|
||||
if (record_.ack_responder_.GetAck() < 1024 * 4) {
|
||||
return Status(SERVER_BUILD_INDEX_ERROR, "too few elements");
|
||||
}
|
||||
|
||||
index_meta_ = remote_index_meta;
|
||||
for (auto&[index_name, entry]: index_meta_->get_entries()) {
|
||||
for (auto& [index_name, entry] : index_meta_->get_entries()) {
|
||||
Assert(entry.index_name == index_name);
|
||||
const auto &field = (*schema_)[entry.field_name];
|
||||
const auto& field = (*schema_)[entry.field_name];
|
||||
|
||||
if (field.is_vector()) {
|
||||
Assert(field.get_data_type() == engine::DataType::VECTOR_FLOAT);
|
||||
|
@ -608,9 +609,9 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) {
|
|||
int64_t
|
||||
SegmentNaive::GetMemoryUsageInBytes() {
|
||||
int64_t total_bytes = 0;
|
||||
if(index_ready_) {
|
||||
if (index_ready_) {
|
||||
auto& index_entries = index_meta_->get_entries();
|
||||
for(auto [index_name, entry]: index_entries) {
|
||||
for (auto [index_name, entry] : index_entries) {
|
||||
Assert(schema_->operator[](entry.field_name).is_vector());
|
||||
auto vec_ptr = std::static_pointer_cast<knowhere::VecIndex>(indexings_[index_name]);
|
||||
total_bytes += vec_ptr->IndexSize();
|
||||
|
|
|
@ -21,12 +21,12 @@ struct ColumnBasedDataChunk {
|
|||
std::vector<std::vector<float>> entity_vecs;
|
||||
|
||||
static ColumnBasedDataChunk
|
||||
from(const DogDataChunk &source, const Schema &schema) {
|
||||
from(const DogDataChunk& source, const Schema& schema) {
|
||||
ColumnBasedDataChunk dest;
|
||||
auto count = source.count;
|
||||
auto raw_data = reinterpret_cast<const char *>(source.raw_data);
|
||||
auto raw_data = reinterpret_cast<const char*>(source.raw_data);
|
||||
auto align = source.sizeof_per_row;
|
||||
for (auto &field : schema) {
|
||||
for (auto& field : schema) {
|
||||
auto len = field.get_sizeof();
|
||||
Assert(len % sizeof(float) == 0);
|
||||
std::vector<float> new_col(len * count / sizeof(float));
|
||||
|
@ -42,28 +42,33 @@ struct ColumnBasedDataChunk {
|
|||
};
|
||||
|
||||
class SegmentNaive : public SegmentBase {
|
||||
public:
|
||||
public:
|
||||
virtual ~SegmentNaive() = default;
|
||||
|
||||
// SegmentBase(std::shared_ptr<FieldsInfo> collection);
|
||||
|
||||
int64_t PreInsert(int64_t size) override;
|
||||
int64_t
|
||||
PreInsert(int64_t size) override;
|
||||
|
||||
// TODO: originally, id should be put into data_chunk
|
||||
// TODO: Is it ok to put them the other side?
|
||||
Status
|
||||
Insert(int64_t reserverd_offset, int64_t size, const int64_t *primary_keys, const Timestamp *timestamps,
|
||||
const DogDataChunk &values) override;
|
||||
Insert(int64_t reserverd_offset,
|
||||
int64_t size,
|
||||
const int64_t* primary_keys,
|
||||
const Timestamp* timestamps,
|
||||
const DogDataChunk& values) override;
|
||||
|
||||
int64_t PreDelete(int64_t size) override;
|
||||
int64_t
|
||||
PreDelete(int64_t size) override;
|
||||
|
||||
// TODO: add id into delete log, possibly bitmap
|
||||
Status
|
||||
Delete(int64_t reserverd_offset, int64_t size, const int64_t *primary_keys, const Timestamp *timestamps) override;
|
||||
Delete(int64_t reserverd_offset, int64_t size, const int64_t* primary_keys, const Timestamp* timestamps) override;
|
||||
|
||||
// query contains metadata of
|
||||
Status
|
||||
Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult &results) override;
|
||||
Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
|
||||
|
||||
// stop receive insert requests
|
||||
// will move data to immutable vector or something
|
||||
|
@ -87,7 +92,7 @@ public:
|
|||
}
|
||||
|
||||
Status
|
||||
LoadRawData(std::string_view field_name, const char *blob, int64_t blob_size) override {
|
||||
LoadRawData(std::string_view field_name, const char* blob, int64_t blob_size) override {
|
||||
// TODO: NO-OP
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -95,7 +100,7 @@ public:
|
|||
int64_t
|
||||
GetMemoryUsageInBytes() override;
|
||||
|
||||
public:
|
||||
public:
|
||||
ssize_t
|
||||
get_row_count() const override {
|
||||
return record_.ack_responder_.GetAck();
|
||||
|
@ -111,23 +116,22 @@ public:
|
|||
return 0;
|
||||
}
|
||||
|
||||
public:
|
||||
public:
|
||||
friend std::unique_ptr<SegmentBase>
|
||||
CreateSegment(SchemaPtr schema);
|
||||
|
||||
explicit SegmentNaive(SchemaPtr schema)
|
||||
: schema_(schema), record_(*schema) {
|
||||
explicit SegmentNaive(SchemaPtr schema) : schema_(schema), record_(*schema) {
|
||||
}
|
||||
|
||||
private:
|
||||
// struct MutableRecord {
|
||||
// ConcurrentVector<uint64_t> uids_;
|
||||
// tbb::concurrent_vector<Timestamp> timestamps_;
|
||||
// std::vector<tbb::concurrent_vector<float>> entity_vecs_;
|
||||
//
|
||||
// MutableRecord(int entity_size) : entity_vecs_(entity_size) {
|
||||
// }
|
||||
// };
|
||||
private:
|
||||
// struct MutableRecord {
|
||||
// ConcurrentVector<uint64_t> uids_;
|
||||
// tbb::concurrent_vector<Timestamp> timestamps_;
|
||||
// std::vector<tbb::concurrent_vector<float>> entity_vecs_;
|
||||
//
|
||||
// MutableRecord(int entity_size) : entity_vecs_(entity_size) {
|
||||
// }
|
||||
// };
|
||||
|
||||
struct Record {
|
||||
std::atomic<int64_t> reserved = 0;
|
||||
|
@ -136,31 +140,32 @@ private:
|
|||
ConcurrentVector<idx_t, true> uids_;
|
||||
std::vector<std::shared_ptr<VectorBase>> entity_vec_;
|
||||
|
||||
Record(const Schema &schema);
|
||||
Record(const Schema& schema);
|
||||
|
||||
template<typename Type>
|
||||
auto get_vec_entity(int offset) {
|
||||
template <typename Type>
|
||||
auto
|
||||
get_vec_entity(int offset) {
|
||||
return std::static_pointer_cast<ConcurrentVector<Type>>(entity_vec_[offset]);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
std::shared_ptr<DeletedRecord::TmpBitmap>
|
||||
get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier, bool force = false);
|
||||
|
||||
Status
|
||||
QueryImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results);
|
||||
QueryImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results);
|
||||
|
||||
Status
|
||||
QuerySlowImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results);
|
||||
QuerySlowImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results);
|
||||
|
||||
Status
|
||||
QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results);
|
||||
QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results);
|
||||
|
||||
template<typename Type>
|
||||
knowhere::IndexPtr BuildVecIndexImpl(const IndexMeta::Entry &entry);
|
||||
template <typename Type>
|
||||
knowhere::IndexPtr
|
||||
BuildVecIndexImpl(const IndexMeta::Entry& entry);
|
||||
|
||||
private:
|
||||
private:
|
||||
SchemaPtr schema_;
|
||||
std::atomic<SegmentState> state_ = SegmentState::Open;
|
||||
Record record_;
|
||||
|
@ -168,7 +173,7 @@ private:
|
|||
|
||||
std::atomic<bool> index_ready_ = false;
|
||||
IndexMetaPtr index_meta_;
|
||||
std::unordered_map<std::string, knowhere::IndexPtr> indexings_; // index_name => indexing
|
||||
std::unordered_map<std::string, knowhere::IndexPtr> indexings_; // index_name => indexing
|
||||
tbb::concurrent_unordered_multimap<idx_t, int64_t> uid2offset_;
|
||||
};
|
||||
} // namespace milvus::dog_segment
|
||||
|
|
|
@ -3,28 +3,28 @@
|
|||
|
||||
CCollection
|
||||
NewCollection(const char* collection_name, const char* schema_conf) {
|
||||
auto name = std::string(collection_name);
|
||||
auto conf = std::string(schema_conf);
|
||||
auto name = std::string(collection_name);
|
||||
auto conf = std::string(schema_conf);
|
||||
|
||||
auto collection = std::make_unique<milvus::dog_segment::Collection>(name, conf);
|
||||
auto collection = std::make_unique<milvus::dog_segment::Collection>(name, conf);
|
||||
|
||||
// TODO: delete print
|
||||
std::cout << "create collection " << collection_name << std::endl;
|
||||
return (void*)collection.release();
|
||||
// TODO: delete print
|
||||
std::cout << "create collection " << collection_name << std::endl;
|
||||
return (void*)collection.release();
|
||||
}
|
||||
|
||||
void
|
||||
DeleteCollection(CCollection collection) {
|
||||
auto col = (milvus::dog_segment::Collection*)collection;
|
||||
auto col = (milvus::dog_segment::Collection*)collection;
|
||||
|
||||
// TODO: delete print
|
||||
std::cout << "delete collection " << col->get_collection_name() << std::endl;
|
||||
delete col;
|
||||
// TODO: delete print
|
||||
std::cout << "delete collection " << col->get_collection_name() << std::endl;
|
||||
delete col;
|
||||
}
|
||||
|
||||
void
|
||||
UpdateIndexes(CCollection c_collection, const char *index_string) {
|
||||
auto c = (milvus::dog_segment::Collection*)c_collection;
|
||||
std::string s(index_string);
|
||||
c->CreateIndex(s);
|
||||
UpdateIndexes(CCollection c_collection, const char* index_string) {
|
||||
auto c = (milvus::dog_segment::Collection*)c_collection;
|
||||
std::string s(index_string);
|
||||
c->CreateIndex(s);
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ void
|
|||
DeleteCollection(CCollection collection);
|
||||
|
||||
void
|
||||
UpdateIndexes(CCollection c_collection, const char *index_string);
|
||||
UpdateIndexes(CCollection c_collection, const char* index_string);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -4,26 +4,26 @@
|
|||
|
||||
CPartition
|
||||
NewPartition(CCollection collection, const char* partition_name) {
|
||||
auto c = (milvus::dog_segment::Collection*)collection;
|
||||
auto c = (milvus::dog_segment::Collection*)collection;
|
||||
|
||||
auto name = std::string(partition_name);
|
||||
auto name = std::string(partition_name);
|
||||
|
||||
auto schema = c->get_schema();
|
||||
auto schema = c->get_schema();
|
||||
|
||||
auto index = c->get_index();
|
||||
auto index = c->get_index();
|
||||
|
||||
auto partition = std::make_unique<milvus::dog_segment::Partition>(name, schema, index);
|
||||
auto partition = std::make_unique<milvus::dog_segment::Partition>(name, schema, index);
|
||||
|
||||
// TODO: delete print
|
||||
std::cout << "create partition " << name << std::endl;
|
||||
return (void*)partition.release();
|
||||
// TODO: delete print
|
||||
std::cout << "create partition " << name << std::endl;
|
||||
return (void*)partition.release();
|
||||
}
|
||||
|
||||
void
|
||||
DeletePartition(CPartition partition) {
|
||||
auto p = (milvus::dog_segment::Partition*)partition;
|
||||
auto p = (milvus::dog_segment::Partition*)partition;
|
||||
|
||||
// TODO: delete print
|
||||
std::cout << "delete partition " << p->get_partition_name() <<std::endl;
|
||||
delete p;
|
||||
// TODO: delete print
|
||||
std::cout << "delete partition " << p->get_partition_name() << std::endl;
|
||||
delete p;
|
||||
}
|
||||
|
|
|
@ -8,89 +8,83 @@
|
|||
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||
#include <knowhere/index/vector_index/VecIndexFactory.h>
|
||||
|
||||
|
||||
CSegmentBase
|
||||
NewSegment(CPartition partition, unsigned long segment_id) {
|
||||
auto p = (milvus::dog_segment::Partition*)partition;
|
||||
auto p = (milvus::dog_segment::Partition*)partition;
|
||||
|
||||
auto segment = milvus::dog_segment::CreateSegment(p->get_schema());
|
||||
auto segment = milvus::dog_segment::CreateSegment(p->get_schema());
|
||||
|
||||
// TODO: delete print
|
||||
std::cout << "create segment " << segment_id << std::endl;
|
||||
return (void*)segment.release();
|
||||
// TODO: delete print
|
||||
std::cout << "create segment " << segment_id << std::endl;
|
||||
return (void*)segment.release();
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
DeleteSegment(CSegmentBase segment) {
|
||||
auto s = (milvus::dog_segment::SegmentBase*)segment;
|
||||
auto s = (milvus::dog_segment::SegmentBase*)segment;
|
||||
|
||||
// TODO: delete print
|
||||
std::cout << "delete segment " << std::endl;
|
||||
delete s;
|
||||
// TODO: delete print
|
||||
std::cout << "delete segment " << std::endl;
|
||||
delete s;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
int
|
||||
Insert(CSegmentBase c_segment,
|
||||
long int reserved_offset,
|
||||
signed long int size,
|
||||
const long* primary_keys,
|
||||
const unsigned long* timestamps,
|
||||
void* raw_data,
|
||||
int sizeof_per_row,
|
||||
signed long int count) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
milvus::dog_segment::DogDataChunk dataChunk{};
|
||||
long int reserved_offset,
|
||||
signed long int size,
|
||||
const long* primary_keys,
|
||||
const unsigned long* timestamps,
|
||||
void* raw_data,
|
||||
int sizeof_per_row,
|
||||
signed long int count) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
milvus::dog_segment::DogDataChunk dataChunk{};
|
||||
|
||||
dataChunk.raw_data = raw_data;
|
||||
dataChunk.sizeof_per_row = sizeof_per_row;
|
||||
dataChunk.count = count;
|
||||
dataChunk.raw_data = raw_data;
|
||||
dataChunk.sizeof_per_row = sizeof_per_row;
|
||||
dataChunk.count = count;
|
||||
|
||||
auto res = segment->Insert(reserved_offset, size, primary_keys, timestamps, dataChunk);
|
||||
auto res = segment->Insert(reserved_offset, size, primary_keys, timestamps, dataChunk);
|
||||
|
||||
// TODO: delete print
|
||||
// std::cout << "do segment insert, sizeof_per_row = " << sizeof_per_row << std::endl;
|
||||
return res.code();
|
||||
// TODO: delete print
|
||||
// std::cout << "do segment insert, sizeof_per_row = " << sizeof_per_row << std::endl;
|
||||
return res.code();
|
||||
}
|
||||
|
||||
|
||||
long int
|
||||
PreInsert(CSegmentBase c_segment, long int size) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
|
||||
// TODO: delete print
|
||||
// std::cout << "PreInsert segment " << std::endl;
|
||||
return segment->PreInsert(size);
|
||||
// TODO: delete print
|
||||
// std::cout << "PreInsert segment " << std::endl;
|
||||
return segment->PreInsert(size);
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
Delete(CSegmentBase c_segment,
|
||||
long int reserved_offset,
|
||||
long size,
|
||||
const long* primary_keys,
|
||||
const unsigned long* timestamps) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
long int reserved_offset,
|
||||
long size,
|
||||
const long* primary_keys,
|
||||
const unsigned long* timestamps) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
|
||||
auto res = segment->Delete(reserved_offset, size, primary_keys, timestamps);
|
||||
return res.code();
|
||||
auto res = segment->Delete(reserved_offset, size, primary_keys, timestamps);
|
||||
return res.code();
|
||||
}
|
||||
|
||||
|
||||
long int
|
||||
PreDelete(CSegmentBase c_segment, long int size) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
|
||||
// TODO: delete print
|
||||
// std::cout << "PreDelete segment " << std::endl;
|
||||
return segment->PreDelete(size);
|
||||
// TODO: delete print
|
||||
// std::cout << "PreDelete segment " << std::endl;
|
||||
return segment->PreDelete(size);
|
||||
}
|
||||
|
||||
|
||||
//int
|
||||
//Search(CSegmentBase c_segment,
|
||||
// int
|
||||
// Search(CSegmentBase c_segment,
|
||||
// const char* query_json,
|
||||
// unsigned long timestamp,
|
||||
// float* query_raw_data,
|
||||
|
@ -125,41 +119,42 @@ PreDelete(CSegmentBase c_segment, long int size) {
|
|||
|
||||
int
|
||||
Search(CSegmentBase c_segment,
|
||||
CQueryInfo c_query_info,
|
||||
unsigned long timestamp,
|
||||
float* query_raw_data,
|
||||
int num_of_query_raw_data,
|
||||
long int* result_ids,
|
||||
float* result_distances) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
milvus::dog_segment::QueryResult query_result;
|
||||
CQueryInfo c_query_info,
|
||||
unsigned long timestamp,
|
||||
float* query_raw_data,
|
||||
int num_of_query_raw_data,
|
||||
long int* result_ids,
|
||||
float* result_distances) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
milvus::dog_segment::QueryResult query_result;
|
||||
|
||||
// construct QueryPtr
|
||||
auto query_ptr = std::make_shared<milvus::query::Query>();
|
||||
query_ptr->num_queries = c_query_info.num_queries;
|
||||
query_ptr->topK = c_query_info.topK;
|
||||
query_ptr->field_name = c_query_info.field_name;
|
||||
// construct QueryPtr
|
||||
auto query_ptr = std::make_shared<milvus::query::Query>();
|
||||
|
||||
query_ptr->query_raw_data.resize(num_of_query_raw_data);
|
||||
memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float));
|
||||
query_ptr->num_queries = c_query_info.num_queries;
|
||||
query_ptr->topK = c_query_info.topK;
|
||||
query_ptr->field_name = c_query_info.field_name;
|
||||
|
||||
auto res = segment->Query(query_ptr, timestamp, query_result);
|
||||
query_ptr->query_raw_data.resize(num_of_query_raw_data);
|
||||
memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float));
|
||||
|
||||
// result_ids and result_distances have been allocated memory in goLang,
|
||||
// so we don't need to malloc here.
|
||||
memcpy(result_ids, query_result.result_ids_.data(), query_result.row_num_ * sizeof(long int));
|
||||
memcpy(result_distances, query_result.result_distances_.data(), query_result.row_num_ * sizeof(float));
|
||||
auto res = segment->Query(query_ptr, timestamp, query_result);
|
||||
|
||||
return res.code();
|
||||
// result_ids and result_distances have been allocated memory in goLang,
|
||||
// so we don't need to malloc here.
|
||||
memcpy(result_ids, query_result.result_ids_.data(), query_result.row_num_ * sizeof(long int));
|
||||
memcpy(result_distances, query_result.result_distances_.data(), query_result.row_num_ * sizeof(float));
|
||||
|
||||
return res.code();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
int
|
||||
Close(CSegmentBase c_segment) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto status = segment->Close();
|
||||
return status.code();
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto status = segment->Close();
|
||||
return status.code();
|
||||
}
|
||||
|
||||
int
|
||||
|
@ -171,34 +166,32 @@ BuildIndex(CCollection c_collection, CSegmentBase c_segment) {
|
|||
return status.code();
|
||||
}
|
||||
|
||||
|
||||
bool
|
||||
IsOpened(CSegmentBase c_segment) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto status = segment->get_state();
|
||||
return status == milvus::dog_segment::SegmentBase::SegmentState::Open;
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto status = segment->get_state();
|
||||
return status == milvus::dog_segment::SegmentBase::SegmentState::Open;
|
||||
}
|
||||
|
||||
long int
|
||||
GetMemoryUsageInBytes(CSegmentBase c_segment) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto mem_size = segment->GetMemoryUsageInBytes();
|
||||
return mem_size;
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto mem_size = segment->GetMemoryUsageInBytes();
|
||||
return mem_size;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
long int
|
||||
GetRowCount(CSegmentBase c_segment) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto row_count = segment->get_row_count();
|
||||
return row_count;
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto row_count = segment->get_row_count();
|
||||
return row_count;
|
||||
}
|
||||
|
||||
|
||||
long int
|
||||
GetDeletedCount(CSegmentBase c_segment) {
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto deleted_count = segment->get_deleted_count();
|
||||
return deleted_count;
|
||||
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
|
||||
auto deleted_count = segment->get_deleted_count();
|
||||
return deleted_count;
|
||||
}
|
||||
|
|
|
@ -23,29 +23,29 @@ DeleteSegment(CSegmentBase segment);
|
|||
|
||||
int
|
||||
Insert(CSegmentBase c_segment,
|
||||
long int reserved_offset,
|
||||
signed long int size,
|
||||
const long* primary_keys,
|
||||
const unsigned long* timestamps,
|
||||
void* raw_data,
|
||||
int sizeof_per_row,
|
||||
signed long int count);
|
||||
long int reserved_offset,
|
||||
signed long int size,
|
||||
const long* primary_keys,
|
||||
const unsigned long* timestamps,
|
||||
void* raw_data,
|
||||
int sizeof_per_row,
|
||||
signed long int count);
|
||||
|
||||
long int
|
||||
PreInsert(CSegmentBase c_segment, long int size);
|
||||
|
||||
int
|
||||
Delete(CSegmentBase c_segment,
|
||||
long int reserved_offset,
|
||||
long size,
|
||||
const long* primary_keys,
|
||||
const unsigned long* timestamps);
|
||||
long int reserved_offset,
|
||||
long size,
|
||||
const long* primary_keys,
|
||||
const unsigned long* timestamps);
|
||||
|
||||
long int
|
||||
PreDelete(CSegmentBase c_segment, long int size);
|
||||
|
||||
//int
|
||||
//Search(CSegmentBase c_segment,
|
||||
// int
|
||||
// Search(CSegmentBase c_segment,
|
||||
// const char* query_json,
|
||||
// unsigned long timestamp,
|
||||
// float* query_raw_data,
|
||||
|
@ -55,7 +55,7 @@ PreDelete(CSegmentBase c_segment, long int size);
|
|||
|
||||
int
|
||||
Search(CSegmentBase c_segment,
|
||||
CQueryInfo c_query_info,
|
||||
CQueryInfo c_query_info,
|
||||
unsigned long timestamp,
|
||||
float* query_raw_data,
|
||||
int num_of_query_raw_data,
|
||||
|
|
|
@ -52,7 +52,18 @@ include(BuildUtilsCore)
|
|||
|
||||
using_ccache_if_defined( KNOWHERE_USE_CCACHE )
|
||||
|
||||
message(STATUS "Building Knowhere CPU version")
|
||||
if (MILVUS_GPU_VERSION)
|
||||
message(STATUS "Building Knowhere GPU version")
|
||||
add_compile_definitions("MILVUS_GPU_VERSION")
|
||||
enable_language(CUDA)
|
||||
find_package(CUDA 10 REQUIRED)
|
||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fPIC -std=c++11 -D_FORCE_INLINES --expt-extended-lambda")
|
||||
if ( CCACHE_FOUND )
|
||||
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_FOUND}")
|
||||
endif()
|
||||
else ()
|
||||
message(STATUS "Building Knowhere CPU version")
|
||||
endif ()
|
||||
|
||||
if (MILVUS_SUPPORT_SPTAG)
|
||||
message(STATUS "Building Knowhere with SPTAG supported")
|
||||
|
@ -63,8 +74,14 @@ include(ThirdPartyPackagesCore)
|
|||
|
||||
if (CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -DELPP_THREAD_SAFE -fopenmp")
|
||||
if (MILVUS_GPU_VERSION)
|
||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3")
|
||||
endif ()
|
||||
else ()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp")
|
||||
if (MILVUS_GPU_VERSION)
|
||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O0 -g")
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
add_subdirectory(knowhere)
|
||||
|
@ -75,10 +92,9 @@ endif ()
|
|||
|
||||
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
|
||||
|
||||
#if (KNOWHERE_BUILD_TESTS)
|
||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS")
|
||||
# add_subdirectory(unittest)
|
||||
#endif ()
|
||||
if (KNOWHERE_BUILD_TESTS)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS")
|
||||
add_subdirectory(unittest)
|
||||
endif ()
|
||||
|
||||
config_summary()
|
||||
|
||||
|
|
|
@ -13,14 +13,17 @@
|
|||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
#endif
|
||||
#include <faiss/Clustering.h>
|
||||
#include <faiss/utils/distances.h>
|
||||
|
||||
#include "config/ServerConfig.h"
|
||||
#include "faiss/FaissHook.h"
|
||||
// #include "scheduler/Utils.h"
|
||||
#include "scheduler/Utils.h"
|
||||
#include "utils/ConfigUtils.h"
|
||||
#include "utils/Error.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
// #include <fiu/fiu-local.h>
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
@ -60,9 +63,38 @@ KnowhereResource::Initialize() {
|
|||
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
|
||||
}
|
||||
|
||||
// engine config
|
||||
int64_t omp_thread = config.engine.omp_thread_num();
|
||||
|
||||
if (omp_thread > 0) {
|
||||
omp_set_num_threads(omp_thread);
|
||||
LOG_SERVER_DEBUG_ << "Specify openmp thread number: " << omp_thread;
|
||||
} else {
|
||||
int64_t sys_thread_cnt = 8;
|
||||
if (milvus::server::GetSystemAvailableThreads(sys_thread_cnt)) {
|
||||
omp_thread = static_cast<int32_t>(ceil(sys_thread_cnt * 0.5));
|
||||
omp_set_num_threads(omp_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// init faiss global variable
|
||||
int64_t use_blas_threshold = config.engine.use_blas_threshold();
|
||||
faiss::distance_compute_blas_threshold = use_blas_threshold;
|
||||
|
||||
int64_t clustering_type = config.engine.clustering_type();
|
||||
switch (clustering_type) {
|
||||
case ClusteringType::K_MEANS:
|
||||
default:
|
||||
faiss::clustering_type = faiss::ClusteringType::K_MEANS;
|
||||
break;
|
||||
case ClusteringType::K_MEANS_PLUS_PLUS:
|
||||
faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS;
|
||||
break;
|
||||
}
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
bool enable_gpu = config.gpu.enable();
|
||||
// fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false);
|
||||
fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false);
|
||||
if (!enable_gpu) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ define_option_string(KNOWHERE_DEPENDENCY_SOURCE
|
|||
"BUNDLED"
|
||||
"SYSTEM")
|
||||
|
||||
define_option(KNOWHERE_USE_CCACHE "Use ccache when compiling (if available)" OFF)
|
||||
define_option(KNOWHERE_USE_CCACHE "Use ccache when compiling (if available)" ON)
|
||||
|
||||
define_option(KNOWHERE_VERBOSE_THIRDPARTY_BUILD
|
||||
"Show output from ExternalProjects rather than just logging to files" ON)
|
||||
|
@ -82,7 +82,7 @@ define_option(KNOWHERE_WITH_OPENBLAS "Build with OpenBLAS library" ON)
|
|||
|
||||
define_option(KNOWHERE_WITH_FAISS "Build with FAISS library" ON)
|
||||
|
||||
define_option(KNOWHERE_WITH_FAISS_GPU_VERSION "Build with FAISS GPU version" OFF)
|
||||
define_option(KNOWHERE_WITH_FAISS_GPU_VERSION "Build with FAISS GPU version" ON)
|
||||
|
||||
define_option(FAISS_WITH_MKL "Build FAISS with MKL" OFF)
|
||||
|
||||
|
|
|
@ -32,8 +32,7 @@ macro(build_dependency DEPENDENCY_NAME)
|
|||
if ("${DEPENDENCY_NAME}" STREQUAL "Arrow")
|
||||
build_arrow()
|
||||
elseif ("${DEPENDENCY_NAME}" STREQUAL "GTest")
|
||||
# build_gtest()
|
||||
# find_package(GTest REQUIRED)
|
||||
find_package(GTest REQUIRED)
|
||||
elseif ("${DEPENDENCY_NAME}" STREQUAL "OpenBLAS")
|
||||
build_openblas()
|
||||
elseif ("${DEPENDENCY_NAME}" STREQUAL "FAISS")
|
||||
|
@ -216,12 +215,12 @@ else ()
|
|||
)
|
||||
endif ()
|
||||
|
||||
if (DEFINED ENV{KNOWHERE_GTEST_URL})
|
||||
set(GTEST_SOURCE_URL "$ENV{KNOWHERE_GTEST_URL}")
|
||||
else ()
|
||||
set(GTEST_SOURCE_URL
|
||||
"https://github.com/google/googletest/archive/release-${GTEST_VERSION}.tar.gz")
|
||||
endif ()
|
||||
# if (DEFINED ENV{KNOWHERE_GTEST_URL})
|
||||
# set(GTEST_SOURCE_URL "$ENV{KNOWHERE_GTEST_URL}")
|
||||
# else ()
|
||||
# set(GTEST_SOURCE_URL
|
||||
# "https://github.com/google/googletest/archive/release-${GTEST_VERSION}.tar.gz")
|
||||
# endif ()
|
||||
|
||||
if (DEFINED ENV{KNOWHERE_OPENBLAS_URL})
|
||||
set(OPENBLAS_SOURCE_URL "$ENV{KNOWHERE_OPENBLAS_URL}")
|
||||
|
@ -387,77 +386,77 @@ endif()
|
|||
# ----------------------------------------------------------------------
|
||||
# Google gtest
|
||||
|
||||
#macro(build_gtest)
|
||||
# message(STATUS "Building gtest-${GTEST_VERSION} from source")
|
||||
# set(GTEST_VENDORED TRUE)
|
||||
# set(GTEST_CMAKE_CXX_FLAGS "${EP_CXX_FLAGS}")
|
||||
#
|
||||
# if (APPLE)
|
||||
# set(GTEST_CMAKE_CXX_FLAGS
|
||||
# ${GTEST_CMAKE_CXX_FLAGS}
|
||||
# -DGTEST_USE_OWN_TR1_TUPLE=1
|
||||
# -Wno-unused-value
|
||||
# -Wno-ignored-attributes)
|
||||
# endif ()
|
||||
#
|
||||
# set(GTEST_PREFIX "${INDEX_BINARY_DIR}/googletest_ep-prefix/src/googletest_ep")
|
||||
# set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include")
|
||||
# set(GTEST_STATIC_LIB
|
||||
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}")
|
||||
# set(GTEST_MAIN_STATIC_LIB
|
||||
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}")
|
||||
#
|
||||
# set(GTEST_CMAKE_ARGS
|
||||
# ${EP_COMMON_CMAKE_ARGS}
|
||||
# "-DCMAKE_INSTALL_PREFIX=${GTEST_PREFIX}"
|
||||
# "-DCMAKE_INSTALL_LIBDIR=lib"
|
||||
# -DCMAKE_CXX_FLAGS=${GTEST_CMAKE_CXX_FLAGS}
|
||||
# -DCMAKE_BUILD_TYPE=Release)
|
||||
#
|
||||
# set(GMOCK_INCLUDE_DIR "${GTEST_PREFIX}/include")
|
||||
# set(GMOCK_STATIC_LIB
|
||||
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX}"
|
||||
# )
|
||||
#
|
||||
# ExternalProject_Add(googletest_ep
|
||||
# URL
|
||||
# ${GTEST_SOURCE_URL}
|
||||
# BUILD_COMMAND
|
||||
# ${MAKE}
|
||||
# ${MAKE_BUILD_ARGS}
|
||||
# BUILD_BYPRODUCTS
|
||||
# ${GTEST_STATIC_LIB}
|
||||
# ${GTEST_MAIN_STATIC_LIB}
|
||||
# ${GMOCK_STATIC_LIB}
|
||||
# CMAKE_ARGS
|
||||
# ${GTEST_CMAKE_ARGS}
|
||||
# ${EP_LOG_OPTIONS})
|
||||
#
|
||||
# # The include directory must exist before it is referenced by a target.
|
||||
# file(MAKE_DIRECTORY "${GTEST_INCLUDE_DIR}")
|
||||
#
|
||||
# add_library(gtest STATIC IMPORTED)
|
||||
# set_target_properties(gtest
|
||||
# PROPERTIES IMPORTED_LOCATION "${GTEST_STATIC_LIB}"
|
||||
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
|
||||
#
|
||||
# add_library(gtest_main STATIC IMPORTED)
|
||||
# set_target_properties(gtest_main
|
||||
# PROPERTIES IMPORTED_LOCATION "${GTEST_MAIN_STATIC_LIB}"
|
||||
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
|
||||
#
|
||||
# add_library(gmock STATIC IMPORTED)
|
||||
# set_target_properties(gmock
|
||||
# PROPERTIES IMPORTED_LOCATION "${GMOCK_STATIC_LIB}"
|
||||
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
|
||||
#
|
||||
# add_dependencies(gtest googletest_ep)
|
||||
# add_dependencies(gtest_main googletest_ep)
|
||||
# add_dependencies(gmock googletest_ep)
|
||||
#
|
||||
#endmacro()
|
||||
# macro(build_gtest)
|
||||
# message(STATUS "Building gtest-${GTEST_VERSION} from source")
|
||||
# set(GTEST_VENDORED TRUE)
|
||||
# set(GTEST_CMAKE_CXX_FLAGS "${EP_CXX_FLAGS}")
|
||||
#
|
||||
# if (APPLE)
|
||||
# set(GTEST_CMAKE_CXX_FLAGS
|
||||
# ${GTEST_CMAKE_CXX_FLAGS}
|
||||
# -DGTEST_USE_OWN_TR1_TUPLE=1
|
||||
# -Wno-unused-value
|
||||
# -Wno-ignored-attributes)
|
||||
# endif ()
|
||||
#
|
||||
# set(GTEST_PREFIX "${INDEX_BINARY_DIR}/googletest_ep-prefix/src/googletest_ep")
|
||||
# set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include")
|
||||
# set(GTEST_STATIC_LIB
|
||||
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}")
|
||||
# set(GTEST_MAIN_STATIC_LIB
|
||||
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}")
|
||||
#
|
||||
# set(GTEST_CMAKE_ARGS
|
||||
# ${EP_COMMON_CMAKE_ARGS}
|
||||
# "-DCMAKE_INSTALL_PREFIX=${GTEST_PREFIX}"
|
||||
# "-DCMAKE_INSTALL_LIBDIR=lib"
|
||||
# -DCMAKE_CXX_FLAGS=${GTEST_CMAKE_CXX_FLAGS}
|
||||
# -DCMAKE_BUILD_TYPE=Release)
|
||||
#
|
||||
# set(GMOCK_INCLUDE_DIR "${GTEST_PREFIX}/include")
|
||||
# set(GMOCK_STATIC_LIB
|
||||
# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX}"
|
||||
# )
|
||||
#
|
||||
# ExternalProject_Add(googletest_ep
|
||||
# URL
|
||||
# ${GTEST_SOURCE_URL}
|
||||
# BUILD_COMMAND
|
||||
# ${MAKE}
|
||||
# ${MAKE_BUILD_ARGS}
|
||||
# BUILD_BYPRODUCTS
|
||||
# ${GTEST_STATIC_LIB}
|
||||
# ${GTEST_MAIN_STATIC_LIB}
|
||||
# ${GMOCK_STATIC_LIB}
|
||||
# CMAKE_ARGS
|
||||
# ${GTEST_CMAKE_ARGS}
|
||||
# ${EP_LOG_OPTIONS})
|
||||
#
|
||||
# # The include directory must exist before it is referenced by a target.
|
||||
# file(MAKE_DIRECTORY "${GTEST_INCLUDE_DIR}")
|
||||
#
|
||||
# add_library(gtest STATIC IMPORTED)
|
||||
# set_target_properties(gtest
|
||||
# PROPERTIES IMPORTED_LOCATION "${GTEST_STATIC_LIB}"
|
||||
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
|
||||
#
|
||||
# add_library(gtest_main STATIC IMPORTED)
|
||||
# set_target_properties(gtest_main
|
||||
# PROPERTIES IMPORTED_LOCATION "${GTEST_MAIN_STATIC_LIB}"
|
||||
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
|
||||
#
|
||||
# add_library(gmock STATIC IMPORTED)
|
||||
# set_target_properties(gmock
|
||||
# PROPERTIES IMPORTED_LOCATION "${GMOCK_STATIC_LIB}"
|
||||
# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
|
||||
#
|
||||
# add_dependencies(gtest googletest_ep)
|
||||
# add_dependencies(gtest_main googletest_ep)
|
||||
# add_dependencies(gmock googletest_ep)
|
||||
#
|
||||
# endmacro()
|
||||
|
||||
# if (KNOWHERE_BUILD_TESTS AND NOT TARGET googletest_ep)
|
||||
## if (KNOWHERE_BUILD_TESTS AND NOT TARGET googletest_ep)
|
||||
#if ( NOT TARGET gtest AND KNOWHERE_BUILD_TESTS )
|
||||
# resolve_dependency(GTest)
|
||||
#
|
||||
|
@ -654,3 +653,5 @@ if (KNOWHERE_WITH_FAISS AND NOT TARGET faiss_ep)
|
|||
include_directories(SYSTEM "${FAISS_INCLUDE_DIR}")
|
||||
link_directories(SYSTEM ${FAISS_PREFIX}/lib/)
|
||||
endif ()
|
||||
|
||||
add_subdirectory(thirdparty/NGT)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
include_directories(${INDEX_SOURCE_DIR}/knowhere)
|
||||
include_directories(${INDEX_SOURCE_DIR}/thirdparty)
|
||||
include_directories(${INDEX_SOURCE_DIR}/thirdparty/NGT/lib)
|
||||
|
||||
if (MILVUS_SUPPORT_SPTAG)
|
||||
include_directories(${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService)
|
||||
|
@ -68,6 +69,9 @@ set(vector_index_srcs
|
|||
knowhere/index/vector_index/IndexRHNSWFlat.cpp
|
||||
knowhere/index/vector_index/IndexRHNSWSQ.cpp
|
||||
knowhere/index/vector_index/IndexRHNSWPQ.cpp
|
||||
knowhere/index/vector_index/IndexNGT.cpp
|
||||
knowhere/index/vector_index/IndexNGTPANNG.cpp
|
||||
knowhere/index/vector_index/IndexNGTONNG.cpp
|
||||
)
|
||||
|
||||
set(vector_offset_index_srcs
|
||||
|
@ -90,6 +94,8 @@ set(depend_libs
|
|||
gomp
|
||||
gfortran
|
||||
pthread
|
||||
fiu
|
||||
ngt
|
||||
)
|
||||
|
||||
if (MILVUS_SUPPORT_SPTAG)
|
||||
|
@ -100,6 +106,32 @@ if (MILVUS_SUPPORT_SPTAG)
|
|||
endif ()
|
||||
|
||||
|
||||
if (MILVUS_GPU_VERSION)
|
||||
include_directories(${CUDA_INCLUDE_DIRS})
|
||||
link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64")
|
||||
set(cuda_lib
|
||||
cudart
|
||||
cublas
|
||||
)
|
||||
set(depend_libs ${depend_libs}
|
||||
${cuda_lib}
|
||||
)
|
||||
|
||||
set(vector_index_srcs ${vector_index_srcs}
|
||||
knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp
|
||||
knowhere/index/vector_index/gpu/IndexGPUIVF.cpp
|
||||
knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp
|
||||
knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp
|
||||
knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp
|
||||
knowhere/index/vector_index/helpers/Cloner.cpp
|
||||
knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp
|
||||
)
|
||||
|
||||
set(vector_offset_index_srcs ${vector_offset_index_srcs}
|
||||
knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (NOT TARGET knowhere)
|
||||
add_library(
|
||||
knowhere STATIC
|
||||
|
@ -130,11 +162,3 @@ if (MILVUS_SUPPORT_SPTAG)
|
|||
endif ()
|
||||
|
||||
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
|
||||
|
||||
# **************************** Get&Print Include Directories ****************************
|
||||
get_property( dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES )
|
||||
|
||||
foreach ( dir ${dirs} )
|
||||
message( STATUS "Knowhere Current Include DIRS: " ${dir} )
|
||||
endforeach ()
|
||||
|
||||
|
|
|
@ -37,6 +37,8 @@ const char* INDEX_RHNSWFlat = "RHNSW_FLAT";
|
|||
const char* INDEX_RHNSWPQ = "RHNSW_PQ";
|
||||
const char* INDEX_RHNSWSQ = "RHNSW_SQ";
|
||||
const char* INDEX_ANNOY = "ANNOY";
|
||||
const char* INDEX_NGTPANNG = "NGT_PANNG";
|
||||
const char* INDEX_NGTONNG = "NGT_ONNG";
|
||||
} // namespace IndexEnum
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -64,6 +64,8 @@ extern const char* INDEX_RHNSWFlat;
|
|||
extern const char* INDEX_RHNSWPQ;
|
||||
extern const char* INDEX_RHNSWSQ;
|
||||
extern const char* INDEX_ANNOY;
|
||||
extern const char* INDEX_NGTPANNG;
|
||||
extern const char* INDEX_NGTONNG;
|
||||
} // namespace IndexEnum
|
||||
|
||||
enum class IndexMode { MODE_CPU = 0, MODE_GPU = 1 };
|
||||
|
|
|
@ -25,13 +25,20 @@ namespace milvus {
|
|||
namespace knowhere {
|
||||
|
||||
static const int64_t MIN_NLIST = 1;
|
||||
static const int64_t MAX_NLIST = 1LL << 20;
|
||||
static const int64_t MAX_NLIST = 65536;
|
||||
static const int64_t MIN_NPROBE = 1;
|
||||
static const int64_t MAX_NPROBE = MAX_NLIST;
|
||||
static const int64_t DEFAULT_MIN_DIM = 1;
|
||||
static const int64_t DEFAULT_MAX_DIM = 32768;
|
||||
static const int64_t DEFAULT_MIN_ROWS = 1; // minimum size for build index
|
||||
static const int64_t DEFAULT_MAX_ROWS = 50000000;
|
||||
static const int64_t NGT_MIN_EDGE_SIZE = 1;
|
||||
static const int64_t NGT_MAX_EDGE_SIZE = 200;
|
||||
static const int64_t HNSW_MIN_EFCONSTRUCTION = 8;
|
||||
static const int64_t HNSW_MAX_EFCONSTRUCTION = 512;
|
||||
static const int64_t HNSW_MIN_M = 4;
|
||||
static const int64_t HNSW_MAX_M = 64;
|
||||
static const int64_t HNSW_MAX_EF = 32768;
|
||||
static const std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::IP};
|
||||
|
||||
#define CheckIntByRange(key, min, max) \
|
||||
|
@ -146,24 +153,34 @@ IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
|||
// auto tune params
|
||||
oricfg[knowhere::IndexParams::nlist] =
|
||||
MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), oricfg[knowhere::IndexParams::nlist].get<int64_t>());
|
||||
|
||||
auto m = oricfg[knowhere::IndexParams::m].get<int64_t>();
|
||||
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
|
||||
// Best Practice
|
||||
// static int64_t MIN_POINTS_PER_CENTROID = 40;
|
||||
// static int64_t MAX_POINTS_PER_CENTROID = 256;
|
||||
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
|
||||
|
||||
std::vector<int64_t> resset;
|
||||
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
|
||||
IVFPQConfAdapter::GetValidMList(dimension, resset);
|
||||
|
||||
CheckIntByValues(knowhere::IndexParams::m, resset);
|
||||
/*std::vector<int64_t> resset;
|
||||
IVFPQConfAdapter::GetValidCPUM(dimension, resset);*/
|
||||
IndexMode ivfpq_mode = mode;
|
||||
return GetValidM(dimension, m, ivfpq_mode);
|
||||
}
|
||||
|
||||
bool
|
||||
IVFPQConfAdapter::GetValidM(int64_t dimension, int64_t m, IndexMode& mode) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (mode == knowhere::IndexMode::MODE_GPU && !IVFPQConfAdapter::GetValidGPUM(dimension, m)) {
|
||||
mode = knowhere::IndexMode::MODE_CPU;
|
||||
}
|
||||
#endif
|
||||
if (mode == knowhere::IndexMode::MODE_CPU && !IVFPQConfAdapter::GetValidCPUM(dimension, m)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset) {
|
||||
resset.clear();
|
||||
bool
|
||||
IVFPQConfAdapter::GetValidGPUM(int64_t dimension, int64_t m) {
|
||||
/*
|
||||
* Faiss 1.6
|
||||
* Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with
|
||||
|
@ -172,7 +189,14 @@ IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset)
|
|||
static const std::vector<int64_t> support_dim_per_subquantizer{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1};
|
||||
static const std::vector<int64_t> support_subquantizer{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1};
|
||||
|
||||
for (const auto& dimperquantizer : support_dim_per_subquantizer) {
|
||||
int64_t sub_dim = dimension / m;
|
||||
return (std::find(std::begin(support_subquantizer), std::end(support_subquantizer), m) !=
|
||||
support_subquantizer.end()) &&
|
||||
(std::find(std::begin(support_dim_per_subquantizer), std::end(support_dim_per_subquantizer), sub_dim) !=
|
||||
support_dim_per_subquantizer.end());
|
||||
|
||||
/*resset.clear();
|
||||
for (const auto& dimperquantizer : support_dim_per_subquantizer) {
|
||||
if (!(dimension % dimperquantizer)) {
|
||||
auto subquantzier_num = dimension / dimperquantizer;
|
||||
auto finder = std::find(support_subquantizer.begin(), support_subquantizer.end(), subquantzier_num);
|
||||
|
@ -180,7 +204,12 @@ IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset)
|
|||
resset.push_back(subquantzier_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
}*/
|
||||
}
|
||||
|
||||
bool
|
||||
IVFPQConfAdapter::GetValidCPUM(int64_t dimension, int64_t m) {
|
||||
return (dimension % m == 0);
|
||||
}
|
||||
|
||||
bool
|
||||
|
@ -222,97 +251,68 @@ NSGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMod
|
|||
|
||||
bool
|
||||
HNSWConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t MIN_EFCONSTRUCTION = 8;
|
||||
static int64_t MAX_EFCONSTRUCTION = 512;
|
||||
static int64_t MIN_M = 4;
|
||||
static int64_t MAX_M = 64;
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
HNSWConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
static int64_t MAX_EF = 4096;
|
||||
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWFlatConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t MIN_EFCONSTRUCTION = 8;
|
||||
static int64_t MAX_EFCONSTRUCTION = 512;
|
||||
static int64_t MIN_M = 4;
|
||||
static int64_t MAX_M = 64;
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWFlatConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
static int64_t MAX_EF = 4096;
|
||||
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t MIN_EFCONSTRUCTION = 8;
|
||||
static int64_t MAX_EFCONSTRUCTION = 512;
|
||||
static int64_t MIN_M = 4;
|
||||
static int64_t MAX_M = 64;
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
|
||||
|
||||
std::vector<int64_t> resset;
|
||||
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
|
||||
IVFPQConfAdapter::GetValidMList(dimension, resset);
|
||||
|
||||
CheckIntByValues(knowhere::IndexParams::PQM, resset);
|
||||
IVFPQConfAdapter::GetValidCPUM(dimension, oricfg[knowhere::IndexParams::PQM].get<int64_t>());
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWPQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
static int64_t MAX_EF = 4096;
|
||||
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t MIN_EFCONSTRUCTION = 8;
|
||||
static int64_t MAX_EFCONSTRUCTION = 512;
|
||||
static int64_t MIN_M = 4;
|
||||
static int64_t MAX_M = 64;
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M);
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
RHNSWSQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
static int64_t MAX_EF = 4096;
|
||||
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
@ -368,5 +368,39 @@ ANNOYConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexM
|
|||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
NGTPANNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
CheckIntByRange(knowhere::IndexParams::edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
NGTPANNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
bool
|
||||
NGTONNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD};
|
||||
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
CheckIntByRange(knowhere::IndexParams::edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
NGTONNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
|
||||
return ConfAdapter::CheckSearch(oricfg, type, mode);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -51,8 +51,14 @@ class IVFPQConfAdapter : public IVFConfAdapter {
|
|||
bool
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
static void
|
||||
GetValidMList(int64_t dimension, std::vector<int64_t>& resset);
|
||||
static bool
|
||||
GetValidM(int64_t dimension, int64_t m, IndexMode& mode);
|
||||
|
||||
static bool
|
||||
GetValidGPUM(int64_t dimension, int64_t m);
|
||||
|
||||
static bool
|
||||
GetValidCPUM(int64_t dimension, int64_t m);
|
||||
};
|
||||
|
||||
class NSGConfAdapter : public IVFConfAdapter {
|
||||
|
@ -120,5 +126,24 @@ class RHNSWSQConfAdapter : public ConfAdapter {
|
|||
bool
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class NGTPANNGConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
bool
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
class NGTONNGConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
bool
|
||||
CheckTrain(Config& oricfg, const IndexMode mode) override;
|
||||
|
||||
bool
|
||||
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
|
|
@ -42,7 +42,7 @@ AdapterMgr::RegisterAdapter() {
|
|||
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq8_adapter);
|
||||
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8H, ivfsq8h_adapter);
|
||||
REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IDMAP, idmap_bin_adapter);
|
||||
REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivf_bin_adapter);
|
||||
REGISTER_CONF_ADAPTER(BinIVFConfAdapter, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivf_bin_adapter);
|
||||
REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexEnum::INDEX_NSG, nsg_adapter);
|
||||
#ifdef MILVUS_SUPPORT_SPTAG
|
||||
REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_SPTAG_KDT_RNT, sptag_kdt_adapter);
|
||||
|
@ -53,6 +53,8 @@ AdapterMgr::RegisterAdapter() {
|
|||
REGISTER_CONF_ADAPTER(RHNSWFlatConfAdapter, IndexEnum::INDEX_RHNSWFlat, rhnswflat_adapter);
|
||||
REGISTER_CONF_ADAPTER(RHNSWPQConfAdapter, IndexEnum::INDEX_RHNSWPQ, rhnswpq_adapter);
|
||||
REGISTER_CONF_ADAPTER(RHNSWSQConfAdapter, IndexEnum::INDEX_RHNSWSQ, rhnswsq_adapter);
|
||||
REGISTER_CONF_ADAPTER(NGTPANNGConfAdapter, IndexEnum::INDEX_NGTPANNG, ngtpanng_adapter);
|
||||
REGISTER_CONF_ADAPTER(NGTONNGConfAdapter, IndexEnum::INDEX_NGTONNG, ngtonng_adapter);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <faiss/index_io.h>
|
||||
#include <fiu/fiu-local.h>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/IndexType.h"
|
||||
|
@ -22,6 +23,7 @@ namespace knowhere {
|
|||
BinarySet
|
||||
FaissBaseIndex::SerializeImpl(const IndexType& type) {
|
||||
try {
|
||||
fiu_do_on("FaissBaseIndex.SerializeImpl.throw_exception", throw std::exception());
|
||||
faiss::Index* index = index_.get();
|
||||
|
||||
MemoryIOWriter writer;
|
||||
|
|
|
@ -105,7 +105,7 @@ IndexAnnoy::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -116,7 +116,6 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
auto all_num = rows * k;
|
||||
auto p_id = static_cast<int64_t*>(malloc(all_num * sizeof(int64_t)));
|
||||
auto p_dist = static_cast<float*>(malloc(all_num * sizeof(float)));
|
||||
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < rows; ++i) {
|
||||
|
@ -125,7 +124,7 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
std::vector<float> distances;
|
||||
distances.reserve(k);
|
||||
index_->get_nns_by_vector(static_cast<const float*>(p_data) + i * dim, k, search_k, &result, &distances,
|
||||
blacklist);
|
||||
bitset);
|
||||
|
||||
int64_t result_num = result.size();
|
||||
auto local_p_id = p_id + k * i;
|
||||
|
|
|
@ -54,7 +54,7 @@ class IndexAnnoy : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -40,7 +40,7 @@ BinaryIDMAP::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
|
||||
auto p_dist = static_cast<float*>(malloc(p_dist_size));
|
||||
|
||||
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config);
|
||||
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config, bitset);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
|
@ -141,14 +141,19 @@ BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config)
|
|||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& config) {
|
||||
BinaryIDMAP::QueryImpl(int64_t n,
|
||||
const uint8_t* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
// assign the metric type
|
||||
auto bin_flat_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get())->index;
|
||||
bin_flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
|
||||
auto i_distances = reinterpret_cast<int32_t*>(distances);
|
||||
bin_flat_index->search(n, data, k, i_distances, labels, bitset_);
|
||||
bin_flat_index->search(n, data, k, i_distances, labels, bitset);
|
||||
|
||||
// if hamming, it need transform int32 to float
|
||||
if (bin_flat_index->metric_type == faiss::METRIC_Hamming) {
|
||||
|
|
|
@ -48,7 +48,7 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
@ -69,7 +69,13 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
|||
|
||||
protected:
|
||||
virtual void
|
||||
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
|
||||
QueryImpl(int64_t n,
|
||||
const uint8_t* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
|
|
|
@ -43,7 +43,7 @@ BinaryIVF::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
|
||||
auto p_dist = static_cast<float*>(malloc(p_dist_size));
|
||||
|
||||
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config);
|
||||
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config, bitset);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
|
||||
|
@ -126,15 +126,20 @@ BinaryIVF::GenParams(const Config& config) {
|
|||
}
|
||||
|
||||
void
|
||||
BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& config) {
|
||||
BinaryIVF::QueryImpl(int64_t n,
|
||||
const uint8_t* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
|
||||
stdclock::time_point before = stdclock::now();
|
||||
auto i_distances = reinterpret_cast<int32_t*>(distances);
|
||||
index_->search(n, data, k, i_distances, labels, bitset_);
|
||||
index_->search(n, data, k, i_distances, labels, bitset);
|
||||
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
|
|
|
@ -60,7 +60,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
@ -76,7 +76,13 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
|
|||
GenParams(const Config& config);
|
||||
|
||||
virtual void
|
||||
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
|
||||
QueryImpl(int64_t n,
|
||||
const uint8_t* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
|
|
|
@ -136,7 +136,7 @@ IndexHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -153,7 +153,6 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
using P = std::pair<float, int64_t>;
|
||||
auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; };
|
||||
|
||||
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
|
||||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < rows; ++i) {
|
||||
std::vector<P> ret;
|
||||
|
@ -166,7 +165,7 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
// } else {
|
||||
// ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
|
||||
// }
|
||||
ret = index_->searchKnn(single_query, k, compare, blacklist);
|
||||
ret = index_->searchKnn(single_query, k, compare, bitset);
|
||||
|
||||
while (ret.size() < k) {
|
||||
ret.emplace_back(std::make_pair(-1, -1));
|
||||
|
|
|
@ -46,7 +46,7 @@ class IndexHNSW : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -95,7 +95,7 @@ IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
|
||||
auto p_dist = static_cast<float*>(malloc(p_dist_size));
|
||||
|
||||
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config);
|
||||
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
|
@ -223,11 +223,17 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
#endif
|
||||
|
||||
void
|
||||
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
IDMAP::QueryImpl(int64_t n,
|
||||
const float* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
// assign the metric type
|
||||
auto flat_index = dynamic_cast<faiss::IndexIDMap*>(index_.get())->index;
|
||||
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
index_->search(n, data, k, distances, labels, bitset_);
|
||||
index_->search(n, data, k, distances, labels, bitset);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -46,7 +46,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
|
@ -80,7 +80,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
|
|||
|
||||
protected:
|
||||
virtual void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
|
|
|
@ -23,6 +23,8 @@
|
|||
#include <faiss/gpu/GpuCloner.h>
|
||||
#endif
|
||||
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -95,7 +97,7 @@ IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -103,6 +105,8 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
GET_TENSOR_DATA(dataset_ptr)
|
||||
|
||||
try {
|
||||
fiu_do_on("IVF.Search.throw_std_exception", throw std::exception());
|
||||
fiu_do_on("IVF.Search.throw_faiss_exception", throw faiss::FaissException(""));
|
||||
auto k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
|
||||
|
@ -111,7 +115,7 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
|
||||
auto p_dist = static_cast<float*>(malloc(p_dist_size));
|
||||
|
||||
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config);
|
||||
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);
|
||||
|
||||
// std::stringstream ss_res_id, ss_res_dist;
|
||||
// for (int i = 0; i < 10; ++i) {
|
||||
|
@ -292,7 +296,7 @@ IVF::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config
|
|||
res.resize(K * b_size);
|
||||
|
||||
const float* xq = data + batch_size * dim * i;
|
||||
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config);
|
||||
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr);
|
||||
|
||||
for (int j = 0; j < b_size; ++j) {
|
||||
auto& node = graph[batch_size * i + j];
|
||||
|
@ -314,17 +318,23 @@ IVF::GenParams(const Config& config) {
|
|||
}
|
||||
|
||||
void
|
||||
IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
IVF::QueryImpl(int64_t n,
|
||||
const float* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
ivf_index->nprobe = std::min(params->nprobe, ivf_index->invlists->nlist);
|
||||
stdclock::time_point before = stdclock::now();
|
||||
if (params->nprobe > 1 && n <= 4) {
|
||||
ivf_index->parallel_mode = 1;
|
||||
} else {
|
||||
ivf_index->parallel_mode = 0;
|
||||
}
|
||||
ivf_index->search(n, data, k, distances, labels, bitset_);
|
||||
ivf_index->search(n, data, k, distances, labels, bitset);
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost
|
||||
|
|
|
@ -51,7 +51,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
|
@ -86,7 +86,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
|
|||
GenParams(const Config&);
|
||||
|
||||
virtual void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&);
|
||||
|
||||
void
|
||||
SealImpl() override;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include "knowhere/index/vector_index/ConfAdapter.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h"
|
||||
#endif
|
||||
|
@ -47,6 +48,12 @@ IVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
VecIndexPtr
|
||||
IVFPQ::CopyCpuToGpu(const int64_t device_id, const Config& config) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
auto ivfpq_index = dynamic_cast<faiss::IndexIVFPQ*>(index_.get());
|
||||
int64_t dim = ivfpq_index->d;
|
||||
int64_t m = ivfpq_index->pq.M;
|
||||
if (!IVFPQConfAdapter::GetValidGPUM(dim, m)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexNGT.h"
|
||||
|
||||
#include <omp.h>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
BinarySet
|
||||
IndexNGT::Serialize(const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
std::stringstream obj, grp, prf, tre;
|
||||
index_->saveIndex(obj, grp, prf, tre);
|
||||
|
||||
auto obj_str = obj.str();
|
||||
auto grp_str = grp.str();
|
||||
auto prf_str = prf.str();
|
||||
auto tre_str = tre.str();
|
||||
uint64_t obj_size = obj_str.size();
|
||||
uint64_t grp_size = grp_str.size();
|
||||
uint64_t prf_size = prf_str.size();
|
||||
uint64_t tre_size = tre_str.size();
|
||||
|
||||
std::shared_ptr<uint8_t[]> obj_data(new uint8_t[obj_size]);
|
||||
memcpy(obj_data.get(), obj_str.data(), obj_size);
|
||||
std::shared_ptr<uint8_t[]> grp_data(new uint8_t[grp_size]);
|
||||
memcpy(grp_data.get(), grp_str.data(), grp_size);
|
||||
std::shared_ptr<uint8_t[]> prf_data(new uint8_t[prf_size]);
|
||||
memcpy(prf_data.get(), prf_str.data(), prf_size);
|
||||
std::shared_ptr<uint8_t[]> tre_data(new uint8_t[tre_size]);
|
||||
memcpy(tre_data.get(), tre_str.data(), tre_size);
|
||||
|
||||
BinarySet res_set;
|
||||
res_set.Append("ngt_obj_data", obj_data, obj_size);
|
||||
res_set.Append("ngt_grp_data", grp_data, grp_size);
|
||||
res_set.Append("ngt_prf_data", prf_data, prf_size);
|
||||
res_set.Append("ngt_tre_data", tre_data, tre_size);
|
||||
return res_set;
|
||||
}
|
||||
|
||||
void
|
||||
IndexNGT::Load(const BinarySet& index_binary) {
|
||||
auto obj_data = index_binary.GetByName("ngt_obj_data");
|
||||
std::string obj_str(reinterpret_cast<char*>(obj_data->data.get()), obj_data->size);
|
||||
|
||||
auto grp_data = index_binary.GetByName("ngt_grp_data");
|
||||
std::string grp_str(reinterpret_cast<char*>(grp_data->data.get()), grp_data->size);
|
||||
|
||||
auto prf_data = index_binary.GetByName("ngt_prf_data");
|
||||
std::string prf_str(reinterpret_cast<char*>(prf_data->data.get()), prf_data->size);
|
||||
|
||||
auto tre_data = index_binary.GetByName("ngt_tre_data");
|
||||
std::string tre_str(reinterpret_cast<char*>(tre_data->data.get()), tre_data->size);
|
||||
|
||||
std::stringstream obj(obj_str);
|
||||
std::stringstream grp(grp_str);
|
||||
std::stringstream prf(prf_str);
|
||||
std::stringstream tre(tre_str);
|
||||
|
||||
index_ = std::shared_ptr<NGT::Index>(NGT::Index::loadIndex(obj, grp, prf, tre));
|
||||
}
|
||||
|
||||
void
|
||||
IndexNGT::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
KNOWHERE_THROW_MSG("IndexNGT has no implementation of BuildAll, please use IndexNGT(PANNG/ONNG) instead!");
|
||||
}
|
||||
|
||||
#if 0
|
||||
void
|
||||
IndexNGT::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
KNOWHERE_THROW_MSG("IndexNGT has no implementation of Train, please use IndexNGT(PANNG/ONNG) instead!");
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr);
|
||||
|
||||
NGT::Property prop;
|
||||
prop.setDefaultForCreateIndex();
|
||||
prop.dimension = dim;
|
||||
|
||||
MetricType metric_type = config[Metric::TYPE];
|
||||
|
||||
if (metric_type == Metric::L2)
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
|
||||
else if (metric_type == Metric::HAMMING)
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
|
||||
else if (metric_type == Metric::JACCARD)
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
|
||||
else
|
||||
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
|
||||
index_ =
|
||||
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
|
||||
}
|
||||
|
||||
void
|
||||
IndexNGT::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset_ptr);
|
||||
|
||||
index_->append(reinterpret_cast<const float*>(p_data), rows);
|
||||
}
|
||||
#endif
|
||||
|
||||
DatasetPtr
|
||||
IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset_ptr);
|
||||
|
||||
size_t k = config[meta::TOPK].get<int64_t>();
|
||||
size_t id_size = sizeof(int64_t) * k;
|
||||
size_t dist_size = sizeof(float) * k;
|
||||
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
|
||||
auto p_dist = static_cast<float*>(malloc(dist_size * rows));
|
||||
|
||||
NGT::Command::SearchParameter sp;
|
||||
sp.size = k;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < rows; ++i) {
|
||||
const float* single_query = reinterpret_cast<float*>(const_cast<void*>(p_data)) + i * Dim();
|
||||
|
||||
NGT::Object* object = index_->allocateObject(single_query, Dim());
|
||||
NGT::SearchContainer sc(*object);
|
||||
|
||||
double epsilon = sp.beginOfEpsilon;
|
||||
|
||||
NGT::ObjectDistances res;
|
||||
sc.setResults(&res);
|
||||
sc.setSize(sp.size);
|
||||
sc.setRadius(sp.radius);
|
||||
|
||||
if (sp.accuracy > 0.0) {
|
||||
sc.setExpectedAccuracy(sp.accuracy);
|
||||
} else {
|
||||
sc.setEpsilon(epsilon);
|
||||
}
|
||||
sc.setEdgeSize(sp.edgeSize);
|
||||
|
||||
try {
|
||||
index_->search(sc, bitset);
|
||||
} catch (NGT::Exception& err) {
|
||||
KNOWHERE_THROW_MSG("Query failed");
|
||||
}
|
||||
|
||||
auto local_id = p_id + i * k;
|
||||
auto local_dist = p_dist + i * k;
|
||||
|
||||
int64_t res_num = res.size();
|
||||
for (int64_t idx = 0; idx < res_num; ++idx) {
|
||||
*(local_id + idx) = res[idx].id - 1;
|
||||
*(local_dist + idx) = res[idx].distance;
|
||||
}
|
||||
while (res_num < static_cast<int64_t>(k)) {
|
||||
*(local_id + res_num) = -1;
|
||||
*(local_dist + res_num) = 1.0 / 0.0;
|
||||
}
|
||||
index_->deleteObject(object);
|
||||
}
|
||||
|
||||
auto res_ds = std::make_shared<Dataset>();
|
||||
res_ds->Set(meta::IDS, p_id);
|
||||
res_ds->Set(meta::DISTANCE, p_dist);
|
||||
return res_ds;
|
||||
}
|
||||
|
||||
int64_t
|
||||
IndexNGT::Count() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
return index_->getNumberOfVectors();
|
||||
}
|
||||
|
||||
int64_t
|
||||
IndexNGT::Dim() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
return index_->getDimension();
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,70 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <NGT/lib/NGT/Command.h>
|
||||
#include <NGT/lib/NGT/Common.h>
|
||||
#include <NGT/lib/NGT/Index.h>
|
||||
|
||||
#include <knowhere/common/Exception.h>
|
||||
#include <knowhere/index/IndexType.h>
|
||||
#include <knowhere/index/vector_index/VecIndex.h>
|
||||
#include <memory>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IndexNGT : public VecIndex {
|
||||
public:
|
||||
IndexNGT() {
|
||||
index_type_ = IndexEnum::INVALID;
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize(const Config& config) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
|
||||
void
|
||||
Train(const DatasetPtr& dataset_ptr, const Config& config) override {
|
||||
KNOWHERE_THROW_MSG("NGT not support add item dynamically, please invoke BuildAll interface.");
|
||||
}
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset_ptr, const Config& config) override {
|
||||
KNOWHERE_THROW_MSG("NGT not support add item dynamically, please invoke BuildAll interface.");
|
||||
}
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) override {
|
||||
KNOWHERE_THROW_MSG("Incremental index is not supported");
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
int64_t
|
||||
Dim() override;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<NGT::Index> index_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexNGTONNG.h"
|
||||
|
||||
#include "NGT/lib/NGT/GraphOptimizer.h"
|
||||
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
void
|
||||
IndexNGTONNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr);
|
||||
|
||||
NGT::Property prop;
|
||||
prop.setDefaultForCreateIndex();
|
||||
prop.dimension = dim;
|
||||
|
||||
auto edge_size = config[IndexParams::edge_size].get<int64_t>();
|
||||
prop.edgeSizeForCreation = edge_size;
|
||||
prop.insertionRadiusCoefficient = 1.0;
|
||||
|
||||
MetricType metric_type = config[Metric::TYPE];
|
||||
|
||||
if (metric_type == Metric::L2) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
|
||||
} else if (metric_type == Metric::HAMMING) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
|
||||
} else if (metric_type == Metric::JACCARD) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
|
||||
}
|
||||
|
||||
index_ =
|
||||
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
|
||||
|
||||
// reconstruct graph
|
||||
NGT::GraphOptimizer graphOptimizer(true);
|
||||
|
||||
auto number_of_outgoing_edges = config[IndexParams::outgoing_edge_size].get<size_t>();
|
||||
auto number_of_incoming_edges = config[IndexParams::incoming_edge_size].get<size_t>();
|
||||
|
||||
graphOptimizer.shortcutReduction = true;
|
||||
graphOptimizer.searchParameterOptimization = false;
|
||||
graphOptimizer.prefetchParameterOptimization = false;
|
||||
graphOptimizer.accuracyTableGeneration = false;
|
||||
graphOptimizer.margin = 0.2;
|
||||
graphOptimizer.gtEpsilon = 0.1;
|
||||
|
||||
graphOptimizer.set(number_of_outgoing_edges, number_of_incoming_edges, 1000, 20);
|
||||
|
||||
graphOptimizer.execute(*index_);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -7,27 +7,24 @@
|
|||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include "knowhere/common/Config.h"
|
||||
#include "knowhere/index/vector_index/IndexNGT.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
struct Quantizer {
|
||||
virtual ~Quantizer() = default;
|
||||
class IndexNGTONNG : public IndexNGT {
|
||||
public:
|
||||
IndexNGTONNG() {
|
||||
index_type_ = IndexEnum::INDEX_NGTONNG;
|
||||
}
|
||||
|
||||
int64_t size = -1;
|
||||
void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
};
|
||||
using QuantizerPtr = std::shared_ptr<Quantizer>;
|
||||
|
||||
// struct QuantizerCfg : Cfg {
|
||||
// int64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data
|
||||
// };
|
||||
// using QuantizerConfig = std::shared_ptr<QuantizerCfg>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexNGTPANNG.h"
|
||||
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
void
|
||||
IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr);
|
||||
|
||||
NGT::Property prop;
|
||||
prop.setDefaultForCreateIndex();
|
||||
prop.dimension = dim;
|
||||
|
||||
auto edge_size = config[IndexParams::edge_size].get<int64_t>();
|
||||
prop.edgeSizeLimitForCreation = edge_size;
|
||||
|
||||
MetricType metric_type = config[Metric::TYPE];
|
||||
|
||||
if (metric_type == Metric::L2) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
|
||||
} else if (metric_type == Metric::HAMMING) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
|
||||
} else if (metric_type == Metric::JACCARD) {
|
||||
prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
|
||||
}
|
||||
|
||||
index_ =
|
||||
std::shared_ptr<NGT::Index>(NGT::Index::createGraphAndTree(reinterpret_cast<const float*>(p_data), prop, rows));
|
||||
|
||||
auto forcedly_pruned_edge_size = config[IndexParams::forcedly_pruned_edge_size].get<int64_t>();
|
||||
auto selectively_pruned_edge_size = config[IndexParams::selectively_pruned_edge_size].get<int64_t>();
|
||||
|
||||
if (!forcedly_pruned_edge_size && !selectively_pruned_edge_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (forcedly_pruned_edge_size && selectively_pruned_edge_size &&
|
||||
selectively_pruned_edge_size >= forcedly_pruned_edge_size) {
|
||||
KNOWHERE_THROW_MSG("Selectively pruned edge size should less than remaining edge size");
|
||||
}
|
||||
|
||||
// prune
|
||||
auto& graph = dynamic_cast<NGT::GraphIndex&>(index_->getIndex());
|
||||
for (size_t id = 1; id < graph.repository.size(); id++) {
|
||||
try {
|
||||
NGT::GraphNode& node = *graph.getNode(id);
|
||||
if (node.size() >= forcedly_pruned_edge_size) {
|
||||
node.resize(forcedly_pruned_edge_size);
|
||||
}
|
||||
if (node.size() >= selectively_pruned_edge_size) {
|
||||
size_t rank = 0;
|
||||
for (auto i = node.begin(); i != node.end(); ++rank) {
|
||||
if (rank >= selectively_pruned_edge_size) {
|
||||
bool found = false;
|
||||
for (size_t t1 = 0; t1 < node.size() && found == false; ++t1) {
|
||||
if (t1 >= selectively_pruned_edge_size) {
|
||||
break;
|
||||
}
|
||||
if (rank == t1) {
|
||||
continue;
|
||||
}
|
||||
NGT::GraphNode& node2 = *graph.getNode(node[t1].id);
|
||||
for (size_t t2 = 0; t2 < node2.size(); ++t2) {
|
||||
if (t2 >= selectively_pruned_edge_size) {
|
||||
break;
|
||||
}
|
||||
if (node2[t2].id == (*i).id) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
} // for
|
||||
} // for
|
||||
if (found) {
|
||||
// remove
|
||||
i = node.erase(i);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
i++;
|
||||
} // for
|
||||
}
|
||||
} catch (NGT::Exception& err) {
|
||||
std::cerr << "Graph::search: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "knowhere/index/vector_index/IndexNGT.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
class IndexNGTPANNG : public IndexNGT {
|
||||
public:
|
||||
IndexNGTPANNG() {
|
||||
index_type_ = IndexEnum::INDEX_NGTPANNG;
|
||||
}
|
||||
|
||||
void
|
||||
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -9,6 +9,7 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
@ -37,6 +38,7 @@ NSG::Serialize(const Config& config) {
|
|||
}
|
||||
|
||||
try {
|
||||
fiu_do_on("NSG.Serialize.throw_exception", throw std::exception());
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
impl::NsgIndex* index = index_.get();
|
||||
|
||||
|
@ -55,6 +57,7 @@ NSG::Serialize(const Config& config) {
|
|||
void
|
||||
NSG::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
fiu_do_on("NSG.Load.throw_exception", throw std::exception());
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
auto binary = index_binary.GetByName("NSG");
|
||||
|
||||
|
@ -70,7 +73,7 @@ NSG::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -84,15 +87,13 @@ NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
|
||||
|
||||
impl::SearchParams s_params;
|
||||
s_params.search_length = config[IndexParams::search_length];
|
||||
s_params.k = config[meta::TOPK];
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
index_->Search((float*)p_data, nullptr, rows, dim, config[meta::TOPK].get<int64_t>(), p_dist, p_id,
|
||||
s_params, blacklist);
|
||||
s_params, bitset);
|
||||
}
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
|
|
|
@ -59,7 +59,7 @@ class NSG : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -79,7 +79,7 @@ IndexRHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -96,10 +96,9 @@ IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
auto real_index = dynamic_cast<faiss::IndexRHNSW*>(index_.get());
|
||||
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
|
||||
|
||||
real_index->hnsw.efSearch = (config[IndexParams::ef]);
|
||||
real_index->search(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, blacklist);
|
||||
real_index->search(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, bitset);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
|
|
|
@ -52,7 +52,7 @@ class IndexRHNSW : public VecIndex, public FaissBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -176,7 +176,7 @@ CPUSPTAGRNG::SetParameters(const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
SetParameters(config);
|
||||
|
||||
float* p_data = (float*)dataset_ptr->Get<const void*>(meta::TENSOR);
|
||||
|
|
|
@ -52,7 +52,7 @@ class CPUSPTAGRNG : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -46,7 +46,7 @@ class VecIndex : public Index {
|
|||
AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
virtual DatasetPtr
|
||||
Query(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
Query(const DatasetPtr& dataset, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) = 0;
|
||||
|
||||
#if 0
|
||||
virtual DatasetPtr
|
||||
|
@ -144,9 +144,11 @@ class VecIndex : public Index {
|
|||
protected:
|
||||
IndexType index_type_ = "";
|
||||
IndexMode index_mode_ = IndexMode::MODE_CPU;
|
||||
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
|
||||
std::vector<IDType> uids_;
|
||||
int64_t index_size_ = -1;
|
||||
|
||||
private:
|
||||
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
|
||||
};
|
||||
|
||||
using VecIndexPtr = std::shared_ptr<VecIndex>;
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/IndexNGTONNG.h"
|
||||
#include "knowhere/index/vector_index/IndexNGTPANNG.h"
|
||||
#include "knowhere/index/vector_index/IndexRHNSWFlat.h"
|
||||
#include "knowhere/index/vector_index/IndexRHNSWPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexRHNSWSQ.h"
|
||||
|
@ -99,6 +101,10 @@ VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) {
|
|||
return std::make_shared<knowhere::IndexRHNSWPQ>();
|
||||
} else if (type == IndexEnum::INDEX_RHNSWSQ) {
|
||||
return std::make_shared<knowhere::IndexRHNSWSQ>();
|
||||
} else if (type == IndexEnum::INDEX_NGTPANNG) {
|
||||
return std::make_shared<knowhere::IndexNGTPANNG>();
|
||||
} else if (type == IndexEnum::INDEX_NGTONNG) {
|
||||
return std::make_shared<knowhere::IndexNGTONNG>();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifdef MILVUS_GPU_VERSION
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#endif
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
@ -43,6 +44,7 @@ GPUIDMAP::CopyGpuToCpu(const Config& config) {
|
|||
BinarySet
|
||||
GPUIDMAP::SerializeImpl(const IndexType& type) {
|
||||
try {
|
||||
fiu_do_on("GPUIDMP.SerializeImpl.throw_exception", throw std::exception());
|
||||
MemoryIOWriter writer;
|
||||
{
|
||||
faiss::Index* index = index_.get();
|
||||
|
@ -102,13 +104,19 @@ GPUIDMAP::GetRawIds() {
|
|||
}
|
||||
|
||||
void
|
||||
GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
GPUIDMAP::QueryImpl(int64_t n,
|
||||
const float* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
ResScope rs(res_, gpu_id_);
|
||||
|
||||
// assign the metric type
|
||||
auto flat_index = dynamic_cast<faiss::IndexIDMap*>(index_.get())->index;
|
||||
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
index_->search(n, data, k, distances, labels, bitset_);
|
||||
index_->search(n, data, k, distances, labels, bitset);
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -132,7 +140,7 @@ GPUIDMAP::GenGraph(const float* data, const int64_t k, GraphType& graph, const C
|
|||
res.resize(K * b_size);
|
||||
|
||||
const float* xq = data + batch_size * dim * i;
|
||||
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config);
|
||||
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr);
|
||||
|
||||
for (int j = 0; j < b_size; ++j) {
|
||||
auto& node = graph[batch_size * i + j];
|
||||
|
|
|
@ -55,7 +55,8 @@ class GPUIDMAP : public IDMAP, public GPUIndex {
|
|||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
|
||||
override;
|
||||
};
|
||||
|
||||
using GPUIDMAPPtr = std::shared_ptr<GPUIDMAP>;
|
||||
|
|
|
@ -9,12 +9,14 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#include <faiss/gpu/GpuIndexIVF.h>
|
||||
#include <faiss/gpu/GpuIndexIVFFlat.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
@ -91,6 +93,7 @@ GPUIVF::SerializeImpl(const IndexType& type) {
|
|||
}
|
||||
|
||||
try {
|
||||
fiu_do_on("GPUIVF.SerializeImpl.throw_exception", throw std::exception());
|
||||
MemoryIOWriter writer;
|
||||
{
|
||||
faiss::Index* index = index_.get();
|
||||
|
@ -134,12 +137,19 @@ GPUIVF::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
|
|||
}
|
||||
|
||||
void
|
||||
GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
GPUIVF::QueryImpl(int64_t n,
|
||||
const float* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
|
||||
fiu_do_on("GPUIVF.search_impl.invald_index", device_index = nullptr);
|
||||
if (device_index) {
|
||||
device_index->nprobe = config[IndexParams::nprobe];
|
||||
device_index->nprobe = std::min(static_cast<int>(config[IndexParams::nprobe]), device_index->nlist);
|
||||
ResScope rs(res_, gpu_id_);
|
||||
|
||||
// if query size > 2048 we search by blocks to avoid malloc issue
|
||||
|
@ -148,7 +158,7 @@ GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int
|
|||
for (int64_t i = 0; i < n; i += block_size) {
|
||||
int64_t search_size = (n - i > block_size) ? block_size : (n - i);
|
||||
device_index->search(search_size, reinterpret_cast<const float*>(data) + i * dim, k, distances + i * k,
|
||||
labels + i * k, bitset_);
|
||||
labels + i * k, bitset);
|
||||
}
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Not a GpuIndexIVF type.");
|
||||
|
|
|
@ -51,7 +51,8 @@ class GPUIVF : public IVF, public GPUIndex {
|
|||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
|
||||
override;
|
||||
};
|
||||
|
||||
using GPUIVFPtr = std::shared_ptr<GPUIVF>;
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include <faiss/gpu/GpuCloner.h>
|
||||
#include <faiss/gpu/GpuIndexIVF.h>
|
||||
#include <faiss/index_factory.h>
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
|
@ -93,7 +94,7 @@ IVFSQHybrid::CopyCpuToGpu(const int64_t device_id, const Config& config) {
|
|||
}
|
||||
}
|
||||
|
||||
std::pair<VecIndexPtr, QuantizerPtr>
|
||||
std::pair<VecIndexPtr, FaissIVFQuantizerPtr>
|
||||
IVFSQHybrid::CopyCpuToGpuWithQuantizer(const int64_t device_id, const Config& config) {
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
|
@ -122,7 +123,7 @@ IVFSQHybrid::CopyCpuToGpuWithQuantizer(const int64_t device_id, const Config& co
|
|||
}
|
||||
|
||||
VecIndexPtr
|
||||
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& quantizer_ptr, const Config& config) {
|
||||
IVFSQHybrid::LoadData(const FaissIVFQuantizerPtr& quantizer_ptr, const Config& config) {
|
||||
int64_t gpu_id = config[knowhere::meta::DEVICEID];
|
||||
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
|
||||
|
@ -150,7 +151,7 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& quantizer_ptr, const Config&
|
|||
}
|
||||
}
|
||||
|
||||
QuantizerPtr
|
||||
FaissIVFQuantizerPtr
|
||||
IVFSQHybrid::LoadQuantizer(const Config& config) {
|
||||
auto gpu_id = config[knowhere::meta::DEVICEID].get<int64_t>();
|
||||
|
||||
|
@ -173,8 +174,6 @@ IVFSQHybrid::LoadQuantizer(const Config& config) {
|
|||
q->size = q_ptr->d * q_ptr->getNumVecs() * sizeof(float);
|
||||
q->quantizer = q_ptr;
|
||||
q->gpu_id = gpu_id;
|
||||
res_ = res;
|
||||
gpu_mode_ = 1;
|
||||
return q;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu: " + std::to_string(gpu_id) + "resource");
|
||||
|
@ -182,20 +181,17 @@ IVFSQHybrid::LoadQuantizer(const Config& config) {
|
|||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::SetQuantizer(const QuantizerPtr& quantizer_ptr) {
|
||||
auto ivf_quantizer = std::dynamic_pointer_cast<FaissIVFQuantizer>(quantizer_ptr);
|
||||
if (ivf_quantizer == nullptr) {
|
||||
KNOWHERE_THROW_MSG("Quantizer type error");
|
||||
IVFSQHybrid::SetQuantizer(const FaissIVFQuantizerPtr& quantizer_ptr) {
|
||||
faiss::IndexIVF* ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
if (ivf_index == nullptr) {
|
||||
KNOWHERE_THROW_MSG("Index type error");
|
||||
}
|
||||
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
// Once SetQuantizer() is called, make sure UnsetQuantizer() is also called before destructuring.
|
||||
// Otherwise, ivf_index->quantizer will be double free.
|
||||
|
||||
auto is_gpu_flat_index = dynamic_cast<faiss::gpu::GpuIndexFlat*>(ivf_index->quantizer);
|
||||
if (is_gpu_flat_index == nullptr) {
|
||||
// delete ivf_index->quantizer;
|
||||
ivf_index->quantizer = ivf_quantizer->quantizer;
|
||||
}
|
||||
quantizer_gpu_id_ = ivf_quantizer->gpu_id;
|
||||
quantizer_ = quantizer_ptr;
|
||||
ivf_index->quantizer = quantizer_->quantizer;
|
||||
gpu_mode_ = 1;
|
||||
}
|
||||
|
||||
|
@ -206,8 +202,10 @@ IVFSQHybrid::UnsetQuantizer() {
|
|||
KNOWHERE_THROW_MSG("Index type error");
|
||||
}
|
||||
|
||||
ivf_index->quantizer = nullptr;
|
||||
quantizer_gpu_id_ = -1;
|
||||
// set back to cpu mode
|
||||
ivf_index->restore_quantizer();
|
||||
quantizer_ = nullptr;
|
||||
gpu_mode_ = 0;
|
||||
}
|
||||
|
||||
BinarySet
|
||||
|
@ -216,6 +214,7 @@ IVFSQHybrid::SerializeImpl(const IndexType& type) {
|
|||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
fiu_do_on("IVFSQHybrid.SerializeImpl.zero_gpu_mode", gpu_mode_ = 0);
|
||||
if (gpu_mode_ == 0) {
|
||||
MemoryIOWriter writer;
|
||||
faiss::write_index(index_.get(), &writer);
|
||||
|
@ -242,20 +241,26 @@ IVFSQHybrid::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
|
|||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& config) {
|
||||
IVFSQHybrid::QueryImpl(int64_t n,
|
||||
const float* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (gpu_mode_ == 2) {
|
||||
GPUIVF::QueryImpl(n, data, k, distances, labels, config);
|
||||
GPUIVF::QueryImpl(n, data, k, distances, labels, config, bitset);
|
||||
// index_->search(n, (float*)data, k, distances, labels);
|
||||
} else if (gpu_mode_ == 1) { // hybrid
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(quantizer_gpu_id_)) {
|
||||
ResScope rs(res, quantizer_gpu_id_, true);
|
||||
IVF::QueryImpl(n, data, k, distances, labels, config);
|
||||
auto gpu_id = quantizer_->gpu_id;
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
|
||||
ResScope rs(res, gpu_id, true);
|
||||
IVF::QueryImpl(n, data, k, distances, labels, config, bitset);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Hybrid Search Error, can't get gpu: " + std::to_string(quantizer_gpu_id_) + "resource");
|
||||
KNOWHERE_THROW_MSG("Hybrid Search Error, can't get gpu: " + std::to_string(gpu_id) + "resource");
|
||||
}
|
||||
} else if (gpu_mode_ == 0) {
|
||||
IVF::QueryImpl(n, data, k, distances, labels, config);
|
||||
IVF::QueryImpl(n, data, k, distances, labels, config, bitset);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -278,7 +283,6 @@ FaissIVFQuantizer::~FaissIVFQuantizer() {
|
|||
delete quantizer;
|
||||
quantizer = nullptr;
|
||||
}
|
||||
// else do nothing
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -18,18 +18,18 @@
|
|||
#include <utility>
|
||||
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/gpu/Quantizer.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
||||
struct FaissIVFQuantizer : public Quantizer {
|
||||
struct FaissIVFQuantizer {
|
||||
faiss::gpu::GpuIndexFlat* quantizer = nullptr;
|
||||
int64_t gpu_id;
|
||||
int64_t size = -1;
|
||||
|
||||
~FaissIVFQuantizer() override;
|
||||
~FaissIVFQuantizer();
|
||||
};
|
||||
using FaissIVFQuantizerPtr = std::shared_ptr<FaissIVFQuantizer>;
|
||||
|
||||
|
@ -62,17 +62,17 @@ class IVFSQHybrid : public GPUIVFSQ {
|
|||
VecIndexPtr
|
||||
CopyCpuToGpu(const int64_t, const Config&) override;
|
||||
|
||||
std::pair<VecIndexPtr, QuantizerPtr>
|
||||
std::pair<VecIndexPtr, FaissIVFQuantizerPtr>
|
||||
CopyCpuToGpuWithQuantizer(const int64_t, const Config&);
|
||||
|
||||
VecIndexPtr
|
||||
LoadData(const knowhere::QuantizerPtr&, const Config&);
|
||||
LoadData(const FaissIVFQuantizerPtr&, const Config&);
|
||||
|
||||
QuantizerPtr
|
||||
FaissIVFQuantizerPtr
|
||||
LoadQuantizer(const Config& conf);
|
||||
|
||||
void
|
||||
SetQuantizer(const QuantizerPtr& q);
|
||||
SetQuantizer(const FaissIVFQuantizerPtr& q);
|
||||
|
||||
void
|
||||
UnsetQuantizer();
|
||||
|
@ -88,11 +88,12 @@ class IVFSQHybrid : public GPUIVFSQ {
|
|||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
|
||||
override;
|
||||
|
||||
protected:
|
||||
int64_t gpu_mode_ = 0; // 0,1,2
|
||||
int64_t quantizer_gpu_id_ = -1;
|
||||
int64_t gpu_mode_ = 0; // 0: CPU, 1: Hybrid, 2: GPU
|
||||
FaissIVFQuantizerPtr quantizer_ = nullptr;
|
||||
};
|
||||
|
||||
using IVFSQHybridPtr = std::shared_ptr<IVFSQHybrid>;
|
||||
|
|
|
@ -65,8 +65,9 @@ CopyCpuToGpu(const VecIndexPtr& index, const int64_t device_id, const Config& co
|
|||
} else {
|
||||
KNOWHERE_THROW_MSG("this index type not support transfer to gpu");
|
||||
}
|
||||
|
||||
CopyIndexData(result, index);
|
||||
if (result != nullptr) {
|
||||
CopyIndexData(result, index);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <utility>
|
||||
|
||||
namespace milvus {
|
||||
|
@ -82,6 +83,7 @@ FaissGpuResourceMgr::InitResource() {
|
|||
|
||||
ResPtr
|
||||
FaissGpuResourceMgr::GetRes(const int64_t device_id, const int64_t alloc_size) {
|
||||
fiu_return_on("FaissGpuResourceMgr.GetRes.ret_null", nullptr);
|
||||
InitResource();
|
||||
|
||||
auto finder = idle_map_.find(device_id);
|
||||
|
|
|
@ -51,6 +51,15 @@ constexpr const char* search_k = "search_k";
|
|||
|
||||
// PQ Params
|
||||
constexpr const char* PQM = "PQM";
|
||||
|
||||
// NGT Params
|
||||
constexpr const char* edge_size = "edge_size";
|
||||
// NGT_PANNG Params
|
||||
constexpr const char* forcedly_pruned_edge_size = "forcedly_pruned_edge_size";
|
||||
constexpr const char* selectively_pruned_edge_size = "selectively_pruned_edge_size";
|
||||
// NGT_ONNG Params
|
||||
constexpr const char* outgoing_edge_size = "outgoing_edge_size";
|
||||
constexpr const char* incoming_edge_size = "incoming_edge_size";
|
||||
} // namespace IndexParams
|
||||
|
||||
namespace Metric {
|
||||
|
|
|
@ -124,7 +124,10 @@ NsgIndex::InitNavigationPoint(float* data) {
|
|||
|
||||
// Specify Link
|
||||
void
|
||||
NsgIndex::GetNeighbors(const float* query, float* data, std::vector<Neighbor>& resset, std::vector<Neighbor>& fullset,
|
||||
NsgIndex::GetNeighbors(const float* query,
|
||||
float* data,
|
||||
std::vector<Neighbor>& resset,
|
||||
std::vector<Neighbor>& fullset,
|
||||
boost::dynamic_bitset<>& has_calculated_dist) {
|
||||
auto& graph = knng;
|
||||
size_t buffer_size = search_length;
|
||||
|
@ -331,8 +334,8 @@ NsgIndex::GetNeighbors(const float* query, float* data, std::vector<Neighbor>& r
|
|||
}
|
||||
|
||||
void
|
||||
NsgIndex::GetNeighbors(const float* query, float* data, std::vector<Neighbor>& resset, Graph& graph,
|
||||
SearchParams* params) {
|
||||
NsgIndex::GetNeighbors(
|
||||
const float* query, float* data, std::vector<Neighbor>& resset, Graph& graph, SearchParams* params) {
|
||||
size_t buffer_size = params ? params->search_length : search_length;
|
||||
|
||||
if (buffer_size > ntotal) {
|
||||
|
@ -482,7 +485,10 @@ NsgIndex::Link(float* data) {
|
|||
}
|
||||
|
||||
void
|
||||
NsgIndex::SyncPrune(float* data, size_t n, std::vector<Neighbor>& pool, boost::dynamic_bitset<>& has_calculated,
|
||||
NsgIndex::SyncPrune(float* data,
|
||||
size_t n,
|
||||
std::vector<Neighbor>& pool,
|
||||
boost::dynamic_bitset<>& has_calculated,
|
||||
float* cut_graph_dist) {
|
||||
// avoid lose nearest neighbor in knng
|
||||
for (size_t i = 0; i < knng[n].size(); ++i) {
|
||||
|
@ -597,8 +603,8 @@ NsgIndex::InterInsert(float* data, unsigned n, std::vector<std::mutex>& mutex_ve
|
|||
}
|
||||
|
||||
void
|
||||
NsgIndex::SelectEdge(float* data, unsigned& cursor, std::vector<Neighbor>& sort_pool, std::vector<Neighbor>& result,
|
||||
bool limit) {
|
||||
NsgIndex::SelectEdge(
|
||||
float* data, unsigned& cursor, std::vector<Neighbor>& sort_pool, std::vector<Neighbor>& result, bool limit) {
|
||||
auto& pool = sort_pool;
|
||||
|
||||
/*
|
||||
|
@ -850,8 +856,15 @@ NsgIndex::FindUnconnectedNode(float* data, boost::dynamic_bitset<>& has_linked,
|
|||
// }
|
||||
|
||||
void
|
||||
NsgIndex::Search(const float* query, float* data, const unsigned& nq, const unsigned& dim, const unsigned& k,
|
||||
float* dist, int64_t* ids, SearchParams& params, faiss::ConcurrentBitsetPtr bitset) {
|
||||
NsgIndex::Search(const float* query,
|
||||
float* data,
|
||||
const unsigned& nq,
|
||||
const unsigned& dim,
|
||||
const unsigned& k,
|
||||
float* dist,
|
||||
int64_t* ids,
|
||||
SearchParams& params,
|
||||
faiss::ConcurrentBitsetPtr bitset) {
|
||||
std::vector<std::vector<Neighbor>> resset(nq);
|
||||
|
||||
TimeRecorder rc("NsgIndex::search", 1);
|
||||
|
|
|
@ -83,8 +83,15 @@ class NsgIndex {
|
|||
Build_with_ids(size_t nb, float* data, const int64_t* ids, const BuildParams& parameters);
|
||||
|
||||
void
|
||||
Search(const float* query, float* data, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist,
|
||||
int64_t* ids, SearchParams& params, faiss::ConcurrentBitsetPtr bitset = nullptr);
|
||||
Search(const float* query,
|
||||
float* data,
|
||||
const unsigned& nq,
|
||||
const unsigned& dim,
|
||||
const unsigned& k,
|
||||
float* dist,
|
||||
int64_t* ids,
|
||||
SearchParams& params,
|
||||
faiss::ConcurrentBitsetPtr bitset = nullptr);
|
||||
|
||||
int64_t
|
||||
GetSize();
|
||||
|
@ -108,7 +115,10 @@ class NsgIndex {
|
|||
|
||||
// link specify
|
||||
void
|
||||
GetNeighbors(const float* query, float* data, std::vector<Neighbor>& resset, std::vector<Neighbor>& fullset,
|
||||
GetNeighbors(const float* query,
|
||||
float* data,
|
||||
std::vector<Neighbor>& resset,
|
||||
std::vector<Neighbor>& fullset,
|
||||
boost::dynamic_bitset<>& has_calculated_dist);
|
||||
|
||||
// FindUnconnectedNode
|
||||
|
@ -117,8 +127,8 @@ class NsgIndex {
|
|||
|
||||
// navigation-point
|
||||
void
|
||||
GetNeighbors(const float* query, float* data, std::vector<Neighbor>& resset, Graph& graph,
|
||||
SearchParams* param = nullptr);
|
||||
GetNeighbors(
|
||||
const float* query, float* data, std::vector<Neighbor>& resset, Graph& graph, SearchParams* param = nullptr);
|
||||
|
||||
// only for search
|
||||
// void
|
||||
|
@ -128,11 +138,17 @@ class NsgIndex {
|
|||
Link(float* data);
|
||||
|
||||
void
|
||||
SyncPrune(float* data, size_t q, std::vector<Neighbor>& pool, boost::dynamic_bitset<>& has_calculated,
|
||||
SyncPrune(float* data,
|
||||
size_t q,
|
||||
std::vector<Neighbor>& pool,
|
||||
boost::dynamic_bitset<>& has_calculated,
|
||||
float* cut_graph_dist);
|
||||
|
||||
void
|
||||
SelectEdge(float* data, unsigned& cursor, std::vector<Neighbor>& sort_pool, std::vector<Neighbor>& result,
|
||||
SelectEdge(float* data,
|
||||
unsigned& cursor,
|
||||
std::vector<Neighbor>& sort_pool,
|
||||
std::vector<Neighbor>& result,
|
||||
bool limit = false);
|
||||
|
||||
void
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <faiss/gpu/GpuCloner.h>
|
||||
#endif
|
||||
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -66,13 +67,13 @@ IVF_NM::Load(const BinarySet& binary_set) {
|
|||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
auto invlists = ivf_index->invlists;
|
||||
auto d = ivf_index->d;
|
||||
size_t nb = binary->size / invlists->code_size;
|
||||
auto arranged_data = new float[d * nb];
|
||||
prefix_sum.resize(invlists->nlist);
|
||||
size_t curr_index = 0;
|
||||
|
||||
#ifndef MILVUS_GPU_VERSION
|
||||
auto ails = dynamic_cast<faiss::ArrayInvertedLists*>(invlists);
|
||||
size_t nb = binary->size / invlists->code_size;
|
||||
auto arranged_data = new float[d * nb];
|
||||
for (size_t i = 0; i < invlists->nlist; i++) {
|
||||
auto list_size = ails->ids[i].size();
|
||||
for (size_t j = 0; j < list_size; j++) {
|
||||
|
@ -81,8 +82,10 @@ IVF_NM::Load(const BinarySet& binary_set) {
|
|||
prefix_sum[i] = curr_index;
|
||||
curr_index += list_size;
|
||||
}
|
||||
data_ = std::shared_ptr<uint8_t[]>(reinterpret_cast<uint8_t*>(arranged_data));
|
||||
#else
|
||||
auto rol = dynamic_cast<faiss::ReadOnlyArrayInvertedLists*>(invlists);
|
||||
auto arranged_data = reinterpret_cast<float*>(rol->pin_readonly_codes->data);
|
||||
auto lengths = rol->readonly_length;
|
||||
auto rol_ids = reinterpret_cast<const int64_t*>(rol->pin_readonly_ids->data);
|
||||
for (size_t i = 0; i < invlists->nlist; i++) {
|
||||
|
@ -94,8 +97,11 @@ IVF_NM::Load(const BinarySet& binary_set) {
|
|||
prefix_sum[i] = curr_index;
|
||||
curr_index += list_size;
|
||||
}
|
||||
|
||||
/* hold codes shared pointer */
|
||||
ro_codes = rol->pin_readonly_codes;
|
||||
data_ = nullptr;
|
||||
#endif
|
||||
data_ = std::shared_ptr<uint8_t[]>(reinterpret_cast<uint8_t*>(arranged_data));
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -132,7 +138,7 @@ IVF_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -140,6 +146,8 @@ IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
GET_TENSOR_DATA(dataset_ptr)
|
||||
|
||||
try {
|
||||
fiu_do_on("IVF_NM.Search.throw_std_exception", throw std::exception());
|
||||
fiu_do_on("IVF_NM.Search.throw_faiss_exception", throw faiss::FaissException(""));
|
||||
auto k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
|
||||
|
@ -148,7 +156,7 @@ IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
|
||||
auto p_dist = static_cast<float*>(malloc(p_dist_size));
|
||||
|
||||
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config);
|
||||
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
|
@ -236,8 +244,8 @@ IVF_NM::CopyCpuToGpu(const int64_t device_id, const Config& config) {
|
|||
#ifdef MILVUS_GPU_VERSION
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
auto gpu_index =
|
||||
faiss::gpu::index_cpu_to_gpu_without_codes(res->faiss_res.get(), device_id, index_.get(), data_.get());
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu_without_codes(res->faiss_res.get(), device_id, index_.get(),
|
||||
static_cast<const uint8_t*>(ro_codes->data));
|
||||
|
||||
std::shared_ptr<faiss::Index> device_index;
|
||||
device_index.reset(gpu_index);
|
||||
|
@ -275,7 +283,7 @@ IVF_NM::GenGraph(const float* data, const int64_t k, GraphType& graph, const Con
|
|||
res.resize(K * b_size);
|
||||
|
||||
const float* xq = data + batch_size * dim * i;
|
||||
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config);
|
||||
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr);
|
||||
|
||||
for (int j = 0; j < b_size; ++j) {
|
||||
auto& node = graph[batch_size * i + j];
|
||||
|
@ -297,7 +305,13 @@ IVF_NM::GenParams(const Config& config) {
|
|||
}
|
||||
|
||||
void
|
||||
IVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
IVF_NM::QueryImpl(int64_t n,
|
||||
const float* query,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
|
@ -308,8 +322,15 @@ IVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int
|
|||
ivf_index->parallel_mode = 0;
|
||||
}
|
||||
bool is_sq8 = (index_type_ == IndexEnum::INDEX_FAISS_IVFSQ8) ? true : false;
|
||||
ivf_index->search_without_codes(n, reinterpret_cast<const float*>(data), data_.get(), prefix_sum, is_sq8, k,
|
||||
distances, labels, bitset_);
|
||||
|
||||
#ifndef MILVUS_GPU_VERSION
|
||||
auto data = static_cast<const uint8_t*>(data_.get());
|
||||
#else
|
||||
auto data = static_cast<const uint8_t*>(ro_codes->data);
|
||||
#endif
|
||||
|
||||
ivf_index->search_without_codes(n, reinterpret_cast<const float*>(query), data, prefix_sum, is_sq8, k, distances,
|
||||
labels, bitset);
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
LOG_KNOWHERE_DEBUG_ << "IVF_NM search cost: " << search_cost
|
||||
|
|
|
@ -51,7 +51,7 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
|
|||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
|
@ -86,15 +86,21 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
|
|||
GenParams(const Config&);
|
||||
|
||||
virtual void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
|
||||
QueryImpl(
|
||||
int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset);
|
||||
|
||||
void
|
||||
SealImpl() override;
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
std::shared_ptr<uint8_t[]> data_ = nullptr;
|
||||
std::vector<size_t> prefix_sum;
|
||||
|
||||
// data_: if CPU, malloc memory while loading data
|
||||
// ro_codes: if GPU, hold a ptr of read only codes so that
|
||||
// destruction won't be done twice
|
||||
std::shared_ptr<uint8_t[]> data_ = nullptr;
|
||||
faiss::PageLockMemoryPtr ro_codes = nullptr;
|
||||
};
|
||||
|
||||
using IVFNMPtr = std::shared_ptr<IVF_NM>;
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
@ -36,6 +37,7 @@ NSG_NM::Serialize(const Config& config) {
|
|||
}
|
||||
|
||||
try {
|
||||
fiu_do_on("NSG_NM.Serialize.throw_exception", throw std::exception());
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
impl::NsgIndex* index = index_.get();
|
||||
|
||||
|
@ -54,6 +56,7 @@ NSG_NM::Serialize(const Config& config) {
|
|||
void
|
||||
NSG_NM::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
fiu_do_on("NSG_NM.Load.throw_exception", throw std::exception());
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
auto binary = index_binary.GetByName("NSG_NM");
|
||||
|
||||
|
@ -71,7 +74,7 @@ NSG_NM::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -86,8 +89,6 @@ NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
|
||||
auto p_dist = static_cast<float*>(malloc(p_dist_size));
|
||||
|
||||
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
|
||||
|
||||
impl::SearchParams s_params;
|
||||
s_params.search_length = config[IndexParams::search_length];
|
||||
s_params.k = config[meta::TOPK];
|
||||
|
@ -95,7 +96,7 @@ NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
// index_->ori_data_ = (float*) data_.get();
|
||||
index_->Search(reinterpret_cast<const float*>(p_data), reinterpret_cast<float*>(data_.get()), rows, dim,
|
||||
topK, p_dist, p_id, s_params, blacklist);
|
||||
topK, p_dist, p_id, s_params, bitset);
|
||||
}
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
|
|
|
@ -59,7 +59,7 @@ class NSG_NM : public VecIndex {
|
|||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <faiss/index_io.h>
|
||||
#include <fiu/fiu-local.h>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/IndexType.h"
|
||||
|
@ -22,6 +23,7 @@ namespace knowhere {
|
|||
BinarySet
|
||||
OffsetBaseIndex::SerializeImpl(const IndexType& type) {
|
||||
try {
|
||||
fiu_do_on("OffsetBaseIndex.SerializeImpl.throw_exception", throw std::exception());
|
||||
faiss::Index* index = index_.get();
|
||||
|
||||
MemoryIOWriter writer;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <faiss/gpu/GpuIndexIVF.h>
|
||||
#include <faiss/gpu/GpuIndexIVFFlat.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <fiu/fiu-local.h>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
@ -97,6 +98,7 @@ GPUIVF_NM::SerializeImpl(const IndexType& type) {
|
|||
}
|
||||
|
||||
try {
|
||||
fiu_do_on("GPUIVF_NM.SerializeImpl.throw_exception", throw std::exception());
|
||||
MemoryIOWriter writer;
|
||||
{
|
||||
faiss::Index* index = index_.get();
|
||||
|
@ -116,10 +118,17 @@ GPUIVF_NM::SerializeImpl(const IndexType& type) {
|
|||
}
|
||||
|
||||
void
|
||||
GPUIVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
GPUIVF_NM::QueryImpl(int64_t n,
|
||||
const float* data,
|
||||
int64_t k,
|
||||
float* distances,
|
||||
int64_t* labels,
|
||||
const Config& config,
|
||||
const faiss::ConcurrentBitsetPtr& bitset) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
|
||||
fiu_do_on("GPUIVF_NM.search_impl.invald_index", device_index = nullptr);
|
||||
if (device_index) {
|
||||
device_index->nprobe = config[IndexParams::nprobe];
|
||||
ResScope rs(res_, gpu_id_);
|
||||
|
@ -129,7 +138,7 @@ GPUIVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances,
|
|||
int64_t dim = device_index->d;
|
||||
for (int64_t i = 0; i < n; i += block_size) {
|
||||
int64_t search_size = (n - i > block_size) ? block_size : (n - i);
|
||||
device_index->search(search_size, data + i * dim, k, distances + i * k, labels + i * k, bitset_);
|
||||
device_index->search(search_size, data + i * dim, k, distances + i * k, labels + i * k, bitset);
|
||||
}
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Not a GpuIndexIVF type.");
|
||||
|
|
|
@ -51,7 +51,8 @@ class GPUIVF_NM : public IVF, public GPUIndex {
|
|||
SerializeImpl(const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset)
|
||||
override;
|
||||
|
||||
protected:
|
||||
uint8_t* arranged_data;
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
if(APPLE)
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
else()
|
||||
cmake_minimum_required(VERSION 2.8)
|
||||
endif()
|
||||
|
||||
project(ngt)
|
||||
|
||||
file(STRINGS "VERSION" ngt_VERSION)
|
||||
message(STATUS "VERSION: ${ngt_VERSION}")
|
||||
string(REGEX MATCH "^[0-9]+" ngt_VERSION_MAJOR ${ngt_VERSION})
|
||||
|
||||
set(ngt_VERSION ${ngt_VERSION})
|
||||
set(ngt_SOVERSION ${ngt_VERSION_MAJOR})
|
||||
|
||||
if (NOT CMAKE_BUILD_TYPE)
|
||||
set (CMAKE_BUILD_TYPE "Release")
|
||||
endif (NOT CMAKE_BUILD_TYPE)
|
||||
string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_LOWER)
|
||||
message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
|
||||
message(STATUS "CMAKE_BUILD_TYPE_LOWER: ${CMAKE_BUILD_TYPE_LOWER}")
|
||||
|
||||
if(${UNIX})
|
||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
|
||||
|
||||
if(CMAKE_VERSION VERSION_LESS 3.1)
|
||||
set(BASE_OPTIONS "-Wall -std=gnu++0x -lrt")
|
||||
|
||||
if(${NGT_AVX_DISABLED})
|
||||
message(STATUS "AVX will not be used to compute distances.")
|
||||
endif()
|
||||
|
||||
if(${NGT_OPENMP_DISABLED})
|
||||
message(STATUS "OpenMP is disabled.")
|
||||
else()
|
||||
set(BASE_OPTIONS "${BASE_OPTIONS} -fopenmp")
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "-g ${BASE_OPTIONS}")
|
||||
|
||||
if(${NGT_MARCH_NATIVE_DISABLED})
|
||||
message(STATUS "Compile option -march=native is disabled.")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "-O2 ${BASE_OPTIONS}")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=native ${BASE_OPTIONS}")
|
||||
endif()
|
||||
else()
|
||||
if (CMAKE_BUILD_TYPE_LOWER STREQUAL "release")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "")
|
||||
if(${NGT_MARCH_NATIVE_DISABLED})
|
||||
message(STATUS "Compile option -march=native is disabled.")
|
||||
add_compile_options(-O2 -DNDEBUG)
|
||||
else()
|
||||
add_compile_options(-Ofast -march=native -DNDEBUG)
|
||||
endif()
|
||||
endif()
|
||||
add_compile_options(-Wall)
|
||||
if(${NGT_AVX_DISABLED})
|
||||
message(STATUS "AVX will not be used to compute distances.")
|
||||
endif()
|
||||
if(${NGT_OPENMP_DISABLED})
|
||||
message(STATUS "OpenMP is disabled.")
|
||||
else()
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
|
||||
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "8.1.0")
|
||||
message(FATAL_ERROR "Insufficient AppleClang version")
|
||||
endif()
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
endif()
|
||||
find_package(OpenMP REQUIRED)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
endif()
|
||||
set(CMAKE_CXX_STANDARD 11) # for std::unordered_set, std::unique_ptr
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
find_package(Threads REQUIRED)
|
||||
endif()
|
||||
|
||||
add_subdirectory("${PROJECT_SOURCE_DIR}/lib")
|
||||
endif( ${UNIX} )
|
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1 @@
|
|||
1.12.0
|
|
@ -0,0 +1,3 @@
|
|||
if( ${UNIX} )
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/lib/NGT)
|
||||
endif()
|
|
@ -0,0 +1,89 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include "ArrayFile.h"
|
||||
#include <iostream>
|
||||
#include <assert.h>
|
||||
|
||||
class ItemID {
|
||||
public:
|
||||
void serialize(std::ostream &os, NGT::ObjectSpace *ospace = 0) {
|
||||
os.write((char*)&value, sizeof(value));
|
||||
}
|
||||
void deserialize(std::istream &is, NGT::ObjectSpace *ospace = 0) {
|
||||
is.read((char*)&value, sizeof(value));
|
||||
}
|
||||
static size_t getSerializedDataSize() {
|
||||
return sizeof(uint64_t);
|
||||
}
|
||||
uint64_t value;
|
||||
};
|
||||
|
||||
void
|
||||
sampleForUsage() {
|
||||
{
|
||||
ArrayFile<ItemID> itemIDFile;
|
||||
itemIDFile.create("test.data", ItemID::getSerializedDataSize());
|
||||
itemIDFile.open("test.data");
|
||||
ItemID itemID;
|
||||
size_t id;
|
||||
|
||||
id = 1;
|
||||
itemID.value = 4910002490100;
|
||||
itemIDFile.put(id, itemID);
|
||||
itemID.value = 0;
|
||||
itemIDFile.get(id, itemID);
|
||||
std::cerr << "value=" << itemID.value << std::endl;
|
||||
assert(itemID.value == 4910002490100);
|
||||
|
||||
id = 2;
|
||||
itemID.value = 4910002490101;
|
||||
itemIDFile.put(id, itemID);
|
||||
itemID.value = 0;
|
||||
itemIDFile.get(id, itemID);
|
||||
std::cerr << "value=" << itemID.value << std::endl;
|
||||
assert(itemID.value == 4910002490101);
|
||||
|
||||
itemID.value = 4910002490102;
|
||||
id = itemIDFile.insert(itemID);
|
||||
itemID.value = 0;
|
||||
itemIDFile.get(id, itemID);
|
||||
std::cerr << "value=" << itemID.value << std::endl;
|
||||
assert(itemID.value == 4910002490102);
|
||||
|
||||
itemIDFile.close();
|
||||
}
|
||||
{
|
||||
ArrayFile<ItemID> itemIDFile;
|
||||
itemIDFile.create("test.data", ItemID::getSerializedDataSize());
|
||||
itemIDFile.open("test.data");
|
||||
ItemID itemID;
|
||||
size_t id;
|
||||
|
||||
id = 10;
|
||||
itemIDFile.get(id, itemID);
|
||||
std::cerr << "value=" << itemID.value << std::endl;
|
||||
assert(itemID.value == 4910002490100);
|
||||
|
||||
id = 20;
|
||||
itemIDFile.get(id, itemID);
|
||||
std::cerr << "value=" << itemID.value << std::endl;
|
||||
assert(itemID.value == 4910002490101);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,220 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <cstddef>
|
||||
#include <stdint.h>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <cerrno>
|
||||
#include <cstring>
|
||||
|
||||
namespace NGT {
|
||||
class ObjectSpace;
|
||||
};
|
||||
|
||||
template <class TYPE>
|
||||
class ArrayFile {
|
||||
private:
|
||||
struct FileHeadStruct {
|
||||
size_t recordSize;
|
||||
uint64_t extraData; // reserve
|
||||
};
|
||||
|
||||
struct RecordStruct {
|
||||
bool deleteFlag;
|
||||
uint64_t extraData; // reserve
|
||||
};
|
||||
|
||||
bool _isOpen;
|
||||
std::fstream _stream;
|
||||
FileHeadStruct _fileHead;
|
||||
|
||||
bool _readFileHead();
|
||||
pthread_mutex_t _mutex;
|
||||
|
||||
public:
|
||||
ArrayFile();
|
||||
~ArrayFile();
|
||||
bool create(const std::string &file, size_t recordSize);
|
||||
bool open(const std::string &file);
|
||||
void close();
|
||||
size_t insert(TYPE &data, NGT::ObjectSpace *objectSpace = 0);
|
||||
void put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0);
|
||||
bool get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0);
|
||||
void remove(const size_t id);
|
||||
bool isOpen() const;
|
||||
size_t size();
|
||||
size_t getRecordSize() { return _fileHead.recordSize; }
|
||||
};
|
||||
|
||||
|
||||
// constructor
|
||||
template <class TYPE>
|
||||
ArrayFile<TYPE>::ArrayFile()
|
||||
: _isOpen(false), _mutex((pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER){
|
||||
if(pthread_mutex_init(&_mutex, NULL) < 0) throw std::runtime_error("pthread init error.");
|
||||
}
|
||||
|
||||
// destructor
|
||||
template <class TYPE>
|
||||
ArrayFile<TYPE>::~ArrayFile() {
|
||||
pthread_mutex_destroy(&_mutex);
|
||||
close();
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
bool ArrayFile<TYPE>::create(const std::string &file, size_t recordSize) {
|
||||
std::fstream tmpstream;
|
||||
tmpstream.open(file.c_str());
|
||||
if(tmpstream){
|
||||
return false;
|
||||
}
|
||||
|
||||
tmpstream.open(file.c_str(), std::ios::out);
|
||||
tmpstream.seekp(0, std::ios::beg);
|
||||
FileHeadStruct fileHead = {recordSize, 0};
|
||||
tmpstream.write((char *)(&fileHead), sizeof(FileHeadStruct));
|
||||
tmpstream.close();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
bool ArrayFile<TYPE>::open(const std::string &file) {
|
||||
_stream.open(file.c_str(), std::ios::in | std::ios::out);
|
||||
if(!_stream){
|
||||
_isOpen = false;
|
||||
return false;
|
||||
}
|
||||
_isOpen = true;
|
||||
|
||||
bool ret = _readFileHead();
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
void ArrayFile<TYPE>::close(){
|
||||
_stream.close();
|
||||
_isOpen = false;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
size_t ArrayFile<TYPE>::insert(TYPE &data, NGT::ObjectSpace *objectSpace) {
|
||||
_stream.seekp(sizeof(RecordStruct), std::ios::end);
|
||||
int64_t write_pos = _stream.tellg();
|
||||
for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); }
|
||||
_stream.seekp(write_pos, std::ios::beg);
|
||||
data.serialize(_stream, objectSpace);
|
||||
|
||||
int64_t offset_pos = _stream.tellg();
|
||||
offset_pos -= sizeof(FileHeadStruct);
|
||||
size_t id = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize);
|
||||
if(offset_pos % (sizeof(RecordStruct) + _fileHead.recordSize) == 0){
|
||||
id -= 1;
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
void ArrayFile<TYPE>::put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) {
|
||||
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
|
||||
offset_pos += sizeof(RecordStruct);
|
||||
_stream.seekp(offset_pos, std::ios::beg);
|
||||
|
||||
for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); }
|
||||
_stream.seekp(offset_pos, std::ios::beg);
|
||||
data.serialize(_stream, objectSpace);
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
bool ArrayFile<TYPE>::get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) {
|
||||
pthread_mutex_lock(&_mutex);
|
||||
|
||||
if( size() <= id ){
|
||||
pthread_mutex_unlock(&_mutex);
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
|
||||
offset_pos += sizeof(RecordStruct);
|
||||
_stream.seekg(offset_pos, std::ios::beg);
|
||||
if (!_stream.fail()) {
|
||||
data.deserialize(_stream, objectSpace);
|
||||
}
|
||||
if (_stream.fail()) {
|
||||
const int trialCount = 10;
|
||||
for (int tc = 0; tc < trialCount; tc++) {
|
||||
_stream.clear();
|
||||
_stream.seekg(offset_pos, std::ios::beg);
|
||||
if (_stream.fail()) {
|
||||
continue;
|
||||
}
|
||||
data.deserialize(_stream, objectSpace);
|
||||
if (_stream.fail()) {
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (_stream.fail()) {
|
||||
throw std::runtime_error("ArrayFile::get: Error!");
|
||||
}
|
||||
}
|
||||
|
||||
pthread_mutex_unlock(&_mutex);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
void ArrayFile<TYPE>::remove(const size_t id) {
|
||||
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
|
||||
_stream.seekp(offset_pos, std::ios::beg);
|
||||
RecordStruct recordHead = {1, 0};
|
||||
_stream.write((char *)(&recordHead), sizeof(RecordStruct));
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
bool ArrayFile<TYPE>::isOpen() const
|
||||
{
|
||||
return _isOpen;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
size_t ArrayFile<TYPE>::size()
|
||||
{
|
||||
_stream.seekp(0, std::ios::end);
|
||||
int64_t offset_pos = _stream.tellg();
|
||||
offset_pos -= sizeof(FileHeadStruct);
|
||||
size_t num = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize);
|
||||
|
||||
return num;
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
bool ArrayFile<TYPE>::_readFileHead() {
|
||||
_stream.seekp(0, std::ios::beg);
|
||||
_stream.read((char *)(&_fileHead), sizeof(FileHeadStruct));
|
||||
if(_stream.bad()){
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
if( ${UNIX} )
|
||||
option(NGT_SHARED_MEMORY_ALLOCATOR "enable shared memory" OFF)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/defines.h.in ${CMAKE_CURRENT_BINARY_DIR}/defines.h)
|
||||
include_directories("${CMAKE_CURRENT_BINARY_DIR}" "${PROJECT_SOURCE_DIR}/lib" "${PROJECT_BINARY_DIR}/lib/")
|
||||
include_directories("${PROJECT_SOURCE_DIR}/../")
|
||||
|
||||
file(GLOB NGT_SOURCES *.cpp)
|
||||
file(GLOB HEADER_FILES *.h *.hpp)
|
||||
file(GLOB NGTQ_HEADER_FILES NGTQ/*.h NGTQ/*.hpp)
|
||||
|
||||
add_library(ngtstatic STATIC ${NGT_SOURCES})
|
||||
set_target_properties(ngtstatic PROPERTIES OUTPUT_NAME ngt)
|
||||
set_target_properties(ngtstatic PROPERTIES COMPILE_FLAGS "-fPIC")
|
||||
target_link_libraries(ngtstatic)
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
|
||||
target_link_libraries(ngtstatic OpenMP::OpenMP_CXX)
|
||||
endif()
|
||||
|
||||
add_library(ngt SHARED ${NGT_SOURCES})
|
||||
set_target_properties(ngt PROPERTIES VERSION ${ngt_VERSION})
|
||||
set_target_properties(ngt PROPERTIES SOVERSION ${ngt_SOVERSION})
|
||||
add_dependencies(ngt ngtstatic)
|
||||
if(${APPLE})
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
|
||||
target_link_libraries(ngt OpenMP::OpenMP_CXX)
|
||||
else()
|
||||
target_link_libraries(ngt gomp)
|
||||
endif()
|
||||
else(${APPLE})
|
||||
target_link_libraries(ngt gomp rt)
|
||||
endif(${APPLE})
|
||||
|
||||
install(TARGETS
|
||||
ngt
|
||||
ngtstatic
|
||||
RUNTIME DESTINATION bin
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib)
|
||||
|
||||
endif()
|
|
@ -0,0 +1,988 @@
|
|||
//
|
||||
// Copyright (C) 2015-2020 Yahoo Japan Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "NGT/Index.h"
|
||||
#include "NGT/GraphOptimizer.h"
|
||||
#include "Capi.h"
|
||||
|
||||
static bool operate_error_string_(const std::stringstream &ss, NGTError error){
|
||||
if(error != NULL){
|
||||
try{
|
||||
std::string *error_str = static_cast<std::string*>(error);
|
||||
*error_str = ss.str();
|
||||
}catch(std::exception &err){
|
||||
std::cerr << ss.str() << " > " << err.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
}else{
|
||||
std::cerr << ss.str() << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
NGTIndex ngt_open_index(const char *index_path, NGTError error) {
|
||||
try{
|
||||
std::string index_path_str(index_path);
|
||||
NGT::Index *index = new NGT::Index(index_path_str);
|
||||
index->disableLog();
|
||||
return static_cast<NGTIndex>(index);
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
NGTIndex ngt_create_graph_and_tree(const char *database, NGTProperty prop, NGTError error) {
|
||||
NGT::Index *index = NULL;
|
||||
try{
|
||||
std::string database_str(database);
|
||||
NGT::Property prop_i = *(static_cast<NGT::Property*>(prop));
|
||||
NGT::Index::createGraphAndTree(database_str, prop_i, true);
|
||||
index = new NGT::Index(database_str);
|
||||
index->disableLog();
|
||||
return static_cast<NGTIndex>(index);
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
delete index;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
NGTIndex ngt_create_graph_and_tree_in_memory(NGTProperty prop, NGTError error) {
|
||||
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << __FUNCTION__ << " is unavailable for shared-memory-type NGT.";
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
#else
|
||||
try{
|
||||
NGT::Index *index = new NGT::GraphAndTreeIndex(*(static_cast<NGT::Property*>(prop)));
|
||||
index->disableLog();
|
||||
return static_cast<NGTIndex>(index);
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
NGTProperty ngt_create_property(NGTError error) {
|
||||
try{
|
||||
return static_cast<NGTProperty>(new NGT::Property());
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
bool ngt_save_index(const NGTIndex index, const char *database, NGTError error) {
|
||||
try{
|
||||
std::string database_str(database);
|
||||
(static_cast<NGT::Index*>(index))->saveIndex(database_str);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_get_property(NGTIndex index, NGTProperty prop, NGTError error) {
|
||||
if(index == NULL || prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
try{
|
||||
(static_cast<NGT::Index*>(index))->getProperty(*(static_cast<NGT::Property*>(prop)));
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t ngt_get_property_dimension(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
return (*static_cast<NGT::Property*>(prop)).dimension;
|
||||
}
|
||||
|
||||
bool ngt_set_property_dimension(NGTProperty prop, int32_t value, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
(*static_cast<NGT::Property*>(prop)).dimension = value;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_edge_size_for_creation(NGTProperty prop, int16_t value, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
(*static_cast<NGT::Property*>(prop)).edgeSizeForCreation = value;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_edge_size_for_search(NGTProperty prop, int16_t value, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
(*static_cast<NGT::Property*>(prop)).edgeSizeForSearch = value;
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t ngt_get_property_object_type(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
return (*static_cast<NGT::Property*>(prop)).objectType;
|
||||
}
|
||||
|
||||
bool ngt_is_property_object_type_float(int32_t object_type) {
|
||||
return (object_type == NGT::ObjectSpace::ObjectType::Float);
|
||||
}
|
||||
|
||||
bool ngt_is_property_object_type_integer(int32_t object_type) {
|
||||
return (object_type == NGT::ObjectSpace::ObjectType::Uint8);
|
||||
}
|
||||
|
||||
bool ngt_set_property_object_type_float(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).objectType = NGT::ObjectSpace::ObjectType::Float;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_object_type_integer(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).objectType = NGT::ObjectSpace::ObjectType::Uint8;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_l1(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL1;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_l2(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_angle(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeAngle;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_hamming(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_jaccard(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_cosine(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeCosine;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_normalized_angle(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedAngle;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_set_property_distance_type_normalized_cosine(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedCosine;
|
||||
return true;
|
||||
}
|
||||
|
||||
NGTObjectDistances ngt_create_empty_results(NGTError error) {
|
||||
try{
|
||||
return static_cast<NGTObjectDistances>(new NGT::ObjectDistances());
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static bool ngt_search_index_(NGT::Index* pindex, NGT::Object *ngtquery, size_t size, float epsilon, float radius, NGTObjectDistances results, int edge_size = INT_MIN) {
|
||||
// set search prameters.
|
||||
NGT::SearchContainer sc(*ngtquery); // search parametera container.
|
||||
|
||||
sc.setResults(static_cast<NGT::ObjectDistances*>(results)); // set the result set.
|
||||
sc.setSize(size); // the number of resultant objects.
|
||||
sc.setRadius(radius); // search radius.
|
||||
sc.setEpsilon(epsilon); // set exploration coefficient.
|
||||
if (edge_size != INT_MIN) {
|
||||
sc.setEdgeSize(edge_size);// set # of edges for each node
|
||||
}
|
||||
|
||||
pindex->search(sc);
|
||||
|
||||
// delete the query object.
|
||||
pindex->deleteObject(ngtquery);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_search_index(NGTIndex index, double *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) {
|
||||
if(index == NULL || query == NULL || results == NULL || query_dim <= 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
NGT::Object *ngtquery = NULL;
|
||||
|
||||
if(radius < 0.0){
|
||||
radius = FLT_MAX;
|
||||
}
|
||||
|
||||
try{
|
||||
std::vector<double> vquery(&query[0], &query[query_dim]);
|
||||
ngtquery = pindex->allocateObject(vquery);
|
||||
ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
if(ngtquery != NULL){
|
||||
pindex->deleteObject(ngtquery);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_search_index_as_float(NGTIndex index, float *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) {
|
||||
if(index == NULL || query == NULL || results == NULL || query_dim <= 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
NGT::Object *ngtquery = NULL;
|
||||
|
||||
if(radius < 0.0){
|
||||
radius = FLT_MAX;
|
||||
}
|
||||
|
||||
try{
|
||||
std::vector<float> vquery(&query[0], &query[query_dim]);
|
||||
ngtquery = pindex->allocateObject(vquery);
|
||||
ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
if(ngtquery != NULL){
|
||||
pindex->deleteObject(ngtquery);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_search_index_with_query(NGTIndex index, NGTQuery query, NGTObjectDistances results, NGTError error) {
|
||||
if(index == NULL || query.query == NULL || results == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query.query << " results = " << results;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
int32_t dim = pindex->getObjectSpace().getDimension();
|
||||
|
||||
NGT::Object *ngtquery = NULL;
|
||||
|
||||
if(query.radius < 0.0){
|
||||
query.radius = FLT_MAX;
|
||||
}
|
||||
|
||||
try{
|
||||
std::vector<float> vquery(&query.query[0], &query.query[dim]);
|
||||
ngtquery = pindex->allocateObject(vquery);
|
||||
ngt_search_index_(pindex, ngtquery, query.size, query.epsilon, query.radius, results, query.edge_size);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
if(ngtquery != NULL){
|
||||
pindex->deleteObject(ngtquery);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// * deprecated *
|
||||
int32_t ngt_get_size(NGTObjectDistances results, NGTError error) {
|
||||
if(results == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
|
||||
return (static_cast<NGT::ObjectDistances*>(results))->size();
|
||||
}
|
||||
|
||||
uint32_t ngt_get_result_size(NGTObjectDistances results, NGTError error) {
|
||||
if(results == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results;
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (static_cast<NGT::ObjectDistances*>(results))->size();
|
||||
}
|
||||
|
||||
NGTObjectDistance ngt_get_result(const NGTObjectDistances results, const uint32_t i, NGTError error) {
|
||||
try{
|
||||
NGT::ObjectDistances objects = *(static_cast<NGT::ObjectDistances*>(results));
|
||||
NGTObjectDistance ret_val = {objects[i].id, objects[i].distance};
|
||||
return ret_val;
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
|
||||
NGTObjectDistance err_val = {0};
|
||||
return err_val;
|
||||
}
|
||||
}
|
||||
|
||||
ObjectID ngt_insert_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) {
|
||||
if(index == NULL || obj == NULL || obj_dim == 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
try{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
std::vector<double> vobj(&obj[0], &obj[obj_dim]);
|
||||
return pindex->insert(vobj);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
ObjectID ngt_append_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) {
|
||||
if(index == NULL || obj == NULL || obj_dim == 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
try{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
std::vector<double> vobj(&obj[0], &obj[obj_dim]);
|
||||
return pindex->append(vobj);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
ObjectID ngt_insert_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) {
|
||||
if(index == NULL || obj == NULL || obj_dim == 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
try{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
std::vector<float> vobj(&obj[0], &obj[obj_dim]);
|
||||
return pindex->insert(vobj);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
ObjectID ngt_append_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) {
|
||||
if(index == NULL || obj == NULL || obj_dim == 0){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
try{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
std::vector<float> vobj(&obj[0], &obj[obj_dim]);
|
||||
return pindex->append(vobj);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool ngt_batch_append_index(NGTIndex index, float *obj, uint32_t data_count, NGTError error) {
|
||||
try{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
pindex->append(obj, data_count);
|
||||
return true;
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool ngt_batch_insert_index(NGTIndex index, float *obj, uint32_t data_count, uint32_t *ids, NGTError error) {
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
int32_t dim = pindex->getObjectSpace().getDimension();
|
||||
|
||||
bool status = true;
|
||||
float *objptr = obj;
|
||||
for (size_t idx = 0; idx < data_count; idx++, objptr += dim) {
|
||||
try{
|
||||
std::vector<double> vobj(objptr, objptr + dim);
|
||||
ids[idx] = pindex->insert(vobj);
|
||||
}catch(std::exception &err) {
|
||||
status = false;
|
||||
ids[idx] = 0;
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
bool ngt_create_index(NGTIndex index, uint32_t pool_size, NGTError error) {
|
||||
if(index == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
try{
|
||||
(static_cast<NGT::Index*>(index))->createIndex(pool_size);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_remove_index(NGTIndex index, ObjectID id, NGTError error) {
|
||||
if(index == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
try{
|
||||
(static_cast<NGT::Index*>(index))->remove(id);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
NGTObjectSpace ngt_get_object_space(NGTIndex index, NGTError error) {
|
||||
if(index == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
try{
|
||||
return static_cast<NGTObjectSpace>(&(static_cast<NGT::Index*>(index))->getObjectSpace());
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
float* ngt_get_object_as_float(NGTObjectSpace object_space, ObjectID id, NGTError error) {
|
||||
if(object_space == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space;
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
try{
|
||||
return static_cast<float*>((static_cast<NGT::ObjectSpace*>(object_space))->getObject(id));
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t* ngt_get_object_as_integer(NGTObjectSpace object_space, ObjectID id, NGTError error) {
|
||||
if(object_space == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space;
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
try{
|
||||
return static_cast<uint8_t*>((static_cast<NGT::ObjectSpace*>(object_space))->getObject(id));
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void ngt_destroy_results(NGTObjectDistances results) {
|
||||
if(results == NULL) return;
|
||||
delete(static_cast<NGT::ObjectDistances*>(results));
|
||||
}
|
||||
|
||||
void ngt_destroy_property(NGTProperty prop) {
|
||||
if(prop == NULL) return;
|
||||
delete(static_cast<NGT::Property*>(prop));
|
||||
}
|
||||
|
||||
void ngt_close_index(NGTIndex index) {
|
||||
if(index == NULL) return;
|
||||
(static_cast<NGT::Index*>(index))->close();
|
||||
delete(static_cast<NGT::Index*>(index));
|
||||
}
|
||||
|
||||
int16_t ngt_get_property_edge_size_for_creation(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
return (*static_cast<NGT::Property*>(prop)).edgeSizeForCreation;
|
||||
}
|
||||
|
||||
int16_t ngt_get_property_edge_size_for_search(NGTProperty prop, NGTError error) {
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
return (*static_cast<NGT::Property*>(prop)).edgeSizeForSearch;
|
||||
}
|
||||
|
||||
int32_t ngt_get_property_distance_type(NGTProperty prop, NGTError error){
|
||||
if(prop == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
|
||||
operate_error_string_(ss, error);
|
||||
return -1;
|
||||
}
|
||||
return (*static_cast<NGT::Property*>(prop)).distanceType;
|
||||
}
|
||||
|
||||
NGTError ngt_create_error_object()
|
||||
{
|
||||
try{
|
||||
std::string *error_str = new std::string();
|
||||
return static_cast<NGTError>(error_str);
|
||||
}catch(std::exception &err){
|
||||
std::cerr << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
const char *ngt_get_error_string(const NGTError error)
|
||||
{
|
||||
std::string *error_str = static_cast<std::string*>(error);
|
||||
return error_str->c_str();
|
||||
}
|
||||
|
||||
void ngt_clear_error_string(NGTError error)
|
||||
{
|
||||
std::string *error_str = static_cast<std::string*>(error);
|
||||
*error_str = "";
|
||||
}
|
||||
|
||||
void ngt_destroy_error_object(NGTError error)
|
||||
{
|
||||
std::string *error_str = static_cast<std::string*>(error);
|
||||
delete error_str;
|
||||
}
|
||||
|
||||
NGTOptimizer ngt_create_optimizer(bool logDisabled, NGTError error)
|
||||
{
|
||||
try{
|
||||
return static_cast<NGTOptimizer>(new NGT::GraphOptimizer(logDisabled));
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
bool ngt_optimizer_adjust_search_coefficients(NGTOptimizer optimizer, const char *index, NGTError error) {
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
try{
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->adjustSearchCoefficients(std::string(index));
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_optimizer_execute(NGTOptimizer optimizer, const char *inIndex, const char *outIndex, NGTError error) {
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
try{
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->execute(std::string(inIndex), std::string(outIndex));
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// obsolute because of a lack of a parameter
|
||||
bool ngt_optimizer_set(NGTOptimizer optimizer, int outgoing, int incoming, int nofqs,
|
||||
float baseAccuracyFrom, float baseAccuracyTo,
|
||||
float rateAccuracyFrom, float rateAccuracyTo,
|
||||
double gte, double m, NGTError error) {
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
try{
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->set(outgoing, incoming, nofqs, baseAccuracyFrom, baseAccuracyTo,
|
||||
rateAccuracyFrom, rateAccuracyTo, gte, m);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_optimizer_set_minimum(NGTOptimizer optimizer, int outgoing, int incoming,
|
||||
int nofqs, int nofrs, NGTError error) {
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
try{
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->set(outgoing, incoming, nofqs, nofrs);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_optimizer_set_extension(NGTOptimizer optimizer,
|
||||
float baseAccuracyFrom, float baseAccuracyTo,
|
||||
float rateAccuracyFrom, float rateAccuracyTo,
|
||||
double gte, double m, NGTError error) {
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
try{
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->setExtension(baseAccuracyFrom, baseAccuracyTo,
|
||||
rateAccuracyFrom, rateAccuracyTo, gte, m);
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_optimizer_set_processing_modes(NGTOptimizer optimizer, bool searchParameter,
|
||||
bool prefetchParameter, bool accuracyTable, NGTError error)
|
||||
{
|
||||
if(optimizer == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
(static_cast<NGT::GraphOptimizer*>(optimizer))->setProcessingModes(searchParameter, prefetchParameter,
|
||||
accuracyTable);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ngt_destroy_optimizer(NGTOptimizer optimizer)
|
||||
{
|
||||
if(optimizer == NULL) return;
|
||||
delete(static_cast<NGT::GraphOptimizer*>(optimizer));
|
||||
}
|
||||
|
||||
bool ngt_refine_anng(NGTIndex index, float epsilon, float accuracy, int noOfEdges, int exploreEdgeSize, size_t batchSize, NGTError error)
|
||||
{
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
try {
|
||||
NGT::GraphReconstructor::refineANNG(*pindex, true, epsilon, accuracy, noOfEdges, exploreEdgeSize, batchSize);
|
||||
} catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ngt_get_edges(NGTIndex index, ObjectID id, NGTObjectDistances edges, NGTError error)
|
||||
{
|
||||
if(index == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
NGT::Index* pindex = static_cast<NGT::Index*>(index);
|
||||
NGT::GraphIndex &graph = static_cast<NGT::GraphIndex&>(pindex->getIndex());
|
||||
|
||||
try {
|
||||
NGT::ObjectDistances &objects = *static_cast<NGT::ObjectDistances*>(edges);
|
||||
objects = *graph.getNode(id);
|
||||
}catch(std::exception &err){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t ngt_get_object_repository_size(NGTIndex index, NGTError error)
|
||||
{
|
||||
if(index == NULL){
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index;
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
NGT::Index& pindex = *static_cast<NGT::Index*>(index);
|
||||
return pindex.getObjectRepositorySize();
|
||||
}
|
||||
|
||||
NGTAnngEdgeOptimizationParameter ngt_get_anng_edge_optimization_parameter()
|
||||
{
|
||||
NGT::GraphOptimizer::ANNGEdgeOptimizationParameter gp;
|
||||
NGTAnngEdgeOptimizationParameter parameter;
|
||||
|
||||
parameter.no_of_queries = gp.noOfQueries;
|
||||
parameter.no_of_results = gp.noOfResults;
|
||||
parameter.no_of_threads = gp.noOfThreads;
|
||||
parameter.target_accuracy = gp.targetAccuracy;
|
||||
parameter.target_no_of_objects = gp.targetNoOfObjects;
|
||||
parameter.no_of_sample_objects = gp.noOfSampleObjects;
|
||||
parameter.max_of_no_of_edges = gp.maxNoOfEdges;
|
||||
parameter.log = false;
|
||||
|
||||
return parameter;
|
||||
}
|
||||
|
||||
bool ngt_optimize_number_of_edges(const char *indexPath, NGTAnngEdgeOptimizationParameter parameter, NGTError error)
|
||||
{
|
||||
|
||||
NGT::GraphOptimizer::ANNGEdgeOptimizationParameter p;
|
||||
|
||||
p.noOfQueries = parameter.no_of_queries;
|
||||
p.noOfResults = parameter.no_of_results;
|
||||
p.noOfThreads = parameter.no_of_threads;
|
||||
p.targetAccuracy = parameter.target_accuracy;
|
||||
p.targetNoOfObjects = parameter.target_no_of_objects;
|
||||
p.noOfSampleObjects = parameter.no_of_sample_objects;
|
||||
p.maxNoOfEdges = parameter.max_of_no_of_edges;
|
||||
|
||||
try {
|
||||
NGT::GraphOptimizer graphOptimizer(!parameter.log); // false=log
|
||||
std::string path(indexPath);
|
||||
auto edge = graphOptimizer.optimizeNumberOfEdgesForANNG(path, p);
|
||||
if (parameter.log) {
|
||||
std::cerr << "the optimized number of edges is" << edge.first << "(" << edge.second << ")" << std::endl;
|
||||
}
|
||||
}catch(std::exception &err) {
|
||||
std::stringstream ss;
|
||||
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
|
||||
operate_error_string_(ss, error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue