Add Testcase for indexing

Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
pull/4973/head^2
FluorineDog 2020-09-15 10:00:00 +08:00 committed by yefu.chen
parent dc916ec10c
commit b80de55ac8
964 changed files with 200171 additions and 852 deletions

View File

@ -1,20 +1,203 @@
project(sulvim_core)
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
cmake_minimum_required( VERSION 3.14 )
add_definitions(-DELPP_THREAD_SAFE)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
message( STATUS "Building using CMake version: ${CMAKE_VERSION}" )
set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake" )
include( Utils )
# **************************** Build time, type and code version ****************************
get_current_time( BUILD_TIME )
message( STATUS "Build time = ${BUILD_TIME}" )
get_build_type( TARGET BUILD_TYPE
DEFAULT "Release" )
message( STATUS "Build type = ${BUILD_TYPE}" )
get_milvus_version( TARGET MILVUS_VERSION
DEFAULT "0.10.0" )
message( STATUS "Build version = ${MILVUS_VERSION}" )
get_last_commit_id( LAST_COMMIT_ID )
message( STATUS "LAST_COMMIT_ID = ${LAST_COMMIT_ID}" )
#configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/src/version.h.in
# ${CMAKE_CURRENT_SOURCE_DIR}/src/version.h @ONLY )
# unset(CMAKE_EXPORT_COMPILE_COMMANDS CACHE)
set( CMAKE_EXPORT_COMPILE_COMMANDS ON )
# **************************** Project ****************************
project( milvus VERSION "${MILVUS_VERSION}" )
cmake_minimum_required(VERSION 3.16)
set( CMAKE_CXX_STANDARD 17 )
set( CMAKE_CXX_STANDARD_REQUIRED on )
set (CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH};${CMAKE_CURRENT_SOURCE_DIR}/cmake")
include_directories(src)
add_subdirectory(src)
add_subdirectory(unittest)
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/dog_segment/
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/include
FILES_MATCHING PATTERN "*_c.h"
)
set( MILVUS_SOURCE_DIR ${PROJECT_SOURCE_DIR} )
set( MILVUS_BINARY_DIR ${PROJECT_BINARY_DIR} )
set( MILVUS_ENGINE_SRC ${PROJECT_SOURCE_DIR}/src )
set( MILVUS_THIRDPARTY_SRC ${PROJECT_SOURCE_DIR}/thirdparty )
install(FILES ${CMAKE_BINARY_DIR}/src/dog_segment/libmilvus_dog_segment.so
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib)
# This will set RPATH to all excutable TARGET
# self-installed dynamic libraries will be correctly linked by excutable
set( CMAKE_INSTALL_RPATH "/usr/lib" "${CMAKE_INSTALL_PREFIX}/lib" )
set( CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE )
# **************************** Dependencies ****************************
include( CTest )
include( BuildUtils )
include( DefineOptions )
include( ExternalProject )
include( FetchContent )
include_directories(thirdparty)
set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download )
set(FETCHCONTENT_QUIET OFF)
include( ThirdPartyPackages )
find_package(OpenMP REQUIRED)
# **************************** Compiler arguments ****************************
message( STATUS "Building Milvus CPU version" )
#append_flags( CMAKE_CXX_FLAGS
# FLAGS
# "-fPIC"
# "-DELPP_THREAD_SAFE"
# "-fopenmp"
# "-Werror"
# )
# **************************** Coding style check tools ****************************
find_package( ClangTools )
set( BUILD_SUPPORT_DIR "${CMAKE_SOURCE_DIR}/build-support" )
message(STATUS "CMAKE_SOURCE_DIR is at ${CMAKE_SOURCE_DIR}" )
if("$ENV{CMAKE_EXPORT_COMPILE_COMMANDS}" STREQUAL "1" OR CLANG_TIDY_FOUND)
# Generate a Clang compile_commands.json "compilation database" file for use
# with various development tools, such as Vim's YouCompleteMe plugin.
# See http://clang.llvm.org/docs/JSONCompilationDatabase.html
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
endif()
#
# "make lint" target
#
if ( NOT MILVUS_VERBOSE_LINT )
set( MILVUS_LINT_QUIET "--quiet" )
endif ()
if ( NOT LINT_EXCLUSIONS_FILE )
# source files matching a glob from a line in this file
# will be excluded from linting (cpplint, clang-tidy, clang-format)
set( LINT_EXCLUSIONS_FILE ${BUILD_SUPPORT_DIR}/lint_exclusions.txt )
endif ()
find_program( CPPLINT_BIN NAMES cpplint cpplint.py HINTS ${BUILD_SUPPORT_DIR} )
message( STATUS "Found cpplint executable at ${CPPLINT_BIN}" )
#
# "make lint" targets
#
add_custom_target( lint
${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/run_cpplint.py
--cpplint_binary ${CPPLINT_BIN}
--exclude_globs ${LINT_EXCLUSIONS_FILE}
--source_dir ${CMAKE_CURRENT_SOURCE_DIR}
${MILVUS_LINT_QUIET}
)
#
# "make clang-format" and "make check-clang-format" targets
#
if ( ${CLANG_FORMAT_FOUND} )
# runs clang format and updates files in place.
add_custom_target( clang-format
${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/run_clang_format.py
--clang_format_binary ${CLANG_FORMAT_BIN}
--exclude_globs ${LINT_EXCLUSIONS_FILE}
--source_dir ${CMAKE_CURRENT_SOURCE_DIR}/src
--fix
${MILVUS_LINT_QUIET} )
# runs clang format and exits with a non-zero exit code if any files need to be reformatted
add_custom_target( check-clang-format
${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/run_clang_format.py
--clang_format_binary ${CLANG_FORMAT_BIN}
--exclude_globs ${LINT_EXCLUSIONS_FILE}
--source_dir ${CMAKE_CURRENT_SOURCE_DIR}/src
${MILVUS_LINT_QUIET} )
endif ()
#
# "make clang-tidy" and "make check-clang-tidy" targets
#
if ( ${CLANG_TIDY_FOUND} )
# runs clang-tidy and attempts to fix any warning automatically
add_custom_target( clang-tidy
${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/run_clang_tidy.py
--clang_tidy_binary ${CLANG_TIDY_BIN}
--exclude_globs ${LINT_EXCLUSIONS_FILE}
--compile_commands ${CMAKE_BINARY_DIR}/compile_commands.json
--source_dir ${CMAKE_CURRENT_SOURCE_DIR}/src
--fix
${MILVUS_LINT_QUIET} )
# runs clang-tidy and exits with a non-zero exit code if any errors are found.
add_custom_target( check-clang-tidy
${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/run_clang_tidy.py
--clang_tidy_binary ${CLANG_TIDY_BIN}
--exclude_globs ${LINT_EXCLUSIONS_FILE}
--compile_commands ${CMAKE_BINARY_DIR}/compile_commands.json
--source_dir ${CMAKE_CURRENT_SOURCE_DIR}/src
${MILVUS_LINT_QUIET} )
endif ()
#
# Validate and print out Milvus configuration options
#
config_summary()
# **************************** Source files ****************************
add_subdirectory( thirdparty )
add_subdirectory( src )
# Unittest lib
if ( BUILD_UNIT_TEST STREQUAL "ON" )
if ( BUILD_COVERAGE STREQUAL "ON" )
append_flags( CMAKE_CXX_FLAGS
FLAGS
"-fprofile-arcs"
"-ftest-coverage"
)
endif ()
append_flags( CMAKE_CXX_FLAGS FLAGS "-DELPP_DISABLE_LOGS")
add_subdirectory( ${CMAKE_CURRENT_SOURCE_DIR}/unittest )
endif ()
add_custom_target( Clean-All COMMAND ${CMAKE_BUILD_TOOL} clean )
# **************************** Install ****************************
if ( NOT MILVUS_DB_PATH )
set( MILVUS_DB_PATH "${CMAKE_INSTALL_PREFIX}" )
endif ()
set( GPU_ENABLE "false" )

20
core/CMakeLists_old.txt Normal file
View File

@ -0,0 +1,20 @@
project(sulvim_core)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
cmake_minimum_required(VERSION 3.16)
set( CMAKE_CXX_STANDARD 17 )
set( CMAKE_CXX_STANDARD_REQUIRED on )
set (CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH};${CMAKE_CURRENT_SOURCE_DIR}/cmake")
include_directories(src)
add_subdirectory(src)
add_subdirectory(unittest)
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/dog_segment/
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/include
FILES_MATCHING PATTERN "*_c.h"
)
install(FILES ${CMAKE_BINARY_DIR}/src/dog_segment/libmilvus_dog_segment.so
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib)

View File

@ -1,8 +1,156 @@
#!/bin/bash
if [[ -d "./build" ]]; then
rm -rf build
# Compile jobs variable; Usage: $ jobs=12 ./build.sh ...
if [[ ! ${jobs+1} ]]; then
jobs=$(nproc)
fi
mkdir build && cd build
cmake ..
make -j8 && make install
BUILD_OUTPUT_DIR="cmake_build"
BUILD_TYPE="Debug"
BUILD_UNITTEST="OFF"
INSTALL_PREFIX=$(pwd)/milvus
MAKE_CLEAN="OFF"
BUILD_COVERAGE="OFF"
DB_PATH="/tmp/milvus"
PROFILING="OFF"
RUN_CPPLINT="OFF"
CUDA_COMPILER=/usr/local/cuda/bin/nvcc
GPU_VERSION="OFF" #defaults to CPU version
WITH_PROMETHEUS="ON"
CUDA_ARCH="DEFAULT"
CUSTOM_THIRDPARTY_PATH=""
while getopts "p:d:t:s:f:ulrcghzme" arg; do
case $arg in
f)
CUSTOM_THIRDPARTY_PATH=$OPTARG
;;
p)
INSTALL_PREFIX=$OPTARG
;;
d)
DB_PATH=$OPTARG
;;
t)
BUILD_TYPE=$OPTARG # BUILD_TYPE
;;
u)
echo "Build and run unittest cases"
BUILD_UNITTEST="ON"
;;
l)
RUN_CPPLINT="ON"
;;
r)
if [[ -d ${BUILD_OUTPUT_DIR} ]]; then
MAKE_CLEAN="ON"
fi
;;
c)
BUILD_COVERAGE="ON"
;;
z)
PROFILING="ON"
;;
g)
GPU_VERSION="ON"
;;
e)
WITH_PROMETHEUS="OFF"
;;
s)
CUDA_ARCH=$OPTARG
;;
h) # help
echo "
parameter:
-f: custom paths of thirdparty downloaded files(default: NULL)
-p: install prefix(default: $(pwd)/milvus)
-d: db data path(default: /tmp/milvus)
-t: build type(default: Debug)
-u: building unit test options(default: OFF)
-l: run cpplint, clang-format and clang-tidy(default: OFF)
-r: remove previous build directory(default: OFF)
-c: code coverage(default: OFF)
-z: profiling(default: OFF)
-g: build GPU version(default: OFF)
-e: build without prometheus(default: OFF)
-s: build with CUDA arch(default:DEFAULT), for example '-gencode=compute_61,code=sm_61;-gencode=compute_75,code=sm_75'
-h: help
usage:
./build.sh -p \${INSTALL_PREFIX} -t \${BUILD_TYPE} -s \${CUDA_ARCH} -f\${CUSTOM_THIRDPARTY_PATH} [-u] [-l] [-r] [-c] [-z] [-g] [-m] [-e] [-h]
"
exit 0
;;
?)
echo "ERROR! unknown argument"
exit 1
;;
esac
done
if [[ ! -d ${BUILD_OUTPUT_DIR} ]]; then
mkdir ${BUILD_OUTPUT_DIR}
fi
cd ${BUILD_OUTPUT_DIR}
# remove make cache since build.sh -l use default variables
# force update the variables each time
make rebuild_cache >/dev/null 2>&1
if [[ ${MAKE_CLEAN} == "ON" ]]; then
echo "Runing make clean in ${BUILD_OUTPUT_DIR} ..."
make clean
exit 0
fi
CMAKE_CMD="cmake \
-DBUILD_UNIT_TEST=${BUILD_UNITTEST} \
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX}
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DOpenBLAS_SOURCE=AUTO \
-DCMAKE_CUDA_COMPILER=${CUDA_COMPILER} \
-DBUILD_COVERAGE=${BUILD_COVERAGE} \
-DMILVUS_DB_PATH=${DB_PATH} \
-DENABLE_CPU_PROFILING=${PROFILING} \
-DMILVUS_GPU_VERSION=${GPU_VERSION} \
-DMILVUS_WITH_PROMETHEUS=${WITH_PROMETHEUS} \
-DMILVUS_CUDA_ARCH=${CUDA_ARCH} \
-DCUSTOM_THIRDPARTY_DOWNLOAD_PATH=${CUSTOM_THIRDPARTY_PATH} \
../"
echo ${CMAKE_CMD}
${CMAKE_CMD}
if [[ ${RUN_CPPLINT} == "ON" ]]; then
# cpplint check
make lint
if [ $? -ne 0 ]; then
echo "ERROR! cpplint check failed"
exit 1
fi
echo "cpplint check passed!"
# clang-format check
make check-clang-format
if [ $? -ne 0 ]; then
echo "ERROR! clang-format check failed"
exit 1
fi
echo "clang-format check passed!"
# clang-tidy check
make check-clang-tidy
if [ $? -ne 0 ]; then
echo "ERROR! clang-tidy check failed"
exit 1
fi
echo "clang-tidy check passed!"
else
# compile and build
make -j ${jobs} install || exit 1
fi

231
core/cmake/BuildUtils.cmake Normal file
View File

@ -0,0 +1,231 @@
# Define a function that check last file modification
function(Check_Last_Modify cache_check_lists_file_path working_dir last_modified_commit_id)
if(EXISTS "${working_dir}")
if(EXISTS "${cache_check_lists_file_path}")
set(GIT_LOG_SKIP_NUM 0)
set(_MATCH_ALL ON CACHE BOOL "Match all")
set(_LOOP_STATUS ON CACHE BOOL "Whether out of loop")
file(STRINGS ${cache_check_lists_file_path} CACHE_IGNORE_TXT)
while(_LOOP_STATUS)
foreach(_IGNORE_ENTRY ${CACHE_IGNORE_TXT})
if(NOT _IGNORE_ENTRY MATCHES "^[^#]+")
continue()
endif()
set(_MATCH_ALL OFF)
execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --name-status --pretty= WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE CHANGE_FILES)
if(NOT CHANGE_FILES STREQUAL "")
string(REPLACE "\n" ";" _CHANGE_FILES ${CHANGE_FILES})
foreach(_FILE_ENTRY ${_CHANGE_FILES})
string(REGEX MATCH "[^ \t]+$" _FILE_NAME ${_FILE_ENTRY})
execute_process(COMMAND sh -c "echo ${_FILE_NAME} | grep ${_IGNORE_ENTRY}" RESULT_VARIABLE return_code)
if (return_code EQUAL 0)
execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID)
set (${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE)
set(_LOOP_STATUS OFF)
endif()
endforeach()
else()
set(_LOOP_STATUS OFF)
endif()
endforeach()
if(_MATCH_ALL)
execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID)
set (${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE)
set(_LOOP_STATUS OFF)
endif()
math(EXPR GIT_LOG_SKIP_NUM "${GIT_LOG_SKIP_NUM} + 1")
endwhile(_LOOP_STATUS)
else()
execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID)
set (${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE)
endif()
else()
message(FATAL_ERROR "The directory ${working_dir} does not exist")
endif()
endfunction()
# Define a function that extracts a cached package
function(ExternalProject_Use_Cache project_name package_file install_path)
message(STATUS "Will use cached package file: ${package_file}")
ExternalProject_Add(${project_name}
DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E echo
"No download step needed (using cached package)"
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E echo
"No configure step needed (using cached package)"
BUILD_COMMAND ${CMAKE_COMMAND} -E echo
"No build step needed (using cached package)"
INSTALL_COMMAND ${CMAKE_COMMAND} -E echo
"No install step needed (using cached package)"
)
# We want our tar files to contain the Install/<package> prefix (not for any
# very special reason, only for consistency and so that we can identify them
# in the extraction logs) which means that we must extract them in the
# binary (top-level build) directory to have them installed in the right
# place for subsequent ExternalProjects to pick them up. It seems that the
# only way to control the working directory is with Add_Step!
ExternalProject_Add_Step(${project_name} extract
ALWAYS 1
COMMAND
${CMAKE_COMMAND} -E echo
"Extracting ${package_file} to ${install_path}"
COMMAND
${CMAKE_COMMAND} -E tar xzf ${package_file} ${install_path}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
)
ExternalProject_Add_StepTargets(${project_name} extract)
endfunction()
# Define a function that to create a new cached package
function(ExternalProject_Create_Cache project_name package_file install_path cache_username cache_password cache_path)
if(EXISTS ${package_file})
message(STATUS "Removing existing package file: ${package_file}")
file(REMOVE ${package_file})
endif()
string(REGEX REPLACE "(.+)/.+$" "\\1" package_dir ${package_file})
if(NOT EXISTS ${package_dir})
file(MAKE_DIRECTORY ${package_dir})
endif()
message(STATUS "Will create cached package file: ${package_file}")
ExternalProject_Add_Step(${project_name} package
DEPENDEES install
BYPRODUCTS ${package_file}
COMMAND ${CMAKE_COMMAND} -E echo "Updating cached package file: ${package_file}"
COMMAND ${CMAKE_COMMAND} -E tar czvf ${package_file} ${install_path}
COMMAND ${CMAKE_COMMAND} -E echo "Uploading package file ${package_file} to ${cache_path}"
COMMAND curl -u${cache_username}:${cache_password} -T ${package_file} ${cache_path}
)
ExternalProject_Add_StepTargets(${project_name} package)
endfunction()
function(ADD_THIRDPARTY_LIB LIB_NAME)
set(options)
set(one_value_args SHARED_LIB STATIC_LIB)
set(multi_value_args DEPS INCLUDE_DIRECTORIES)
cmake_parse_arguments(ARG
"${options}"
"${one_value_args}"
"${multi_value_args}"
${ARGN})
if(ARG_UNPARSED_ARGUMENTS)
message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}")
endif()
if(ARG_STATIC_LIB AND ARG_SHARED_LIB)
if(NOT ARG_STATIC_LIB)
message(FATAL_ERROR "No static or shared library provided for ${LIB_NAME}")
endif()
set(AUG_LIB_NAME "${LIB_NAME}_static")
add_library(${AUG_LIB_NAME} STATIC IMPORTED)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_LOCATION "${ARG_STATIC_LIB}")
if(ARG_DEPS)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}")
endif()
message(STATUS "Added static library dependency ${AUG_LIB_NAME}: ${ARG_STATIC_LIB}")
if(ARG_INCLUDE_DIRECTORIES)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
"${ARG_INCLUDE_DIRECTORIES}")
endif()
set(AUG_LIB_NAME "${LIB_NAME}_shared")
add_library(${AUG_LIB_NAME} SHARED IMPORTED)
if(WIN32)
# Mark the ".lib" location as part of a Windows DLL
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_IMPLIB "${ARG_SHARED_LIB}")
else()
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_LOCATION "${ARG_SHARED_LIB}")
endif()
if(ARG_DEPS)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}")
endif()
message(STATUS "Added shared library dependency ${AUG_LIB_NAME}: ${ARG_SHARED_LIB}")
if(ARG_INCLUDE_DIRECTORIES)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
"${ARG_INCLUDE_DIRECTORIES}")
endif()
elseif(ARG_STATIC_LIB)
set(AUG_LIB_NAME "${LIB_NAME}_static")
add_library(${AUG_LIB_NAME} STATIC IMPORTED)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_LOCATION "${ARG_STATIC_LIB}")
if(ARG_DEPS)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}")
endif()
message(STATUS "Added static library dependency ${AUG_LIB_NAME}: ${ARG_STATIC_LIB}")
if(ARG_INCLUDE_DIRECTORIES)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
"${ARG_INCLUDE_DIRECTORIES}")
endif()
elseif(ARG_SHARED_LIB)
set(AUG_LIB_NAME "${LIB_NAME}_shared")
add_library(${AUG_LIB_NAME} SHARED IMPORTED)
if(WIN32)
# Mark the ".lib" location as part of a Windows DLL
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_IMPLIB "${ARG_SHARED_LIB}")
else()
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_LOCATION "${ARG_SHARED_LIB}")
endif()
message(STATUS "Added shared library dependency ${AUG_LIB_NAME}: ${ARG_SHARED_LIB}")
if(ARG_DEPS)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}")
endif()
if(ARG_INCLUDE_DIRECTORIES)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
"${ARG_INCLUDE_DIRECTORIES}")
endif()
else()
message(FATAL_ERROR "No static or shared library provided for ${LIB_NAME}")
endif()
endfunction()
MACRO (import_mysql_inc)
find_path (MYSQL_INCLUDE_DIR
NAMES "mysql.h"
PATH_SUFFIXES "mysql")
if (${MYSQL_INCLUDE_DIR} STREQUAL "MYSQL_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Could not found MySQL include directory")
else ()
include_directories(${MYSQL_INCLUDE_DIR})
endif ()
ENDMACRO (import_mysql_inc)
MACRO(using_ccache_if_defined MILVUS_USE_CCACHE)
if (MILVUS_USE_CCACHE)
find_program(CCACHE_FOUND ccache)
if (CCACHE_FOUND)
message(STATUS "Using ccache: ${CCACHE_FOUND}")
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_FOUND})
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ${CCACHE_FOUND})
# let ccache preserve C++ comments, because some of them may be
# meaningful to the compiler
set(ENV{CCACHE_COMMENTS} "1")
endif (CCACHE_FOUND)
endif ()
ENDMACRO(using_ccache_if_defined)

View File

@ -0,0 +1,156 @@
macro(set_option_category name)
set(MILVUS_OPTION_CATEGORY ${name})
list(APPEND "MILVUS_OPTION_CATEGORIES" ${name})
endmacro()
macro(define_option name description default)
option(${name} ${description} ${default})
list(APPEND "MILVUS_${MILVUS_OPTION_CATEGORY}_OPTION_NAMES" ${name})
set("${name}_OPTION_DESCRIPTION" ${description})
set("${name}_OPTION_DEFAULT" ${default})
set("${name}_OPTION_TYPE" "bool")
endmacro()
function(list_join lst glue out)
if ("${${lst}}" STREQUAL "")
set(${out} "" PARENT_SCOPE)
return()
endif ()
list(GET ${lst} 0 joined)
list(REMOVE_AT ${lst} 0)
foreach (item ${${lst}})
set(joined "${joined}${glue}${item}")
endforeach ()
set(${out} ${joined} PARENT_SCOPE)
endfunction()
macro(define_option_string name description default)
set(${name} ${default} CACHE STRING ${description})
list(APPEND "MILVUS_${MILVUS_OPTION_CATEGORY}_OPTION_NAMES" ${name})
set("${name}_OPTION_DESCRIPTION" ${description})
set("${name}_OPTION_DEFAULT" "\"${default}\"")
set("${name}_OPTION_TYPE" "string")
set("${name}_OPTION_ENUM" ${ARGN})
list_join("${name}_OPTION_ENUM" "|" "${name}_OPTION_ENUM")
if (NOT ("${${name}_OPTION_ENUM}" STREQUAL ""))
set_property(CACHE ${name} PROPERTY STRINGS ${ARGN})
endif ()
endmacro()
#----------------------------------------------------------------------
set_option_category("Milvus Build Option")
define_option(MILVUS_GPU_VERSION "Build GPU version" OFF)
#----------------------------------------------------------------------
set_option_category("Thirdparty")
set(MILVUS_DEPENDENCY_SOURCE_DEFAULT "BUNDLED")
define_option_string(MILVUS_DEPENDENCY_SOURCE
"Method to use for acquiring MILVUS's build dependencies"
"${MILVUS_DEPENDENCY_SOURCE_DEFAULT}"
"AUTO"
"BUNDLED"
"SYSTEM")
define_option(MILVUS_USE_CCACHE "Use ccache when compiling (if available)" ON)
define_option(MILVUS_VERBOSE_THIRDPARTY_BUILD
"Show output from ExternalProjects rather than just logging to files" ON)
define_option(MILVUS_WITH_EASYLOGGINGPP "Build with Easylogging++ library" ON)
define_option(MILVUS_WITH_GRPC "Build with GRPC" OFF)
define_option(MILVUS_WITH_ZLIB "Build with zlib compression" ON)
define_option(MILVUS_WITH_OPENTRACING "Build with Opentracing" ON)
define_option(MILVUS_WITH_YAMLCPP "Build with yaml-cpp library" ON)
define_option(MILVUS_WITH_PULSAR "Build with pulsar-client" ON)
#----------------------------------------------------------------------
set_option_category("Test and benchmark")
unset(MILVUS_BUILD_TESTS CACHE)
if (BUILD_UNIT_TEST)
define_option(MILVUS_BUILD_TESTS "Build the MILVUS googletest unit tests" ON)
else ()
define_option(MILVUS_BUILD_TESTS "Build the MILVUS googletest unit tests" OFF)
endif (BUILD_UNIT_TEST)
#----------------------------------------------------------------------
macro(config_summary)
message(STATUS "---------------------------------------------------------------------")
message(STATUS "MILVUS version: ${MILVUS_VERSION}")
message(STATUS)
message(STATUS "Build configuration summary:")
message(STATUS " Generator: ${CMAKE_GENERATOR}")
message(STATUS " Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS " Source directory: ${CMAKE_CURRENT_SOURCE_DIR}")
if (${CMAKE_EXPORT_COMPILE_COMMANDS})
message(
STATUS " Compile commands: ${CMAKE_CURRENT_BINARY_DIR}/compile_commands.json")
endif ()
foreach (category ${MILVUS_OPTION_CATEGORIES})
message(STATUS)
message(STATUS "${category} options:")
set(option_names ${MILVUS_${category}_OPTION_NAMES})
set(max_value_length 0)
foreach (name ${option_names})
string(LENGTH "\"${${name}}\"" value_length)
if (${max_value_length} LESS ${value_length})
set(max_value_length ${value_length})
endif ()
endforeach ()
foreach (name ${option_names})
if ("${${name}_OPTION_TYPE}" STREQUAL "string")
set(value "\"${${name}}\"")
else ()
set(value "${${name}}")
endif ()
set(default ${${name}_OPTION_DEFAULT})
set(description ${${name}_OPTION_DESCRIPTION})
string(LENGTH ${description} description_length)
if (${description_length} LESS 70)
string(
SUBSTRING
" "
${description_length} -1 description_padding)
else ()
set(description_padding "
")
endif ()
set(comment "[${name}]")
if ("${value}" STREQUAL "${default}")
set(comment "[default] ${comment}")
endif ()
if (NOT ("${${name}_OPTION_ENUM}" STREQUAL ""))
set(comment "${comment} [${${name}_OPTION_ENUM}]")
endif ()
string(
SUBSTRING "${value} "
0 ${max_value_length} value)
message(STATUS " ${description} ${description_padding} ${value} ${comment}")
endforeach ()
endforeach ()
endmacro()

View File

@ -0,0 +1,111 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Tries to find the clang-tidy and clang-format modules
#
# Usage of this module as follows:
#
# find_package(ClangTools)
#
# Variables used by this module, they can change the default behaviour and need
# to be set before calling find_package:
#
# ClangToolsBin_HOME -
# When set, this path is inspected instead of standard library binary locations
# to find clang-tidy and clang-format
#
# This module defines
# CLANG_TIDY_BIN, The path to the clang tidy binary
# CLANG_TIDY_FOUND, Whether clang tidy was found
# CLANG_FORMAT_BIN, The path to the clang format binary
# CLANG_TIDY_FOUND, Whether clang format was found
find_program(CLANG_TIDY_BIN
NAMES
clang-tidy-7.0
clang-tidy-7
clang-tidy-6.0
clang-tidy-5.0
clang-tidy-4.0
clang-tidy-3.9
clang-tidy-3.8
clang-tidy-3.7
clang-tidy-3.6
clang-tidy
PATHS ${ClangTools_PATH} $ENV{CLANG_TOOLS_PATH} /usr/local/bin /usr/bin
NO_DEFAULT_PATH
)
if ( "${CLANG_TIDY_BIN}" STREQUAL "CLANG_TIDY_BIN-NOTFOUND" )
set(CLANG_TIDY_FOUND 0)
message("clang-tidy not found")
else()
set(CLANG_TIDY_FOUND 1)
message("clang-tidy found at ${CLANG_TIDY_BIN}")
endif()
if (CLANG_FORMAT_VERSION)
find_program(CLANG_FORMAT_BIN
NAMES clang-format-${CLANG_FORMAT_VERSION}
PATHS
${ClangTools_PATH}
$ENV{CLANG_TOOLS_PATH}
/usr/local/bin /usr/bin
NO_DEFAULT_PATH
)
# If not found yet, search alternative locations
if (("${CLANG_FORMAT_BIN}" STREQUAL "CLANG_FORMAT_BIN-NOTFOUND") AND APPLE)
# Homebrew ships older LLVM versions in /usr/local/opt/llvm@version/
STRING(REGEX REPLACE "^([0-9]+)\\.[0-9]+" "\\1" CLANG_FORMAT_MAJOR_VERSION "${CLANG_FORMAT_VERSION}")
STRING(REGEX REPLACE "^[0-9]+\\.([0-9]+)" "\\1" CLANG_FORMAT_MINOR_VERSION "${CLANG_FORMAT_VERSION}")
if ("${CLANG_FORMAT_MINOR_VERSION}" STREQUAL "0")
find_program(CLANG_FORMAT_BIN
NAMES clang-format
PATHS /usr/local/opt/llvm@${CLANG_FORMAT_MAJOR_VERSION}/bin
NO_DEFAULT_PATH
)
else()
find_program(CLANG_FORMAT_BIN
NAMES clang-format
PATHS /usr/local/opt/llvm@${CLANG_FORMAT_VERSION}/bin
NO_DEFAULT_PATH
)
endif()
endif()
else()
find_program(CLANG_FORMAT_BIN
NAMES
clang-format-7.0
clang-format-7
clang-format-6.0
clang-format-5.0
clang-format-4.0
clang-format-3.9
clang-format-3.8
clang-format-3.7
clang-format-3.6
clang-format
PATHS ${ClangTools_PATH} $ENV{CLANG_TOOLS_PATH} /usr/local/bin /usr/bin
NO_DEFAULT_PATH
)
endif()
if ( "${CLANG_FORMAT_BIN}" STREQUAL "CLANG_FORMAT_BIN-NOTFOUND" )
set(CLANG_FORMAT_FOUND 0)
message("clang-format not found")
else()
set(CLANG_FORMAT_FOUND 1)
message("clang-format found at ${CLANG_FORMAT_BIN}")
endif()

View File

@ -0,0 +1,172 @@
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
message(STATUS "Using ${MILVUS_DEPENDENCY_SOURCE} approach to find dependencies")
# For each dependency, set dependency source to global default, if unset
foreach (DEPENDENCY ${MILVUS_THIRDPARTY_DEPENDENCIES})
if ("${${DEPENDENCY}_SOURCE}" STREQUAL "")
set(${DEPENDENCY}_SOURCE ${MILVUS_DEPENDENCY_SOURCE})
endif ()
endforeach ()
# ----------------------------------------------------------------------
# Identify OS
if (UNIX)
if (APPLE)
set(CMAKE_OS_NAME "osx" CACHE STRING "Operating system name" FORCE)
else (APPLE)
## Check for Debian GNU/Linux ________________
find_file(DEBIAN_FOUND debian_version debconf.conf
PATHS /etc
)
if (DEBIAN_FOUND)
set(CMAKE_OS_NAME "debian" CACHE STRING "Operating system name" FORCE)
endif (DEBIAN_FOUND)
## Check for Fedora _________________________
find_file(FEDORA_FOUND fedora-release
PATHS /etc
)
if (FEDORA_FOUND)
set(CMAKE_OS_NAME "fedora" CACHE STRING "Operating system name" FORCE)
endif (FEDORA_FOUND)
## Check for RedHat _________________________
find_file(REDHAT_FOUND redhat-release inittab.RH
PATHS /etc
)
if (REDHAT_FOUND)
set(CMAKE_OS_NAME "redhat" CACHE STRING "Operating system name" FORCE)
endif (REDHAT_FOUND)
## Extra check for Ubuntu ____________________
if (DEBIAN_FOUND)
## At its core Ubuntu is a Debian system, with
## a slightly altered configuration; hence from
## a first superficial inspection a system will
## be considered as Debian, which signifies an
## extra check is required.
find_file(UBUNTU_EXTRA legal issue
PATHS /etc
)
if (UBUNTU_EXTRA)
## Scan contents of file
file(STRINGS ${UBUNTU_EXTRA} UBUNTU_FOUND
REGEX Ubuntu
)
## Check result of string search
if (UBUNTU_FOUND)
set(CMAKE_OS_NAME "ubuntu" CACHE STRING "Operating system name" FORCE)
set(DEBIAN_FOUND FALSE)
find_program(LSB_RELEASE_EXEC lsb_release)
execute_process(COMMAND ${LSB_RELEASE_EXEC} -rs
OUTPUT_VARIABLE LSB_RELEASE_ID_SHORT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
STRING(REGEX REPLACE "\\." "_" UBUNTU_VERSION "${LSB_RELEASE_ID_SHORT}")
endif (UBUNTU_FOUND)
endif (UBUNTU_EXTRA)
endif (DEBIAN_FOUND)
endif (APPLE)
endif (UNIX)
# ----------------------------------------------------------------------
# thirdparty directory
set(THIRDPARTY_DIR "${MILVUS_SOURCE_DIR}/thirdparty")
# ----------------------------------------------------------------------
# ExternalProject options
string(TOUPPER ${CMAKE_BUILD_TYPE} UPPERCASE_BUILD_TYPE)
set(EP_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}}")
set(EP_C_FLAGS "${CMAKE_C_FLAGS} ${CMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}}")
# Set -fPIC on all external projects
set(EP_CXX_FLAGS "${EP_CXX_FLAGS} -fPIC")
set(EP_C_FLAGS "${EP_C_FLAGS} -fPIC")
# CC/CXX environment variables are captured on the first invocation of the
# builder (e.g make or ninja) instead of when CMake is invoked into to build
# directory. This leads to issues if the variables are exported in a subshell
# and the invocation of make/ninja is in distinct subshell without the same
# environment (CC/CXX).
set(EP_COMMON_TOOLCHAIN -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER})
if (CMAKE_AR)
set(EP_COMMON_TOOLCHAIN ${EP_COMMON_TOOLCHAIN} -DCMAKE_AR=${CMAKE_AR})
endif ()
if (CMAKE_RANLIB)
set(EP_COMMON_TOOLCHAIN ${EP_COMMON_TOOLCHAIN} -DCMAKE_RANLIB=${CMAKE_RANLIB})
endif ()
# External projects are still able to override the following declarations.
# cmake command line will favor the last defined variable when a duplicate is
# encountered. This requires that `EP_COMMON_CMAKE_ARGS` is always the first
# argument.
set(EP_COMMON_CMAKE_ARGS
${EP_COMMON_TOOLCHAIN}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DCMAKE_C_FLAGS=${EP_C_FLAGS}
-DCMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_C_FLAGS}
-DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_CXX_FLAGS})
if (NOT MILVUS_VERBOSE_THIRDPARTY_BUILD)
set(EP_LOG_OPTIONS LOG_CONFIGURE 1 LOG_BUILD 1 LOG_INSTALL 1 LOG_DOWNLOAD 1)
else ()
set(EP_LOG_OPTIONS)
endif ()
# Ensure that a default make is set
if ("${MAKE}" STREQUAL "")
find_program(MAKE make)
endif ()
if (NOT DEFINED MAKE_BUILD_ARGS)
set(MAKE_BUILD_ARGS "-j8")
endif ()
message(STATUS "Third Party MAKE_BUILD_ARGS = ${MAKE_BUILD_ARGS}")
# ----------------------------------------------------------------------
# Find pthreads
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
# ----------------------------------------------------------------------
# Versions and URLs for toolchain builds, which also can be used to configure
# offline builds
# Read toolchain versions from cpp/thirdparty/versions.txt
file(STRINGS "${THIRDPARTY_DIR}/versions.txt" TOOLCHAIN_VERSIONS_TXT)
foreach (_VERSION_ENTRY ${TOOLCHAIN_VERSIONS_TXT})
# Exclude comments
if (NOT _VERSION_ENTRY MATCHES "^[^#][A-Za-z0-9-_]+_VERSION=")
continue()
endif ()
string(REGEX MATCH "^[^=]*" _LIB_NAME ${_VERSION_ENTRY})
string(REPLACE "${_LIB_NAME}=" "" _LIB_VERSION ${_VERSION_ENTRY})
# Skip blank or malformed lines
if (${_LIB_VERSION} STREQUAL "")
continue()
endif ()
# For debugging
#message(STATUS "${_LIB_NAME}: ${_LIB_VERSION}")
set(${_LIB_NAME} "${_LIB_VERSION}")
endforeach ()

102
core/cmake/Utils.cmake Normal file
View File

@ -0,0 +1,102 @@
# get build time
MACRO(get_current_time CURRENT_TIME)
execute_process(COMMAND "date" "+%Y-%m-%d %H:%M.%S" OUTPUT_VARIABLE ${CURRENT_TIME})
string(REGEX REPLACE "\n" "" ${CURRENT_TIME} ${${CURRENT_TIME}})
ENDMACRO(get_current_time)
# get build type
MACRO(get_build_type)
cmake_parse_arguments(BUILD_TYPE "" "TARGET;DEFAULT" "" ${ARGN})
if (NOT DEFINED CMAKE_BUILD_TYPE)
set(${BUILD_TYPE_TARGET} ${BUILD_TYPE_DEFAULT})
elseif (CMAKE_BUILD_TYPE STREQUAL "Release")
set(${BUILD_TYPE_TARGET} "Release")
elseif (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(${BUILD_TYPE_TARGET} "Debug")
else ()
set(${BUILD_TYPE_TARGET} ${BUILD_TYPE_DEFAULT})
endif ()
ENDMACRO(get_build_type)
# get git branch name
MACRO(get_git_branch_name GIT_BRANCH_NAME)
set(GIT_BRANCH_NAME_REGEX "[0-9]+\\.[0-9]+\\.[0-9]")
execute_process(COMMAND sh "-c" "git log --decorate | head -n 1 | sed 's/.*(\\(.*\\))/\\1/' | sed 's/.*, //' | sed 's=[a-zA-Z]*\/==g'"
OUTPUT_VARIABLE ${GIT_BRANCH_NAME})
if (NOT GIT_BRANCH_NAME MATCHES "${GIT_BRANCH_NAME_REGEX}")
execute_process(COMMAND "git" rev-parse --abbrev-ref HEAD OUTPUT_VARIABLE ${GIT_BRANCH_NAME})
endif ()
if (NOT GIT_BRANCH_NAME MATCHES "${GIT_BRANCH_NAME_REGEX}")
execute_process(COMMAND "git" symbolic-ref -q --short HEAD OUTPUT_VARIABLE ${GIT_BRANCH_NAME})
endif ()
message(DEBUG "GIT_BRANCH_NAME = ${GIT_BRANCH_NAME}")
# Some unexpected case
if (NOT GIT_BRANCH_NAME STREQUAL "")
string(REGEX REPLACE "\n" "" GIT_BRANCH_NAME ${GIT_BRANCH_NAME})
else ()
set(GIT_BRANCH_NAME "#")
endif ()
ENDMACRO(get_git_branch_name)
# get last commit id
MACRO(get_last_commit_id LAST_COMMIT_ID)
execute_process(COMMAND sh "-c" "git log --decorate | head -n 1 | awk '{print $2}'"
OUTPUT_VARIABLE ${LAST_COMMIT_ID})
message(DEBUG "LAST_COMMIT_ID = ${${LAST_COMMIT_ID}}")
if (NOT LAST_COMMIT_ID STREQUAL "")
string(REGEX REPLACE "\n" "" ${LAST_COMMIT_ID} ${${LAST_COMMIT_ID}})
else ()
set(LAST_COMMIT_ID "Unknown")
endif ()
ENDMACRO(get_last_commit_id)
# get milvus version
MACRO(get_milvus_version)
cmake_parse_arguments(VER "" "TARGET;DEFAULT" "" ${ARGN})
# Step 1: get branch name
get_git_branch_name(GIT_BRANCH_NAME)
message(DEBUG ${GIT_BRANCH_NAME})
# Step 2: match MAJOR.MINOR.PATCH format or set DEFAULT value
string(REGEX MATCH "([0-9]+)\\.([0-9]+)\\.([0-9]+)" ${VER_TARGET} ${GIT_BRANCH_NAME})
if (NOT ${VER_TARGET})
set(${VER_TARGET} ${VER_DEFAULT})
endif()
ENDMACRO(get_milvus_version)
# set definition
MACRO(set_milvus_definition DEF_PASS_CMAKE MILVUS_DEF)
if (${${DEF_PASS_CMAKE}})
add_compile_definitions(${MILVUS_DEF})
endif()
ENDMACRO(set_milvus_definition)
MACRO(append_flags target)
cmake_parse_arguments(M "" "" "FLAGS" ${ARGN})
foreach(FLAG IN ITEMS ${M_FLAGS})
set(${target} "${${target}} ${FLAG}")
endforeach()
ENDMACRO(append_flags)
macro(create_executable)
cmake_parse_arguments(E "" "TARGET" "SRCS;LIBS;DEFS" ${ARGN})
add_executable(${E_TARGET})
target_sources(${E_TARGET} PRIVATE ${E_SRCS})
target_link_libraries(${E_TARGET} PRIVATE ${E_LIBS})
target_compile_definitions(${E_TARGET} PRIVATE ${E_DEFS})
endmacro()
macro(create_library)
cmake_parse_arguments(L "" "TARGET" "SRCS;LIBS;DEFS" ${ARGN})
add_library(${L_TARGET} ${L_SRCS})
target_link_libraries(${L_TARGET} PRIVATE ${L_LIBS})
target_compile_definitions(${L_TARGET} PRIVATE ${L_DEFS})
endmacro()

View File

@ -1,4 +1,75 @@
add_subdirectory(utils)
add_subdirectory(dog_segment)
#add_subdirectory(index)
add_subdirectory(query)
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
include_directories(${MILVUS_SOURCE_DIR})
include_directories(${MILVUS_ENGINE_SRC})
include_directories(${MILVUS_THIRDPARTY_SRC})
#include_directories(${MILVUS_ENGINE_SRC}/grpc)
set(FOUND_OPENBLAS "unknown")
add_subdirectory(index)
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
foreach (DIR ${INDEX_INCLUDE_DIRS})
include_directories(${DIR})
endforeach ()
add_subdirectory( utils )
add_subdirectory( log)
add_subdirectory( dog_segment)
add_subdirectory( cache )
add_subdirectory( query )
# add_subdirectory( db ) # target milvus_engine
# add_subdirectory( server )
# set(link_lib
# milvus_engine
# # dog_segment
# #query
# utils
# curl
# )
# set( BOOST_LIB libboost_system.a
# libboost_filesystem.a
# libboost_serialization.a
# )
# set( THIRD_PARTY_LIBS yaml-cpp
# )
# target_link_libraries( server
# PUBLIC ${link_lib}
# ${THIRD_PARTY_LIBS}
# ${BOOST_LIB}
# )
# # **************************** Get&Print Include Directories ****************************
# get_property( dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES )
# foreach ( dir ${dirs} )
# message( STATUS "Current Include DIRS: ")
# endforeach ()
# set( SERVER_LIBS server )
# add_executable( milvus_server ${CMAKE_CURRENT_SOURCE_DIR}/main.cpp
# )
# #target_include_directories(db PUBLIC ${PROJECT_BINARY_DIR}/thirdparty/pulsar-client-cpp/pulsar-client-cpp-src/pulsar-client-cpp/include)
# target_link_libraries( milvus_server PRIVATE ${SERVER_LIBS} )
# install( TARGETS milvus_server DESTINATION bin )

View File

@ -0,0 +1,4 @@
add_subdirectory(utils)
add_subdirectory(dog_segment)
#add_subdirectory(index)
add_subdirectory(query)

20
core/src/cache/CMakeLists.txt vendored Normal file
View File

@ -0,0 +1,20 @@
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
aux_source_directory( ${MILVUS_ENGINE_SRC}/cache CACHE_FILES )
add_library( cache STATIC )
target_sources( cache PRIVATE ${CACHE_FILES}
CacheMgr.inl
Cache.inl
)
target_include_directories( cache PUBLIC ${MILVUS_ENGINE_SRC}/cache )

104
core/src/cache/Cache.h vendored Normal file
View File

@ -0,0 +1,104 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include "LRU.h"
#include "utils/Log.h"
#include <atomic>
#include <mutex>
#include <set>
#include <string>
namespace milvus {
namespace cache {
template <typename ItemObj>
class Cache {
public:
// mem_capacity, units:GB
Cache(int64_t capacity_gb, int64_t cache_max_count, const std::string& header = "");
~Cache() = default;
int64_t
usage() const {
return usage_;
}
// unit: BYTE
int64_t
capacity() const {
return capacity_;
}
// unit: BYTE
void
set_capacity(int64_t capacity);
double
freemem_percent() const {
return freemem_percent_;
}
void
set_freemem_percent(double percent) {
freemem_percent_ = percent;
}
size_t
size() const;
bool
exists(const std::string& key);
ItemObj
get(const std::string& key);
void
insert(const std::string& key, const ItemObj& item);
void
erase(const std::string& key);
bool
reserve(const int64_t size);
void
print();
void
clear();
private:
void
insert_internal(const std::string& key, const ItemObj& item);
void
erase_internal(const std::string& key);
void
free_memory_internal(const int64_t target_size);
private:
std::string header_;
int64_t usage_;
int64_t capacity_;
double freemem_percent_;
LRU<std::string, ItemObj> lru_;
mutable std::mutex mutex_;
};
} // namespace cache
} // namespace milvus
#include "cache/Cache.inl"

191
core/src/cache/Cache.inl vendored Normal file
View File

@ -0,0 +1,191 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
namespace milvus {
namespace cache {
constexpr double DEFAULT_THRESHOLD_PERCENT = 0.7;
template <typename ItemObj>
Cache<ItemObj>::Cache(int64_t capacity, int64_t cache_max_count, const std::string& header)
: header_(header),
usage_(0),
capacity_(capacity),
freemem_percent_(DEFAULT_THRESHOLD_PERCENT),
lru_(cache_max_count) {
}
template <typename ItemObj>
void
Cache<ItemObj>::set_capacity(int64_t capacity) {
std::lock_guard<std::mutex> lock(mutex_);
if (capacity > 0) {
capacity_ = capacity;
free_memory_internal(capacity);
}
}
template <typename ItemObj>
size_t
Cache<ItemObj>::size() const {
std::lock_guard<std::mutex> lock(mutex_);
return lru_.size();
}
template <typename ItemObj>
bool
Cache<ItemObj>::exists(const std::string& key) {
std::lock_guard<std::mutex> lock(mutex_);
return lru_.exists(key);
}
template <typename ItemObj>
ItemObj
Cache<ItemObj>::get(const std::string& key) {
std::lock_guard<std::mutex> lock(mutex_);
if (!lru_.exists(key)) {
return nullptr;
}
return lru_.get(key);
}
template <typename ItemObj>
void
Cache<ItemObj>::insert(const std::string& key, const ItemObj& item) {
std::lock_guard<std::mutex> lock(mutex_);
insert_internal(key, item);
}
template <typename ItemObj>
void
Cache<ItemObj>::erase(const std::string& key) {
std::lock_guard<std::mutex> lock(mutex_);
erase_internal(key);
}
template <typename ItemObj>
bool
Cache<ItemObj>::reserve(const int64_t item_size) {
std::lock_guard<std::mutex> lock(mutex_);
if (item_size > capacity_) {
LOG_SERVER_ERROR_ << header_ << " item size " << (item_size >> 20) << "MB too big to insert into cache capacity"
<< (capacity_ >> 20) << "MB";
return false;
}
if (item_size > capacity_ - usage_) {
free_memory_internal(capacity_ - item_size);
}
return true;
}
template <typename ItemObj>
void
Cache<ItemObj>::clear() {
std::lock_guard<std::mutex> lock(mutex_);
lru_.clear();
usage_ = 0;
LOG_SERVER_DEBUG_ << header_ << " Clear cache !";
}
template <typename ItemObj>
void
Cache<ItemObj>::print() {
std::lock_guard<std::mutex> lock(mutex_);
size_t cache_count = lru_.size();
// for (auto it = lru_.begin(); it != lru_.end(); ++it) {
// LOG_SERVER_DEBUG_ << it->first;
// }
LOG_SERVER_DEBUG_ << header_ << " [item count]: " << cache_count << ", [usage] " << (usage_ >> 20)
<< "MB, [capacity] " << (capacity_ >> 20) << "MB";
}
template <typename ItemObj>
void
Cache<ItemObj>::insert_internal(const std::string& key, const ItemObj& item) {
if (item == nullptr) {
return;
}
size_t item_size = item->Size();
// if key already exist, subtract old item size
if (lru_.exists(key)) {
const ItemObj& old_item = lru_.get(key);
usage_ -= old_item->Size();
}
// plus new item size
usage_ += item_size;
// if usage exceed capacity, free some items
if (usage_ > capacity_) {
LOG_SERVER_DEBUG_ << header_ << " Current usage " << (usage_ >> 20) << "MB is too high for capacity "
<< (capacity_ >> 20) << "MB, start free memory";
free_memory_internal(capacity_);
}
// insert new item
lru_.put(key, item);
LOG_SERVER_DEBUG_ << header_ << " Insert " << key << " size: " << (item_size >> 20) << "MB into cache";
LOG_SERVER_DEBUG_ << header_ << " Count: " << lru_.size() << ", Usage: " << (usage_ >> 20) << "MB, Capacity: "
<< (capacity_ >> 20) << "MB";
}
template <typename ItemObj>
void
Cache<ItemObj>::erase_internal(const std::string& key) {
if (!lru_.exists(key)) {
return;
}
const ItemObj& item = lru_.get(key);
size_t item_size = item->Size();
lru_.erase(key);
usage_ -= item_size;
LOG_SERVER_DEBUG_ << header_ << " Erase " << key << " size: " << (item_size >> 20) << "MB from cache";
LOG_SERVER_DEBUG_ << header_ << " Count: " << lru_.size() << ", Usage: " << (usage_ >> 20) << "MB, Capacity: "
<< (capacity_ >> 20) << "MB";
}
template <typename ItemObj>
void
Cache<ItemObj>::free_memory_internal(const int64_t target_size) {
int64_t threshold = std::min((int64_t)(capacity_ * freemem_percent_), target_size);
int64_t delta_size = usage_ - threshold;
if (delta_size <= 0) {
delta_size = 1; // ensure at least one item erased
}
std::set<std::string> key_array;
int64_t released_size = 0;
auto it = lru_.rbegin();
while (it != lru_.rend() && released_size < delta_size) {
auto& key = it->first;
auto& obj_ptr = it->second;
key_array.emplace(key);
released_size += obj_ptr->Size();
++it;
}
LOG_SERVER_DEBUG_ << header_ << " To be released memory size: " << (released_size >> 20) << "MB";
for (auto& key : key_array) {
erase_internal(key);
}
}
} // namespace cache
} // namespace milvus

72
core/src/cache/CacheMgr.h vendored Normal file
View File

@ -0,0 +1,72 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include "Cache.h"
// #include "s/Metrics.h"
#include "utils/Log.h"
#include <memory>
#include <string>
namespace milvus {
namespace cache {
template <typename ItemObj>
class CacheMgr {
public:
virtual uint64_t
ItemCount() const;
virtual bool
ItemExists(const std::string& key);
virtual ItemObj
GetItem(const std::string& key);
virtual void
InsertItem(const std::string& key, const ItemObj& data);
virtual void
EraseItem(const std::string& key);
virtual bool
Reserve(const int64_t size);
virtual void
PrintInfo();
virtual void
ClearCache();
int64_t
CacheUsage() const;
int64_t
CacheCapacity() const;
void
SetCapacity(int64_t capacity);
protected:
CacheMgr();
virtual ~CacheMgr();
protected:
std::shared_ptr<Cache<ItemObj>> cache_;
};
} // namespace cache
} // namespace milvus
#include "cache/CacheMgr.inl"

137
core/src/cache/CacheMgr.inl vendored Normal file
View File

@ -0,0 +1,137 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
namespace milvus {
namespace cache {
template <typename ItemObj>
CacheMgr<ItemObj>::CacheMgr() {
}
template <typename ItemObj>
CacheMgr<ItemObj>::~CacheMgr() {
}
template <typename ItemObj>
uint64_t
CacheMgr<ItemObj>::ItemCount() const {
if (cache_ == nullptr) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return 0;
}
return (uint64_t)(cache_->size());
}
template <typename ItemObj>
bool
CacheMgr<ItemObj>::ItemExists(const std::string& key) {
if (cache_ == nullptr) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return false;
}
return cache_->exists(key);
}
template <typename ItemObj>
ItemObj
CacheMgr<ItemObj>::GetItem(const std::string& key) {
if (cache_ == nullptr) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return nullptr;
}
// server::Metrics::GetInstance().CacheAccessTotalIncrement();
return cache_->get(key);
}
template <typename ItemObj>
void
CacheMgr<ItemObj>::InsertItem(const std::string& key, const ItemObj& data) {
if (cache_ == nullptr) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return;
}
cache_->insert(key, data);
// server::Metrics::GetInstance().CacheAccessTotalIncrement();
}
template <typename ItemObj>
void
CacheMgr<ItemObj>::EraseItem(const std::string& key) {
if (cache_ == nullptr) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return;
}
cache_->erase(key);
// server::Metrics::GetInstance().CacheAccessTotalIncrement();
}
template <typename ItemObj>
bool
CacheMgr<ItemObj>::Reserve(const int64_t size) {
if (cache_ == nullptr) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return false;
}
return cache_->reserve(size);
}
template <typename ItemObj>
void
CacheMgr<ItemObj>::PrintInfo() {
if (cache_ == nullptr) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return;
}
cache_->print();
}
template <typename ItemObj>
void
CacheMgr<ItemObj>::ClearCache() {
if (cache_ == nullptr) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return;
}
cache_->clear();
}
template <typename ItemObj>
int64_t
CacheMgr<ItemObj>::CacheUsage() const {
if (cache_ == nullptr) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return 0;
}
return cache_->usage();
}
template <typename ItemObj>
int64_t
CacheMgr<ItemObj>::CacheCapacity() const {
if (cache_ == nullptr) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return 0;
}
return cache_->capacity();
}
template <typename ItemObj>
void
CacheMgr<ItemObj>::SetCapacity(int64_t capacity) {
if (cache_ == nullptr) {
LOG_SERVER_ERROR_ << "Cache doesn't exist";
return;
}
cache_->set_capacity(capacity);
}
} // namespace cache
} // namespace milvus

49
core/src/cache/CpuCacheMgr.cpp vendored Normal file
View File

@ -0,0 +1,49 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "cache/CpuCacheMgr.h"
#include <utility>
// #include <fiu/fiu-local.h>
#include "config/ServerConfig.h"
#include "utils/Log.h"
namespace milvus {
namespace cache {
CpuCacheMgr::CpuCacheMgr() {
// cache_ = std::make_shared<Cache<DataObjPtr>>(config.cache.cache_size(), 1UL << 32, "[CACHE CPU]");
// if (config.cache.cpu_cache_threshold() > 0.0) {
// cache_->set_freemem_percent(config.cache.cpu_cache_threshold());
// }
ConfigMgr::GetInstance().Attach("cache.cache_size", this);
}
CpuCacheMgr::~CpuCacheMgr() {
ConfigMgr::GetInstance().Detach("cache.cache_size", this);
}
CpuCacheMgr&
CpuCacheMgr::GetInstance() {
static CpuCacheMgr s_mgr;
return s_mgr;
}
void
CpuCacheMgr::ConfigUpdate(const std::string& name) {
// SetCapacity(config.cache.cache_size());
}
} // namespace cache
} // namespace milvus

40
core/src/cache/CpuCacheMgr.h vendored Normal file
View File

@ -0,0 +1,40 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "cache/CacheMgr.h"
#include "cache/DataObj.h"
#include "config/ConfigMgr.h"
namespace milvus {
namespace cache {
class CpuCacheMgr : public CacheMgr<DataObjPtr>, public ConfigObserver {
private:
CpuCacheMgr();
~CpuCacheMgr();
public:
static CpuCacheMgr&
GetInstance();
public:
void
ConfigUpdate(const std::string& name) override;
};
} // namespace cache
} // namespace milvus

28
core/src/cache/DataObj.h vendored Normal file
View File

@ -0,0 +1,28 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
namespace milvus {
namespace cache {
class DataObj {
public:
virtual int64_t
Size() = 0;
};
using DataObjPtr = std::shared_ptr<DataObj>;
} // namespace cache
} // namespace milvus

63
core/src/cache/GpuCacheMgr.cpp vendored Normal file
View File

@ -0,0 +1,63 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "cache/GpuCacheMgr.h"
#include "config/ServerConfig.h"
#include "utils/Log.h"
// #include <fiu/fiu-local.h>
#include <sstream>
#include <utility>
namespace milvus {
namespace cache {
#ifdef MILVUS_GPU_VERSION
std::mutex GpuCacheMgr::global_mutex_;
std::unordered_map<int64_t, GpuCacheMgrPtr> GpuCacheMgr::instance_;
GpuCacheMgr::GpuCacheMgr(int64_t gpu_id) : gpu_id_(gpu_id) {
std::string header = "[CACHE GPU" + std::to_string(gpu_id) + "]";
cache_ = std::make_shared<Cache<DataObjPtr>>(config.gpu.cache_size(), 1UL << 32, header);
if (config.gpu.cache_threshold() > 0.0) {
cache_->set_freemem_percent(config.gpu.cache_threshold());
}
ConfigMgr::GetInstance().Attach("gpu.cache_threshold", this);
}
GpuCacheMgr::~GpuCacheMgr() {
ConfigMgr::GetInstance().Detach("gpu.cache_threshold", this);
}
GpuCacheMgrPtr
GpuCacheMgr::GetInstance(int64_t gpu_id) {
if (instance_.find(gpu_id) == instance_.end()) {
std::lock_guard<std::mutex> lock(global_mutex_);
if (instance_.find(gpu_id) == instance_.end()) {
instance_[gpu_id] = std::make_shared<GpuCacheMgr>(gpu_id);
}
}
return instance_[gpu_id];
}
void
GpuCacheMgr::ConfigUpdate(const std::string& name) {
std::lock_guard<std::mutex> lock(global_mutex_);
for (auto& it : instance_) {
it.second->SetCapacity(config.gpu.cache_size());
}
}
#endif
} // namespace cache
} // namespace milvus

51
core/src/cache/GpuCacheMgr.h vendored Normal file
View File

@ -0,0 +1,51 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include "cache/CacheMgr.h"
#include "cache/DataObj.h"
#include "config/ConfigMgr.h"
namespace milvus {
namespace cache {
#ifdef MILVUS_GPU_VERSION
class GpuCacheMgr;
using GpuCacheMgrPtr = std::shared_ptr<GpuCacheMgr>;
using MutexPtr = std::shared_ptr<std::mutex>;
class GpuCacheMgr : public CacheMgr<DataObjPtr>, public ConfigObserver {
public:
explicit GpuCacheMgr(int64_t gpu_id);
~GpuCacheMgr();
static GpuCacheMgrPtr
GetInstance(int64_t gpu_id);
public:
void
ConfigUpdate(const std::string& name) override;
private:
int64_t gpu_id_;
static std::mutex global_mutex_;
static std::unordered_map<int64_t, GpuCacheMgrPtr> instance_;
};
#endif
} // namespace cache
} // namespace milvus

116
core/src/cache/LRU.h vendored Normal file
View File

@ -0,0 +1,116 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <cstddef>
#include <list>
#include <stdexcept>
#include <unordered_map>
#include <utility>
namespace milvus {
namespace cache {
template <typename key_t, typename value_t>
class LRU {
public:
typedef typename std::pair<key_t, value_t> key_value_pair_t;
typedef typename std::list<key_value_pair_t>::iterator list_iterator_t;
typedef typename std::list<key_value_pair_t>::reverse_iterator reverse_list_iterator_t;
explicit LRU(size_t max_size) : max_size_(max_size) {
}
void
put(const key_t& key, const value_t& value) {
auto it = cache_items_map_.find(key);
cache_items_list_.push_front(key_value_pair_t(key, value));
if (it != cache_items_map_.end()) {
cache_items_list_.erase(it->second);
cache_items_map_.erase(it);
}
cache_items_map_[key] = cache_items_list_.begin();
if (cache_items_map_.size() > max_size_) {
auto last = cache_items_list_.end();
last--;
cache_items_map_.erase(last->first);
cache_items_list_.pop_back();
}
}
const value_t&
get(const key_t& key) {
auto it = cache_items_map_.find(key);
if (it == cache_items_map_.end()) {
throw std::range_error("There is no such key in cache");
} else {
cache_items_list_.splice(cache_items_list_.begin(), cache_items_list_, it->second);
return it->second->second;
}
}
void
erase(const key_t& key) {
auto it = cache_items_map_.find(key);
if (it != cache_items_map_.end()) {
cache_items_list_.erase(it->second);
cache_items_map_.erase(it);
}
}
bool
exists(const key_t& key) const {
return cache_items_map_.find(key) != cache_items_map_.end();
}
size_t
size() const {
return cache_items_map_.size();
}
list_iterator_t
begin() {
iter_ = cache_items_list_.begin();
return iter_;
}
list_iterator_t
end() {
return cache_items_list_.end();
}
reverse_list_iterator_t
rbegin() {
return cache_items_list_.rbegin();
}
reverse_list_iterator_t
rend() {
return cache_items_list_.rend();
}
void
clear() {
cache_items_list_.clear();
cache_items_map_.clear();
}
private:
std::list<key_value_pair_t> cache_items_list_;
std::unordered_map<key_t, list_iterator_t> cache_items_map_;
size_t max_size_;
list_iterator_t iter_;
};
} // namespace cache
} // namespace milvus

View File

@ -0,0 +1,31 @@
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
# library
set( CONFIG_SRCS ConfigMgr.h
ConfigMgr.cpp
ConfigType.h
ConfigType.cpp
ServerConfig.h
ServerConfig.cpp
)
set( CONFIG_LIBS yaml-cpp
)
create_library(
TARGET config
SRCS ${CONFIG_SRCS}
LIBS ${CONFIG_LIBS}
)

View File

@ -0,0 +1,223 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <yaml-cpp/yaml.h>
#include <cstring>
#include <limits>
#include <unordered_map>
#include<iostream>
#include "config/ConfigMgr.h"
#include "config/ServerConfig.h"
namespace {
const int64_t MB = (1024ll * 1024);
const int64_t GB = (1024ll * 1024 * 1024);
void
Flatten(const YAML::Node& node, std::unordered_map<std::string, std::string>& target, const std::string& prefix) {
for (auto& it : node) {
auto key = prefix.empty() ? it.first.as<std::string>() : prefix + "." + it.first.as<std::string>();
switch (it.second.Type()) {
case YAML::NodeType::Null: {
target[key] = "";
break;
}
case YAML::NodeType::Scalar: {
target[key] = it.second.as<std::string>();
break;
}
case YAML::NodeType::Sequence: {
std::string value;
for (auto& sub : it.second) value += sub.as<std::string>() + ",";
target[key] = value;
break;
}
case YAML::NodeType::Map: {
Flatten(it.second, target, key);
break;
}
case YAML::NodeType::Undefined: {
throw "Unexpected";
}
default:
break;
}
}
}
void
ThrowIfNotSuccess(const milvus::ConfigStatus& cs) {
if (cs.set_return != milvus::SetReturn::SUCCESS) {
throw cs;
}
}
}; // namespace
namespace milvus {
ConfigMgr ConfigMgr::instance;
ConfigMgr::ConfigMgr() {
config_list_ = {
/* general */
{"timezone",
CreateStringConfig("timezone", false, &config.timezone.value, "UTC+8", nullptr, nullptr)},
/* network */
{"network.address", CreateStringConfig("network.address", false, &config.network.address.value,
"0.0.0.0", nullptr, nullptr)},
{"network.port", CreateIntegerConfig("network.port", false, 0, 65535, &config.network.port.value,
19530, nullptr, nullptr)},
/* pulsar */
{"pulsar.address", CreateStringConfig("pulsar.address", false, &config.pulsar.address.value,
"localhost", nullptr, nullptr)},
{"pulsar.port", CreateIntegerConfig("pulsar.port", false, 0, 65535, &config.pulsar.port.value,
6650, nullptr, nullptr)},
/* log */
{"logs.level", CreateStringConfig("logs.level", false, &config.logs.level.value, "debug", nullptr, nullptr)},
{"logs.trace.enable",
CreateBoolConfig("logs.trace.enable", false, &config.logs.trace.enable.value, true, nullptr, nullptr)},
{"logs.path",
CreateStringConfig("logs.path", false, &config.logs.path.value, "/var/lib/milvus/logs", nullptr, nullptr)},
{"logs.max_log_file_size", CreateSizeConfig("logs.max_log_file_size", false, 512 * MB, 4096 * MB,
&config.logs.max_log_file_size.value, 1024 * MB, nullptr, nullptr)},
{"logs.log_rotate_num", CreateIntegerConfig("logs.log_rotate_num", false, 0, 1024,
&config.logs.log_rotate_num.value, 0, nullptr, nullptr)},
/* tracing */
{"tracing.json_config_path", CreateStringConfig("tracing.json_config_path", false,
&config.tracing.json_config_path.value, "", nullptr, nullptr)},
/* invisible */
/* engine */
{"engine.build_index_threshold",
CreateIntegerConfig("engine.build_index_threshold", false, 0, std::numeric_limits<int64_t>::max(),
&config.engine.build_index_threshold.value, 4096, nullptr, nullptr)},
{"engine.search_combine_nq",
CreateIntegerConfig("engine.search_combine_nq", true, 0, std::numeric_limits<int64_t>::max(),
&config.engine.search_combine_nq.value, 64, nullptr, nullptr)},
{"engine.use_blas_threshold",
CreateIntegerConfig("engine.use_blas_threshold", true, 0, std::numeric_limits<int64_t>::max(),
&config.engine.use_blas_threshold.value, 1100, nullptr, nullptr)},
{"engine.omp_thread_num",
CreateIntegerConfig("engine.omp_thread_num", true, 0, std::numeric_limits<int64_t>::max(),
&config.engine.omp_thread_num.value, 0, nullptr, nullptr)},
{"engine.simd_type", CreateEnumConfig("engine.simd_type", false, &SimdMap, &config.engine.simd_type.value,
SimdType::AUTO, nullptr, nullptr)},
};
}
void
ConfigMgr::Init() {
std::lock_guard<std::mutex> lock(GetConfigMutex());
for (auto& kv : config_list_) {
kv.second->Init();
}
}
void
ConfigMgr::Load(const std::string& path) {
/* load from milvus.yaml */
auto yaml = YAML::LoadFile(path);
/* make it flattened */
std::unordered_map<std::string, std::string> flattened;
// auto proxy_yaml = yaml["porxy"];
auto other_yaml = YAML::Node{};
other_yaml["pulsar"] = yaml["pulsar"];
Flatten(yaml["proxy"], flattened, "");
Flatten(other_yaml, flattened, "");
// Flatten(yaml["proxy"], flattened, "");
/* update config */
for (auto& it : flattened) Set(it.first, it.second, false);
}
void
ConfigMgr::Set(const std::string& name, const std::string& value, bool update) {
std::cout<<"InSet Config "<< name <<std::endl;
if (config_list_.find(name) == config_list_.end()){
std::cout<<"Config "<< name << " not found!"<<std::endl;
return;
}
try {
auto& config = config_list_.at(name);
std::unique_lock<std::mutex> lock(GetConfigMutex());
/* update=false when loading from config file */
if (not update) {
ThrowIfNotSuccess(config->Set(value, update));
} else if (config->modifiable_) {
/* set manually */
ThrowIfNotSuccess(config->Set(value, update));
lock.unlock();
Notify(name);
} else {
throw ConfigStatus(SetReturn::IMMUTABLE, "Config " + name + " is not modifiable");
}
} catch (ConfigStatus& cs) {
throw cs;
} catch (...) {
throw "Config " + name + " not found.";
}
}
std::string
ConfigMgr::Get(const std::string& name) const {
try {
auto& config = config_list_.at(name);
std::lock_guard<std::mutex> lock(GetConfigMutex());
return config->Get();
} catch (...) {
throw "Config " + name + " not found.";
}
}
std::string
ConfigMgr::Dump() const {
std::stringstream ss;
for (auto& kv : config_list_) {
auto& config = kv.second;
ss << config->name_ << ": " << config->Get() << std::endl;
}
return ss.str();
}
void
ConfigMgr::Attach(const std::string& name, ConfigObserver* observer) {
std::lock_guard<std::mutex> lock(observer_mutex_);
observers_[name].push_back(observer);
}
void
ConfigMgr::Detach(const std::string& name, ConfigObserver* observer) {
std::lock_guard<std::mutex> lock(observer_mutex_);
if (observers_.find(name) == observers_.end())
return;
auto& ob_list = observers_[name];
ob_list.remove(observer);
}
void
ConfigMgr::Notify(const std::string& name) {
std::lock_guard<std::mutex> lock(observer_mutex_);
if (observers_.find(name) == observers_.end())
return;
auto& ob_list = observers_[name];
for (auto& ob : ob_list) {
ob->ConfigUpdate(name);
}
}
} // namespace milvus

View File

@ -0,0 +1,92 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <list>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "config/ServerConfig.h"
namespace milvus {
class ConfigObserver {
public:
virtual ~ConfigObserver() {
}
virtual void
ConfigUpdate(const std::string& name) = 0;
};
using ConfigObserverPtr = std::shared_ptr<ConfigObserver>;
class ConfigMgr {
public:
static ConfigMgr&
GetInstance() {
return instance;
}
private:
static ConfigMgr instance;
public:
ConfigMgr();
ConfigMgr(const ConfigMgr&) = delete;
ConfigMgr&
operator=(const ConfigMgr&) = delete;
ConfigMgr(ConfigMgr&&) = delete;
ConfigMgr&
operator=(ConfigMgr&&) = delete;
public:
void
Init();
void
Load(const std::string& path);
void
Set(const std::string& name, const std::string& value, bool update = true);
std::string
Get(const std::string& name) const;
std::string
Dump() const;
public:
// Shared pointer should not be used here
void
Attach(const std::string& name, ConfigObserver* observer);
void
Detach(const std::string& name, ConfigObserver* observer);
private:
void
Notify(const std::string& name);
private:
std::unordered_map<std::string, BaseConfigPtr> config_list_;
std::mutex mutex_;
std::unordered_map<std::string, std::list<ConfigObserver*>> observers_;
std::mutex observer_mutex_;
};
} // namespace milvus

View File

@ -0,0 +1,528 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "config/ConfigType.h"
#include "config/ServerConfig.h"
#include <strings.h>
#include <algorithm>
#include <cassert>
#include <functional>
#include <sstream>
#include <string>
namespace {
std::unordered_map<std::string, int64_t> BYTE_UNITS = {
{"b", 1},
{"k", 1024},
{"m", 1024 * 1024},
{"g", 1024 * 1024 * 1024},
};
bool
is_integer(const std::string& s) {
if (not s.empty() && (std::isdigit(s[0]) || s[0] == '-')) {
auto ss = s.substr(1);
return std::find_if(ss.begin(), ss.end(), [](unsigned char c) { return !std::isdigit(c); }) == ss.end();
}
return false;
}
bool
is_number(const std::string& s) {
return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { return !std::isdigit(c); }) == s.end();
}
bool
is_alpha(const std::string& s) {
return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { return !std::isalpha(c); }) == s.end();
}
template <typename T>
bool
boundary_check(T val, T lower_bound, T upper_bound) {
return lower_bound <= val && val <= upper_bound;
}
bool
parse_bool(const std::string& str, std::string& err) {
if (!strcasecmp(str.c_str(), "true"))
return true;
else if (!strcasecmp(str.c_str(), "false"))
return false;
else
err = "The specified value must be true or false";
return false;
}
std::string
str_tolower(std::string s) {
std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
return s;
}
int64_t
parse_bytes(const std::string& str, std::string& err) {
try {
if (str.find_first_of('-') != std::string::npos) {
std::stringstream ss;
ss << "The specified value for memory (" << str << ") should be a positive integer.";
err = ss.str();
return 0;
}
std::string s = str;
if (is_number(s))
return std::stoll(s);
if (s.length() == 0)
return 0;
auto last_two = s.substr(s.length() - 2, 2);
auto last_one = s.substr(s.length() - 1);
if (is_alpha(last_two) && is_alpha(last_one))
if (last_one == "b" or last_one == "B")
s = s.substr(0, s.length() - 1);
auto& units = BYTE_UNITS;
auto suffix = str_tolower(s.substr(s.length() - 1));
std::string digits_part;
if (is_number(suffix)) {
digits_part = s;
suffix = 'b';
} else {
digits_part = s.substr(0, s.length() - 1);
}
if (is_number(digits_part) && (units.find(suffix) != units.end() || is_number(suffix))) {
auto digits = std::stoll(digits_part);
return digits * units[suffix];
} else {
std::stringstream ss;
ss << "The specified value for memory (" << str << ") should specify the units."
<< "The postfix should be one of the `b` `k` `m` `g` characters";
err = ss.str();
}
} catch (...) {
err = "Unknown error happened on parse bytes.";
}
return 0;
}
} // namespace
// Use (void) to silent unused warnings.
#define assertm(exp, msg) assert(((void)msg, exp))
namespace milvus {
std::vector<std::string>
OptionValue(const configEnum& ce) {
std::vector<std::string> ret;
for (auto& e : ce) {
ret.emplace_back(e.first);
}
return ret;
}
BaseConfig::BaseConfig(const char* name, const char* alias, bool modifiable)
: name_(name), alias_(alias), modifiable_(modifiable) {
}
void
BaseConfig::Init() {
assertm(not inited_, "already initialized");
inited_ = true;
}
BoolConfig::BoolConfig(const char* name, const char* alias, bool modifiable, bool* config, bool default_value,
std::function<bool(bool val, std::string& err)> is_valid_fn,
std::function<bool(bool val, bool prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
BoolConfig::Init() {
BaseConfig::Init();
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
BoolConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
std::string err;
bool value = parse_bool(val, err);
if (not err.empty())
return ConfigStatus(SetReturn::INVALID, err);
if (is_valid_fn_ && not is_valid_fn_(value, err))
return ConfigStatus(SetReturn::INVALID, err);
bool prev = *config_;
*config_ = value;
if (update && update_fn_ && not update_fn_(value, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
BoolConfig::Get() {
assertm(inited_, "uninitialized");
return *config_ ? "true" : "false";
}
StringConfig::StringConfig(
const char* name, const char* alias, bool modifiable, std::string* config, const char* default_value,
std::function<bool(const std::string& val, std::string& err)> is_valid_fn,
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
StringConfig::Init() {
BaseConfig::Init();
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
StringConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
std::string err;
if (is_valid_fn_ && not is_valid_fn_(val, err))
return ConfigStatus(SetReturn::INVALID, err);
std::string prev = *config_;
*config_ = val;
if (update && update_fn_ && not update_fn_(val, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
StringConfig::Get() {
assertm(inited_, "uninitialized");
return *config_;
}
EnumConfig::EnumConfig(const char* name, const char* alias, bool modifiable, configEnum* enumd, int64_t* config,
int64_t default_value, std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
enum_value_(enumd),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
EnumConfig::Init() {
BaseConfig::Init();
assert(enum_value_ != nullptr);
assertm(not enum_value_->empty(), "enum value empty");
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
EnumConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
if (enum_value_->find(val) == enum_value_->end()) {
auto option_values = OptionValue(*enum_value_);
std::stringstream ss;
ss << "Config " << name_ << "(" << val << ") must be one of following: ";
for (size_t i = 0; i < option_values.size() - 1; ++i) {
ss << option_values[i] << ", ";
}
ss << option_values.back() << ".";
return ConfigStatus(SetReturn::ENUM_VALUE_NOTFOUND, ss.str());
}
int64_t value = enum_value_->at(val);
std::string err;
if (is_valid_fn_ && not is_valid_fn_(value, err)) {
return ConfigStatus(SetReturn::INVALID, err);
}
int64_t prev = *config_;
*config_ = value;
if (update && update_fn_ && not update_fn_(value, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
EnumConfig::Get() {
assertm(inited_, "uninitialized");
for (auto& it : *enum_value_) {
if (*config_ == it.second) {
return it.first;
}
}
return "unknown";
}
IntegerConfig::IntegerConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound,
int64_t upper_bound, int64_t* config, int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
lower_bound_(lower_bound),
upper_bound_(upper_bound),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
IntegerConfig::Init() {
BaseConfig::Init();
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
IntegerConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
if (not is_integer(val)) {
std::stringstream ss;
ss << "Config " << name_ << "(" << val << ") must be a integer.";
return ConfigStatus(SetReturn::INVALID, ss.str());
}
int64_t value = std::stoll(val);
if (not boundary_check<int64_t>(value, lower_bound_, upper_bound_)) {
std::stringstream ss;
ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << ", " << upper_bound_
<< "].";
return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str());
}
std::string err;
if (is_valid_fn_ && not is_valid_fn_(value, err))
return ConfigStatus(SetReturn::INVALID, err);
int64_t prev = *config_;
*config_ = value;
if (update && update_fn_ && not update_fn_(value, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
IntegerConfig::Get() {
assertm(inited_, "uninitialized");
return std::to_string(*config_);
}
FloatingConfig::FloatingConfig(const char* name, const char* alias, bool modifiable, double lower_bound,
double upper_bound, double* config, double default_value,
std::function<bool(double val, std::string& err)> is_valid_fn,
std::function<bool(double val, double prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
lower_bound_(lower_bound),
upper_bound_(upper_bound),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
FloatingConfig::Init() {
BaseConfig::Init();
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
FloatingConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
double value = std::stod(val);
if (not boundary_check<double>(value, lower_bound_, upper_bound_)) {
std::stringstream ss;
ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << ", " << upper_bound_
<< "].";
return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str());
}
std::string err;
if (is_valid_fn_ && not is_valid_fn_(value, err))
return ConfigStatus(SetReturn::INVALID, err);
double prev = *config_;
*config_ = value;
if (update && update_fn_ && not update_fn_(value, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
FloatingConfig::Get() {
assertm(inited_, "uninitialized");
return std::to_string(*config_);
}
SizeConfig::SizeConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound,
int64_t* config, int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn)
: BaseConfig(name, alias, modifiable),
config_(config),
lower_bound_(lower_bound),
upper_bound_(upper_bound),
default_value_(default_value),
is_valid_fn_(std::move(is_valid_fn)),
update_fn_(std::move(update_fn)) {
}
void
SizeConfig::Init() {
BaseConfig::Init();
assert(config_ != nullptr);
*config_ = default_value_;
}
ConfigStatus
SizeConfig::Set(const std::string& val, bool update) {
assertm(inited_, "uninitialized");
try {
if (update and not modifiable_) {
std::stringstream ss;
ss << "Config " << name_ << " is immutable.";
return ConfigStatus(SetReturn::IMMUTABLE, ss.str());
}
std::string err;
int64_t value = parse_bytes(val, err);
if (not err.empty()) {
return ConfigStatus(SetReturn::INVALID, err);
}
if (not boundary_check<int64_t>(value, lower_bound_, upper_bound_)) {
std::stringstream ss;
ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << " Byte, " << upper_bound_
<< " Byte].";
return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str());
}
if (is_valid_fn_ && not is_valid_fn_(value, err)) {
return ConfigStatus(SetReturn::INVALID, err);
}
int64_t prev = *config_;
*config_ = value;
if (update && update_fn_ && not update_fn_(value, prev, err)) {
*config_ = prev;
return ConfigStatus(SetReturn::UPDATE_FAILURE, err);
}
return ConfigStatus(SetReturn::SUCCESS, "");
} catch (std::exception& e) {
return ConfigStatus(SetReturn::EXCEPTION, e.what());
} catch (...) {
return ConfigStatus(SetReturn::UNEXPECTED, "unexpected");
}
}
std::string
SizeConfig::Get() {
assertm(inited_, "uninitialized");
return std::to_string(*config_);
}
} // namespace milvus

View File

@ -0,0 +1,235 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace milvus {
using configEnum = const std::unordered_map<std::string, int64_t>;
std::vector<std::string>
OptionValue(const configEnum& ce);
enum SetReturn {
SUCCESS = 1,
IMMUTABLE,
ENUM_VALUE_NOTFOUND,
INVALID,
OUT_OF_RANGE,
UPDATE_FAILURE,
EXCEPTION,
UNEXPECTED,
};
struct ConfigStatus {
ConfigStatus(SetReturn sr, std::string msg) : set_return(sr), message(std::move(msg)) {
}
SetReturn set_return;
std::string message;
};
class BaseConfig {
public:
BaseConfig(const char* name, const char* alias, bool modifiable);
virtual ~BaseConfig() = default;
public:
bool inited_ = false;
const char* name_;
const char* alias_;
const bool modifiable_;
public:
virtual void
Init();
virtual ConfigStatus
Set(const std::string& value, bool update) = 0;
virtual std::string
Get() = 0;
};
using BaseConfigPtr = std::shared_ptr<BaseConfig>;
class BoolConfig : public BaseConfig {
public:
BoolConfig(const char* name, const char* alias, bool modifiable, bool* config, bool default_value,
std::function<bool(bool val, std::string& err)> is_valid_fn,
std::function<bool(bool val, bool prev, std::string& err)> update_fn);
private:
bool* config_;
const bool default_value_;
std::function<bool(bool val, std::string& err)> is_valid_fn_;
std::function<bool(bool val, bool prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
class StringConfig : public BaseConfig {
public:
StringConfig(const char* name, const char* alias, bool modifiable, std::string* config, const char* default_value,
std::function<bool(const std::string& val, std::string& err)> is_valid_fn,
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn);
private:
std::string* config_;
const char* default_value_;
std::function<bool(const std::string& val, std::string& err)> is_valid_fn_;
std::function<bool(const std::string& val, const std::string& prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
class EnumConfig : public BaseConfig {
public:
EnumConfig(const char* name, const char* alias, bool modifiable, configEnum* enumd, int64_t* config,
int64_t default_value, std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
private:
int64_t* config_;
configEnum* enum_value_;
const int64_t default_value_;
std::function<bool(int64_t val, std::string& err)> is_valid_fn_;
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
class IntegerConfig : public BaseConfig {
public:
IntegerConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound,
int64_t* config, int64_t default_value,
std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
private:
int64_t* config_;
int64_t lower_bound_;
int64_t upper_bound_;
const int64_t default_value_;
std::function<bool(int64_t val, std::string& err)> is_valid_fn_;
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
class FloatingConfig : public BaseConfig {
public:
FloatingConfig(const char* name, const char* alias, bool modifiable, double lower_bound, double upper_bound,
double* config, double default_value, std::function<bool(double val, std::string& err)> is_valid_fn,
std::function<bool(double val, double prev, std::string& err)> update_fn);
private:
double* config_;
double lower_bound_;
double upper_bound_;
const double default_value_;
std::function<bool(double val, std::string& err)> is_valid_fn_;
std::function<bool(double val, double prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
class SizeConfig : public BaseConfig {
public:
SizeConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound,
int64_t* config, int64_t default_value, std::function<bool(int64_t val, std::string& err)> is_valid_fn,
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn);
private:
int64_t* config_;
int64_t lower_bound_;
int64_t upper_bound_;
const int64_t default_value_;
std::function<bool(int64_t val, std::string& err)> is_valid_fn_;
std::function<bool(int64_t val, int64_t prev, std::string& err)> update_fn_;
public:
void
Init() override;
ConfigStatus
Set(const std::string& value, bool update) override;
std::string
Get() override;
};
#define CreateBoolConfig(name, modifiable, config_addr, default, is_valid, update) \
std::make_shared<BoolConfig>(name, nullptr, modifiable, config_addr, (default), is_valid, update)
#define CreateStringConfig(name, modifiable, config_addr, default, is_valid, update) \
std::make_shared<StringConfig>(name, nullptr, modifiable, config_addr, (default), is_valid, update)
#define CreateEnumConfig(name, modifiable, enumd, config_addr, default, is_valid, update) \
std::make_shared<EnumConfig>(name, nullptr, modifiable, enumd, config_addr, (default), is_valid, update)
#define CreateIntegerConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \
std::make_shared<IntegerConfig>(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \
is_valid, update)
#define CreateFloatingConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \
std::make_shared<FloatingConfig>(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \
is_valid, update)
#define CreateSizeConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \
std::make_shared<SizeConfig>(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \
is_valid, update)
} // namespace milvus

View File

@ -0,0 +1,493 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <cstring>
#include <functional>
#include "config/ServerConfig.h"
#include "gtest/gtest.h"
namespace milvus {
#define _MODIFIABLE (true)
#define _IMMUTABLE (false)
template <typename T>
class Utils {
public:
bool
validate_fn(const T& value, std::string& err) {
validate_value = value;
return true;
}
bool
update_fn(const T& value, const T& prev, std::string& err) {
new_value = value;
prev_value = prev;
return true;
}
protected:
T validate_value;
T new_value;
T prev_value;
};
/* ValidBoolConfigTest */
class ValidBoolConfigTest : public testing::Test, public Utils<bool> {
protected:
};
TEST_F(ValidBoolConfigTest, init_load_update_get_test) {
auto validate = std::bind(&ValidBoolConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidBoolConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
bool bool_value = true;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, false, validate, update);
ASSERT_EQ(bool_value, true);
ASSERT_EQ(bool_config->modifiable_, true);
bool_config->Init();
ASSERT_EQ(bool_value, false);
ASSERT_EQ(bool_config->Get(), "false");
{
// now `bool_value` is `false`, calling Set(update=false) to set it to `true`, but not notify update_fn()
validate_value = false;
new_value = false;
prev_value = true;
ConfigStatus status(SetReturn::SUCCESS, "");
status = bool_config->Set("true", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(bool_value, true);
EXPECT_EQ(bool_config->Get(), "true");
// expect change
EXPECT_EQ(validate_value, true);
// expect not change
EXPECT_EQ(new_value, false);
EXPECT_EQ(prev_value, true);
}
{
// now `bool_value` is `true`, calling Set(update=true) to set it to `false`, will notify update_fn()
validate_value = true;
new_value = true;
prev_value = false;
ConfigStatus status(SetReturn::SUCCESS, "");
status = bool_config->Set("false", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(bool_value, false);
EXPECT_EQ(bool_config->Get(), "false");
// expect change
EXPECT_EQ(validate_value, false);
EXPECT_EQ(new_value, false);
EXPECT_EQ(prev_value, true);
}
}
/* ValidStringConfigTest */
class ValidStringConfigTest : public testing::Test, public Utils<std::string> {
protected:
};
TEST_F(ValidStringConfigTest, init_load_update_get_test) {
auto validate = std::bind(&ValidStringConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidStringConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
std::string string_value;
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", validate, update);
ASSERT_EQ(string_value, "");
ASSERT_EQ(string_config->modifiable_, true);
string_config->Init();
ASSERT_EQ(string_value, "Magic");
ASSERT_EQ(string_config->Get(), "Magic");
{
// now `string_value` is `Magic`, calling Set(update=false) to set it to `cigaM`, but not notify update_fn()
validate_value = "";
new_value = "";
prev_value = "";
ConfigStatus status(SetReturn::SUCCESS, "");
status = string_config->Set("cigaM", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(string_value, "cigaM");
EXPECT_EQ(string_config->Get(), "cigaM");
// expect change
EXPECT_EQ(validate_value, "cigaM");
// expect not change
EXPECT_EQ(new_value, "");
EXPECT_EQ(prev_value, "");
}
{
// now `string_value` is `cigaM`, calling Set(update=true) to set it to `Check`, will notify update_fn()
validate_value = "";
new_value = "";
prev_value = "";
ConfigStatus status(SetReturn::SUCCESS, "");
status = string_config->Set("Check", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(string_value, "Check");
EXPECT_EQ(string_config->Get(), "Check");
// expect change
EXPECT_EQ(validate_value, "Check");
EXPECT_EQ(new_value, "Check");
EXPECT_EQ(prev_value, "cigaM");
}
}
/* ValidIntegerConfigTest */
class ValidIntegerConfigTest : public testing::Test, public Utils<int64_t> {
protected:
};
TEST_F(ValidIntegerConfigTest, init_load_update_get_test) {
auto validate = std::bind(&ValidIntegerConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidIntegerConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
int64_t integer_value = 0;
auto integer_config = CreateIntegerConfig("i", _MODIFIABLE, -100, 100, &integer_value, 42, validate, update);
ASSERT_EQ(integer_value, 0);
ASSERT_EQ(integer_config->modifiable_, true);
integer_config->Init();
ASSERT_EQ(integer_value, 42);
ASSERT_EQ(integer_config->Get(), "42");
{
// now `integer_value` is `42`, calling Set(update=false) to set it to `24`, but not notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = integer_config->Set("24", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(integer_value, 24);
EXPECT_EQ(integer_config->Get(), "24");
// expect change
EXPECT_EQ(validate_value, 24);
// expect not change
EXPECT_EQ(new_value, 0);
EXPECT_EQ(prev_value, 0);
}
{
// now `integer_value` is `24`, calling Set(update=true) to set it to `36`, will notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = integer_config->Set("36", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(integer_value, 36);
EXPECT_EQ(integer_config->Get(), "36");
// expect change
EXPECT_EQ(validate_value, 36);
EXPECT_EQ(new_value, 36);
EXPECT_EQ(prev_value, 24);
}
}
/* ValidFloatingConfigTest */
class ValidFloatingConfigTest : public testing::Test, public Utils<double> {
protected:
};
TEST_F(ValidFloatingConfigTest, init_load_update_get_test) {
auto validate =
std::bind(&ValidFloatingConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidFloatingConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
double floating_value = 0.0;
auto floating_config = CreateFloatingConfig("f", _MODIFIABLE, -10.0, 10.0, &floating_value, 3.14, validate, update);
ASSERT_FLOAT_EQ(floating_value, 0.0);
ASSERT_EQ(floating_config->modifiable_, true);
floating_config->Init();
ASSERT_FLOAT_EQ(floating_value, 3.14);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 3.14);
{
// now `floating_value` is `3.14`, calling Set(update=false) to set it to `6.22`, but not notify update_fn()
validate_value = 0.0;
new_value = 0.0;
prev_value = 0.0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = floating_config->Set("6.22", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
ASSERT_FLOAT_EQ(floating_value, 6.22);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 6.22);
// expect change
ASSERT_FLOAT_EQ(validate_value, 6.22);
// expect not change
ASSERT_FLOAT_EQ(new_value, 0.0);
ASSERT_FLOAT_EQ(prev_value, 0.0);
}
{
// now `integer_value` is `6.22`, calling Set(update=true) to set it to `-3.14`, will notify update_fn()
validate_value = 0.0;
new_value = 0.0;
prev_value = 0.0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = floating_config->Set("-3.14", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
ASSERT_FLOAT_EQ(floating_value, -3.14);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), -3.14);
// expect change
ASSERT_FLOAT_EQ(validate_value, -3.14);
ASSERT_FLOAT_EQ(new_value, -3.14);
ASSERT_FLOAT_EQ(prev_value, 6.22);
}
}
/* ValidEnumConfigTest */
class ValidEnumConfigTest : public testing::Test, public Utils<int64_t> {
protected:
};
// template <>
// int64_t Utils<int64_t>::validate_value = 0;
// template <>
// int64_t Utils<int64_t>::new_value = 0;
// template <>
// int64_t Utils<int64_t>::prev_value = 0;
TEST_F(ValidEnumConfigTest, init_load_update_get_test) {
auto validate = std::bind(&ValidEnumConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidEnumConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value = 0;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, validate, update);
ASSERT_EQ(enum_value, 0);
ASSERT_EQ(enum_config->modifiable_, true);
enum_config->Init();
ASSERT_EQ(enum_value, 1);
ASSERT_EQ(enum_config->Get(), "a");
{
// now `enum_value` is `a`, calling Set(update=false) to set it to `b`, but not notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = enum_config->Set("b", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
ASSERT_EQ(enum_value, 2);
ASSERT_EQ(enum_config->Get(), "b");
// expect change
ASSERT_EQ(validate_value, 2);
// expect not change
ASSERT_EQ(new_value, 0);
ASSERT_EQ(prev_value, 0);
}
{
// now `enum_value` is `b`, calling Set(update=true) to set it to `c`, will notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = enum_config->Set("c", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
ASSERT_EQ(enum_value, 3);
ASSERT_EQ(enum_config->Get(), "c");
// expect change
ASSERT_EQ(validate_value, 3);
ASSERT_EQ(new_value, 3);
ASSERT_EQ(prev_value, 2);
}
}
/* ValidSizeConfigTest */
class ValidSizeConfigTest : public testing::Test, public Utils<int64_t> {
protected:
};
// template <>
// int64_t Utils<int64_t>::validate_value = 0;
// template <>
// int64_t Utils<int64_t>::new_value = 0;
// template <>
// int64_t Utils<int64_t>::prev_value = 0;
TEST_F(ValidSizeConfigTest, init_load_update_get_test) {
auto validate = std::bind(&ValidSizeConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2);
auto update = std::bind(&ValidSizeConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3);
int64_t size_value = 0;
auto size_config = CreateSizeConfig("i", _MODIFIABLE, 0, 1024 * 1024, &size_value, 1024, validate, update);
ASSERT_EQ(size_value, 0);
ASSERT_EQ(size_config->modifiable_, true);
size_config->Init();
ASSERT_EQ(size_value, 1024);
ASSERT_EQ(size_config->Get(), "1024");
{
// now `size_value` is `1024`, calling Set(update=false) to set it to `4096`, but not notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = size_config->Set("4096", false);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(size_value, 4096);
EXPECT_EQ(size_config->Get(), "4096");
// expect change
EXPECT_EQ(validate_value, 4096);
// expect not change
EXPECT_EQ(new_value, 0);
EXPECT_EQ(prev_value, 0);
}
{
// now `size_value` is `4096`, calling Set(update=true) to set it to `256kb`, will notify update_fn()
validate_value = 0;
new_value = 0;
prev_value = 0;
ConfigStatus status(SetReturn::SUCCESS, "");
status = size_config->Set("256kb", true);
EXPECT_EQ(status.set_return, SetReturn::SUCCESS);
EXPECT_EQ(size_value, 256 * 1024);
EXPECT_EQ(size_config->Get(), "262144");
// expect change
EXPECT_EQ(validate_value, 262144);
EXPECT_EQ(new_value, 262144);
EXPECT_EQ(prev_value, 4096);
}
}
class ValidTest : public testing::Test {
protected:
configEnum family{
{"ipv4", 1},
{"ipv6", 2},
};
struct Server {
bool running = true;
std::string hostname;
int64_t family = 0;
int64_t port = 0;
double uptime = 0;
};
Server server;
protected:
void
SetUp() override {
config_list = {
CreateBoolConfig("running", true, &server.running, true, nullptr, nullptr),
CreateStringConfig("hostname", true, &server.hostname, "Magic", nullptr, nullptr),
CreateEnumConfig("socket_family", false, &family, &server.family, 2, nullptr, nullptr),
CreateIntegerConfig("port", true, 1024, 65535, &server.port, 19530, nullptr, nullptr),
CreateFloatingConfig("uptime", true, 0, 9999.0, &server.uptime, 0, nullptr, nullptr),
};
}
void
TearDown() override {
}
protected:
void
Init() {
for (auto& config : config_list) {
config->Init();
}
}
void
Load() {
std::unordered_map<std::string, std::string> config_file{
{"running", "false"},
};
for (auto& c : config_file) Set(c.first, c.second, false);
}
void
Set(const std::string& name, const std::string& value, bool update = true) {
for (auto& config : config_list) {
if (std::strcmp(name.c_str(), config->name_) == 0) {
config->Set(value, update);
return;
}
}
throw "Config " + name + " not found.";
}
std::string
Get(const std::string& name) {
for (auto& config : config_list) {
if (std::strcmp(name.c_str(), config->name_) == 0) {
return config->Get();
}
}
throw "Config " + name + " not found.";
}
std::vector<BaseConfigPtr> config_list;
};
} // namespace milvus

View File

@ -0,0 +1,861 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "config/ServerConfig.h"
#include "gtest/gtest.h"
namespace milvus {
#define _MODIFIABLE (true)
#define _IMMUTABLE (false)
template <typename T>
class Utils {
public:
static bool
valid_check_failure(const T& value, std::string& err) {
err = "Value is invalid.";
return false;
}
static bool
update_failure(const T& value, const T& prev, std::string& err) {
err = "Update is failure";
return false;
}
static bool
valid_check_raise_string(const T& value, std::string& err) {
throw "string exception";
}
static bool
valid_check_raise_exception(const T& value, std::string& err) {
throw std::bad_alloc();
}
};
/* BoolConfigTest */
class BoolConfigTest : public testing::Test, public Utils<bool> {};
TEST_F(BoolConfigTest, nullptr_init_test) {
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, nullptr, true, nullptr, nullptr);
ASSERT_DEATH(bool_config->Init(), "nullptr");
}
TEST_F(BoolConfigTest, init_twice_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr);
ASSERT_DEATH(
{
bool_config->Init();
bool_config->Init();
},
"initialized");
}
TEST_F(BoolConfigTest, non_init_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr);
ASSERT_DEATH(bool_config->Set("false", true), "uninitialized");
ASSERT_DEATH(bool_config->Get(), "uninitialized");
}
TEST_F(BoolConfigTest, immutable_update_test) {
bool bool_value = false;
auto bool_config = CreateBoolConfig("b", _IMMUTABLE, &bool_value, true, nullptr, nullptr);
bool_config->Init();
ASSERT_EQ(bool_value, true);
ConfigStatus status(SUCCESS, "");
status = bool_config->Set("false", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_EQ(bool_value, true);
}
TEST_F(BoolConfigTest, set_invalid_value_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr);
bool_config->Init();
ConfigStatus status(SUCCESS, "");
status = bool_config->Set(" false", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("false ", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("afalse", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("falsee", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("abcdefg", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("123456", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
status = bool_config->Set("", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
}
TEST_F(BoolConfigTest, valid_check_fail_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_failure, nullptr);
bool_config->Init();
ConfigStatus status(SUCCESS, "");
status = bool_config->Set("123456", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(bool_config->Get(), "true");
}
TEST_F(BoolConfigTest, update_fail_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, update_failure);
bool_config->Init();
ConfigStatus status(SUCCESS, "");
status = bool_config->Set("false", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_EQ(bool_config->Get(), "true");
}
TEST_F(BoolConfigTest, string_exception_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_raise_string, nullptr);
bool_config->Init();
ConfigStatus status(SUCCESS, "");
status = bool_config->Set("false", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_EQ(bool_config->Get(), "true");
}
TEST_F(BoolConfigTest, standard_exception_test) {
bool bool_value;
auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_raise_exception, nullptr);
bool_config->Init();
ConfigStatus status(SUCCESS, "");
status = bool_config->Set("false", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_EQ(bool_config->Get(), "true");
}
/* StringConfigTest */
class StringConfigTest : public testing::Test, public Utils<std::string> {};
TEST_F(StringConfigTest, nullptr_init_test) {
auto string_config = CreateStringConfig("s", true, nullptr, "Magic", nullptr, nullptr);
ASSERT_DEATH(string_config->Init(), "nullptr");
}
TEST_F(StringConfigTest, init_twice_test) {
std::string string_value;
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, nullptr);
ASSERT_DEATH(
{
string_config->Init();
string_config->Init();
},
"initialized");
}
TEST_F(StringConfigTest, non_init_test) {
std::string string_value;
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, nullptr);
ASSERT_DEATH(string_config->Set("value", true), "uninitialized");
ASSERT_DEATH(string_config->Get(), "uninitialized");
}
TEST_F(StringConfigTest, immutable_update_test) {
std::string string_value;
auto string_config = CreateStringConfig("s", _IMMUTABLE, &string_value, "Magic", nullptr, nullptr);
string_config->Init();
ASSERT_EQ(string_value, "Magic");
ConfigStatus status(SUCCESS, "");
status = string_config->Set("cigaM", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_EQ(string_value, "Magic");
}
TEST_F(StringConfigTest, valid_check_fail_test) {
std::string string_value;
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_failure, nullptr);
string_config->Init();
ConfigStatus status(SUCCESS, "");
status = string_config->Set("123456", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(string_config->Get(), "Magic");
}
TEST_F(StringConfigTest, update_fail_test) {
std::string string_value;
auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, update_failure);
string_config->Init();
ConfigStatus status(SUCCESS, "");
status = string_config->Set("Mi", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_EQ(string_config->Get(), "Magic");
}
TEST_F(StringConfigTest, string_exception_test) {
std::string string_value;
auto string_config =
CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_raise_string, nullptr);
string_config->Init();
ConfigStatus status(SUCCESS, "");
status = string_config->Set("any", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_EQ(string_config->Get(), "Magic");
}
TEST_F(StringConfigTest, standard_exception_test) {
std::string string_value;
auto string_config =
CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_raise_exception, nullptr);
string_config->Init();
ConfigStatus status(SUCCESS, "");
status = string_config->Set("any", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_EQ(string_config->Get(), "Magic");
}
/* IntegerConfigTest */
class IntegerConfigTest : public testing::Test, public Utils<int64_t> {};
TEST_F(IntegerConfigTest, nullptr_init_test) {
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, nullptr, 19530, nullptr, nullptr);
ASSERT_DEATH(integer_config->Init(), "nullptr");
}
TEST_F(IntegerConfigTest, init_twice_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
ASSERT_DEATH(
{
integer_config->Init();
integer_config->Init();
},
"initialized");
}
TEST_F(IntegerConfigTest, non_init_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
ASSERT_DEATH(integer_config->Set("42", true), "uninitialized");
ASSERT_DEATH(integer_config->Get(), "uninitialized");
}
TEST_F(IntegerConfigTest, immutable_update_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", _IMMUTABLE, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
integer_config->Init();
ASSERT_EQ(integer_value, 19530);
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("2048", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_EQ(integer_value, 19530);
}
TEST_F(IntegerConfigTest, set_invalid_value_test) {
}
TEST_F(IntegerConfigTest, valid_check_fail_test) {
int64_t integer_value;
auto integer_config =
CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_failure, nullptr);
integer_config->Init();
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("2048", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "19530");
}
TEST_F(IntegerConfigTest, update_fail_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, update_failure);
integer_config->Init();
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("2048", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_EQ(integer_config->Get(), "19530");
}
TEST_F(IntegerConfigTest, string_exception_test) {
int64_t integer_value;
auto integer_config =
CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_raise_string, nullptr);
integer_config->Init();
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("2048", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_EQ(integer_config->Get(), "19530");
}
TEST_F(IntegerConfigTest, standard_exception_test) {
int64_t integer_value;
auto integer_config =
CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_raise_exception, nullptr);
integer_config->Init();
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("2048", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_EQ(integer_config->Get(), "19530");
}
TEST_F(IntegerConfigTest, out_of_range_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr);
integer_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("1023", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(integer_config->Get(), "19530");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("65536", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(integer_config->Get(), "19530");
}
}
TEST_F(IntegerConfigTest, invalid_bound_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 100, 0, &integer_value, 50, nullptr, nullptr);
integer_config->Init();
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("30", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(integer_config->Get(), "50");
}
TEST_F(IntegerConfigTest, invalid_format_test) {
int64_t integer_value;
auto integer_config = CreateIntegerConfig("i", true, 0, 100, &integer_value, 50, nullptr, nullptr);
integer_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("3-0", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("30-", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("+30", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("a30", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("30a", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
{
ConfigStatus status(SUCCESS, "");
status = integer_config->Set("3a0", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(integer_config->Get(), "50");
}
}
/* FloatingConfigTest */
class FloatingConfigTest : public testing::Test, public Utils<double> {};
TEST_F(FloatingConfigTest, nullptr_init_test) {
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, nullptr, 4.5, nullptr, nullptr);
ASSERT_DEATH(floating_config->Init(), "nullptr");
}
TEST_F(FloatingConfigTest, init_twice_test) {
double floating_value;
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr);
ASSERT_DEATH(
{
floating_config->Init();
floating_config->Init();
},
"initialized");
}
TEST_F(FloatingConfigTest, non_init_test) {
double floating_value;
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr);
ASSERT_DEATH(floating_config->Set("3.14", true), "uninitialized");
ASSERT_DEATH(floating_config->Get(), "uninitialized");
}
TEST_F(FloatingConfigTest, immutable_update_test) {
double floating_value;
auto floating_config = CreateFloatingConfig("f", _IMMUTABLE, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr);
floating_config->Init();
ASSERT_FLOAT_EQ(floating_value, 4.5);
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("1.23", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, set_invalid_value_test) {
}
TEST_F(FloatingConfigTest, valid_check_fail_test) {
double floating_value;
auto floating_config =
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_failure, nullptr);
floating_config->Init();
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("1.23", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, update_fail_test) {
double floating_value;
auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, update_failure);
floating_config->Init();
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("1.23", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, string_exception_test) {
double floating_value;
auto floating_config =
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_string, nullptr);
floating_config->Init();
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("1.23", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, standard_exception_test) {
double floating_value;
auto floating_config =
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_exception, nullptr);
floating_config->Init();
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("1.23", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, out_of_range_test) {
double floating_value;
auto floating_config =
CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_exception, nullptr);
floating_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("0.99", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
{
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("10.00", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
}
TEST_F(FloatingConfigTest, invalid_bound_test) {
double floating_value;
auto floating_config =
CreateFloatingConfig("f", true, 9.9, 1.0, &floating_value, 4.5, valid_check_raise_exception, nullptr);
floating_config->Init();
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("6.0", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
TEST_F(FloatingConfigTest, DISABLED_invalid_format_test) {
double floating_value;
auto floating_config = CreateFloatingConfig("f", true, 1.0, 100.0, &floating_value, 4.5, nullptr, nullptr);
floating_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("6.0.1", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
{
ConfigStatus status(SUCCESS, "");
status = floating_config->Set("6a0", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5);
}
}
/* EnumConfigTest */
class EnumConfigTest : public testing::Test, public Utils<int64_t> {};
TEST_F(EnumConfigTest, nullptr_init_test) {
configEnum testEnum{
{"e", 1},
};
int64_t testEnumValue;
auto enum_config_1 = CreateEnumConfig("e", _MODIFIABLE, &testEnum, nullptr, 2, nullptr, nullptr);
ASSERT_DEATH(enum_config_1->Init(), "nullptr");
auto enum_config_2 = CreateEnumConfig("e", _MODIFIABLE, nullptr, &testEnumValue, 2, nullptr, nullptr);
ASSERT_DEATH(enum_config_2->Init(), "nullptr");
auto enum_config_3 = CreateEnumConfig("e", _MODIFIABLE, nullptr, nullptr, 2, nullptr, nullptr);
ASSERT_DEATH(enum_config_3->Init(), "nullptr");
}
TEST_F(EnumConfigTest, init_twice_test) {
configEnum testEnum{
{"e", 1},
};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr);
ASSERT_DEATH(
{
enum_config->Init();
enum_config->Init();
},
"initialized");
}
TEST_F(EnumConfigTest, non_init_test) {
configEnum testEnum{
{"e", 1},
};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr);
ASSERT_DEATH(enum_config->Set("e", true), "uninitialized");
ASSERT_DEATH(enum_config->Get(), "uninitialized");
}
TEST_F(EnumConfigTest, immutable_update_test) {
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value = 0;
auto enum_config = CreateEnumConfig("e", _IMMUTABLE, &testEnum, &enum_value, 1, nullptr, nullptr);
enum_config->Init();
ASSERT_EQ(enum_value, 1);
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_EQ(enum_value, 1);
}
TEST_F(EnumConfigTest, set_invalid_value_check) {
configEnum testEnum{
{"a", 1},
};
int64_t enum_value = 0;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, nullptr, nullptr);
enum_config->Init();
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::ENUM_VALUE_NOTFOUND);
ASSERT_EQ(enum_config->Get(), "a");
}
TEST_F(EnumConfigTest, empty_enum_test) {
configEnum testEnum{};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr);
ASSERT_DEATH(enum_config->Init(), "empty");
}
TEST_F(EnumConfigTest, valid_check_fail_test) {
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_failure, nullptr);
enum_config->Init();
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(enum_config->Get(), "a");
}
TEST_F(EnumConfigTest, update_fail_test) {
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, nullptr, update_failure);
enum_config->Init();
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_EQ(enum_config->Get(), "a");
}
TEST_F(EnumConfigTest, string_exception_test) {
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value;
auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_raise_string, nullptr);
enum_config->Init();
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_EQ(enum_config->Get(), "a");
}
TEST_F(EnumConfigTest, standard_exception_test) {
configEnum testEnum{
{"a", 1},
{"b", 2},
{"c", 3},
};
int64_t enum_value;
auto enum_config =
CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_raise_exception, nullptr);
enum_config->Init();
ConfigStatus status(SUCCESS, "");
status = enum_config->Set("b", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_EQ(enum_config->Get(), "a");
}
/* SizeConfigTest */
class SizeConfigTest : public testing::Test, public Utils<int64_t> {};
TEST_F(SizeConfigTest, nullptr_init_test) {
auto size_config = CreateSizeConfig("i", true, 1024, 4096, nullptr, 2048, nullptr, nullptr);
ASSERT_DEATH(size_config->Init(), "nullptr");
}
TEST_F(SizeConfigTest, init_twice_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
ASSERT_DEATH(
{
size_config->Init();
size_config->Init();
},
"initialized");
}
TEST_F(SizeConfigTest, non_init_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
ASSERT_DEATH(size_config->Set("3000", true), "uninitialized");
ASSERT_DEATH(size_config->Get(), "uninitialized");
}
TEST_F(SizeConfigTest, immutable_update_test) {
int64_t size_value = 0;
auto size_config = CreateSizeConfig("i", _IMMUTABLE, 1024, 4096, &size_value, 2048, nullptr, nullptr);
size_config->Init();
ASSERT_EQ(size_value, 2048);
ConfigStatus status(SUCCESS, "");
status = size_config->Set("3000", true);
ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE);
ASSERT_EQ(size_value, 2048);
}
TEST_F(SizeConfigTest, set_invalid_value_test) {
}
TEST_F(SizeConfigTest, valid_check_fail_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_failure, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("3000", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, update_fail_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, update_failure);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("3000", true);
ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, string_exception_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_raise_string, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("3000", true);
ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, standard_exception_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_raise_exception, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("3000", true);
ASSERT_EQ(status.set_return, SetReturn::EXCEPTION);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, out_of_range_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
size_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = size_config->Set("1023", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(size_config->Get(), "2048");
}
{
ConfigStatus status(SUCCESS, "");
status = size_config->Set("4097", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(size_config->Get(), "2048");
}
}
TEST_F(SizeConfigTest, negative_integer_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("-3KB", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, invalid_bound_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 100, 0, &size_value, 50, nullptr, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("30", true);
ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE);
ASSERT_EQ(size_config->Get(), "50");
}
TEST_F(SizeConfigTest, invalid_unit_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
size_config->Init();
ConfigStatus status(SUCCESS, "");
status = size_config->Set("1 TB", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
TEST_F(SizeConfigTest, invalid_format_test) {
int64_t size_value;
auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr);
size_config->Init();
{
ConfigStatus status(SUCCESS, "");
status = size_config->Set("a10GB", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
{
ConfigStatus status(SUCCESS, "");
status = size_config->Set("200*0", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
{
ConfigStatus status(SUCCESS, "");
status = size_config->Set("10AB", true);
ASSERT_EQ(status.set_return, SetReturn::INVALID);
ASSERT_EQ(size_config->Get(), "2048");
}
}
} // namespace milvus

View File

@ -0,0 +1,62 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
#include "config/ServerConfig.h"
namespace milvus {
std::mutex config_mutex;
std::mutex&
GetConfigMutex() {
return config_mutex;
}
ServerConfig config;
std::vector<std::string>
ParsePreloadCollection(const std::string& str) {
std::stringstream ss(str);
std::vector<std::string> collections;
std::string collection;
while (std::getline(ss, collection, ',')) {
collections.push_back(collection);
}
return collections;
}
std::vector<int64_t>
ParseGPUDevices(const std::string& str) {
std::stringstream ss(str);
std::vector<int64_t> devices;
std::unordered_set<int64_t> device_set;
std::string device;
while (std::getline(ss, device, ',')) {
if (device.length() < 4) {
/* Invalid format string */
return {};
}
device_set.insert(std::stoll(device.substr(3)));
}
for (auto dev : device_set) devices.push_back(dev);
return devices;
}
} // namespace milvus

View File

@ -0,0 +1,112 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "config/ConfigType.h"
namespace milvus {
extern std::mutex&
GetConfigMutex();
template <typename T>
class ConfigValue {
public:
explicit ConfigValue(T init_value) : value(std::move(init_value)) {
}
const T&
operator()() {
std::lock_guard<std::mutex> lock(GetConfigMutex());
return value;
}
public:
T value;
};
enum ClusterRole {
RW = 1,
RO,
};
enum SimdType {
AUTO = 1,
SSE,
AVX2,
AVX512,
};
const configEnum SimdMap{
{"auto", SimdType::AUTO},
{"sse", SimdType::SSE},
{"avx2", SimdType::AVX2},
{"avx512", SimdType::AVX512},
};
struct ServerConfig {
using String = ConfigValue<std::string>;
using Bool = ConfigValue<bool>;
using Integer = ConfigValue<int64_t>;
using Floating = ConfigValue<double>;
String timezone{"unknown"};
struct Network {
String address{"unknown"};
Integer port{0};
} network;
struct Pulsar{
String address{"localhost"};
Integer port{6650};
}pulsar;
struct Engine {
Integer build_index_threshold{4096};
Integer search_combine_nq{0};
Integer use_blas_threshold{0};
Integer omp_thread_num{0};
Integer simd_type{0};
} engine;
struct Tracing {
String json_config_path{"unknown"};
} tracing;
struct Logs {
String level{"unknown"};
struct Trace {
Bool enable{false};
} trace;
String path{"unknown"};
Integer max_log_file_size{0};
Integer log_rotate_num{0};
} logs;
};
extern ServerConfig config;
extern std::mutex _config_mutex;
std::vector<std::string>
ParsePreloadCollection(const std::string&);
std::vector<int64_t>
ParseGPUDevices(const std::string&);
} // namespace milvus

View File

@ -0,0 +1,19 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <gtest/gtest.h>
#include "config/ServerConfig.h"
TEST(ServerConfigTest, parse_invalid_devices) {
auto collections = milvus::ParseGPUDevices("gpu0,gpu1");
ASSERT_EQ(collections.size(), 0);
}

View File

@ -13,5 +13,4 @@ add_library(milvus_dog_segment SHARED
)
#add_dependencies( segment sqlite mysqlpp )
target_link_libraries(milvus_dog_segment tbb milvus_utils pthread)
target_link_libraries(milvus_dog_segment tbb utils pthread knowhere log)

View File

@ -38,6 +38,7 @@ namespace milvus::dog_segment {
template <typename Type>
using FixedVector = std::vector<Type>;
constexpr int64_t DefaultElementPerChunk = 32 * 1024;
template <typename Type>
class ThreadSafeVector {
@ -91,7 +92,7 @@ class VectorBase {
virtual void set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) = 0;
};
template <typename Type, bool is_scalar = false, ssize_t ElementsPerChunk = 32 * 1024>
template <typename Type, bool is_scalar = false, ssize_t ElementsPerChunk = DefaultElementPerChunk>
class ConcurrentVector : public VectorBase {
public:
// constants

View File

@ -0,0 +1,37 @@
#pragma once
#include "AckResponder.h"
#include "SegmentDefs.h"
namespace milvus::dog_segment {
struct DeletedRecord {
std::atomic<int64_t> reserved = 0;
AckResponder ack_responder_;
ConcurrentVector<Timestamp, true> timestamps_;
ConcurrentVector<idx_t, true> uids_;
struct TmpBitmap {
// Just for query
int64_t del_barrier = 0;
std::vector<bool> bitmap;
};
std::shared_ptr<TmpBitmap> lru_;
std::shared_mutex shared_mutex_;
DeletedRecord(): lru_(std::make_shared<TmpBitmap>()) {}
auto get_lru_entry() {
std::shared_lock lck(shared_mutex_);
return lru_;
}
void insert_lru_entry(std::shared_ptr<TmpBitmap> new_entry) {
std::lock_guard lck(shared_mutex_);
if(new_entry->del_barrier <= lru_->del_barrier) {
// DO NOTHING
return;
}
lru_ = std::move(new_entry);
}
};
}

View File

@ -1,56 +1,55 @@
// #include "IndexMeta.h"
// #include <mutex>
// #include <cassert>
// namespace milvus::dog_segment {
//
// Status
// IndexMeta::AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode,
// IndexConfig config) {
// Entry entry{
// index_name,
// field_name,
// type,
// mode,
// std::move(config)
// };
// VerifyEntry(entry);
//
// if (entries_.count(index_name)) {
// throw std::invalid_argument("duplicate index_name");
// }
// // TODO: support multiple indexes for single field
// assert(!lookups_.count(field_name));
// lookups_[field_name] = index_name;
// entries_[index_name] = std::move(entry);
//
// return Status::OK();
// }
//
// Status
// IndexMeta::DropEntry(const std::string& index_name) {
// assert(entries_.count(index_name));
// auto entry = std::move(entries_[index_name]);
// if(lookups_[entry.field_name] == index_name) {
// lookups_.erase(entry.field_name);
// }
// return Status::OK();
// }
//
// void IndexMeta::VerifyEntry(const Entry &entry) {
// auto is_mode_valid = std::set{IndexMode::MODE_CPU, IndexMode::MODE_GPU}.count(entry.mode);
// if(!is_mode_valid) {
// throw std::invalid_argument("invalid mode");
// }
//
// auto& schema = *schema_;
// auto& field_meta = schema[entry.index_name];
// // TODO checking
// if(field_meta.is_vector()) {
// assert(entry.type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
// } else {
// assert(false);
// }
// }
//
// } // namespace milvus::dog_segment
//
#include "IndexMeta.h"
#include <mutex>
#include <cassert>
namespace milvus::dog_segment {
Status
IndexMeta::AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode,
IndexConfig config) {
Entry entry{
index_name,
field_name,
type,
mode,
std::move(config)
};
VerifyEntry(entry);
if (entries_.count(index_name)) {
throw std::invalid_argument("duplicate index_name");
}
// TODO: support multiple indexes for single field
assert(!lookups_.count(field_name));
lookups_[field_name] = index_name;
entries_[index_name] = std::move(entry);
return Status::OK();
}
Status
IndexMeta::DropEntry(const std::string& index_name) {
assert(entries_.count(index_name));
auto entry = std::move(entries_[index_name]);
if(lookups_[entry.field_name] == index_name) {
lookups_.erase(entry.field_name);
}
return Status::OK();
}
void IndexMeta::VerifyEntry(const Entry &entry) {
auto is_mode_valid = std::set{IndexMode::MODE_CPU, IndexMode::MODE_GPU}.count(entry.mode);
if(!is_mode_valid) {
throw std::invalid_argument("invalid mode");
}
auto& schema = *schema_;
auto& field_meta = schema[entry.field_name];
// TODO checking
if(field_meta.is_vector()) {
assert(entry.type == knowhere::IndexEnum::INDEX_FAISS_IVFPQ);
} else {
assert(false);
}
}
} // namespace milvus::dog_segment

View File

@ -3,55 +3,56 @@
//#include <shared_mutex>
//
//#include "SegmentDefs.h"
//#include "knowhere/index/IndexType.h"
//
// #include "dog_segment/SegmentBase.h"
#include "dog_segment/SegmentDefs.h"
#include "knowhere/index/IndexType.h"
#include "knowhere/common/Config.h"
#include <map>
#include <memory>
class IndexMeta;
namespace milvus::dog_segment {
//// TODO: this is
//class IndexMeta {
// public:
// IndexMeta(SchemaPtr schema) : schema_(schema) {
// }
// using IndexType = knowhere::IndexType;
// using IndexMode = knowhere::IndexMode;
// using IndexConfig = knowhere::Config;
//
// struct Entry {
// std::string index_name;
// std::string field_name;
// IndexType type;
// IndexMode mode;
// IndexConfig config;
// };
//
// Status
// AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode,
// IndexConfig config);
//
// Status
// DropEntry(const std::string& index_name);
//
// const std::map<std::string, Entry>&
// get_entries() {
// return entries_;
// }
//
// const Entry& lookup_by_field(const std::string& field_name) {
// auto index_name = lookups_.at(field_name);
// return entries_.at(index_name);
// }
// private:
// void
// VerifyEntry(const Entry& entry);
//
// private:
// SchemaPtr schema_;
// std::map<std::string, Entry> entries_; // index_name => Entry
// std::map<std::string, std::string> lookups_; // field_name => index_name
//};
//
// TODO: this is
class IndexMeta {
public:
IndexMeta(SchemaPtr schema) : schema_(schema) {
}
using IndexType = knowhere::IndexType;
using IndexMode = knowhere::IndexMode;
using IndexConfig = knowhere::Config;
struct Entry {
std::string index_name;
std::string field_name;
IndexType type;
IndexMode mode;
IndexConfig config;
};
Status
AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode,
IndexConfig config);
Status
DropEntry(const std::string& index_name);
const std::map<std::string, Entry>&
get_entries() {
return entries_;
}
const Entry& lookup_by_field(const std::string& field_name) {
auto index_name = lookups_.at(field_name);
return entries_.at(index_name);
}
private:
void
VerifyEntry(const Entry& entry);
private:
SchemaPtr schema_;
std::map<std::string, Entry> entries_; // index_name => Entry
std::map<std::string, std::string> lookups_; // field_name => index_name
};
using IndexMetaPtr = std::shared_ptr<IndexMeta>;
//
} // namespace milvus::dog_segment
//

View File

@ -7,6 +7,7 @@
#include "utils/Types.h"
// #include "knowhere/index/Index.h"
#include "utils/Status.h"
#include "dog_segment/IndexMeta.h"
namespace milvus::dog_segment {
using Timestamp = uint64_t; // TODO: use TiKV-like timestamp
@ -152,6 +153,13 @@ class Schema {
return sizeof_infos_;
}
std::optional<int> get_offset(const std::string& field_name) {
if(!offsets_.count(field_name)) {
return std::nullopt;
} else {
return offsets_[field_name];
}
}
const FieldMeta&
operator[](const std::string& field_name) const {
@ -160,7 +168,6 @@ class Schema {
auto offset = offset_iter->second;
return (*this)[offset];
}
private:
// this is where data holds
std::vector<FieldMeta> fields_;
@ -173,5 +180,6 @@ class Schema {
};
using SchemaPtr = std::shared_ptr<Schema>;
using idx_t = int64_t;
} // namespace milvus::dog_segment

View File

@ -5,6 +5,10 @@
#include <thread>
#include <queue>
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include <knowhere/index/vector_index/VecIndexFactory.h>
namespace milvus::dog_segment {
int
TestABI() {
@ -13,12 +17,31 @@ TestABI() {
std::unique_ptr<SegmentBase>
CreateSegment(SchemaPtr schema, IndexMetaPtr remote_index_meta) {
if (remote_index_meta == nullptr) {
auto index_meta = std::make_shared<IndexMeta>(schema);
auto dim = schema->operator[]("fakevec").get_dim();
// TODO: this is merge of query conf and insert conf
// TODO: should be splitted into multiple configs
auto conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, dim},
{milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::nprobe, 4},
{milvus::knowhere::IndexParams::m, 4},
{milvus::knowhere::IndexParams::nbits, 8},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{milvus::knowhere::meta::DEVICEID, 0},
};
index_meta->AddEntry("fakeindex", "fakevec", knowhere::IndexEnum::INDEX_FAISS_IVFPQ,
knowhere::IndexMode::MODE_CPU, conf);
remote_index_meta = index_meta;
}
auto segment = std::make_unique<SegmentNaive>(schema, remote_index_meta);
return segment;
}
SegmentNaive::Record::Record(const Schema& schema) : uids_(1), timestamps_(1) {
for (auto& field : schema) {
SegmentNaive::Record::Record(const Schema &schema) : uids_(1), timestamps_(1) {
for (auto &field : schema) {
if (field.is_vector()) {
assert(field.get_data_type() == DataType::VECTOR_FLOAT);
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<float>>(field.get_dim()));
@ -41,31 +64,32 @@ SegmentNaive::PreDelete(int64_t size) {
return reserved_begin;
}
auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier) -> std::shared_ptr<DeletedRecord::TmpBitmap> {
auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp,
int64_t insert_barrier) -> std::shared_ptr<DeletedRecord::TmpBitmap> {
auto old = deleted_record_.get_lru_entry();
if(old->del_barrier == del_barrier) {
if (old->del_barrier == del_barrier) {
return old;
}
auto current = std::make_shared<DeletedRecord::TmpBitmap>(*old);
auto& vec = current->bitmap;
auto &vec = current->bitmap;
if(del_barrier < old->del_barrier) {
for(auto del_index = del_barrier; del_index < old->del_barrier; ++del_index) {
if (del_barrier < old->del_barrier) {
for (auto del_index = del_barrier; del_index < old->del_barrier; ++del_index) {
// get uid in delete logs
auto uid = deleted_record_.uids_[del_index];
// map uid to corrensponding offsets, select the max one, which should be the target
// the max one should be closest to query_timestamp, so the delete log should refer to it
int64_t the_offset = -1;
auto [iter_b, iter_e] = uid2offset_.equal_range(uid);
for(auto iter = iter_b; iter != iter_e; ++iter) {
auto[iter_b, iter_e] = uid2offset_.equal_range(uid);
for (auto iter = iter_b; iter != iter_e; ++iter) {
auto offset = iter->second;
if(record_.timestamps_[offset] < query_timestamp) {
if (record_.timestamps_[offset] < query_timestamp) {
assert(offset < vec.size());
the_offset = std::max(the_offset, offset);
}
}
// if not found, skip
if(the_offset == -1) {
if (the_offset == -1) {
continue;
}
// otherwise, clear the flag
@ -74,29 +98,29 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times
return current;
} else {
vec.resize(insert_barrier);
for(auto del_index = old->del_barrier; del_index < del_barrier; ++del_index) {
for (auto del_index = old->del_barrier; del_index < del_barrier; ++del_index) {
// get uid in delete logs
auto uid = deleted_record_.uids_[del_index];
// map uid to corrensponding offsets, select the max one, which should be the target
// the max one should be closest to query_timestamp, so the delete log should refer to it
int64_t the_offset = -1;
auto [iter_b, iter_e] = uid2offset_.equal_range(uid);
for(auto iter = iter_b; iter != iter_e; ++iter) {
auto[iter_b, iter_e] = uid2offset_.equal_range(uid);
for (auto iter = iter_b; iter != iter_e; ++iter) {
auto offset = iter->second;
if(offset >= insert_barrier){
if (offset >= insert_barrier) {
continue;
}
if(offset >= vec.size()) {
if (offset >= vec.size()) {
continue;
}
if(record_.timestamps_[offset] < query_timestamp) {
if (record_.timestamps_[offset] < query_timestamp) {
assert(offset < vec.size());
the_offset = std::max(the_offset, offset);
}
}
// if not found, skip
if(the_offset == -1) {
if (the_offset == -1) {
continue;
}
@ -109,11 +133,11 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times
}
Status
SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t* uids_raw, const Timestamp* timestamps_raw,
const DogDataChunk& entities_raw) {
SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t *uids_raw, const Timestamp *timestamps_raw,
const DogDataChunk &entities_raw) {
assert(entities_raw.count == size);
assert(entities_raw.sizeof_per_row == schema_->get_total_sizeof());
auto raw_data = reinterpret_cast<const char*>(entities_raw.raw_data);
auto raw_data = reinterpret_cast<const char *>(entities_raw.raw_data);
// std::vector<char> entities(raw_data, raw_data + size * len_per_row);
auto len_per_row = entities_raw.sizeof_per_row;
@ -138,7 +162,7 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t* uids_r
std::vector<Timestamp> timestamps(size);
// #pragma omp parallel for
for (int index = 0; index < size; ++index) {
auto [t, uid, order_index] = ordering[index];
auto[t, uid, order_index] = ordering[index];
timestamps[index] = t;
uids[index] = uid;
for (int fid = 0; fid < schema_->size(); ++fid) {
@ -156,7 +180,7 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t* uids_r
record_.entity_vec_[fid]->set_data_raw(reserved_begin, entities[fid].data(), size);
}
for(int i = 0; i < uids.size(); ++i) {
for (int i = 0; i < uids.size(); ++i) {
auto uid = uids[i];
// NOTE: this must be the last step, cannot be put above
uid2offset_.insert(std::make_pair(uid, reserved_begin + i));
@ -197,7 +221,8 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t* uids_r
}
Status
SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t* uids_raw, const Timestamp* timestamps_raw) {
SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t *uids_raw,
const Timestamp *timestamps_raw) {
std::vector<std::tuple<Timestamp, idx_t>> ordering;
ordering.resize(size);
// #pragma omp parallel for
@ -209,7 +234,7 @@ SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t* uids_r
std::vector<Timestamp> timestamps(size);
// #pragma omp parallel for
for (int index = 0; index < size; ++index) {
auto [t, uid] = ordering[index];
auto[t, uid] = ordering[index];
timestamps[index] = t;
uids[index] = uid;
}
@ -228,44 +253,15 @@ SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t* uids_r
// TODO: remove mock
Status
SegmentNaive::QueryImpl(const query::QueryPtr& query, Timestamp timestamp, QueryResult& result) {
throw std::runtime_error("unimplemented");
// auto ack_count = ack_count_.load();
// assert(query == nullptr);
// assert(schema_->size() >= 1);
// const auto& field = schema_->operator[](0);
// assert(field.get_data_type() == DataType::VECTOR_FLOAT);
// assert(field.get_name() == "fakevec");
// auto dim = field.get_dim();
// // assume query vector is [0, 0, ..., 0]
// std::vector<float> query_vector(dim, 0);
// auto& target_vec = record.entity_vecs_[0];
// int current_index = -1;
// float min_diff = std::numeric_limits<float>::max();
// for (int index = 0; index < ack_count; ++index) {
// float diff = 0;
// int offset = index * dim;
// for (auto d = 0; d < dim; ++d) {
// auto v = target_vec[offset + d] - query_vector[d];
// diff += v * v;
// }
// if (diff < min_diff) {
// min_diff = diff;
// current_index = index;
// }
// }
// QueryResult query_result;
// query_result.row_num_ = 1;
// query_result.result_distances_.push_back(min_diff);
// query_result.result_ids_.push_back(record.uids_[current_index]);
// query_result.data_chunk_ = nullptr;
// result = std::move(query_result);
// return Status::OK();
SegmentNaive::QueryImpl(const query::QueryPtr &query, Timestamp timestamp, QueryResult &result) {
// assert(query);
throw std::runtime_error("unimplemnted");
}
template<typename RecordType>
int64_t get_barrier(const RecordType& record, Timestamp timestamp) {
auto& vec = record.timestamps_;
int64_t get_barrier(const RecordType &record, Timestamp timestamp) {
auto &vec = record.timestamps_;
int64_t beg = 0;
int64_t end = record.ack_responder_.GetAck();
while (beg < end) {
@ -280,11 +276,11 @@ int64_t get_barrier(const RecordType& record, Timestamp timestamp) {
}
Status
SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) {
// TODO: enable delete
// TODO: enable index
if(query_info == nullptr) {
if (query_info == nullptr) {
query_info = std::make_shared<query::Query>();
query_info->field_name = "fakevec";
query_info->topK = 10;
@ -294,12 +290,12 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult
std::default_random_engine e(42);
std::uniform_real_distribution<> dis(0.0, 1.0);
query_info->query_raw_data.resize(query_info->num_queries * dim);
for(auto& x: query_info->query_raw_data) {
for (auto &x: query_info->query_raw_data) {
x = dis(e);
}
}
auto& field = schema_->operator[](query_info->field_name);
auto &field = schema_->operator[](query_info->field_name);
assert(field.get_data_type() == DataType::VECTOR_FLOAT);
auto dim = field.get_dim();
@ -308,7 +304,7 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult
auto barrier = get_barrier(record_, timestamp);
auto del_barrier = get_barrier(deleted_record_, timestamp);
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, barrier);
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, barrier);
if (!bitmap_holder) {
throw std::runtime_error("fuck");
@ -316,13 +312,13 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult
auto bitmap = &bitmap_holder->bitmap;
if(topK > barrier) {
if (topK > barrier) {
topK = barrier;
}
auto get_L2_distance = [dim](const float* a, const float* b) {
auto get_L2_distance = [dim](const float *a, const float *b) {
float L2_distance = 0;
for(auto i = 0; i < dim; ++i) {
for (auto i = 0; i < dim; ++i) {
auto d = a[i] - b[i];
L2_distance += d * d;
}
@ -332,18 +328,18 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult
std::vector<std::priority_queue<std::pair<float, int>>> records(num_queries);
// TODO: optimize
auto vec_ptr = std::static_pointer_cast<ConcurrentVector<float>>(record_.entity_vec_[0]);
for(int64_t i = 0; i < barrier; ++i) {
if(i < bitmap->size() && bitmap->at(i)) {
for (int64_t i = 0; i < barrier; ++i) {
if (i < bitmap->size() && bitmap->at(i)) {
continue;
}
auto element = vec_ptr->get_element(i);
for(auto query_id = 0; query_id < num_queries; ++query_id) {
for (auto query_id = 0; query_id < num_queries; ++query_id) {
auto query_blob = query_info->query_raw_data.data() + query_id * dim;
auto dis = get_L2_distance(query_blob, element);
auto& record = records[query_id];
if(record.size() < topK) {
auto &record = records[query_id];
if (record.size() < topK) {
record.emplace(dis, i);
} else if(record.top().first > dis) {
} else if (record.top().first > dis) {
record.emplace(dis, i);
record.pop();
}
@ -359,11 +355,11 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult
result.result_ids_.resize(row_num);
result.result_distances_.resize(row_num);
for(int q_id = 0; q_id < num_queries; ++q_id) {
for (int q_id = 0; q_id < num_queries; ++q_id) {
// reverse
for(int i = 0; i < topK; ++i) {
for (int i = 0; i < topK; ++i) {
auto dst_id = topK - 1 - i + q_id * topK;
auto [dis, offset] = records[q_id].top();
auto[dis, offset] = records[q_id].top();
records[q_id].pop();
result.result_ids_[dst_id] = record_.uids_[offset];
result.result_distances_[dst_id] = dis;
@ -384,39 +380,62 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult
Status
SegmentNaive::Close() {
if (this->record_.reserved != this->record_.ack_responder_.GetAck()) {
std::runtime_error("insert not ready");
}
if (this->deleted_record_.reserved != this->record_.ack_responder_.GetAck()) {
std::runtime_error("delete not ready");
}
state_ = SegmentState::Closed;
return Status::OK();
// auto src_record = GetMutableRecord();
// assert(src_record);
//
// auto dst_record = std::make_shared<ImmutableRecord>(schema_->size());
//
// auto data_move = [](auto& dst_vec, const auto& src_vec) {
// assert(dst_vec.size() == 0);
// dst_vec.insert(dst_vec.begin(), src_vec.begin(), src_vec.end());
// };
// data_move(dst_record->uids_, src_record->uids_);
// data_move(dst_record->timestamps_, src_record->uids_);
//
// assert(src_record->entity_vecs_.size() == schema_->size());
// assert(dst_record->entity_vecs_.size() == schema_->size());
// for (int i = 0; i < schema_->size(); ++i) {
// data_move(dst_record->entity_vecs_[i], src_record->entity_vecs_[i]);
// }
// bool ready_old = false;
// record_immutable_ = dst_record;
// ready_immutable_.compare_exchange_strong(ready_old, true);
// if (ready_old) {
// throw std::logic_error("Close may be called twice, with potential race condition");
// }
// return Status::OK();
}
template<typename Type>
knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry) {
auto offset_opt = schema_->get_offset(entry.field_name);
assert(offset_opt.has_value());
auto offset = offset_opt.value();
auto field = (*schema_)[offset];
auto dim = field.get_dim();
auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode);
auto chunk_size = record_.uids_.chunk_size();
auto &uids = record_.uids_;
auto entities = record_.get_vec_entity<float>(offset);
std::vector<knowhere::DatasetPtr> datasets;
for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) {
auto &uids_chunk = uids.get_chunk(chunk_id);
auto &entities_chunk = entities->get_chunk(chunk_id);
int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
: DefaultElementPerChunk;
datasets.push_back(knowhere::GenDatasetWithIds(count, dim, entities_chunk.data(), uids_chunk.data()));
}
for (auto &ds: datasets) {
indexing->Train(ds, entry.config);
}
for (auto &ds: datasets) {
indexing->Add(ds, entry.config);
}
return indexing;
}
Status
SegmentNaive::BuildIndex() {
throw std::runtime_error("unimplemented");
// assert(ready_immutable_);
// throw std::runtime_error("unimplemented");
for (auto&[index_name, entry]: index_meta_->get_entries()) {
assert(entry.index_name == index_name);
const auto &field = (*schema_)[entry.field_name];
if (field.is_vector()) {
assert(field.get_data_type() == engine::DataType::VECTOR_FLOAT);
auto index_ptr = BuildVecIndexImpl<float>(entry);
indexings_[index_name] = index_ptr;
} else {
throw std::runtime_error("unimplemented");
}
}
return Status::OK();
}
} // namespace milvus::dog_segment

View File

@ -4,6 +4,7 @@
#include <tbb/concurrent_vector.h>
#include <shared_mutex>
#include <knowhere/index/vector_index/VecIndex.h>
#include "AckResponder.h"
#include "ConcurrentVector.h"
@ -11,7 +12,7 @@
// #include "knowhere/index/structured_index/StructuredIndex.h"
#include "query/GeneralQuery.h"
#include "utils/Status.h"
using idx_t = int64_t;
#include "dog_segment/DeletedRecord.h"
namespace milvus::dog_segment {
struct ColumnBasedDataChunk {
@ -87,6 +88,27 @@ class SegmentNaive : public SegmentBase {
return Status::OK();
}
public:
ssize_t
get_row_count() const override {
return record_.ack_responder_.GetAck();
}
SegmentState
get_state() const override {
return state_.load(std::memory_order_relaxed);
}
ssize_t
get_deleted_count() const override {
return 0;
}
public:
friend std::unique_ptr<SegmentBase>
CreateSegment(SchemaPtr schema, IndexMetaPtr index_meta);
explicit SegmentNaive(SchemaPtr schema, IndexMetaPtr index_meta)
: schema_(schema), index_meta_(index_meta), record_(*schema) {
}
private:
struct MutableRecord {
ConcurrentVector<uint64_t> uids_;
@ -103,79 +125,31 @@ class SegmentNaive : public SegmentBase {
ConcurrentVector<idx_t, true> uids_;
std::vector<std::shared_ptr<VectorBase>> entity_vec_;
Record(const Schema& schema);
template<typename Type>
auto get_vec_entity(int offset) {
return std::static_pointer_cast<ConcurrentVector<Type>>(entity_vec_[offset]);
}
};
tbb::concurrent_unordered_multimap<idx_t, int64_t> uid2offset_;
struct DeletedRecord {
std::atomic<int64_t> reserved = 0;
AckResponder ack_responder_;
ConcurrentVector<Timestamp, true> timestamps_;
ConcurrentVector<idx_t, true> uids_;
struct TmpBitmap {
// Just for query
int64_t del_barrier = 0;
std::vector<char> bitmap;
};
std::shared_ptr<TmpBitmap> lru_;
std::shared_mutex shared_mutex_;
DeletedRecord(): lru_(std::make_shared<TmpBitmap>()) {}
auto get_lru_entry() {
std::shared_lock lck(shared_mutex_);
return lru_;
}
void insert_lru_entry(std::shared_ptr<TmpBitmap> new_entry) {
std::lock_guard lck(shared_mutex_);
if(new_entry->del_barrier <= lru_->del_barrier) {
// DO NOTHING
return;
}
lru_ = std::move(new_entry);
}
};
std::shared_ptr<DeletedRecord::TmpBitmap> get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier);
Status
QueryImpl(const query::QueryPtr& query, Timestamp timestamp, QueryResult& results);
public:
ssize_t
get_row_count() const override {
return record_.ack_responder_.GetAck();
}
SegmentState
get_state() const override {
return state_.load(std::memory_order_relaxed);
}
ssize_t
get_deleted_count() const override {
return 0;
}
public:
friend std::unique_ptr<SegmentBase>
CreateSegment(SchemaPtr schema, IndexMetaPtr index_meta);
explicit SegmentNaive(SchemaPtr schema, IndexMetaPtr index_meta)
: schema_(schema), index_meta_(index_meta), record_(*schema) {
}
template<typename Type>
knowhere::IndexPtr BuildVecIndexImpl(const IndexMeta::Entry& entry);
private:
SchemaPtr schema_;
IndexMetaPtr index_meta_;
std::atomic<SegmentState> state_ = SegmentState::Open;
Record record_;
DeletedRecord deleted_record_;
// tbb::concurrent_unordered_map<uint64_t, int> internal_indexes_;
// std::shared_ptr<MutableRecord> record_mutable_;
// // to determined that if immutable data if available
// std::shared_ptr<ImmutableRecord> record_immutable_ = nullptr;
// std::unordered_map<int, knowhere::VecIndexPtr> vec_indexings_;
// // TODO: scalar indexing
// // std::unordered_map<int, knowhere::IndexPtr> scalar_indexings_;
// tbb::concurrent_unordered_multimap<int, Timestamp> delete_logs_;
IndexMetaPtr index_meta_;
std::unordered_map<std::string, knowhere::IndexPtr> indexings_; // index_name => indexing
};
} // namespace milvus::dog_segment

View File

@ -3,6 +3,9 @@
#include "SegmentBase.h"
#include "segment_c.h"
#include "Partition.h"
#include <knowhere/index/vector_index/VecIndex.h>
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include <knowhere/index/vector_index/VecIndexFactory.h>
CSegmentBase
@ -46,9 +49,6 @@ Insert(CSegmentBase c_segment,
dataChunk.count = count;
auto res = segment->Insert(reserved_offset, size, primary_keys, timestamps, dataChunk);
// TODO: delete print
// std::cout << "do segment insert, sizeof_per_row = " << sizeof_per_row << std::endl;
return res.code();
}
@ -58,7 +58,7 @@ PreInsert(CSegmentBase c_segment, long int size) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
// TODO: delete print
// std::cout << "PreInsert segment " << std::endl;
std::cout << "PreInsert segment " << std::endl;
return segment->PreInsert(size);
}
@ -81,7 +81,7 @@ PreDelete(CSegmentBase c_segment, long int size) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
// TODO: delete print
// std::cout << "PreDelete segment " << std::endl;
std::cout << "PreDelete segment " << std::endl;
return segment->PreDelete(size);
}
@ -114,6 +114,13 @@ Close(CSegmentBase c_segment) {
return status.code();
}
int
BuildIndex(CSegmentBase c_segment) {
auto segment = (milvus::dog_segment::SegmentBase*)c_segment;
auto status = segment->BuildIndex();
return status.code();
}
bool
IsOpened(CSegmentBase c_segment) {

View File

@ -50,6 +50,9 @@ Search(CSegmentBase c_segment,
int
Close(CSegmentBase c_segment);
int
BuildIndex(CSegmentBase c_segment);
bool
IsOpened(CSegmentBase c_segment);

1
core/src/index/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
cmake_build

View File

@ -0,0 +1,84 @@
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
cmake_minimum_required(VERSION 3.12)
message(STATUS "------------------------------KNOWHERE-----------------------------------")
message(STATUS "Building using CMake version: ${CMAKE_VERSION}")
project(knowhere LANGUAGES C CXX)
set(CMAKE_CXX_STANDARD 17)
# if no build build type is specified, default to release builds
if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif (NOT CMAKE_BUILD_TYPE)
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
message(STATUS "building milvus_engine on x86 architecture")
set(KNOWHERE_BUILD_ARCH x86_64)
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "(ppc)")
message(STATUS "building milvus_engine on ppc architecture")
set(KNOWHERE_BUILD_ARCH ppc64le)
else ()
message(WARNING "unknown processor type")
message(WARNING "CMAKE_SYSTEM_PROCESSOR=${CMAKE_SYSTEM_PROCESSOR}")
set(KNOWHERE_BUILD_ARCH unknown)
endif ()
if (CMAKE_BUILD_TYPE STREQUAL "Release")
set(BUILD_TYPE "release")
else ()
set(BUILD_TYPE "debug")
endif ()
message(STATUS "Build type = ${BUILD_TYPE}")
set(INDEX_SOURCE_DIR ${PROJECT_SOURCE_DIR})
set(INDEX_BINARY_DIR ${PROJECT_BINARY_DIR})
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${INDEX_SOURCE_DIR}/cmake")
include(ExternalProject)
include(DefineOptionsCore)
include(BuildUtilsCore)
using_ccache_if_defined( KNOWHERE_USE_CCACHE )
message(STATUS "Building Knowhere CPU version")
if (MILVUS_SUPPORT_SPTAG)
message(STATUS "Building Knowhere with SPTAG supported")
add_compile_definitions("MILVUS_SUPPORT_SPTAG")
endif ()
include(ThirdPartyPackagesCore)
if (CMAKE_BUILD_TYPE STREQUAL "Release")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -DELPP_THREAD_SAFE -fopenmp")
else ()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp")
endif ()
add_subdirectory(knowhere)
if (BUILD_COVERAGE STREQUAL "ON")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage")
endif ()
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
#if (KNOWHERE_BUILD_TESTS)
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS")
# add_subdirectory(unittest)
#endif ()
config_summary()

View File

@ -0,0 +1,112 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "index/archive/KnowhereResource.h"
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#endif
#include "config/ServerConfig.h"
#include "faiss/FaissHook.h"
// #include "scheduler/Utils.h"
#include "utils/Error.h"
#include "utils/Log.h"
// #include <fiu/fiu-local.h>
#include <map>
#include <set>
#include <string>
#include <utility>
#include <vector>
namespace milvus {
namespace engine {
constexpr int64_t M_BYTE = 1024 * 1024;
Status
KnowhereResource::Initialize() {
auto simd_type = config.engine.simd_type();
if (simd_type == SimdType::AVX512) {
faiss::faiss_use_avx512 = true;
faiss::faiss_use_avx2 = false;
faiss::faiss_use_sse = false;
} else if (simd_type == SimdType::AVX2) {
faiss::faiss_use_avx512 = false;
faiss::faiss_use_avx2 = true;
faiss::faiss_use_sse = false;
} else if (simd_type == SimdType::SSE) {
faiss::faiss_use_avx512 = false;
faiss::faiss_use_avx2 = false;
faiss::faiss_use_sse = true;
} else {
faiss::faiss_use_avx512 = true;
faiss::faiss_use_avx2 = true;
faiss::faiss_use_sse = true;
}
std::string cpu_flag;
if (faiss::hook_init(cpu_flag)) {
std::cout << "FAISS hook " << cpu_flag << std::endl;
LOG_ENGINE_DEBUG_ << "FAISS hook " << cpu_flag;
} else {
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
}
#ifdef MILVUS_GPU_VERSION
bool enable_gpu = config.gpu.enable();
// fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false);
if (!enable_gpu) {
return Status::OK();
}
struct GpuResourceSetting {
int64_t pinned_memory = 256 * M_BYTE;
int64_t temp_memory = 256 * M_BYTE;
int64_t resource_num = 2;
};
using GpuResourcesArray = std::map<int64_t, GpuResourceSetting>;
GpuResourcesArray gpu_resources;
// get build index gpu resource
std::vector<int64_t> build_index_gpus = ParseGPUDevices(config.gpu.build_index_devices());
for (auto gpu_id : build_index_gpus) {
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
}
// get search gpu resource
std::vector<int64_t> search_gpus = ParseGPUDevices(config.gpu.search_devices());
for (auto& gpu_id : search_gpus) {
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
}
// init gpu resources
for (auto& gpu_resource : gpu_resources) {
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(gpu_resource.first, gpu_resource.second.pinned_memory,
gpu_resource.second.temp_memory,
gpu_resource.second.resource_num);
}
#endif
return Status::OK();
}
Status
KnowhereResource::Finalize() {
#ifdef MILVUS_GPU_VERSION
knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource.
#endif
return Status::OK();
}
} // namespace engine
} // namespace milvus

View File

@ -0,0 +1,29 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include "utils/Status.h"
namespace milvus {
namespace engine {
class KnowhereResource {
public:
static Status
Initialize();
static Status
Finalize();
};
} // namespace engine
} // namespace milvus

77
core/src/index/build.sh Executable file
View File

@ -0,0 +1,77 @@
#!/bin/bash
BUILD_TYPE="Debug"
BUILD_UNITTEST="OFF"
INSTALL_PREFIX=$(pwd)/cmake_build
MAKE_CLEAN="OFF"
PROFILING="OFF"
while getopts "p:d:t:uhrcgm" arg
do
case $arg in
t)
BUILD_TYPE=$OPTARG # BUILD_TYPE
;;
u)
echo "Build and run unittest cases" ;
BUILD_UNITTEST="ON";
;;
p)
INSTALL_PREFIX=$OPTARG
;;
r)
if [[ -d cmake_build ]]; then
rm ./cmake_build -r
MAKE_CLEAN="ON"
fi
;;
g)
PROFILING="ON"
;;
h) # help
echo "
parameter:
-t: build type(default: Debug)
-u: building unit test options(default: OFF)
-p: install prefix(default: $(pwd)/knowhere)
-r: remove previous build directory(default: OFF)
-g: profiling(default: OFF)
usage:
./build.sh -t \${BUILD_TYPE} [-u] [-h] [-g] [-r] [-c]
"
exit 0
;;
?)
echo "unknown argument"
exit 1
;;
esac
done
if [[ ! -d cmake_build ]]; then
mkdir cmake_build
MAKE_CLEAN="ON"
fi
cd cmake_build
CUDA_COMPILER=/usr/local/cuda/bin/nvcc
if [[ ${MAKE_CLEAN} == "ON" ]]; then
CMAKE_CMD="cmake -DBUILD_UNIT_TEST=${BUILD_UNITTEST} \
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX}
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DCMAKE_CUDA_COMPILER=${CUDA_COMPILER} \
-DMILVUS_ENABLE_PROFILING=${PROFILING} \
../"
echo ${CMAKE_CMD}
${CMAKE_CMD}
make clean
fi
make -j 8 || exit 1
make install || exit 1

View File

@ -0,0 +1,218 @@
# Define a function that check last file modification
function(Check_Last_Modify cache_check_lists_file_path working_dir last_modified_commit_id)
if (EXISTS "${working_dir}")
if (EXISTS "${cache_check_lists_file_path}")
set(GIT_LOG_SKIP_NUM 0)
set(_MATCH_ALL ON CACHE BOOL "Match all")
set(_LOOP_STATUS ON CACHE BOOL "Whether out of loop")
file(STRINGS ${cache_check_lists_file_path} CACHE_IGNORE_TXT)
while (_LOOP_STATUS)
foreach (_IGNORE_ENTRY ${CACHE_IGNORE_TXT})
if (NOT _IGNORE_ENTRY MATCHES "^[^#]+")
continue()
endif ()
set(_MATCH_ALL OFF)
execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --name-status --pretty= WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE CHANGE_FILES)
if (NOT CHANGE_FILES STREQUAL "")
string(REPLACE "\n" ";" _CHANGE_FILES ${CHANGE_FILES})
foreach (_FILE_ENTRY ${_CHANGE_FILES})
string(REGEX MATCH "[^ \t]+$" _FILE_NAME ${_FILE_ENTRY})
execute_process(COMMAND sh -c "echo ${_FILE_NAME} | grep ${_IGNORE_ENTRY}" RESULT_VARIABLE return_code)
if (return_code EQUAL 0)
execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID)
set(${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE)
set(_LOOP_STATUS OFF)
endif ()
endforeach ()
else ()
set(_LOOP_STATUS OFF)
endif ()
endforeach ()
if (_MATCH_ALL)
execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID)
set(${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE)
set(_LOOP_STATUS OFF)
endif ()
math(EXPR GIT_LOG_SKIP_NUM "${GIT_LOG_SKIP_NUM} + 1")
endwhile (_LOOP_STATUS)
else ()
execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID)
set(${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE)
endif ()
else ()
message(FATAL_ERROR "The directory ${working_dir} does not exist")
endif ()
endfunction()
# Define a function that extracts a cached package
function(ExternalProject_Use_Cache project_name package_file install_path)
message(STATUS "Will use cached package file: ${package_file}")
ExternalProject_Add(${project_name}
DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E echo
"No download step needed (using cached package)"
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E echo
"No configure step needed (using cached package)"
BUILD_COMMAND ${CMAKE_COMMAND} -E echo
"No build step needed (using cached package)"
INSTALL_COMMAND ${CMAKE_COMMAND} -E echo
"No install step needed (using cached package)"
)
# We want our tar files to contain the Install/<package> prefix (not for any
# very special reason, only for consistency and so that we can identify them
# in the extraction logs) which means that we must extract them in the
# binary (top-level build) directory to have them installed in the right
# place for subsequent ExternalProjects to pick them up. It seems that the
# only way to control the working directory is with Add_Step!
ExternalProject_Add_Step(${project_name} extract
ALWAYS 1
COMMAND
${CMAKE_COMMAND} -E echo
"Extracting ${package_file} to ${install_path}"
COMMAND
${CMAKE_COMMAND} -E tar xzf ${package_file} ${install_path}
WORKING_DIRECTORY ${INDEX_BINARY_DIR}
)
ExternalProject_Add_StepTargets(${project_name} extract)
endfunction()
# Define a function that to create a new cached package
function(ExternalProject_Create_Cache project_name package_file install_path cache_username cache_password cache_path)
if (EXISTS ${package_file})
message(STATUS "Removing existing package file: ${package_file}")
file(REMOVE ${package_file})
endif ()
string(REGEX REPLACE "(.+)/.+$" "\\1" package_dir ${package_file})
if (NOT EXISTS ${package_dir})
file(MAKE_DIRECTORY ${package_dir})
endif ()
message(STATUS "Will create cached package file: ${package_file}")
ExternalProject_Add_Step(${project_name} package
DEPENDEES install
BYPRODUCTS ${package_file}
COMMAND ${CMAKE_COMMAND} -E echo "Updating cached package file: ${package_file}"
COMMAND ${CMAKE_COMMAND} -E tar czvf ${package_file} ${install_path}
COMMAND ${CMAKE_COMMAND} -E echo "Uploading package file ${package_file} to ${cache_path}"
COMMAND curl -u${cache_username}:${cache_password} -T ${package_file} ${cache_path}
)
ExternalProject_Add_StepTargets(${project_name} package)
endfunction()
function(ADD_THIRDPARTY_LIB LIB_NAME)
set(options)
set(one_value_args SHARED_LIB STATIC_LIB)
set(multi_value_args DEPS INCLUDE_DIRECTORIES)
cmake_parse_arguments(ARG
"${options}"
"${one_value_args}"
"${multi_value_args}"
${ARGN})
if (ARG_UNPARSED_ARGUMENTS)
message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}")
endif ()
if (ARG_STATIC_LIB AND ARG_SHARED_LIB)
if (NOT ARG_STATIC_LIB)
message(FATAL_ERROR "No static or shared library provided for ${LIB_NAME}")
endif ()
set(AUG_LIB_NAME "${LIB_NAME}_static")
add_library(${AUG_LIB_NAME} STATIC IMPORTED)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_LOCATION "${ARG_STATIC_LIB}")
if (ARG_DEPS)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}")
endif ()
message(STATUS "Added static library dependency ${AUG_LIB_NAME}: ${ARG_STATIC_LIB}")
if (ARG_INCLUDE_DIRECTORIES)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
"${ARG_INCLUDE_DIRECTORIES}")
endif ()
set(AUG_LIB_NAME "${LIB_NAME}_shared")
add_library(${AUG_LIB_NAME} SHARED IMPORTED)
if (WIN32)
# Mark the ".lib" location as part of a Windows DLL
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_IMPLIB "${ARG_SHARED_LIB}")
else ()
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_LOCATION "${ARG_SHARED_LIB}")
endif ()
if (ARG_DEPS)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}")
endif ()
message(STATUS "Added shared library dependency ${AUG_LIB_NAME}: ${ARG_SHARED_LIB}")
if (ARG_INCLUDE_DIRECTORIES)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
"${ARG_INCLUDE_DIRECTORIES}")
endif ()
elseif (ARG_STATIC_LIB)
set(AUG_LIB_NAME "${LIB_NAME}_static")
add_library(${AUG_LIB_NAME} STATIC IMPORTED)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_LOCATION "${ARG_STATIC_LIB}")
if (ARG_DEPS)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}")
endif ()
message(STATUS "Added static library dependency ${AUG_LIB_NAME}: ${ARG_STATIC_LIB}")
if (ARG_INCLUDE_DIRECTORIES)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
"${ARG_INCLUDE_DIRECTORIES}")
endif ()
elseif (ARG_SHARED_LIB)
set(AUG_LIB_NAME "${LIB_NAME}_shared")
add_library(${AUG_LIB_NAME} SHARED IMPORTED)
if (WIN32)
# Mark the ".lib" location as part of a Windows DLL
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_IMPLIB "${ARG_SHARED_LIB}")
else ()
set_target_properties(${AUG_LIB_NAME}
PROPERTIES IMPORTED_LOCATION "${ARG_SHARED_LIB}")
endif ()
message(STATUS "Added shared library dependency ${AUG_LIB_NAME}: ${ARG_SHARED_LIB}")
if (ARG_DEPS)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}")
endif ()
if (ARG_INCLUDE_DIRECTORIES)
set_target_properties(${AUG_LIB_NAME}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
"${ARG_INCLUDE_DIRECTORIES}")
endif ()
else ()
message(FATAL_ERROR "No static or shared library provided for ${LIB_NAME}")
endif ()
endfunction()
MACRO(using_ccache_if_defined KNOWHERE_USE_CCACHE)
if (MILVUS_USE_CCACHE)
find_program(CCACHE_FOUND ccache)
if (CCACHE_FOUND)
message(STATUS "Using ccache: ${CCACHE_FOUND}")
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_FOUND})
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ${CCACHE_FOUND})
# let ccache preserve C++ comments, because some of them may be
# meaningful to the compiler
set(ENV{CCACHE_COMMENTS} "1")
endif (CCACHE_FOUND)
endif ()
ENDMACRO(using_ccache_if_defined)

View File

@ -0,0 +1,169 @@
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
macro(set_option_category name)
set(KNOWHERE_OPTION_CATEGORY ${name})
list(APPEND "KNOWHERE_OPTION_CATEGORIES" ${name})
endmacro()
macro(define_option name description default)
option(${name} ${description} ${default})
list(APPEND "KNOWHERE_${KNOWHERE_OPTION_CATEGORY}_OPTION_NAMES" ${name})
set("${name}_OPTION_DESCRIPTION" ${description})
set("${name}_OPTION_DEFAULT" ${default})
set("${name}_OPTION_TYPE" "bool")
endmacro()
function(list_join lst glue out)
if ("${${lst}}" STREQUAL "")
set(${out} "" PARENT_SCOPE)
return()
endif ()
list(GET ${lst} 0 joined)
list(REMOVE_AT ${lst} 0)
foreach (item ${${lst}})
set(joined "${joined}${glue}${item}")
endforeach ()
set(${out} ${joined} PARENT_SCOPE)
endfunction()
macro(define_option_string name description default)
set(${name} ${default} CACHE STRING ${description})
list(APPEND "KNOWHERE_${KNOWHERE_OPTION_CATEGORY}_OPTION_NAMES" ${name})
set("${name}_OPTION_DESCRIPTION" ${description})
set("${name}_OPTION_DEFAULT" "\"${default}\"")
set("${name}_OPTION_TYPE" "string")
set("${name}_OPTION_ENUM" ${ARGN})
list_join("${name}_OPTION_ENUM" "|" "${name}_OPTION_ENUM")
if (NOT ("${${name}_OPTION_ENUM}" STREQUAL ""))
set_property(CACHE ${name} PROPERTY STRINGS ${ARGN})
endif ()
endmacro()
#----------------------------------------------------------------------
set_option_category("Thirdparty")
set(KNOWHERE_DEPENDENCY_SOURCE_DEFAULT "BUNDLED")
define_option_string(KNOWHERE_DEPENDENCY_SOURCE
"Method to use for acquiring KNOWHERE's build dependencies"
"${KNOWHERE_DEPENDENCY_SOURCE_DEFAULT}"
"AUTO"
"BUNDLED"
"SYSTEM")
define_option(KNOWHERE_USE_CCACHE "Use ccache when compiling (if available)" OFF)
define_option(KNOWHERE_VERBOSE_THIRDPARTY_BUILD
"Show output from ExternalProjects rather than just logging to files" ON)
define_option(KNOWHERE_BOOST_USE_SHARED "Rely on boost shared libraries where relevant" OFF)
define_option(KNOWHERE_BOOST_VENDORED "Use vendored Boost instead of existing Boost. \
Note that this requires linking Boost statically" OFF)
define_option(KNOWHERE_BOOST_HEADER_ONLY "Use only BOOST headers" OFF)
define_option(KNOWHERE_WITH_ARROW "Build with ARROW" OFF)
define_option(KNOWHERE_WITH_OPENBLAS "Build with OpenBLAS library" ON)
define_option(KNOWHERE_WITH_FAISS "Build with FAISS library" ON)
define_option(KNOWHERE_WITH_FAISS_GPU_VERSION "Build with FAISS GPU version" OFF)
define_option(FAISS_WITH_MKL "Build FAISS with MKL" OFF)
define_option(MILVUS_CUDA_ARCH "Build with CUDA arch" "DEFAULT")
#----------------------------------------------------------------------
set_option_category("Test and benchmark")
if (BUILD_UNIT_TEST)
define_option(KNOWHERE_BUILD_TESTS "Build the KNOWHERE googletest unit tests" ON)
else ()
define_option(KNOWHERE_BUILD_TESTS "Build the KNOWHERE googletest unit tests" OFF)
endif (BUILD_UNIT_TEST)
#----------------------------------------------------------------------
macro(config_summary)
message(STATUS "---------------------------------------------------------------------")
message(STATUS "KNOWHERE version: ${KNOWHERE_VERSION}")
message(STATUS)
message(STATUS "Build configuration summary:")
message(STATUS " Generator: ${CMAKE_GENERATOR}")
message(STATUS " Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS " Source directory: ${CMAKE_CURRENT_SOURCE_DIR}")
if (${CMAKE_EXPORT_COMPILE_COMMANDS})
message(
STATUS " Compile commands: ${CMAKE_CURRENT_BINARY_DIR}/compile_commands.json")
endif ()
foreach (category ${KNOWHERE_OPTION_CATEGORIES})
message(STATUS)
message(STATUS "${category} options:")
set(option_names ${KNOWHERE_${category}_OPTION_NAMES})
set(max_value_length 0)
foreach (name ${option_names})
string(LENGTH "\"${${name}}\"" value_length)
if (${max_value_length} LESS ${value_length})
set(max_value_length ${value_length})
endif ()
endforeach ()
foreach (name ${option_names})
if ("${${name}_OPTION_TYPE}" STREQUAL "string")
set(value "\"${${name}}\"")
else ()
set(value "${${name}}")
endif ()
set(default ${${name}_OPTION_DEFAULT})
set(description ${${name}_OPTION_DESCRIPTION})
string(LENGTH ${description} description_length)
if (${description_length} LESS 70)
string(
SUBSTRING
" "
${description_length} -1 description_padding)
else ()
set(description_padding "
")
endif ()
set(comment "[${name}]")
if ("${value}" STREQUAL "${default}")
set(comment "[default] ${comment}")
endif ()
if (NOT ("${${name}_OPTION_ENUM}" STREQUAL ""))
set(comment "${comment} [${${name}_OPTION_ENUM}]")
endif ()
string(
SUBSTRING "${value} "
0 ${max_value_length} value)
message(STATUS " ${description} ${description_padding} ${value} ${comment}")
endforeach ()
endforeach ()
endmacro()

View File

@ -0,0 +1,431 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# - Find Arrow (arrow/api.h, libarrow.a, libarrow.so)
# This module defines
# ARROW_FOUND, whether Arrow has been found
# ARROW_FULL_SO_VERSION, full shared object version of found Arrow "100.0.0"
# ARROW_IMPORT_LIB, path to libarrow's import library (Windows only)
# ARROW_INCLUDE_DIR, directory containing headers
# ARROW_LIBS, deprecated. Use ARROW_LIB_DIR instead
# ARROW_LIB_DIR, directory containing Arrow libraries
# ARROW_SHARED_IMP_LIB, deprecated. Use ARROW_IMPORT_LIB instead
# ARROW_SHARED_LIB, path to libarrow's shared library
# ARROW_SO_VERSION, shared object version of found Arrow such as "100"
# ARROW_STATIC_LIB, path to libarrow.a
# ARROW_VERSION, version of found Arrow
# ARROW_VERSION_MAJOR, major version of found Arrow
# ARROW_VERSION_MINOR, minor version of found Arrow
# ARROW_VERSION_PATCH, patch version of found Arrow
include(FindPkgConfig)
include(FindPackageHandleStandardArgs)
set(ARROW_SEARCH_LIB_PATH_SUFFIXES)
if(CMAKE_LIBRARY_ARCHITECTURE)
list(APPEND ARROW_SEARCH_LIB_PATH_SUFFIXES "lib/${CMAKE_LIBRARY_ARCHITECTURE}")
endif()
list(APPEND ARROW_SEARCH_LIB_PATH_SUFFIXES
"lib64"
"lib32"
"lib"
"bin")
set(ARROW_CONFIG_SUFFIXES
"_RELEASE"
"_RELWITHDEBINFO"
"_MINSIZEREL"
"_DEBUG"
"")
if(CMAKE_BUILD_TYPE)
string(TOUPPER ${CMAKE_BUILD_TYPE} ARROW_CONFIG_SUFFIX_PREFERRED)
set(ARROW_CONFIG_SUFFIX_PREFERRED "_${ARROW_CONFIG_SUFFIX_PREFERRED}")
list(INSERT ARROW_CONFIG_SUFFIXES 0 "${ARROW_CONFIG_SUFFIX_PREFERRED}")
endif()
if(NOT DEFINED ARROW_MSVC_STATIC_LIB_SUFFIX)
if(MSVC)
set(ARROW_MSVC_STATIC_LIB_SUFFIX "_static")
else()
set(ARROW_MSVC_STATIC_LIB_SUFFIX "")
endif()
endif()
# Internal function.
#
# Set shared library name for ${base_name} to ${output_variable}.
#
# Example:
# arrow_build_shared_library_name(ARROW_SHARED_LIBRARY_NAME arrow)
# # -> ARROW_SHARED_LIBRARY_NAME=libarrow.so on Linux
# # -> ARROW_SHARED_LIBRARY_NAME=libarrow.dylib on macOS
# # -> ARROW_SHARED_LIBRARY_NAME=arrow.dll with MSVC on Windows
# # -> ARROW_SHARED_LIBRARY_NAME=libarrow.dll with MinGW on Windows
function(arrow_build_shared_library_name output_variable base_name)
set(${output_variable}
"${CMAKE_SHARED_LIBRARY_PREFIX}${base_name}${CMAKE_SHARED_LIBRARY_SUFFIX}"
PARENT_SCOPE)
endfunction()
# Internal function.
#
# Set import library name for ${base_name} to ${output_variable}.
# This is useful only for MSVC build. Import library is used only
# with MSVC build.
#
# Example:
# arrow_build_import_library_name(ARROW_IMPORT_LIBRARY_NAME arrow)
# # -> ARROW_IMPORT_LIBRARY_NAME=arrow on Linux (meaningless)
# # -> ARROW_IMPORT_LIBRARY_NAME=arrow on macOS (meaningless)
# # -> ARROW_IMPORT_LIBRARY_NAME=arrow.lib with MSVC on Windows
# # -> ARROW_IMPORT_LIBRARY_NAME=libarrow.dll.a with MinGW on Windows
function(arrow_build_import_library_name output_variable base_name)
set(${output_variable}
"${CMAKE_IMPORT_LIBRARY_PREFIX}${base_name}${CMAKE_IMPORT_LIBRARY_SUFFIX}"
PARENT_SCOPE)
endfunction()
# Internal function.
#
# Set static library name for ${base_name} to ${output_variable}.
#
# Example:
# arrow_build_static_library_name(ARROW_STATIC_LIBRARY_NAME arrow)
# # -> ARROW_STATIC_LIBRARY_NAME=libarrow.a on Linux
# # -> ARROW_STATIC_LIBRARY_NAME=libarrow.a on macOS
# # -> ARROW_STATIC_LIBRARY_NAME=arrow.lib with MSVC on Windows
# # -> ARROW_STATIC_LIBRARY_NAME=libarrow.dll.a with MinGW on Windows
function(arrow_build_static_library_name output_variable base_name)
set(
${output_variable}
"${CMAKE_STATIC_LIBRARY_PREFIX}${base_name}${ARROW_MSVC_STATIC_LIB_SUFFIX}${CMAKE_STATIC_LIBRARY_SUFFIX}"
PARENT_SCOPE)
endfunction()
# Internal function.
#
# Set macro value for ${macro_name} in ${header_content} to ${output_variable}.
#
# Example:
# arrow_extract_macro_value(version_major
# "ARROW_VERSION_MAJOR"
# "#define ARROW_VERSION_MAJOR 1.0.0")
# # -> version_major=1.0.0
function(arrow_extract_macro_value output_variable macro_name header_content)
string(REGEX MATCH "#define +${macro_name} +[^\r\n]+" macro_definition
"${header_content}")
string(REGEX
REPLACE "^#define +${macro_name} +(.+)$" "\\1" macro_value "${macro_definition}")
set(${output_variable} "${macro_value}" PARENT_SCOPE)
endfunction()
# Internal macro only for arrow_find_package.
#
# Find package in HOME.
macro(arrow_find_package_home)
find_path(${prefix}_include_dir "${header_path}"
PATHS "${home}"
PATH_SUFFIXES "include"
NO_DEFAULT_PATH)
set(include_dir "${${prefix}_include_dir}")
set(${prefix}_INCLUDE_DIR "${include_dir}" PARENT_SCOPE)
if(MSVC)
set(CMAKE_SHARED_LIBRARY_SUFFIXES_ORIGINAL ${CMAKE_FIND_LIBRARY_SUFFIXES})
# .dll isn't found by find_library with MSVC because .dll isn't included in
# CMAKE_FIND_LIBRARY_SUFFIXES.
list(APPEND CMAKE_FIND_LIBRARY_SUFFIXES "${CMAKE_SHARED_LIBRARY_SUFFIX}")
endif()
find_library(${prefix}_shared_lib
NAMES "${shared_lib_name}"
PATHS "${home}"
PATH_SUFFIXES ${ARROW_SEARCH_LIB_PATH_SUFFIXES}
NO_DEFAULT_PATH)
if(MSVC)
set(CMAKE_SHARED_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES_ORIGINAL})
endif()
set(shared_lib "${${prefix}_shared_lib}")
set(${prefix}_SHARED_LIB "${shared_lib}" PARENT_SCOPE)
if(shared_lib)
add_library(${target_shared} SHARED IMPORTED)
set_target_properties(${target_shared} PROPERTIES IMPORTED_LOCATION "${shared_lib}")
if(include_dir)
set_target_properties(${target_shared}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${include_dir}")
endif()
find_library(${prefix}_import_lib
NAMES "${import_lib_name}"
PATHS "${home}"
PATH_SUFFIXES ${ARROW_SEARCH_LIB_PATH_SUFFIXES}
NO_DEFAULT_PATH)
set(import_lib "${${prefix}_import_lib}")
set(${prefix}_IMPORT_LIB "${import_lib}" PARENT_SCOPE)
if(import_lib)
set_target_properties(${target_shared} PROPERTIES IMPORTED_IMPLIB "${import_lib}")
endif()
endif()
find_library(${prefix}_static_lib
NAMES "${static_lib_name}"
PATHS "${home}"
PATH_SUFFIXES ${ARROW_SEARCH_LIB_PATH_SUFFIXES}
NO_DEFAULT_PATH)
set(static_lib "${${prefix}_static_lib}")
set(${prefix}_STATIC_LIB "${static_lib}" PARENT_SCOPE)
if(static_lib)
add_library(${target_static} STATIC IMPORTED)
set_target_properties(${target_static} PROPERTIES IMPORTED_LOCATION "${static_lib}")
if(include_dir)
set_target_properties(${target_static}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${include_dir}")
endif()
endif()
endmacro()
# Internal macro only for arrow_find_package.
#
# Find package by CMake package configuration.
macro(arrow_find_package_cmake_package_configuration)
# ARROW-5575: We need to split target files for each component
if(TARGET ${target_shared} OR TARGET ${target_static})
set(${cmake_package_name}_FOUND TRUE)
else()
find_package(${cmake_package_name} CONFIG)
endif()
if(${cmake_package_name}_FOUND)
set(${prefix}_USE_CMAKE_PACKAGE_CONFIG TRUE PARENT_SCOPE)
if(TARGET ${target_shared})
foreach(suffix ${ARROW_CONFIG_SUFFIXES})
get_target_property(shared_lib ${target_shared} IMPORTED_LOCATION${suffix})
if(shared_lib)
# Remove shared library version:
# libarrow.so.100.0.0 -> libarrow.so
# Because ARROW_HOME and pkg-config approaches don't add
# shared library version.
string(REGEX
REPLACE "(${CMAKE_SHARED_LIBRARY_SUFFIX})[.0-9]+$" "\\1" shared_lib
"${shared_lib}")
set(${prefix}_SHARED_LIB "${shared_lib}" PARENT_SCOPE)
break()
endif()
endforeach()
endif()
if(TARGET ${target_static})
foreach(suffix ${ARROW_CONFIG_SUFFIXES})
get_target_property(static_lib ${target_static} IMPORTED_LOCATION${suffix})
if(static_lib)
set(${prefix}_STATIC_LIB "${static_lib}" PARENT_SCOPE)
break()
endif()
endforeach()
endif()
endif()
endmacro()
# Internal macro only for arrow_find_package.
#
# Find package by pkg-config.
macro(arrow_find_package_pkg_config)
pkg_check_modules(${prefix}_PC ${pkg_config_name})
if(${prefix}_PC_FOUND)
set(${prefix}_USE_PKG_CONFIG TRUE PARENT_SCOPE)
set(include_dir "${${prefix}_PC_INCLUDEDIR}")
set(lib_dir "${${prefix}_PC_LIBDIR}")
set(shared_lib_paths "${${prefix}_PC_LINK_LIBRARIES}")
# Use the first shared library path as the IMPORTED_LOCATION
# for ${target_shared}. This assumes that the first shared library
# path is the shared library path for this module.
list(GET shared_lib_paths 0 first_shared_lib_path)
# Use the rest shared library paths as the INTERFACE_LINK_LIBRARIES
# for ${target_shared}. This assumes that the rest shared library
# paths are dependency library paths for this module.
list(LENGTH shared_lib_paths n_shared_lib_paths)
if(n_shared_lib_paths LESS_EQUAL 1)
set(rest_shared_lib_paths)
else()
list(SUBLIST
shared_lib_paths
1
-1
rest_shared_lib_paths)
endif()
set(${prefix}_VERSION "${${prefix}_PC_VERSION}" PARENT_SCOPE)
set(${prefix}_INCLUDE_DIR "${include_dir}" PARENT_SCOPE)
set(${prefix}_SHARED_LIB "${first_shared_lib_path}" PARENT_SCOPE)
add_library(${target_shared} SHARED IMPORTED)
set_target_properties(${target_shared}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
"${include_dir}"
INTERFACE_LINK_LIBRARIES
"${rest_shared_lib_paths}"
IMPORTED_LOCATION
"${first_shared_lib_path}")
find_library(${prefix}_static_lib
NAMES "${static_lib_name}"
PATHS "${lib_dir}"
NO_DEFAULT_PATH)
set(static_lib "${${prefix}_static_lib}")
set(${prefix}_STATIC_LIB "${static_lib}" PARENT_SCOPE)
if(static_lib)
add_library(${target_static} STATIC IMPORTED)
set_target_properties(${target_static}
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${include_dir}"
IMPORTED_LOCATION "${static_lib}")
endif()
endif()
endmacro()
function(arrow_find_package
prefix
home
base_name
header_path
cmake_package_name
pkg_config_name)
arrow_build_shared_library_name(shared_lib_name ${base_name})
arrow_build_import_library_name(import_lib_name ${base_name})
arrow_build_static_library_name(static_lib_name ${base_name})
set(target_shared ${base_name}_shared)
set(target_static ${base_name}_static)
if(home)
arrow_find_package_home()
set(${prefix}_FIND_APPROACH "HOME: ${home}" PARENT_SCOPE)
else()
arrow_find_package_cmake_package_configuration()
if(${cmake_package_name}_FOUND)
set(${prefix}_FIND_APPROACH
"CMake package configuration: ${cmake_package_name}"
PARENT_SCOPE)
else()
arrow_find_package_pkg_config()
set(${prefix}_FIND_APPROACH "pkg-config: ${pkg_config_name}" PARENT_SCOPE)
endif()
endif()
if(NOT include_dir)
if(TARGET ${target_shared})
get_target_property(include_dir ${target_shared} INTERFACE_INCLUDE_DIRECTORIES)
elseif(TARGET ${target_static})
get_target_property(include_dir ${target_static} INTERFACE_INCLUDE_DIRECTORIES)
endif()
endif()
if(include_dir)
set(${prefix}_INCLUDE_DIR "${include_dir}" PARENT_SCOPE)
endif()
if(shared_lib)
get_filename_component(lib_dir "${shared_lib}" DIRECTORY)
elseif(static_lib)
get_filename_component(lib_dir "${static_lib}" DIRECTORY)
else()
set(lib_dir NOTFOUND)
endif()
set(${prefix}_LIB_DIR "${lib_dir}" PARENT_SCOPE)
# For backward compatibility
set(${prefix}_LIBS "${lib_dir}" PARENT_SCOPE)
endfunction()
if(NOT "$ENV{ARROW_HOME}" STREQUAL "")
file(TO_CMAKE_PATH "$ENV{ARROW_HOME}" ARROW_HOME)
endif()
arrow_find_package(ARROW
"${ARROW_HOME}"
arrow
arrow/api.h
Arrow
arrow)
if(ARROW_HOME)
if(ARROW_INCLUDE_DIR)
file(READ "${ARROW_INCLUDE_DIR}/arrow/util/config.h" ARROW_CONFIG_H_CONTENT)
arrow_extract_macro_value(ARROW_VERSION_MAJOR "ARROW_VERSION_MAJOR"
"${ARROW_CONFIG_H_CONTENT}")
arrow_extract_macro_value(ARROW_VERSION_MINOR "ARROW_VERSION_MINOR"
"${ARROW_CONFIG_H_CONTENT}")
arrow_extract_macro_value(ARROW_VERSION_PATCH "ARROW_VERSION_PATCH"
"${ARROW_CONFIG_H_CONTENT}")
if("${ARROW_VERSION_MAJOR}" STREQUAL ""
OR "${ARROW_VERSION_MINOR}" STREQUAL ""
OR "${ARROW_VERSION_PATCH}" STREQUAL "")
set(ARROW_VERSION "0.0.0")
else()
set(ARROW_VERSION
"${ARROW_VERSION_MAJOR}.${ARROW_VERSION_MINOR}.${ARROW_VERSION_PATCH}")
endif()
arrow_extract_macro_value(ARROW_SO_VERSION_QUOTED "ARROW_SO_VERSION"
"${ARROW_CONFIG_H_CONTENT}")
string(REGEX REPLACE "^\"(.+)\"$" "\\1" ARROW_SO_VERSION "${ARROW_SO_VERSION_QUOTED}")
arrow_extract_macro_value(ARROW_FULL_SO_VERSION_QUOTED "ARROW_FULL_SO_VERSION"
"${ARROW_CONFIG_H_CONTENT}")
string(REGEX
REPLACE "^\"(.+)\"$" "\\1" ARROW_FULL_SO_VERSION
"${ARROW_FULL_SO_VERSION_QUOTED}")
endif()
else()
if(ARROW_USE_CMAKE_PACKAGE_CONFIG)
find_package(Arrow CONFIG)
elseif(ARROW_USE_PKG_CONFIG)
pkg_get_variable(ARROW_SO_VERSION arrow so_version)
pkg_get_variable(ARROW_FULL_SO_VERSION arrow full_so_version)
endif()
endif()
set(ARROW_ABI_VERSION ${ARROW_SO_VERSION})
mark_as_advanced(ARROW_ABI_VERSION
ARROW_CONFIG_SUFFIXES
ARROW_FULL_SO_VERSION
ARROW_IMPORT_LIB
ARROW_INCLUDE_DIR
ARROW_LIBS
ARROW_LIB_DIR
ARROW_SEARCH_LIB_PATH_SUFFIXES
ARROW_SHARED_IMP_LIB
ARROW_SHARED_LIB
ARROW_SO_VERSION
ARROW_STATIC_LIB
ARROW_VERSION
ARROW_VERSION_MAJOR
ARROW_VERSION_MINOR
ARROW_VERSION_PATCH)
find_package_handle_standard_args(Arrow REQUIRED_VARS
# The first required variable is shown
# in the found message. So this list is
# not sorted alphabetically.
ARROW_INCLUDE_DIR
ARROW_LIB_DIR
ARROW_FULL_SO_VERSION
ARROW_SO_VERSION
VERSION_VAR
ARROW_VERSION)
set(ARROW_FOUND ${Arrow_FOUND})
if(Arrow_FOUND AND NOT Arrow_FIND_QUIETLY)
message(STATUS "Arrow version: ${ARROW_VERSION} (${ARROW_FIND_APPROACH})")
message(STATUS "Arrow SO and ABI version: ${ARROW_SO_VERSION}")
message(STATUS "Arrow full SO version: ${ARROW_FULL_SO_VERSION}")
message(STATUS "Found the Arrow core shared library: ${ARROW_SHARED_LIB}")
message(STATUS "Found the Arrow core import library: ${ARROW_IMPORT_LIB}")
message(STATUS "Found the Arrow core static library: ${ARROW_STATIC_LIB}")
endif()

View File

@ -0,0 +1,93 @@
if (OpenBLAS_FOUND) # the git version propose a OpenBLASConfig.cmake
message(STATUS "OpenBLASConfig found")
set(OpenBLAS_INCLUDE_DIR ${OpenBLAS_INCLUDE_DIRS})
else()
message("OpenBLASConfig not found")
unset(OpenBLAS_DIR CACHE)
set(OpenBLAS_INCLUDE_SEARCH_PATHS
/usr/local/openblas/include
/usr/include
/usr/include/openblas
/usr/include/openblas-base
/usr/local/include
/usr/local/include/openblas
/usr/local/include/openblas-base
/opt/OpenBLAS/include
/usr/local/opt/openblas/include
$ENV{OpenBLAS_HOME}
$ENV{OpenBLAS_HOME}/include
)
set(OpenBLAS_LIB_SEARCH_PATHS
/usr/local/openblas/lib
/lib/
/lib/openblas-base
/lib64/
/usr/lib
/usr/lib/openblas-base
/usr/lib64
/usr/local/lib
/usr/local/lib64
/usr/local/opt/openblas/lib
/opt/OpenBLAS/lib
$ENV{OpenBLAS}
$ENV{OpenBLAS}/lib
$ENV{OpenBLAS_HOME}
$ENV{OpenBLAS_HOME}/lib
)
set(DEFAULT_OpenBLAS_LIB_PATH
/usr/local/openblas/lib
${OPENBLAS_PREFIX}/lib)
message("DEFAULT_OpenBLAS_LIB_PATH: ${DEFAULT_OpenBLAS_LIB_PATH}")
find_path(OpenBLAS_INCLUDE_DIR NAMES openblas_config.h lapacke.h PATHS ${OpenBLAS_INCLUDE_SEARCH_PATHS})
find_library(OpenBLAS_LIB NAMES openblas PATHS ${DEFAULT_OpenBLAS_LIB_PATH} NO_DEFAULT_PATH)
find_library(OpenBLAS_LIB NAMES openblas PATHS ${OpenBLAS_LIB_SEARCH_PATHS})
# mostly for debian
find_library(Lapacke_LIB NAMES lapacke PATHS ${DEFAULT_OpenBLAS_LIB_PATH} NO_DEFAULT_PATH)
find_library(Lapacke_LIB NAMES lapacke PATHS ${OpenBLAS_LIB_SEARCH_PATHS})
set(OpenBLAS_FOUND ON)
# Check include files
if(NOT OpenBLAS_INCLUDE_DIR)
set(OpenBLAS_FOUND OFF)
message(STATUS "Could not find OpenBLAS include. Turning OpenBLAS_FOUND off")
else()
message(STATUS "find OpenBLAS include:${OpenBLAS_INCLUDE_DIR} ")
endif()
# Check libraries
if(NOT OpenBLAS_LIB)
set(OpenBLAS_FOUND OFF)
message(STATUS "Could not find OpenBLAS lib. Turning OpenBLAS_FOUND off")
else()
message(STATUS "find OpenBLAS lib:${OpenBLAS_LIB} ")
endif()
if (OpenBLAS_FOUND)
set(FOUND_OPENBLAS "true" PARENT_SCOPE)
set(OpenBLAS_LIBRARIES ${OpenBLAS_LIB})
STRING(REGEX REPLACE "/libopenblas.so" "" OpenBLAS_LIB_DIR ${OpenBLAS_LIBRARIES})
message(STATUS "find OpenBLAS libraries:${OpenBLAS_LIBRARIES} ")
if (Lapacke_LIB)
set(OpenBLAS_LIBRARIES ${OpenBLAS_LIBRARIES} ${Lapacke_LIB})
endif()
if (NOT OpenBLAS_FIND_QUIETLY)
message(STATUS "Found OpenBLAS libraries: ${OpenBLAS_LIBRARIES}")
message(STATUS "Found OpenBLAS include: ${OpenBLAS_INCLUDE_DIR}")
endif()
else()
set(FOUND_OPENBLAS "false" PARENT_SCOPE)
if (OpenBLAS_FIND_REQUIRED)
message(FATAL_ERROR "Could not find OpenBLAS")
endif()
endif()
endif()
mark_as_advanced(
OpenBLAS_INCLUDE_DIR
OpenBLAS_LIBRARIES
OpenBLAS_LIB_DIR
)

View File

@ -0,0 +1,655 @@
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
set(KNOWHERE_THIRDPARTY_DEPENDENCIES
Arrow
FAISS
GTest
OpenBLAS
MKL
)
message(STATUS "Using ${KNOWHERE_DEPENDENCY_SOURCE} approach to find dependencies")
# For each dependency, set dependency source to global default, if unset
foreach (DEPENDENCY ${KNOWHERE_THIRDPARTY_DEPENDENCIES})
if ("${${DEPENDENCY}_SOURCE}" STREQUAL "")
set(${DEPENDENCY}_SOURCE ${KNOWHERE_DEPENDENCY_SOURCE})
endif ()
endforeach ()
macro(build_dependency DEPENDENCY_NAME)
if ("${DEPENDENCY_NAME}" STREQUAL "Arrow")
build_arrow()
elseif ("${DEPENDENCY_NAME}" STREQUAL "GTest")
build_gtest()
elseif ("${DEPENDENCY_NAME}" STREQUAL "OpenBLAS")
build_openblas()
elseif ("${DEPENDENCY_NAME}" STREQUAL "FAISS")
build_faiss()
elseif ("${DEPENDENCY_NAME}" STREQUAL "MKL")
build_mkl()
else ()
message(FATAL_ERROR "Unknown thirdparty dependency to build: ${DEPENDENCY_NAME}")
endif ()
endmacro()
macro(resolve_dependency DEPENDENCY_NAME)
if (${DEPENDENCY_NAME}_SOURCE STREQUAL "AUTO")
find_package(${DEPENDENCY_NAME} MODULE)
if (NOT ${${DEPENDENCY_NAME}_FOUND})
build_dependency(${DEPENDENCY_NAME})
endif ()
elseif (${DEPENDENCY_NAME}_SOURCE STREQUAL "BUNDLED")
build_dependency(${DEPENDENCY_NAME})
elseif (${DEPENDENCY_NAME}_SOURCE STREQUAL "SYSTEM")
find_package(${DEPENDENCY_NAME} REQUIRED)
endif ()
endmacro()
# ----------------------------------------------------------------------
# Identify OS
if (UNIX)
if (APPLE)
set(CMAKE_OS_NAME "osx" CACHE STRING "Operating system name" FORCE)
else (APPLE)
## Check for Debian GNU/Linux ________________
find_file(DEBIAN_FOUND debian_version debconf.conf
PATHS /etc
)
if (DEBIAN_FOUND)
set(CMAKE_OS_NAME "debian" CACHE STRING "Operating system name" FORCE)
endif (DEBIAN_FOUND)
## Check for Fedora _________________________
find_file(FEDORA_FOUND fedora-release
PATHS /etc
)
if (FEDORA_FOUND)
set(CMAKE_OS_NAME "fedora" CACHE STRING "Operating system name" FORCE)
endif (FEDORA_FOUND)
## Check for RedHat _________________________
find_file(REDHAT_FOUND redhat-release inittab.RH
PATHS /etc
)
if (REDHAT_FOUND)
set(CMAKE_OS_NAME "redhat" CACHE STRING "Operating system name" FORCE)
endif (REDHAT_FOUND)
## Extra check for Ubuntu ____________________
if (DEBIAN_FOUND)
## At its core Ubuntu is a Debian system, with
## a slightly altered configuration; hence from
## a first superficial inspection a system will
## be considered as Debian, which signifies an
## extra check is required.
find_file(UBUNTU_EXTRA legal issue
PATHS /etc
)
if (UBUNTU_EXTRA)
## Scan contents of file
file(STRINGS ${UBUNTU_EXTRA} UBUNTU_FOUND
REGEX Ubuntu
)
## Check result of string search
if (UBUNTU_FOUND)
set(CMAKE_OS_NAME "ubuntu" CACHE STRING "Operating system name" FORCE)
set(DEBIAN_FOUND FALSE)
endif (UBUNTU_FOUND)
endif (UBUNTU_EXTRA)
endif (DEBIAN_FOUND)
endif (APPLE)
endif (UNIX)
# ----------------------------------------------------------------------
# thirdparty directory
set(THIRDPARTY_DIR "${INDEX_SOURCE_DIR}/thirdparty")
# ----------------------------------------------------------------------
# ExternalProject options
string(TOUPPER ${CMAKE_BUILD_TYPE} UPPERCASE_BUILD_TYPE)
set(FAISS_FLAGS "-DELPP_THREAD_SAFE -fopenmp -Werror=return-type")
set(EP_CXX_FLAGS "${FAISS_FLAGS} ${CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}}")
set(EP_C_FLAGS "${FAISS_FLAGS} ${CMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}}")
if (NOT MSVC)
# Set -fPIC on all external projects
set(EP_CXX_FLAGS "${EP_CXX_FLAGS} -fPIC")
set(EP_C_FLAGS "${EP_C_FLAGS} -fPIC")
endif ()
# CC/CXX environment variables are captured on the first invocation of the
# builder (e.g make or ninja) instead of when CMake is invoked into to build
# directory. This leads to issues if the variables are exported in a subshell
# and the invocation of make/ninja is in distinct subshell without the same
# environment (CC/CXX).
set(EP_COMMON_TOOLCHAIN -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER})
if (CMAKE_AR)
set(EP_COMMON_TOOLCHAIN ${EP_COMMON_TOOLCHAIN} -DCMAKE_AR=${CMAKE_AR})
endif ()
if (CMAKE_RANLIB)
set(EP_COMMON_TOOLCHAIN ${EP_COMMON_TOOLCHAIN} -DCMAKE_RANLIB=${CMAKE_RANLIB})
endif ()
# External projects are still able to override the following declarations.
# cmake command line will favor the last defined variable when a duplicate is
# encountered. This requires that `EP_COMMON_CMAKE_ARGS` is always the first
# argument.
set(EP_COMMON_CMAKE_ARGS
${EP_COMMON_TOOLCHAIN}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DCMAKE_C_FLAGS=${EP_C_FLAGS}
-DCMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_C_FLAGS}
-DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_CXX_FLAGS})
if (NOT KNOWHERE_VERBOSE_THIRDPARTY_BUILD)
set(EP_LOG_OPTIONS LOG_CONFIGURE 1 LOG_BUILD 1 LOG_INSTALL 1 LOG_DOWNLOAD 1)
else ()
set(EP_LOG_OPTIONS)
endif ()
# Ensure that a default make is set
if ("${MAKE}" STREQUAL "")
if (NOT MSVC)
find_program(MAKE make)
endif ()
endif ()
set(MAKE_BUILD_ARGS "-j6")
# ----------------------------------------------------------------------
# Find pthreads
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
# ----------------------------------------------------------------------
# Versions and URLs for toolchain builds, which also can be used to configure
# offline builds
# Read toolchain versions from cpp/thirdparty/versions.txt
file(STRINGS "${THIRDPARTY_DIR}/versions.txt" TOOLCHAIN_VERSIONS_TXT)
foreach (_VERSION_ENTRY ${TOOLCHAIN_VERSIONS_TXT})
# Exclude comments
if (NOT _VERSION_ENTRY MATCHES "^[^#][A-Za-z0-9-_]+_VERSION=")
continue()
endif ()
string(REGEX MATCH "^[^=]*" _LIB_NAME ${_VERSION_ENTRY})
string(REPLACE "${_LIB_NAME}=" "" _LIB_VERSION ${_VERSION_ENTRY})
# Skip blank or malformed lines
if (${_LIB_VERSION} STREQUAL "")
continue()
endif ()
# For debugging
#message(STATUS "${_LIB_NAME}: ${_LIB_VERSION}")
set(${_LIB_NAME} "${_LIB_VERSION}")
endforeach ()
set(FAISS_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/thirdparty/faiss)
if (DEFINED ENV{KNOWHERE_ARROW_URL})
set(ARROW_SOURCE_URL "$ENV{KNOWHERE_ARROW_URL}")
else ()
set(ARROW_SOURCE_URL
"https://github.com/apache/arrow.git"
)
endif ()
if (DEFINED ENV{KNOWHERE_GTEST_URL})
set(GTEST_SOURCE_URL "$ENV{KNOWHERE_GTEST_URL}")
else ()
set(GTEST_SOURCE_URL
"https://github.com/google/googletest/archive/release-${GTEST_VERSION}.tar.gz")
endif ()
if (DEFINED ENV{KNOWHERE_OPENBLAS_URL})
set(OPENBLAS_SOURCE_URL "$ENV{KNOWHERE_OPENBLAS_URL}")
else ()
set(OPENBLAS_SOURCE_URL
"https://github.com/xianyi/OpenBLAS/archive/v${OPENBLAS_VERSION}.tar.gz")
endif ()
# ----------------------------------------------------------------------
# ARROW
set(ARROW_PREFIX "${INDEX_BINARY_DIR}/arrow_ep-prefix/src/arrow_ep/cpp")
macro(build_arrow)
message(STATUS "Building Apache ARROW-${ARROW_VERSION} from source")
set(ARROW_STATIC_LIB_NAME arrow)
set(ARROW_LIB_DIR "${ARROW_PREFIX}/lib")
set(ARROW_STATIC_LIB
"${ARROW_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}${ARROW_STATIC_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}"
)
set(ARROW_INCLUDE_DIR "${ARROW_PREFIX}/include")
set(ARROW_CMAKE_ARGS
${EP_COMMON_CMAKE_ARGS}
-DARROW_BUILD_STATIC=ON
-DARROW_BUILD_SHARED=OFF
-DARROW_USE_GLOG=OFF
-DCMAKE_INSTALL_PREFIX=${ARROW_PREFIX}
-DCMAKE_INSTALL_LIBDIR=${ARROW_LIB_DIR}
-DARROW_CUDA=OFF
-DARROW_FLIGHT=OFF
-DARROW_GANDIVA=OFF
-DARROW_GANDIVA_JAVA=OFF
-DARROW_HDFS=OFF
-DARROW_HIVESERVER2=OFF
-DARROW_ORC=OFF
-DARROW_PARQUET=OFF
-DARROW_PLASMA=OFF
-DARROW_PLASMA_JAVA_CLIENT=OFF
-DARROW_PYTHON=OFF
-DARROW_WITH_BZ2=OFF
-DARROW_WITH_ZLIB=OFF
-DARROW_WITH_LZ4=OFF
-DARROW_WITH_SNAPPY=OFF
-DARROW_WITH_ZSTD=OFF
-DARROW_WITH_BROTLI=OFF
-DCMAKE_BUILD_TYPE=Release
-DARROW_DEPENDENCY_SOURCE=BUNDLED #Build all arrow dependencies from source instead of calling find_package first
-DBOOST_SOURCE=AUTO #try to find BOOST in the system default locations and build from source if not found
)
externalproject_add(arrow_ep
GIT_REPOSITORY
${ARROW_SOURCE_URL}
GIT_TAG
${ARROW_VERSION}
GIT_SHALLOW
TRUE
SOURCE_SUBDIR
cpp
${EP_LOG_OPTIONS}
CMAKE_ARGS
${ARROW_CMAKE_ARGS}
BUILD_COMMAND
""
INSTALL_COMMAND
${MAKE} ${MAKE_BUILD_ARGS} install
BUILD_BYPRODUCTS
"${ARROW_STATIC_LIB}"
)
file(MAKE_DIRECTORY "${ARROW_INCLUDE_DIR}")
add_library(arrow STATIC IMPORTED)
set_target_properties(arrow
PROPERTIES IMPORTED_LOCATION "${ARROW_STATIC_LIB}"
INTERFACE_INCLUDE_DIRECTORIES "${ARROW_INCLUDE_DIR}")
add_dependencies(arrow arrow_ep)
set(JEMALLOC_PREFIX "${INDEX_BINARY_DIR}/arrow_ep-prefix/src/arrow_ep-build/jemalloc_ep-prefix/src/jemalloc_ep")
add_custom_command(TARGET arrow_ep POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${ARROW_LIB_DIR}
COMMAND ${CMAKE_COMMAND} -E copy ${JEMALLOC_PREFIX}/lib/libjemalloc_pic.a ${ARROW_LIB_DIR}
DEPENDS ${JEMALLOC_PREFIX}/lib/libjemalloc_pic.a)
endmacro()
if (KNOWHERE_WITH_ARROW AND NOT TARGET arrow_ep)
resolve_dependency(Arrow)
link_directories(SYSTEM ${ARROW_LIB_DIR})
include_directories(SYSTEM ${ARROW_INCLUDE_DIR})
endif ()
# ----------------------------------------------------------------------
# OpenBLAS
set(OPENBLAS_PREFIX "${INDEX_BINARY_DIR}/openblas_ep-prefix/src/openblas_ep")
macro(build_openblas)
message(STATUS "Building OpenBLAS-${OPENBLAS_VERSION} from source")
set(OpenBLAS_INCLUDE_DIR "${OPENBLAS_PREFIX}/include")
set(OpenBLAS_LIB_DIR "${OPENBLAS_PREFIX}/lib")
set(OPENBLAS_SHARED_LIB
"${OPENBLAS_PREFIX}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}openblas${CMAKE_SHARED_LIBRARY_SUFFIX}")
set(OPENBLAS_STATIC_LIB
"${OPENBLAS_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}")
set(OPENBLAS_CMAKE_ARGS
${EP_COMMON_CMAKE_ARGS}
-DCMAKE_BUILD_TYPE=Release
-DBUILD_SHARED_LIBS=ON
-DBUILD_STATIC_LIBS=ON
-DTARGET=CORE2
-DDYNAMIC_ARCH=1
-DDYNAMIC_OLDER=1
-DUSE_THREAD=0
-DUSE_OPENMP=0
-DFC=gfortran
-DCC=gcc
-DINTERFACE64=0
-DNUM_THREADS=128
-DNO_LAPACKE=1
"-DVERSION=${OPENBLAS_VERSION}"
"-DCMAKE_INSTALL_PREFIX=${OPENBLAS_PREFIX}"
-DCMAKE_INSTALL_LIBDIR=lib)
externalproject_add(openblas_ep
URL
${OPENBLAS_SOURCE_URL}
${EP_LOG_OPTIONS}
CMAKE_ARGS
${OPENBLAS_CMAKE_ARGS}
BUILD_COMMAND
${MAKE}
${MAKE_BUILD_ARGS}
BUILD_IN_SOURCE
1
INSTALL_COMMAND
${MAKE}
PREFIX=${OPENBLAS_PREFIX}
install
BUILD_BYPRODUCTS
${OPENBLAS_SHARED_LIB}
${OPENBLAS_STATIC_LIB})
file(MAKE_DIRECTORY "${OpenBLAS_INCLUDE_DIR}")
add_library(openblas SHARED IMPORTED)
set_target_properties(
openblas
PROPERTIES
IMPORTED_LOCATION "${OPENBLAS_SHARED_LIB}"
LIBRARY_OUTPUT_NAME "openblas"
INTERFACE_INCLUDE_DIRECTORIES "${OpenBLAS_INCLUDE_DIR}")
add_dependencies(openblas openblas_ep)
get_target_property(OpenBLAS_INCLUDE_DIR openblas INTERFACE_INCLUDE_DIRECTORIES)
set(OpenBLAS_LIBRARIES "${OPENBLAS_SHARED_LIB}")
endmacro()
if (KNOWHERE_WITH_OPENBLAS)
resolve_dependency(OpenBLAS)
include_directories(SYSTEM "${OpenBLAS_INCLUDE_DIR}")
link_directories(SYSTEM "${OpenBLAS_LIB_DIR}")
endif()
# ----------------------------------------------------------------------
# Google gtest
macro(build_gtest)
message(STATUS "Building gtest-${GTEST_VERSION} from source")
set(GTEST_VENDORED TRUE)
set(GTEST_CMAKE_CXX_FLAGS "${EP_CXX_FLAGS}")
if (APPLE)
set(GTEST_CMAKE_CXX_FLAGS
${GTEST_CMAKE_CXX_FLAGS}
-DGTEST_USE_OWN_TR1_TUPLE=1
-Wno-unused-value
-Wno-ignored-attributes)
endif ()
set(GTEST_PREFIX "${INDEX_BINARY_DIR}/googletest_ep-prefix/src/googletest_ep")
set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include")
set(GTEST_STATIC_LIB
"${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}")
set(GTEST_MAIN_STATIC_LIB
"${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}")
set(GTEST_CMAKE_ARGS
${EP_COMMON_CMAKE_ARGS}
"-DCMAKE_INSTALL_PREFIX=${GTEST_PREFIX}"
"-DCMAKE_INSTALL_LIBDIR=lib"
-DCMAKE_CXX_FLAGS=${GTEST_CMAKE_CXX_FLAGS}
-DCMAKE_BUILD_TYPE=Release)
set(GMOCK_INCLUDE_DIR "${GTEST_PREFIX}/include")
set(GMOCK_STATIC_LIB
"${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX}"
)
ExternalProject_Add(googletest_ep
URL
${GTEST_SOURCE_URL}
BUILD_COMMAND
${MAKE}
${MAKE_BUILD_ARGS}
BUILD_BYPRODUCTS
${GTEST_STATIC_LIB}
${GTEST_MAIN_STATIC_LIB}
${GMOCK_STATIC_LIB}
CMAKE_ARGS
${GTEST_CMAKE_ARGS}
${EP_LOG_OPTIONS})
# The include directory must exist before it is referenced by a target.
file(MAKE_DIRECTORY "${GTEST_INCLUDE_DIR}")
add_library(gtest STATIC IMPORTED)
set_target_properties(gtest
PROPERTIES IMPORTED_LOCATION "${GTEST_STATIC_LIB}"
INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
add_library(gtest_main STATIC IMPORTED)
set_target_properties(gtest_main
PROPERTIES IMPORTED_LOCATION "${GTEST_MAIN_STATIC_LIB}"
INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
add_library(gmock STATIC IMPORTED)
set_target_properties(gmock
PROPERTIES IMPORTED_LOCATION "${GMOCK_STATIC_LIB}"
INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
add_dependencies(gtest googletest_ep)
add_dependencies(gtest_main googletest_ep)
add_dependencies(gmock googletest_ep)
endmacro()
# if (KNOWHERE_BUILD_TESTS AND NOT TARGET googletest_ep)
if ( NOT TARGET gtest AND KNOWHERE_BUILD_TESTS )
resolve_dependency(GTest)
if (NOT GTEST_VENDORED)
endif ()
# TODO: Don't use global includes but rather target_include_directories
get_target_property(GTEST_INCLUDE_DIR gtest INTERFACE_INCLUDE_DIRECTORIES)
link_directories(SYSTEM "${GTEST_PREFIX}/lib")
include_directories(SYSTEM ${GTEST_INCLUDE_DIR})
endif ()
# ----------------------------------------------------------------------
# MKL
macro(build_mkl)
if (FAISS_WITH_MKL)
if (EXISTS "/proc/cpuinfo")
FILE(READ /proc/cpuinfo PROC_CPUINFO)
SET(VENDOR_ID_RX "vendor_id[ \t]*:[ \t]*([a-zA-Z]+)\n")
STRING(REGEX MATCH "${VENDOR_ID_RX}" VENDOR_ID "${PROC_CPUINFO}")
STRING(REGEX REPLACE "${VENDOR_ID_RX}" "\\1" VENDOR_ID "${VENDOR_ID}")
if (NOT ${VENDOR_ID} STREQUAL "GenuineIntel")
set(FAISS_WITH_MKL OFF)
endif ()
endif ()
find_path(MKL_LIB_PATH
NAMES "libmkl_intel_ilp64.a" "libmkl_gnu_thread.a" "libmkl_core.a"
PATH_SUFFIXES "intel/compilers_and_libraries_${MKL_VERSION}/linux/mkl/lib/intel64/")
if (${MKL_LIB_PATH} STREQUAL "MKL_LIB_PATH-NOTFOUND")
message(FATAL_ERROR "Could not find MKL libraries")
endif ()
message(STATUS "MKL lib path = ${MKL_LIB_PATH}")
set(MKL_LIBS
${MKL_LIB_PATH}/libmkl_intel_ilp64.a
${MKL_LIB_PATH}/libmkl_gnu_thread.a
${MKL_LIB_PATH}/libmkl_core.a
)
endif ()
endmacro()
# ----------------------------------------------------------------------
# FAISS
macro(build_faiss)
message(STATUS "Building FAISS-${FAISS_VERSION} from source")
set(FAISS_PREFIX "${INDEX_BINARY_DIR}/faiss_ep-prefix/src/faiss_ep")
set(FAISS_INCLUDE_DIR "${FAISS_PREFIX}/include")
set(FAISS_STATIC_LIB
"${FAISS_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}faiss${CMAKE_STATIC_LIBRARY_SUFFIX}")
if (CCACHE_FOUND)
set(FAISS_C_COMPILER "${CCACHE_FOUND} ${CMAKE_C_COMPILER}")
if (MILVUS_GPU_VERSION)
set(FAISS_CXX_COMPILER "${CMAKE_CXX_COMPILER}")
set(FAISS_CUDA_COMPILER "${CCACHE_FOUND} ${CMAKE_CUDA_COMPILER}")
else ()
set(FAISS_CXX_COMPILER "${CCACHE_FOUND} ${CMAKE_CXX_COMPILER}")
endif()
else ()
set(FAISS_C_COMPILER "${CMAKE_C_COMPILER}")
set(FAISS_CXX_COMPILER "${CMAKE_CXX_COMPILER}")
endif()
set(FAISS_CONFIGURE_ARGS
"--prefix=${FAISS_PREFIX}"
"CC=${FAISS_C_COMPILER}"
"CXX=${FAISS_CXX_COMPILER}"
"NVCC=${FAISS_CUDA_COMPILER}"
"CFLAGS=${EP_C_FLAGS}"
"CXXFLAGS=${EP_CXX_FLAGS} -mf16c -O3"
--without-python)
if (FAISS_WITH_MKL)
set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS}
"CPPFLAGS=-DFINTEGER=long -DMKL_ILP64 -m64 -I${MKL_LIB_PATH}/../../include"
"LDFLAGS=-L${MKL_LIB_PATH}"
)
else ()
message(STATUS "Build Faiss with OpenBlas/LAPACK")
if(OpenBLAS_FOUND)
set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS}
"LDFLAGS=-L${OpenBLAS_LIB_DIR}")
else()
set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS}
"LDFLAGS=-L${OPENBLAS_PREFIX}/lib")
endif()
endif ()
if (MILVUS_GPU_VERSION)
if (NOT MILVUS_CUDA_ARCH OR MILVUS_CUDA_ARCH STREQUAL "DEFAULT")
set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS}
"--with-cuda=${CUDA_TOOLKIT_ROOT_DIR}"
"--with-cuda-arch=-gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75"
)
else()
STRING(REPLACE ";" " " MILVUS_CUDA_ARCH "${MILVUS_CUDA_ARCH}")
set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS}
"--with-cuda=${CUDA_TOOLKIT_ROOT_DIR}"
"--with-cuda-arch=${MILVUS_CUDA_ARCH}"
)
endif ()
else ()
set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS}
"CPPFLAGS=-DUSE_CPU"
--without-cuda)
endif ()
message(STATUS "Building FAISS with configure args -${FAISS_CONFIGURE_ARGS}")
if (DEFINED ENV{FAISS_SOURCE_URL})
set(FAISS_SOURCE_URL "$ENV{FAISS_SOURCE_URL}")
externalproject_add(faiss_ep
URL
${FAISS_SOURCE_URL}
${EP_LOG_OPTIONS}
CONFIGURE_COMMAND
"./configure"
${FAISS_CONFIGURE_ARGS}
BUILD_COMMAND
${MAKE} ${MAKE_BUILD_ARGS} all
BUILD_IN_SOURCE
1
INSTALL_COMMAND
${MAKE} install
BUILD_BYPRODUCTS
${FAISS_STATIC_LIB})
else ()
externalproject_add(faiss_ep
DOWNLOAD_COMMAND
""
SOURCE_DIR
${FAISS_SOURCE_DIR}
${EP_LOG_OPTIONS}
CONFIGURE_COMMAND
"./configure"
${FAISS_CONFIGURE_ARGS}
BUILD_COMMAND
${MAKE} ${MAKE_BUILD_ARGS} all
BUILD_IN_SOURCE
1
INSTALL_COMMAND
${MAKE} install
BUILD_BYPRODUCTS
${FAISS_STATIC_LIB})
endif ()
if(NOT OpenBLAS_FOUND)
message("add faiss dependencies: openblas_ep")
ExternalProject_Add_StepDependencies(faiss_ep configure openblas_ep)
endif()
file(MAKE_DIRECTORY "${FAISS_INCLUDE_DIR}")
add_library(faiss STATIC IMPORTED)
set_target_properties(
faiss
PROPERTIES
IMPORTED_LOCATION "${FAISS_STATIC_LIB}"
INTERFACE_INCLUDE_DIRECTORIES "${FAISS_INCLUDE_DIR}"
)
if (FAISS_WITH_MKL)
set_target_properties(
faiss
PROPERTIES
INTERFACE_LINK_LIBRARIES "${MKL_LIBS}")
else ()
set_target_properties(
faiss
PROPERTIES
INTERFACE_LINK_LIBRARIES "${OpenBLAS_LIBRARIES}")
endif ()
add_dependencies(faiss faiss_ep)
endmacro()
if (KNOWHERE_WITH_FAISS AND NOT TARGET faiss_ep)
if (FAISS_WITH_MKL)
resolve_dependency(MKL)
else ()
message("faiss with no mkl")
endif ()
resolve_dependency(FAISS)
get_target_property(FAISS_INCLUDE_DIR faiss INTERFACE_INCLUDE_DIRECTORIES)
include_directories(SYSTEM "${FAISS_INCLUDE_DIR}")
link_directories(SYSTEM ${FAISS_PREFIX}/lib/)
endif ()

View File

@ -0,0 +1,140 @@
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
include_directories(${INDEX_SOURCE_DIR}/knowhere)
include_directories(${INDEX_SOURCE_DIR}/thirdparty)
if (MILVUS_SUPPORT_SPTAG)
include_directories(${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService)
set(SPTAG_SOURCE_DIR ${INDEX_SOURCE_DIR}/thirdparty/SPTAG)
file(GLOB HDR_FILES
${SPTAG_SOURCE_DIR}/AnnService/inc/Core/*.h
${SPTAG_SOURCE_DIR}/AnnService/inc/Core/Common/*.h
${SPTAG_SOURCE_DIR}/AnnService/inc/Core/BKT/*.h
${SPTAG_SOURCE_DIR}/AnnService/inc/Core/KDT/*.h
${SPTAG_SOURCE_DIR}/AnnService/inc/Helper/*.h)
file(GLOB SRC_FILES
${SPTAG_SOURCE_DIR}/AnnService/src/Core/*.cpp
${SPTAG_SOURCE_DIR}/AnnService/src/Core/Common/*.cpp
${SPTAG_SOURCE_DIR}/AnnService/src/Core/BKT/*.cpp
${SPTAG_SOURCE_DIR}/AnnService/src/Core/KDT/*.cpp
${SPTAG_SOURCE_DIR}/AnnService/src/Helper/*.cpp)
if (NOT TARGET SPTAGLibStatic)
add_library(SPTAGLibStatic STATIC ${SRC_FILES} ${HDR_FILES})
endif ()
endif ()
set(external_srcs
knowhere/common/Exception.cpp
knowhere/common/Log.cpp
knowhere/common/Timer.cpp
)
set(vector_index_srcs
knowhere/index/vector_index/adapter/VectorAdapter.cpp
knowhere/index/vector_index/helpers/FaissIO.cpp
knowhere/index/vector_index/helpers/IndexParameter.cpp
knowhere/index/vector_index/impl/nsg/Distance.cpp
knowhere/index/vector_index/impl/nsg/NSG.cpp
knowhere/index/vector_index/impl/nsg/NSGHelper.cpp
knowhere/index/vector_index/impl/nsg/NSGIO.cpp
knowhere/index/vector_index/ConfAdapter.cpp
knowhere/index/vector_index/ConfAdapterMgr.cpp
knowhere/index/vector_index/FaissBaseBinaryIndex.cpp
knowhere/index/vector_index/FaissBaseIndex.cpp
knowhere/index/vector_index/IndexBinaryIDMAP.cpp
knowhere/index/vector_index/IndexBinaryIVF.cpp
knowhere/index/vector_index/IndexIDMAP.cpp
knowhere/index/vector_index/IndexIVF.cpp
knowhere/index/vector_index/IndexIVFPQ.cpp
knowhere/index/vector_index/IndexIVFSQ.cpp
knowhere/index/IndexType.cpp
knowhere/index/vector_index/VecIndexFactory.cpp
knowhere/index/vector_index/IndexAnnoy.cpp
knowhere/index/vector_index/IndexRHNSW.cpp
knowhere/index/vector_index/IndexHNSW.cpp
knowhere/index/vector_index/IndexRHNSWFlat.cpp
knowhere/index/vector_index/IndexRHNSWSQ.cpp
knowhere/index/vector_index/IndexRHNSWPQ.cpp
)
set(vector_offset_index_srcs
knowhere/index/vector_offset_index/OffsetBaseIndex.cpp
knowhere/index/vector_offset_index/IndexIVF_NM.cpp
knowhere/index/vector_offset_index/IndexNSG_NM.cpp
)
if (MILVUS_SUPPORT_SPTAG)
set(vector_index_srcs
knowhere/index/vector_index/adapter/SptagAdapter.cpp
knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
knowhere/index/vector_index/IndexSPTAG.cpp
${vector_index_srcs}
)
endif ()
set(depend_libs
faiss
gomp
gfortran
pthread
)
if (MILVUS_SUPPORT_SPTAG)
set(depend_libs
SPTAGLibStatic
${depend_libs}
)
endif ()
if (NOT TARGET knowhere)
add_library(
knowhere STATIC
${external_srcs}
${vector_index_srcs}
${vector_offset_index_srcs}
)
endif ()
target_link_libraries(
knowhere
${depend_libs}
)
set(INDEX_INCLUDE_DIRS
${INDEX_SOURCE_DIR}/knowhere
${INDEX_SOURCE_DIR}/thirdparty
${FAISS_INCLUDE_DIR}
${OpenBLAS_INCLUDE_DIR}
${LAPACK_INCLUDE_DIR}
)
if (MILVUS_SUPPORT_SPTAG)
set(INDEX_INCLUDE_DIRS
${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService
${INDEX_INCLUDE_DIRS}
)
endif ()
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
# **************************** Get&Print Include Directories ****************************
get_property( dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES )
foreach ( dir ${dirs} )
message( STATUS "Knowhere Current Include DIRS: " ${dir} )
endforeach ()

View File

@ -0,0 +1,87 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <string.h>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace milvus {
namespace knowhere {
struct Binary {
std::shared_ptr<uint8_t[]> data;
int64_t size = 0;
};
using BinaryPtr = std::shared_ptr<Binary>;
inline uint8_t*
CopyBinary(const BinaryPtr& bin) {
uint8_t* newdata = new uint8_t[bin->size];
memcpy(newdata, bin->data.get(), bin->size);
return newdata;
}
class BinarySet {
public:
BinaryPtr
GetByName(const std::string& name) const {
return binary_map_.at(name);
}
void
Append(const std::string& name, BinaryPtr binary) {
binary_map_[name] = std::move(binary);
}
void
Append(const std::string& name, std::shared_ptr<uint8_t[]> data, int64_t size) {
auto binary = std::make_shared<Binary>();
binary->data = data;
binary->size = size;
binary_map_[name] = std::move(binary);
}
// void
// Append(const std::string &name, void *data, int64_t size, ID id) {
// Binary binary;
// binary.data = data;
// binary.size = size;
// binary.id = id;
// binary_map_[name] = binary;
//}
BinaryPtr
Erase(const std::string& name) {
BinaryPtr result = nullptr;
auto it = binary_map_.find(name);
if (it != binary_map_.end()) {
result = it->second;
binary_map_.erase(it);
}
return result;
}
void
clear() {
binary_map_.clear();
}
public:
std::map<std::string, BinaryPtr> binary_map_;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,22 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include "utils/Json.h"
namespace milvus {
namespace knowhere {
using Config = milvus::json;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,61 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <any>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
namespace milvus {
namespace knowhere {
using Value = std::any;
using ValuePtr = std::shared_ptr<Value>;
class Dataset {
public:
Dataset() = default;
template <typename T>
void
Set(const std::string& k, T&& v) {
std::lock_guard<std::mutex> lk(mutex_);
data_[k] = std::make_shared<Value>(std::forward<T>(v));
}
template <typename T>
T
Get(const std::string& k) {
std::lock_guard<std::mutex> lk(mutex_);
try {
return std::any_cast<T>(*(data_.at(k)));
} catch (...) {
throw std::logic_error("Can't find this key");
}
}
const std::map<std::string, ValuePtr>&
data() const {
return data_;
}
private:
std::mutex mutex_;
std::map<std::string, ValuePtr> data_;
};
using DatasetPtr = std::shared_ptr<Dataset>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,46 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <cstdio>
#include <utility>
#include "Log.h"
#include "knowhere/common/Exception.h"
namespace milvus {
namespace knowhere {
KnowhereException::KnowhereException(std::string msg) : msg_(std::move(msg)) {
}
KnowhereException::KnowhereException(const std::string& m, const char* funcName, const char* file, int line) {
std::string filename;
try {
size_t pos;
std::string file_path(file);
pos = file_path.find_last_of('/');
filename = file_path.substr(pos + 1);
} catch (std::exception& e) {
LOG_KNOWHERE_ERROR_ << e.what();
}
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", funcName, filename.c_str(), line, m.c_str());
msg_.resize(size + 1);
snprintf(&msg_[0], msg_.size(), "Error in %s at %s:%d: %s", funcName, filename.c_str(), line, m.c_str());
}
const char*
KnowhereException::what() const noexcept {
return msg_.c_str();
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,49 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <exception>
#include <string>
namespace milvus {
namespace knowhere {
class KnowhereException : public std::exception {
public:
explicit KnowhereException(std::string msg);
KnowhereException(const std::string& msg, const char* funName, const char* file, int line);
const char*
what() const noexcept override;
std::string msg_;
};
#define KNOHWERE_ERROR_MSG(MSG) printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what())
#define KNOWHERE_THROW_MSG(MSG) \
do { \
throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
} while (false)
#define KNOHERE_THROW_FORMAT(FMT, ...) \
do { \
std::string __s; \
int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__); \
__s.resize(__size + 1); \
snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__); \
throw faiss::FaissException(__s, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
} while (false)
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,85 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/common/Log.h"
#include <cstdarg>
#include <cstdio>
#include <memory>
#include <string>
namespace milvus {
namespace knowhere {
std::string
LogOut(const char* pattern, ...) {
size_t len = strnlen(pattern, 1024) + 256;
auto str_p = std::make_unique<char[]>(len);
memset(str_p.get(), 0, len);
va_list vl;
va_start(vl, pattern);
vsnprintf(str_p.get(), len, pattern, vl); // NOLINT
va_end(vl);
return std::string(str_p.get());
}
void
SetThreadName(const std::string& name) {
pthread_setname_np(pthread_self(), name.c_str());
}
std::string
GetThreadName() {
std::string thread_name = "unamed";
char name[16];
size_t len = 16;
auto err = pthread_getname_np(pthread_self(), name, len);
if (not err) {
thread_name = name;
}
return thread_name;
}
void
log_trace_(const std::string& s) {
LOG_KNOWHERE_TRACE_ << s;
}
void
log_debug_(const std::string& s) {
LOG_KNOWHERE_DEBUG_ << s;
}
void
log_info_(const std::string& s) {
LOG_KNOWHERE_INFO_ << s;
}
void
log_warning_(const std::string& s) {
LOG_KNOWHERE_WARNING_ << s;
}
void
log_error_(const std::string& s) {
LOG_KNOWHERE_ERROR_ << s;
}
void
log_fatal_(const std::string& s) {
LOG_KNOWHERE_FATAL_ << s;
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,74 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <string>
#include "easyloggingpp/easylogging++.h"
namespace milvus {
namespace knowhere {
std::string
LogOut(const char* pattern, ...);
void
SetThreadName(const std::string& name);
std::string
GetThreadName();
void
log_trace_(const std::string&);
void
log_debug_(const std::string&);
void
log_info_(const std::string&);
void
log_warning_(const std::string&);
void
log_error_(const std::string&);
void
log_fatal_(const std::string&);
/*
* Please use LOG_MODULE_LEVEL_C macro in member function of class
* and LOG_MODULE_LEVEL_ macro in other functions.
*/
/////////////////////////////////////////////////////////////////////////////////////////////////
#define KNOWHERE_MODULE_NAME "KNOWHERE"
#define KNOWHERE_MODULE_CLASS_FUNCTION \
LogOut("[%s][%s::%s][%s] ", KNOWHERE_MODULE_NAME, (typeid(*this).name()), __FUNCTION__, GetThreadName().c_str())
#define KNOWHERE_MODULE_FUNCTION LogOut("[%s][%s][%s] ", KNOWHERE_MODULE_NAME, __FUNCTION__, GetThreadName().c_str())
#define LOG_KNOWHERE_TRACE_C LOG(TRACE) << KNOWHERE_MODULE_CLASS_FUNCTION
#define LOG_KNOWHERE_DEBUG_C LOG(DEBUG) << KNOWHERE_MODULE_CLASS_FUNCTION
#define LOG_KNOWHERE_INFO_C LOG(INFO) << KNOWHERE_MODULE_CLASS_FUNCTION
#define LOG_KNOWHERE_WARNING_C LOG(WARNING) << KNOWHERE_MODULE_CLASS_FUNCTION
#define LOG_KNOWHERE_ERROR_C LOG(ERROR) << KNOWHERE_MODULE_CLASS_FUNCTION
#define LOG_KNOWHERE_FATAL_C LOG(FATAL) << KNOWHERE_MODULE_CLASS_FUNCTION
#define LOG_KNOWHERE_TRACE_ LOG(TRACE) << KNOWHERE_MODULE_FUNCTION
#define LOG_KNOWHERE_DEBUG_ LOG(DEBUG) << KNOWHERE_MODULE_FUNCTION
#define LOG_KNOWHERE_INFO_ LOG(INFO) << KNOWHERE_MODULE_FUNCTION
#define LOG_KNOWHERE_WARNING_ LOG(WARNING) << KNOWHERE_MODULE_FUNCTION
#define LOG_KNOWHERE_ERROR_ LOG(ERROR) << KNOWHERE_MODULE_FUNCTION
#define LOG_KNOWHERE_FATAL_ LOG(FATAL) << KNOWHERE_MODULE_FUNCTION
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,74 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <iostream>
#include <utility>
#include "knowhere/common/Log.h"
#include "knowhere/common/Timer.h"
namespace milvus {
namespace knowhere {
TimeRecorder::TimeRecorder(std::string hdr, int64_t log_level) : header_(std::move(hdr)), log_level_(log_level) {
start_ = last_ = stdclock::now();
}
std::string
TimeRecorder::GetTimeSpanStr(double span) {
std::string str_sec = std::to_string(span * 0.000001) + ((span > 1000000) ? " seconds" : " second");
std::string str_ms = std::to_string(span * 0.001) + " ms";
return str_sec + " [" + str_ms + "]";
}
void
TimeRecorder::PrintTimeRecord(const std::string& msg, double span) {
std::string str_log;
if (!header_.empty()) {
str_log += header_ + ": ";
}
str_log += msg;
str_log += " (";
str_log += TimeRecorder::GetTimeSpanStr(span);
str_log += ")";
switch (log_level_) {
case 0:
std::cout << str_log << std::endl;
break;
default:
LOG_KNOWHERE_DEBUG_ << str_log;
break;
}
}
double
TimeRecorder::RecordSection(const std::string& msg) {
stdclock::time_point curr = stdclock::now();
double span = (std::chrono::duration<double, std::micro>(curr - last_)).count();
last_ = curr;
PrintTimeRecord(msg, span);
return span;
}
double
TimeRecorder::ElapseFromBegin(const std::string& msg) {
stdclock::time_point curr = stdclock::now();
double span = (std::chrono::duration<double, std::micro>(curr - start_)).count();
PrintTimeRecord(msg, span);
return span;
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,49 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <chrono>
#include <string>
namespace milvus {
namespace knowhere {
class TimeRecorder {
using stdclock = std::chrono::high_resolution_clock;
public:
// trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5
explicit TimeRecorder(std::string hdr, int64_t log_level = 0);
virtual ~TimeRecorder() = default;
double
RecordSection(const std::string& msg);
double
ElapseFromBegin(const std::string& msg);
static std::string
GetTimeSpanStr(double span);
private:
void
PrintTimeRecord(const std::string& msg, double span);
private:
std::string header_;
stdclock::time_point start_;
stdclock::time_point last_;
int64_t log_level_;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,29 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <string>
#include <vector>
namespace milvus {
namespace knowhere {
using MetricType = std::string;
// using IndexType = std::string;
using IDType = int64_t;
using FloatType = float;
using BinaryType = uint8_t;
using GraphType = std::vector<std::vector<IDType>>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,35 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <memory>
#include "cache/DataObj.h"
#include "knowhere/common/BinarySet.h"
#include "knowhere/common/Config.h"
namespace milvus {
namespace knowhere {
class Index : public milvus::cache::DataObj {
public:
virtual BinarySet
Serialize(const Config& config) = 0;
virtual void
Load(const BinarySet&) = 0;
};
using IndexPtr = std::shared_ptr<Index>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,43 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <unordered_map>
#include "knowhere/common/Exception.h"
#include "knowhere/index/IndexType.h"
namespace milvus {
namespace knowhere {
/* used in 0.8.0 */
namespace IndexEnum {
const char* INVALID = "";
const char* INDEX_FAISS_IDMAP = "FLAT";
const char* INDEX_FAISS_IVFFLAT = "IVF_FLAT";
const char* INDEX_FAISS_IVFPQ = "IVF_PQ";
const char* INDEX_FAISS_IVFSQ8 = "IVF_SQ8";
const char* INDEX_FAISS_IVFSQ8H = "IVF_SQ8_HYBRID";
const char* INDEX_FAISS_BIN_IDMAP = "BIN_FLAT";
const char* INDEX_FAISS_BIN_IVFFLAT = "BIN_IVF_FLAT";
const char* INDEX_NSG = "NSG";
#ifdef MILVUS_SUPPORT_SPTAG
const char* INDEX_SPTAG_KDT_RNT = "SPTAG_KDT_RNT";
const char* INDEX_SPTAG_BKT_RNT = "SPTAG_BKT_RNT";
#endif
const char* INDEX_HNSW = "HNSW";
const char* INDEX_RHNSWFlat = "RHNSW_FLAT";
const char* INDEX_RHNSWPQ = "RHNSW_PQ";
const char* INDEX_RHNSWSQ = "RHNSW_SQ";
const char* INDEX_ANNOY = "ANNOY";
} // namespace IndexEnum
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,72 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <string>
namespace milvus {
namespace knowhere {
/* used in 0.7.0 */
enum class OldIndexType {
INVALID = 0,
FAISS_IDMAP = 1,
FAISS_IVFFLAT_CPU,
FAISS_IVFFLAT_GPU,
FAISS_IVFFLAT_MIX, // build on gpu and search on cpu
FAISS_IVFPQ_CPU,
FAISS_IVFPQ_GPU,
SPTAG_KDT_RNT_CPU,
FAISS_IVFSQ8_MIX,
FAISS_IVFSQ8_CPU,
FAISS_IVFSQ8_GPU,
FAISS_IVFSQ8_HYBRID, // only support build on gpu.
NSG_MIX,
FAISS_IVFPQ_MIX,
SPTAG_BKT_RNT_CPU,
HNSW,
ANNOY,
RHNSW_FLAT,
RHNSW_PQ,
RHNSW_SQ,
FAISS_BIN_IDMAP = 100,
FAISS_BIN_IVFLAT_CPU = 101,
};
using IndexType = std::string;
/* used in 0.8.0 */
namespace IndexEnum {
extern const char* INVALID;
extern const char* INDEX_FAISS_IDMAP;
extern const char* INDEX_FAISS_IVFFLAT;
extern const char* INDEX_FAISS_IVFPQ;
extern const char* INDEX_FAISS_IVFSQ8;
extern const char* INDEX_FAISS_IVFSQ8H;
extern const char* INDEX_FAISS_BIN_IDMAP;
extern const char* INDEX_FAISS_BIN_IVFFLAT;
extern const char* INDEX_NSG;
#ifdef MILVUS_SUPPORT_SPTAG
extern const char* INDEX_SPTAG_KDT_RNT;
extern const char* INDEX_SPTAG_BKT_RNT;
#endif
extern const char* INDEX_HNSW;
extern const char* INDEX_RHNSWFlat;
extern const char* INDEX_RHNSWPQ;
extern const char* INDEX_RHNSWSQ;
extern const char* INDEX_ANNOY;
} // namespace IndexEnum
enum class IndexMode { MODE_CPU = 0, MODE_GPU = 1 };
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,30 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <memory>
#include "knowhere/common/Dataset.h"
namespace milvus {
namespace knowhere {
class Preprocessor {
public:
virtual DatasetPtr
Preprocess(const DatasetPtr& input) = 0;
};
using PreprocessorPtr = std::shared_ptr<Preprocessor>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,86 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <map>
#include <memory>
#include <string>
#include "faiss/utils/ConcurrentBitset.h"
#include "knowhere/index/Index.h"
namespace milvus {
namespace knowhere {
enum OperatorType { LT = 0, LE = 1, GT = 3, GE = 4 };
static std::map<std::string, OperatorType> s_map_operator_type = {
{"LT", OperatorType::LT},
{"LE", OperatorType::LE},
{"GT", OperatorType::GT},
{"GE", OperatorType::GE},
};
template <typename T>
struct IndexStructure {
IndexStructure() : a_(0), idx_(0) {
}
explicit IndexStructure(const T a) : a_(a), idx_(0) {
}
IndexStructure(const T a, const size_t idx) : a_(a), idx_(idx) {
}
bool
operator<(const IndexStructure& b) const {
return a_ < b.a_;
}
bool
operator<=(const IndexStructure& b) const {
return a_ <= b.a_;
}
bool
operator>(const IndexStructure& b) const {
return a_ > b.a_;
}
bool
operator>=(const IndexStructure& b) const {
return a_ >= b.a_;
}
bool
operator==(const IndexStructure& b) const {
return a_ == b.a_;
}
T a_;
size_t idx_;
};
template <typename T>
class StructuredIndex : public Index {
public:
virtual void
Build(const size_t n, const T* values) = 0;
virtual const faiss::ConcurrentBitsetPtr
In(const size_t n, const T* values) = 0;
virtual const faiss::ConcurrentBitsetPtr
NotIn(const size_t n, const T* values) = 0;
virtual const faiss::ConcurrentBitsetPtr
Range(const T value, const OperatorType op) = 0;
virtual const faiss::ConcurrentBitsetPtr
Range(const T lower_bound_value, bool lb_inclusive, const T upper_bound_value, bool ub_inclusive) = 0;
};
template <typename T>
using StructuredIndexPtr = std::shared_ptr<StructuredIndex<T>>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,153 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <src/index/knowhere/knowhere/common/Log.h>
#include <algorithm>
#include <memory>
#include <utility>
#include "knowhere/index/structured_index/StructuredIndexFlat.h"
namespace milvus {
namespace knowhere {
template <typename T>
StructuredIndexFlat<T>::StructuredIndexFlat() : is_built_(false), data_() {
}
template <typename T>
StructuredIndexFlat<T>::StructuredIndexFlat(const size_t n, const T* values) : is_built_(false) {
Build(n, values);
}
template <typename T>
StructuredIndexFlat<T>::~StructuredIndexFlat() {
}
template <typename T>
void
StructuredIndexFlat<T>::Build(const size_t n, const T* values) {
data_.reserve(n);
T* p = const_cast<T*>(values);
for (size_t i = 0; i < n; ++i) {
data_.emplace_back(IndexStructure(*p++, i));
}
is_built_ = true;
}
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexFlat<T>::In(const size_t n, const T* values) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
for (size_t i = 0; i < n; ++i) {
for (const auto& index : data_) {
if (index->a_ == *(values + i)) {
bitset->set(index->idx_);
}
}
}
return bitset;
}
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexFlat<T>::NotIn(const size_t n, const T* values) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size(), 0xff);
for (size_t i = 0; i < n; ++i) {
for (const auto& index : data_) {
if (index->a_ == *(values + i)) {
bitset->clear(index->idx_);
}
}
}
return bitset;
}
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexFlat<T>::Range(const T value, const OperatorType op) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
auto lb = data_.begin();
auto ub = data_.end();
for (; lb <= ub; lb++) {
switch (op) {
case OperatorType::LT:
if (lb < IndexStructure<T>(value)) {
bitset->set(lb->idx_);
}
break;
case OperatorType::LE:
if (lb <= IndexStructure<T>(value)) {
bitset->set(lb->idx_);
}
break;
case OperatorType::GT:
if (lb > IndexStructure<T>(value)) {
bitset->set(lb->idx_);
}
break;
case OperatorType::GE:
if (lb >= IndexStructure<T>(value)) {
bitset->set(lb->idx_);
}
break;
default:
KNOWHERE_THROW_MSG("Invalid OperatorType:" + std::to_string((int)op) + "!");
}
}
return bitset;
}
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexFlat<T>::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
if (lower_bound_value > upper_bound_value) {
std::swap(lower_bound_value, upper_bound_value);
std::swap(lb_inclusive, ub_inclusive);
}
auto lb = data_.begin();
auto ub = data_.end();
for (; lb <= ub; ++lb) {
if (lb_inclusive && ub_inclusive) {
if (lb >= IndexStructure<T>(lower_bound_value) && lb <= IndexStructure<T>(upper_bound_value)) {
bitset->set(lb->idx_);
}
} else if (lb_inclusive && !ub_inclusive) {
if (lb >= IndexStructure<T>(lower_bound_value) && lb < IndexStructure<T>(upper_bound_value)) {
bitset->set(lb->idx_);
}
} else if (!lb_inclusive && ub_inclusive) {
if (lb > IndexStructure<T>(lower_bound_value) && lb <= IndexStructure<T>(upper_bound_value)) {
bitset->set(lb->idx_);
}
} else {
if (lb > IndexStructure<T>(lower_bound_value) && lb < IndexStructure<T>(upper_bound_value)) {
bitset->set(lb->idx_);
}
}
}
return bitset;
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,80 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "knowhere/common/Exception.h"
#include "knowhere/index/structured_index/StructuredIndex.h"
namespace milvus {
namespace knowhere {
template <typename T>
class StructuredIndexFlat : public StructuredIndex<T> {
public:
StructuredIndexFlat();
StructuredIndexFlat(const size_t n, const T* values);
~StructuredIndexFlat();
BinarySet
Serialize(const Config& config = Config()) override;
void
Load(const BinarySet& index_binary) override;
void
Build(const size_t n, const T* values) override;
void
build();
const faiss::ConcurrentBitsetPtr
In(const size_t n, const T* values) override;
const faiss::ConcurrentBitsetPtr
NotIn(const size_t n, const T* values) override;
const faiss::ConcurrentBitsetPtr
Range(const T value, const OperatorType op) override;
const faiss::ConcurrentBitsetPtr
Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) override;
const std::vector<IndexStructure<T>>&
GetData() {
return data_;
}
int64_t
Size() override {
return (int64_t)data_.size();
}
bool
IsBuilt() const {
return is_built_;
}
private:
bool is_built_;
std::vector<IndexStructure<T>> data_;
};
template <typename T>
using StructuredIndexFlatPtr = std::shared_ptr<StructuredIndexFlat<T>>;
} // namespace knowhere
} // namespace milvus
#include "knowhere/index/structured_index/StructuredIndexFlat-inl.h"

View File

@ -0,0 +1,199 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <src/index/knowhere/knowhere/common/Log.h>
#include <algorithm>
#include <memory>
#include <utility>
#include "knowhere/index/structured_index/StructuredIndexSort.h"
namespace milvus {
namespace knowhere {
template <typename T>
StructuredIndexSort<T>::StructuredIndexSort() : is_built_(false), data_() {
}
template <typename T>
StructuredIndexSort<T>::StructuredIndexSort(const size_t n, const T* values) : is_built_(false) {
StructuredIndexSort<T>::Build(n, values);
}
template <typename T>
StructuredIndexSort<T>::~StructuredIndexSort() {
}
template <typename T>
void
StructuredIndexSort<T>::Build(const size_t n, const T* values) {
data_.reserve(n);
T* p = const_cast<T*>(values);
for (size_t i = 0; i < n; ++i) {
data_.emplace_back(IndexStructure(*p++, i));
}
build();
}
template <typename T>
void
StructuredIndexSort<T>::build() {
if (is_built_)
return;
if (data_.size() == 0) {
// todo: throw an exception
KNOWHERE_THROW_MSG("StructuredIndexSort cannot build null values!");
}
std::sort(data_.begin(), data_.end());
is_built_ = true;
}
template <typename T>
BinarySet
StructuredIndexSort<T>::Serialize(const milvus::knowhere::Config& config) {
if (!is_built_) {
build();
}
auto index_data_size = data_.size() * sizeof(IndexStructure<T>);
std::shared_ptr<uint8_t[]> index_data(new uint8_t[index_data_size]);
memcpy(index_data.get(), data_.data(), index_data_size);
std::shared_ptr<uint8_t[]> index_length(new uint8_t[sizeof(size_t)]);
auto index_size = data_.size();
memcpy(index_length.get(), &index_size, sizeof(size_t));
BinarySet res_set;
res_set.Append("index_data", index_data, index_data_size);
res_set.Append("index_length", index_length, sizeof(size_t));
return res_set;
}
template <typename T>
void
StructuredIndexSort<T>::Load(const milvus::knowhere::BinarySet& index_binary) {
try {
size_t index_size;
auto index_length = index_binary.GetByName("index_length");
memcpy(&index_size, index_length->data.get(), (size_t)index_length->size);
auto index_data = index_binary.GetByName("index_data");
data_.resize(index_size);
memcpy(data_.data(), index_data->data.get(), (size_t)index_data->size);
is_built_ = true;
} catch (...) {
KNOHWERE_ERROR_MSG("StructuredIndexSort Load failed!");
}
}
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexSort<T>::In(const size_t n, const T* values) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
for (size_t i = 0; i < n; ++i) {
auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
for (; lb < ub; ++lb) {
if (lb->a_ != *(values + i)) {
LOG_KNOWHERE_ERROR_ << "error happens in StructuredIndexSort<T>::In, experted value is: "
<< *(values + i) << ", but real value is: " << lb->a_;
}
bitset->set(lb->idx_);
}
}
return bitset;
}
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexSort<T>::NotIn(const size_t n, const T* values) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size(), 0xff);
for (size_t i = 0; i < n; ++i) {
auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
for (; lb < ub; ++lb) {
if (lb->a_ != *(values + i)) {
LOG_KNOWHERE_ERROR_ << "error happens in StructuredIndexSort<T>::NotIn, experted value is: "
<< *(values + i) << ", but real value is: " << lb->a_;
}
bitset->clear(lb->idx_);
}
}
return bitset;
}
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexSort<T>::Range(const T value, const OperatorType op) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
auto lb = data_.begin();
auto ub = data_.end();
switch (op) {
case OperatorType::LT:
ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
break;
case OperatorType::LE:
ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
break;
case OperatorType::GT:
lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
break;
case OperatorType::GE:
lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
break;
default:
KNOWHERE_THROW_MSG("Invalid OperatorType:" + std::to_string((int)op) + "!");
}
for (; lb < ub; ++lb) {
bitset->set(lb->idx_);
}
return bitset;
}
template <typename T>
const faiss::ConcurrentBitsetPtr
StructuredIndexSort<T>::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) {
if (!is_built_) {
build();
}
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(data_.size());
if (lower_bound_value > upper_bound_value) {
std::swap(lower_bound_value, upper_bound_value);
std::swap(lb_inclusive, ub_inclusive);
}
auto lb = data_.begin();
auto ub = data_.end();
if (lb_inclusive) {
lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
} else {
lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
}
if (ub_inclusive) {
ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
} else {
ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
}
for (; lb < ub; ++lb) {
bitset->set(lb->idx_);
}
return bitset;
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,80 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "knowhere/common/Exception.h"
#include "knowhere/index/structured_index/StructuredIndex.h"
namespace milvus {
namespace knowhere {
template <typename T>
class StructuredIndexSort : public StructuredIndex<T> {
public:
StructuredIndexSort();
StructuredIndexSort(const size_t n, const T* values);
~StructuredIndexSort();
BinarySet
Serialize(const Config& config = Config()) override;
void
Load(const BinarySet& index_binary) override;
void
Build(const size_t n, const T* values) override;
void
build();
const faiss::ConcurrentBitsetPtr
In(const size_t n, const T* values) override;
const faiss::ConcurrentBitsetPtr
NotIn(const size_t n, const T* values) override;
const faiss::ConcurrentBitsetPtr
Range(const T value, const OperatorType op) override;
const faiss::ConcurrentBitsetPtr
Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) override;
const std::vector<IndexStructure<T>>&
GetData() {
return data_;
}
int64_t
Size() override {
return (int64_t)data_.size();
}
bool
IsBuilt() const {
return is_built_;
}
private:
bool is_built_;
std::vector<IndexStructure<T>> data_;
};
template <typename T>
using StructuredIndexSortPtr = std::shared_ptr<StructuredIndexSort<T>>;
} // namespace knowhere
} // namespace milvus
#include "knowhere/index/structured_index/StructuredIndexSort-inl.h"

View File

@ -0,0 +1,372 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/ConfAdapter.h"
#include <cmath>
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#ifdef MILVUS_GPU_VERSION
#include "faiss/gpu/utils/DeviceUtils.h"
#endif
namespace milvus {
namespace knowhere {
static const int64_t MIN_NLIST = 1;
static const int64_t MAX_NLIST = 1LL << 20;
static const int64_t MIN_NPROBE = 1;
static const int64_t MAX_NPROBE = MAX_NLIST;
static const int64_t DEFAULT_MIN_DIM = 1;
static const int64_t DEFAULT_MAX_DIM = 32768;
static const int64_t DEFAULT_MIN_ROWS = 1; // minimum size for build index
static const int64_t DEFAULT_MAX_ROWS = 50000000;
static const std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::IP};
#define CheckIntByRange(key, min, max) \
if (!oricfg.contains(key) || !oricfg[key].is_number_integer() || oricfg[key].get<int64_t>() > max || \
oricfg[key].get<int64_t>() < min) { \
return false; \
}
#define CheckIntByValues(key, container) \
if (!oricfg.contains(key) || !oricfg[key].is_number_integer()) { \
return false; \
} else { \
auto finder = std::find(std::begin(container), std::end(container), oricfg[key].get<int64_t>()); \
if (finder == std::end(container)) { \
return false; \
} \
}
#define CheckStrByValues(key, container) \
if (!oricfg.contains(key) || !oricfg[key].is_string()) { \
return false; \
} else { \
auto finder = std::find(std::begin(container), std::end(container), oricfg[key].get<std::string>()); \
if (finder == std::end(container)) { \
return false; \
} \
}
bool
ConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
return true;
}
bool
ConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
const int64_t DEFAULT_MIN_K = 1;
const int64_t DEFAULT_MAX_K = 16384;
CheckIntByRange(knowhere::meta::TOPK, DEFAULT_MIN_K - 1, DEFAULT_MAX_K);
return true;
}
int64_t
MatchNlist(int64_t size, int64_t nlist) {
const int64_t TYPICAL_COUNT = 1000000;
const int64_t PER_NLIST = 16384;
if (nlist * TYPICAL_COUNT > size * PER_NLIST) {
// nlist is too large, adjust to a proper value
nlist = std::max(1L, size * PER_NLIST / TYPICAL_COUNT);
}
return nlist;
}
bool
IVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
// int64_t nlist = oricfg[knowhere::IndexParams::nlist];
// CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
// auto tune params
auto nq = oricfg[knowhere::meta::ROWS].get<int64_t>();
auto nlist = oricfg[knowhere::IndexParams::nlist].get<int64_t>();
oricfg[knowhere::IndexParams::nlist] = MatchNlist(nq, nlist);
// Best Practice
// static int64_t MIN_POINTS_PER_CENTROID = 40;
// static int64_t MAX_POINTS_PER_CENTROID = 256;
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
IVFConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
int64_t max_nprobe = MAX_NPROBE;
#ifdef MILVUS_GPU_VERSION
if (mode == IndexMode::MODE_GPU) {
max_nprobe = faiss::gpu::getMaxKSelection();
}
#endif
CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, max_nprobe);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
IVFSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
const int64_t DEFAULT_NBITS = 8;
oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS;
return IVFConfAdapter::CheckTrain(oricfg, mode);
}
bool
IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
const int64_t DEFAULT_NBITS = 8;
oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS;
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
// int64_t nlist = oricfg[knowhere::IndexParams::nlist];
// CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
// auto tune params
oricfg[knowhere::IndexParams::nlist] =
MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), oricfg[knowhere::IndexParams::nlist].get<int64_t>());
// Best Practice
// static int64_t MIN_POINTS_PER_CENTROID = 40;
// static int64_t MAX_POINTS_PER_CENTROID = 256;
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
std::vector<int64_t> resset;
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
IVFPQConfAdapter::GetValidMList(dimension, resset);
CheckIntByValues(knowhere::IndexParams::m, resset);
return true;
}
void
IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset) {
resset.clear();
/*
* Faiss 1.6
* Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with
* no precomputed codes. Precomputed codes supports any number of dimensions, but will involve memory overheads.
*/
static const std::vector<int64_t> support_dim_per_subquantizer{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1};
static const std::vector<int64_t> support_subquantizer{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1};
for (const auto& dimperquantizer : support_dim_per_subquantizer) {
if (!(dimension % dimperquantizer)) {
auto subquantzier_num = dimension / dimperquantizer;
auto finder = std::find(support_subquantizer.begin(), support_subquantizer.end(), subquantzier_num);
if (finder != support_subquantizer.end()) {
resset.push_back(subquantzier_num);
}
}
}
}
bool
NSGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
const int64_t MIN_KNNG = 5;
const int64_t MAX_KNNG = 300;
const int64_t MIN_SEARCH_LENGTH = 10;
const int64_t MAX_SEARCH_LENGTH = 300;
const int64_t MIN_OUT_DEGREE = 5;
const int64_t MAX_OUT_DEGREE = 300;
const int64_t MIN_CANDIDATE_POOL_SIZE = 50;
const int64_t MAX_CANDIDATE_POOL_SIZE = 1000;
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::knng, MIN_KNNG, MAX_KNNG);
CheckIntByRange(knowhere::IndexParams::search_length, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH);
CheckIntByRange(knowhere::IndexParams::out_degree, MIN_OUT_DEGREE, MAX_OUT_DEGREE);
CheckIntByRange(knowhere::IndexParams::candidate, MIN_CANDIDATE_POOL_SIZE, MAX_CANDIDATE_POOL_SIZE);
// auto tune params
oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), 8192);
int64_t nprobe = int(oricfg[knowhere::IndexParams::nlist].get<int64_t>() * 0.1);
oricfg[knowhere::IndexParams::nprobe] = nprobe < 1 ? 1 : nprobe;
return true;
}
bool
NSGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MIN_SEARCH_LENGTH = 1;
static int64_t MAX_SEARCH_LENGTH = 300;
CheckIntByRange(knowhere::IndexParams::search_length, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
HNSWConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_EFCONSTRUCTION = 8;
static int64_t MAX_EFCONSTRUCTION = 512;
static int64_t MIN_M = 4;
static int64_t MAX_M = 64;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
HNSWConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
RHNSWFlatConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_EFCONSTRUCTION = 8;
static int64_t MAX_EFCONSTRUCTION = 512;
static int64_t MIN_M = 4;
static int64_t MAX_M = 64;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
RHNSWFlatConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_EFCONSTRUCTION = 8;
static int64_t MAX_EFCONSTRUCTION = 512;
static int64_t MIN_M = 4;
static int64_t MAX_M = 64;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
std::vector<int64_t> resset;
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
IVFPQConfAdapter::GetValidMList(dimension, resset);
CheckIntByValues(knowhere::IndexParams::PQM, resset);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
RHNSWPQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
RHNSWSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_EFCONSTRUCTION = 8;
static int64_t MAX_EFCONSTRUCTION = 512;
static int64_t MIN_M = 4;
static int64_t MAX_M = 64;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
RHNSWSQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
bool
BinIDMAPConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static const std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
knowhere::Metric::TANIMOTO, knowhere::Metric::SUBSTRUCTURE,
knowhere::Metric::SUPERSTRUCTURE};
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
return true;
}
bool
BinIVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static const std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
knowhere::Metric::TANIMOTO};
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
int64_t nlist = oricfg[knowhere::IndexParams::nlist];
CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
// Best Practice
// static int64_t MIN_POINTS_PER_CENTROID = 40;
// static int64_t MAX_POINTS_PER_CENTROID = 256;
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
return true;
}
bool
ANNOYConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MIN_NTREES = 1;
// too large of n_trees takes much time, if there is real requirement, change this threshold.
static int64_t MAX_NTREES = 1024;
CheckIntByRange(knowhere::IndexParams::n_trees, MIN_NTREES, MAX_NTREES);
return ConfAdapter::CheckTrain(oricfg, mode);
}
bool
ANNOYConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
CheckIntByRange(knowhere::IndexParams::search_k, std::numeric_limits<int64_t>::min(),
std::numeric_limits<int64_t>::max());
return ConfAdapter::CheckSearch(oricfg, type, mode);
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,124 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <algorithm>
#include <memory>
#include <vector>
#include "knowhere/common/Config.h"
#include "knowhere/index/IndexType.h"
namespace milvus {
namespace knowhere {
class ConfAdapter {
public:
virtual bool
CheckTrain(Config& oricfg, const IndexMode mode);
virtual bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode);
};
using ConfAdapterPtr = std::shared_ptr<ConfAdapter>;
class IVFConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class IVFSQConfAdapter : public IVFConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
};
class IVFPQConfAdapter : public IVFConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
static void
GetValidMList(int64_t dimension, std::vector<int64_t>& resset);
};
class NSGConfAdapter : public IVFConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class BinIDMAPConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
};
class BinIVFConfAdapter : public IVFConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
};
class HNSWConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class ANNOYConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class RHNSWFlatConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class RHNSWPQConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
class RHNSWSQConfAdapter : public ConfAdapter {
public:
bool
CheckTrain(Config& oricfg, const IndexMode mode) override;
bool
CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,59 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/ConfAdapterMgr.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
namespace milvus {
namespace knowhere {
ConfAdapterPtr
AdapterMgr::GetAdapter(const IndexType type) {
if (!init_) {
RegisterAdapter();
}
try {
return collection_.at(type)();
} catch (...) {
KNOWHERE_THROW_MSG("Can not find confadapter: " + type);
}
}
#define REGISTER_CONF_ADAPTER(T, TYPE, NAME) static AdapterMgr::register_t<T> reg_##NAME##_(TYPE)
void
AdapterMgr::RegisterAdapter() {
init_ = true;
REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_FAISS_IDMAP, idmap_adapter);
REGISTER_CONF_ADAPTER(IVFConfAdapter, IndexEnum::INDEX_FAISS_IVFFLAT, ivf_adapter);
REGISTER_CONF_ADAPTER(IVFPQConfAdapter, IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_adapter);
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq8_adapter);
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8H, ivfsq8h_adapter);
REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IDMAP, idmap_bin_adapter);
REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivf_bin_adapter);
REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexEnum::INDEX_NSG, nsg_adapter);
#ifdef MILVUS_SUPPORT_SPTAG
REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_SPTAG_KDT_RNT, sptag_kdt_adapter);
REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_SPTAG_BKT_RNT, sptag_bkt_adapter);
#endif
REGISTER_CONF_ADAPTER(HNSWConfAdapter, IndexEnum::INDEX_HNSW, hnsw_adapter);
REGISTER_CONF_ADAPTER(ANNOYConfAdapter, IndexEnum::INDEX_ANNOY, annoy_adapter);
REGISTER_CONF_ADAPTER(RHNSWFlatConfAdapter, IndexEnum::INDEX_RHNSWFlat, rhnswflat_adapter);
REGISTER_CONF_ADAPTER(RHNSWPQConfAdapter, IndexEnum::INDEX_RHNSWPQ, rhnswpq_adapter);
REGISTER_CONF_ADAPTER(RHNSWSQConfAdapter, IndexEnum::INDEX_RHNSWSQ, rhnswsq_adapter);
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,51 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <unordered_map>
#include "knowhere/index/IndexType.h"
#include "knowhere/index/vector_index/ConfAdapter.h"
namespace milvus {
namespace knowhere {
class AdapterMgr {
public:
template <typename T>
struct register_t {
explicit register_t(const IndexType type) {
AdapterMgr::GetInstance().collection_[type] = ([] { return std::make_shared<T>(); });
}
};
static AdapterMgr&
GetInstance() {
static AdapterMgr instance;
return instance;
}
ConfAdapterPtr
GetAdapter(const IndexType indexType);
void
RegisterAdapter();
protected:
bool init_ = false;
std::unordered_map<IndexType, std::function<ConfAdapterPtr()>> collection_;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,51 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <faiss/index_io.h>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace milvus {
namespace knowhere {
BinarySet
FaissBaseBinaryIndex::SerializeImpl(const IndexType& type) {
try {
faiss::IndexBinary* index = index_.get();
MemoryIOWriter writer;
faiss::write_index_binary(index, &writer);
std::shared_ptr<uint8_t[]> data(writer.data_);
BinarySet res_set;
res_set.Append("BinaryIVF", data, writer.rp);
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
FaissBaseBinaryIndex::LoadImpl(const BinarySet& index_binary, const IndexType& type) {
auto binary = index_binary.GetByName("BinaryIVF");
MemoryIOReader reader;
reader.total = binary->size;
reader.data_ = binary->data.get();
faiss::IndexBinary* index = faiss::read_index_binary(&reader);
index_.reset(index);
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,42 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <utility>
#include <faiss/IndexBinary.h>
#include "knowhere/common/BinarySet.h"
#include "knowhere/common/Dataset.h"
#include "knowhere/index/IndexType.h"
namespace milvus {
namespace knowhere {
class FaissBaseBinaryIndex {
protected:
explicit FaissBaseBinaryIndex(std::shared_ptr<faiss::IndexBinary> index) : index_(std::move(index)) {
}
virtual BinarySet
SerializeImpl(const IndexType& type);
virtual void
LoadImpl(const BinarySet& index_binary, const IndexType& type);
public:
std::shared_ptr<faiss::IndexBinary> index_ = nullptr;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,61 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <faiss/index_io.h>
#include "knowhere/common/Exception.h"
#include "knowhere/index/IndexType.h"
#include "knowhere/index/vector_index/FaissBaseIndex.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace milvus {
namespace knowhere {
BinarySet
FaissBaseIndex::SerializeImpl(const IndexType& type) {
try {
faiss::Index* index = index_.get();
MemoryIOWriter writer;
faiss::write_index(index, &writer);
std::shared_ptr<uint8_t[]> data(writer.data_);
BinarySet res_set;
// TODO(linxj): use virtual func Name() instead of raw string.
res_set.Append("IVF", data, writer.rp);
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
FaissBaseIndex::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
auto binary = binary_set.GetByName("IVF");
MemoryIOReader reader;
reader.total = binary->size;
reader.data_ = binary->data.get();
faiss::Index* index = faiss::read_index(&reader);
index_.reset(index);
SealImpl();
}
void
FaissBaseIndex::SealImpl() {
}
// FaissBaseIndex::~FaissBaseIndex() {}
//
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,44 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <memory>
#include <utility>
#include <faiss/Index.h>
#include "knowhere/common/BinarySet.h"
#include "knowhere/index/IndexType.h"
namespace milvus {
namespace knowhere {
class FaissBaseIndex {
protected:
explicit FaissBaseIndex(std::shared_ptr<faiss::Index> index) : index_(std::move(index)) {
}
virtual BinarySet
SerializeImpl(const IndexType& type);
virtual void
LoadImpl(const BinarySet&, const IndexType& type);
virtual void
SealImpl();
public:
std::shared_ptr<faiss::Index> index_ = nullptr;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,172 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexAnnoy.h"
#include <algorithm>
#include <cassert>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace milvus {
namespace knowhere {
BinarySet
IndexAnnoy::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto metric_type_length = metric_type_.length();
std::shared_ptr<uint8_t[]> metric_type(new uint8_t[metric_type_length]);
memcpy(metric_type.get(), metric_type_.data(), metric_type_.length());
auto dim = Dim();
std::shared_ptr<uint8_t[]> dim_data(new uint8_t[sizeof(uint64_t)]);
memcpy(dim_data.get(), &dim, sizeof(uint64_t));
size_t index_length = index_->get_index_length();
std::shared_ptr<uint8_t[]> index_data(new uint8_t[index_length]);
memcpy(index_data.get(), index_->get_index(), index_length);
BinarySet res_set;
res_set.Append("annoy_metric_type", metric_type, metric_type_length);
res_set.Append("annoy_dim", dim_data, sizeof(uint64_t));
res_set.Append("annoy_index_data", index_data, index_length);
return res_set;
}
void
IndexAnnoy::Load(const BinarySet& index_binary) {
auto metric_type = index_binary.GetByName("annoy_metric_type");
metric_type_.resize(static_cast<size_t>(metric_type->size));
memcpy(metric_type_.data(), metric_type->data.get(), static_cast<size_t>(metric_type->size));
auto dim_data = index_binary.GetByName("annoy_dim");
uint64_t dim;
memcpy(&dim, dim_data->data.get(), static_cast<size_t>(dim_data->size));
if (metric_type_ == Metric::L2) {
index_ = std::make_shared<AnnoyIndex<int64_t, float, ::Euclidean, ::Kiss64Random>>(dim);
} else if (metric_type_ == Metric::IP) {
index_ = std::make_shared<AnnoyIndex<int64_t, float, ::DotProduct, ::Kiss64Random>>(dim);
} else {
KNOWHERE_THROW_MSG("metric not supported " + metric_type_);
}
auto index_data = index_binary.GetByName("annoy_index_data");
char* p = nullptr;
if (!index_->load_index(reinterpret_cast<void*>(index_data->data.get()), index_data->size, &p)) {
std::string error_msg(p);
free(p);
KNOWHERE_THROW_MSG(error_msg);
}
}
void
IndexAnnoy::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
if (index_) {
// it is builded all
LOG_KNOWHERE_DEBUG_ << "IndexAnnoy::BuildAll: index_ has been built!";
return;
}
GET_TENSOR(dataset_ptr)
metric_type_ = config[Metric::TYPE];
if (metric_type_ == Metric::L2) {
index_ = std::make_shared<AnnoyIndex<int64_t, float, ::Euclidean, ::Kiss64Random>>(dim);
} else if (metric_type_ == Metric::IP) {
index_ = std::make_shared<AnnoyIndex<int64_t, float, ::DotProduct, ::Kiss64Random>>(dim);
} else {
KNOWHERE_THROW_MSG("metric not supported " + metric_type_);
}
for (int i = 0; i < rows; ++i) {
index_->add_item(p_ids[i], static_cast<const float*>(p_data) + dim * i);
}
index_->build(config[IndexParams::n_trees].get<int64_t>());
}
DatasetPtr
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
GET_TENSOR_DATA_DIM(dataset_ptr)
auto k = config[meta::TOPK].get<int64_t>();
auto search_k = config[IndexParams::search_k].get<int64_t>();
auto all_num = rows * k;
auto p_id = static_cast<int64_t*>(malloc(all_num * sizeof(int64_t)));
auto p_dist = static_cast<float*>(malloc(all_num * sizeof(float)));
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
#pragma omp parallel for
for (unsigned int i = 0; i < rows; ++i) {
std::vector<int64_t> result;
result.reserve(k);
std::vector<float> distances;
distances.reserve(k);
index_->get_nns_by_vector(static_cast<const float*>(p_data) + i * dim, k, search_k, &result, &distances,
blacklist);
int64_t result_num = result.size();
auto local_p_id = p_id + k * i;
auto local_p_dist = p_dist + k * i;
memcpy(local_p_id, result.data(), result_num * sizeof(int64_t));
memcpy(local_p_dist, distances.data(), result_num * sizeof(float));
for (; result_num < k; result_num++) {
local_p_id[result_num] = -1;
local_p_dist[result_num] = 1.0 / 0.0;
}
}
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
}
int64_t
IndexAnnoy::Count() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->get_n_items();
}
int64_t
IndexAnnoy::Dim() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->get_dim();
}
void
IndexAnnoy::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
index_size_ = index_->cal_size();
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,74 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <mutex>
#include "annoy/src/annoylib.h"
#include "annoy/src/kissrandom.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/VecIndex.h"
namespace milvus {
namespace knowhere {
class IndexAnnoy : public VecIndex {
public:
IndexAnnoy() {
index_type_ = IndexEnum::INDEX_ANNOY;
}
BinarySet
Serialize(const Config& config) override;
void
Load(const BinarySet& index_binary) override;
void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override;
void
Train(const DatasetPtr& dataset_ptr, const Config& config) override {
KNOWHERE_THROW_MSG("Annoy not support build item dynamically, please invoke BuildAll interface.");
}
void
Add(const DatasetPtr& dataset_ptr, const Config& config) override {
KNOWHERE_THROW_MSG("Annoy not support add item dynamically, please invoke BuildAll interface.");
}
void
AddWithoutIds(const DatasetPtr&, const Config&) override {
KNOWHERE_THROW_MSG("Incremental index is not supported");
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
int64_t
Count() override;
int64_t
Dim() override;
void
UpdateIndexSize() override;
private:
MetricType metric_type_;
std::shared_ptr<AnnoyIndexInterface<int64_t, float>> index_ = nullptr;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,163 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
#include <faiss/IndexBinaryFlat.h>
#include <faiss/MetaIndexes.h>
#include <faiss/index_factory.h>
#include <string>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
namespace milvus {
namespace knowhere {
BinarySet
BinaryIDMAP::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
std::lock_guard<std::mutex> lk(mutex_);
return SerializeImpl(index_type_);
}
void
BinaryIDMAP::Load(const BinarySet& index_binary) {
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(index_binary, index_type_);
}
DatasetPtr
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GET_TENSOR_DATA(dataset_ptr)
auto k = config[meta::TOPK].get<int64_t>();
auto elems = rows * k;
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
}
int64_t
BinaryIDMAP::Count() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->ntotal;
}
int64_t
BinaryIDMAP::Dim() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->d;
}
void
BinaryIDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
std::lock_guard<std::mutex> lk(mutex_);
GET_TENSOR_DATA_ID(dataset_ptr)
index_->add_with_ids(rows, reinterpret_cast<const uint8_t*>(p_data), p_ids);
}
void
BinaryIDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) {
// users will assign the metric type when querying
// so we let Tanimoto be the default type
constexpr faiss::MetricType metric_type = faiss::METRIC_Tanimoto;
const char* desc = "BFlat";
auto dim = config[meta::DIM].get<int64_t>();
auto index = faiss::index_binary_factory(dim, desc, metric_type);
index_.reset(index);
}
const uint8_t*
BinaryIDMAP::GetRawVectors() {
try {
auto file_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get());
auto flat_index = dynamic_cast<faiss::IndexBinaryFlat*>(file_index->index);
return flat_index->xb.data();
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
const int64_t*
BinaryIDMAP::GetRawIds() {
try {
auto file_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get());
return file_index->id_map.data();
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
std::lock_guard<std::mutex> lk(mutex_);
GET_TENSOR_DATA(dataset_ptr)
std::vector<int64_t> new_ids(rows);
for (int i = 0; i < rows; ++i) {
new_ids[i] = i;
}
index_->add_with_ids(rows, reinterpret_cast<const uint8_t*>(p_data), new_ids.data());
}
void
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
// assign the metric type
auto bin_flat_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get())->index;
bin_flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto i_distances = reinterpret_cast<int32_t*>(distances);
bin_flat_index->search(n, data, k, i_distances, labels, bitset_);
// if hamming, it need transform int32 to float
if (bin_flat_index->metric_type == faiss::METRIC_Hamming) {
int64_t num = n * k;
for (int64_t i = 0; i < num; i++) {
distances[i] = static_cast<float>(i_distances[i]);
}
}
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,81 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
#include "knowhere/index/vector_index/VecIndex.h"
namespace milvus {
namespace knowhere {
class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
public:
BinaryIDMAP() : FaissBaseBinaryIndex(nullptr) {
index_type_ = IndexEnum::INDEX_FAISS_BIN_IDMAP;
}
explicit BinaryIDMAP(std::shared_ptr<faiss::IndexBinary> index) : FaissBaseBinaryIndex(std::move(index)) {
index_type_ = IndexEnum::INDEX_FAISS_BIN_IDMAP;
}
BinarySet
Serialize(const Config&) override;
void
Load(const BinarySet&) override;
void
Train(const DatasetPtr&, const Config&) override;
void
Add(const DatasetPtr&, const Config&) override;
void
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
int64_t
Count() override;
int64_t
Dim() override;
int64_t
IndexSize() override {
return Count() * Dim() / 8;
}
virtual const uint8_t*
GetRawVectors();
virtual const int64_t*
GetRawIds();
protected:
virtual void
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
protected:
std::mutex mutex_;
};
using BinaryIDMAPPtr = std::shared_ptr<BinaryIDMAP>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,157 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexBinaryIVF.h"
#include <faiss/IndexBinaryFlat.h>
#include <faiss/IndexBinaryIVF.h>
#include <chrono>
#include <string>
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
namespace milvus {
namespace knowhere {
using stdclock = std::chrono::high_resolution_clock;
BinarySet
BinaryIVF::Serialize(const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
std::lock_guard<std::mutex> lk(mutex_);
return SerializeImpl(index_type_);
}
void
BinaryIVF::Load(const BinarySet& index_binary) {
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(index_binary, index_type_);
}
DatasetPtr
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
GET_TENSOR_DATA(dataset_ptr)
try {
auto k = config[meta::TOPK].get<int64_t>();
auto elems = rows * k;
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
} catch (faiss::FaissException& e) {
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
int64_t
BinaryIVF::Count() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->ntotal;
}
int64_t
BinaryIVF::Dim() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->d;
}
void
BinaryIVF::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
auto bin_ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
auto nb = bin_ivf_index->invlists->compute_ntotal();
auto nlist = bin_ivf_index->nlist;
auto code_size = bin_ivf_index->code_size;
// binary ivf codes, ids and quantizer
index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size;
}
void
BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR(dataset_ptr)
int64_t nlist = config[IndexParams::nlist];
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
faiss::IndexBinary* coarse_quantizer = new faiss::IndexBinaryFlat(dim, metric_type);
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, nlist, metric_type);
index->train(rows, static_cast<const uint8_t*>(p_data));
index->add_with_ids(rows, static_cast<const uint8_t*>(p_data), p_ids);
index_ = index;
}
std::shared_ptr<faiss::IVFSearchParameters>
BinaryIVF::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFSearchParameters>();
params->nprobe = config[IndexParams::nprobe];
// params->max_codes = config["max_code"];
return params;
}
void
BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
stdclock::time_point before = stdclock::now();
auto i_distances = reinterpret_cast<int32_t*>(distances);
index_->search(n, data, k, i_distances, labels, bitset_);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
faiss::indexIVF_stats.quantization_time = 0;
faiss::indexIVF_stats.search_time = 0;
// if hamming, it need transform int32 to float
if (ivf_index->metric_type == faiss::METRIC_Hamming) {
int64_t num = n * k;
for (int64_t i = 0; i < num; i++) {
distances[i] = static_cast<float>(i_distances[i]);
}
}
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,88 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include <faiss/IndexIVF.h>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
#include "knowhere/index/vector_index/VecIndex.h"
namespace milvus {
namespace knowhere {
class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
public:
BinaryIVF() : FaissBaseBinaryIndex(nullptr) {
index_type_ = IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
}
explicit BinaryIVF(std::shared_ptr<faiss::IndexBinary> index) : FaissBaseBinaryIndex(std::move(index)) {
index_type_ = IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
}
BinarySet
Serialize(const Config& config) override;
void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override {
Train(dataset_ptr, config);
}
void
Load(const BinarySet& index_binary) override;
void
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
void
Add(const DatasetPtr& dataset_ptr, const Config& config) override {
KNOWHERE_THROW_MSG("not support yet");
}
void
AddWithoutIds(const DatasetPtr&, const Config&) override {
KNOWHERE_THROW_MSG("AddWithoutIds is not supported");
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
int64_t
Count() override;
int64_t
Dim() override;
void
UpdateIndexSize() override;
protected:
virtual std::shared_ptr<faiss::IVFSearchParameters>
GenParams(const Config& config);
virtual void
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
protected:
std::mutex mutex_;
};
using BinaryIVFIndexPtr = std::shared_ptr<BinaryIVF>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,222 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexHNSW.h"
#include <algorithm>
#include <cassert>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "faiss/BuilderSuspend.h"
#include "hnswlib/hnswalg.h"
#include "hnswlib/space_ip.h"
#include "hnswlib/space_l2.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace milvus {
namespace knowhere {
// void
// normalize_vector(float* data, float* norm_array, size_t dim) {
// float norm = 0.0f;
// for (int i = 0; i < dim; i++) norm += data[i] * data[i];
// norm = 1.0f / (sqrtf(norm) + 1e-30f);
// for (int i = 0; i < dim; i++) norm_array[i] = data[i] * norm;
// }
BinarySet
IndexHNSW::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
try {
MemoryIOWriter writer;
index_->saveIndex(writer);
std::shared_ptr<uint8_t[]> data(writer.data_);
BinarySet res_set;
res_set.Append("HNSW", data, writer.rp);
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexHNSW::Load(const BinarySet& index_binary) {
try {
auto binary = index_binary.GetByName("HNSW");
MemoryIOReader reader;
reader.total = binary->size;
reader.data_ = binary->data.get();
hnswlib::SpaceInterface<float>* space = nullptr;
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space);
index_->loadIndex(reader);
normalize = index_->metric_type_ == 1; // 1 == InnerProduct
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) {
try {
auto dim = dataset_ptr->Get<int64_t>(meta::DIM);
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
hnswlib::SpaceInterface<float>* space;
std::string metric_type = config[Metric::TYPE];
if (metric_type == Metric::L2) {
space = new hnswlib::L2Space(dim);
} else if (metric_type == Metric::IP) {
space = new hnswlib::InnerProductSpace(dim);
normalize = true;
} else {
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
}
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space, rows, config[IndexParams::M].get<int64_t>(),
config[IndexParams::efConstruction].get<int64_t>());
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
std::lock_guard<std::mutex> lk(mutex_);
GET_TENSOR_DATA_ID(dataset_ptr)
// if (normalize) {
// std::vector<float> ep_norm_vector(Dim());
// normalize_vector((float*)(p_data), ep_norm_vector.data(), Dim());
// index_->addPoint((void*)(ep_norm_vector.data()), p_ids[0]);
// #pragma omp parallel for
// for (int i = 1; i < rows; ++i) {
// std::vector<float> norm_vector(Dim());
// normalize_vector((float*)(p_data + Dim() * i), norm_vector.data(), Dim());
// index_->addPoint((void*)(norm_vector.data()), p_ids[i]);
// }
// } else {
// index_->addPoint((void*)(p_data), p_ids[0]);
// #pragma omp parallel for
// for (int i = 1; i < rows; ++i) {
// index_->addPoint((void*)(p_data + Dim() * i), p_ids[i]);
// }
// }
index_->addPoint(p_data, p_ids[0]);
#pragma omp parallel for
for (int i = 1; i < rows; ++i) {
faiss::BuilderSuspend::check_wait();
index_->addPoint((reinterpret_cast<const float*>(p_data) + Dim() * i), p_ids[i]);
}
}
DatasetPtr
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
GET_TENSOR_DATA(dataset_ptr)
size_t k = config[meta::TOPK].get<int64_t>();
size_t id_size = sizeof(int64_t) * k;
size_t dist_size = sizeof(float) * k;
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
auto p_dist = static_cast<float*>(malloc(dist_size * rows));
index_->setEf(config[IndexParams::ef]);
using P = std::pair<float, int64_t>;
auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; };
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
#pragma omp parallel for
for (unsigned int i = 0; i < rows; ++i) {
std::vector<P> ret;
const float* single_query = reinterpret_cast<const float*>(p_data) + i * Dim();
// if (normalize) {
// std::vector<float> norm_vector(Dim());
// normalize_vector((float*)(single_query), norm_vector.data(), Dim());
// ret = index_->searchKnn((float*)(norm_vector.data()), config[meta::TOPK].get<int64_t>(), compare);
// } else {
// ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
// }
ret = index_->searchKnn(single_query, k, compare, blacklist);
while (ret.size() < k) {
ret.emplace_back(std::make_pair(-1, -1));
}
std::vector<float> dist;
std::vector<int64_t> ids;
if (normalize) {
std::transform(ret.begin(), ret.end(), std::back_inserter(dist),
[](const std::pair<float, int64_t>& e) { return float(1 - e.first); });
} else {
std::transform(ret.begin(), ret.end(), std::back_inserter(dist),
[](const std::pair<float, int64_t>& e) { return e.first; });
}
std::transform(ret.begin(), ret.end(), std::back_inserter(ids),
[](const std::pair<float, int64_t>& e) { return e.second; });
memcpy(p_dist + i * k, dist.data(), dist_size);
memcpy(p_id + i * k, ids.data(), id_size);
}
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
}
int64_t
IndexHNSW::Count() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->cur_element_count;
}
int64_t
IndexHNSW::Dim() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return (*static_cast<size_t*>(index_->dist_func_param_));
}
void
IndexHNSW::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
index_size_ = index_->cal_size();
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,67 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <mutex>
#include "hnswlib/hnswlib.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/VecIndex.h"
namespace milvus {
namespace knowhere {
class IndexHNSW : public VecIndex {
public:
IndexHNSW() {
index_type_ = IndexEnum::INDEX_HNSW;
}
BinarySet
Serialize(const Config& config) override;
void
Load(const BinarySet& index_binary) override;
void
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
void
Add(const DatasetPtr& dataset_ptr, const Config& config) override;
void
AddWithoutIds(const DatasetPtr&, const Config&) override {
KNOWHERE_THROW_MSG("Incremental index is not supported");
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
int64_t
Count() override;
int64_t
Dim() override;
void
UpdateIndexSize() override;
private:
bool normalize = false;
std::mutex mutex_;
std::shared_ptr<hnswlib::HierarchicalNSW<float>> index_;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,234 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include <faiss/AutoTune.h>
#include <faiss/IndexFlat.h>
#include <faiss/MetaIndexes.h>
#include <faiss/clone_index.h>
#include <faiss/index_factory.h>
#include <faiss/index_io.h>
#ifdef MILVUS_GPU_VERSION
#include <faiss/gpu/GpuCloner.h>
#endif
#include <string>
#include <vector>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h"
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#endif
namespace milvus {
namespace knowhere {
BinarySet
IDMAP::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
std::lock_guard<std::mutex> lk(mutex_);
return SerializeImpl(index_type_);
}
void
IDMAP::Load(const BinarySet& binary_set) {
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(binary_set, index_type_);
}
void
IDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) {
// users will assign the metric type when querying
// so we let L2 be the default type
constexpr faiss::MetricType metric_type = faiss::METRIC_L2;
const char* desc = "IDMap,Flat";
auto dim = config[meta::DIM].get<int64_t>();
auto index = faiss::index_factory(dim, desc, metric_type);
index_.reset(index);
}
void
IDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
std::lock_guard<std::mutex> lk(mutex_);
GET_TENSOR_DATA_ID(dataset_ptr)
index_->add_with_ids(rows, reinterpret_cast<const float*>(p_data), p_ids);
}
void
IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
std::lock_guard<std::mutex> lk(mutex_);
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
auto p_data = dataset_ptr->Get<const void*>(meta::TENSOR);
// TODO: caiyd need check
std::vector<int64_t> new_ids(rows);
for (int i = 0; i < rows; ++i) {
new_ids[i] = i;
}
index_->add_with_ids(rows, reinterpret_cast<const float*>(p_data), new_ids.data());
}
DatasetPtr
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GET_TENSOR_DATA(dataset_ptr)
auto k = config[meta::TOPK].get<int64_t>();
auto elems = rows * k;
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
}
#if 0
DatasetPtr
IDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
// GETTENSOR(dataset)
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
int64_t k = config[meta::TOPK].get<int64_t>();
auto elems = rows * k;
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
// todo: enable search by id (zhiru)
// auto blacklist = dataset_ptr->Get<faiss::ConcurrentBitsetPtr>("bitset");
// index_->searchById(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, blacklist);
index_->search_by_id(rows, p_data, k, p_dist, p_id, bitset_);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
}
#endif
int64_t
IDMAP::Count() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->ntotal;
}
int64_t
IDMAP::Dim() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->d;
}
VecIndexPtr
IDMAP::CopyCpuToGpu(const int64_t device_id, const Config& config) {
#ifdef MILVUS_GPU_VERSION
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
std::shared_ptr<faiss::Index> device_index;
device_index.reset(gpu_index);
return std::make_shared<GPUIDMAP>(device_index, device_id, res);
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
}
#else
KNOWHERE_THROW_MSG("Calling IDMAP::CopyCpuToGpu when we are using CPU version");
#endif
}
const float*
IDMAP::GetRawVectors() {
try {
auto file_index = dynamic_cast<faiss::IndexIDMap*>(index_.get());
auto flat_index = dynamic_cast<faiss::IndexFlat*>(file_index->index);
return flat_index->xb.data();
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
const int64_t*
IDMAP::GetRawIds() {
try {
auto file_index = dynamic_cast<faiss::IndexIDMap*>(index_.get());
return file_index->id_map.data();
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
#if 0
DatasetPtr
IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
// GETTENSOR(dataset)
// auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
auto elems = dataset_ptr->Get<int64_t>(meta::DIM);
size_t p_x_size = sizeof(float) * elems;
auto p_x = (float*)malloc(p_x_size);
index_->get_vector_by_id(1, p_data, p_x, bitset_);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::TENSOR, p_x);
return ret_ds;
}
#endif
void
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
// assign the metric type
auto flat_index = dynamic_cast<faiss::IndexIDMap*>(index_.get())->index;
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
index_->search(n, data, k, distances, labels, bitset_);
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,92 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <memory>
#include <utility>
#include "knowhere/index/vector_index/FaissBaseIndex.h"
#include "knowhere/index/vector_index/VecIndex.h"
namespace milvus {
namespace knowhere {
class IDMAP : public VecIndex, public FaissBaseIndex {
public:
IDMAP() : FaissBaseIndex(nullptr) {
index_type_ = IndexEnum::INDEX_FAISS_IDMAP;
}
explicit IDMAP(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
index_type_ = IndexEnum::INDEX_FAISS_IDMAP;
}
BinarySet
Serialize(const Config&) override;
void
Load(const BinarySet&) override;
void
Train(const DatasetPtr&, const Config&) override;
void
Add(const DatasetPtr&, const Config&) override;
void
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
#if 0
DatasetPtr
QueryById(const DatasetPtr& dataset, const Config& config) override;
#endif
int64_t
Count() override;
int64_t
Dim() override;
int64_t
IndexSize() override {
return Count() * Dim() * sizeof(FloatType);
}
#if 0
DatasetPtr
GetVectorById(const DatasetPtr& dataset, const Config& config) override;
#endif
VecIndexPtr
CopyCpuToGpu(const int64_t, const Config&);
virtual const float*
GetRawVectors();
virtual const int64_t*
GetRawIds();
protected:
virtual void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
protected:
std::mutex mutex_;
};
using IDMAPPtr = std::shared_ptr<IDMAP>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,349 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <faiss/AutoTune.h>
#include <faiss/IVFlib.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVF.h>
#include <faiss/IndexIVFFlat.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/clone_index.h>
#include <faiss/index_factory.h>
#include <faiss/index_io.h>
#ifdef MILVUS_GPU_VERSION
#include <faiss/gpu/GpuAutoTune.h>
#include <faiss/gpu/GpuCloner.h>
#endif
#include <chrono>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "faiss/BuilderSuspend.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#endif
namespace milvus {
namespace knowhere {
using stdclock = std::chrono::high_resolution_clock;
BinarySet
IVF::Serialize(const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
std::lock_guard<std::mutex> lk(mutex_);
return SerializeImpl(index_type_);
}
void
IVF::Load(const BinarySet& binary_set) {
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(binary_set, index_type_);
}
void
IVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA_DIM(dataset_ptr)
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
auto nlist = config[IndexParams::nlist].get<int64_t>();
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexIVFFlat(coarse_quantizer, dim, nlist, metric_type));
index_->train(rows, reinterpret_cast<const float*>(p_data));
}
void
IVF::Add(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
std::lock_guard<std::mutex> lk(mutex_);
GET_TENSOR_DATA_ID(dataset_ptr)
index_->add_with_ids(rows, reinterpret_cast<const float*>(p_data), p_ids);
}
void
IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
std::lock_guard<std::mutex> lk(mutex_);
GET_TENSOR_DATA(dataset_ptr)
index_->add(rows, reinterpret_cast<const float*>(p_data));
}
DatasetPtr
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
GET_TENSOR_DATA(dataset_ptr)
try {
auto k = config[meta::TOPK].get<int64_t>();
auto elems = rows * k;
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config);
// std::stringstream ss_res_id, ss_res_dist;
// for (int i = 0; i < 10; ++i) {
// printf("%llu", p_id[i]);
// printf("\n");
// printf("%.6f", p_dist[i]);
// printf("\n");
// ss_res_id << p_id[i] << " ";
// ss_res_dist << p_dist[i] << " ";
// }
// std::cout << std::endl << "after search: " << std::endl;
// std::cout << ss_res_id.str() << std::endl;
// std::cout << ss_res_dist.str() << std::endl << std::endl;
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
} catch (faiss::FaissException& e) {
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
#if 0
DatasetPtr
IVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
try {
int64_t k = config[meta::TOPK].get<int64_t>();
auto elems = rows * k;
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
// todo: enable search by id (zhiru)
// auto blacklist = dataset_ptr->Get<faiss::ConcurrentBitsetPtr>("bitset");
auto index_ivf = std::static_pointer_cast<faiss::IndexIVF>(index_);
index_ivf->search_by_id(rows, p_data, k, p_dist, p_id, bitset_);
// std::stringstream ss_res_id, ss_res_dist;
// for (int i = 0; i < 10; ++i) {
// printf("%llu", res_ids[i]);
// printf("\n");
// printf("%.6f", res_dis[i]);
// printf("\n");
// ss_res_id << res_ids[i] << " ";
// ss_res_dist << res_dis[i] << " ";
// }
// std::cout << std::endl << "after search: " << std::endl;
// std::cout << ss_res_id.str() << std::endl;
// std::cout << ss_res_dist.str() << std::endl << std::endl;
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
} catch (faiss::FaissException& e) {
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
DatasetPtr
IVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
auto elems = dataset_ptr->Get<int64_t>(meta::DIM);
try {
size_t p_x_size = sizeof(float) * elems;
auto p_x = (float*)malloc(p_x_size);
auto index_ivf = std::static_pointer_cast<faiss::IndexIVF>(index_);
index_ivf->get_vector_by_id(1, p_data, p_x, bitset_);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::TENSOR, p_x);
return ret_ds;
} catch (faiss::FaissException& e) {
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
#endif
int64_t
IVF::Count() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->ntotal;
}
int64_t
IVF::Dim() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->d;
}
void
IVF::Seal() {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
SealImpl();
}
void
IVF::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
auto ivf_index = dynamic_cast<faiss::IndexIVFFlat*>(index_.get());
auto nb = ivf_index->invlists->compute_ntotal();
auto nlist = ivf_index->nlist;
auto code_size = ivf_index->code_size;
// ivf codes, ivf ids and quantizer
index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size;
}
VecIndexPtr
IVF::CopyCpuToGpu(const int64_t device_id, const Config& config) {
#ifdef MILVUS_GPU_VERSION
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
std::shared_ptr<faiss::Index> device_index;
device_index.reset(gpu_index);
return std::make_shared<GPUIVF>(device_index, device_id, res);
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
}
#else
KNOWHERE_THROW_MSG("Calling IVF::CopyCpuToGpu when we are using CPU version");
#endif
}
void
IVF::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config) {
int64_t K = k + 1;
auto ntotal = Count();
size_t dim = config[meta::DIM];
auto batch_size = 1000;
auto tail_batch_size = ntotal % batch_size;
auto batch_search_count = ntotal / batch_size;
auto total_search_count = tail_batch_size == 0 ? batch_search_count : batch_search_count + 1;
std::vector<float> res_dis(K * batch_size);
graph.resize(ntotal);
GraphType res_vec(total_search_count);
for (int i = 0; i < total_search_count; ++i) {
// it is usually used in NSG::train, to check BuilderSuspend
faiss::BuilderSuspend::check_wait();
auto b_size = (i == (total_search_count - 1)) && tail_batch_size != 0 ? tail_batch_size : batch_size;
auto& res = res_vec[i];
res.resize(K * b_size);
const float* xq = data + batch_size * dim * i;
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config);
for (int j = 0; j < b_size; ++j) {
auto& node = graph[batch_size * i + j];
node.resize(k);
auto start_pos = j * K + 1;
for (int m = 0, cursor = start_pos; m < k && cursor < start_pos + k; ++m, ++cursor) {
node[m] = res[cursor];
}
}
}
}
std::shared_ptr<faiss::IVFSearchParameters>
IVF::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFSearchParameters>();
params->nprobe = config[IndexParams::nprobe];
// params->max_codes = config["max_codes"];
return params;
}
void
IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
stdclock::time_point before = stdclock::now();
if (params->nprobe > 1 && n <= 4) {
ivf_index->parallel_mode = 1;
} else {
ivf_index->parallel_mode = 0;
}
ivf_index->search(n, data, k, distances, labels, bitset_);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
faiss::indexIVF_stats.quantization_time = 0;
faiss::indexIVF_stats.search_time = 0;
}
void
IVF::SealImpl() {
#ifdef MILVUS_GPU_VERSION
faiss::Index* index = index_.get();
auto idx = dynamic_cast<faiss::IndexIVF*>(index);
if (idx != nullptr) {
idx->to_readonly();
}
#endif
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,101 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include <faiss/IndexIVF.h>
#include "knowhere/common/Typedef.h"
#include "knowhere/index/vector_index/FaissBaseIndex.h"
#include "knowhere/index/vector_index/VecIndex.h"
namespace milvus {
namespace knowhere {
class IVF : public VecIndex, public FaissBaseIndex {
public:
IVF() : FaissBaseIndex(nullptr) {
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
}
explicit IVF(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT;
}
BinarySet
Serialize(const Config&) override;
void
Load(const BinarySet&) override;
void
Train(const DatasetPtr&, const Config&) override;
void
Add(const DatasetPtr&, const Config&) override;
void
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
#if 0
DatasetPtr
QueryById(const DatasetPtr& dataset, const Config& config) override;
#endif
int64_t
Count() override;
int64_t
Dim() override;
void
UpdateIndexSize() override;
#if 0
DatasetPtr
GetVectorById(const DatasetPtr& dataset, const Config& config) override;
#endif
virtual void
Seal();
virtual VecIndexPtr
CopyCpuToGpu(const int64_t, const Config&);
virtual void
GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config);
protected:
virtual std::shared_ptr<faiss::IVFSearchParameters>
GenParams(const Config&);
virtual void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
void
SealImpl() override;
protected:
std::mutex mutex_;
};
using IVFPtr = std::shared_ptr<IVF>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,100 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <string>
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/clone_index.h>
#ifdef MILVUS_GPU_VERSION
#include <faiss/gpu/GpuCloner.h>
#endif
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h"
#endif
namespace milvus {
namespace knowhere {
void
IVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA_DIM(dataset_ptr)
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexIVFPQ(
coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(), config[IndexParams::m].get<int64_t>(),
config[IndexParams::nbits].get<int64_t>(), metric_type));
index_->train(rows, reinterpret_cast<const float*>(p_data));
}
VecIndexPtr
IVFPQ::CopyCpuToGpu(const int64_t device_id, const Config& config) {
#ifdef MILVUS_GPU_VERSION
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
std::shared_ptr<faiss::Index> device_index;
device_index.reset(gpu_index);
return std::make_shared<GPUIVFPQ>(device_index, device_id, res);
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
}
#else
KNOWHERE_THROW_MSG("Calling IVFPQ::CopyCpuToGpu when we are using CPU version");
#endif
}
std::shared_ptr<faiss::IVFSearchParameters>
IVFPQ::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
params->nprobe = config[IndexParams::nprobe];
// params->scan_table_threshold = config["scan_table_threhold"]
// params->polysemous_ht = config["polysemous_ht"]
// params->max_codes = config["max_codes"]
return params;
}
void
IVFPQ::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
auto ivfpq_index = dynamic_cast<faiss::IndexIVFPQ*>(index_.get());
auto nb = ivfpq_index->invlists->compute_ntotal();
auto code_size = ivfpq_index->code_size;
auto pq = ivfpq_index->pq;
auto nlist = ivfpq_index->nlist;
auto d = ivfpq_index->d;
// ivf codes, ivf ids and quantizer
auto capacity = nb * code_size + nb * sizeof(int64_t) + nlist * d * sizeof(float);
auto centroid_table = pq.M * pq.ksub * pq.dsub * sizeof(float);
auto precomputed_table = nlist * pq.M * pq.ksub * sizeof(float);
if (precomputed_table > ivfpq_index->precomputed_table_max_bytes) {
// will not precompute table
precomputed_table = 0;
}
index_size_ = capacity + centroid_table + precomputed_table;
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,49 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <memory>
#include <utility>
#include "knowhere/index/vector_index/IndexIVF.h"
namespace milvus {
namespace knowhere {
class IVFPQ : public IVF {
public:
IVFPQ() : IVF() {
index_type_ = IndexEnum::INDEX_FAISS_IVFPQ;
}
explicit IVFPQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
index_type_ = IndexEnum::INDEX_FAISS_IVFPQ;
}
void
Train(const DatasetPtr&, const Config&) override;
VecIndexPtr
CopyCpuToGpu(const int64_t, const Config&) override;
void
UpdateIndexSize() override;
protected:
std::shared_ptr<faiss::IVFSearchParameters>
GenParams(const Config& config) override;
};
using IVFPQPtr = std::shared_ptr<IVFPQ>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,88 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <memory>
#include <string>
#ifdef MILVUS_GPU_VERSION
#include <faiss/gpu/GpuAutoTune.h>
#include <faiss/gpu/GpuCloner.h>
#endif
#include <faiss/IndexFlat.h>
#include <faiss/IndexScalarQuantizer.h>
#include <faiss/clone_index.h>
#include <faiss/index_factory.h>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h"
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#endif
namespace milvus {
namespace knowhere {
void
IVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
GET_TENSOR_DATA_DIM(dataset_ptr)
// std::stringstream index_type;
// index_type << "IVF" << config[IndexParams::nlist] << ","
// << "SQ" << config[IndexParams::nbits];
// index_ = std::shared_ptr<faiss::Index>(
// faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>())));
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexIVFScalarQuantizer(
coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(), faiss::QuantizerType::QT_8bit, metric_type));
index_->train(rows, reinterpret_cast<const float*>(p_data));
}
VecIndexPtr
IVFSQ::CopyCpuToGpu(const int64_t device_id, const Config& config) {
#ifdef MILVUS_GPU_VERSION
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
std::shared_ptr<faiss::Index> device_index;
device_index.reset(gpu_index);
return std::make_shared<GPUIVFSQ>(device_index, device_id, res);
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
}
#else
KNOWHERE_THROW_MSG("Calling IVFSQ::CopyCpuToGpu when we are using CPU version");
#endif
}
void
IVFSQ::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
auto ivfsq_index = dynamic_cast<faiss::IndexIVFScalarQuantizer*>(index_.get());
auto nb = ivfsq_index->invlists->compute_ntotal();
auto code_size = ivfsq_index->code_size;
auto nlist = ivfsq_index->nlist;
auto d = ivfsq_index->d;
// ivf codes, ivf ids, sq trained vectors and quantizer
index_size_ = nb * code_size + nb * sizeof(int64_t) + 2 * d * sizeof(float) + nlist * d * sizeof(float);
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,45 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <memory>
#include <utility>
#include "knowhere/index/vector_index/IndexIVF.h"
namespace milvus {
namespace knowhere {
class IVFSQ : public IVF {
public:
IVFSQ() : IVF() {
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8;
}
explicit IVFSQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8;
}
void
Train(const DatasetPtr&, const Config&) override;
VecIndexPtr
CopyCpuToGpu(const int64_t, const Config&) override;
void
UpdateIndexSize() override;
};
using IVFSQPtr = std::shared_ptr<IVFSQ>;
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,181 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <string>
#include "knowhere/common/Exception.h"
#include "knowhere/common/Timer.h"
#include "knowhere/index/IndexType.h"
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/impl/nsg/NSG.h"
#include "knowhere/index/vector_index/impl/nsg/NSGIO.h"
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h"
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
#include "knowhere/index/vector_index/helpers/Cloner.h"
#endif
namespace milvus {
namespace knowhere {
BinarySet
NSG::Serialize(const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
try {
std::lock_guard<std::mutex> lk(mutex_);
impl::NsgIndex* index = index_.get();
MemoryIOWriter writer;
impl::write_index(index, writer);
std::shared_ptr<uint8_t[]> data(writer.data_);
BinarySet res_set;
res_set.Append("NSG", data, writer.rp);
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
NSG::Load(const BinarySet& index_binary) {
try {
std::lock_guard<std::mutex> lk(mutex_);
auto binary = index_binary.GetByName("NSG");
MemoryIOReader reader;
reader.total = binary->size;
reader.data_ = binary->data.get();
auto index = impl::read_index(reader);
index_.reset(index);
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
DatasetPtr
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
GET_TENSOR_DATA_DIM(dataset_ptr)
try {
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
impl::SearchParams s_params;
s_params.search_length = config[IndexParams::search_length];
s_params.k = config[meta::TOPK];
{
std::lock_guard<std::mutex> lk(mutex_);
index_->Search((float*)p_data, nullptr, rows, dim, config[meta::TOPK].get<int64_t>(), p_dist, p_id,
s_params, blacklist);
}
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
NSG::Train(const DatasetPtr& dataset_ptr, const Config& config) {
auto idmap = std::make_shared<IDMAP>();
idmap->Train(dataset_ptr, config);
idmap->AddWithoutIds(dataset_ptr, config);
impl::Graph knng;
const float* raw_data = idmap->GetRawVectors();
const int64_t k = config[IndexParams::knng].get<int64_t>();
#ifdef MILVUS_GPU_VERSION
const int64_t device_id = config[knowhere::meta::DEVICEID].get<int64_t>();
if (device_id == -1) {
auto preprocess_index = std::make_shared<IVF>();
preprocess_index->Train(dataset_ptr, config);
preprocess_index->AddWithoutIds(dataset_ptr, config);
preprocess_index->GenGraph(raw_data, k, knng, config);
} else {
auto gpu_idx = cloner::CopyCpuToGpu(idmap, device_id, config);
auto gpu_idmap = std::dynamic_pointer_cast<GPUIDMAP>(gpu_idx);
gpu_idmap->GenGraph(raw_data, k, knng, config);
}
#else
auto preprocess_index = std::make_shared<IVF>();
preprocess_index->Train(dataset_ptr, config);
preprocess_index->AddWithoutIds(dataset_ptr, config);
preprocess_index->GenGraph(raw_data, k, knng, config);
#endif
impl::BuildParams b_params;
b_params.candidate_pool_size = config[IndexParams::candidate];
b_params.out_degree = config[IndexParams::out_degree];
b_params.search_length = config[IndexParams::search_length];
GET_TENSOR(dataset_ptr)
impl::NsgIndex::Metric_Type metric;
auto metric_str = config[Metric::TYPE].get<std::string>();
if (metric_str == knowhere::Metric::IP) {
metric = impl::NsgIndex::Metric_Type::Metric_Type_IP;
} else if (metric_str == knowhere::Metric::L2) {
metric = impl::NsgIndex::Metric_Type::Metric_Type_L2;
} else {
KNOWHERE_THROW_MSG("Metric is not supported");
}
index_ = std::make_shared<impl::NsgIndex>(dim, rows, metric);
index_->SetKnnGraph(knng);
index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params);
}
int64_t
NSG::Count() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->ntotal;
}
int64_t
NSG::Dim() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->dimension;
}
void
NSG::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
index_size_ = index_->GetSize();
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,79 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <memory>
#include <vector>
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/VecIndex.h"
namespace milvus {
namespace knowhere {
namespace impl {
class NsgIndex;
}
class NSG : public VecIndex {
public:
explicit NSG(const int64_t gpu_num = -1) : gpu_(gpu_num) {
if (gpu_ >= 0) {
index_mode_ = IndexMode::MODE_GPU;
}
index_type_ = IndexEnum::INDEX_NSG;
}
BinarySet
Serialize(const Config&) override;
void
Load(const BinarySet&) override;
void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override {
Train(dataset_ptr, config);
}
void
Train(const DatasetPtr&, const Config&) override;
void
Add(const DatasetPtr&, const Config&) override {
KNOWHERE_THROW_MSG("Incremental index is not supported");
}
void
AddWithoutIds(const DatasetPtr&, const Config&) override {
KNOWHERE_THROW_MSG("Addwithoutids is not supported");
}
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
int64_t
Count() override;
int64_t
Dim() override;
private:
std::mutex mutex_;
int64_t gpu_;
std::shared_ptr<impl::NsgIndex> index_;
};
using NSGIndexPtr = std::shared_ptr<NSG>();
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,148 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexRHNSW.h"
#include <algorithm>
#include <cassert>
#include <iterator>
#include <utility>
#include <vector>
#include "faiss/BuilderSuspend.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace milvus {
namespace knowhere {
BinarySet
IndexRHNSW::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
try {
MemoryIOWriter writer;
writer.name = this->index_type() + "_Index";
faiss::write_index(index_.get(), &writer);
std::shared_ptr<uint8_t[]> data(writer.data_);
BinarySet res_set;
res_set.Append(writer.name, data, writer.rp);
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSW::Load(const BinarySet& index_binary) {
try {
MemoryIOReader reader;
reader.name = this->index_type() + "_Index";
auto binary = index_binary.GetByName(reader.name);
reader.total = static_cast<size_t>(binary->size);
reader.data_ = binary->data.get();
auto idx = faiss::read_index(&reader);
index_.reset(idx);
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) {
KNOWHERE_THROW_MSG("IndexRHNSW has no implementation of Train, please use IndexRHNSW(Flat/SQ/PQ) instead!");
}
void
IndexRHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GET_TENSOR_DATA(dataset_ptr)
index_->add(rows, reinterpret_cast<const float*>(p_data));
}
DatasetPtr
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
GET_TENSOR_DATA(dataset_ptr)
auto k = config[meta::TOPK].get<int64_t>();
int64_t id_size = sizeof(int64_t) * k;
int64_t dist_size = sizeof(float) * k;
auto p_id = static_cast<int64_t*>(malloc(id_size * rows));
auto p_dist = static_cast<float*>(malloc(dist_size * rows));
for (auto i = 0; i < k * rows; ++i) {
p_id[i] = -1;
p_dist[i] = -1;
}
auto real_index = dynamic_cast<faiss::IndexRHNSW*>(index_.get());
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
real_index->hnsw.efSearch = (config[IndexParams::ef]);
real_index->search(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, blacklist);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
}
int64_t
IndexRHNSW::Count() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->ntotal;
}
int64_t
IndexRHNSW::Dim() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
return index_->d;
}
void
IndexRHNSW::UpdateIndexSize() {
KNOWHERE_THROW_MSG(
"IndexRHNSW has no implementation of UpdateIndexSize, please use IndexRHNSW(Flat/SQ/PQ) instead!");
}
/*
BinarySet
IndexRHNSW::SerializeImpl(const milvus::knowhere::IndexType &type) { return BinarySet(); }
void
IndexRHNSW::SealImpl() {}
void
IndexRHNSW::LoadImpl(const milvus::knowhere::BinarySet &, const milvus::knowhere::IndexType &type) {}
*/
void
IndexRHNSW::AddWithoutIds(const milvus::knowhere::DatasetPtr& dataset, const milvus::knowhere::Config& config) {
KNOWHERE_THROW_MSG("IndexRHNSW has no implementation of AddWithoutIds, please use IndexRHNSW(Flat/SQ/PQ) instead!");
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,67 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <mutex>
#include <utility>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/FaissBaseIndex.h"
#include "knowhere/index/vector_index/VecIndex.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include <faiss/index_io.h>
#include "faiss/IndexRHNSW.h"
namespace milvus {
namespace knowhere {
class IndexRHNSW : public VecIndex, public FaissBaseIndex {
public:
IndexRHNSW() : FaissBaseIndex(nullptr) {
index_type_ = IndexEnum::INVALID;
}
explicit IndexRHNSW(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
index_type_ = IndexEnum::INVALID;
}
BinarySet
Serialize(const Config& config) override;
void
Load(const BinarySet& index_binary) override;
void
Train(const DatasetPtr& dataset_ptr, const Config& config) override;
void
Add(const DatasetPtr& dataset_ptr, const Config& config) override;
void
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
int64_t
Count() override;
int64_t
Dim() override;
void
UpdateIndexSize() override;
};
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,107 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexRHNSWFlat.h"
#include <algorithm>
#include <cassert>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "faiss/BuilderSuspend.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace milvus {
namespace knowhere {
IndexRHNSWFlat::IndexRHNSWFlat(int d, int M, milvus::knowhere::MetricType metric) {
faiss::MetricType mt =
metric == Metric::L2 ? faiss::MetricType::METRIC_L2 : faiss::MetricType::METRIC_INNER_PRODUCT;
index_ = std::shared_ptr<faiss::Index>(new faiss::IndexRHNSWFlat(d, M, mt));
}
BinarySet
IndexRHNSWFlat::Serialize(const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
try {
auto res_set = IndexRHNSW::Serialize(config);
MemoryIOWriter writer;
writer.name = this->index_type() + "_Data";
auto real_idx = dynamic_cast<faiss::IndexRHNSWFlat*>(index_.get());
if (real_idx == nullptr) {
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWFlat*>(index_) failed during Serialize!");
}
auto storage_index = dynamic_cast<faiss::IndexFlat*>(real_idx->storage);
faiss::write_index(storage_index, &writer);
std::shared_ptr<uint8_t[]> data(writer.data_);
res_set.Append(writer.name, data, writer.rp);
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWFlat::Load(const BinarySet& index_binary) {
try {
IndexRHNSW::Load(index_binary);
MemoryIOReader reader;
reader.name = this->index_type() + "_Data";
auto binary = index_binary.GetByName(reader.name);
reader.total = static_cast<size_t>(binary->size);
reader.data_ = binary->data.get();
auto real_idx = dynamic_cast<faiss::IndexRHNSWFlat*>(index_.get());
if (real_idx == nullptr) {
KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWFlat*>(index_) failed during Load!");
}
real_idx->storage = faiss::read_index(&reader);
real_idx->init_hnsw();
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWFlat::Train(const DatasetPtr& dataset_ptr, const Config& config) {
try {
GET_TENSOR_DATA_DIM(dataset_ptr)
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto idx = new faiss::IndexRHNSWFlat(int(dim), config[IndexParams::M], metric_type);
idx->hnsw.efConstruction = config[IndexParams::efConstruction];
index_ = std::shared_ptr<faiss::Index>(idx);
index_->train(rows, reinterpret_cast<const float*>(p_data));
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
IndexRHNSWFlat::UpdateIndexSize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
index_size_ = dynamic_cast<faiss::IndexRHNSWFlat*>(index_.get())->cal_size();
}
} // namespace knowhere
} // namespace milvus

Some files were not shown because too many files have changed in this diff Show More