mirror of https://github.com/milvus-io/milvus.git
Merge branch 'branch-0.5.0' of http://192.168.1.105:6060/megasearch/milvus into branch-0.5.0
Former-commit-id: 0032fed31caf96d6abceb0cbbf4735075d23a94cpull/191/head
commit
c59b0a65ff
|
@ -18,3 +18,10 @@
|
|||
BasedOnStyle: Google
|
||||
DerivePointerAlignment: false
|
||||
ColumnLimit: 120
|
||||
IndentWidth: 4
|
||||
AccessModifierOffset: -3
|
||||
AlwaysBreakAfterReturnType: All
|
||||
AllowShortBlocksOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: false
|
||||
AllowShortIfStatementsOnASingleLine: false
|
||||
AlignTrailingComments: true
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
Checks: 'clang-diagnostic-*,clang-analyzer-*,-clang-analyzer-alpha*,google-*,modernize-*,readability-*'
|
||||
# produce HeaderFilterRegex from cpp/build-support/lint_exclusions.txt with:
|
||||
# echo -n '^('; sed -e 's/*/\.*/g' cpp/build-support/lint_exclusions.txt | tr '\n' '|'; echo ')$'
|
||||
HeaderFilterRegex: '^(.*cmake-build-debug.*|.*cmake-build-release.*|.*cmake_build.*|.*src/thirdparty.*|.*src/core/thirdparty.*|.*src/grpc.*|)$'
|
||||
HeaderFilterRegex: '^(.*cmake-build-debug.*|.*cmake-build-release.*|.*cmake_build.*|.*src/core/thirdparty.*|.*thirdparty.*|.*easylogging++.*|.*SqliteMetaImpl.cpp|.*src/grpc.*|.*src/core.*|.*src/wrapper.*)$'
|
||||
AnalyzeTemporaryDtors: true
|
||||
ChainedConditionalReturn: 1
|
||||
ChainedConditionalAssignment: 1
|
||||
CheckOptions:
|
||||
- key: google-readability-braces-around-statements.ShortStatementLines
|
||||
value: '1'
|
||||
|
|
|
@ -33,6 +33,7 @@ Please mark all change in change log and use the ticket from JIRA.
|
|||
- MS-575 - Add Clang-format & Clang-tidy & Cpplint
|
||||
- MS-586 - Remove BUILD_FAISS_WITH_MKL option
|
||||
- MS-590 - Refine cmake code to support cpplint
|
||||
- MS-600 - Reconstruct unittest code
|
||||
|
||||
# Milvus 0.4.0 (2019-09-12)
|
||||
|
||||
|
|
|
@ -189,7 +189,7 @@ add_custom_target(lint
|
|||
--exclude_globs
|
||||
${LINT_EXCLUSIONS_FILE}
|
||||
--source_dir
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${MILVUS_LINT_QUIET})
|
||||
|
||||
#
|
||||
|
|
|
@ -50,10 +50,9 @@ Centos7 :
|
|||
$ yum install gfortran qt4 flex bison
|
||||
$ yum install mysql-devel mysql
|
||||
|
||||
Ubuntu16.04 :
|
||||
Ubuntu 16.04 or 18.04:
|
||||
$ sudo apt-get install gfortran qt4-qmake flex bison
|
||||
$ sudo apt-get install libmysqlclient-dev mysql-client
|
||||
|
||||
```
|
||||
|
||||
Verify the existence of `libmysqlclient_r.so`:
|
||||
|
@ -66,6 +65,10 @@ $ locate libmysqlclient_r.so
|
|||
If not, you need to create a symbolic link:
|
||||
|
||||
```shell
|
||||
# Locate libmysqlclient.so
|
||||
$ sudo updatedb
|
||||
$ locate libmysqlclient.so
|
||||
|
||||
# Create symbolic link
|
||||
$ sudo ln -s /path/to/libmysqlclient.so /path/to/libmysqlclient_r.so
|
||||
```
|
||||
|
@ -90,7 +93,7 @@ please reinstall CMake with curl:
|
|||
```shell
|
||||
CentOS 7:
|
||||
$ yum install curl-devel
|
||||
Ubuntu 16.04:
|
||||
Ubuntu 16.04 or 18.04:
|
||||
$ sudo apt-get install libcurl4-openssl-dev
|
||||
```
|
||||
|
||||
|
@ -106,7 +109,7 @@ please reinstall CMake with curl:
|
|||
```shell
|
||||
CentOS 7:
|
||||
$ yum install clang
|
||||
Ubuntu 16.04:
|
||||
Ubuntu 16.04 or 18.04:
|
||||
$ sudo apt-get install clang-format clang-tidy
|
||||
|
||||
$ ./build.sh -l
|
||||
|
@ -123,7 +126,7 @@ $ ./build.sh -u
|
|||
```shell
|
||||
CentOS 7:
|
||||
$ yum install lcov
|
||||
Ubuntu 16.04:
|
||||
Ubuntu 16.04 or 18.04:
|
||||
$ sudo apt-get install lcov
|
||||
|
||||
$ ./build.sh -u -c
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
<code_scheme name="Default" version="173">
|
||||
<Objective-C>
|
||||
<option name="INDENT_NAMESPACE_MEMBERS" value="0" />
|
||||
<option name="INDENT_VISIBILITY_KEYWORDS" value="1" />
|
||||
<option name="KEEP_STRUCTURES_IN_ONE_LINE" value="true" />
|
||||
<option name="KEEP_CASE_EXPRESSIONS_IN_ONE_LINE" value="true" />
|
||||
<option name="FUNCTION_NON_TOP_AFTER_RETURN_TYPE_WRAP" value="0" />
|
||||
<option name="FUNCTION_TOP_AFTER_RETURN_TYPE_WRAP" value="2" />
|
||||
<option name="FUNCTION_PARAMETERS_WRAP" value="5" />
|
||||
<option name="FUNCTION_CALL_ARGUMENTS_WRAP" value="5" />
|
||||
<option name="TEMPLATE_CALL_ARGUMENTS_WRAP" value="5" />
|
||||
<option name="TEMPLATE_CALL_ARGUMENTS_ALIGN_MULTILINE" value="true" />
|
||||
<option name="CLASS_CONSTRUCTOR_INIT_LIST_WRAP" value="5" />
|
||||
<option name="ALIGN_INIT_LIST_IN_COLUMNS" value="false" />
|
||||
<option name="SPACE_BEFORE_PROTOCOLS_BRACKETS" value="false" />
|
||||
<option name="SPACE_BEFORE_POINTER_IN_DECLARATION" value="false" />
|
||||
<option name="SPACE_AFTER_POINTER_IN_DECLARATION" value="true" />
|
||||
<option name="SPACE_BEFORE_REFERENCE_IN_DECLARATION" value="false" />
|
||||
<option name="SPACE_AFTER_REFERENCE_IN_DECLARATION" value="true" />
|
||||
<option name="KEEP_BLANK_LINES_BEFORE_END" value="1" />
|
||||
</Objective-C>
|
||||
<codeStyleSettings language="ObjectiveC">
|
||||
<option name="KEEP_BLANK_LINES_IN_DECLARATIONS" value="1" />
|
||||
<option name="KEEP_BLANK_LINES_IN_CODE" value="1" />
|
||||
<option name="KEEP_BLANK_LINES_BEFORE_RBRACE" value="1" />
|
||||
<option name="BLANK_LINES_AROUND_CLASS" value="0" />
|
||||
<option name="BLANK_LINES_AROUND_METHOD_IN_INTERFACE" value="0" />
|
||||
<option name="BLANK_LINES_AFTER_CLASS_HEADER" value="1" />
|
||||
<option name="ALIGN_MULTILINE_BINARY_OPERATION" value="false" />
|
||||
<option name="SPACE_AFTER_TYPE_CAST" value="false" />
|
||||
<option name="BINARY_OPERATION_SIGN_ON_NEXT_LINE" value="true" />
|
||||
<option name="KEEP_SIMPLE_BLOCKS_IN_ONE_LINE" value="false" />
|
||||
<option name="FOR_STATEMENT_WRAP" value="1" />
|
||||
<option name="ASSIGNMENT_WRAP" value="1" />
|
||||
<indentOptions>
|
||||
<option name="CONTINUATION_INDENT_SIZE" value="4" />
|
||||
</indentOptions>
|
||||
</codeStyleSettings>
|
||||
</code_scheme>
|
|
@ -1,7 +1,8 @@
|
|||
*cmake-build-debug*
|
||||
*cmake-build-release*
|
||||
*cmake_build*
|
||||
*src/thirdparty*
|
||||
*src/core/thirdparty*
|
||||
*src/grpc*
|
||||
*easylogging++*
|
||||
*thirdparty*
|
||||
*easylogging++*
|
||||
*SqliteMetaImpl.cpp
|
||||
*src/grpc*
|
|
@ -99,21 +99,26 @@ if [[ ${RUN_CPPLINT} == "ON" ]]; then
|
|||
# cpplint check
|
||||
make lint
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "ERROR! cpplint check not pass"
|
||||
echo "ERROR! cpplint check failed"
|
||||
exit 1
|
||||
fi
|
||||
echo "cpplint check passed!"
|
||||
|
||||
# clang-format check
|
||||
make check-clang-format
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "ERROR! clang-format check failed"
|
||||
exit 1
|
||||
fi
|
||||
echo "clang-format check passed!"
|
||||
|
||||
# clang-tidy check
|
||||
make check-clang-tidy
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "ERROR! clang-tidy check failed"
|
||||
exit 1
|
||||
fi
|
||||
echo "clang-tidy check passed!"
|
||||
else
|
||||
# compile and build
|
||||
make -j 4 || exit 1
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# All the following configurations are default values.
|
||||
# Default values are used when you make no changes to the following parameters.
|
||||
|
||||
server_config:
|
||||
address: 0.0.0.0 # milvus server ip address (IPv4)
|
||||
|
@ -11,25 +11,27 @@ db_config:
|
|||
secondary_path: # path used to store data only, split by semicolon
|
||||
|
||||
backend_url: sqlite://:@:/ # URI format: dialect://username:password@host:port/database
|
||||
# Keep 'dialect://:@:/', and replace other texts with real values.
|
||||
# Keep 'dialect://:@:/', and replace other texts with real values
|
||||
# Replace 'dialect' with 'mysql' or 'sqlite'
|
||||
|
||||
insert_buffer_size: 4 # GB, maximum insert buffer size allowed
|
||||
# sum of insert_buffer_size and cpu_cache_capacity cannot exceed total memory
|
||||
build_index_gpu: 0 # gpu id used for building index
|
||||
|
||||
metric_config:
|
||||
enable_monitor: false # enable monitoring or not
|
||||
collector: prometheus # prometheus
|
||||
prometheus_config:
|
||||
port: 8080 # port prometheus used to fetch metrics
|
||||
port: 8080 # port prometheus uses to fetch metrics
|
||||
|
||||
cache_config:
|
||||
cpu_mem_capacity: 16 # GB, CPU memory used for cache
|
||||
cpu_mem_threshold: 0.85 # percentage of data kept when cache cleanup triggered
|
||||
cache_insert_data: false # whether load inserted data into cache
|
||||
cpu_cache_capacity: 16 # GB, CPU memory used for cache
|
||||
cpu_cache_threshold: 0.85 # percentage of data that will be kept when cache cleanup is triggered
|
||||
cache_insert_data: false # whether to load inserted data into cache
|
||||
|
||||
engine_config:
|
||||
blas_threshold: 20
|
||||
use_blas_threshold: 20 # if nq < use_blas_threshold, use SSE, faster with fluctuated response times
|
||||
# if nq >= use_blas_threshold, use OpenBlas, slower with stable response times
|
||||
|
||||
resource_config:
|
||||
resource_pool:
|
||||
|
|
|
@ -101,6 +101,7 @@ ${LCOV_CMD} -r "${FILE_INFO_OUTPUT}" -o "${FILE_INFO_OUTPUT_NEW}" \
|
|||
"src/core/cmake_build*" \
|
||||
"src/core/thirdparty*" \
|
||||
"src/grpc*"\
|
||||
"src/metrics/MetricBase.h"\
|
||||
"src/server/Server.cpp"\
|
||||
"src/server/DBWrapper.cpp"\
|
||||
"src/server/grpc_impl/GrpcMilvusServer.cpp"\
|
||||
|
|
|
@ -20,29 +20,24 @@
|
|||
include_directories(${MILVUS_SOURCE_DIR})
|
||||
include_directories(${MILVUS_ENGINE_SRC})
|
||||
|
||||
add_subdirectory(core)
|
||||
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||
include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-status)
|
||||
include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-milvus)
|
||||
|
||||
#this statement must put here, since the CORE_INCLUDE_DIRS is defined in code/CMakeList.txt
|
||||
add_subdirectory(core)
|
||||
set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE)
|
||||
foreach (dir ${CORE_INCLUDE_DIRS})
|
||||
include_directories(${dir})
|
||||
endforeach ()
|
||||
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/cache cache_files)
|
||||
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/config config_files)
|
||||
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/metrics metrics_files)
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/db db_main_files)
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/db/engine db_engine_files)
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/db/insert db_insert_files)
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/db/meta db_meta_files)
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler db_scheduler_files)
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler/context db_scheduler_context_files)
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler/task db_scheduler_task_files)
|
||||
set(db_scheduler_files
|
||||
${db_scheduler_files}
|
||||
${db_scheduler_context_files}
|
||||
${db_scheduler_task_files}
|
||||
)
|
||||
|
||||
set(grpc_service_files
|
||||
${MILVUS_ENGINE_SRC}/grpc/gen-milvus/milvus.grpc.pb.cc
|
||||
|
@ -51,8 +46,6 @@ set(grpc_service_files
|
|||
${MILVUS_ENGINE_SRC}/grpc/gen-status/status.pb.cc
|
||||
)
|
||||
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/metrics metrics_files)
|
||||
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler scheduler_main_files)
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/action scheduler_action_files)
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/event scheduler_event_files)
|
||||
|
@ -70,9 +63,7 @@ set(scheduler_files
|
|||
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/server server_files)
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/server/grpc_impl grpc_server_files)
|
||||
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/utils utils_files)
|
||||
|
||||
aux_source_directory(${MILVUS_ENGINE_SRC}/wrapper wrapper_files)
|
||||
|
||||
set(engine_files
|
||||
|
@ -82,16 +73,11 @@ set(engine_files
|
|||
${db_engine_files}
|
||||
${db_insert_files}
|
||||
${db_meta_files}
|
||||
${db_scheduler_files}
|
||||
${metrics_files}
|
||||
${utils_files}
|
||||
${wrapper_files}
|
||||
)
|
||||
|
||||
include_directories("${CUDA_TOOLKIT_ROOT_DIR}/include")
|
||||
include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-status)
|
||||
include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-milvus)
|
||||
|
||||
set(client_grpc_lib
|
||||
grpcpp_channelz
|
||||
grpc++
|
||||
|
@ -112,6 +98,12 @@ set(boost_lib
|
|||
boost_serialization_static
|
||||
)
|
||||
|
||||
set(cuda_lib
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so
|
||||
cudart
|
||||
cublas
|
||||
)
|
||||
|
||||
set(third_party_libs
|
||||
sqlite
|
||||
${client_grpc_lib}
|
||||
|
@ -123,17 +115,15 @@ set(third_party_libs
|
|||
snappy
|
||||
zlib
|
||||
zstd
|
||||
cudart
|
||||
cublas
|
||||
${cuda_lib}
|
||||
mysqlpp
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so
|
||||
cudart
|
||||
)
|
||||
|
||||
if (MILVUS_ENABLE_PROFILING STREQUAL "ON")
|
||||
set(third_party_libs ${third_party_libs}
|
||||
gperftools
|
||||
libunwind
|
||||
)
|
||||
gperftools
|
||||
libunwind
|
||||
)
|
||||
endif ()
|
||||
|
||||
link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64")
|
||||
|
@ -141,7 +131,6 @@ set(engine_libs
|
|||
pthread
|
||||
libgomp.a
|
||||
libgfortran.a
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so
|
||||
)
|
||||
|
||||
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
||||
|
@ -152,7 +141,11 @@ if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
|||
endif ()
|
||||
|
||||
cuda_add_library(milvus_engine STATIC ${engine_files})
|
||||
target_link_libraries(milvus_engine ${engine_libs} knowhere ${third_party_libs})
|
||||
target_link_libraries(milvus_engine
|
||||
knowhere
|
||||
${engine_libs}
|
||||
${third_party_libs}
|
||||
)
|
||||
|
||||
add_library(metrics STATIC ${metrics_files})
|
||||
|
||||
|
@ -180,7 +173,9 @@ add_executable(milvus_server
|
|||
${utils_files}
|
||||
)
|
||||
|
||||
target_link_libraries(milvus_server ${server_libs})
|
||||
target_link_libraries(milvus_server
|
||||
${server_libs}
|
||||
)
|
||||
|
||||
install(TARGETS milvus_server DESTINATION bin)
|
||||
|
||||
|
|
|
@ -15,47 +15,69 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "LRU.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace zilliz {
|
||||
namespace milvus {
|
||||
namespace cache {
|
||||
|
||||
template<typename ItemObj>
|
||||
template <typename ItemObj>
|
||||
class Cache {
|
||||
public:
|
||||
//mem_capacity, units:GB
|
||||
public:
|
||||
// mem_capacity, units:GB
|
||||
Cache(int64_t capacity_gb, uint64_t cache_max_count);
|
||||
~Cache() = default;
|
||||
|
||||
int64_t usage() const { return usage_; }
|
||||
int64_t capacity() const { return capacity_; } //unit: BYTE
|
||||
void set_capacity(int64_t capacity); //unit: BYTE
|
||||
int64_t
|
||||
usage() const {
|
||||
return usage_;
|
||||
}
|
||||
|
||||
double freemem_percent() const { return freemem_percent_; };
|
||||
void set_freemem_percent(double percent) { freemem_percent_ = percent; }
|
||||
int64_t
|
||||
capacity() const {
|
||||
return capacity_;
|
||||
} // unit: BYTE
|
||||
void
|
||||
set_capacity(int64_t capacity); // unit: BYTE
|
||||
|
||||
size_t size() const;
|
||||
bool exists(const std::string& key);
|
||||
ItemObj get(const std::string& key);
|
||||
void insert(const std::string& key, const ItemObj& item);
|
||||
void erase(const std::string& key);
|
||||
void print();
|
||||
void clear();
|
||||
double
|
||||
freemem_percent() const {
|
||||
return freemem_percent_;
|
||||
}
|
||||
|
||||
private:
|
||||
void free_memory();
|
||||
void
|
||||
set_freemem_percent(double percent) {
|
||||
freemem_percent_ = percent;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t
|
||||
size() const;
|
||||
bool
|
||||
exists(const std::string& key);
|
||||
ItemObj
|
||||
get(const std::string& key);
|
||||
void
|
||||
insert(const std::string& key, const ItemObj& item);
|
||||
void
|
||||
erase(const std::string& key);
|
||||
void
|
||||
print();
|
||||
void
|
||||
clear();
|
||||
|
||||
private:
|
||||
void
|
||||
free_memory();
|
||||
|
||||
private:
|
||||
int64_t usage_;
|
||||
int64_t capacity_;
|
||||
double freemem_percent_;
|
||||
|
@ -64,8 +86,8 @@ private:
|
|||
mutable std::mutex mutex_;
|
||||
};
|
||||
|
||||
} // cache
|
||||
} // milvus
|
||||
} // zilliz
|
||||
} // namespace cache
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
||||
#include "cache/Cache.inl"
|
||||
#include "cache/Cache.inl"
|
||||
|
|
|
@ -33,29 +33,33 @@ Cache<ItemObj>::Cache(int64_t capacity, uint64_t cache_max_count)
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
void Cache<ItemObj>::set_capacity(int64_t capacity) {
|
||||
if(capacity > 0) {
|
||||
void
|
||||
Cache<ItemObj>::set_capacity(int64_t capacity) {
|
||||
if (capacity > 0) {
|
||||
capacity_ = capacity;
|
||||
free_memory();
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
size_t Cache<ItemObj>::size() const {
|
||||
size_t
|
||||
Cache<ItemObj>::size() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return lru_.size();
|
||||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
bool Cache<ItemObj>::exists(const std::string& key) {
|
||||
bool
|
||||
Cache<ItemObj>::exists(const std::string &key) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return lru_.exists(key);
|
||||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
ItemObj Cache<ItemObj>::get(const std::string& key) {
|
||||
ItemObj
|
||||
Cache<ItemObj>::get(const std::string &key) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if(!lru_.exists(key)){
|
||||
if (!lru_.exists(key)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -63,8 +67,9 @@ ItemObj Cache<ItemObj>::get(const std::string& key) {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
void Cache<ItemObj>::insert(const std::string& key, const ItemObj& item) {
|
||||
if(item == nullptr) {
|
||||
void
|
||||
Cache<ItemObj>::insert(const std::string &key, const ItemObj &item) {
|
||||
if (item == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -80,7 +85,7 @@ void Cache<ItemObj>::insert(const std::string& key, const ItemObj& item) {
|
|||
|
||||
//if key already exist, subtract old item size
|
||||
if (lru_.exists(key)) {
|
||||
const ItemObj& old_item = lru_.get(key);
|
||||
const ItemObj &old_item = lru_.get(key);
|
||||
usage_ -= old_item->size();
|
||||
}
|
||||
|
||||
|
@ -107,13 +112,14 @@ void Cache<ItemObj>::insert(const std::string& key, const ItemObj& item) {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
void Cache<ItemObj>::erase(const std::string& key) {
|
||||
void
|
||||
Cache<ItemObj>::erase(const std::string &key) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if(!lru_.exists(key)){
|
||||
if (!lru_.exists(key)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const ItemObj& old_item = lru_.get(key);
|
||||
const ItemObj &old_item = lru_.get(key);
|
||||
usage_ -= old_item->size();
|
||||
|
||||
SERVER_LOG_DEBUG << "Erase " << key << " size: " << old_item->size();
|
||||
|
@ -122,7 +128,8 @@ void Cache<ItemObj>::erase(const std::string& key) {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
void Cache<ItemObj>::clear() {
|
||||
void
|
||||
Cache<ItemObj>::clear() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
lru_.clear();
|
||||
usage_ = 0;
|
||||
|
@ -131,12 +138,13 @@ void Cache<ItemObj>::clear() {
|
|||
|
||||
/* free memory space when CACHE occupation exceed its capacity */
|
||||
template<typename ItemObj>
|
||||
void Cache<ItemObj>::free_memory() {
|
||||
void
|
||||
Cache<ItemObj>::free_memory() {
|
||||
if (usage_ <= capacity_) return;
|
||||
|
||||
int64_t threshhold = capacity_ * freemem_percent_;
|
||||
int64_t delta_size = usage_ - threshhold;
|
||||
if(delta_size <= 0) {
|
||||
if (delta_size <= 0) {
|
||||
delta_size = 1;//ensure at least one item erased
|
||||
}
|
||||
|
||||
|
@ -148,8 +156,8 @@ void Cache<ItemObj>::free_memory() {
|
|||
|
||||
auto it = lru_.rbegin();
|
||||
while (it != lru_.rend() && released_size < delta_size) {
|
||||
auto& key = it->first;
|
||||
auto& obj_ptr = it->second;
|
||||
auto &key = it->first;
|
||||
auto &obj_ptr = it->second;
|
||||
|
||||
key_array.emplace(key);
|
||||
released_size += obj_ptr->size();
|
||||
|
@ -159,7 +167,7 @@ void Cache<ItemObj>::free_memory() {
|
|||
|
||||
SERVER_LOG_DEBUG << "to be released memory size: " << released_size;
|
||||
|
||||
for (auto& key : key_array) {
|
||||
for (auto &key : key_array) {
|
||||
erase(key);
|
||||
}
|
||||
|
||||
|
@ -167,7 +175,8 @@ void Cache<ItemObj>::free_memory() {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
void Cache<ItemObj>::print() {
|
||||
void
|
||||
Cache<ItemObj>::print() {
|
||||
size_t cache_count = 0;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
@ -179,7 +188,7 @@ void Cache<ItemObj>::print() {
|
|||
SERVER_LOG_DEBUG << "[Cache capacity]: " << capacity_ << " bytes";
|
||||
}
|
||||
|
||||
} // cache
|
||||
} // milvus
|
||||
} // zilliz
|
||||
} // namespace cache
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
||||
|
|
|
@ -15,50 +15,61 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Cache.h"
|
||||
#include "utils/Log.h"
|
||||
#include "metrics/Metrics.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace zilliz {
|
||||
namespace milvus {
|
||||
namespace cache {
|
||||
|
||||
template<typename ItemObj>
|
||||
template <typename ItemObj>
|
||||
class CacheMgr {
|
||||
public:
|
||||
virtual uint64_t ItemCount() const;
|
||||
public:
|
||||
virtual uint64_t
|
||||
ItemCount() const;
|
||||
|
||||
virtual bool ItemExists(const std::string& key);
|
||||
virtual bool
|
||||
ItemExists(const std::string& key);
|
||||
|
||||
virtual ItemObj GetItem(const std::string& key);
|
||||
virtual ItemObj
|
||||
GetItem(const std::string& key);
|
||||
|
||||
virtual void InsertItem(const std::string& key, const ItemObj& data);
|
||||
virtual void
|
||||
InsertItem(const std::string& key, const ItemObj& data);
|
||||
|
||||
virtual void EraseItem(const std::string& key);
|
||||
virtual void
|
||||
EraseItem(const std::string& key);
|
||||
|
||||
virtual void PrintInfo();
|
||||
virtual void
|
||||
PrintInfo();
|
||||
|
||||
virtual void ClearCache();
|
||||
virtual void
|
||||
ClearCache();
|
||||
|
||||
int64_t CacheUsage() const;
|
||||
int64_t CacheCapacity() const;
|
||||
void SetCapacity(int64_t capacity);
|
||||
int64_t
|
||||
CacheUsage() const;
|
||||
int64_t
|
||||
CacheCapacity() const;
|
||||
void
|
||||
SetCapacity(int64_t capacity);
|
||||
|
||||
protected:
|
||||
protected:
|
||||
CacheMgr();
|
||||
virtual ~CacheMgr();
|
||||
|
||||
protected:
|
||||
protected:
|
||||
using CachePtr = std::shared_ptr<Cache<ItemObj>>;
|
||||
CachePtr cache_;
|
||||
};
|
||||
|
||||
} // namespace cache
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#include "cache/CacheMgr.inl"
|
||||
#include "cache/CacheMgr.inl"
|
||||
|
|
|
@ -30,18 +30,20 @@ CacheMgr<ItemObj>::~CacheMgr() {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
uint64_t CacheMgr<ItemObj>::ItemCount() const {
|
||||
if(cache_ == nullptr) {
|
||||
uint64_t
|
||||
CacheMgr<ItemObj>::ItemCount() const {
|
||||
if (cache_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Cache doesn't exist";
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (uint64_t)(cache_->size());
|
||||
return (uint64_t) (cache_->size());
|
||||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
bool CacheMgr<ItemObj>::ItemExists(const std::string& key) {
|
||||
if(cache_ == nullptr) {
|
||||
bool
|
||||
CacheMgr<ItemObj>::ItemExists(const std::string &key) {
|
||||
if (cache_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Cache doesn't exist";
|
||||
return false;
|
||||
}
|
||||
|
@ -50,8 +52,9 @@ bool CacheMgr<ItemObj>::ItemExists(const std::string& key) {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
ItemObj CacheMgr<ItemObj>::GetItem(const std::string& key) {
|
||||
if(cache_ == nullptr) {
|
||||
ItemObj
|
||||
CacheMgr<ItemObj>::GetItem(const std::string &key) {
|
||||
if (cache_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Cache doesn't exist";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -60,8 +63,9 @@ ItemObj CacheMgr<ItemObj>::GetItem(const std::string& key) {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
void CacheMgr<ItemObj>::InsertItem(const std::string& key, const ItemObj& data) {
|
||||
if(cache_ == nullptr) {
|
||||
void
|
||||
CacheMgr<ItemObj>::InsertItem(const std::string &key, const ItemObj &data) {
|
||||
if (cache_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Cache doesn't exist";
|
||||
return;
|
||||
}
|
||||
|
@ -71,8 +75,9 @@ void CacheMgr<ItemObj>::InsertItem(const std::string& key, const ItemObj& data)
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
void CacheMgr<ItemObj>::EraseItem(const std::string& key) {
|
||||
if(cache_ == nullptr) {
|
||||
void
|
||||
CacheMgr<ItemObj>::EraseItem(const std::string &key) {
|
||||
if (cache_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Cache doesn't exist";
|
||||
return;
|
||||
}
|
||||
|
@ -82,8 +87,9 @@ void CacheMgr<ItemObj>::EraseItem(const std::string& key) {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
void CacheMgr<ItemObj>::PrintInfo() {
|
||||
if(cache_ == nullptr) {
|
||||
void
|
||||
CacheMgr<ItemObj>::PrintInfo() {
|
||||
if (cache_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Cache doesn't exist";
|
||||
return;
|
||||
}
|
||||
|
@ -92,8 +98,9 @@ void CacheMgr<ItemObj>::PrintInfo() {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
void CacheMgr<ItemObj>::ClearCache() {
|
||||
if(cache_ == nullptr) {
|
||||
void
|
||||
CacheMgr<ItemObj>::ClearCache() {
|
||||
if (cache_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Cache doesn't exist";
|
||||
return;
|
||||
}
|
||||
|
@ -102,8 +109,9 @@ void CacheMgr<ItemObj>::ClearCache() {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
int64_t CacheMgr<ItemObj>::CacheUsage() const {
|
||||
if(cache_ == nullptr) {
|
||||
int64_t
|
||||
CacheMgr<ItemObj>::CacheUsage() const {
|
||||
if (cache_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Cache doesn't exist";
|
||||
return 0;
|
||||
}
|
||||
|
@ -112,8 +120,9 @@ int64_t CacheMgr<ItemObj>::CacheUsage() const {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
int64_t CacheMgr<ItemObj>::CacheCapacity() const {
|
||||
if(cache_ == nullptr) {
|
||||
int64_t
|
||||
CacheMgr<ItemObj>::CacheCapacity() const {
|
||||
if (cache_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Cache doesn't exist";
|
||||
return 0;
|
||||
}
|
||||
|
@ -122,14 +131,15 @@ int64_t CacheMgr<ItemObj>::CacheCapacity() const {
|
|||
}
|
||||
|
||||
template<typename ItemObj>
|
||||
void CacheMgr<ItemObj>::SetCapacity(int64_t capacity) {
|
||||
if(cache_ == nullptr) {
|
||||
void
|
||||
CacheMgr<ItemObj>::SetCapacity(int64_t capacity) {
|
||||
if (cache_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Cache doesn't exist";
|
||||
return;
|
||||
}
|
||||
cache_->set_capacity(capacity);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace cache
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,58 +15,61 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include "CpuCacheMgr.h"
|
||||
#include "cache/CpuCacheMgr.h"
|
||||
#include "server/Config.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace zilliz {
|
||||
namespace milvus {
|
||||
namespace cache {
|
||||
|
||||
namespace {
|
||||
constexpr int64_t unit = 1024 * 1024 * 1024;
|
||||
constexpr int64_t unit = 1024 * 1024 * 1024;
|
||||
}
|
||||
|
||||
CpuCacheMgr::CpuCacheMgr() {
|
||||
server::Config& config = server::Config::GetInstance();
|
||||
Status s;
|
||||
|
||||
int32_t cpu_mem_cap;
|
||||
s = config.GetCacheConfigCpuMemCapacity(cpu_mem_cap);
|
||||
int32_t cpu_cache_cap;
|
||||
s = config.GetCacheConfigCpuCacheCapacity(cpu_cache_cap);
|
||||
if (!s.ok()) {
|
||||
SERVER_LOG_ERROR << s.message();
|
||||
}
|
||||
int64_t cap = cpu_mem_cap * unit;
|
||||
cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL<<32);
|
||||
int64_t cap = cpu_cache_cap * unit;
|
||||
cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL << 32);
|
||||
|
||||
float cpu_mem_threshold;
|
||||
s = config.GetCacheConfigCpuMemThreshold(cpu_mem_threshold);
|
||||
float cpu_cache_threshold;
|
||||
s = config.GetCacheConfigCpuCacheThreshold(cpu_cache_threshold);
|
||||
if (!s.ok()) {
|
||||
SERVER_LOG_ERROR << s.message();
|
||||
}
|
||||
if (cpu_mem_threshold > 0.0 && cpu_mem_threshold <= 1.0) {
|
||||
cache_->set_freemem_percent(cpu_mem_threshold);
|
||||
if (cpu_cache_threshold > 0.0 && cpu_cache_threshold <= 1.0) {
|
||||
cache_->set_freemem_percent(cpu_cache_threshold);
|
||||
} else {
|
||||
SERVER_LOG_ERROR << "Invalid cpu_mem_threshold: " << cpu_mem_threshold
|
||||
<< ", by default set to " << cache_->freemem_percent();
|
||||
SERVER_LOG_ERROR << "Invalid cpu_cache_threshold: " << cpu_cache_threshold << ", by default set to "
|
||||
<< cache_->freemem_percent();
|
||||
}
|
||||
}
|
||||
|
||||
CpuCacheMgr* CpuCacheMgr::GetInstance() {
|
||||
CpuCacheMgr*
|
||||
CpuCacheMgr::GetInstance() {
|
||||
static CpuCacheMgr s_mgr;
|
||||
return &s_mgr;
|
||||
}
|
||||
|
||||
engine::VecIndexPtr CpuCacheMgr::GetIndex(const std::string& key) {
|
||||
engine::VecIndexPtr
|
||||
CpuCacheMgr::GetIndex(const std::string& key) {
|
||||
DataObjPtr obj = GetItem(key);
|
||||
if(obj != nullptr) {
|
||||
if (obj != nullptr) {
|
||||
return obj->data();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace cache
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -20,21 +20,26 @@
|
|||
#include "CacheMgr.h"
|
||||
#include "DataObj.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace zilliz {
|
||||
namespace milvus {
|
||||
namespace cache {
|
||||
|
||||
class CpuCacheMgr : public CacheMgr<DataObjPtr> {
|
||||
private:
|
||||
private:
|
||||
CpuCacheMgr();
|
||||
|
||||
public:
|
||||
//TODO: use smart pointer instead
|
||||
static CpuCacheMgr* GetInstance();
|
||||
public:
|
||||
// TODO(myh): use smart pointer instead
|
||||
static CpuCacheMgr*
|
||||
GetInstance();
|
||||
|
||||
engine::VecIndexPtr GetIndex(const std::string& key);
|
||||
engine::VecIndexPtr
|
||||
GetIndex(const std::string& key);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace cache
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "src/wrapper/VecIndex.h"
|
||||
|
@ -27,38 +26,43 @@ namespace milvus {
|
|||
namespace cache {
|
||||
|
||||
class DataObj {
|
||||
public:
|
||||
DataObj(const engine::VecIndexPtr& index)
|
||||
: index_(index)
|
||||
{}
|
||||
public:
|
||||
explicit DataObj(const engine::VecIndexPtr& index) : index_(index) {
|
||||
}
|
||||
|
||||
DataObj(const engine::VecIndexPtr& index, int64_t size)
|
||||
: index_(index),
|
||||
size_(size)
|
||||
{}
|
||||
DataObj(const engine::VecIndexPtr& index, int64_t size) : index_(index), size_(size) {
|
||||
}
|
||||
|
||||
engine::VecIndexPtr data() { return index_; }
|
||||
const engine::VecIndexPtr& data() const { return index_; }
|
||||
engine::VecIndexPtr
|
||||
data() {
|
||||
return index_;
|
||||
}
|
||||
|
||||
int64_t size() const {
|
||||
if(index_ == nullptr) {
|
||||
const engine::VecIndexPtr&
|
||||
data() const {
|
||||
return index_;
|
||||
}
|
||||
|
||||
int64_t
|
||||
size() const {
|
||||
if (index_ == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if(size_ > 0) {
|
||||
if (size_ > 0) {
|
||||
return size_;
|
||||
}
|
||||
|
||||
return index_->Count() * index_->Dimension() * sizeof(float);
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
engine::VecIndexPtr index_ = nullptr;
|
||||
int64_t size_ = 0;
|
||||
};
|
||||
|
||||
using DataObjPtr = std::shared_ptr<DataObj>;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace cache
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,11 +15,12 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include "cache/GpuCacheMgr.h"
|
||||
#include "server/Config.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
#include <sstream>
|
||||
#include "utils/Log.h"
|
||||
#include "GpuCacheMgr.h"
|
||||
#include "server/Config.h"
|
||||
#include <utility>
|
||||
|
||||
namespace zilliz {
|
||||
namespace milvus {
|
||||
|
@ -29,35 +30,36 @@ std::mutex GpuCacheMgr::mutex_;
|
|||
std::unordered_map<uint64_t, GpuCacheMgrPtr> GpuCacheMgr::instance_;
|
||||
|
||||
namespace {
|
||||
constexpr int64_t G_BYTE = 1024 * 1024 * 1024;
|
||||
constexpr int64_t G_BYTE = 1024 * 1024 * 1024;
|
||||
}
|
||||
|
||||
GpuCacheMgr::GpuCacheMgr() {
|
||||
server::Config& config = server::Config::GetInstance();
|
||||
Status s;
|
||||
|
||||
int32_t gpu_mem_cap;
|
||||
s = config.GetCacheConfigGpuMemCapacity(gpu_mem_cap);
|
||||
int32_t gpu_cache_cap;
|
||||
s = config.GetCacheConfigGpuCacheCapacity(gpu_cache_cap);
|
||||
if (!s.ok()) {
|
||||
SERVER_LOG_ERROR << s.message();
|
||||
}
|
||||
int32_t cap = gpu_mem_cap * G_BYTE;
|
||||
cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL<<32);
|
||||
int32_t cap = gpu_cache_cap * G_BYTE;
|
||||
cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL << 32);
|
||||
|
||||
float gpu_mem_threshold;
|
||||
s = config.GetCacheConfigGpuMemThreshold(gpu_mem_threshold);
|
||||
s = config.GetCacheConfigGpuCacheThreshold(gpu_mem_threshold);
|
||||
if (!s.ok()) {
|
||||
SERVER_LOG_ERROR << s.message();
|
||||
}
|
||||
if (gpu_mem_threshold > 0.0 && gpu_mem_threshold <= 1.0) {
|
||||
cache_->set_freemem_percent(gpu_mem_threshold);
|
||||
} else {
|
||||
SERVER_LOG_ERROR << "Invalid gpu_mem_threshold: " << gpu_mem_threshold
|
||||
<< ", by default set to " << cache_->freemem_percent();
|
||||
SERVER_LOG_ERROR << "Invalid gpu_mem_threshold: " << gpu_mem_threshold << ", by default set to "
|
||||
<< cache_->freemem_percent();
|
||||
}
|
||||
}
|
||||
|
||||
GpuCacheMgr* GpuCacheMgr::GetInstance(uint64_t gpu_id) {
|
||||
GpuCacheMgr*
|
||||
GpuCacheMgr::GetInstance(uint64_t gpu_id) {
|
||||
if (instance_.find(gpu_id) == instance_.end()) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (instance_.find(gpu_id) == instance_.end()) {
|
||||
|
@ -70,15 +72,16 @@ GpuCacheMgr* GpuCacheMgr::GetInstance(uint64_t gpu_id) {
|
|||
}
|
||||
}
|
||||
|
||||
engine::VecIndexPtr GpuCacheMgr::GetIndex(const std::string& key) {
|
||||
engine::VecIndexPtr
|
||||
GpuCacheMgr::GetIndex(const std::string& key) {
|
||||
DataObjPtr obj = GetItem(key);
|
||||
if(obj != nullptr) {
|
||||
if (obj != nullptr) {
|
||||
return obj->data();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace cache
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include "CacheMgr.h"
|
||||
#include "DataObj.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace zilliz {
|
||||
namespace milvus {
|
||||
|
@ -30,18 +30,20 @@ class GpuCacheMgr;
|
|||
using GpuCacheMgrPtr = std::shared_ptr<GpuCacheMgr>;
|
||||
|
||||
class GpuCacheMgr : public CacheMgr<DataObjPtr> {
|
||||
public:
|
||||
public:
|
||||
GpuCacheMgr();
|
||||
|
||||
static GpuCacheMgr* GetInstance(uint64_t gpu_id);
|
||||
static GpuCacheMgr*
|
||||
GetInstance(uint64_t gpu_id);
|
||||
|
||||
engine::VecIndexPtr GetIndex(const std::string& key);
|
||||
engine::VecIndexPtr
|
||||
GetIndex(const std::string& key);
|
||||
|
||||
private:
|
||||
private:
|
||||
static std::mutex mutex_;
|
||||
static std::unordered_map<uint64_t, GpuCacheMgrPtr> instance_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace cache
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,28 +15,30 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <list>
|
||||
#include <cstddef>
|
||||
#include <list>
|
||||
#include <stdexcept>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
namespace zilliz {
|
||||
namespace milvus {
|
||||
namespace cache {
|
||||
|
||||
template<typename key_t, typename value_t>
|
||||
template <typename key_t, typename value_t>
|
||||
class LRU {
|
||||
public:
|
||||
public:
|
||||
typedef typename std::pair<key_t, value_t> key_value_pair_t;
|
||||
typedef typename std::list<key_value_pair_t>::iterator list_iterator_t;
|
||||
typedef typename std::list<key_value_pair_t>::reverse_iterator reverse_list_iterator_t;
|
||||
|
||||
LRU(size_t max_size) : max_size_(max_size) {}
|
||||
explicit LRU(size_t max_size) : max_size_(max_size) {
|
||||
}
|
||||
|
||||
void put(const key_t& key, const value_t& value) {
|
||||
void
|
||||
put(const key_t& key, const value_t& value) {
|
||||
auto it = cache_items_map_.find(key);
|
||||
cache_items_list_.push_front(key_value_pair_t(key, value));
|
||||
if (it != cache_items_map_.end()) {
|
||||
|
@ -53,7 +55,8 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
const value_t& get(const key_t& key) {
|
||||
const value_t&
|
||||
get(const key_t& key) {
|
||||
auto it = cache_items_map_.find(key);
|
||||
if (it == cache_items_map_.end()) {
|
||||
throw std::range_error("There is no such key in cache");
|
||||
|
@ -63,7 +66,8 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
void erase(const key_t& key) {
|
||||
void
|
||||
erase(const key_t& key) {
|
||||
auto it = cache_items_map_.find(key);
|
||||
if (it != cache_items_map_.end()) {
|
||||
cache_items_list_.erase(it->second);
|
||||
|
@ -71,44 +75,50 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
bool exists(const key_t& key) const {
|
||||
bool
|
||||
exists(const key_t& key) const {
|
||||
return cache_items_map_.find(key) != cache_items_map_.end();
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
size_t
|
||||
size() const {
|
||||
return cache_items_map_.size();
|
||||
}
|
||||
|
||||
list_iterator_t begin() {
|
||||
list_iterator_t
|
||||
begin() {
|
||||
iter_ = cache_items_list_.begin();
|
||||
return iter_;
|
||||
}
|
||||
|
||||
list_iterator_t end() {
|
||||
list_iterator_t
|
||||
end() {
|
||||
return cache_items_list_.end();
|
||||
}
|
||||
|
||||
reverse_list_iterator_t rbegin() {
|
||||
reverse_list_iterator_t
|
||||
rbegin() {
|
||||
return cache_items_list_.rbegin();
|
||||
}
|
||||
|
||||
reverse_list_iterator_t rend() {
|
||||
reverse_list_iterator_t
|
||||
rend() {
|
||||
return cache_items_list_.rend();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
void
|
||||
clear() {
|
||||
cache_items_list_.clear();
|
||||
cache_items_map_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
std::list<key_value_pair_t> cache_items_list_;
|
||||
std::unordered_map<key_t, list_iterator_t> cache_items_map_;
|
||||
size_t max_size_;
|
||||
list_iterator_t iter_;
|
||||
};
|
||||
|
||||
} // cache
|
||||
} // milvus
|
||||
} // zilliz
|
||||
|
||||
} // namespace cache
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -22,12 +22,12 @@ namespace zilliz {
|
|||
namespace milvus {
|
||||
namespace server {
|
||||
|
||||
ConfigMgr *
|
||||
ConfigMgr*
|
||||
ConfigMgr::GetInstance() {
|
||||
static YamlConfigMgr mgr;
|
||||
return &mgr;
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "utils/Error.h"
|
||||
#include "ConfigNode.h"
|
||||
#include "utils/Error.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
|
@ -42,16 +42,22 @@ namespace server {
|
|||
|
||||
class ConfigMgr {
|
||||
public:
|
||||
static ConfigMgr *GetInstance();
|
||||
static ConfigMgr*
|
||||
GetInstance();
|
||||
|
||||
virtual ErrorCode LoadConfigFile(const std::string &filename) = 0;
|
||||
virtual void Print() const = 0;//will be deleted
|
||||
virtual std::string DumpString() const = 0;
|
||||
virtual ErrorCode
|
||||
LoadConfigFile(const std::string& filename) = 0;
|
||||
virtual void
|
||||
Print() const = 0; // will be deleted
|
||||
virtual std::string
|
||||
DumpString() const = 0;
|
||||
|
||||
virtual const ConfigNode &GetRootNode() const = 0;
|
||||
virtual ConfigNode &GetRootNode() = 0;
|
||||
virtual const ConfigNode&
|
||||
GetRootNode() const = 0;
|
||||
virtual ConfigNode&
|
||||
GetRootNode() = 0;
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -19,51 +19,51 @@
|
|||
#include "utils/Error.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
|
||||
namespace zilliz {
|
||||
namespace milvus {
|
||||
namespace server {
|
||||
|
||||
void
|
||||
ConfigNode::Combine(const ConfigNode &target) {
|
||||
const std::map<std::string, std::string> &kv = target.GetConfig();
|
||||
ConfigNode::Combine(const ConfigNode& target) {
|
||||
const std::map<std::string, std::string>& kv = target.GetConfig();
|
||||
for (auto itr = kv.begin(); itr != kv.end(); ++itr) {
|
||||
config_[itr->first] = itr->second;
|
||||
}
|
||||
|
||||
const std::map<std::string, std::vector<std::string> > &sequences = target.GetSequences();
|
||||
const std::map<std::string, std::vector<std::string> >& sequences = target.GetSequences();
|
||||
for (auto itr = sequences.begin(); itr != sequences.end(); ++itr) {
|
||||
sequences_[itr->first] = itr->second;
|
||||
}
|
||||
|
||||
const std::map<std::string, ConfigNode> &children = target.GetChildren();
|
||||
const std::map<std::string, ConfigNode>& children = target.GetChildren();
|
||||
for (auto itr = children.begin(); itr != children.end(); ++itr) {
|
||||
children_[itr->first] = itr->second;
|
||||
}
|
||||
}
|
||||
|
||||
//key/value pair config
|
||||
// key/value pair config
|
||||
void
|
||||
ConfigNode::SetValue(const std::string &key, const std::string &value) {
|
||||
ConfigNode::SetValue(const std::string& key, const std::string& value) {
|
||||
config_[key] = value;
|
||||
}
|
||||
|
||||
std::string
|
||||
ConfigNode::GetValue(const std::string ¶m_key, const std::string &default_val) const {
|
||||
ConfigNode::GetValue(const std::string& param_key, const std::string& default_val) const {
|
||||
auto ref = config_.find(param_key);
|
||||
if (ref != config_.end()) {
|
||||
return ref->second;
|
||||
}
|
||||
|
||||
//THROW_UNEXPECTED_ERROR("Can't find parameter key: " + param_key);
|
||||
// THROW_UNEXPECTED_ERROR("Can't find parameter key: " + param_key);
|
||||
return default_val;
|
||||
}
|
||||
|
||||
bool
|
||||
ConfigNode::GetBoolValue(const std::string ¶m_key, bool default_val) const {
|
||||
ConfigNode::GetBoolValue(const std::string& param_key, bool default_val) const {
|
||||
std::string val = GetValue(param_key);
|
||||
if (!val.empty()) {
|
||||
std::transform(val.begin(), val.end(), val.begin(), ::tolower);
|
||||
|
@ -74,17 +74,17 @@ ConfigNode::GetBoolValue(const std::string ¶m_key, bool default_val) const {
|
|||
}
|
||||
|
||||
int32_t
|
||||
ConfigNode::GetInt32Value(const std::string ¶m_key, int32_t default_val) const {
|
||||
ConfigNode::GetInt32Value(const std::string& param_key, int32_t default_val) const {
|
||||
std::string val = GetValue(param_key);
|
||||
if (!val.empty()) {
|
||||
return (int32_t) std::strtol(val.c_str(), nullptr, 10);
|
||||
return (int32_t)std::strtol(val.c_str(), nullptr, 10);
|
||||
} else {
|
||||
return default_val;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t
|
||||
ConfigNode::GetInt64Value(const std::string ¶m_key, int64_t default_val) const {
|
||||
ConfigNode::GetInt64Value(const std::string& param_key, int64_t default_val) const {
|
||||
std::string val = GetValue(param_key);
|
||||
if (!val.empty()) {
|
||||
return std::strtol(val.c_str(), nullptr, 10);
|
||||
|
@ -94,7 +94,7 @@ ConfigNode::GetInt64Value(const std::string ¶m_key, int64_t default_val) con
|
|||
}
|
||||
|
||||
float
|
||||
ConfigNode::GetFloatValue(const std::string ¶m_key, float default_val) const {
|
||||
ConfigNode::GetFloatValue(const std::string& param_key, float default_val) const {
|
||||
std::string val = GetValue(param_key);
|
||||
if (!val.empty()) {
|
||||
return std::strtof(val.c_str(), nullptr);
|
||||
|
@ -104,7 +104,7 @@ ConfigNode::GetFloatValue(const std::string ¶m_key, float default_val) const
|
|||
}
|
||||
|
||||
double
|
||||
ConfigNode::GetDoubleValue(const std::string ¶m_key, double default_val) const {
|
||||
ConfigNode::GetDoubleValue(const std::string& param_key, double default_val) const {
|
||||
std::string val = GetValue(param_key);
|
||||
if (!val.empty()) {
|
||||
return std::strtod(val.c_str(), nullptr);
|
||||
|
@ -113,7 +113,7 @@ ConfigNode::GetDoubleValue(const std::string ¶m_key, double default_val) con
|
|||
}
|
||||
}
|
||||
|
||||
const std::map<std::string, std::string> &
|
||||
const std::map<std::string, std::string>&
|
||||
ConfigNode::GetConfig() const {
|
||||
return config_;
|
||||
}
|
||||
|
@ -123,14 +123,14 @@ ConfigNode::ClearConfig() {
|
|||
config_.clear();
|
||||
}
|
||||
|
||||
//key/object config
|
||||
// key/object config
|
||||
void
|
||||
ConfigNode::AddChild(const std::string &type_name, const ConfigNode &config) {
|
||||
ConfigNode::AddChild(const std::string& type_name, const ConfigNode& config) {
|
||||
children_[type_name] = config;
|
||||
}
|
||||
|
||||
ConfigNode
|
||||
ConfigNode::GetChild(const std::string &type_name) const {
|
||||
ConfigNode::GetChild(const std::string& type_name) const {
|
||||
auto ref = children_.find(type_name);
|
||||
if (ref != children_.end()) {
|
||||
return ref->second;
|
||||
|
@ -140,20 +140,20 @@ ConfigNode::GetChild(const std::string &type_name) const {
|
|||
return nc;
|
||||
}
|
||||
|
||||
ConfigNode &
|
||||
ConfigNode::GetChild(const std::string &type_name) {
|
||||
ConfigNode&
|
||||
ConfigNode::GetChild(const std::string& type_name) {
|
||||
return children_[type_name];
|
||||
}
|
||||
|
||||
void
|
||||
ConfigNode::GetChildren(ConfigNodeArr &arr) const {
|
||||
ConfigNode::GetChildren(ConfigNodeArr& arr) const {
|
||||
arr.clear();
|
||||
for (auto ref : children_) {
|
||||
arr.push_back(ref.second);
|
||||
}
|
||||
}
|
||||
|
||||
const std::map<std::string, ConfigNode> &
|
||||
const std::map<std::string, ConfigNode>&
|
||||
ConfigNode::GetChildren() const {
|
||||
return children_;
|
||||
}
|
||||
|
@ -163,14 +163,14 @@ ConfigNode::ClearChildren() {
|
|||
children_.clear();
|
||||
}
|
||||
|
||||
//key/sequence config
|
||||
// key/sequence config
|
||||
void
|
||||
ConfigNode::AddSequenceItem(const std::string &key, const std::string &item) {
|
||||
ConfigNode::AddSequenceItem(const std::string& key, const std::string& item) {
|
||||
sequences_[key].push_back(item);
|
||||
}
|
||||
|
||||
std::vector<std::string>
|
||||
ConfigNode::GetSequence(const std::string &key) const {
|
||||
ConfigNode::GetSequence(const std::string& key) const {
|
||||
auto itr = sequences_.find(key);
|
||||
if (itr != sequences_.end()) {
|
||||
return itr->second;
|
||||
|
@ -180,7 +180,7 @@ ConfigNode::GetSequence(const std::string &key) const {
|
|||
}
|
||||
}
|
||||
|
||||
const std::map<std::string, std::vector<std::string> > &
|
||||
const std::map<std::string, std::vector<std::string> >&
|
||||
ConfigNode::GetSequences() const {
|
||||
return sequences_;
|
||||
}
|
||||
|
@ -191,40 +191,40 @@ ConfigNode::ClearSequences() {
|
|||
}
|
||||
|
||||
void
|
||||
ConfigNode::PrintAll(const std::string &prefix) const {
|
||||
for (auto &elem : config_) {
|
||||
ConfigNode::PrintAll(const std::string& prefix) const {
|
||||
for (auto& elem : config_) {
|
||||
SERVER_LOG_INFO << prefix << elem.first + ": " << elem.second;
|
||||
}
|
||||
|
||||
for (auto &elem : sequences_) {
|
||||
for (auto& elem : sequences_) {
|
||||
SERVER_LOG_INFO << prefix << elem.first << ": ";
|
||||
for (auto &str : elem.second) {
|
||||
for (auto& str : elem.second) {
|
||||
SERVER_LOG_INFO << prefix << " - " << str;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &elem : children_) {
|
||||
for (auto& elem : children_) {
|
||||
SERVER_LOG_INFO << prefix << elem.first << ": ";
|
||||
elem.second.PrintAll(prefix + " ");
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
ConfigNode::DumpString(const std::string &prefix) const {
|
||||
ConfigNode::DumpString(const std::string& prefix) const {
|
||||
std::stringstream str_buffer;
|
||||
const std::string endl = "\n";
|
||||
for (auto &elem : config_) {
|
||||
for (auto& elem : config_) {
|
||||
str_buffer << prefix << elem.first << ": " << elem.second << endl;
|
||||
}
|
||||
|
||||
for (auto &elem : sequences_) {
|
||||
for (auto& elem : sequences_) {
|
||||
str_buffer << prefix << elem.first << ": " << endl;
|
||||
for (auto &str : elem.second) {
|
||||
for (auto& str : elem.second) {
|
||||
str_buffer << prefix + " - " << str << endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &elem : children_) {
|
||||
for (auto& elem : children_) {
|
||||
str_buffer << prefix << elem.first << ": " << endl;
|
||||
str_buffer << elem.second.DumpString(prefix + " ") << endl;
|
||||
}
|
||||
|
@ -232,6 +232,6 @@ ConfigNode::DumpString(const std::string &prefix) const {
|
|||
return str_buffer.str();
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace zilliz {
|
||||
namespace milvus {
|
||||
|
@ -30,39 +30,61 @@ typedef std::vector<ConfigNode> ConfigNodeArr;
|
|||
|
||||
class ConfigNode {
|
||||
public:
|
||||
void Combine(const ConfigNode &target);
|
||||
void
|
||||
Combine(const ConfigNode& target);
|
||||
|
||||
//key/value pair config
|
||||
void SetValue(const std::string &key, const std::string &value);
|
||||
// key/value pair config
|
||||
void
|
||||
SetValue(const std::string& key, const std::string& value);
|
||||
|
||||
std::string GetValue(const std::string ¶m_key, const std::string &default_val = "") const;
|
||||
bool GetBoolValue(const std::string ¶m_key, bool default_val = false) const;
|
||||
int32_t GetInt32Value(const std::string ¶m_key, int32_t default_val = 0) const;
|
||||
int64_t GetInt64Value(const std::string ¶m_key, int64_t default_val = 0) const;
|
||||
float GetFloatValue(const std::string ¶m_key, float default_val = 0.0) const;
|
||||
double GetDoubleValue(const std::string ¶m_key, double default_val = 0.0) const;
|
||||
std::string
|
||||
GetValue(const std::string& param_key, const std::string& default_val = "") const;
|
||||
bool
|
||||
GetBoolValue(const std::string& param_key, bool default_val = false) const;
|
||||
int32_t
|
||||
GetInt32Value(const std::string& param_key, int32_t default_val = 0) const;
|
||||
int64_t
|
||||
GetInt64Value(const std::string& param_key, int64_t default_val = 0) const;
|
||||
float
|
||||
GetFloatValue(const std::string& param_key, float default_val = 0.0) const;
|
||||
double
|
||||
GetDoubleValue(const std::string& param_key, double default_val = 0.0) const;
|
||||
|
||||
const std::map<std::string, std::string> &GetConfig() const;
|
||||
void ClearConfig();
|
||||
const std::map<std::string, std::string>&
|
||||
GetConfig() const;
|
||||
void
|
||||
ClearConfig();
|
||||
|
||||
//key/object config
|
||||
void AddChild(const std::string &type_name, const ConfigNode &config);
|
||||
ConfigNode GetChild(const std::string &type_name) const;
|
||||
ConfigNode &GetChild(const std::string &type_name);
|
||||
void GetChildren(ConfigNodeArr &arr) const;
|
||||
// key/object config
|
||||
void
|
||||
AddChild(const std::string& type_name, const ConfigNode& config);
|
||||
ConfigNode
|
||||
GetChild(const std::string& type_name) const;
|
||||
ConfigNode&
|
||||
GetChild(const std::string& type_name);
|
||||
void
|
||||
GetChildren(ConfigNodeArr& arr) const;
|
||||
|
||||
const std::map<std::string, ConfigNode> &GetChildren() const;
|
||||
void ClearChildren();
|
||||
const std::map<std::string, ConfigNode>&
|
||||
GetChildren() const;
|
||||
void
|
||||
ClearChildren();
|
||||
|
||||
//key/sequence config
|
||||
void AddSequenceItem(const std::string &key, const std::string &item);
|
||||
std::vector<std::string> GetSequence(const std::string &key) const;
|
||||
// key/sequence config
|
||||
void
|
||||
AddSequenceItem(const std::string& key, const std::string& item);
|
||||
std::vector<std::string>
|
||||
GetSequence(const std::string& key) const;
|
||||
|
||||
const std::map<std::string, std::vector<std::string> > &GetSequences() const;
|
||||
void ClearSequences();
|
||||
const std::map<std::string, std::vector<std::string> >&
|
||||
GetSequences() const;
|
||||
void
|
||||
ClearSequences();
|
||||
|
||||
void PrintAll(const std::string &prefix = "") const;
|
||||
std::string DumpString(const std::string &prefix = "") const;
|
||||
void
|
||||
PrintAll(const std::string& prefix = "") const;
|
||||
std::string
|
||||
DumpString(const std::string& prefix = "") const;
|
||||
|
||||
private:
|
||||
std::map<std::string, std::string> config_;
|
||||
|
@ -70,6 +92,6 @@ class ConfigNode {
|
|||
std::map<std::string, std::vector<std::string> > sequences_;
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -25,7 +25,7 @@ namespace milvus {
|
|||
namespace server {
|
||||
|
||||
ErrorCode
|
||||
YamlConfigMgr::LoadConfigFile(const std::string &filename) {
|
||||
YamlConfigMgr::LoadConfigFile(const std::string& filename) {
|
||||
struct stat directoryStat;
|
||||
int statOK = stat(filename.c_str(), &directoryStat);
|
||||
if (statOK != 0) {
|
||||
|
@ -36,8 +36,7 @@ YamlConfigMgr::LoadConfigFile(const std::string &filename) {
|
|||
try {
|
||||
node_ = YAML::LoadFile(filename);
|
||||
LoadConfigNode(node_, config_);
|
||||
}
|
||||
catch (YAML::Exception &e) {
|
||||
} catch (YAML::Exception& e) {
|
||||
SERVER_LOG_ERROR << "Failed to load config file: " << std::string(e.what());
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
|
@ -56,20 +55,18 @@ YamlConfigMgr::DumpString() const {
|
|||
return config_.DumpString("");
|
||||
}
|
||||
|
||||
const ConfigNode &
|
||||
const ConfigNode&
|
||||
YamlConfigMgr::GetRootNode() const {
|
||||
return config_;
|
||||
}
|
||||
|
||||
ConfigNode &
|
||||
ConfigNode&
|
||||
YamlConfigMgr::GetRootNode() {
|
||||
return config_;
|
||||
}
|
||||
|
||||
bool
|
||||
YamlConfigMgr::SetConfigValue(const YAML::Node &node,
|
||||
const std::string &key,
|
||||
ConfigNode &config) {
|
||||
YamlConfigMgr::SetConfigValue(const YAML::Node& node, const std::string& key, ConfigNode& config) {
|
||||
if (node[key].IsDefined()) {
|
||||
config.SetValue(key, node[key].as<std::string>());
|
||||
return true;
|
||||
|
@ -78,9 +75,7 @@ YamlConfigMgr::SetConfigValue(const YAML::Node &node,
|
|||
}
|
||||
|
||||
bool
|
||||
YamlConfigMgr::SetChildConfig(const YAML::Node &node,
|
||||
const std::string &child_name,
|
||||
ConfigNode &config) {
|
||||
YamlConfigMgr::SetChildConfig(const YAML::Node& node, const std::string& child_name, ConfigNode& config) {
|
||||
if (node[child_name].IsDefined()) {
|
||||
ConfigNode sub_config;
|
||||
LoadConfigNode(node[child_name], sub_config);
|
||||
|
@ -91,9 +86,7 @@ YamlConfigMgr::SetChildConfig(const YAML::Node &node,
|
|||
}
|
||||
|
||||
bool
|
||||
YamlConfigMgr::SetSequence(const YAML::Node &node,
|
||||
const std::string &child_name,
|
||||
ConfigNode &config) {
|
||||
YamlConfigMgr::SetSequence(const YAML::Node& node, const std::string& child_name, ConfigNode& config) {
|
||||
if (node[child_name].IsDefined()) {
|
||||
size_t cnt = node[child_name].size();
|
||||
for (size_t i = 0; i < cnt; i++) {
|
||||
|
@ -105,7 +98,7 @@ YamlConfigMgr::SetSequence(const YAML::Node &node,
|
|||
}
|
||||
|
||||
void
|
||||
YamlConfigMgr::LoadConfigNode(const YAML::Node &node, ConfigNode &config) {
|
||||
YamlConfigMgr::LoadConfigNode(const YAML::Node& node, ConfigNode& config) {
|
||||
std::string key;
|
||||
for (YAML::const_iterator it = node.begin(); it != node.end(); ++it) {
|
||||
if (!it->first.IsNull()) {
|
||||
|
@ -121,6 +114,6 @@ YamlConfigMgr::LoadConfigNode(const YAML::Node &node, ConfigNode &config) {
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -21,8 +21,8 @@
|
|||
#include "ConfigNode.h"
|
||||
#include "utils/Error.h"
|
||||
|
||||
#include <string>
|
||||
#include <yaml-cpp/yaml.h>
|
||||
#include <string>
|
||||
|
||||
namespace zilliz {
|
||||
namespace milvus {
|
||||
|
@ -30,34 +30,36 @@ namespace server {
|
|||
|
||||
class YamlConfigMgr : public ConfigMgr {
|
||||
public:
|
||||
virtual ErrorCode LoadConfigFile(const std::string &filename);
|
||||
virtual void Print() const;
|
||||
virtual std::string DumpString() const;
|
||||
virtual ErrorCode
|
||||
LoadConfigFile(const std::string& filename);
|
||||
virtual void
|
||||
Print() const;
|
||||
virtual std::string
|
||||
DumpString() const;
|
||||
|
||||
virtual const ConfigNode &GetRootNode() const;
|
||||
virtual ConfigNode &GetRootNode();
|
||||
virtual const ConfigNode&
|
||||
GetRootNode() const;
|
||||
virtual ConfigNode&
|
||||
GetRootNode();
|
||||
|
||||
private:
|
||||
bool SetConfigValue(const YAML::Node &node,
|
||||
const std::string &key,
|
||||
ConfigNode &config);
|
||||
|
||||
bool SetChildConfig(const YAML::Node &node,
|
||||
const std::string &name,
|
||||
ConfigNode &config);
|
||||
bool
|
||||
SetConfigValue(const YAML::Node& node, const std::string& key, ConfigNode& config);
|
||||
|
||||
bool
|
||||
SetSequence(const YAML::Node &node,
|
||||
const std::string &child_name,
|
||||
ConfigNode &config);
|
||||
SetChildConfig(const YAML::Node& node, const std::string& child_name, ConfigNode& config);
|
||||
|
||||
void LoadConfigNode(const YAML::Node &node, ConfigNode &config);
|
||||
bool
|
||||
SetSequence(const YAML::Node& node, const std::string& child_name, ConfigNode& config);
|
||||
|
||||
void
|
||||
LoadConfigNode(const YAML::Node& node, ConfigNode& config);
|
||||
|
||||
private:
|
||||
YAML::Node node_;
|
||||
ConfigNode config_;
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -46,9 +46,11 @@ if(NOT CMAKE_BUILD_TYPE)
|
|||
endif(NOT CMAKE_BUILD_TYPE)
|
||||
|
||||
if(CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -fopenmp")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -DELPP_THREAD_SAFE -fopenmp")
|
||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -fopenmp")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp")
|
||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O0 -g")
|
||||
endif()
|
||||
MESSAGE(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS})
|
||||
|
||||
|
@ -93,7 +95,7 @@ endif()
|
|||
set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE)
|
||||
|
||||
if(BUILD_UNIT_TEST STREQUAL "ON")
|
||||
# add_subdirectory(test)
|
||||
add_subdirectory(unittest)
|
||||
endif()
|
||||
|
||||
config_summary()
|
||||
|
|
|
@ -8,7 +8,6 @@ link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
|||
include_directories(${CORE_SOURCE_DIR}/knowhere)
|
||||
include_directories(${CORE_SOURCE_DIR}/thirdparty)
|
||||
include_directories(${CORE_SOURCE_DIR}/thirdparty/SPTAG/AnnService)
|
||||
include_directories(${CORE_SOURCE_DIR}/thirdparty/jsoncons-0.126.0/include)
|
||||
|
||||
set(SPTAG_SOURCE_DIR ${CORE_SOURCE_DIR}/thirdparty/SPTAG)
|
||||
file(GLOB HDR_FILES
|
||||
|
@ -55,6 +54,7 @@ set(index_srcs
|
|||
knowhere/index/vector_index/IndexGPUIVFPQ.cpp
|
||||
knowhere/index/vector_index/FaissBaseIndex.cpp
|
||||
knowhere/index/vector_index/helpers/FaissIO.cpp
|
||||
knowhere/index/vector_index/helpers/IndexParameter.cpp
|
||||
)
|
||||
|
||||
set(depend_libs
|
||||
|
@ -117,7 +117,6 @@ set(CORE_INCLUDE_DIRS
|
|||
${CORE_SOURCE_DIR}/knowhere
|
||||
${CORE_SOURCE_DIR}/thirdparty
|
||||
${CORE_SOURCE_DIR}/thirdparty/SPTAG/AnnService
|
||||
${CORE_SOURCE_DIR}/thirdparty/jsoncons-0.126.0/include
|
||||
${ARROW_INCLUDE_DIR}
|
||||
${FAISS_INCLUDE_DIR}
|
||||
${OPENBLAS_INCLUDE_DIR}
|
||||
|
@ -129,8 +128,6 @@ set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE)
|
|||
|
||||
#INSTALL(DIRECTORY
|
||||
# ${CORE_SOURCE_DIR}/include/knowhere
|
||||
# ${CORE_SOURCE_DIR}/thirdparty/jsoncons-0.126.0/include/jsoncons
|
||||
# ${CORE_SOURCE_DIR}/thirdparty/jsoncons-0.126.0/include/jsoncons_ext
|
||||
# ${ARROW_INCLUDE_DIR}/arrow
|
||||
# ${FAISS_PREFIX}/include/faiss
|
||||
# ${OPENBLAS_INCLUDE_DIR}/
|
||||
|
|
|
@ -15,42 +15,41 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include "ArrowAdapter.h"
|
||||
#include "knowhere/adapter/ArrowAdapter.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
ArrayPtr
|
||||
CopyArray(const ArrayPtr &origin) {
|
||||
CopyArray(const ArrayPtr& origin) {
|
||||
ArrayPtr copy = nullptr;
|
||||
auto copy_data = origin->data()->Copy();
|
||||
switch (origin->type_id()) {
|
||||
#define DEFINE_TYPE(type, clazz) \
|
||||
case arrow::Type::type: { \
|
||||
copy = std::make_shared<arrow::clazz>(copy_data); \
|
||||
}
|
||||
#define DEFINE_TYPE(type, clazz) \
|
||||
case arrow::Type::type: { \
|
||||
copy = std::make_shared<arrow::clazz>(copy_data); \
|
||||
}
|
||||
DEFINE_TYPE(BOOL, BooleanArray)
|
||||
DEFINE_TYPE(BINARY, BinaryArray)
|
||||
DEFINE_TYPE(FIXED_SIZE_BINARY, FixedSizeBinaryArray)
|
||||
DEFINE_TYPE(DECIMAL, Decimal128Array)
|
||||
DEFINE_TYPE(FLOAT, NumericArray<arrow::FloatType>)
|
||||
DEFINE_TYPE(INT64, NumericArray<arrow::Int64Type>)
|
||||
default:break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return copy;
|
||||
}
|
||||
|
||||
SchemaPtr
|
||||
CopySchema(const SchemaPtr &origin) {
|
||||
CopySchema(const SchemaPtr& origin) {
|
||||
std::vector<std::shared_ptr<Field>> fields;
|
||||
for (auto &field : origin->fields()) {
|
||||
auto copy = std::make_shared<Field>(field->name(), field->type(),field->nullable(), nullptr);
|
||||
for (auto& field : origin->fields()) {
|
||||
auto copy = std::make_shared<Field>(field->name(), field->type(), field->nullable(), nullptr);
|
||||
fields.emplace_back(copy);
|
||||
}
|
||||
return std::make_shared<Schema>(std::move(fields));
|
||||
}
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,22 +15,22 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "knowhere/common/Array.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
ArrayPtr
|
||||
CopyArray(const ArrayPtr &origin);
|
||||
CopyArray(const ArrayPtr& origin);
|
||||
|
||||
SchemaPtr
|
||||
CopySchema(const SchemaPtr &origin);
|
||||
CopySchema(const SchemaPtr& origin);
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,36 +15,31 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include "knowhere/adapter/SptagAdapter.h"
|
||||
#include "knowhere/adapter/Structure.h"
|
||||
#include "knowhere/index/vector_index/helpers/Definitions.h"
|
||||
#include "SptagAdapter.h"
|
||||
#include "Structure.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
||||
std::shared_ptr<SPTAG::MetadataSet>
|
||||
ConvertToMetadataSet(const DatasetPtr &dataset) {
|
||||
ConvertToMetadataSet(const DatasetPtr& dataset) {
|
||||
auto array = dataset->array()[0];
|
||||
auto elems = array->length();
|
||||
|
||||
auto p_data = array->data()->GetValues<int64_t>(1, 0);
|
||||
auto p_offset = (int64_t *) malloc(sizeof(int64_t) * elems);
|
||||
for (auto i = 0; i <= elems; ++i)
|
||||
p_offset[i] = i * 8;
|
||||
|
||||
std::shared_ptr<SPTAG::MetadataSet> metaset(new SPTAG::MemMetadataSet(
|
||||
SPTAG::ByteArray((std::uint8_t *) p_data, elems * sizeof(int64_t), false),
|
||||
SPTAG::ByteArray((std::uint8_t *) p_offset, elems * sizeof(int64_t), true),
|
||||
elems));
|
||||
auto p_offset = (int64_t*)malloc(sizeof(int64_t) * elems);
|
||||
for (auto i = 0; i <= elems; ++i) p_offset[i] = i * 8;
|
||||
|
||||
std::shared_ptr<SPTAG::MetadataSet> metaset(
|
||||
new SPTAG::MemMetadataSet(SPTAG::ByteArray((std::uint8_t*)p_data, elems * sizeof(int64_t), false),
|
||||
SPTAG::ByteArray((std::uint8_t*)p_offset, elems * sizeof(int64_t), true), elems));
|
||||
|
||||
return metaset;
|
||||
}
|
||||
|
||||
std::shared_ptr<SPTAG::VectorSet>
|
||||
ConvertToVectorSet(const DatasetPtr &dataset) {
|
||||
ConvertToVectorSet(const DatasetPtr& dataset) {
|
||||
auto tensor = dataset->tensor()[0];
|
||||
|
||||
auto p_data = tensor->raw_mutable_data();
|
||||
|
@ -54,23 +49,20 @@ ConvertToVectorSet(const DatasetPtr &dataset) {
|
|||
|
||||
SPTAG::ByteArray byte_array(p_data, num_bytes, false);
|
||||
|
||||
auto vectorset = std::make_shared<SPTAG::BasicVectorSet>(byte_array,
|
||||
SPTAG::VectorValueType::Float,
|
||||
dimension,
|
||||
rows);
|
||||
auto vectorset =
|
||||
std::make_shared<SPTAG::BasicVectorSet>(byte_array, SPTAG::VectorValueType::Float, dimension, rows);
|
||||
return vectorset;
|
||||
}
|
||||
|
||||
std::vector<SPTAG::QueryResult>
|
||||
ConvertToQueryResult(const DatasetPtr &dataset, const Config &config) {
|
||||
ConvertToQueryResult(const DatasetPtr& dataset, const Config& config) {
|
||||
auto tensor = dataset->tensor()[0];
|
||||
|
||||
auto p_data = (float *) tensor->raw_mutable_data();
|
||||
auto p_data = (float*)tensor->raw_mutable_data();
|
||||
auto dimension = tensor->shape()[1];
|
||||
auto rows = tensor->shape()[0];
|
||||
|
||||
auto k = config[META_K].as<int64_t>();
|
||||
std::vector<SPTAG::QueryResult> query_results(rows, SPTAG::QueryResult(nullptr, k, true));
|
||||
std::vector<SPTAG::QueryResult> query_results(rows, SPTAG::QueryResult(nullptr, config->k, true));
|
||||
for (auto i = 0; i < rows; ++i) {
|
||||
query_results[i].SetTarget(&p_data[i * dimension]);
|
||||
}
|
||||
|
@ -83,23 +75,23 @@ ConvertToDataset(std::vector<SPTAG::QueryResult> query_results) {
|
|||
auto k = query_results[0].GetResultNum();
|
||||
auto elems = query_results.size() * k;
|
||||
|
||||
auto p_id = (int64_t *) malloc(sizeof(int64_t) * elems);
|
||||
auto p_dist = (float *) malloc(sizeof(float) * elems);
|
||||
// TODO: throw if malloc failed.
|
||||
auto p_id = (int64_t*)malloc(sizeof(int64_t) * elems);
|
||||
auto p_dist = (float*)malloc(sizeof(float) * elems);
|
||||
// TODO: throw if malloc failed.
|
||||
|
||||
#pragma omp parallel for
|
||||
for (auto i = 0; i < query_results.size(); ++i) {
|
||||
auto results = query_results[i].GetResults();
|
||||
auto num_result = query_results[i].GetResultNum();
|
||||
for (auto j = 0; j < num_result; ++j) {
|
||||
// p_id[i * k + j] = results[j].VID;
|
||||
p_id[i * k + j] = *(int64_t *) query_results[i].GetMetadata(j).Data();
|
||||
// p_id[i * k + j] = results[j].VID;
|
||||
p_id[i * k + j] = *(int64_t*)query_results[i].GetMetadata(j).Data();
|
||||
p_dist[i * k + j] = results[j].Dist;
|
||||
}
|
||||
}
|
||||
|
||||
auto id_buf = MakeMutableBufferSmart((uint8_t *) p_id, sizeof(int64_t) * elems);
|
||||
auto dist_buf = MakeMutableBufferSmart((uint8_t *) p_dist, sizeof(float) * elems);
|
||||
auto id_buf = MakeMutableBufferSmart((uint8_t*)p_id, sizeof(int64_t) * elems);
|
||||
auto dist_buf = MakeMutableBufferSmart((uint8_t*)p_dist, sizeof(float) * elems);
|
||||
|
||||
// TODO: magic
|
||||
std::vector<BufferPtr> id_bufs{nullptr, id_buf};
|
||||
|
@ -110,11 +102,11 @@ ConvertToDataset(std::vector<SPTAG::QueryResult> query_results) {
|
|||
|
||||
auto id_array_data = arrow::ArrayData::Make(int64_type, elems, id_bufs);
|
||||
auto dist_array_data = arrow::ArrayData::Make(float_type, elems, dist_bufs);
|
||||
// auto id_array_data = std::make_shared<ArrayData>(int64_type, sizeof(int64_t) * elems, id_bufs);
|
||||
// auto dist_array_data = std::make_shared<ArrayData>(float_type, sizeof(float) * elems, dist_bufs);
|
||||
// auto id_array_data = std::make_shared<ArrayData>(int64_type, sizeof(int64_t) * elems, id_bufs);
|
||||
// auto dist_array_data = std::make_shared<ArrayData>(float_type, sizeof(float) * elems, dist_bufs);
|
||||
|
||||
// auto ids = ConstructInt64Array((uint8_t*)p_id, sizeof(int64_t) * elems);
|
||||
// auto dists = ConstructFloatArray((uint8_t*)p_dist, sizeof(float) * elems);
|
||||
// auto ids = ConstructInt64Array((uint8_t*)p_id, sizeof(int64_t) * elems);
|
||||
// auto dists = ConstructFloatArray((uint8_t*)p_dist, sizeof(float) * elems);
|
||||
|
||||
auto ids = std::make_shared<NumericArray<arrow::Int64Type>>(id_array_data);
|
||||
auto dists = std::make_shared<NumericArray<arrow::FloatType>>(dist_array_data);
|
||||
|
@ -128,5 +120,5 @@ ConvertToDataset(std::vector<SPTAG::QueryResult> query_results) {
|
|||
return std::make_shared<Dataset>(array, schema);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,12 +15,11 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <SPTAG/AnnService/inc/Core/VectorIndex.h>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "knowhere/common/Dataset.h"
|
||||
|
||||
|
@ -28,16 +27,16 @@ namespace zilliz {
|
|||
namespace knowhere {
|
||||
|
||||
std::shared_ptr<SPTAG::VectorSet>
|
||||
ConvertToVectorSet(const DatasetPtr &dataset);
|
||||
ConvertToVectorSet(const DatasetPtr& dataset);
|
||||
|
||||
std::shared_ptr<SPTAG::MetadataSet>
|
||||
ConvertToMetadataSet(const DatasetPtr &dataset);
|
||||
ConvertToMetadataSet(const DatasetPtr& dataset);
|
||||
|
||||
std::vector<SPTAG::QueryResult>
|
||||
ConvertToQueryResult(const DatasetPtr &dataset, const Config &config);
|
||||
ConvertToQueryResult(const DatasetPtr& dataset, const Config& config);
|
||||
|
||||
DatasetPtr
|
||||
ConvertToDataset(std::vector<SPTAG::QueryResult> query_results);
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,15 +15,16 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include "knowhere/adapter/Structure.h"
|
||||
|
||||
#include "Structure.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
ArrayPtr
|
||||
ConstructInt64ArraySmart(uint8_t *data, int64_t size) {
|
||||
ConstructInt64ArraySmart(uint8_t* data, int64_t size) {
|
||||
// TODO: magic
|
||||
std::vector<BufferPtr> id_buf{nullptr, MakeMutableBufferSmart(data, size)};
|
||||
auto type = std::make_shared<arrow::Int64Type>();
|
||||
|
@ -32,7 +33,7 @@ ConstructInt64ArraySmart(uint8_t *data, int64_t size) {
|
|||
}
|
||||
|
||||
ArrayPtr
|
||||
ConstructFloatArraySmart(uint8_t *data, int64_t size) {
|
||||
ConstructFloatArraySmart(uint8_t* data, int64_t size) {
|
||||
// TODO: magic
|
||||
std::vector<BufferPtr> id_buf{nullptr, MakeMutableBufferSmart(data, size)};
|
||||
auto type = std::make_shared<arrow::FloatType>();
|
||||
|
@ -41,14 +42,14 @@ ConstructFloatArraySmart(uint8_t *data, int64_t size) {
|
|||
}
|
||||
|
||||
TensorPtr
|
||||
ConstructFloatTensorSmart(uint8_t *data, int64_t size, std::vector<int64_t> shape) {
|
||||
ConstructFloatTensorSmart(uint8_t* data, int64_t size, std::vector<int64_t> shape) {
|
||||
auto buffer = MakeMutableBufferSmart(data, size);
|
||||
auto float_type = std::make_shared<arrow::FloatType>();
|
||||
return std::make_shared<Tensor>(float_type, buffer, shape);
|
||||
}
|
||||
|
||||
ArrayPtr
|
||||
ConstructInt64Array(uint8_t *data, int64_t size) {
|
||||
ConstructInt64Array(uint8_t* data, int64_t size) {
|
||||
// TODO: magic
|
||||
std::vector<BufferPtr> id_buf{nullptr, MakeMutableBuffer(data, size)};
|
||||
auto type = std::make_shared<arrow::Int64Type>();
|
||||
|
@ -57,7 +58,7 @@ ConstructInt64Array(uint8_t *data, int64_t size) {
|
|||
}
|
||||
|
||||
ArrayPtr
|
||||
ConstructFloatArray(uint8_t *data, int64_t size) {
|
||||
ConstructFloatArray(uint8_t* data, int64_t size) {
|
||||
// TODO: magic
|
||||
std::vector<BufferPtr> id_buf{nullptr, MakeMutableBuffer(data, size)};
|
||||
auto type = std::make_shared<arrow::FloatType>();
|
||||
|
@ -66,23 +67,23 @@ ConstructFloatArray(uint8_t *data, int64_t size) {
|
|||
}
|
||||
|
||||
TensorPtr
|
||||
ConstructFloatTensor(uint8_t *data, int64_t size, std::vector<int64_t> shape) {
|
||||
ConstructFloatTensor(uint8_t* data, int64_t size, std::vector<int64_t> shape) {
|
||||
auto buffer = MakeMutableBuffer(data, size);
|
||||
auto float_type = std::make_shared<arrow::FloatType>();
|
||||
return std::make_shared<Tensor>(float_type, buffer, shape);
|
||||
}
|
||||
|
||||
FieldPtr
|
||||
ConstructInt64Field(const std::string &name) {
|
||||
ConstructInt64Field(const std::string& name) {
|
||||
auto type = std::make_shared<arrow::Int64Type>();
|
||||
return std::make_shared<Field>(name, type);
|
||||
}
|
||||
|
||||
|
||||
FieldPtr
|
||||
ConstructFloatField(const std::string &name) {
|
||||
ConstructFloatField(const std::string& name) {
|
||||
auto type = std::make_shared<arrow::FloatType>();
|
||||
return std::make_shared<Field>(name, type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,40 +15,40 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include "knowhere/common/Dataset.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "knowhere/common/Dataset.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
extern ArrayPtr
|
||||
ConstructInt64ArraySmart(uint8_t *data, int64_t size);
|
||||
ConstructInt64ArraySmart(uint8_t* data, int64_t size);
|
||||
|
||||
extern ArrayPtr
|
||||
ConstructFloatArraySmart(uint8_t *data, int64_t size);
|
||||
ConstructFloatArraySmart(uint8_t* data, int64_t size);
|
||||
|
||||
extern TensorPtr
|
||||
ConstructFloatTensorSmart(uint8_t *data, int64_t size, std::vector<int64_t> shape);
|
||||
ConstructFloatTensorSmart(uint8_t* data, int64_t size, std::vector<int64_t> shape);
|
||||
|
||||
extern ArrayPtr
|
||||
ConstructInt64Array(uint8_t *data, int64_t size);
|
||||
ConstructInt64Array(uint8_t* data, int64_t size);
|
||||
|
||||
extern ArrayPtr
|
||||
ConstructFloatArray(uint8_t *data, int64_t size);
|
||||
ConstructFloatArray(uint8_t* data, int64_t size);
|
||||
|
||||
extern TensorPtr
|
||||
ConstructFloatTensor(uint8_t *data, int64_t size, std::vector<int64_t> shape);
|
||||
ConstructFloatTensor(uint8_t* data, int64_t size, std::vector<int64_t> shape);
|
||||
|
||||
extern FieldPtr
|
||||
ConstructInt64Field(const std::string &name);
|
||||
ConstructInt64Field(const std::string& name);
|
||||
|
||||
extern FieldPtr
|
||||
ConstructFloatField(const std::string &name);
|
||||
ConstructFloatField(const std::string& name);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,18 +15,16 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
#define GETTENSOR(dataset) \
|
||||
auto tensor = dataset->tensor()[0]; \
|
||||
auto p_data = tensor->raw_data(); \
|
||||
auto dim = tensor->shape()[1]; \
|
||||
auto rows = tensor->shape()[0]; \
|
||||
#define GETTENSOR(dataset) \
|
||||
auto tensor = dataset->tensor()[0]; \
|
||||
auto p_data = tensor->raw_data(); \
|
||||
auto dim = tensor->shape()[1]; \
|
||||
auto rows = tensor->shape()[0];
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,14 +15,13 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <arrow/array.h>
|
||||
#include <memory>
|
||||
|
||||
#include "Schema.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
@ -35,9 +34,9 @@ using ArrayPtr = std::shared_ptr<Array>;
|
|||
using BooleanArray = arrow::BooleanArray;
|
||||
using BooleanArrayPtr = std::shared_ptr<arrow::BooleanArray>;
|
||||
|
||||
template<typename DType>
|
||||
template <typename DType>
|
||||
using NumericArray = arrow::NumericArray<DType>;
|
||||
template<typename DType>
|
||||
template <typename DType>
|
||||
using NumericArrayPtr = std::shared_ptr<arrow::NumericArray<DType>>;
|
||||
|
||||
using BinaryArray = arrow::BinaryArray;
|
||||
|
@ -49,6 +48,5 @@ using FixedSizeBinaryArrayPtr = std::shared_ptr<arrow::FixedSizeBinaryArray>;
|
|||
using Decimal128Array = arrow::Decimal128Array;
|
||||
using Decimal128ArrayPtr = std::shared_ptr<arrow::Decimal128Array>;
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,21 +15,19 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "Id.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
||||
struct Binary {
|
||||
ID id;
|
||||
std::shared_ptr<uint8_t> data;
|
||||
|
@ -37,29 +35,28 @@ struct Binary {
|
|||
};
|
||||
using BinaryPtr = std::shared_ptr<Binary>;
|
||||
|
||||
|
||||
class BinarySet {
|
||||
public:
|
||||
BinaryPtr
|
||||
GetByName(const std::string &name) const {
|
||||
GetByName(const std::string& name) const {
|
||||
return binary_map_.at(name);
|
||||
}
|
||||
|
||||
void
|
||||
Append(const std::string &name, BinaryPtr binary) {
|
||||
Append(const std::string& name, BinaryPtr binary) {
|
||||
binary_map_[name] = std::move(binary);
|
||||
}
|
||||
|
||||
void
|
||||
Append(const std::string &name, std::shared_ptr<uint8_t> data, int64_t size) {
|
||||
Append(const std::string& name, std::shared_ptr<uint8_t> data, int64_t size) {
|
||||
auto binary = std::make_shared<Binary>();
|
||||
binary->data = data;
|
||||
binary->size = size;
|
||||
binary_map_[name] = std::move(binary);
|
||||
}
|
||||
|
||||
//void
|
||||
//Append(const std::string &name, void *data, int64_t size, ID id) {
|
||||
// void
|
||||
// Append(const std::string &name, void *data, int64_t size, ID id) {
|
||||
// Binary binary;
|
||||
// binary.data = data;
|
||||
// binary.size = size;
|
||||
|
@ -67,7 +64,8 @@ class BinarySet {
|
|||
// binary_map_[name] = binary;
|
||||
//}
|
||||
|
||||
void clear() {
|
||||
void
|
||||
clear() {
|
||||
binary_map_.clear();
|
||||
}
|
||||
|
||||
|
@ -75,6 +73,5 @@ class BinarySet {
|
|||
std::map<std::string, BinaryPtr> binary_map_;
|
||||
};
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,14 +15,12 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <arrow/buffer.h>
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
@ -34,31 +32,32 @@ using MutableBufferPtr = std::shared_ptr<MutableBuffer>;
|
|||
namespace internal {
|
||||
|
||||
struct BufferDeleter {
|
||||
void operator()(Buffer *buffer) {
|
||||
free((void *) buffer->data());
|
||||
void
|
||||
operator()(Buffer* buffer) {
|
||||
free((void*)buffer->data());
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
inline BufferPtr
|
||||
MakeBufferSmart(uint8_t *data, const int64_t size) {
|
||||
MakeBufferSmart(uint8_t* data, const int64_t size) {
|
||||
return BufferPtr(new Buffer(data, size), internal::BufferDeleter());
|
||||
}
|
||||
|
||||
inline MutableBufferPtr
|
||||
MakeMutableBufferSmart(uint8_t *data, const int64_t size) {
|
||||
MakeMutableBufferSmart(uint8_t* data, const int64_t size) {
|
||||
return MutableBufferPtr(new MutableBuffer(data, size), internal::BufferDeleter());
|
||||
}
|
||||
|
||||
inline BufferPtr
|
||||
MakeBuffer(uint8_t *data, const int64_t size) {
|
||||
MakeBuffer(uint8_t* data, const int64_t size) {
|
||||
return std::make_shared<Buffer>(data, size);
|
||||
}
|
||||
|
||||
inline MutableBufferPtr
|
||||
MakeMutableBuffer(uint8_t *data, const int64_t size) {
|
||||
MakeMutableBuffer(uint8_t* data, const int64_t size) {
|
||||
return std::make_shared<MutableBuffer>(data, size);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,18 +15,44 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <jsoncons/json.hpp>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
enum class METRICTYPE {
|
||||
INVALID = 0,
|
||||
L2 = 1,
|
||||
IP = 2,
|
||||
};
|
||||
|
||||
using Config = jsoncons::json;
|
||||
// General Config
|
||||
constexpr int64_t INVALID_VALUE = -1;
|
||||
constexpr int64_t DEFAULT_K = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_DIM = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_GPUID = INVALID_VALUE;
|
||||
constexpr METRICTYPE DEFAULT_TYPE = METRICTYPE::INVALID;
|
||||
|
||||
struct Cfg {
|
||||
METRICTYPE metric_type = DEFAULT_TYPE;
|
||||
int64_t k = DEFAULT_K;
|
||||
int64_t gpu_id = DEFAULT_GPUID;
|
||||
int64_t d = DEFAULT_DIM;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
Cfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, METRICTYPE type)
|
||||
: metric_type(type), k(k), gpu_id(gpu_id), d(dim) {
|
||||
}
|
||||
|
||||
Cfg() = default;
|
||||
|
||||
virtual bool
|
||||
CheckValid() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
using Config = std::shared_ptr<Cfg>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,20 +15,19 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "Array.h"
|
||||
#include "Buffer.h"
|
||||
#include "Tensor.h"
|
||||
#include "Schema.h"
|
||||
#include "Config.h"
|
||||
#include "Schema.h"
|
||||
#include "Tensor.h"
|
||||
#include "knowhere/adapter/ArrowAdapter.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
@ -40,34 +39,38 @@ class Dataset {
|
|||
public:
|
||||
Dataset() = default;
|
||||
|
||||
Dataset(std::vector<ArrayPtr> &&array, SchemaPtr array_schema,
|
||||
std::vector<TensorPtr> &&tensor, SchemaPtr tensor_schema)
|
||||
Dataset(std::vector<ArrayPtr>&& array, SchemaPtr array_schema, std::vector<TensorPtr>&& tensor,
|
||||
SchemaPtr tensor_schema)
|
||||
: array_(std::move(array)),
|
||||
array_schema_(std::move(array_schema)),
|
||||
tensor_(std::move(tensor)),
|
||||
tensor_schema_(std::move(tensor_schema)) {}
|
||||
tensor_schema_(std::move(tensor_schema)) {
|
||||
}
|
||||
|
||||
Dataset(std::vector<ArrayPtr> array, SchemaPtr array_schema)
|
||||
: array_(std::move(array)), array_schema_(std::move(array_schema)) {}
|
||||
: array_(std::move(array)), array_schema_(std::move(array_schema)) {
|
||||
}
|
||||
|
||||
Dataset(std::vector<TensorPtr> tensor, SchemaPtr tensor_schema)
|
||||
: tensor_(std::move(tensor)), tensor_schema_(std::move(tensor_schema)) {}
|
||||
: tensor_(std::move(tensor)), tensor_schema_(std::move(tensor_schema)) {
|
||||
}
|
||||
|
||||
Dataset(const Dataset &) = delete;
|
||||
Dataset &operator=(const Dataset &) = delete;
|
||||
Dataset(const Dataset&) = delete;
|
||||
Dataset&
|
||||
operator=(const Dataset&) = delete;
|
||||
|
||||
DatasetPtr
|
||||
Clone() {
|
||||
auto dataset = std::make_shared<Dataset>();
|
||||
|
||||
std::vector<ArrayPtr> clone_array;
|
||||
for (auto &array : array_) {
|
||||
for (auto& array : array_) {
|
||||
clone_array.emplace_back(CopyArray(array));
|
||||
}
|
||||
dataset->set_array(clone_array);
|
||||
|
||||
std::vector<TensorPtr> clone_tensor;
|
||||
for (auto &tensor : tensor_) {
|
||||
for (auto& tensor : tensor_) {
|
||||
auto buffer = tensor->data();
|
||||
std::shared_ptr<Buffer> copy_buffer;
|
||||
// TODO: checkout copy success;
|
||||
|
@ -86,16 +89,20 @@ class Dataset {
|
|||
}
|
||||
|
||||
public:
|
||||
const std::vector<ArrayPtr> &
|
||||
array() const { return array_; }
|
||||
const std::vector<ArrayPtr>&
|
||||
array() const {
|
||||
return array_;
|
||||
}
|
||||
|
||||
void
|
||||
set_array(std::vector<ArrayPtr> array) {
|
||||
array_ = std::move(array);
|
||||
}
|
||||
|
||||
const std::vector<TensorPtr> &
|
||||
tensor() const { return tensor_; }
|
||||
const std::vector<TensorPtr>&
|
||||
tensor() const {
|
||||
return tensor_;
|
||||
}
|
||||
|
||||
void
|
||||
set_tensor(std::vector<TensorPtr> tensor) {
|
||||
|
@ -103,7 +110,9 @@ class Dataset {
|
|||
}
|
||||
|
||||
SchemaConstPtr
|
||||
array_schema() const { return array_schema_; }
|
||||
array_schema() const {
|
||||
return array_schema_;
|
||||
}
|
||||
|
||||
void
|
||||
set_array_schema(SchemaPtr array_schema) {
|
||||
|
@ -111,31 +120,32 @@ class Dataset {
|
|||
}
|
||||
|
||||
SchemaConstPtr
|
||||
tensor_schema() const { return tensor_schema_; }
|
||||
tensor_schema() const {
|
||||
return tensor_schema_;
|
||||
}
|
||||
|
||||
void
|
||||
set_tensor_schema(SchemaPtr tensor_schema) {
|
||||
tensor_schema_ = std::move(tensor_schema);
|
||||
}
|
||||
|
||||
//const Config &
|
||||
//meta() const { return meta_; }
|
||||
// const Config &
|
||||
// meta() const { return meta_; }
|
||||
|
||||
//void
|
||||
//set_meta(Config meta) {
|
||||
// void
|
||||
// set_meta(Config meta) {
|
||||
// meta_ = std::move(meta);
|
||||
//}
|
||||
|
||||
private:
|
||||
SchemaPtr array_schema_;
|
||||
SchemaPtr tensor_schema_;
|
||||
std::vector<ArrayPtr> array_;
|
||||
SchemaPtr array_schema_;
|
||||
std::vector<TensorPtr> tensor_;
|
||||
//Config meta_;
|
||||
SchemaPtr tensor_schema_;
|
||||
// Config meta_;
|
||||
};
|
||||
|
||||
using DatasetPtr = std::shared_ptr<Dataset>;
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,41 +15,37 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
#include "Exception.h"
|
||||
#include "Log.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
KnowhereException::KnowhereException(const std::string& msg) : msg(msg) {
|
||||
}
|
||||
|
||||
KnowhereException::KnowhereException(const std::string &msg):msg(msg) {}
|
||||
|
||||
KnowhereException::KnowhereException(const std::string &m, const char *funcName, const char *file, int line) {
|
||||
KnowhereException::KnowhereException(const std::string& m, const char* funcName, const char* file, int line) {
|
||||
#ifdef DEBUG
|
||||
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s",
|
||||
funcName, file, line, m.c_str());
|
||||
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", funcName, file, line, m.c_str());
|
||||
msg.resize(size + 1);
|
||||
snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s",
|
||||
funcName, file, line, m.c_str());
|
||||
snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s", funcName, file, line, m.c_str());
|
||||
#else
|
||||
std::string file_path(file);
|
||||
auto const pos = file_path.find_last_of('/');
|
||||
auto filename = file_path.substr(pos+1).c_str();
|
||||
auto filename = file_path.substr(pos + 1).c_str();
|
||||
|
||||
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s",
|
||||
funcName, filename, line, m.c_str());
|
||||
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", funcName, filename, line, m.c_str());
|
||||
msg.resize(size + 1);
|
||||
snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s",
|
||||
funcName, filename, line, m.c_str());
|
||||
snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s", funcName, filename, line, m.c_str());
|
||||
#endif
|
||||
}
|
||||
|
||||
const char *KnowhereException::what() const noexcept {
|
||||
const char*
|
||||
KnowhereException::what() const noexcept {
|
||||
return msg.c_str();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,46 +15,41 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
class KnowhereException : public std::exception {
|
||||
public:
|
||||
explicit KnowhereException(const std::string &msg);
|
||||
explicit KnowhereException(const std::string& msg);
|
||||
|
||||
KnowhereException(const std::string &msg, const char *funName,
|
||||
const char *file, int line);
|
||||
KnowhereException(const std::string& msg, const char* funName, const char* file, int line);
|
||||
|
||||
const char *what() const noexcept override;
|
||||
const char*
|
||||
what() const noexcept override;
|
||||
|
||||
std::string msg;
|
||||
};
|
||||
|
||||
#define KNOHWERE_ERROR_MSG(MSG) printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what())
|
||||
|
||||
#define KNOHWERE_ERROR_MSG(MSG)\
|
||||
printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what())
|
||||
#define KNOWHERE_THROW_MSG(MSG) \
|
||||
do { \
|
||||
throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
|
||||
} while (false)
|
||||
|
||||
#define KNOWHERE_THROW_MSG(MSG)\
|
||||
do {\
|
||||
throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__);\
|
||||
} while (false)
|
||||
#define KNOHERE_THROW_FORMAT(FMT, ...) \
|
||||
do { \
|
||||
std::string __s; \
|
||||
int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__); \
|
||||
__s.resize(__size + 1); \
|
||||
snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__); \
|
||||
throw faiss::FaissException(__s, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
|
||||
} while (false)
|
||||
|
||||
#define KNOHERE_THROW_FORMAT(FMT, ...)\
|
||||
do { \
|
||||
std::string __s;\
|
||||
int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__);\
|
||||
__s.resize(__size + 1);\
|
||||
snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__);\
|
||||
throw faiss::FaissException(__s, __PRETTY_FUNCTION__, __FILE__, __LINE__);\
|
||||
} while (false)
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,11 +15,10 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
//#include "zcommon/id/id.h"
|
||||
//using ID = zilliz::common::ID;
|
||||
// using ID = zilliz::common::ID;
|
||||
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
|
@ -27,18 +26,20 @@
|
|||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
||||
|
||||
class ID {
|
||||
public:
|
||||
constexpr static int64_t kIDSize = 20;
|
||||
|
||||
public:
|
||||
const int32_t *
|
||||
data() const { return content_; }
|
||||
const int32_t*
|
||||
data() const {
|
||||
return content_;
|
||||
}
|
||||
|
||||
int32_t *
|
||||
mutable_data() { return content_; }
|
||||
int32_t*
|
||||
mutable_data() {
|
||||
return content_;
|
||||
}
|
||||
|
||||
bool
|
||||
IsValid() const;
|
||||
|
@ -47,14 +48,14 @@ class ID {
|
|||
ToString() const;
|
||||
|
||||
bool
|
||||
operator==(const ID &that) const;
|
||||
operator==(const ID& that) const;
|
||||
|
||||
bool
|
||||
operator<(const ID &that) const;
|
||||
operator<(const ID& that) const;
|
||||
|
||||
protected:
|
||||
int32_t content_[5] = {};
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils/easylogging++.h"
|
||||
|
@ -33,5 +32,5 @@ namespace knowhere {
|
|||
#define KNOWHERE_LOG_ERROR LOG(ERROR) << KNOWHERE_DOMAIN_NAME
|
||||
#define KNOWHERE_LOG_FATAL LOG(FATAL) << KNOWHERE_DOMAIN_NAME
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,18 +15,15 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <arrow/type.h>
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
||||
using DataType = arrow::DataType;
|
||||
using Field = arrow::Field;
|
||||
using FieldPtr = std::shared_ptr<arrow::Field>;
|
||||
|
@ -34,7 +31,5 @@ using Schema = arrow::Schema;
|
|||
using SchemaPtr = std::shared_ptr<Schema>;
|
||||
using SchemaConstPtr = std::shared_ptr<const Schema>;
|
||||
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,21 +15,17 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <arrow/tensor.h>
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
||||
using Tensor = arrow::Tensor;
|
||||
using TensorPtr = std::shared_ptr<Tensor>;
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,18 +15,14 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <iostream> // TODO(linxj): using Log instead
|
||||
|
||||
#include <iostream> // TODO(linxj): using Log instead
|
||||
|
||||
#include "Timer.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
TimeRecorder::TimeRecorder(const std::string &header,
|
||||
int64_t log_level) :
|
||||
header_(header),
|
||||
log_level_(log_level) {
|
||||
TimeRecorder::TimeRecorder(const std::string& header, int64_t log_level) : header_(header), log_level_(log_level) {
|
||||
start_ = last_ = stdclock::now();
|
||||
}
|
||||
|
||||
|
@ -42,9 +38,10 @@ TimeRecorder::GetTimeSpanStr(double span) {
|
|||
}
|
||||
|
||||
void
|
||||
TimeRecorder::PrintTimeRecord(const std::string &msg, double span) {
|
||||
TimeRecorder::PrintTimeRecord(const std::string& msg, double span) {
|
||||
std::string str_log;
|
||||
if (!header_.empty()) str_log += header_ + ": ";
|
||||
if (!header_.empty())
|
||||
str_log += header_ + ": ";
|
||||
str_log += msg;
|
||||
str_log += " (";
|
||||
str_log += TimeRecorder::GetTimeSpanStr(span);
|
||||
|
@ -55,35 +52,35 @@ TimeRecorder::PrintTimeRecord(const std::string &msg, double span) {
|
|||
std::cout << str_log << std::endl;
|
||||
break;
|
||||
}
|
||||
//case 1: {
|
||||
// SERVER_LOG_DEBUG << str_log;
|
||||
// break;
|
||||
//}
|
||||
//case 2: {
|
||||
// SERVER_LOG_INFO << str_log;
|
||||
// break;
|
||||
//}
|
||||
//case 3: {
|
||||
// SERVER_LOG_WARNING << str_log;
|
||||
// break;
|
||||
//}
|
||||
//case 4: {
|
||||
// SERVER_LOG_ERROR << str_log;
|
||||
// break;
|
||||
//}
|
||||
//case 5: {
|
||||
// SERVER_LOG_FATAL << str_log;
|
||||
// break;
|
||||
//}
|
||||
//default: {
|
||||
// SERVER_LOG_INFO << str_log;
|
||||
// break;
|
||||
//}
|
||||
// case 1: {
|
||||
// SERVER_LOG_DEBUG << str_log;
|
||||
// break;
|
||||
//}
|
||||
// case 2: {
|
||||
// SERVER_LOG_INFO << str_log;
|
||||
// break;
|
||||
//}
|
||||
// case 3: {
|
||||
// SERVER_LOG_WARNING << str_log;
|
||||
// break;
|
||||
//}
|
||||
// case 4: {
|
||||
// SERVER_LOG_ERROR << str_log;
|
||||
// break;
|
||||
//}
|
||||
// case 5: {
|
||||
// SERVER_LOG_FATAL << str_log;
|
||||
// break;
|
||||
//}
|
||||
// default: {
|
||||
// SERVER_LOG_INFO << str_log;
|
||||
// break;
|
||||
//}
|
||||
}
|
||||
}
|
||||
|
||||
double
|
||||
TimeRecorder::RecordSection(const std::string &msg) {
|
||||
TimeRecorder::RecordSection(const std::string& msg) {
|
||||
stdclock::time_point curr = stdclock::now();
|
||||
double span = (std::chrono::duration<double, std::micro>(curr - last_)).count();
|
||||
last_ = curr;
|
||||
|
@ -93,7 +90,7 @@ TimeRecorder::RecordSection(const std::string &msg) {
|
|||
}
|
||||
|
||||
double
|
||||
TimeRecorder::ElapseFromBegin(const std::string &msg) {
|
||||
TimeRecorder::ElapseFromBegin(const std::string& msg) {
|
||||
stdclock::time_point curr = stdclock::now();
|
||||
double span = (std::chrono::duration<double, std::micro>(curr - start_)).count();
|
||||
|
||||
|
@ -101,5 +98,5 @@ TimeRecorder::ElapseFromBegin(const std::string &msg) {
|
|||
return span;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,11 +15,10 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
@ -28,19 +27,22 @@ class TimeRecorder {
|
|||
using stdclock = std::chrono::high_resolution_clock;
|
||||
|
||||
public:
|
||||
TimeRecorder(const std::string &header,
|
||||
int64_t log_level = 0);
|
||||
explicit TimeRecorder(const std::string& header, int64_t log_level = 0);
|
||||
|
||||
~TimeRecorder();//trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5
|
||||
~TimeRecorder(); // trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5
|
||||
|
||||
double RecordSection(const std::string &msg);
|
||||
double
|
||||
RecordSection(const std::string& msg);
|
||||
|
||||
double ElapseFromBegin(const std::string &msg);
|
||||
double
|
||||
ElapseFromBegin(const std::string& msg);
|
||||
|
||||
static std::string GetTimeSpanStr(double span);
|
||||
static std::string
|
||||
GetTimeSpanStr(double span);
|
||||
|
||||
private:
|
||||
void PrintTimeRecord(const std::string &msg, double span);
|
||||
void
|
||||
PrintTimeRecord(const std::string& msg, double span);
|
||||
|
||||
private:
|
||||
std::string header_;
|
||||
|
@ -49,5 +51,5 @@ class TimeRecorder {
|
|||
int64_t log_level_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,54 +15,55 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "IndexModel.h"
|
||||
#include "IndexType.h"
|
||||
#include "knowhere/common/BinarySet.h"
|
||||
#include "knowhere/common/Dataset.h"
|
||||
#include "IndexType.h"
|
||||
#include "IndexModel.h"
|
||||
#include "knowhere/index/preprocessor/Preprocessor.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
||||
class Index {
|
||||
public:
|
||||
virtual BinarySet
|
||||
Serialize() = 0;
|
||||
|
||||
virtual void
|
||||
Load(const BinarySet &index_binary) = 0;
|
||||
Load(const BinarySet& index_binary) = 0;
|
||||
|
||||
// @throw
|
||||
virtual DatasetPtr
|
||||
Search(const DatasetPtr &dataset, const Config &config) = 0;
|
||||
Search(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
public:
|
||||
IndexType
|
||||
idx_type() const { return idx_type_; }
|
||||
idx_type() const {
|
||||
return idx_type_;
|
||||
}
|
||||
|
||||
void
|
||||
set_idx_type(IndexType idx_type) { idx_type_ = idx_type; }
|
||||
set_idx_type(IndexType idx_type) {
|
||||
idx_type_ = idx_type;
|
||||
}
|
||||
|
||||
virtual void
|
||||
set_preprocessor(PreprocessorPtr preprocessor) {}
|
||||
set_preprocessor(PreprocessorPtr preprocessor) {
|
||||
}
|
||||
|
||||
virtual void
|
||||
set_index_model(IndexModelPtr model) {}
|
||||
set_index_model(IndexModelPtr model) {
|
||||
}
|
||||
|
||||
private:
|
||||
IndexType idx_type_;
|
||||
};
|
||||
|
||||
|
||||
using IndexPtr = std::shared_ptr<Index>;
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
@ -24,19 +23,16 @@
|
|||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
||||
class IndexModel {
|
||||
public:
|
||||
virtual BinarySet
|
||||
Serialize() = 0;
|
||||
|
||||
virtual void
|
||||
Load(const BinarySet &binary) = 0;
|
||||
Load(const BinarySet& binary) = 0;
|
||||
};
|
||||
|
||||
using IndexModelPtr = std::shared_ptr<IndexModel>;
|
||||
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,14 +15,11 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
||||
enum class IndexType {
|
||||
kUnknown = 0,
|
||||
kVecIdxBegin = 100,
|
||||
|
@ -30,6 +27,5 @@ enum class IndexType {
|
|||
kVecIdxEnd,
|
||||
};
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -1,14 +1,30 @@
|
|||
//// Licensed to the Apache Software Foundation (ASF) under one
|
||||
//// or more contributor license agreements. See the NOTICE file
|
||||
//// distributed with this work for additional information
|
||||
//// regarding copyright ownership. The ASF licenses this file
|
||||
//// to you under the Apache License, Version 2.0 (the
|
||||
//// "License"); you may not use this file except in compliance
|
||||
//// with the License. You may obtain a copy of the License at
|
||||
////
|
||||
//// http://www.apache.org/licenses/LICENSE-2.0
|
||||
////
|
||||
//// Unless required by applicable law or agreed to in writing,
|
||||
//// software distributed under the License is distributed on an
|
||||
//// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
//// KIND, either express or implied. See the License for the
|
||||
//// specific language governing permissions and limitations
|
||||
//// under the License.
|
||||
//
|
||||
//#include "knowhere/index/vector_index/definitions.h"
|
||||
//#include "knowhere/common/config.h"
|
||||
//#include "knowhere/index/preprocessor/normalize.h"
|
||||
#include "knowhere/index/preprocessor/Normalize.h"
|
||||
//
|
||||
//
|
||||
//namespace zilliz {
|
||||
//namespace knowhere {
|
||||
// namespace zilliz {
|
||||
// namespace knowhere {
|
||||
//
|
||||
//DatasetPtr
|
||||
//NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
|
||||
// DatasetPtr
|
||||
// NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
|
||||
// // TODO: wrap dataset->tensor
|
||||
// auto tensor = dataset->tensor()[0];
|
||||
// auto p_data = (float *)tensor->raw_mutable_data();
|
||||
|
@ -21,8 +37,8 @@
|
|||
// }
|
||||
//}
|
||||
//
|
||||
//void
|
||||
//NormalizePreprocessor::Normalize(float *arr, int64_t dimension) {
|
||||
// void
|
||||
// NormalizePreprocessor::Normalize(float *arr, int64_t dimension) {
|
||||
// double vector_length = 0;
|
||||
// for (auto j = 0; j < dimension; j++) {
|
||||
// double val = arr[j];
|
||||
|
@ -39,4 +55,3 @@
|
|||
//
|
||||
//} // namespace knowhere
|
||||
//} // namespace zilliz
|
||||
|
||||
|
|
|
@ -1,13 +1,30 @@
|
|||
//// Licensed to the Apache Software Foundation (ASF) under one
|
||||
//// or more contributor license agreements. See the NOTICE file
|
||||
//// distributed with this work for additional information
|
||||
//// regarding copyright ownership. The ASF licenses this file
|
||||
//// to you under the Apache License, Version 2.0 (the
|
||||
//// "License"); you may not use this file except in compliance
|
||||
//// with the License. You may obtain a copy of the License at
|
||||
////
|
||||
//// http://www.apache.org/licenses/LICENSE-2.0
|
||||
////
|
||||
//// Unless required by applicable law or agreed to in writing,
|
||||
//// software distributed under the License is distributed on an
|
||||
//// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
//// KIND, either express or implied. See the License for the
|
||||
//// specific language governing permissions and limitations
|
||||
//// under the License.
|
||||
//
|
||||
//#pragma once
|
||||
//
|
||||
//#include <memory>
|
||||
//#include "preprocessor.h"
|
||||
//
|
||||
//
|
||||
//namespace zilliz {
|
||||
//namespace knowhere {
|
||||
// namespace zilliz {
|
||||
// namespace knowhere {
|
||||
//
|
||||
//class NormalizePreprocessor : public Preprocessor {
|
||||
// class NormalizePreprocessor : public Preprocessor {
|
||||
// public:
|
||||
// DatasetPtr
|
||||
// Preprocess(const DatasetPtr &input) override;
|
||||
|
@ -19,7 +36,7 @@
|
|||
//};
|
||||
//
|
||||
//
|
||||
//using NormalizePreprocessorPtr = std::shared_ptr<NormalizePreprocessor>;
|
||||
// using NormalizePreprocessorPtr = std::shared_ptr<NormalizePreprocessor>;
|
||||
//
|
||||
//
|
||||
//} // namespace knowhere
|
||||
|
|
|
@ -15,27 +15,22 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "knowhere/common/Dataset.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
||||
class Preprocessor {
|
||||
public:
|
||||
virtual DatasetPtr
|
||||
Preprocess(const DatasetPtr &input) = 0;
|
||||
Preprocess(const DatasetPtr& input) = 0;
|
||||
};
|
||||
|
||||
|
||||
using PreprocessorPtr = std::shared_ptr<Preprocessor>;
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,23 +15,24 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <faiss/index_io.h>
|
||||
#include <faiss/IndexIVF.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/FaissBaseIndex.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
#include "FaissBaseIndex.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
FaissBaseIndex::FaissBaseIndex(std::shared_ptr<faiss::Index> index) : index_(std::move(index)) {}
|
||||
FaissBaseIndex::FaissBaseIndex(std::shared_ptr<faiss::Index> index) : index_(std::move(index)) {
|
||||
}
|
||||
|
||||
BinarySet FaissBaseIndex::SerializeImpl() {
|
||||
BinarySet
|
||||
FaissBaseIndex::SerializeImpl() {
|
||||
try {
|
||||
faiss::Index *index = index_.get();
|
||||
faiss::Index* index = index_.get();
|
||||
|
||||
SealImpl();
|
||||
|
||||
|
@ -44,37 +45,38 @@ BinarySet FaissBaseIndex::SerializeImpl() {
|
|||
// TODO(linxj): use virtual func Name() instead of raw string.
|
||||
res_set.Append("IVF", data, writer.rp);
|
||||
return res_set;
|
||||
} catch (std::exception &e) {
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void FaissBaseIndex::LoadImpl(const BinarySet &index_binary) {
|
||||
void
|
||||
FaissBaseIndex::LoadImpl(const BinarySet& index_binary) {
|
||||
auto binary = index_binary.GetByName("IVF");
|
||||
|
||||
MemoryIOReader reader;
|
||||
reader.total = binary->size;
|
||||
reader.data_ = binary->data.get();
|
||||
|
||||
faiss::Index *index = faiss::read_index(&reader);
|
||||
faiss::Index* index = faiss::read_index(&reader);
|
||||
|
||||
index_.reset(index);
|
||||
}
|
||||
|
||||
void FaissBaseIndex::SealImpl() {
|
||||
// TODO(linxj): enable
|
||||
//#ifdef ZILLIZ_FAISS
|
||||
faiss::Index *index = index_.get();
|
||||
auto idx = dynamic_cast<faiss::IndexIVF *>(index);
|
||||
void
|
||||
FaissBaseIndex::SealImpl() {
|
||||
// TODO(linxj): enable
|
||||
//#ifdef ZILLIZ_FAISS
|
||||
faiss::Index* index = index_.get();
|
||||
auto idx = dynamic_cast<faiss::IndexIVF*>(index);
|
||||
if (idx != nullptr) {
|
||||
idx->to_readonly();
|
||||
}
|
||||
//else {
|
||||
// else {
|
||||
// KNOHWERE_ERROR_MSG("Seal failed");
|
||||
//}
|
||||
//#endif
|
||||
//#endif
|
||||
}
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
@ -24,7 +23,6 @@
|
|||
|
||||
#include "knowhere/common/BinarySet.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
@ -36,7 +34,7 @@ class FaissBaseIndex {
|
|||
SerializeImpl();
|
||||
|
||||
virtual void
|
||||
LoadImpl(const BinarySet &index_binary);
|
||||
LoadImpl(const BinarySet& index_binary);
|
||||
|
||||
virtual void
|
||||
SealImpl();
|
||||
|
@ -45,8 +43,5 @@ class FaissBaseIndex {
|
|||
std::shared_ptr<faiss::Index> index_ = nullptr;
|
||||
};
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,31 +15,30 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
|
||||
#include <faiss/IndexIVFPQ.h>
|
||||
#include <faiss/gpu/GpuAutoTune.h>
|
||||
#include <faiss/gpu/GpuIndexFlat.h>
|
||||
#include <faiss/gpu/GpuIndexIVF.h>
|
||||
#include <faiss/gpu/GpuIndexIVFFlat.h>
|
||||
#include <faiss/gpu/GpuAutoTune.h>
|
||||
#include <faiss/IndexIVFPQ.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <memory>
|
||||
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/helpers/Cloner.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "IndexGPUIVF.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/helpers/Cloner.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
|
||||
auto nlist = config["nlist"].as<size_t>();
|
||||
gpu_id_ = config.get_with_default("gpu_id", gpu_id_);
|
||||
auto metric_type = config["metric_type"].as_string() == "L2" ?
|
||||
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||
IndexModelPtr
|
||||
GPUIVF::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
gpu_id_ = build_cfg->gpu_id;
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
|
@ -48,8 +47,9 @@ IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
|
|||
ResScope rs(temp_resource, gpu_id_, true);
|
||||
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
|
||||
idx_config.device = gpu_id_;
|
||||
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, nlist, metric_type, idx_config);
|
||||
device_index.train(rows, (float *) p_data);
|
||||
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, build_cfg->nlist,
|
||||
GetMetricType(build_cfg->metric_type), idx_config);
|
||||
device_index.train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index));
|
||||
|
@ -60,7 +60,8 @@ IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
|
|||
}
|
||||
}
|
||||
|
||||
void GPUIVF::set_index_model(IndexModelPtr model) {
|
||||
void
|
||||
GPUIVF::set_index_model(IndexModelPtr model) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto host_index = std::static_pointer_cast<IVFIndexModel>(model);
|
||||
|
@ -74,7 +75,8 @@ void GPUIVF::set_index_model(IndexModelPtr model) {
|
|||
}
|
||||
}
|
||||
|
||||
BinarySet GPUIVF::SerializeImpl() {
|
||||
BinarySet
|
||||
GPUIVF::SerializeImpl() {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -82,8 +84,8 @@ BinarySet GPUIVF::SerializeImpl() {
|
|||
try {
|
||||
MemoryIOWriter writer;
|
||||
{
|
||||
faiss::Index *index = index_.get();
|
||||
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(index);
|
||||
faiss::Index* index = index_.get();
|
||||
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(index);
|
||||
|
||||
SealImpl();
|
||||
|
||||
|
@ -97,19 +99,20 @@ BinarySet GPUIVF::SerializeImpl() {
|
|||
res_set.Append("IVF", data, writer.rp);
|
||||
|
||||
return res_set;
|
||||
} catch (std::exception &e) {
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void GPUIVF::LoadImpl(const BinarySet &index_binary) {
|
||||
void
|
||||
GPUIVF::LoadImpl(const BinarySet& index_binary) {
|
||||
auto binary = index_binary.GetByName("IVF");
|
||||
MemoryIOReader reader;
|
||||
{
|
||||
reader.total = binary->size;
|
||||
reader.data_ = binary->data.get();
|
||||
|
||||
faiss::Index *index = faiss::read_index(&reader);
|
||||
faiss::Index* index = faiss::read_index(&reader);
|
||||
|
||||
if (auto temp_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) {
|
||||
ResScope rs(temp_res, gpu_id_, false);
|
||||
|
@ -124,75 +127,78 @@ void GPUIVF::LoadImpl(const BinarySet &index_binary) {
|
|||
}
|
||||
}
|
||||
|
||||
IVFIndexPtr GPUIVF::Copy_index_gpu_to_cpu() {
|
||||
IVFIndexPtr
|
||||
GPUIVF::Copy_index_gpu_to_cpu() {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
faiss::Index *device_index = index_.get();
|
||||
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
faiss::Index* device_index = index_.get();
|
||||
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
|
||||
std::shared_ptr<faiss::Index> new_index;
|
||||
new_index.reset(host_index);
|
||||
return std::make_shared<IVF>(new_index);
|
||||
}
|
||||
|
||||
void GPUIVF::search_impl(int64_t n,
|
||||
const float *data,
|
||||
int64_t k,
|
||||
float *distances,
|
||||
int64_t *labels,
|
||||
const Config &cfg) {
|
||||
void
|
||||
GPUIVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
// TODO(linxj): gpu index support GenParams
|
||||
if (auto device_index = std::static_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) {
|
||||
auto nprobe = cfg.get_with_default("nprobe", size_t(1));
|
||||
device_index->setNumProbes(nprobe);
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(cfg);
|
||||
device_index->setNumProbes(search_cfg->nprobe);
|
||||
|
||||
{
|
||||
// TODO(linxj): allocate mem
|
||||
ResScope rs(res_, gpu_id_);
|
||||
device_index->search(n, (float *) data, k, distances, labels);
|
||||
device_index->search(n, (float*)data, k, distances, labels);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VectorIndexPtr GPUIVF::CopyGpuToCpu(const Config &config) {
|
||||
VectorIndexPtr
|
||||
GPUIVF::CopyGpuToCpu(const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
faiss::Index *device_index = index_.get();
|
||||
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
faiss::Index* device_index = index_.get();
|
||||
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
|
||||
std::shared_ptr<faiss::Index> new_index;
|
||||
new_index.reset(host_index);
|
||||
return std::make_shared<IVF>(new_index);
|
||||
}
|
||||
|
||||
VectorIndexPtr GPUIVF::Clone() {
|
||||
VectorIndexPtr
|
||||
GPUIVF::Clone() {
|
||||
auto cpu_idx = CopyGpuToCpu(Config());
|
||||
return ::zilliz::knowhere::cloner::CopyCpuToGpu(cpu_idx, gpu_id_, Config());
|
||||
}
|
||||
|
||||
VectorIndexPtr GPUIVF::CopyGpuToGpu(const int64_t &device_id, const Config &config) {
|
||||
VectorIndexPtr
|
||||
GPUIVF::CopyGpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
auto host_index = CopyGpuToCpu(config);
|
||||
return std::static_pointer_cast<IVF>(host_index)->CopyCpuToGpu(device_id, config);
|
||||
}
|
||||
|
||||
void GPUIVF::Add(const DatasetPtr &dataset, const Config &config) {
|
||||
void
|
||||
GPUIVF::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
if (auto spt = res_.lock()) {
|
||||
ResScope rs(res_, gpu_id_);
|
||||
IVF::Add(dataset, config);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Add IVF can't get gpu resource");
|
||||
}
|
||||
}
|
||||
|
||||
void GPUIndex::SetGpuDevice(const int &gpu_id) {
|
||||
void
|
||||
GPUIndex::SetGpuDevice(const int& gpu_id) {
|
||||
gpu_id_ = gpu_id;
|
||||
}
|
||||
|
||||
const int64_t &GPUIndex::GetGpuDevice() {
|
||||
const int64_t&
|
||||
GPUIndex::GetGpuDevice() {
|
||||
return gpu_id_;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,84 +15,84 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
class GPUIndex {
|
||||
public:
|
||||
explicit GPUIndex(const int &device_id) : gpu_id_(device_id) {}
|
||||
public:
|
||||
explicit GPUIndex(const int& device_id) : gpu_id_(device_id) {
|
||||
}
|
||||
|
||||
GPUIndex(const int& device_id, const ResPtr& resource): gpu_id_(device_id), res_(resource) {}
|
||||
GPUIndex(const int& device_id, const ResPtr& resource) : gpu_id_(device_id), res_(resource) {
|
||||
}
|
||||
|
||||
virtual VectorIndexPtr
|
||||
CopyGpuToCpu(const Config &config) = 0;
|
||||
virtual VectorIndexPtr
|
||||
CopyGpuToCpu(const Config& config) = 0;
|
||||
|
||||
virtual VectorIndexPtr
|
||||
CopyGpuToGpu(const int64_t &device_id, const Config &config) = 0;
|
||||
virtual VectorIndexPtr
|
||||
CopyGpuToGpu(const int64_t& device_id, const Config& config) = 0;
|
||||
|
||||
void
|
||||
SetGpuDevice(const int &gpu_id);
|
||||
void
|
||||
SetGpuDevice(const int& gpu_id);
|
||||
|
||||
const int64_t &
|
||||
GetGpuDevice();
|
||||
const int64_t&
|
||||
GetGpuDevice();
|
||||
|
||||
protected:
|
||||
int64_t gpu_id_;
|
||||
ResWPtr res_;
|
||||
protected:
|
||||
int64_t gpu_id_;
|
||||
ResWPtr res_;
|
||||
};
|
||||
|
||||
class GPUIVF : public IVF, public GPUIndex {
|
||||
public:
|
||||
explicit GPUIVF(const int &device_id) : IVF(), GPUIndex(device_id) {}
|
||||
public:
|
||||
explicit GPUIVF(const int& device_id) : IVF(), GPUIndex(device_id) {
|
||||
}
|
||||
|
||||
explicit GPUIVF(std::shared_ptr<faiss::Index> index, const int64_t &device_id, ResPtr &resource)
|
||||
: IVF(std::move(index)), GPUIndex(device_id, resource) {};
|
||||
explicit GPUIVF(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& resource)
|
||||
: IVF(std::move(index)), GPUIndex(device_id, resource) {
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr &dataset, const Config &config) override;
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr &dataset, const Config &config) override;
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
set_index_model(IndexModelPtr model) override;
|
||||
void
|
||||
set_index_model(IndexModelPtr model) override;
|
||||
|
||||
//DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const Config &config) override;
|
||||
// DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const Config& config) override;
|
||||
|
||||
VectorIndexPtr
|
||||
CopyGpuToGpu(const int64_t &device_id, const Config &config) override;
|
||||
VectorIndexPtr
|
||||
CopyGpuToGpu(const int64_t& device_id, const Config& config) override;
|
||||
|
||||
VectorIndexPtr
|
||||
Clone() final;
|
||||
VectorIndexPtr
|
||||
Clone() final;
|
||||
|
||||
// TODO(linxj): Deprecated
|
||||
virtual IVFIndexPtr Copy_index_gpu_to_cpu();
|
||||
// TODO(linxj): Deprecated
|
||||
virtual IVFIndexPtr
|
||||
Copy_index_gpu_to_cpu();
|
||||
|
||||
protected:
|
||||
void
|
||||
search_impl(int64_t n,
|
||||
const float *data,
|
||||
int64_t k,
|
||||
float *distances,
|
||||
int64_t *labels,
|
||||
const Config &cfg) override;
|
||||
protected:
|
||||
void
|
||||
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) override;
|
||||
|
||||
BinarySet
|
||||
SerializeImpl() override;
|
||||
BinarySet
|
||||
SerializeImpl() override;
|
||||
|
||||
void
|
||||
LoadImpl(const BinarySet &index_binary) override;
|
||||
void
|
||||
LoadImpl(const BinarySet& index_binary) override;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,34 +15,34 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <faiss/gpu/GpuIndexIVFPQ.h>
|
||||
#include <faiss/gpu/GpuAutoTune.h>
|
||||
#include <faiss/IndexIVFPQ.h>
|
||||
#include <faiss/gpu/GpuAutoTune.h>
|
||||
#include <faiss/gpu/GpuIndexIVFPQ.h>
|
||||
#include <memory>
|
||||
|
||||
#include "IndexGPUIVFPQ.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVFPQ.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
IndexModelPtr GPUIVFPQ::Train(const DatasetPtr &dataset, const Config &config) {
|
||||
auto nlist = config["nlist"].as<size_t>();
|
||||
auto M = config["M"].as<size_t>(); // number of subquantizers(subvectors)
|
||||
auto nbits = config["nbits"].as<size_t>();// number of bit per subvector index
|
||||
auto gpu_num = config.get_with_default("gpu_id", gpu_id_);
|
||||
auto metric_type = config["metric_type"].as_string() == "L2" ?
|
||||
faiss::METRIC_L2 : faiss::METRIC_L2; // IP not support.
|
||||
IndexModelPtr
|
||||
GPUIVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
gpu_id_ = build_cfg->gpu_id;
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
// TODO(linxj): set device here.
|
||||
// TODO(linxj): set gpu resource here.
|
||||
faiss::gpu::StandardGpuResources res;
|
||||
faiss::gpu::GpuIndexIVFPQ device_index(&res, dim, nlist, M, nbits, metric_type);
|
||||
device_index.train(rows, (float *) p_data);
|
||||
faiss::gpu::GpuIndexIVFPQ device_index(&res, dim, build_cfg->nlist, build_cfg->m, build_cfg->nbits,
|
||||
GetMetricType(build_cfg->metric_type)); // IP not support
|
||||
device_index.train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index));
|
||||
|
@ -50,19 +50,22 @@ IndexModelPtr GPUIVFPQ::Train(const DatasetPtr &dataset, const Config &config) {
|
|||
return std::make_shared<IVFIndexModel>(host_index);
|
||||
}
|
||||
|
||||
std::shared_ptr<faiss::IVFSearchParameters> GPUIVFPQ::GenParams(const Config &config) {
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GPUIVFPQ::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
|
||||
params->nprobe = config.get_with_default("nprobe", size_t(1));
|
||||
//params->scan_table_threshold = 0;
|
||||
//params->polysemous_ht = 0;
|
||||
//params->max_codes = 0;
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
|
||||
params->nprobe = search_cfg->nprobe;
|
||||
// params->scan_table_threshold = conf->scan_table_threhold;
|
||||
// params->polysemous_ht = conf->polysemous_ht;
|
||||
// params->max_codes = conf->max_codes;
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
VectorIndexPtr GPUIVFPQ::CopyGpuToCpu(const Config &config) {
|
||||
VectorIndexPtr
|
||||
GPUIVFPQ::CopyGpuToCpu(const Config& config) {
|
||||
KNOWHERE_THROW_MSG("not support yet");
|
||||
}
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,33 +15,32 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "IndexGPUIVF.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
class GPUIVFPQ : public GPUIVF {
|
||||
public:
|
||||
explicit GPUIVFPQ(const int &device_id) : GPUIVF(device_id) {}
|
||||
public:
|
||||
explicit GPUIVFPQ(const int& device_id) : GPUIVF(device_id) {
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr &dataset, const Config &config) override;
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
public:
|
||||
public:
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const Config &config) override;
|
||||
CopyGpuToCpu(const Config& config) override;
|
||||
|
||||
protected:
|
||||
protected:
|
||||
// TODO(linxj): remove GenParams.
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GenParams(const Config &config) override;
|
||||
GenParams(const Config& config) override;
|
||||
};
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,60 +15,62 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <faiss/gpu/GpuAutoTune.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "IndexGPUIVFSQ.h"
|
||||
#include "IndexIVFSQ.h"
|
||||
|
||||
#include "knowhere/index/vector_index/IndexGPUIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
|
||||
auto nlist = config["nlist"].as<size_t>();
|
||||
auto nbits = config["nbits"].as<size_t>(); // TODO(linxj): gpu only support SQ4 SQ8 SQ16
|
||||
gpu_id_ = config.get_with_default("gpu_id", gpu_id_);
|
||||
auto metric_type = config["metric_type"].as_string() == "L2" ?
|
||||
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
std::stringstream index_type;
|
||||
index_type << "IVF" << nlist << "," << "SQ" << nbits;
|
||||
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type);
|
||||
|
||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (temp_resource != nullptr) {
|
||||
ResScope rs(temp_resource, gpu_id_, true);
|
||||
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
|
||||
device_index->train(rows, (float *) p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
|
||||
|
||||
delete device_index;
|
||||
delete build_index;
|
||||
|
||||
return std::make_shared<IVFIndexModel>(host_index);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource");
|
||||
}
|
||||
IndexModelPtr
|
||||
GPUIVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
gpu_id_ = build_cfg->gpu_id;
|
||||
|
||||
VectorIndexPtr GPUIVFSQ::CopyGpuToCpu(const Config &config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
GETTENSOR(dataset)
|
||||
|
||||
faiss::Index *device_index = index_.get();
|
||||
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
std::stringstream index_type;
|
||||
index_type << "IVF" << build_cfg->nlist << ","
|
||||
<< "SQ" << build_cfg->nbits;
|
||||
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
|
||||
|
||||
std::shared_ptr<faiss::Index> new_index;
|
||||
new_index.reset(host_index);
|
||||
return std::make_shared<IVFSQ>(new_index);
|
||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (temp_resource != nullptr) {
|
||||
ResScope rs(temp_resource, gpu_id_, true);
|
||||
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
|
||||
device_index->train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
|
||||
|
||||
delete device_index;
|
||||
delete build_index;
|
||||
|
||||
return std::make_shared<IVFIndexModel>(host_index);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource");
|
||||
}
|
||||
}
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
VectorIndexPtr
|
||||
GPUIVFSQ::CopyGpuToCpu(const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
faiss::Index* device_index = index_.get();
|
||||
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
|
||||
std::shared_ptr<faiss::Index> new_index;
|
||||
new_index.reset(host_index);
|
||||
return std::make_shared<IVFSQ>(new_index);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,29 +15,31 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "IndexGPUIVF.h"
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexGPUIVF.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
class GPUIVFSQ : public GPUIVF {
|
||||
public:
|
||||
explicit GPUIVFSQ(const int &device_id) : GPUIVF(device_id) {}
|
||||
public:
|
||||
explicit GPUIVFSQ(const int& device_id) : GPUIVF(device_id) {
|
||||
}
|
||||
|
||||
explicit GPUIVFSQ(std::shared_ptr<faiss::Index> index, const int64_t &device_id, ResPtr &resource)
|
||||
: GPUIVF(std::move(index), device_id, resource) {};
|
||||
explicit GPUIVFSQ(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& resource)
|
||||
: GPUIVF(std::move(index), device_id, resource) {
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr &dataset, const Config &config) override;
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const Config &config) override;
|
||||
CopyGpuToCpu(const Config& config) override;
|
||||
};
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,24 +15,23 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/AutoTune.h>
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/MetaIndexes.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <faiss/gpu/GpuAutoTune.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
#include "IndexIDMAP.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
BinarySet IDMAP::Serialize() {
|
||||
BinarySet
|
||||
IDMAP::Serialize() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
@ -41,34 +40,34 @@ BinarySet IDMAP::Serialize() {
|
|||
return SerializeImpl();
|
||||
}
|
||||
|
||||
void IDMAP::Load(const BinarySet &index_binary) {
|
||||
void
|
||||
IDMAP::Load(const BinarySet& index_binary) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(index_binary);
|
||||
}
|
||||
|
||||
DatasetPtr IDMAP::Search(const DatasetPtr &dataset, const Config &config) {
|
||||
DatasetPtr
|
||||
IDMAP::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
auto k = config["k"].as<size_t>();
|
||||
//auto metric_type = config["metric_type"].as_string() == "L2" ?
|
||||
config->CheckValid();
|
||||
// auto metric_type = config["metric_type"].as_string() == "L2" ?
|
||||
// faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||
//index_->metric_type = metric_type;
|
||||
// index_->metric_type = metric_type;
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
// TODO(linxj): handle malloc exception
|
||||
auto elems = rows * k;
|
||||
auto res_ids = (int64_t *) malloc(sizeof(int64_t) * elems);
|
||||
auto res_dis = (float *) malloc(sizeof(float) * elems);
|
||||
auto elems = rows * config->k;
|
||||
auto res_ids = (int64_t*)malloc(sizeof(int64_t) * elems);
|
||||
auto res_dis = (float*)malloc(sizeof(float) * elems);
|
||||
|
||||
search_impl(rows, (float *) p_data, k, res_dis, res_ids, Config());
|
||||
search_impl(rows, (float*)p_data, config->k, res_dis, res_ids, Config());
|
||||
|
||||
auto id_buf = MakeMutableBufferSmart((uint8_t *) res_ids, sizeof(int64_t) * elems);
|
||||
auto dist_buf = MakeMutableBufferSmart((uint8_t *) res_dis, sizeof(float) * elems);
|
||||
auto id_buf = MakeMutableBufferSmart((uint8_t*)res_ids, sizeof(int64_t) * elems);
|
||||
auto dist_buf = MakeMutableBufferSmart((uint8_t*)res_dis, sizeof(float) * elems);
|
||||
|
||||
// TODO: magic
|
||||
std::vector<BufferPtr> id_bufs{nullptr, id_buf};
|
||||
std::vector<BufferPtr> dist_bufs{nullptr, dist_buf};
|
||||
|
||||
|
@ -85,12 +84,13 @@ DatasetPtr IDMAP::Search(const DatasetPtr &dataset, const Config &config) {
|
|||
return std::make_shared<Dataset>(array, nullptr);
|
||||
}
|
||||
|
||||
void IDMAP::search_impl(int64_t n, const float *data, int64_t k, float *distances, int64_t *labels, const Config &cfg) {
|
||||
index_->search(n, (float *) data, k, distances, labels);
|
||||
|
||||
void
|
||||
IDMAP::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
|
||||
index_->search(n, (float*)data, k, distances, labels);
|
||||
}
|
||||
|
||||
void IDMAP::Add(const DatasetPtr &dataset, const Config &config) {
|
||||
void
|
||||
IDMAP::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
@ -100,51 +100,56 @@ void IDMAP::Add(const DatasetPtr &dataset, const Config &config) {
|
|||
|
||||
// TODO: magic here.
|
||||
auto array = dataset->array()[0];
|
||||
auto p_ids = array->data()->GetValues<long>(1, 0);
|
||||
auto p_ids = array->data()->GetValues<int64_t>(1, 0);
|
||||
|
||||
index_->add_with_ids(rows, (float *) p_data, p_ids);
|
||||
index_->add_with_ids(rows, (float*)p_data, p_ids);
|
||||
}
|
||||
|
||||
int64_t IDMAP::Count() {
|
||||
int64_t
|
||||
IDMAP::Count() {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t IDMAP::Dimension() {
|
||||
int64_t
|
||||
IDMAP::Dimension() {
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
// TODO(linxj): return const pointer
|
||||
float *IDMAP::GetRawVectors() {
|
||||
float*
|
||||
IDMAP::GetRawVectors() {
|
||||
try {
|
||||
auto file_index = dynamic_cast<faiss::IndexIDMap *>(index_.get());
|
||||
auto file_index = dynamic_cast<faiss::IndexIDMap*>(index_.get());
|
||||
auto flat_index = dynamic_cast<faiss::IndexFlat*>(file_index->index);
|
||||
return flat_index->xb.data();
|
||||
} catch (std::exception &e) {
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(linxj): return const pointer
|
||||
int64_t *IDMAP::GetRawIds() {
|
||||
int64_t*
|
||||
IDMAP::GetRawIds() {
|
||||
try {
|
||||
auto file_index = dynamic_cast<faiss::IndexIDMap *>(index_.get());
|
||||
auto file_index = dynamic_cast<faiss::IndexIDMap*>(index_.get());
|
||||
return file_index->id_map.data();
|
||||
} catch (std::exception &e) {
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
const char* type = "IDMap,Flat";
|
||||
void IDMAP::Train(const Config &config) {
|
||||
auto metric_type = config["metric_type"].as_string() == "L2" ?
|
||||
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||
auto dim = config["dim"].as<size_t>();
|
||||
|
||||
auto index = faiss::index_factory(dim, type, metric_type);
|
||||
void
|
||||
IDMAP::Train(const Config& config) {
|
||||
config->CheckValid();
|
||||
|
||||
auto index = faiss::index_factory(config->d, type, GetMetricType(config->metric_type));
|
||||
index_.reset(index);
|
||||
}
|
||||
|
||||
VectorIndexPtr IDMAP::Clone() {
|
||||
VectorIndexPtr
|
||||
IDMAP::Clone() {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto clone_index = faiss::clone_index(index_.get());
|
||||
|
@ -153,8 +158,9 @@ VectorIndexPtr IDMAP::Clone() {
|
|||
return std::make_shared<IDMAP>(new_index);
|
||||
}
|
||||
|
||||
VectorIndexPtr IDMAP::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)){
|
||||
VectorIndexPtr
|
||||
IDMAP::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
|
||||
|
||||
|
@ -166,38 +172,41 @@ VectorIndexPtr IDMAP::CopyCpuToGpu(const int64_t &device_id, const Config &confi
|
|||
}
|
||||
}
|
||||
|
||||
void IDMAP::Seal() {
|
||||
void
|
||||
IDMAP::Seal() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
VectorIndexPtr GPUIDMAP::CopyGpuToCpu(const Config &config) {
|
||||
VectorIndexPtr
|
||||
GPUIDMAP::CopyGpuToCpu(const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
faiss::Index *device_index = index_.get();
|
||||
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
faiss::Index* device_index = index_.get();
|
||||
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
|
||||
|
||||
std::shared_ptr<faiss::Index> new_index;
|
||||
new_index.reset(host_index);
|
||||
return std::make_shared<IDMAP>(new_index);
|
||||
}
|
||||
|
||||
VectorIndexPtr GPUIDMAP::Clone() {
|
||||
VectorIndexPtr
|
||||
GPUIDMAP::Clone() {
|
||||
auto cpu_idx = CopyGpuToCpu(Config());
|
||||
|
||||
if (auto idmap = std::dynamic_pointer_cast<IDMAP>(cpu_idx)){
|
||||
if (auto idmap = std::dynamic_pointer_cast<IDMAP>(cpu_idx)) {
|
||||
return idmap->CopyCpuToGpu(gpu_id_, Config());
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("IndexType not Support GpuClone");
|
||||
}
|
||||
}
|
||||
|
||||
BinarySet GPUIDMAP::SerializeImpl() {
|
||||
BinarySet
|
||||
GPUIDMAP::SerializeImpl() {
|
||||
try {
|
||||
MemoryIOWriter writer;
|
||||
{
|
||||
faiss::Index *index = index_.get();
|
||||
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(index);
|
||||
faiss::Index* index = index_.get();
|
||||
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(index);
|
||||
|
||||
faiss::write_index(host_index, &writer);
|
||||
delete host_index;
|
||||
|
@ -209,21 +218,22 @@ BinarySet GPUIDMAP::SerializeImpl() {
|
|||
res_set.Append("IVF", data, writer.rp);
|
||||
|
||||
return res_set;
|
||||
} catch (std::exception &e) {
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void GPUIDMAP::LoadImpl(const BinarySet &index_binary) {
|
||||
void
|
||||
GPUIDMAP::LoadImpl(const BinarySet& index_binary) {
|
||||
auto binary = index_binary.GetByName("IVF");
|
||||
MemoryIOReader reader;
|
||||
{
|
||||
reader.total = binary->size;
|
||||
reader.data_ = binary->data.get();
|
||||
|
||||
faiss::Index *index = faiss::read_index(&reader);
|
||||
faiss::Index* index = faiss::read_index(&reader);
|
||||
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_) ){
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) {
|
||||
ResScope rs(res, gpu_id_, false);
|
||||
auto device_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index);
|
||||
index_.reset(device_index);
|
||||
|
@ -236,28 +246,27 @@ void GPUIDMAP::LoadImpl(const BinarySet &index_binary) {
|
|||
}
|
||||
}
|
||||
|
||||
VectorIndexPtr GPUIDMAP::CopyGpuToGpu(const int64_t &device_id, const Config &config) {
|
||||
VectorIndexPtr
|
||||
GPUIDMAP::CopyGpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
auto cpu_index = CopyGpuToCpu(config);
|
||||
return std::static_pointer_cast<IDMAP>(cpu_index)->CopyCpuToGpu(device_id, config);
|
||||
}
|
||||
|
||||
float *GPUIDMAP::GetRawVectors() {
|
||||
float*
|
||||
GPUIDMAP::GetRawVectors() {
|
||||
KNOWHERE_THROW_MSG("Not support");
|
||||
}
|
||||
|
||||
int64_t *GPUIDMAP::GetRawIds() {
|
||||
int64_t*
|
||||
GPUIDMAP::GetRawIds() {
|
||||
KNOWHERE_THROW_MSG("Not support");
|
||||
}
|
||||
|
||||
void GPUIDMAP::search_impl(int64_t n,
|
||||
const float *data,
|
||||
int64_t k,
|
||||
float *distances,
|
||||
int64_t *labels,
|
||||
const Config &cfg) {
|
||||
void
|
||||
GPUIDMAP::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
|
||||
ResScope rs(res_, gpu_id_);
|
||||
index_->search(n, (float *) data, k, distances, labels);
|
||||
index_->search(n, (float*)data, k, distances, labels);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,41 +15,54 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "IndexIVF.h"
|
||||
#include "IndexGPUIVF.h"
|
||||
#include "IndexIVF.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
class IDMAP : public VectorIndex, public FaissBaseIndex {
|
||||
public:
|
||||
IDMAP() : FaissBaseIndex(nullptr) {};
|
||||
explicit IDMAP(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {};
|
||||
BinarySet Serialize() override;
|
||||
void Load(const BinarySet &index_binary) override;
|
||||
void Train(const Config &config);
|
||||
DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
|
||||
int64_t Count() override;
|
||||
VectorIndexPtr Clone() override;
|
||||
int64_t Dimension() override;
|
||||
void Add(const DatasetPtr &dataset, const Config &config) override;
|
||||
VectorIndexPtr CopyCpuToGpu(const int64_t &device_id, const Config &config);
|
||||
void Seal() override;
|
||||
IDMAP() : FaissBaseIndex(nullptr) {
|
||||
}
|
||||
|
||||
virtual float *GetRawVectors();
|
||||
virtual int64_t *GetRawIds();
|
||||
explicit IDMAP(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
void
|
||||
Train(const Config& config);
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
int64_t
|
||||
Count() override;
|
||||
VectorIndexPtr
|
||||
Clone() override;
|
||||
int64_t
|
||||
Dimension() override;
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
VectorIndexPtr
|
||||
CopyCpuToGpu(const int64_t& device_id, const Config& config);
|
||||
void
|
||||
Seal() override;
|
||||
|
||||
virtual float*
|
||||
GetRawVectors();
|
||||
virtual int64_t*
|
||||
GetRawIds();
|
||||
|
||||
protected:
|
||||
virtual void search_impl(int64_t n,
|
||||
const float *data,
|
||||
int64_t k,
|
||||
float *distances,
|
||||
int64_t *labels,
|
||||
const Config &cfg);
|
||||
virtual void
|
||||
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
|
@ -57,27 +70,31 @@ using IDMAPPtr = std::shared_ptr<IDMAP>;
|
|||
|
||||
class GPUIDMAP : public IDMAP, public GPUIndex {
|
||||
public:
|
||||
explicit GPUIDMAP(std::shared_ptr<faiss::Index> index, const int64_t &device_id, ResPtr& res)
|
||||
: IDMAP(std::move(index)), GPUIndex(device_id, res) {}
|
||||
explicit GPUIDMAP(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& res)
|
||||
: IDMAP(std::move(index)), GPUIndex(device_id, res) {
|
||||
}
|
||||
|
||||
VectorIndexPtr CopyGpuToCpu(const Config &config) override;
|
||||
float *GetRawVectors() override;
|
||||
int64_t *GetRawIds() override;
|
||||
VectorIndexPtr Clone() override;
|
||||
VectorIndexPtr CopyGpuToGpu(const int64_t &device_id, const Config &config) override;
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const Config& config) override;
|
||||
float*
|
||||
GetRawVectors() override;
|
||||
int64_t*
|
||||
GetRawIds() override;
|
||||
VectorIndexPtr
|
||||
Clone() override;
|
||||
VectorIndexPtr
|
||||
CopyGpuToGpu(const int64_t& device_id, const Config& config) override;
|
||||
|
||||
protected:
|
||||
void search_impl(int64_t n,
|
||||
const float *data,
|
||||
int64_t k,
|
||||
float *distances,
|
||||
int64_t *labels,
|
||||
const Config &cfg) override;
|
||||
BinarySet SerializeImpl() override;
|
||||
void LoadImpl(const BinarySet &index_binary) override;
|
||||
void
|
||||
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) override;
|
||||
BinarySet
|
||||
SerializeImpl() override;
|
||||
void
|
||||
LoadImpl(const BinarySet& index_binary) override;
|
||||
};
|
||||
|
||||
using GPUIDMAPPtr = std::shared_ptr<GPUIDMAP>;
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,44 +15,47 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <faiss/AutoTune.h>
|
||||
#include <faiss/AuxIndexStructures.h>
|
||||
#include <faiss/IVFlib.h>
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/IndexIVF.h>
|
||||
#include <faiss/IndexIVFFlat.h>
|
||||
#include <faiss/IndexIVFPQ.h>
|
||||
#include <faiss/AutoTune.h>
|
||||
#include <faiss/IVFlib.h>
|
||||
#include <faiss/AuxIndexStructures.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <faiss/gpu/GpuAutoTune.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "IndexIVF.h"
|
||||
#include "IndexGPUIVF.h"
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
||||
IndexModelPtr IVF::Train(const DatasetPtr &dataset, const Config &config) {
|
||||
auto nlist = config["nlist"].as<size_t>();
|
||||
auto metric_type = config["metric_type"].as_string() == "L2" ?
|
||||
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||
IndexModelPtr
|
||||
IVF::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
faiss::Index *coarse_quantizer = new faiss::IndexFlatL2(dim);
|
||||
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, nlist, metric_type);
|
||||
index->train(rows, (float *) p_data);
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlatL2(dim);
|
||||
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, build_cfg->nlist,
|
||||
GetMetricType(build_cfg->metric_type));
|
||||
index->train(rows, (float*)p_data);
|
||||
|
||||
// TODO: override here. train return model or not.
|
||||
// TODO(linxj): override here. train return model or not.
|
||||
return std::make_shared<IVFIndexModel>(index);
|
||||
}
|
||||
|
||||
|
||||
void IVF::Add(const DatasetPtr &dataset, const Config &config) {
|
||||
void
|
||||
IVF::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -60,13 +63,13 @@ void IVF::Add(const DatasetPtr &dataset, const Config &config) {
|
|||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
GETTENSOR(dataset)
|
||||
|
||||
// TODO: magic here.
|
||||
auto array = dataset->array()[0];
|
||||
auto p_ids = array->data()->GetValues<long>(1, 0);
|
||||
index_->add_with_ids(rows, (float *) p_data, p_ids);
|
||||
auto p_ids = array->data()->GetValues<int64_t>(1, 0);
|
||||
index_->add_with_ids(rows, (float*)p_data, p_ids);
|
||||
}
|
||||
|
||||
void IVF::AddWithoutIds(const DatasetPtr &dataset, const Config &config) {
|
||||
void
|
||||
IVF::AddWithoutIds(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -74,10 +77,11 @@ void IVF::AddWithoutIds(const DatasetPtr &dataset, const Config &config) {
|
|||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
GETTENSOR(dataset)
|
||||
|
||||
index_->add(rows, (float *) p_data);
|
||||
index_->add(rows, (float*)p_data);
|
||||
}
|
||||
|
||||
BinarySet IVF::Serialize() {
|
||||
BinarySet
|
||||
IVF::Serialize() {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
@ -87,38 +91,34 @@ BinarySet IVF::Serialize() {
|
|||
return SerializeImpl();
|
||||
}
|
||||
|
||||
void IVF::Load(const BinarySet &index_binary) {
|
||||
void
|
||||
IVF::Load(const BinarySet& index_binary) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(index_binary);
|
||||
}
|
||||
|
||||
DatasetPtr IVF::Search(const DatasetPtr &dataset, const Config &config) {
|
||||
DatasetPtr
|
||||
IVF::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
auto k = config["k"].as<size_t>();
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
if (search_cfg != nullptr) {
|
||||
search_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
// TODO(linxj): handle malloc exception
|
||||
auto elems = rows * k;
|
||||
auto res_ids = (int64_t *) malloc(sizeof(int64_t) * elems);
|
||||
auto res_dis = (float *) malloc(sizeof(float) * elems);
|
||||
auto elems = rows * search_cfg->k;
|
||||
auto res_ids = (int64_t*)malloc(sizeof(int64_t) * elems);
|
||||
auto res_dis = (float*)malloc(sizeof(float) * elems);
|
||||
|
||||
search_impl(rows, (float*) p_data, k, res_dis, res_ids, config);
|
||||
//faiss::ivflib::search_with_parameters(index_.get(),
|
||||
// rows,
|
||||
// (float *) p_data,
|
||||
// k,
|
||||
// res_dis,
|
||||
// res_ids,
|
||||
// params.get());
|
||||
search_impl(rows, (float*)p_data, search_cfg->k, res_dis, res_ids, config);
|
||||
|
||||
auto id_buf = MakeMutableBufferSmart((uint8_t *) res_ids, sizeof(int64_t) * elems);
|
||||
auto dist_buf = MakeMutableBufferSmart((uint8_t *) res_dis, sizeof(float) * elems);
|
||||
auto id_buf = MakeMutableBufferSmart((uint8_t*)res_ids, sizeof(int64_t) * elems);
|
||||
auto dist_buf = MakeMutableBufferSmart((uint8_t*)res_dis, sizeof(float) * elems);
|
||||
|
||||
// TODO: magic
|
||||
std::vector<BufferPtr> id_bufs{nullptr, id_buf};
|
||||
std::vector<BufferPtr> dist_bufs{nullptr, dist_buf};
|
||||
|
||||
|
@ -135,7 +135,8 @@ DatasetPtr IVF::Search(const DatasetPtr &dataset, const Config &config) {
|
|||
return std::make_shared<Dataset>(array, nullptr);
|
||||
}
|
||||
|
||||
void IVF::set_index_model(IndexModelPtr model) {
|
||||
void
|
||||
IVF::set_index_model(IndexModelPtr model) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto rel_model = std::static_pointer_cast<IVFIndexModel>(model);
|
||||
|
@ -144,23 +145,29 @@ void IVF::set_index_model(IndexModelPtr model) {
|
|||
index_.reset(faiss::clone_index(rel_model->index_.get()));
|
||||
}
|
||||
|
||||
std::shared_ptr<faiss::IVFSearchParameters> IVF::GenParams(const Config &config) {
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
IVF::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
|
||||
params->nprobe = config.get_with_default("nprobe", size_t(1));
|
||||
//params->max_codes = config.get_with_default("max_codes", size_t(0));
|
||||
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
params->nprobe = search_cfg->nprobe;
|
||||
// params->max_codes = config.get_with_default("max_codes", size_t(0));
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
int64_t IVF::Count() {
|
||||
int64_t
|
||||
IVF::Count() {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t IVF::Dimension() {
|
||||
int64_t
|
||||
IVF::Dimension() {
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, const Config &config) {
|
||||
void
|
||||
IVF::GenGraph(const int64_t& k, Graph& graph, const DatasetPtr& dataset, const Config& config) {
|
||||
GETTENSOR(dataset)
|
||||
|
||||
auto ntotal = Count();
|
||||
|
@ -176,7 +183,7 @@ void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, co
|
|||
for (int i = 0; i < total_search_count; ++i) {
|
||||
auto b_size = i == total_search_count - 1 && tail_batch_size != 0 ? tail_batch_size : batch_size;
|
||||
|
||||
auto &res = res_vec[i];
|
||||
auto& res = res_vec[i];
|
||||
res.resize(k * b_size);
|
||||
|
||||
auto xq = p_data + batch_size * dim * i;
|
||||
|
@ -184,7 +191,7 @@ void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, co
|
|||
|
||||
int tmp = 0;
|
||||
for (int j = 0; j < b_size; ++j) {
|
||||
auto &node = graph[batch_size * i + j];
|
||||
auto& node = graph[batch_size * i + j];
|
||||
node.resize(k);
|
||||
for (int m = 0; m < k && tmp < k * b_size; ++m, ++tmp) {
|
||||
// TODO(linxj): avoid memcopy here.
|
||||
|
@ -194,18 +201,15 @@ void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, co
|
|||
}
|
||||
}
|
||||
|
||||
void IVF::search_impl(int64_t n,
|
||||
const float *data,
|
||||
int64_t k,
|
||||
float *distances,
|
||||
int64_t *labels,
|
||||
const Config &cfg) {
|
||||
void
|
||||
IVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
|
||||
auto params = GenParams(cfg);
|
||||
faiss::ivflib::search_with_parameters(index_.get(), n, (float *) data, k, distances, labels, params.get());
|
||||
faiss::ivflib::search_with_parameters(index_.get(), n, (float*)data, k, distances, labels, params.get());
|
||||
}
|
||||
|
||||
VectorIndexPtr IVF::CopyCpuToGpu(const int64_t& device_id, const Config &config) {
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)){
|
||||
VectorIndexPtr
|
||||
IVF::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
|
||||
|
||||
|
@ -217,7 +221,8 @@ VectorIndexPtr IVF::CopyCpuToGpu(const int64_t& device_id, const Config &config)
|
|||
}
|
||||
}
|
||||
|
||||
VectorIndexPtr IVF::Clone() {
|
||||
VectorIndexPtr
|
||||
IVF::Clone() {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto clone_index = faiss::clone_index(index_.get());
|
||||
|
@ -226,21 +231,24 @@ VectorIndexPtr IVF::Clone() {
|
|||
return Clone_impl(new_index);
|
||||
}
|
||||
|
||||
VectorIndexPtr IVF::Clone_impl(const std::shared_ptr<faiss::Index> &index) {
|
||||
VectorIndexPtr
|
||||
IVF::Clone_impl(const std::shared_ptr<faiss::Index>& index) {
|
||||
return std::make_shared<IVF>(index);
|
||||
}
|
||||
|
||||
void IVF::Seal() {
|
||||
void
|
||||
IVF::Seal() {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
SealImpl();
|
||||
}
|
||||
|
||||
IVFIndexModel::IVFIndexModel(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
|
||||
}
|
||||
|
||||
IVFIndexModel::IVFIndexModel(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {}
|
||||
|
||||
BinarySet IVFIndexModel::Serialize() {
|
||||
BinarySet
|
||||
IVFIndexModel::Serialize() {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("indexmodel not initialize or trained");
|
||||
}
|
||||
|
@ -248,18 +256,16 @@ BinarySet IVFIndexModel::Serialize() {
|
|||
return SerializeImpl();
|
||||
}
|
||||
|
||||
void IVFIndexModel::Load(const BinarySet &binary_set) {
|
||||
void
|
||||
IVFIndexModel::Load(const BinarySet& binary_set) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(binary_set);
|
||||
}
|
||||
|
||||
void IVFIndexModel::SealImpl() {
|
||||
void
|
||||
IVFIndexModel::SealImpl() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,17 +15,17 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "VectorIndex.h"
|
||||
#include "FaissBaseIndex.h"
|
||||
#include "VectorIndex.h"
|
||||
#include "faiss/IndexIVF.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
@ -33,36 +33,38 @@ using Graph = std::vector<std::vector<int64_t>>;
|
|||
|
||||
class IVF : public VectorIndex, protected FaissBaseIndex {
|
||||
public:
|
||||
IVF() : FaissBaseIndex(nullptr) {};
|
||||
IVF() : FaissBaseIndex(nullptr) {
|
||||
}
|
||||
|
||||
explicit IVF(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {}
|
||||
explicit IVF(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
Clone() override;;
|
||||
Clone() override;
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr &dataset, const Config &config) override;
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
set_index_model(IndexModelPtr model) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr &dataset, const Config &config) override;
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
AddWithoutIds(const DatasetPtr &dataset, const Config &config);
|
||||
AddWithoutIds(const DatasetPtr& dataset, const Config& config);
|
||||
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr &dataset, const Config &config) override;
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, const Config &config);
|
||||
GenGraph(const int64_t& k, Graph& graph, const DatasetPtr& dataset, const Config& config);
|
||||
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
|
||||
void
|
||||
Load(const BinarySet &index_binary) override;
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
@ -74,23 +76,17 @@ class IVF : public VectorIndex, protected FaissBaseIndex {
|
|||
Seal() override;
|
||||
|
||||
virtual VectorIndexPtr
|
||||
CopyCpuToGpu(const int64_t &device_id, const Config &config);
|
||||
|
||||
CopyCpuToGpu(const int64_t& device_id, const Config& config);
|
||||
|
||||
protected:
|
||||
virtual std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GenParams(const Config &config);
|
||||
GenParams(const Config& config);
|
||||
|
||||
virtual VectorIndexPtr
|
||||
Clone_impl(const std::shared_ptr<faiss::Index> &index);
|
||||
Clone_impl(const std::shared_ptr<faiss::Index>& index);
|
||||
|
||||
virtual void
|
||||
search_impl(int64_t n,
|
||||
const float *data,
|
||||
int64_t k,
|
||||
float *distances,
|
||||
int64_t *labels,
|
||||
const Config &cfg);
|
||||
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
|
@ -106,13 +102,14 @@ class IVFIndexModel : public IndexModel, public FaissBaseIndex {
|
|||
public:
|
||||
explicit IVFIndexModel(std::shared_ptr<faiss::Index> index);
|
||||
|
||||
IVFIndexModel() : FaissBaseIndex(nullptr) {};
|
||||
IVFIndexModel() : FaissBaseIndex(nullptr) {
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
|
||||
void
|
||||
Load(const BinarySet &binary) override;
|
||||
Load(const BinarySet& binary) override;
|
||||
|
||||
protected:
|
||||
void
|
||||
|
@ -121,7 +118,8 @@ class IVFIndexModel : public IndexModel, public FaissBaseIndex {
|
|||
protected:
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
using IVFIndexModelPtr = std::shared_ptr<IVFIndexModel>;
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,46 +15,51 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/IndexIVFPQ.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexIVFPQ.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
IndexModelPtr IVFPQ::Train(const DatasetPtr &dataset, const Config &config) {
|
||||
auto nlist = config["nlist"].as<size_t>();
|
||||
auto M = config["M"].as<size_t>(); // number of subquantizers(subvector)
|
||||
auto nbits = config["nbits"].as<size_t>();// number of bit per subvector index
|
||||
auto metric_type = config["metric_type"].as_string() == "L2" ?
|
||||
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||
IndexModelPtr
|
||||
IVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
faiss::Index *coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
|
||||
auto index = std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, nlist, M, nbits);
|
||||
index->train(rows, (float *) p_data);
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(build_cfg->metric_type));
|
||||
auto index =
|
||||
std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, build_cfg->nlist, build_cfg->m, build_cfg->nbits);
|
||||
index->train(rows, (float*)p_data);
|
||||
|
||||
return std::make_shared<IVFIndexModel>(index);
|
||||
}
|
||||
|
||||
std::shared_ptr<faiss::IVFSearchParameters> IVFPQ::GenParams(const Config &config) {
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
IVFPQ::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
|
||||
params->nprobe = config.get_with_default("nprobe", size_t(1));
|
||||
//params->scan_table_threshold = 0;
|
||||
//params->polysemous_ht = 0;
|
||||
//params->max_codes = 0;
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
|
||||
params->nprobe = search_cfg->nprobe;
|
||||
// params->scan_table_threshold = conf->scan_table_threhold;
|
||||
// params->polysemous_ht = conf->polysemous_ht;
|
||||
// params->max_codes = conf->max_codes;
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
VectorIndexPtr IVFPQ::Clone_impl(const std::shared_ptr<faiss::Index> &index) {
|
||||
VectorIndexPtr
|
||||
IVFPQ::Clone_impl(const std::shared_ptr<faiss::Index>& index) {
|
||||
return std::make_shared<IVFPQ>(index);
|
||||
}
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,33 +15,33 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexIVF.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
class IVFPQ : public IVF {
|
||||
public:
|
||||
explicit IVFPQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {}
|
||||
public:
|
||||
explicit IVFPQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
|
||||
}
|
||||
|
||||
IVFPQ() = default;
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr &dataset, const Config &config) override;
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
protected:
|
||||
protected:
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GenParams(const Config &config) override;
|
||||
GenParams(const Config& config) override;
|
||||
|
||||
VectorIndexPtr
|
||||
Clone_impl(const std::shared_ptr<faiss::Index> &index) override;
|
||||
Clone_impl(const std::shared_ptr<faiss::Index>& index) override;
|
||||
};
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,43 +15,46 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <faiss/gpu/GpuAutoTune.h>
|
||||
#include <memory>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "IndexIVFSQ.h"
|
||||
#include "IndexGPUIVFSQ.h"
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
IndexModelPtr IVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
|
||||
auto nlist = config["nlist"].as<size_t>();
|
||||
auto nbits = config["nbits"].as<size_t>(); // TODO(linxj): only support SQ4 SQ6 SQ8 SQ16
|
||||
auto metric_type = config["metric_type"].as_string() == "L2" ?
|
||||
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||
IndexModelPtr
|
||||
IVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
std::stringstream index_type;
|
||||
index_type << "IVF" << nlist << "," << "SQ" << nbits;
|
||||
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type);
|
||||
build_index->train(rows, (float *) p_data);
|
||||
index_type << "IVF" << build_cfg->nlist << ","
|
||||
<< "SQ" << build_cfg->nbits;
|
||||
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
|
||||
build_index->train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> ret_index;
|
||||
ret_index.reset(build_index);
|
||||
return std::make_shared<IVFIndexModel>(ret_index);
|
||||
}
|
||||
|
||||
VectorIndexPtr IVFSQ::Clone_impl(const std::shared_ptr<faiss::Index> &index) {
|
||||
VectorIndexPtr
|
||||
IVFSQ::Clone_impl(const std::shared_ptr<faiss::Index>& index) {
|
||||
return std::make_shared<IVFSQ>(index);
|
||||
}
|
||||
|
||||
VectorIndexPtr IVFSQ::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)){
|
||||
VectorIndexPtr
|
||||
IVFSQ::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
|
||||
ResScope rs(res, device_id, false);
|
||||
faiss::gpu::GpuClonerOptions option;
|
||||
option.allInGpu = true;
|
||||
|
@ -66,5 +69,5 @@ VectorIndexPtr IVFSQ::CopyCpuToGpu(const int64_t &device_id, const Config &confi
|
|||
}
|
||||
}
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,31 +15,33 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "IndexIVF.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
class IVFSQ : public IVF {
|
||||
public:
|
||||
explicit IVFSQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {}
|
||||
public:
|
||||
explicit IVFSQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
|
||||
}
|
||||
|
||||
IVFSQ() = default;
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr &dataset, const Config &config) override;
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
VectorIndexPtr
|
||||
CopyCpuToGpu(const int64_t &device_id, const Config &config) override;
|
||||
CopyCpuToGpu(const int64_t& device_id, const Config& config) override;
|
||||
|
||||
protected:
|
||||
protected:
|
||||
VectorIndexPtr
|
||||
Clone_impl(const std::shared_ptr<faiss::Index> &index) override;
|
||||
Clone_impl(const std::shared_ptr<faiss::Index>& index) override;
|
||||
};
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,41 +15,39 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <sstream>
|
||||
#include <SPTAG/AnnService/inc/Server/QueryParser.h>
|
||||
#include <SPTAG/AnnService/inc/Core/VectorSet.h>
|
||||
#include <SPTAG/AnnService/inc/Core/Common.h>
|
||||
|
||||
#include <SPTAG/AnnService/inc/Core/VectorSet.h>
|
||||
#include <SPTAG/AnnService/inc/Server/QueryParser.h>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#undef mkdir
|
||||
|
||||
#include "IndexKDT.h"
|
||||
#include "knowhere/index/vector_index/IndexKDT.h"
|
||||
#include "knowhere/index/vector_index/helpers/Definitions.h"
|
||||
//#include "knowhere/index/preprocessor/normalize.h"
|
||||
#include "knowhere/index/vector_index/helpers/KDTParameterMgr.h"
|
||||
#include "knowhere/adapter/SptagAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/KDTParameterMgr.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
BinarySet
|
||||
CPUKDTRNG::Serialize() {
|
||||
std::vector<void *> index_blobs;
|
||||
std::vector<void*> index_blobs;
|
||||
std::vector<int64_t> index_len;
|
||||
index_ptr_->SaveIndexToMemory(index_blobs, index_len);
|
||||
BinarySet binary_set;
|
||||
|
||||
auto sample = std::make_shared<uint8_t>();
|
||||
sample.reset(static_cast<uint8_t *>(index_blobs[0]));
|
||||
sample.reset(static_cast<uint8_t*>(index_blobs[0]));
|
||||
auto tree = std::make_shared<uint8_t>();
|
||||
tree.reset(static_cast<uint8_t *>(index_blobs[1]));
|
||||
tree.reset(static_cast<uint8_t*>(index_blobs[1]));
|
||||
auto graph = std::make_shared<uint8_t>();
|
||||
graph.reset(static_cast<uint8_t *>(index_blobs[2]));
|
||||
graph.reset(static_cast<uint8_t*>(index_blobs[2]));
|
||||
auto metadata = std::make_shared<uint8_t>();
|
||||
metadata.reset(static_cast<uint8_t *>(index_blobs[3]));
|
||||
metadata.reset(static_cast<uint8_t*>(index_blobs[3]));
|
||||
|
||||
binary_set.Append("samples", sample, index_len[0]);
|
||||
binary_set.Append("tree", tree, index_len[1]);
|
||||
|
@ -59,8 +57,8 @@ CPUKDTRNG::Serialize() {
|
|||
}
|
||||
|
||||
void
|
||||
CPUKDTRNG::Load(const BinarySet &binary_set) {
|
||||
std::vector<void *> index_blobs;
|
||||
CPUKDTRNG::Load(const BinarySet& binary_set) {
|
||||
std::vector<void*> index_blobs;
|
||||
|
||||
auto samples = binary_set.GetByName("samples");
|
||||
index_blobs.push_back(samples->data.get());
|
||||
|
@ -77,17 +75,17 @@ CPUKDTRNG::Load(const BinarySet &binary_set) {
|
|||
index_ptr_->LoadIndexFromMemory(index_blobs);
|
||||
}
|
||||
|
||||
//PreprocessorPtr
|
||||
//CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
|
||||
// PreprocessorPtr
|
||||
// CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
|
||||
// return std::make_shared<NormalizePreprocessor>();
|
||||
//}
|
||||
|
||||
IndexModelPtr
|
||||
CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) {
|
||||
CPUKDTRNG::Train(const DatasetPtr& origin, const Config& train_config) {
|
||||
SetParameters(train_config);
|
||||
DatasetPtr dataset = origin->Clone();
|
||||
|
||||
//if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
|
||||
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
|
||||
// && preprocessor_) {
|
||||
// preprocessor_->Preprocess(dataset);
|
||||
//}
|
||||
|
@ -101,11 +99,11 @@ CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) {
|
|||
}
|
||||
|
||||
void
|
||||
CPUKDTRNG::Add(const DatasetPtr &origin, const Config &add_config) {
|
||||
CPUKDTRNG::Add(const DatasetPtr& origin, const Config& add_config) {
|
||||
SetParameters(add_config);
|
||||
DatasetPtr dataset = origin->Clone();
|
||||
|
||||
//if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
|
||||
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
|
||||
// && preprocessor_) {
|
||||
// preprocessor_->Preprocess(dataset);
|
||||
//}
|
||||
|
@ -116,18 +114,18 @@ CPUKDTRNG::Add(const DatasetPtr &origin, const Config &add_config) {
|
|||
}
|
||||
|
||||
void
|
||||
CPUKDTRNG::SetParameters(const Config &config) {
|
||||
for (auto ¶ : KDTParameterMgr::GetInstance().GetKDTParameters()) {
|
||||
auto value = config.get_with_default(para.first, para.second);
|
||||
index_ptr_->SetParameter(para.first, value);
|
||||
CPUKDTRNG::SetParameters(const Config& config) {
|
||||
for (auto& para : KDTParameterMgr::GetInstance().GetKDTParameters()) {
|
||||
// auto value = config.get_with_default(para.first, para.second);
|
||||
index_ptr_->SetParameter(para.first, para.second);
|
||||
}
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
CPUKDTRNG::Search(const DatasetPtr &dataset, const Config &config) {
|
||||
CPUKDTRNG::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
SetParameters(config);
|
||||
auto tensor = dataset->tensor()[0];
|
||||
auto p = (float *) tensor->raw_mutable_data();
|
||||
auto p = (float*)tensor->raw_mutable_data();
|
||||
for (auto i = 0; i < 10; ++i) {
|
||||
for (auto j = 0; j < 10; ++j) {
|
||||
std::cout << p[i * 10 + j] << " ";
|
||||
|
@ -138,7 +136,7 @@ CPUKDTRNG::Search(const DatasetPtr &dataset, const Config &config) {
|
|||
|
||||
#pragma omp parallel for
|
||||
for (auto i = 0; i < query_results.size(); ++i) {
|
||||
auto target = (float *) query_results[i].GetTarget();
|
||||
auto target = (float*)query_results[i].GetTarget();
|
||||
std::cout << target[0] << ", " << target[1] << ", " << target[2] << std::endl;
|
||||
index_ptr_->SearchIndex(query_results[i]);
|
||||
}
|
||||
|
@ -146,27 +144,34 @@ CPUKDTRNG::Search(const DatasetPtr &dataset, const Config &config) {
|
|||
return ConvertToDataset(query_results);
|
||||
}
|
||||
|
||||
int64_t CPUKDTRNG::Count() {
|
||||
int64_t
|
||||
CPUKDTRNG::Count() {
|
||||
index_ptr_->GetNumSamples();
|
||||
}
|
||||
int64_t CPUKDTRNG::Dimension() {
|
||||
|
||||
int64_t
|
||||
CPUKDTRNG::Dimension() {
|
||||
index_ptr_->GetFeatureDim();
|
||||
}
|
||||
|
||||
VectorIndexPtr CPUKDTRNG::Clone() {
|
||||
VectorIndexPtr
|
||||
CPUKDTRNG::Clone() {
|
||||
KNOWHERE_THROW_MSG("not support");
|
||||
}
|
||||
|
||||
void CPUKDTRNG::Seal() {
|
||||
void
|
||||
CPUKDTRNG::Seal() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
// TODO(linxj):
|
||||
BinarySet
|
||||
CPUKDTRNGIndexModel::Serialize() {}
|
||||
CPUKDTRNGIndexModel::Serialize() {
|
||||
}
|
||||
|
||||
void
|
||||
CPUKDTRNGIndexModel::Load(const BinarySet &binary) {}
|
||||
CPUKDTRNGIndexModel::Load(const BinarySet& binary) {
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,53 +15,54 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <SPTAG/AnnService/inc/Core/VectorIndex.h>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include "VectorIndex.h"
|
||||
#include "knowhere/index/IndexModel.h"
|
||||
#include <SPTAG/AnnService/inc/Core/VectorIndex.h>
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
||||
class CPUKDTRNG : public VectorIndex {
|
||||
public:
|
||||
CPUKDTRNG() {
|
||||
index_ptr_ = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::KDT,
|
||||
SPTAG::VectorValueType::Float);
|
||||
index_ptr_ = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::KDT, SPTAG::VectorValueType::Float);
|
||||
index_ptr_->SetParameter("DistCalcMethod", "L2");
|
||||
}
|
||||
|
||||
public:
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
VectorIndexPtr Clone() override;
|
||||
VectorIndexPtr
|
||||
Clone() override;
|
||||
void
|
||||
Load(const BinarySet &index_array) override;
|
||||
Load(const BinarySet& index_array) override;
|
||||
|
||||
public:
|
||||
//PreprocessorPtr
|
||||
//BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
|
||||
int64_t Count() override;
|
||||
int64_t Dimension() override;
|
||||
// PreprocessorPtr
|
||||
// BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
|
||||
int64_t
|
||||
Count() override;
|
||||
int64_t
|
||||
Dimension() override;
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr &dataset, const Config &config) override;
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr &dataset, const Config &config) override;
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr &dataset, const Config &config) override;
|
||||
void Seal() override;
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
void
|
||||
Seal() override;
|
||||
|
||||
private:
|
||||
void
|
||||
SetParameters(const Config &config);
|
||||
SetParameters(const Config& config);
|
||||
|
||||
private:
|
||||
PreprocessorPtr preprocessor_;
|
||||
|
@ -76,7 +77,7 @@ class CPUKDTRNGIndexModel : public IndexModel {
|
|||
Serialize() override;
|
||||
|
||||
void
|
||||
Load(const BinarySet &binary) override;
|
||||
Load(const BinarySet& binary) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SPTAG::VectorIndex> index_;
|
||||
|
@ -84,5 +85,5 @@ class CPUKDTRNGIndexModel : public IndexModel {
|
|||
|
||||
using CPUKDTRNGIndexModelPtr = std::shared_ptr<CPUKDTRNGIndexModel>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,28 +15,27 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include "IndexNSG.h"
|
||||
#include "knowhere/index/vector_index/nsg/NSG.h"
|
||||
#include "knowhere/index/vector_index/nsg/NSGIO.h"
|
||||
#include "IndexIDMAP.h"
|
||||
#include "IndexIVF.h"
|
||||
#include "IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexNSG.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
|
||||
#include "knowhere/index/vector_index/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/nsg/NSG.h"
|
||||
#include "knowhere/index/vector_index/nsg/NSGIO.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
BinarySet NSG::Serialize() {
|
||||
BinarySet
|
||||
NSG::Serialize() {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
try {
|
||||
algo::NsgIndex *index = index_.get();
|
||||
algo::NsgIndex* index = index_.get();
|
||||
|
||||
MemoryIOWriter writer;
|
||||
algo::write_index(index, writer);
|
||||
|
@ -46,12 +45,13 @@ BinarySet NSG::Serialize() {
|
|||
BinarySet res_set;
|
||||
res_set.Append("NSG", data, writer.total);
|
||||
return res_set;
|
||||
} catch (std::exception &e) {
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void NSG::Load(const BinarySet &index_binary) {
|
||||
void
|
||||
NSG::Load(const BinarySet& index_binary) {
|
||||
try {
|
||||
auto binary = index_binary.GetByName("NSG");
|
||||
|
||||
|
@ -61,36 +61,35 @@ void NSG::Load(const BinarySet &index_binary) {
|
|||
|
||||
auto index = algo::read_index(reader);
|
||||
index_.reset(index);
|
||||
} catch (std::exception &e) {
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
DatasetPtr NSG::Search(const DatasetPtr &dataset, const Config &config) {
|
||||
DatasetPtr
|
||||
NSG::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<NSGCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
// Required
|
||||
// if not found throw exception here.
|
||||
auto k = config["k"].as<size_t>();
|
||||
auto search_length = config.get_with_default("search_length", 30);
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
auto elems = rows * k;
|
||||
auto res_ids = (int64_t *) malloc(sizeof(int64_t) * elems);
|
||||
auto res_dis = (float *) malloc(sizeof(float) * elems);
|
||||
auto elems = rows * build_cfg->k;
|
||||
auto res_ids = (int64_t*)malloc(sizeof(int64_t) * elems);
|
||||
auto res_dis = (float*)malloc(sizeof(float) * elems);
|
||||
|
||||
// TODO(linxj): get from config
|
||||
algo::SearchParams s_params;
|
||||
s_params.search_length = search_length;
|
||||
index_->Search((float *) p_data, rows, dim, k, res_dis, res_ids, s_params);
|
||||
s_params.search_length = build_cfg->search_length;
|
||||
index_->Search((float*)p_data, rows, dim, build_cfg->k, res_dis, res_ids, s_params);
|
||||
|
||||
auto id_buf = MakeMutableBufferSmart((uint8_t *) res_ids, sizeof(int64_t) * elems);
|
||||
auto dist_buf = MakeMutableBufferSmart((uint8_t *) res_dis, sizeof(float) * elems);
|
||||
auto id_buf = MakeMutableBufferSmart((uint8_t*)res_ids, sizeof(int64_t) * elems);
|
||||
auto dist_buf = MakeMutableBufferSmart((uint8_t*)res_dis, sizeof(float) * elems);
|
||||
|
||||
// TODO: magic
|
||||
std::vector<BufferPtr> id_bufs{nullptr, id_buf};
|
||||
std::vector<BufferPtr> dist_bufs{nullptr, dist_buf};
|
||||
|
||||
|
@ -107,63 +106,65 @@ DatasetPtr NSG::Search(const DatasetPtr &dataset, const Config &config) {
|
|||
return std::make_shared<Dataset>(array, nullptr);
|
||||
}
|
||||
|
||||
IndexModelPtr NSG::Train(const DatasetPtr &dataset, const Config &config) {
|
||||
TimeRecorder rc("Interface");
|
||||
IndexModelPtr
|
||||
NSG::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<NSGCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
auto metric_type = config["metric_type"].as_string();
|
||||
if (metric_type != "L2") { KNOWHERE_THROW_MSG("NSG not support this kind of metric type");}
|
||||
if (build_cfg->metric_type != METRICTYPE::L2) {
|
||||
KNOWHERE_THROW_MSG("NSG not support this kind of metric type");
|
||||
}
|
||||
|
||||
// TODO(linxj): dev IndexFactory, support more IndexType
|
||||
auto preprocess_index = std::make_shared<GPUIVF>(0);
|
||||
//auto preprocess_index = std::make_shared<IVF>();
|
||||
auto preprocess_index = std::make_shared<GPUIVF>(build_cfg->gpu_id);
|
||||
auto model = preprocess_index->Train(dataset, config);
|
||||
preprocess_index->set_index_model(model);
|
||||
preprocess_index->AddWithoutIds(dataset, config);
|
||||
rc.RecordSection("build ivf");
|
||||
|
||||
auto k = config["knng"].as<int64_t>();
|
||||
Graph knng;
|
||||
preprocess_index->GenGraph(k, knng, dataset, config);
|
||||
rc.RecordSection("build knng");
|
||||
preprocess_index->GenGraph(build_cfg->knng, knng, dataset, config);
|
||||
|
||||
algo::BuildParams b_params;
|
||||
b_params.candidate_pool_size = build_cfg->candidate_pool_size;
|
||||
b_params.out_degree = build_cfg->out_degree;
|
||||
b_params.search_length = build_cfg->search_length;
|
||||
|
||||
GETTENSOR(dataset)
|
||||
auto array = dataset->array()[0];
|
||||
auto p_ids = array->data()->GetValues<long>(1, 0);
|
||||
|
||||
algo::BuildParams b_params;
|
||||
b_params.candidate_pool_size = config["candidate_pool_size"].as<size_t>();
|
||||
b_params.out_degree = config["out_degree"].as<size_t>();
|
||||
b_params.search_length = config["search_length"].as<size_t>();
|
||||
auto p_ids = array->data()->GetValues<int64_t>(1, 0);
|
||||
|
||||
index_ = std::make_shared<algo::NsgIndex>(dim, rows);
|
||||
index_->SetKnnGraph(knng);
|
||||
index_->Build_with_ids(rows, (float *) p_data, (long *) p_ids, b_params);
|
||||
rc.RecordSection("build nsg");
|
||||
rc.ElapseFromBegin("total cost");
|
||||
return nullptr; // TODO(linxj): support serialize
|
||||
index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params);
|
||||
return nullptr; // TODO(linxj): support serialize
|
||||
}
|
||||
|
||||
void NSG::Add(const DatasetPtr &dataset, const Config &config) {
|
||||
// TODO(linxj): support incremental index.
|
||||
|
||||
//KNOWHERE_THROW_MSG("Not support yet");
|
||||
}
|
||||
|
||||
int64_t NSG::Count() {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t NSG::Dimension() {
|
||||
return index_->dimension;
|
||||
}
|
||||
VectorIndexPtr NSG::Clone() {
|
||||
KNOWHERE_THROW_MSG("not support");
|
||||
}
|
||||
|
||||
void NSG::Seal() {
|
||||
void
|
||||
NSG::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
}
|
||||
int64_t
|
||||
NSG::Count() {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t
|
||||
NSG::Dimension() {
|
||||
return index_->dimension;
|
||||
}
|
||||
|
||||
VectorIndexPtr
|
||||
NSG::Clone() {
|
||||
KNOWHERE_THROW_MSG("not support");
|
||||
}
|
||||
|
||||
void
|
||||
NSG::Seal() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,11 +15,12 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "VectorIndex.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "VectorIndex.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
@ -30,18 +31,30 @@ class NsgIndex;
|
|||
|
||||
class NSG : public VectorIndex {
|
||||
public:
|
||||
explicit NSG(const int64_t& gpu_num):gpu_(gpu_num){}
|
||||
explicit NSG(const int64_t& gpu_num) : gpu_(gpu_num) {
|
||||
}
|
||||
|
||||
NSG() = default;
|
||||
|
||||
IndexModelPtr Train(const DatasetPtr &dataset, const Config &config) override;
|
||||
DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
|
||||
void Add(const DatasetPtr &dataset, const Config &config) override;
|
||||
BinarySet Serialize() override;
|
||||
void Load(const BinarySet &index_binary) override;
|
||||
int64_t Count() override;
|
||||
int64_t Dimension() override;
|
||||
VectorIndexPtr Clone() override;
|
||||
void Seal() override;
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
int64_t
|
||||
Count() override;
|
||||
int64_t
|
||||
Dimension() override;
|
||||
VectorIndexPtr
|
||||
Clone() override;
|
||||
void
|
||||
Seal() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<algo::NsgIndex> index_;
|
||||
int64_t gpu_;
|
||||
|
@ -49,5 +62,5 @@ class NSG : public VectorIndex {
|
|||
|
||||
using NSGIndexPtr = std::shared_ptr<NSG>();
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,17 +15,15 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "knowhere/common/Config.h"
|
||||
#include "knowhere/common/Dataset.h"
|
||||
#include "knowhere/index/Index.h"
|
||||
#include "knowhere/index/preprocessor/Preprocessor.h"
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
@ -33,17 +31,20 @@ namespace knowhere {
|
|||
class VectorIndex;
|
||||
using VectorIndexPtr = std::shared_ptr<VectorIndex>;
|
||||
|
||||
|
||||
class VectorIndex : public Index {
|
||||
public:
|
||||
virtual PreprocessorPtr
|
||||
BuildPreprocessor(const DatasetPtr &dataset, const Config &config) { return nullptr; }
|
||||
BuildPreprocessor(const DatasetPtr& dataset, const Config& config) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual IndexModelPtr
|
||||
Train(const DatasetPtr &dataset, const Config &config) { return nullptr; }
|
||||
Train(const DatasetPtr& dataset, const Config& config) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual void
|
||||
Add(const DatasetPtr &dataset, const Config &config) = 0;
|
||||
Add(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
virtual void
|
||||
Seal() = 0;
|
||||
|
@ -58,7 +59,5 @@ class VectorIndex : public Index {
|
|||
Dimension() = 0;
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,21 +15,20 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/Cloner.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexGPUIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "Cloner.h"
|
||||
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFSQ.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
namespace cloner {
|
||||
|
||||
VectorIndexPtr CopyGpuToCpu(const VectorIndexPtr &index, const Config &config) {
|
||||
VectorIndexPtr
|
||||
CopyGpuToCpu(const VectorIndexPtr& index, const Config& config) {
|
||||
if (auto device_index = std::dynamic_pointer_cast<GPUIndex>(index)) {
|
||||
return device_index->CopyGpuToCpu(config);
|
||||
} else {
|
||||
|
@ -37,7 +36,8 @@ VectorIndexPtr CopyGpuToCpu(const VectorIndexPtr &index, const Config &config) {
|
|||
}
|
||||
}
|
||||
|
||||
VectorIndexPtr CopyCpuToGpu(const VectorIndexPtr &index, const int64_t &device_id, const Config &config) {
|
||||
VectorIndexPtr
|
||||
CopyCpuToGpu(const VectorIndexPtr& index, const int64_t& device_id, const Config& config) {
|
||||
if (auto device_index = std::dynamic_pointer_cast<GPUIndex>(index)) {
|
||||
return device_index->CopyGpuToGpu(device_id, config);
|
||||
}
|
||||
|
@ -55,6 +55,6 @@ VectorIndexPtr CopyCpuToGpu(const VectorIndexPtr &index, const int64_t &device_i
|
|||
}
|
||||
}
|
||||
|
||||
} // cloner
|
||||
}
|
||||
}
|
||||
} // namespace cloner
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,23 +15,21 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "knowhere/index/vector_index/VectorIndex.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
namespace cloner {
|
||||
|
||||
// TODO(linxj): rename CopyToGpu
|
||||
extern VectorIndexPtr
|
||||
CopyCpuToGpu(const VectorIndexPtr &index, const int64_t &device_id, const Config &config);
|
||||
CopyCpuToGpu(const VectorIndexPtr& index, const int64_t& device_id, const Config& config);
|
||||
|
||||
extern VectorIndexPtr
|
||||
CopyGpuToCpu(const VectorIndexPtr &index, const Config &config);
|
||||
CopyGpuToCpu(const VectorIndexPtr& index, const Config& config);
|
||||
|
||||
} // cloner
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
} // namespace cloner
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,10 +15,8 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
namespace definition {
|
||||
|
@ -27,6 +25,6 @@ namespace definition {
|
|||
#define META_DIM ("dimension")
|
||||
#define META_K ("k")
|
||||
|
||||
} // definition
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
} // namespace definition
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,25 +15,24 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
|
||||
#include "FaissGpuResourceMgr.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
FaissGpuResourceMgr &FaissGpuResourceMgr::GetInstance() {
|
||||
FaissGpuResourceMgr&
|
||||
FaissGpuResourceMgr::GetInstance() {
|
||||
static FaissGpuResourceMgr instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void FaissGpuResourceMgr::AllocateTempMem(ResPtr &resource,
|
||||
const int64_t &device_id,
|
||||
const int64_t &size) {
|
||||
void
|
||||
FaissGpuResourceMgr::AllocateTempMem(ResPtr& resource, const int64_t& device_id, const int64_t& size) {
|
||||
if (size) {
|
||||
resource->faiss_res->setTempMemory(size);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
auto search = devices_params_.find(device_id);
|
||||
if (search != devices_params_.end()) {
|
||||
resource->faiss_res->setTempMemory(search->second.temp_mem_size);
|
||||
|
@ -42,10 +41,8 @@ void FaissGpuResourceMgr::AllocateTempMem(ResPtr &resource,
|
|||
}
|
||||
}
|
||||
|
||||
void FaissGpuResourceMgr::InitDevice(int64_t device_id,
|
||||
int64_t pin_mem_size,
|
||||
int64_t temp_mem_size,
|
||||
int64_t res_num) {
|
||||
void
|
||||
FaissGpuResourceMgr::InitDevice(int64_t device_id, int64_t pin_mem_size, int64_t temp_mem_size, int64_t res_num) {
|
||||
DeviceParams params;
|
||||
params.pinned_mem_size = pin_mem_size;
|
||||
params.temp_mem_size = temp_mem_size;
|
||||
|
@ -54,23 +51,25 @@ void FaissGpuResourceMgr::InitDevice(int64_t device_id,
|
|||
devices_params_.emplace(device_id, params);
|
||||
}
|
||||
|
||||
void FaissGpuResourceMgr::InitResource() {
|
||||
if(is_init) return ;
|
||||
void
|
||||
FaissGpuResourceMgr::InitResource() {
|
||||
if (is_init)
|
||||
return;
|
||||
|
||||
is_init = true;
|
||||
|
||||
//std::cout << "InitResource" << std::endl;
|
||||
for(auto& device : devices_params_) {
|
||||
// std::cout << "InitResource" << std::endl;
|
||||
for (auto& device : devices_params_) {
|
||||
auto& device_id = device.first;
|
||||
|
||||
mutex_cache_.emplace(device_id, std::make_unique<std::mutex>());
|
||||
|
||||
//std::cout << "Device Id: " << device_id << std::endl;
|
||||
// std::cout << "Device Id: " << device_id << std::endl;
|
||||
auto& device_param = device.second;
|
||||
auto& bq = idle_map_[device_id];
|
||||
|
||||
for (int64_t i = 0; i < device_param.resource_num; ++i) {
|
||||
//std::cout << "Resource Id: " << i << std::endl;
|
||||
// std::cout << "Resource Id: " << i << std::endl;
|
||||
auto raw_resource = std::make_shared<faiss::gpu::StandardGpuResources>();
|
||||
|
||||
// TODO(linxj): enable set pinned memory
|
||||
|
@ -80,11 +79,11 @@ void FaissGpuResourceMgr::InitResource() {
|
|||
bq.Put(res_wrapper);
|
||||
}
|
||||
}
|
||||
//std::cout << "End initResource" << std::endl;
|
||||
// std::cout << "End initResource" << std::endl;
|
||||
}
|
||||
|
||||
ResPtr FaissGpuResourceMgr::GetRes(const int64_t &device_id,
|
||||
const int64_t &alloc_size) {
|
||||
ResPtr
|
||||
FaissGpuResourceMgr::GetRes(const int64_t& device_id, const int64_t& alloc_size) {
|
||||
InitResource();
|
||||
|
||||
auto finder = idle_map_.find(device_id);
|
||||
|
@ -97,7 +96,8 @@ ResPtr FaissGpuResourceMgr::GetRes(const int64_t &device_id,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void FaissGpuResourceMgr::MoveToIdle(const int64_t &device_id, const ResPtr &res) {
|
||||
void
|
||||
FaissGpuResourceMgr::MoveToIdle(const int64_t& device_id, const ResPtr& res) {
|
||||
auto finder = idle_map_.find(device_id);
|
||||
if (finder != idle_map_.end()) {
|
||||
auto& bq = finder->second;
|
||||
|
@ -105,8 +105,9 @@ void FaissGpuResourceMgr::MoveToIdle(const int64_t &device_id, const ResPtr &res
|
|||
}
|
||||
}
|
||||
|
||||
void FaissGpuResourceMgr::Free() {
|
||||
for (auto &item : idle_map_) {
|
||||
void
|
||||
FaissGpuResourceMgr::Free() {
|
||||
for (auto& item : idle_map_) {
|
||||
auto& bq = item.second;
|
||||
while (!bq.Empty()) {
|
||||
bq.Take();
|
||||
|
@ -117,12 +118,11 @@ void FaissGpuResourceMgr::Free() {
|
|||
|
||||
void
|
||||
FaissGpuResourceMgr::Dump() {
|
||||
for (auto &item : idle_map_) {
|
||||
for (auto& item : idle_map_) {
|
||||
auto& bq = item.second;
|
||||
std::cout << "device_id: " << item.first
|
||||
<< ", resource count:" << bq.Size();
|
||||
std::cout << "device_id: " << item.first << ", resource count:" << bq.Size();
|
||||
}
|
||||
}
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
|
||||
#include <faiss/gpu/StandardGpuResources.h>
|
||||
|
||||
|
@ -30,7 +30,7 @@ namespace zilliz {
|
|||
namespace knowhere {
|
||||
|
||||
struct Resource {
|
||||
explicit Resource(std::shared_ptr<faiss::gpu::StandardGpuResources> &r) : faiss_res(r) {
|
||||
explicit Resource(std::shared_ptr<faiss::gpu::StandardGpuResources>& r) : faiss_res(r) {
|
||||
static int64_t global_id = 0;
|
||||
id = global_id++;
|
||||
}
|
||||
|
@ -43,19 +43,19 @@ using ResPtr = std::shared_ptr<Resource>;
|
|||
using ResWPtr = std::weak_ptr<Resource>;
|
||||
|
||||
class FaissGpuResourceMgr {
|
||||
public:
|
||||
public:
|
||||
friend class ResScope;
|
||||
using ResBQ = zilliz::milvus::server::BlockingQueue<ResPtr>;
|
||||
|
||||
public:
|
||||
public:
|
||||
struct DeviceParams {
|
||||
int64_t temp_mem_size = 0;
|
||||
int64_t pinned_mem_size = 0;
|
||||
int64_t resource_num = 2;
|
||||
};
|
||||
|
||||
public:
|
||||
static FaissGpuResourceMgr &
|
||||
public:
|
||||
static FaissGpuResourceMgr&
|
||||
GetInstance();
|
||||
|
||||
// Free gpu resource, avoid cudaGetDevice error when deallocate.
|
||||
|
@ -64,67 +64,67 @@ public:
|
|||
Free();
|
||||
|
||||
void
|
||||
AllocateTempMem(ResPtr &resource, const int64_t& device_id, const int64_t& size);
|
||||
AllocateTempMem(ResPtr& resource, const int64_t& device_id, const int64_t& size);
|
||||
|
||||
void
|
||||
InitDevice(int64_t device_id,
|
||||
int64_t pin_mem_size = 0,
|
||||
int64_t temp_mem_size = 0,
|
||||
int64_t res_num = 2);
|
||||
InitDevice(int64_t device_id, int64_t pin_mem_size = 0, int64_t temp_mem_size = 0, int64_t res_num = 2);
|
||||
|
||||
void
|
||||
InitResource();
|
||||
|
||||
// allocate gpu memory invoke by build or copy_to_gpu
|
||||
ResPtr
|
||||
GetRes(const int64_t &device_id, const int64_t& alloc_size = 0);
|
||||
GetRes(const int64_t& device_id, const int64_t& alloc_size = 0);
|
||||
|
||||
void
|
||||
MoveToIdle(const int64_t &device_id, const ResPtr& res);
|
||||
MoveToIdle(const int64_t& device_id, const ResPtr& res);
|
||||
|
||||
void
|
||||
Dump();
|
||||
|
||||
protected:
|
||||
protected:
|
||||
bool is_init = false;
|
||||
|
||||
std::map<int64_t ,std::unique_ptr<std::mutex>> mutex_cache_;
|
||||
std::map<int64_t, std::unique_ptr<std::mutex>> mutex_cache_;
|
||||
std::map<int64_t, DeviceParams> devices_params_;
|
||||
std::map<int64_t, ResBQ> idle_map_;
|
||||
};
|
||||
|
||||
class ResScope {
|
||||
public:
|
||||
ResScope(ResPtr &res, const int64_t& device_id, const bool& isown)
|
||||
: resource(res), device_id(device_id), move(true), own(isown) {
|
||||
public:
|
||||
ResScope(ResPtr& res, const int64_t& device_id, const bool& isown)
|
||||
: resource(res), device_id(device_id), move(true), own(isown) {
|
||||
Lock();
|
||||
}
|
||||
|
||||
// specif for search
|
||||
// get the ownership of gpuresource and gpu
|
||||
ResScope(ResWPtr &res, const int64_t &device_id)
|
||||
:device_id(device_id),move(false),own(true) {
|
||||
ResScope(ResWPtr& res, const int64_t& device_id) : device_id(device_id), move(false), own(true) {
|
||||
resource = res.lock();
|
||||
Lock();
|
||||
}
|
||||
|
||||
void Lock() {
|
||||
if (own) FaissGpuResourceMgr::GetInstance().mutex_cache_[device_id]->lock();
|
||||
void
|
||||
Lock() {
|
||||
if (own)
|
||||
FaissGpuResourceMgr::GetInstance().mutex_cache_[device_id]->lock();
|
||||
resource->mutex.lock();
|
||||
}
|
||||
|
||||
~ResScope() {
|
||||
if (own) FaissGpuResourceMgr::GetInstance().mutex_cache_[device_id]->unlock();
|
||||
if (move) FaissGpuResourceMgr::GetInstance().MoveToIdle(device_id, resource);
|
||||
if (own)
|
||||
FaissGpuResourceMgr::GetInstance().mutex_cache_[device_id]->unlock();
|
||||
if (move)
|
||||
FaissGpuResourceMgr::GetInstance().MoveToIdle(device_id, resource);
|
||||
resource->mutex.unlock();
|
||||
}
|
||||
|
||||
private:
|
||||
ResPtr resource; // hold resource until deconstruct
|
||||
private:
|
||||
ResPtr resource; // hold resource until deconstruct
|
||||
int64_t device_id;
|
||||
bool move = true;
|
||||
bool own = false;
|
||||
};
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,51 +15,55 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "FaissIO.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
// TODO(linxj): Get From Config File
|
||||
static size_t magic_num = 2;
|
||||
size_t MemoryIOWriter::operator()(const void *ptr, size_t size, size_t nitems) {
|
||||
|
||||
size_t
|
||||
MemoryIOWriter::operator()(const void* ptr, size_t size, size_t nitems) {
|
||||
auto total_need = size * nitems + rp;
|
||||
|
||||
if (!data_) { // data == nullptr
|
||||
if (!data_) { // data == nullptr
|
||||
total = total_need * magic_num;
|
||||
rp = size * nitems;
|
||||
data_ = new uint8_t[total];
|
||||
memcpy((void *) (data_), ptr, rp);
|
||||
memcpy((void*)(data_), ptr, rp);
|
||||
}
|
||||
|
||||
if (total_need > total) {
|
||||
total = total_need * magic_num;
|
||||
auto new_data = new uint8_t[total];
|
||||
memcpy((void *) new_data, (void *) data_, rp);
|
||||
memcpy((void*)new_data, (void*)data_, rp);
|
||||
delete data_;
|
||||
data_ = new_data;
|
||||
|
||||
memcpy((void *) (data_ + rp), ptr, size * nitems);
|
||||
memcpy((void*)(data_ + rp), ptr, size * nitems);
|
||||
rp = total_need;
|
||||
} else {
|
||||
memcpy((void *) (data_ + rp), ptr, size * nitems);
|
||||
memcpy((void*)(data_ + rp), ptr, size * nitems);
|
||||
rp = total_need;
|
||||
}
|
||||
|
||||
return nitems;
|
||||
}
|
||||
|
||||
size_t MemoryIOReader::operator()(void *ptr, size_t size, size_t nitems) {
|
||||
if (rp >= total) return 0;
|
||||
size_t
|
||||
MemoryIOReader::operator()(void* ptr, size_t size, size_t nitems) {
|
||||
if (rp >= total)
|
||||
return 0;
|
||||
size_t nremain = (total - rp) / size;
|
||||
if (nremain < nitems) nitems = nremain;
|
||||
memcpy(ptr, (void *) (data_ + rp), size * nitems);
|
||||
if (nremain < nitems)
|
||||
nitems = nremain;
|
||||
memcpy(ptr, (void*)(data_ + rp), size * nitems);
|
||||
rp += size * nitems;
|
||||
return nitems;
|
||||
}
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <faiss/AuxIndexStructures.h>
|
||||
|
@ -24,25 +23,22 @@ namespace zilliz {
|
|||
namespace knowhere {
|
||||
|
||||
struct MemoryIOWriter : public faiss::IOWriter {
|
||||
uint8_t *data_ = nullptr;
|
||||
uint8_t* data_ = nullptr;
|
||||
size_t total = 0;
|
||||
size_t rp = 0;
|
||||
|
||||
size_t
|
||||
operator()(const void *ptr, size_t size, size_t nitems) override;
|
||||
operator()(const void* ptr, size_t size, size_t nitems) override;
|
||||
};
|
||||
|
||||
struct MemoryIOReader : public faiss::IOReader {
|
||||
uint8_t *data_;
|
||||
uint8_t* data_;
|
||||
size_t rp = 0;
|
||||
size_t total = 0;
|
||||
|
||||
size_t
|
||||
operator()(void *ptr, size_t size, size_t nitems) override;
|
||||
operator()(void* ptr, size_t size, size_t nitems) override;
|
||||
};
|
||||
|
||||
} // knowhere
|
||||
} // zilliz
|
||||
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
||||
#include <faiss/Index.h>
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
faiss::MetricType
|
||||
GetMetricType(METRICTYPE& type) {
|
||||
if (type == METRICTYPE::L2) {
|
||||
return faiss::METRIC_L2;
|
||||
}
|
||||
if (type == METRICTYPE::IP) {
|
||||
return faiss::METRIC_INNER_PRODUCT;
|
||||
}
|
||||
|
||||
KNOWHERE_THROW_MSG("Metric type is invalid");
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
|
@ -0,0 +1,135 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <faiss/Index.h>
|
||||
#include <memory>
|
||||
|
||||
#include "knowhere/common/Config.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
extern faiss::MetricType
|
||||
GetMetricType(METRICTYPE& type);
|
||||
|
||||
// IVF Config
|
||||
constexpr int64_t DEFAULT_NLIST = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NPROBE = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NSUBVECTORS = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NBITS = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_SCAN_TABLE_THREHOLD = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_POLYSEMOUS_HT = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_MAX_CODES = INVALID_VALUE;
|
||||
|
||||
// NSG Config
|
||||
constexpr int64_t DEFAULT_SEARCH_LENGTH = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_OUT_DEGREE = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_CANDIDATE_SISE = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NNG_K = INVALID_VALUE;
|
||||
|
||||
struct IVFCfg : public Cfg {
|
||||
int64_t nlist = DEFAULT_NLIST;
|
||||
int64_t nprobe = DEFAULT_NPROBE;
|
||||
|
||||
IVFCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
|
||||
METRICTYPE type)
|
||||
: Cfg(dim, k, gpu_id, type), nlist(nlist), nprobe(nprobe) {
|
||||
}
|
||||
|
||||
IVFCfg() = default;
|
||||
|
||||
bool
|
||||
CheckValid() override {
|
||||
return true;
|
||||
};
|
||||
};
|
||||
using IVFConfig = std::shared_ptr<IVFCfg>;
|
||||
|
||||
struct IVFSQCfg : public IVFCfg {
|
||||
// TODO(linxj): cpu only support SQ4 SQ6 SQ8 SQ16, gpu only support SQ4, SQ8, SQ16
|
||||
int64_t nbits = DEFAULT_NBITS;
|
||||
|
||||
IVFSQCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
|
||||
const int64_t& nbits, METRICTYPE type)
|
||||
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type), nbits(nbits) {
|
||||
}
|
||||
|
||||
IVFSQCfg() = default;
|
||||
|
||||
bool
|
||||
CheckValid() override {
|
||||
return true;
|
||||
};
|
||||
};
|
||||
using IVFSQConfig = std::shared_ptr<IVFSQCfg>;
|
||||
|
||||
struct IVFPQCfg : public IVFCfg {
|
||||
int64_t m = DEFAULT_NSUBVECTORS; // number of subquantizers(subvector)
|
||||
int64_t nbits = DEFAULT_NBITS; // number of bit per subvector index
|
||||
|
||||
// TODO(linxj): not use yet
|
||||
int64_t scan_table_threhold = DEFAULT_SCAN_TABLE_THREHOLD;
|
||||
int64_t polysemous_ht = DEFAULT_POLYSEMOUS_HT;
|
||||
int64_t max_codes = DEFAULT_MAX_CODES;
|
||||
|
||||
IVFPQCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
|
||||
const int64_t& nbits, const int64_t& m, METRICTYPE type)
|
||||
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type), m(m), nbits(nbits) {
|
||||
}
|
||||
|
||||
IVFPQCfg() = default;
|
||||
|
||||
bool
|
||||
CheckValid() override {
|
||||
return true;
|
||||
};
|
||||
};
|
||||
using IVFPQConfig = std::shared_ptr<IVFPQCfg>;
|
||||
|
||||
struct NSGCfg : public IVFCfg {
|
||||
int64_t knng = DEFAULT_NNG_K;
|
||||
int64_t search_length = DEFAULT_SEARCH_LENGTH;
|
||||
int64_t out_degree = DEFAULT_OUT_DEGREE;
|
||||
int64_t candidate_pool_size = DEFAULT_CANDIDATE_SISE;
|
||||
|
||||
NSGCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
|
||||
const int64_t& knng, const int64_t& search_length, const int64_t& out_degree, const int64_t& candidate_size,
|
||||
METRICTYPE type)
|
||||
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type),
|
||||
knng(knng),
|
||||
search_length(search_length),
|
||||
out_degree(out_degree),
|
||||
candidate_pool_size(candidate_size) {
|
||||
}
|
||||
|
||||
NSGCfg() = default;
|
||||
|
||||
bool
|
||||
CheckValid() override {
|
||||
return true;
|
||||
};
|
||||
};
|
||||
using NSGConfig = std::shared_ptr<NSGCfg>;
|
||||
|
||||
struct KDTCfg : public Cfg {
|
||||
int64_t tptnubmber = -1;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
|
@ -15,16 +15,14 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <mutex>
|
||||
|
||||
#include "KDTParameterMgr.h"
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/KDTParameterMgr.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
const std::vector<KDTParameter> &
|
||||
const std::vector<KDTParameter>&
|
||||
KDTParameterMgr::GetKDTParameters() {
|
||||
return kdt_parameters_;
|
||||
}
|
||||
|
@ -35,7 +33,7 @@ KDTParameterMgr::KDTParameterMgr() {
|
|||
{"NumTopDimensionKDTSplit", "5"},
|
||||
{"NumSamplesKDTSplitConsideration", "100"},
|
||||
|
||||
{"TPTNumber", "32"},
|
||||
{"TPTNumber", "1"},
|
||||
{"TPTLeafSize", "2000"},
|
||||
{"NumTopDimensionTPTSplit", "5"},
|
||||
|
||||
|
@ -55,5 +53,5 @@ KDTParameterMgr::KDTParameterMgr() {
|
|||
};
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,13 +15,13 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
|
@ -29,18 +29,20 @@ using KDTParameter = std::pair<std::string, std::string>;
|
|||
|
||||
class KDTParameterMgr {
|
||||
public:
|
||||
const std::vector<KDTParameter> &
|
||||
const std::vector<KDTParameter>&
|
||||
GetKDTParameters();
|
||||
|
||||
public:
|
||||
static KDTParameterMgr &
|
||||
static KDTParameterMgr&
|
||||
GetInstance() {
|
||||
static KDTParameterMgr instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
KDTParameterMgr(const KDTParameterMgr &) = delete;
|
||||
KDTParameterMgr &operator=(const KDTParameterMgr &) = delete;
|
||||
KDTParameterMgr(const KDTParameterMgr&) = delete;
|
||||
KDTParameterMgr&
|
||||
operator=(const KDTParameterMgr&) = delete;
|
||||
|
||||
private:
|
||||
KDTParameterMgr();
|
||||
|
||||
|
@ -48,5 +50,5 @@ class KDTParameterMgr {
|
|||
std::vector<KDTParameter> kdt_parameters_;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,28 +15,28 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <stack>
|
||||
#include <omp.h>
|
||||
#include <utility>
|
||||
|
||||
#include "NSG.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
#include "NSGHelper.h"
|
||||
#include "knowhere/index/vector_index/nsg/NSG.h"
|
||||
#include "knowhere/index/vector_index/nsg/NSGHelper.h"
|
||||
|
||||
// TODO: enable macro
|
||||
//#include <gperftools/profiler.h>
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
|
||||
|
||||
NsgIndex::NsgIndex(const size_t &dimension, const size_t &n, MetricType metric)
|
||||
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, MetricType metric)
|
||||
: dimension(dimension), ntotal(n), metric_type(metric) {
|
||||
}
|
||||
|
||||
|
@ -45,16 +45,17 @@ NsgIndex::~NsgIndex() {
|
|||
delete[] ids_;
|
||||
}
|
||||
|
||||
//void NsgIndex::Build(size_t nb, const float *data, const BuildParam ¶meters) {
|
||||
// void NsgIndex::Build(size_t nb, const float *data, const BuildParam ¶meters) {
|
||||
//}
|
||||
void NsgIndex::Build_with_ids(size_t nb, const float *data, const long *ids, const BuildParams ¶meters) {
|
||||
void
|
||||
NsgIndex::Build_with_ids(size_t nb, const float* data, const int64_t* ids, const BuildParams& parameters) {
|
||||
TimeRecorder rc("NSG");
|
||||
|
||||
ntotal = nb;
|
||||
ori_data_ = new float[ntotal * dimension];
|
||||
ids_ = new long[ntotal];
|
||||
memcpy((void *) ori_data_, (void *) data, sizeof(float) * ntotal * dimension);
|
||||
memcpy((void *) ids_, (void *) ids, sizeof(long) * ntotal);
|
||||
ids_ = new int64_t[ntotal];
|
||||
memcpy((void*)ori_data_, (void*)data, sizeof(float) * ntotal * dimension);
|
||||
memcpy((void*)ids_, (void*)ids, sizeof(int64_t) * ntotal);
|
||||
|
||||
search_length = parameters.search_length;
|
||||
out_degree = parameters.out_degree;
|
||||
|
@ -68,8 +69,8 @@ void NsgIndex::Build_with_ids(size_t nb, const float *data, const long *ids, con
|
|||
|
||||
//>> Debug code
|
||||
/////
|
||||
//int count = 0;
|
||||
//for (int i = 0; i < ntotal; ++i) {
|
||||
// int count = 0;
|
||||
// for (int i = 0; i < ntotal; ++i) {
|
||||
// count += nsg[i].size();
|
||||
//}
|
||||
/////
|
||||
|
@ -80,17 +81,19 @@ void NsgIndex::Build_with_ids(size_t nb, const float *data, const long *ids, con
|
|||
//>> Debug code
|
||||
///
|
||||
int total_degree = 0;
|
||||
for (int i = 0; i < ntotal; ++i) {
|
||||
for (size_t i = 0; i < ntotal; ++i) {
|
||||
total_degree += nsg[i].size();
|
||||
}
|
||||
std::cout << "graph physical size: " << total_degree * sizeof(node_t) / 1024 / 1024;
|
||||
std::cout << "average degree: " << total_degree / ntotal;
|
||||
|
||||
KNOWHERE_LOG_DEBUG << "Graph physical size: " << total_degree * sizeof(node_t) / 1024 / 1024 << "m";
|
||||
KNOWHERE_LOG_DEBUG << "Average degree: " << total_degree / ntotal;
|
||||
/////
|
||||
|
||||
is_trained = true;
|
||||
}
|
||||
|
||||
void NsgIndex::InitNavigationPoint() {
|
||||
void
|
||||
NsgIndex::InitNavigationPoint() {
|
||||
// calculate the center of vectors
|
||||
auto center = new float[dimension];
|
||||
memset(center, 0, sizeof(float) * dimension);
|
||||
|
@ -106,11 +109,12 @@ void NsgIndex::InitNavigationPoint() {
|
|||
|
||||
// select navigation point
|
||||
std::vector<Neighbor> resset, fullset;
|
||||
navigation_point = rand() % ntotal; // random initialize navigating point
|
||||
unsigned int seed = 100;
|
||||
navigation_point = rand_r(&seed) % ntotal; // random initialize navigating point
|
||||
|
||||
//>> Debug code
|
||||
/////
|
||||
//navigation_point = drand48();
|
||||
// navigation_point = drand48();
|
||||
/////
|
||||
|
||||
GetNeighbors(center, resset, knng);
|
||||
|
@ -118,22 +122,21 @@ void NsgIndex::InitNavigationPoint() {
|
|||
|
||||
//>> Debug code
|
||||
/////
|
||||
//std::cout << "ep: " << navigation_point << std::endl;
|
||||
// std::cout << "ep: " << navigation_point << std::endl;
|
||||
/////
|
||||
|
||||
//>> Debug code
|
||||
/////
|
||||
//float r1 = calculate(center, ori_data_ + navigation_point * dimension, dimension);
|
||||
//assert(r1 == resset[0].distance);
|
||||
// float r1 = calculate(center, ori_data_ + navigation_point * dimension, dimension);
|
||||
// assert(r1 == resset[0].distance);
|
||||
/////
|
||||
}
|
||||
|
||||
// Specify Link
|
||||
void NsgIndex::GetNeighbors(const float *query,
|
||||
std::vector<Neighbor> &resset,
|
||||
std::vector<Neighbor> &fullset,
|
||||
boost::dynamic_bitset<> &has_calculated_dist) {
|
||||
auto &graph = knng;
|
||||
void
|
||||
NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, std::vector<Neighbor>& fullset,
|
||||
boost::dynamic_bitset<>& has_calculated_dist) {
|
||||
auto& graph = knng;
|
||||
size_t buffer_size = search_length;
|
||||
|
||||
if (buffer_size > ntotal) {
|
||||
|
@ -154,9 +157,12 @@ void NsgIndex::GetNeighbors(const float *query,
|
|||
has_calculated_dist[init_ids[i]] = true;
|
||||
++count;
|
||||
}
|
||||
|
||||
unsigned int seed = 100;
|
||||
while (count < buffer_size) {
|
||||
node_t id = rand() % ntotal;
|
||||
if (has_calculated_dist[id]) continue; // duplicate id
|
||||
node_t id = rand_r(&seed) % ntotal;
|
||||
if (has_calculated_dist[id])
|
||||
continue; // duplicate id
|
||||
init_ids.push_back(id);
|
||||
++count;
|
||||
has_calculated_dist[id] = true;
|
||||
|
@ -170,7 +176,7 @@ void NsgIndex::GetNeighbors(const float *query,
|
|||
for (size_t i = 0; i < init_ids.size(); ++i) {
|
||||
node_t id = init_ids[i];
|
||||
|
||||
if (id >= ntotal) {
|
||||
if (id >= static_cast<node_t>(ntotal)) {
|
||||
KNOWHERE_THROW_MSG("Build Index Error, id > ntotal");
|
||||
continue;
|
||||
}
|
||||
|
@ -182,9 +188,9 @@ void NsgIndex::GetNeighbors(const float *query,
|
|||
fullset.push_back(resset[i]);
|
||||
///////////////////////////////////////
|
||||
}
|
||||
std::sort(resset.begin(), resset.end()); // sort by distance
|
||||
std::sort(resset.begin(), resset.end()); // sort by distance
|
||||
|
||||
//search nearest neighbor
|
||||
// search nearest neighbor
|
||||
size_t cursor = 0;
|
||||
while (cursor < buffer_size) {
|
||||
size_t nearest_updated_pos = buffer_size;
|
||||
|
@ -193,36 +199,42 @@ void NsgIndex::GetNeighbors(const float *query,
|
|||
resset[cursor].has_explored = true;
|
||||
|
||||
node_t start_pos = resset[cursor].id;
|
||||
auto &wait_for_search_node_vec = graph[start_pos];
|
||||
auto& wait_for_search_node_vec = graph[start_pos];
|
||||
for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) {
|
||||
node_t id = wait_for_search_node_vec[i];
|
||||
if (has_calculated_dist[id]) continue;
|
||||
if (has_calculated_dist[id])
|
||||
continue;
|
||||
has_calculated_dist[id] = true;
|
||||
|
||||
float
|
||||
dist = calculate(query, ori_data_ + dimension * id, dimension);
|
||||
float dist = calculate(query, ori_data_ + dimension * id, dimension);
|
||||
Neighbor nn(id, dist, false);
|
||||
fullset.push_back(nn);
|
||||
|
||||
if (dist >= resset[buffer_size - 1].distance) continue;
|
||||
if (dist >= resset[buffer_size - 1].distance)
|
||||
continue;
|
||||
|
||||
size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node
|
||||
if (pos < nearest_updated_pos) nearest_updated_pos = pos;
|
||||
size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node
|
||||
if (pos < nearest_updated_pos)
|
||||
nearest_updated_pos = pos;
|
||||
|
||||
//assert(buffer_size + 1 >= resset.size());
|
||||
if (buffer_size + 1 < resset.size()) ++buffer_size;
|
||||
// assert(buffer_size + 1 >= resset.size());
|
||||
if (buffer_size + 1 < resset.size())
|
||||
++buffer_size;
|
||||
}
|
||||
}
|
||||
if (cursor >= nearest_updated_pos) {
|
||||
cursor = nearest_updated_pos; // re-search from new pos
|
||||
} else ++cursor;
|
||||
cursor = nearest_updated_pos; // re-search from new pos
|
||||
} else {
|
||||
++cursor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FindUnconnectedNode
|
||||
void NsgIndex::GetNeighbors(const float *query, std::vector<Neighbor> &resset, std::vector<Neighbor> &fullset) {
|
||||
auto &graph = nsg;
|
||||
void
|
||||
NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, std::vector<Neighbor>& fullset) {
|
||||
auto& graph = nsg;
|
||||
size_t buffer_size = search_length;
|
||||
|
||||
if (buffer_size > ntotal) {
|
||||
|
@ -230,7 +242,7 @@ void NsgIndex::GetNeighbors(const float *query, std::vector<Neighbor> &resset, s
|
|||
}
|
||||
|
||||
std::vector<node_t> init_ids;
|
||||
boost::dynamic_bitset<> has_calculated_dist{ntotal, 0}; // TODO: ?
|
||||
boost::dynamic_bitset<> has_calculated_dist{ntotal, 0}; // TODO: ?
|
||||
|
||||
{
|
||||
/*
|
||||
|
@ -244,9 +256,11 @@ void NsgIndex::GetNeighbors(const float *query, std::vector<Neighbor> &resset, s
|
|||
has_calculated_dist[init_ids[i]] = true;
|
||||
++count;
|
||||
}
|
||||
unsigned int seed = 100;
|
||||
while (count < buffer_size) {
|
||||
node_t id = rand() % ntotal;
|
||||
if (has_calculated_dist[id]) continue; // duplicate id
|
||||
node_t id = rand_r(&seed) % ntotal;
|
||||
if (has_calculated_dist[id])
|
||||
continue; // duplicate id
|
||||
init_ids.push_back(id);
|
||||
++count;
|
||||
has_calculated_dist[id] = true;
|
||||
|
@ -260,7 +274,7 @@ void NsgIndex::GetNeighbors(const float *query, std::vector<Neighbor> &resset, s
|
|||
for (size_t i = 0; i < init_ids.size(); ++i) {
|
||||
node_t id = init_ids[i];
|
||||
|
||||
if (id >= ntotal) {
|
||||
if (id >= static_cast<node_t>(ntotal)) {
|
||||
KNOWHERE_THROW_MSG("Build Index Error, id > ntotal");
|
||||
continue;
|
||||
}
|
||||
|
@ -268,7 +282,7 @@ void NsgIndex::GetNeighbors(const float *query, std::vector<Neighbor> &resset, s
|
|||
float dist = calculate(ori_data_ + id * dimension, query, dimension);
|
||||
resset[i] = Neighbor(id, dist, false);
|
||||
}
|
||||
std::sort(resset.begin(), resset.end()); // sort by distance
|
||||
std::sort(resset.begin(), resset.end()); // sort by distance
|
||||
|
||||
// search nearest neighbor
|
||||
size_t cursor = 0;
|
||||
|
@ -279,38 +293,41 @@ void NsgIndex::GetNeighbors(const float *query, std::vector<Neighbor> &resset, s
|
|||
resset[cursor].has_explored = true;
|
||||
|
||||
node_t start_pos = resset[cursor].id;
|
||||
auto &wait_for_search_node_vec = graph[start_pos];
|
||||
auto& wait_for_search_node_vec = graph[start_pos];
|
||||
for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) {
|
||||
node_t id = wait_for_search_node_vec[i];
|
||||
if (has_calculated_dist[id]) continue;
|
||||
if (has_calculated_dist[id])
|
||||
continue;
|
||||
has_calculated_dist[id] = true;
|
||||
|
||||
float
|
||||
dist = calculate(ori_data_ + dimension * id, query, dimension);
|
||||
float dist = calculate(ori_data_ + dimension * id, query, dimension);
|
||||
Neighbor nn(id, dist, false);
|
||||
fullset.push_back(nn);
|
||||
|
||||
if (dist >= resset[buffer_size - 1].distance) continue;
|
||||
if (dist >= resset[buffer_size - 1].distance)
|
||||
continue;
|
||||
|
||||
size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node
|
||||
if (pos < nearest_updated_pos) nearest_updated_pos = pos;
|
||||
size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node
|
||||
if (pos < nearest_updated_pos)
|
||||
nearest_updated_pos = pos;
|
||||
|
||||
//assert(buffer_size + 1 >= resset.size());
|
||||
if (buffer_size + 1 < resset.size()) ++buffer_size; // trick
|
||||
// assert(buffer_size + 1 >= resset.size());
|
||||
if (buffer_size + 1 < resset.size())
|
||||
++buffer_size; // trick
|
||||
}
|
||||
}
|
||||
if (cursor >= nearest_updated_pos) {
|
||||
cursor = nearest_updated_pos; // re-search from new pos
|
||||
} else ++cursor;
|
||||
cursor = nearest_updated_pos; // re-search from new pos
|
||||
} else {
|
||||
++cursor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NsgIndex::GetNeighbors(const float *query,
|
||||
std::vector<Neighbor> &resset,
|
||||
Graph &graph,
|
||||
SearchParams *params) {
|
||||
size_t &buffer_size = params ? params->search_length : search_length;
|
||||
void
|
||||
NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, Graph& graph, SearchParams* params) {
|
||||
size_t& buffer_size = params ? params->search_length : search_length;
|
||||
|
||||
if (buffer_size > ntotal) {
|
||||
// TODO: throw exception here.
|
||||
|
@ -331,9 +348,11 @@ void NsgIndex::GetNeighbors(const float *query,
|
|||
has_calculated_dist[init_ids[i]] = true;
|
||||
++count;
|
||||
}
|
||||
unsigned int seed = 100;
|
||||
while (count < buffer_size) {
|
||||
node_t id = rand() % ntotal;
|
||||
if (has_calculated_dist[id]) continue; // duplicate id
|
||||
node_t id = rand_r(&seed) % ntotal;
|
||||
if (has_calculated_dist[id])
|
||||
continue; // duplicate id
|
||||
init_ids.push_back(id);
|
||||
++count;
|
||||
has_calculated_dist[id] = true;
|
||||
|
@ -347,8 +366,8 @@ void NsgIndex::GetNeighbors(const float *query,
|
|||
for (size_t i = 0; i < init_ids.size(); ++i) {
|
||||
node_t id = init_ids[i];
|
||||
|
||||
//assert(id < ntotal);
|
||||
if (id >= ntotal) {
|
||||
// assert(id < ntotal);
|
||||
if (id >= static_cast<node_t>(ntotal)) {
|
||||
KNOWHERE_THROW_MSG("Build Index Error, id > ntotal");
|
||||
continue;
|
||||
}
|
||||
|
@ -356,11 +375,11 @@ void NsgIndex::GetNeighbors(const float *query,
|
|||
float dist = calculate(ori_data_ + id * dimension, query, dimension);
|
||||
resset[i] = Neighbor(id, dist, false);
|
||||
}
|
||||
std::sort(resset.begin(), resset.end()); // sort by distance
|
||||
std::sort(resset.begin(), resset.end()); // sort by distance
|
||||
|
||||
//>> Debug code
|
||||
/////
|
||||
//for (int j = 0; j < buffer_size; ++j) {
|
||||
// for (int j = 0; j < buffer_size; ++j) {
|
||||
// std::cout << "resset_id: " << resset[j].id << ", resset_dist: " << resset[j].distance << std::endl;
|
||||
//}
|
||||
/////
|
||||
|
@ -374,41 +393,47 @@ void NsgIndex::GetNeighbors(const float *query,
|
|||
resset[cursor].has_explored = true;
|
||||
|
||||
node_t start_pos = resset[cursor].id;
|
||||
auto &wait_for_search_node_vec = graph[start_pos];
|
||||
auto& wait_for_search_node_vec = graph[start_pos];
|
||||
for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) {
|
||||
node_t id = wait_for_search_node_vec[i];
|
||||
if (has_calculated_dist[id]) continue;
|
||||
if (has_calculated_dist[id])
|
||||
continue;
|
||||
has_calculated_dist[id] = true;
|
||||
|
||||
float
|
||||
dist = calculate(query, ori_data_ + dimension * id, dimension);
|
||||
float dist = calculate(query, ori_data_ + dimension * id, dimension);
|
||||
|
||||
if (dist >= resset[buffer_size - 1].distance) continue;
|
||||
if (dist >= resset[buffer_size - 1].distance)
|
||||
continue;
|
||||
///////////// difference from other GetNeighbors ///////////////
|
||||
Neighbor nn(id, dist, false);
|
||||
///////////////////////////////////////
|
||||
|
||||
size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node
|
||||
if (pos < nearest_updated_pos) nearest_updated_pos = pos;
|
||||
size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node
|
||||
if (pos < nearest_updated_pos)
|
||||
nearest_updated_pos = pos;
|
||||
|
||||
//>> Debug code
|
||||
/////
|
||||
//std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " << nearest_updated_pos << std::endl;
|
||||
// std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " <<
|
||||
// nearest_updated_pos << std::endl;
|
||||
/////
|
||||
|
||||
|
||||
// trick: avoid search query search_length < init_ids.size() ...
|
||||
if (buffer_size + 1 < resset.size()) ++buffer_size;
|
||||
if (buffer_size + 1 < resset.size())
|
||||
++buffer_size;
|
||||
}
|
||||
}
|
||||
if (cursor >= nearest_updated_pos) {
|
||||
cursor = nearest_updated_pos; // re-search from new pos
|
||||
} else ++cursor;
|
||||
cursor = nearest_updated_pos; // re-search from new pos
|
||||
} else {
|
||||
++cursor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NsgIndex::Link() {
|
||||
void
|
||||
NsgIndex::Link() {
|
||||
auto cut_graph_dist = new float[ntotal * out_degree];
|
||||
nsg.resize(ntotal);
|
||||
|
||||
|
@ -416,7 +441,7 @@ void NsgIndex::Link() {
|
|||
{
|
||||
std::vector<Neighbor> fullset;
|
||||
std::vector<Neighbor> temp;
|
||||
boost::dynamic_bitset<> flags{ntotal, 0}; // TODO: ?
|
||||
boost::dynamic_bitset<> flags{ntotal, 0}; // TODO: ?
|
||||
#pragma omp for schedule(dynamic, 100)
|
||||
for (size_t n = 0; n < ntotal; ++n) {
|
||||
fullset.clear();
|
||||
|
@ -425,8 +450,8 @@ void NsgIndex::Link() {
|
|||
|
||||
//>> Debug code
|
||||
/////
|
||||
//float r1 = calculate(ori_data_ + n * dimension, ori_data_ + temp[0].id * dimension, dimension);
|
||||
//assert(r1 == temp[0].distance);
|
||||
// float r1 = calculate(ori_data_ + n * dimension, ori_data_ + temp[0].id * dimension, dimension);
|
||||
// assert(r1 == temp[0].distance);
|
||||
/////
|
||||
SyncPrune(n, fullset, flags, cut_graph_dist);
|
||||
}
|
||||
|
@ -434,7 +459,7 @@ void NsgIndex::Link() {
|
|||
|
||||
//>> Debug code
|
||||
/////
|
||||
//auto bak_nsg = nsg;
|
||||
// auto bak_nsg = nsg;
|
||||
/////
|
||||
|
||||
knng.clear();
|
||||
|
@ -450,8 +475,8 @@ void NsgIndex::Link() {
|
|||
|
||||
//>> Debug code
|
||||
/////
|
||||
//int count = 0;
|
||||
//for (int i = 0; i < ntotal; ++i) {
|
||||
// int count = 0;
|
||||
// for (int i = 0; i < ntotal; ++i) {
|
||||
// if (bak_nsg[i].size() != nsg[i].size()) {
|
||||
// //count += nsg[i].size() - bak_nsg[i].size();
|
||||
// count += nsg[i].size();
|
||||
|
@ -459,21 +484,20 @@ void NsgIndex::Link() {
|
|||
//}
|
||||
/////
|
||||
|
||||
for (int i = 0; i < ntotal; ++i) {
|
||||
for (size_t i = 0; i < ntotal; ++i) {
|
||||
nsg[i].shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
void NsgIndex::SyncPrune(size_t n,
|
||||
std::vector<Neighbor> &pool,
|
||||
boost::dynamic_bitset<> &has_calculated,
|
||||
float *cut_graph_dist) {
|
||||
void
|
||||
NsgIndex::SyncPrune(size_t n, std::vector<Neighbor>& pool, boost::dynamic_bitset<>& has_calculated,
|
||||
float* cut_graph_dist) {
|
||||
// avoid lose nearest neighbor in knng
|
||||
for (size_t i = 0; i < knng[n].size(); ++i) {
|
||||
auto id = knng[n][i];
|
||||
if (has_calculated[id]) continue;
|
||||
float dist = calculate(ori_data_ + dimension * n,
|
||||
ori_data_ + dimension * id, dimension);
|
||||
if (has_calculated[id])
|
||||
continue;
|
||||
float dist = calculate(ori_data_ + dimension * n, ori_data_ + dimension * id, dimension);
|
||||
pool.emplace_back(Neighbor(id, dist, true));
|
||||
}
|
||||
|
||||
|
@ -481,14 +505,16 @@ void NsgIndex::SyncPrune(size_t n,
|
|||
unsigned cursor = 0;
|
||||
std::sort(pool.begin(), pool.end());
|
||||
std::vector<Neighbor> result;
|
||||
if (pool[cursor].id == n) cursor++;
|
||||
result.push_back(pool[cursor]); // init result with nearest neighbor
|
||||
if (pool[cursor].id == static_cast<node_t>(n)) {
|
||||
cursor++;
|
||||
}
|
||||
result.push_back(pool[cursor]); // init result with nearest neighbor
|
||||
|
||||
SelectEdge(cursor, pool, result, true);
|
||||
|
||||
// filling the cut_graph
|
||||
auto &des_id_pool = nsg[n];
|
||||
float *des_dist_pool = cut_graph_dist + n * out_degree;
|
||||
auto& des_id_pool = nsg[n];
|
||||
float* des_dist_pool = cut_graph_dist + n * out_degree;
|
||||
for (size_t i = 0; i < result.size(); ++i) {
|
||||
des_id_pool.push_back(result[i].id);
|
||||
des_dist_pool[i] = result[i].distance;
|
||||
|
@ -500,24 +526,27 @@ void NsgIndex::SyncPrune(size_t n,
|
|||
}
|
||||
|
||||
//>> Optimize: remove read-lock
|
||||
void NsgIndex::InterInsert(unsigned n, std::vector<std::mutex> &mutex_vec, float *cut_graph_dist) {
|
||||
auto ¤t = n;
|
||||
void
|
||||
NsgIndex::InterInsert(unsigned n, std::vector<std::mutex>& mutex_vec, float* cut_graph_dist) {
|
||||
auto& current = n;
|
||||
|
||||
auto &neighbor_id_pool = nsg[current];
|
||||
float *neighbor_dist_pool = cut_graph_dist + current * out_degree;
|
||||
auto& neighbor_id_pool = nsg[current];
|
||||
float* neighbor_dist_pool = cut_graph_dist + current * out_degree;
|
||||
for (size_t i = 0; i < out_degree; ++i) {
|
||||
if (neighbor_dist_pool[i] == -1) break;
|
||||
if (neighbor_dist_pool[i] == -1)
|
||||
break;
|
||||
|
||||
size_t current_neighbor = neighbor_id_pool[i]; // center's neighbor id
|
||||
auto &nsn_id_pool = nsg[current_neighbor]; // nsn => neighbor's neighbor
|
||||
float *nsn_dist_pool = cut_graph_dist + current_neighbor * out_degree;
|
||||
size_t current_neighbor = neighbor_id_pool[i]; // center's neighbor id
|
||||
auto& nsn_id_pool = nsg[current_neighbor]; // nsn => neighbor's neighbor
|
||||
float* nsn_dist_pool = cut_graph_dist + current_neighbor * out_degree;
|
||||
|
||||
std::vector<Neighbor> wait_for_link_pool; // maintain candidate neighbor of the current neighbor.
|
||||
std::vector<Neighbor> wait_for_link_pool; // maintain candidate neighbor of the current neighbor.
|
||||
int duplicate = false;
|
||||
{
|
||||
LockGuard lk(mutex_vec[current_neighbor]);
|
||||
for (int j = 0; j < out_degree; ++j) {
|
||||
if (nsn_dist_pool[j] == -1) break;
|
||||
for (size_t j = 0; j < out_degree; ++j) {
|
||||
if (nsn_dist_pool[j] == -1)
|
||||
break;
|
||||
|
||||
// 保证至少有一条边能连回来
|
||||
if (n == nsn_id_pool[j]) {
|
||||
|
@ -529,7 +558,8 @@ void NsgIndex::InterInsert(unsigned n, std::vector<std::mutex> &mutex_vec, float
|
|||
wait_for_link_pool.push_back(nsn);
|
||||
}
|
||||
}
|
||||
if (duplicate) continue;
|
||||
if (duplicate)
|
||||
continue;
|
||||
|
||||
// original: (neighbor) <------- (current)
|
||||
// after: (neighbor) -------> (current)
|
||||
|
@ -549,31 +579,29 @@ void NsgIndex::InterInsert(unsigned n, std::vector<std::mutex> &mutex_vec, float
|
|||
|
||||
{
|
||||
LockGuard lk(mutex_vec[current_neighbor]);
|
||||
for (int j = 0; j < result.size(); ++j) {
|
||||
for (size_t j = 0; j < result.size(); ++j) {
|
||||
nsn_id_pool[j] = result[j].id;
|
||||
nsn_dist_pool[j] = result[j].distance;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LockGuard lk(mutex_vec[current_neighbor]);
|
||||
for (int j = 0; j < out_degree; ++j) {
|
||||
for (size_t j = 0; j < out_degree; ++j) {
|
||||
if (nsn_dist_pool[j] == -1) {
|
||||
nsn_id_pool.push_back(current_as_neighbor.id);
|
||||
nsn_dist_pool[j] = current_as_neighbor.distance;
|
||||
if (j + 1 < out_degree) nsn_dist_pool[j + 1] = -1;
|
||||
if (j + 1 < out_degree)
|
||||
nsn_dist_pool[j + 1] = -1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void NsgIndex::SelectEdge(unsigned &cursor,
|
||||
std::vector<Neighbor> &sort_pool,
|
||||
std::vector<Neighbor> &result,
|
||||
bool limit) {
|
||||
auto &pool = sort_pool;
|
||||
void
|
||||
NsgIndex::SelectEdge(unsigned& cursor, std::vector<Neighbor>& sort_pool, std::vector<Neighbor>& result, bool limit) {
|
||||
auto& pool = sort_pool;
|
||||
|
||||
/*
|
||||
* edge selection
|
||||
|
@ -583,55 +611,59 @@ void NsgIndex::SelectEdge(unsigned &cursor,
|
|||
*/
|
||||
size_t search_deepth = limit ? candidate_pool_size : pool.size();
|
||||
while (result.size() < out_degree && cursor < search_deepth && (++cursor) < pool.size()) {
|
||||
auto &p = pool[cursor];
|
||||
auto& p = pool[cursor];
|
||||
bool should_link = true;
|
||||
for (size_t t = 0; t < result.size(); ++t) {
|
||||
float dist = calculate(ori_data_ + dimension * result[t].id,
|
||||
ori_data_ + dimension * p.id, dimension);
|
||||
float dist = calculate(ori_data_ + dimension * result[t].id, ori_data_ + dimension * p.id, dimension);
|
||||
|
||||
if (dist < p.distance) {
|
||||
should_link = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (should_link) result.push_back(p);
|
||||
if (should_link)
|
||||
result.push_back(p);
|
||||
}
|
||||
}
|
||||
|
||||
void NsgIndex::CheckConnectivity() {
|
||||
void
|
||||
NsgIndex::CheckConnectivity() {
|
||||
auto root = navigation_point;
|
||||
boost::dynamic_bitset<> has_linked{ntotal, 0};
|
||||
int64_t linked_count = 0;
|
||||
|
||||
while (linked_count < ntotal) {
|
||||
while (linked_count < static_cast<int64_t>(ntotal)) {
|
||||
DFS(root, has_linked, linked_count);
|
||||
if (linked_count >= ntotal) break;
|
||||
if (linked_count >= static_cast<int64_t>(ntotal)) {
|
||||
break;
|
||||
}
|
||||
FindUnconnectedNode(has_linked, root);
|
||||
}
|
||||
}
|
||||
|
||||
void NsgIndex::DFS(size_t root, boost::dynamic_bitset<> &has_linked, int64_t &linked_count) {
|
||||
void
|
||||
NsgIndex::DFS(size_t root, boost::dynamic_bitset<>& has_linked, int64_t& linked_count) {
|
||||
size_t start = root;
|
||||
std::stack<size_t> s;
|
||||
s.push(root);
|
||||
if (!has_linked[root]) {
|
||||
linked_count++; // not link
|
||||
has_linked[root] = true; // link start...
|
||||
linked_count++; // not link
|
||||
has_linked[root] = true; // link start...
|
||||
}
|
||||
|
||||
while (!s.empty()) {
|
||||
size_t next = ntotal + 1;
|
||||
|
||||
for (unsigned i = 0; i < nsg[start].size(); i++) {
|
||||
if (has_linked[nsg[start][i]] == false) // if not link
|
||||
{
|
||||
if (has_linked[nsg[start][i]] == false) { // if not link
|
||||
next = nsg[start][i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (next == (ntotal + 1)) {
|
||||
s.pop();
|
||||
if (s.empty()) break;
|
||||
if (s.empty())
|
||||
break;
|
||||
start = s.top();
|
||||
continue;
|
||||
}
|
||||
|
@ -642,17 +674,19 @@ void NsgIndex::DFS(size_t root, boost::dynamic_bitset<> &has_linked, int64_t &li
|
|||
}
|
||||
}
|
||||
|
||||
void NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<> &has_linked, int64_t &root) {
|
||||
void
|
||||
NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<>& has_linked, int64_t& root) {
|
||||
// find any of unlinked-node
|
||||
size_t id = ntotal;
|
||||
for (size_t i = 0; i < ntotal; i++) { // find not link
|
||||
for (size_t i = 0; i < ntotal; i++) { // find not link
|
||||
if (has_linked[i] == false) {
|
||||
id = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (id == ntotal) return; // No Unlinked Node
|
||||
if (id == ntotal)
|
||||
return; // No Unlinked Node
|
||||
|
||||
// search unlinked-node's neighbor
|
||||
std::vector<Neighbor> tmp, pool;
|
||||
|
@ -660,7 +694,7 @@ void NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<> &has_linked, int64_t
|
|||
std::sort(pool.begin(), pool.end());
|
||||
|
||||
size_t found = 0;
|
||||
for (size_t i = 0; i < pool.size(); i++) { // find nearest neighbor and add unlinked-node as its neighbor
|
||||
for (size_t i = 0; i < pool.size(); i++) { // find nearest neighbor and add unlinked-node as its neighbor
|
||||
if (has_linked[pool[i].id]) {
|
||||
root = pool[i].id;
|
||||
found = 1;
|
||||
|
@ -668,8 +702,9 @@ void NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<> &has_linked, int64_t
|
|||
}
|
||||
}
|
||||
if (found == 0) {
|
||||
while (true) { // random a linked-node and add unlinked-node as its neighbor
|
||||
size_t rid = rand() % ntotal;
|
||||
unsigned int seed = 100;
|
||||
while (true) { // random a linked-node and add unlinked-node as its neighbor
|
||||
size_t rid = rand_r(&seed) % ntotal;
|
||||
if (has_linked[rid]) {
|
||||
root = rid;
|
||||
break;
|
||||
|
@ -679,23 +714,18 @@ void NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<> &has_linked, int64_t
|
|||
nsg[root].push_back(id);
|
||||
}
|
||||
|
||||
|
||||
void NsgIndex::Search(const float *query,
|
||||
const unsigned &nq,
|
||||
const unsigned &dim,
|
||||
const unsigned &k,
|
||||
float *dist,
|
||||
long *ids,
|
||||
SearchParams ¶ms) {
|
||||
void
|
||||
NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist,
|
||||
int64_t* ids, SearchParams& params) {
|
||||
std::vector<std::vector<Neighbor>> resset(nq);
|
||||
|
||||
TimeRecorder rc("search");
|
||||
if (nq == 1) {
|
||||
GetNeighbors(query, resset[0], nsg, ¶ms);
|
||||
} else{
|
||||
//#pragma omp parallel for schedule(dynamic, 50)
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
} else {
|
||||
//#pragma omp parallel for schedule(dynamic, 50)
|
||||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < nq; ++i) {
|
||||
// TODO(linxj): when to use openmp
|
||||
auto single_query = query + i * dim;
|
||||
GetNeighbors(single_query, resset[i], nsg, ¶ms);
|
||||
|
@ -703,9 +733,9 @@ void NsgIndex::Search(const float *query,
|
|||
}
|
||||
rc.ElapseFromBegin("cost");
|
||||
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
for (int j = 0; j < k; ++j) {
|
||||
//ids[i * k + j] = resset[i][j].id;
|
||||
for (unsigned int i = 0; i < nq; ++i) {
|
||||
for (unsigned int j = 0; j < k; ++j) {
|
||||
// ids[i * k + j] = resset[i][j].id;
|
||||
|
||||
// Fix(linxj): bug, reset[i][j] out of range
|
||||
ids[i * k + j] = ids_[resset[i][j].id];
|
||||
|
@ -714,27 +744,28 @@ void NsgIndex::Search(const float *query,
|
|||
}
|
||||
|
||||
//>> Debug: test single insert
|
||||
//int x_0 = resset[0].size();
|
||||
//for (int l = 0; l < resset[0].size(); ++l) {
|
||||
// int x_0 = resset[0].size();
|
||||
// for (int l = 0; l < resset[0].size(); ++l) {
|
||||
// resset[0].pop_back();
|
||||
//}
|
||||
//resset.clear();
|
||||
// resset.clear();
|
||||
|
||||
//ProfilerStart("xx.prof");
|
||||
//std::vector<Neighbor> resset;
|
||||
//GetNeighbors(query, resset, nsg, ¶ms);
|
||||
//for (int i = 0; i < k; ++i) {
|
||||
// ProfilerStart("xx.prof");
|
||||
// std::vector<Neighbor> resset;
|
||||
// GetNeighbors(query, resset, nsg, ¶ms);
|
||||
// for (int i = 0; i < k; ++i) {
|
||||
// ids[i] = resset[i].id;
|
||||
//dist[i] = resset[i].distance;
|
||||
// dist[i] = resset[i].distance;
|
||||
//}
|
||||
//ProfilerStop();
|
||||
// ProfilerStop();
|
||||
}
|
||||
|
||||
void NsgIndex::SetKnnGraph(Graph &g) {
|
||||
void
|
||||
NsgIndex::SetKnnGraph(Graph& g) {
|
||||
knng = std::move(g);
|
||||
}
|
||||
|
||||
//void NsgIndex::GetKnnGraphFromFile() {
|
||||
// void NsgIndex::GetKnnGraphFromFile() {
|
||||
// //std::string filename = "/home/zilliz/opt/workspace/wook/efanna_graph/tests/sift.1M.50NN.graph";
|
||||
// std::string filename = "/home/zilliz/opt/workspace/wook/efanna_graph/tests/sift.50NN.graph";
|
||||
//
|
||||
|
@ -759,6 +790,6 @@ void NsgIndex::SetKnnGraph(Graph &g) {
|
|||
// in.close();
|
||||
//}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace algo
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,22 +15,19 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
#include "Neighbor.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
|
||||
|
||||
using node_t = int64_t;
|
||||
|
||||
enum class MetricType {
|
||||
|
@ -53,15 +50,15 @@ using Graph = std::vector<std::vector<node_t>>;
|
|||
class NsgIndex {
|
||||
public:
|
||||
size_t dimension;
|
||||
size_t ntotal; // totabl nb of indexed vectors
|
||||
MetricType metric_type; // L2 | IP
|
||||
size_t ntotal; // totabl nb of indexed vectors
|
||||
MetricType metric_type; // L2 | IP
|
||||
|
||||
float *ori_data_;
|
||||
long *ids_; // TODO: support different type
|
||||
Graph nsg; // final graph
|
||||
Graph knng; // reset after build
|
||||
float* ori_data_;
|
||||
int64_t* ids_; // TODO: support different type
|
||||
Graph nsg; // final graph
|
||||
Graph knng; // reset after build
|
||||
|
||||
node_t navigation_point; // offset of node in origin data
|
||||
node_t navigation_point; // offset of node in origin data
|
||||
|
||||
bool is_trained = false;
|
||||
|
||||
|
@ -69,91 +66,81 @@ class NsgIndex {
|
|||
* build and search parameter
|
||||
*/
|
||||
size_t search_length;
|
||||
size_t candidate_pool_size; // search deepth in fullset
|
||||
size_t candidate_pool_size; // search deepth in fullset
|
||||
size_t out_degree;
|
||||
|
||||
public:
|
||||
explicit NsgIndex(const size_t &dimension,
|
||||
const size_t &n,
|
||||
MetricType metric = MetricType::METRIC_L2);
|
||||
explicit NsgIndex(const size_t& dimension, const size_t& n, MetricType metric = MetricType::METRIC_L2);
|
||||
|
||||
NsgIndex() = default;
|
||||
|
||||
virtual ~NsgIndex();
|
||||
|
||||
void SetKnnGraph(Graph &knng);
|
||||
void
|
||||
SetKnnGraph(Graph& knng);
|
||||
|
||||
virtual void Build_with_ids(size_t nb,
|
||||
const float *data,
|
||||
const long *ids,
|
||||
const BuildParams ¶meters);
|
||||
virtual void
|
||||
Build_with_ids(size_t nb, const float* data, const int64_t* ids, const BuildParams& parameters);
|
||||
|
||||
void Search(const float *query,
|
||||
const unsigned &nq,
|
||||
const unsigned &dim,
|
||||
const unsigned &k,
|
||||
float *dist,
|
||||
long *ids,
|
||||
SearchParams ¶ms);
|
||||
void
|
||||
Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist, int64_t* ids,
|
||||
SearchParams& params);
|
||||
|
||||
// Not support yet.
|
||||
//virtual void Add() = 0;
|
||||
//virtual void Add_with_ids() = 0;
|
||||
//virtual void Delete() = 0;
|
||||
//virtual void Delete_with_ids() = 0;
|
||||
//virtual void Rebuild(size_t nb,
|
||||
// virtual void Add() = 0;
|
||||
// virtual void Add_with_ids() = 0;
|
||||
// virtual void Delete() = 0;
|
||||
// virtual void Delete_with_ids() = 0;
|
||||
// virtual void Rebuild(size_t nb,
|
||||
// const float *data,
|
||||
// const long *ids,
|
||||
// const int64_t *ids,
|
||||
// const Parameters ¶meters) = 0;
|
||||
//virtual void Build(size_t nb,
|
||||
// virtual void Build(size_t nb,
|
||||
// const float *data,
|
||||
// const BuildParam ¶meters);
|
||||
|
||||
protected:
|
||||
virtual void InitNavigationPoint();
|
||||
virtual void
|
||||
InitNavigationPoint();
|
||||
|
||||
// link specify
|
||||
void GetNeighbors(const float *query,
|
||||
std::vector<Neighbor> &resset,
|
||||
std::vector<Neighbor> &fullset,
|
||||
boost::dynamic_bitset<> &has_calculated_dist);
|
||||
void
|
||||
GetNeighbors(const float* query, std::vector<Neighbor>& resset, std::vector<Neighbor>& fullset,
|
||||
boost::dynamic_bitset<>& has_calculated_dist);
|
||||
|
||||
// FindUnconnectedNode
|
||||
void GetNeighbors(const float *query,
|
||||
std::vector<Neighbor> &resset,
|
||||
std::vector<Neighbor> &fullset);
|
||||
void
|
||||
GetNeighbors(const float* query, std::vector<Neighbor>& resset, std::vector<Neighbor>& fullset);
|
||||
|
||||
// search and navigation-point
|
||||
void GetNeighbors(const float *query,
|
||||
std::vector<Neighbor> &resset,
|
||||
Graph &graph,
|
||||
SearchParams *param = nullptr);
|
||||
void
|
||||
GetNeighbors(const float* query, std::vector<Neighbor>& resset, Graph& graph, SearchParams* param = nullptr);
|
||||
|
||||
void Link();
|
||||
void
|
||||
Link();
|
||||
|
||||
void SyncPrune(size_t q,
|
||||
std::vector<Neighbor> &pool,
|
||||
boost::dynamic_bitset<> &has_calculated,
|
||||
float *cut_graph_dist
|
||||
);
|
||||
void
|
||||
SyncPrune(size_t q, std::vector<Neighbor>& pool, boost::dynamic_bitset<>& has_calculated, float* cut_graph_dist);
|
||||
|
||||
void SelectEdge(unsigned &cursor,
|
||||
std::vector<Neighbor> &sort_pool,
|
||||
std::vector<Neighbor> &result,
|
||||
bool limit = false);
|
||||
void
|
||||
SelectEdge(unsigned& cursor, std::vector<Neighbor>& sort_pool, std::vector<Neighbor>& result, bool limit = false);
|
||||
|
||||
void InterInsert(unsigned n, std::vector<std::mutex> &mutex_vec, float *dist);
|
||||
void
|
||||
InterInsert(unsigned n, std::vector<std::mutex>& mutex_vec, float* dist);
|
||||
|
||||
void CheckConnectivity();
|
||||
void
|
||||
CheckConnectivity();
|
||||
|
||||
void DFS(size_t root, boost::dynamic_bitset<> &flags, int64_t &count);
|
||||
void
|
||||
DFS(size_t root, boost::dynamic_bitset<>& flags, int64_t& count);
|
||||
|
||||
void FindUnconnectedNode(boost::dynamic_bitset<> &flags, int64_t &root);
|
||||
void
|
||||
FindUnconnectedNode(boost::dynamic_bitset<>& flags, int64_t& root);
|
||||
|
||||
//private:
|
||||
// void GetKnnGraphFromFile();
|
||||
// private:
|
||||
// void GetKnnGraphFromFile();
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace algo
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,21 +15,20 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
|
||||
#include "NSGHelper.h"
|
||||
|
||||
#include "knowhere/index/vector_index/nsg/NSGHelper.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
|
||||
// TODO: impl search && insert && return insert pos. why not just find and swap?
|
||||
int InsertIntoPool(Neighbor *addr, unsigned K, Neighbor nn) {
|
||||
int
|
||||
InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn) {
|
||||
//>> Fix: Add assert
|
||||
for (int i = 0; i < K; ++i) {
|
||||
for (unsigned int i = 0; i < K; ++i) {
|
||||
assert(addr[i].id != nn.id);
|
||||
}
|
||||
|
||||
|
@ -37,7 +36,7 @@ int InsertIntoPool(Neighbor *addr, unsigned K, Neighbor nn) {
|
|||
int left = 0, right = K - 1;
|
||||
if (addr[left].distance > nn.distance) {
|
||||
//>> Fix: memmove overflow, dump when vector<Neighbor> deconstruct
|
||||
memmove((char *) &addr[left + 1], &addr[left], (K - 1) * sizeof(Neighbor));
|
||||
memmove((char*)&addr[left + 1], &addr[left], (K - 1) * sizeof(Neighbor));
|
||||
addr[left] = nn;
|
||||
return left;
|
||||
}
|
||||
|
@ -52,10 +51,10 @@ int InsertIntoPool(Neighbor *addr, unsigned K, Neighbor nn) {
|
|||
else
|
||||
left = mid;
|
||||
}
|
||||
//check equal ID
|
||||
// check equal ID
|
||||
|
||||
while (left > 0) {
|
||||
if (addr[left].distance < nn.distance) // pos is right
|
||||
if (addr[left].distance < nn.distance) // pos is right
|
||||
break;
|
||||
if (addr[left].id == nn.id)
|
||||
return K + 1;
|
||||
|
@ -65,24 +64,25 @@ int InsertIntoPool(Neighbor *addr, unsigned K, Neighbor nn) {
|
|||
return K + 1;
|
||||
|
||||
//>> Fix: memmove overflow, dump when vector<Neighbor> deconstruct
|
||||
memmove((char *) &addr[right + 1], &addr[right], (K - 1 - right) * sizeof(Neighbor));
|
||||
memmove((char*)&addr[right + 1], &addr[right], (K - 1 - right) * sizeof(Neighbor));
|
||||
addr[right] = nn;
|
||||
return right;
|
||||
}
|
||||
|
||||
// TODO: support L2 / IP
|
||||
float calculate(const float *a, const float *b, unsigned size) {
|
||||
float
|
||||
calculate(const float* a, const float* b, unsigned size) {
|
||||
float result = 0;
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __AVX__
|
||||
|
||||
#define AVX_L2SQR(addr1, addr2, dest, tmp1, tmp2) \
|
||||
tmp1 = _mm256_loadu_ps(addr1);\
|
||||
tmp2 = _mm256_loadu_ps(addr2);\
|
||||
tmp1 = _mm256_sub_ps(tmp1, tmp2); \
|
||||
tmp1 = _mm256_mul_ps(tmp1, tmp1); \
|
||||
dest = _mm256_add_ps(dest, tmp1);
|
||||
tmp1 = _mm256_loadu_ps(addr1); \
|
||||
tmp2 = _mm256_loadu_ps(addr2); \
|
||||
tmp1 = _mm256_sub_ps(tmp1, tmp2); \
|
||||
tmp1 = _mm256_mul_ps(tmp1, tmp1); \
|
||||
dest = _mm256_add_ps(dest, tmp1);
|
||||
|
||||
__m256 sum;
|
||||
__m256 l0, l1;
|
||||
|
@ -90,14 +90,16 @@ float calculate(const float *a, const float *b, unsigned size) {
|
|||
unsigned D = (size + 7) & ~7U;
|
||||
unsigned DR = D % 16;
|
||||
unsigned DD = D - DR;
|
||||
const float *l = a;
|
||||
const float *r = b;
|
||||
const float *e_l = l + DD;
|
||||
const float *e_r = r + DD;
|
||||
float unpack[8] __attribute__ ((aligned (32))) = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
const float* l = a;
|
||||
const float* r = b;
|
||||
const float* e_l = l + DD;
|
||||
const float* e_r = r + DD;
|
||||
float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
sum = _mm256_loadu_ps(unpack);
|
||||
if (DR) { AVX_L2SQR(e_l, e_r, sum, l0, r0); }
|
||||
if (DR) {
|
||||
AVX_L2SQR(e_l, e_r, sum, l0, r0);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
|
||||
AVX_L2SQR(l, r, sum, l0, r0);
|
||||
|
@ -109,11 +111,11 @@ float calculate(const float *a, const float *b, unsigned size) {
|
|||
#else
|
||||
#ifdef __SSE2__
|
||||
#define SSE_L2SQR(addr1, addr2, dest, tmp1, tmp2) \
|
||||
tmp1 = _mm_load_ps(addr1);\
|
||||
tmp2 = _mm_load_ps(addr2);\
|
||||
tmp1 = _mm_sub_ps(tmp1, tmp2); \
|
||||
tmp1 = _mm_mul_ps(tmp1, tmp1); \
|
||||
dest = _mm_add_ps(dest, tmp1);
|
||||
tmp1 = _mm_load_ps(addr1); \
|
||||
tmp2 = _mm_load_ps(addr2); \
|
||||
tmp1 = _mm_sub_ps(tmp1, tmp2); \
|
||||
tmp1 = _mm_mul_ps(tmp1, tmp1); \
|
||||
dest = _mm_add_ps(dest, tmp1);
|
||||
|
||||
__m128 sum;
|
||||
__m128 l0, l1, l2, l3;
|
||||
|
@ -121,18 +123,22 @@ float calculate(const float *a, const float *b, unsigned size) {
|
|||
unsigned D = (size + 3) & ~3U;
|
||||
unsigned DR = D % 16;
|
||||
unsigned DD = D - DR;
|
||||
const float *l = a;
|
||||
const float *r = b;
|
||||
const float *e_l = l + DD;
|
||||
const float *e_r = r + DD;
|
||||
float unpack[4] __attribute__ ((aligned (16))) = {0, 0, 0, 0};
|
||||
const float* l = a;
|
||||
const float* r = b;
|
||||
const float* e_l = l + DD;
|
||||
const float* e_r = r + DD;
|
||||
float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0};
|
||||
|
||||
sum = _mm_load_ps(unpack);
|
||||
switch (DR) {
|
||||
case 12:SSE_L2SQR(e_l + 8, e_r + 8, sum, l2, r2);
|
||||
case 8:SSE_L2SQR(e_l + 4, e_r + 4, sum, l1, r1);
|
||||
case 4:SSE_L2SQR(e_l, e_r, sum, l0, r0);
|
||||
default:break;
|
||||
case 12:
|
||||
SSE_L2SQR(e_l + 8, e_r + 8, sum, l2, r2);
|
||||
case 8:
|
||||
SSE_L2SQR(e_l + 4, e_r + 4, sum, l1, r1);
|
||||
case 4:
|
||||
SSE_L2SQR(e_l, e_r, sum, l0, r0);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
|
||||
SSE_L2SQR(l, r, sum, l0, r0);
|
||||
|
@ -143,28 +149,28 @@ float calculate(const float *a, const float *b, unsigned size) {
|
|||
_mm_storeu_ps(unpack, sum);
|
||||
result += unpack[0] + unpack[1] + unpack[2] + unpack[3];
|
||||
|
||||
//nomal distance
|
||||
// nomal distance
|
||||
#else
|
||||
|
||||
float diff0, diff1, diff2, diff3;
|
||||
const float* last = a + size;
|
||||
const float* unroll_group = last - 3;
|
||||
const float* last = a + size;
|
||||
const float* unroll_group = last - 3;
|
||||
|
||||
/* Process 4 items with each loop for efficiency. */
|
||||
while (a < unroll_group) {
|
||||
diff0 = a[0] - b[0];
|
||||
diff1 = a[1] - b[1];
|
||||
diff2 = a[2] - b[2];
|
||||
diff3 = a[3] - b[3];
|
||||
result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
|
||||
a += 4;
|
||||
b += 4;
|
||||
}
|
||||
/* Process last 0-3 pixels. Not needed for standard vector lengths. */
|
||||
while (a < last) {
|
||||
diff0 = *a++ - *b++;
|
||||
result += diff0 * diff0;
|
||||
}
|
||||
/* Process 4 items with each loop for efficiency. */
|
||||
while (a < unroll_group) {
|
||||
diff0 = a[0] - b[0];
|
||||
diff1 = a[1] - b[1];
|
||||
diff2 = a[2] - b[2];
|
||||
diff3 = a[3] - b[3];
|
||||
result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
|
||||
a += 4;
|
||||
b += 4;
|
||||
}
|
||||
/* Process last 0-3 pixels. Not needed for standard vector lengths. */
|
||||
while (a < last) {
|
||||
diff0 = *a++ - *b++;
|
||||
result += diff0 * diff0;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
@ -172,7 +178,6 @@ float calculate(const float *a, const float *b, unsigned size) {
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace algo
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <x86intrin.h>
|
||||
|
@ -26,14 +25,15 @@
|
|||
#include "NSG.h"
|
||||
#include "knowhere/common/Config.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
|
||||
extern int InsertIntoPool(Neighbor *addr, unsigned K, Neighbor nn);
|
||||
extern float calculate(const float *a, const float *b, unsigned size);
|
||||
extern int
|
||||
InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn);
|
||||
extern float
|
||||
calculate(const float* a, const float* b, unsigned size);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace algo
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,31 +15,31 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "NSGIO.h"
|
||||
|
||||
#include "knowhere/index/vector_index/nsg/NSGIO.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
|
||||
void write_index(NsgIndex *index, MemoryIOWriter &writer) {
|
||||
void
|
||||
write_index(NsgIndex* index, MemoryIOWriter& writer) {
|
||||
writer(&index->ntotal, sizeof(index->ntotal), 1);
|
||||
writer(&index->dimension, sizeof(index->dimension), 1);
|
||||
writer(&index->navigation_point, sizeof(index->navigation_point), 1);
|
||||
writer(index->ori_data_, sizeof(float) * index->ntotal * index->dimension, 1);
|
||||
writer(index->ids_, sizeof(long) * index->ntotal, 1);
|
||||
writer(index->ids_, sizeof(int64_t) * index->ntotal, 1);
|
||||
|
||||
for (unsigned i = 0; i < index->ntotal; ++i) {
|
||||
auto neighbor_num = (node_t) index->nsg[i].size();
|
||||
auto neighbor_num = (node_t)index->nsg[i].size();
|
||||
writer(&neighbor_num, sizeof(node_t), 1);
|
||||
writer(index->nsg[i].data(), neighbor_num * sizeof(node_t), 1);
|
||||
}
|
||||
}
|
||||
|
||||
NsgIndex *read_index(MemoryIOReader &reader) {
|
||||
NsgIndex*
|
||||
read_index(MemoryIOReader& reader) {
|
||||
size_t ntotal;
|
||||
size_t dimension;
|
||||
reader(&ntotal, sizeof(size_t), 1);
|
||||
|
@ -48,9 +48,9 @@ NsgIndex *read_index(MemoryIOReader &reader) {
|
|||
reader(&index->navigation_point, sizeof(index->navigation_point), 1);
|
||||
|
||||
index->ori_data_ = new float[index->ntotal * index->dimension];
|
||||
index->ids_ = new long[index->ntotal];
|
||||
index->ids_ = new int64_t[index->ntotal];
|
||||
reader(index->ori_data_, sizeof(float) * index->ntotal * index->dimension, 1);
|
||||
reader(index->ids_, sizeof(long) * index->ntotal, 1);
|
||||
reader(index->ids_, sizeof(int64_t) * index->ntotal, 1);
|
||||
|
||||
index->nsg.reserve(index->ntotal);
|
||||
index->nsg.resize(index->ntotal);
|
||||
|
@ -66,6 +66,6 @@ NsgIndex *read_index(MemoryIOReader &reader) {
|
|||
return index;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace algo
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,21 +15,21 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
#include "NSG.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
|
||||
extern void write_index(NsgIndex* index, MemoryIOWriter& writer);
|
||||
extern NsgIndex* read_index(MemoryIOReader& reader);
|
||||
extern void
|
||||
write_index(NsgIndex* index, MemoryIOWriter& writer);
|
||||
extern NsgIndex*
|
||||
read_index(MemoryIOReader& reader);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace algo
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -15,12 +15,10 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
namespace algo {
|
||||
|
@ -29,21 +27,25 @@ using node_t = int64_t;
|
|||
|
||||
// TODO: search use simple neighbor
|
||||
struct Neighbor {
|
||||
node_t id; // offset of node in origin data
|
||||
node_t id; // offset of node in origin data
|
||||
float distance;
|
||||
bool has_explored;
|
||||
|
||||
Neighbor() = default;
|
||||
explicit Neighbor(node_t id, float distance, bool f) : id{id}, distance{distance}, has_explored(f) {}
|
||||
|
||||
explicit Neighbor(node_t id, float distance) : id{id}, distance{distance}, has_explored(false) {}
|
||||
explicit Neighbor(node_t id, float distance, bool f) : id{id}, distance{distance}, has_explored(f) {
|
||||
}
|
||||
|
||||
inline bool operator<(const Neighbor &other) const {
|
||||
explicit Neighbor(node_t id, float distance) : id{id}, distance{distance}, has_explored(false) {
|
||||
}
|
||||
|
||||
inline bool
|
||||
operator<(const Neighbor& other) const {
|
||||
return distance < other.distance;
|
||||
}
|
||||
};
|
||||
|
||||
//struct SimpleNeighbor {
|
||||
// struct SimpleNeighbor {
|
||||
// node_t id; // offset of node in origin data
|
||||
// float distance;
|
||||
//
|
||||
|
@ -57,7 +59,6 @@ struct Neighbor {
|
|||
|
||||
typedef std::lock_guard<std::mutex> LockGuard;
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace algo
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
|
||||
#include <random>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "SPTAG/AnnService/inc/Core/Common.h"
|
||||
#include "SPTAG/AnnService/inc/Core/VectorIndex.h"
|
||||
|
||||
|
||||
int
|
||||
main(int argc, char *argv[]) {
|
||||
using namespace SPTAG;
|
||||
const int d = 128;
|
||||
const int n = 100;
|
||||
|
||||
auto p_data = new float[n * d];
|
||||
|
||||
auto index = VectorIndex::CreateInstance(IndexAlgoType::KDT, VectorValueType::Float);
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 mt(rd());
|
||||
std::uniform_real_distribution<double> dist(1.0, 2.0);
|
||||
|
||||
for (auto i = 0; i < n; i++) {
|
||||
for (auto j = 0; j < d; j++) {
|
||||
p_data[i * d + j] = dist(mt) - 1;
|
||||
}
|
||||
}
|
||||
std::cout << "generate random n * d finished.";
|
||||
ByteArray data((uint8_t *) p_data, n * d * sizeof(float), true);
|
||||
|
||||
auto vectorset = std::make_shared<BasicVectorSet>(data, VectorValueType::Float, d, n);
|
||||
index->BuildIndex(vectorset, nullptr);
|
||||
|
||||
std::cout << index->GetFeatureDim();
|
||||
}
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
|
||||
#include "knowhere/index/vector_index/definitions.h"
|
||||
#include "knowhere/adapter/sptag.h"
|
||||
#include "knowhere/adapter/structure.h"
|
||||
|
||||
|
||||
using namespace zilliz::knowhere;
|
||||
|
||||
DatasetPtr
|
||||
generate_dataset(int64_t n, int64_t d, int64_t base) {
|
||||
auto elems = n * d;
|
||||
auto p_data = (float *) malloc(elems * sizeof(float));
|
||||
auto p_id = (int64_t *) malloc(elems * sizeof(int64_t));
|
||||
assert(p_data != nullptr && p_id != nullptr);
|
||||
|
||||
for (auto i = 0; i < n; ++i) {
|
||||
for (auto j = 0; j < d; ++j) {
|
||||
p_data[i * d + j] = float(base + i);
|
||||
}
|
||||
p_id[i] = i;
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape{n, d};
|
||||
auto tensor = ConstructFloatTensorSmart((uint8_t *) p_data, elems * sizeof(float), shape);
|
||||
std::vector<TensorPtr> tensors{tensor};
|
||||
std::vector<FieldPtr> tensor_fields{ConstructFloatField("data")};
|
||||
auto tensor_schema = std::make_shared<Schema>(tensor_fields);
|
||||
|
||||
auto id_array = ConstructInt64ArraySmart((uint8_t *) p_id, n * sizeof(int64_t));
|
||||
std::vector<ArrayPtr> arrays{id_array};
|
||||
std::vector<FieldPtr> array_fields{ConstructInt64Field("id")};
|
||||
auto array_schema = std::make_shared<Schema>(tensor_fields);
|
||||
|
||||
auto dataset = std::make_shared<Dataset>(std::move(arrays), array_schema,
|
||||
std::move(tensors), tensor_schema);
|
||||
|
||||
return dataset;
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
generate_queries(int64_t n, int64_t d, int64_t k, int64_t base) {
|
||||
size_t size = sizeof(float) * n * d;
|
||||
auto v = (float *) malloc(size);
|
||||
// TODO: check malloc
|
||||
for (auto i = 0; i < n; ++i) {
|
||||
for (auto j = 0; j < d; ++j) {
|
||||
v[i * d + j] = float(base + i);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TensorPtr> data;
|
||||
auto buffer = MakeMutableBufferSmart((uint8_t *) v, size);
|
||||
std::vector<int64_t> shape{n, d};
|
||||
auto float_type = std::make_shared<arrow::FloatType>();
|
||||
auto tensor = std::make_shared<Tensor>(float_type, buffer, shape);
|
||||
data.push_back(tensor);
|
||||
|
||||
Config meta;
|
||||
meta[META_ROWS] = int64_t (n);
|
||||
meta[META_DIM] = int64_t (d);
|
||||
meta[META_K] = int64_t (k);
|
||||
|
||||
auto type = std::make_shared<arrow::FloatType>();
|
||||
auto field = std::make_shared<Field>("data", type);
|
||||
std::vector<FieldPtr> fields{field};
|
||||
auto schema = std::make_shared<Schema>(fields);
|
||||
|
||||
return std::make_shared<Dataset>(data, schema);
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
main(int argc, char *argv[]) {
|
||||
auto kdt_index = std::make_shared<CPUKDTRNG>();
|
||||
|
||||
const auto d = 10;
|
||||
const auto k = 3;
|
||||
const auto nquery = 10;
|
||||
|
||||
// ID [0, 99]
|
||||
auto train = generate_dataset(100, d, 0);
|
||||
// ID [100]
|
||||
auto base = generate_dataset(1, d, 0);
|
||||
auto queries = generate_queries(nquery, d, k, 0);
|
||||
|
||||
// Build Preprocessor
|
||||
auto preprocessor = kdt_index->BuildPreprocessor(train, Config());
|
||||
|
||||
// Set Preprocessor
|
||||
kdt_index->set_preprocessor(preprocessor);
|
||||
|
||||
Config train_config;
|
||||
train_config["TPTNumber"] = "64";
|
||||
// Train
|
||||
kdt_index->Train(train, train_config);
|
||||
|
||||
// Add
|
||||
kdt_index->Add(base, Config());
|
||||
|
||||
auto binary = kdt_index->Serialize();
|
||||
auto new_index = std::make_shared<CPUKDTRNG>();
|
||||
new_index->Load(binary);
|
||||
// auto new_index = kdt_index;
|
||||
|
||||
Config search_config;
|
||||
search_config[META_K] = int64_t (k);
|
||||
|
||||
// Search
|
||||
auto result = new_index->Search(queries, search_config);
|
||||
|
||||
// Print Result
|
||||
{
|
||||
auto ids = result->array()[0];
|
||||
auto dists = result->array()[1];
|
||||
|
||||
std::stringstream ss_id;
|
||||
std::stringstream ss_dist;
|
||||
for (auto i = 0; i < nquery; i++) {
|
||||
for (auto j = 0; j < k; ++j) {
|
||||
ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
|
||||
ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
|
||||
}
|
||||
ss_id << std::endl;
|
||||
ss_dist << std::endl;
|
||||
}
|
||||
std::cout << "id\n" << ss_id.str() << std::endl;
|
||||
std::cout << "dist\n" << ss_dist.str() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
##############################
|
||||
include_directories(/usr/local/include/gperftools)
|
||||
link_directories(/usr/local/lib)
|
||||
|
||||
add_definitions(-std=c++11 -O3 -lboost -march=native -Wall -DINFO)
|
||||
|
||||
find_package(OpenMP)
|
||||
if (OPENMP_FOUND)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
else ()
|
||||
message(FATAL_ERROR "no OpenMP supprot")
|
||||
endif ()
|
||||
message(${OpenMP_CXX_FLAGS})
|
||||
|
||||
include_directories(${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/nsg)
|
||||
|
||||
aux_source_directory(${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/nsg nsg_src)
|
||||
|
||||
set(interface_src
|
||||
${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/ivf.cpp
|
||||
${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/gpu_ivf.cpp
|
||||
${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/cloner.cpp
|
||||
${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/idmap.cpp
|
||||
${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/nsg_index.cpp
|
||||
${CORE_SOURCE_DIR}/src/knowhere/adapter/structure.cpp
|
||||
${CORE_SOURCE_DIR}/src/knowhere/common/exception.cpp
|
||||
${CORE_SOURCE_DIR}/src/knowhere/common/timer.cpp
|
||||
../utils.cpp
|
||||
)
|
||||
|
||||
if(NOT TARGET test_nsg)
|
||||
add_executable(test_nsg
|
||||
test_nsg.cpp
|
||||
${interface_src}
|
||||
${nsg_src}
|
||||
${util_srcs}
|
||||
)
|
||||
endif()
|
||||
|
||||
target_link_libraries(test_nsg ${depend_libs} ${unittest_libs} ${basic_libs})
|
||||
##############################
|
||||
|
||||
install(TARGETS test_nsg DESTINATION unittest)
|
|
@ -1,111 +0,0 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "knowhere/common/exception.h"
|
||||
#include "knowhere/index/vector_index/gpu_ivf.h"
|
||||
#include "knowhere/index/vector_index/nsg_index.h"
|
||||
#include "knowhere/index/vector_index/nsg/nsg_io.h"
|
||||
|
||||
#include "../utils.h"
|
||||
|
||||
|
||||
using namespace zilliz::knowhere;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::Combine;
|
||||
|
||||
constexpr int64_t DEVICE_ID = 0;
|
||||
|
||||
class NSGInterfaceTest : public DataGen, public TestWithParam<::std::tuple<Config, Config>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
//Init_with_default();
|
||||
FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_ID, 1024*1024*200, 1024*1024*600, 2);
|
||||
Generate(256, 10000, 1);
|
||||
index_ = std::make_shared<NSG>();
|
||||
std::tie(train_cfg, search_cfg) = GetParam();
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
FaissGpuResourceMgr::GetInstance().Free();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<NSG> index_;
|
||||
Config train_cfg;
|
||||
Config search_cfg;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(NSGparameters, NSGInterfaceTest,
|
||||
Values(std::make_tuple(
|
||||
// search length > out_degree
|
||||
Config::object{{"nlist", 128}, {"nprobe", 50}, {"knng", 100}, {"metric_type", "L2"},
|
||||
{"search_length", 60}, {"out_degree", 70}, {"candidate_pool_size", 500}},
|
||||
Config::object{{"k", 20}, {"search_length", 30}}))
|
||||
);
|
||||
|
||||
void AssertAnns(const DatasetPtr &result,
|
||||
const int &nq,
|
||||
const int &k) {
|
||||
auto ids = result->array()[0];
|
||||
for (auto i = 0; i < nq; i++) {
|
||||
EXPECT_EQ(i, *(ids->data()->GetValues<int64_t>(1, i * k)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(NSGInterfaceTest, basic_test) {
|
||||
assert(!xb.empty());
|
||||
|
||||
auto model = index_->Train(base_dataset, train_cfg);
|
||||
auto result = index_->Search(query_dataset, search_cfg);
|
||||
AssertAnns(result, nq, k);
|
||||
|
||||
auto binaryset = index_->Serialize();
|
||||
auto new_index = std::make_shared<NSG>();
|
||||
new_index->Load(binaryset);
|
||||
auto new_result = new_index->Search(query_dataset, Config::object{{"k", k}});
|
||||
AssertAnns(result, nq, k);
|
||||
|
||||
ASSERT_EQ(index_->Count(), nb);
|
||||
ASSERT_EQ(index_->Dimension(), dim);
|
||||
ASSERT_THROW({index_->Clone();}, zilliz::knowhere::KnowhereException);
|
||||
ASSERT_NO_THROW({
|
||||
index_->Add(base_dataset, Config());
|
||||
index_->Seal();
|
||||
});
|
||||
|
||||
{
|
||||
//std::cout << "k = 1" << std::endl;
|
||||
//new_index->Search(GenQuery(1), Config::object{{"k", 1}});
|
||||
//new_index->Search(GenQuery(10), Config::object{{"k", 1}});
|
||||
//new_index->Search(GenQuery(100), Config::object{{"k", 1}});
|
||||
//new_index->Search(GenQuery(1000), Config::object{{"k", 1}});
|
||||
//new_index->Search(GenQuery(10000), Config::object{{"k", 1}});
|
||||
|
||||
//std::cout << "k = 5" << std::endl;
|
||||
//new_index->Search(GenQuery(1), Config::object{{"k", 5}});
|
||||
//new_index->Search(GenQuery(20), Config::object{{"k", 5}});
|
||||
//new_index->Search(GenQuery(100), Config::object{{"k", 5}});
|
||||
//new_index->Search(GenQuery(300), Config::object{{"k", 5}});
|
||||
//new_index->Search(GenQuery(500), Config::object{{"k", 5}});
|
||||
}
|
||||
}
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
INITIALIZE_EASYLOGGINGPP
|
||||
|
||||
void InitLog() {
|
||||
el::Configurations defaultConf;
|
||||
defaultConf.setToDefault();
|
||||
defaultConf.set(el::Level::Debug,
|
||||
el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)");
|
||||
el::Loggers::reconfigureLogger("default", defaultConf);
|
||||
}
|
||||
|
||||
void DataGen::Init_with_default() {
|
||||
Generate(dim, nb, nq);
|
||||
}
|
||||
|
||||
void DataGen::Generate(const int &dim, const int &nb, const int &nq) {
|
||||
this->nb = nb;
|
||||
this->nq = nq;
|
||||
this->dim = dim;
|
||||
|
||||
GenAll(dim, nb, xb, ids, nq, xq);
|
||||
assert(xb.size() == dim * nb);
|
||||
assert(xq.size() == dim * nq);
|
||||
|
||||
base_dataset = generate_dataset(nb, dim, xb.data(), ids.data());
|
||||
query_dataset = generate_query_dataset(nq, dim, xq.data());
|
||||
|
||||
}
|
||||
zilliz::knowhere::DatasetPtr DataGen::GenQuery(const int &nq) {
|
||||
xq.resize(nq * dim);
|
||||
for (size_t i = 0; i < nq * dim; ++i) {
|
||||
xq[i] = xb[i];
|
||||
}
|
||||
return generate_query_dataset(nq, dim, xq.data());
|
||||
}
|
||||
|
||||
void GenAll(const int64_t dim,
|
||||
const int64_t &nb,
|
||||
std::vector<float> &xb,
|
||||
std::vector<int64_t> &ids,
|
||||
const int64_t &nq,
|
||||
std::vector<float> &xq) {
|
||||
xb.resize(nb * dim);
|
||||
xq.resize(nq * dim);
|
||||
ids.resize(nb);
|
||||
GenAll(dim, nb, xb.data(), ids.data(), nq, xq.data());
|
||||
}
|
||||
|
||||
void GenAll(const int64_t &dim,
|
||||
const int64_t &nb,
|
||||
float *xb,
|
||||
int64_t *ids,
|
||||
const int64_t &nq,
|
||||
float *xq) {
|
||||
GenBase(dim, nb, xb, ids);
|
||||
for (size_t i = 0; i < nq * dim; ++i) {
|
||||
xq[i] = xb[i];
|
||||
}
|
||||
}
|
||||
|
||||
void GenBase(const int64_t &dim,
|
||||
const int64_t &nb,
|
||||
float *xb,
|
||||
int64_t *ids) {
|
||||
for (auto i = 0; i < nb; ++i) {
|
||||
for (auto j = 0; j < dim; ++j) {
|
||||
//p_data[i * d + j] = float(base + i);
|
||||
xb[i * dim + j] = drand48();
|
||||
}
|
||||
xb[dim * i] += i / 1000.;
|
||||
ids[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
FileIOReader::FileIOReader(const std::string &fname) {
|
||||
name = fname;
|
||||
fs = std::fstream(name, std::ios::in | std::ios::binary);
|
||||
}
|
||||
|
||||
FileIOReader::~FileIOReader() {
|
||||
fs.close();
|
||||
}
|
||||
|
||||
size_t FileIOReader::operator()(void *ptr, size_t size) {
|
||||
fs.read(reinterpret_cast<char *>(ptr), size);
|
||||
return size;
|
||||
}
|
||||
|
||||
FileIOWriter::FileIOWriter(const std::string &fname) {
|
||||
name = fname;
|
||||
fs = std::fstream(name, std::ios::out | std::ios::binary);
|
||||
}
|
||||
|
||||
FileIOWriter::~FileIOWriter() {
|
||||
fs.close();
|
||||
}
|
||||
|
||||
size_t FileIOWriter::operator()(void *ptr, size_t size) {
|
||||
fs.write(reinterpret_cast<char *>(ptr), size);
|
||||
return size;
|
||||
}
|
||||
|
||||
using namespace zilliz::knowhere;
|
||||
|
||||
DatasetPtr
|
||||
generate_dataset(int64_t nb, int64_t dim, float *xb, long *ids) {
|
||||
std::vector<int64_t> shape{nb, dim};
|
||||
auto tensor = ConstructFloatTensor((uint8_t *) xb, nb * dim * sizeof(float), shape);
|
||||
std::vector<TensorPtr> tensors{tensor};
|
||||
std::vector<FieldPtr> tensor_fields{ConstructFloatField("data")};
|
||||
auto tensor_schema = std::make_shared<Schema>(tensor_fields);
|
||||
|
||||
auto id_array = ConstructInt64Array((uint8_t *) ids, nb * sizeof(int64_t));
|
||||
std::vector<ArrayPtr> arrays{id_array};
|
||||
std::vector<FieldPtr> array_fields{ConstructInt64Field("id")};
|
||||
auto array_schema = std::make_shared<Schema>(tensor_fields);
|
||||
|
||||
auto dataset = std::make_shared<Dataset>(std::move(arrays), array_schema,
|
||||
std::move(tensors), tensor_schema);
|
||||
return dataset;
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
generate_query_dataset(int64_t nb, int64_t dim, float *xb) {
|
||||
std::vector<int64_t> shape{nb, dim};
|
||||
auto tensor = ConstructFloatTensor((uint8_t *) xb, nb * dim * sizeof(float), shape);
|
||||
std::vector<TensorPtr> tensors{tensor};
|
||||
std::vector<FieldPtr> tensor_fields{ConstructFloatField("data")};
|
||||
auto tensor_schema = std::make_shared<Schema>(tensor_fields);
|
||||
|
||||
auto dataset = std::make_shared<Dataset>(std::move(tensors), tensor_schema);
|
||||
return dataset;
|
||||
}
|
|
@ -179,13 +179,13 @@ FileMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string&
|
|||
|
||||
ErrorCode
|
||||
FileMetadataSet::SaveMetadataToMemory(void **pGraphMemFile, int64_t &len) {
|
||||
// TODO: serialize file to mem?
|
||||
// TODO(lxj): serialize file to mem?
|
||||
return ErrorCode::Fail;
|
||||
}
|
||||
|
||||
ErrorCode
|
||||
FileMetadataSet::LoadMetadataFromMemory(void *pGraphMemFile) {
|
||||
// TODO: not support yet
|
||||
// TODO(lxj): not support yet
|
||||
return ErrorCode::Fail;
|
||||
}
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue