diff --git a/Makefile b/Makefile index 52e98dd21b..006115dfe8 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,10 @@ useasan = false ifeq (${USE_ASAN}, true) useasan = true endif -opensimd = OFF +use_dynamic_simd = OFF +ifdef USE_DYNAMIC_SIMD + use_dynamic_simd = ${USE_DYNAMIC_SIMD} +endif export GIT_BRANCH=master @@ -149,19 +152,19 @@ generated-proto: download-milvus-proto build-3rdparty build-cpp: generated-proto @echo "Building Milvus cpp library ..." - @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -i ${opensimd}) + @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -y ${use_dynamic_simd}) build-cpp-gpu: generated-proto @echo "Building Milvus cpp gpu library ..." - @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -g -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -i ${opensimd}) + @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -g -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -y ${use_dynamic_simd}) build-cpp-with-unittest: generated-proto @echo "Building Milvus cpp library with unittest ..." - @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -i ${opensimd}) + @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -y ${use_dynamic_simd}) build-cpp-with-coverage: generated-proto @echo "Building Milvus cpp library with coverage and unittest ..." - @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -a ${useasan} -c -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -i ${opensimd}) + @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -a ${useasan} -c -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -y ${use_dynamic_simd}) check-proto-product: generated-proto @(env bash $(PWD)/scripts/check_proto_product.sh) diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt index 77168255b1..dd96e0a2e1 100644 --- a/internal/core/CMakeLists.txt +++ b/internal/core/CMakeLists.txt @@ -29,6 +29,10 @@ if ( MILVUS_GPU_VERSION ) add_definitions(-DMILVUS_GPU_VERSION) endif () +if ( USE_DYNAMIC_SIMD ) + add_definitions(-DUSE_DYNAMIC_SIMD) +endif() + project(core) include(CheckCXXCompilerFlag) if ( APPLE ) diff --git a/internal/core/src/CMakeLists.txt b/internal/core/src/CMakeLists.txt index abb2293020..70bc827fde 100644 --- a/internal/core/src/CMakeLists.txt +++ b/internal/core/src/CMakeLists.txt @@ -35,3 +35,6 @@ add_subdirectory( index ) add_subdirectory( query ) add_subdirectory( segcore ) add_subdirectory( indexbuilder ) +if(USE_DYNAMIC_SIMD) + add_subdirectory( simd ) +endif() diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 7b97a236f1..dd7fd77698 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -140,7 +140,7 @@ using IndexType = knowhere::IndexType; // Plus 1 because we can't use greater(>) symbol constexpr size_t REF_SIZE_THRESHOLD = 16 + 1; -using BitSetBlockType = BitsetType::block_type; +using BitsetBlockType = BitsetType::block_type; constexpr size_t BITSET_BLOCK_SIZE = sizeof(BitsetType::block_type); constexpr size_t BITSET_BLOCK_BIT_SIZE = sizeof(BitsetType::block_type) * 8; template diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt index ecbe1db36c..c476e476a9 100644 --- a/internal/core/src/query/CMakeLists.txt +++ b/internal/core/src/query/CMakeLists.txt @@ -30,4 +30,8 @@ set(MILVUS_QUERY_SRCS PlanProto.cpp ) add_library(milvus_query ${MILVUS_QUERY_SRCS}) -target_link_libraries(milvus_query milvus_index) +if(USE_DYNAMIC_SIMD) + target_link_libraries(milvus_query milvus_index milvus_simd) +else() + target_link_libraries(milvus_query milvus_index) +endif() diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index 5d0a7d185f..3aea1e2f70 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -35,6 +35,8 @@ #include "segcore/SegmentGrowingImpl.h" #include "simdjson/error.h" #include "query/PlanProto.h" +#include "simd/hook.h" + namespace milvus::query { // THIS CONTAINS EXTRA BODY FOR VISITOR // WILL BE USED BY GENERATOR @@ -186,7 +188,10 @@ AppendOneChunk(BitsetType& result, const FixedVector& chunk_res) { // Append a value once instead of BITSET_BLOCK_BIT_SIZE times. auto AppendBlock = [&result](const bool* ptr, int n) { for (int i = 0; i < n; ++i) { - BitSetBlockType val = 0; +#if defined(USE_DYNAMIC_SIMD) + auto val = milvus::simd::get_bitset_block(ptr); +#else + BitsetBlockType val = 0; // This can use CPU SIMD optimzation uint8_t vals[BITSET_BLOCK_SIZE] = {0}; for (size_t j = 0; j < 8; ++j) { @@ -195,8 +200,9 @@ AppendOneChunk(BitsetType& result, const FixedVector& chunk_res) { } } for (size_t j = 0; j < BITSET_BLOCK_SIZE; ++j) { - val |= BitSetBlockType(vals[j]) << (8 * j); + val |= BitsetBlockType(vals[j]) << (8 * j); } +#endif result.append(val); ptr += BITSET_BLOCK_SIZE * 8; } @@ -1782,11 +1788,31 @@ ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType { auto index_func = [&terms, n](Index* index) { return index->In(n, terms.data()); }; - auto elem_func = [&terms, &term_set](MayConstRef x) { - //// terms has already been sorted. - // return std::binary_search(terms.begin(), terms.end(), x); + +#if defined(USE_DYNAMIC_SIMD) + std::function x)> elem_func; + if (n <= milvus::simd::TERM_EXPR_IN_SIZE_THREAD) { + elem_func = [&terms, &term_set, n](MayConstRef x) { + if constexpr (std::is_integral::value || + std::is_floating_point::value) { + return milvus::simd::find_term_func(terms.data(), n, x); + } else { + // For string type, simd performance not better than set mode + static_assert(std::is_same::value || + std::is_same::value); + return term_set.find(x) != term_set.end(); + } + }; + } else { + elem_func = [&term_set, n](MayConstRef x) { + return term_set.find(x) != term_set.end(); + }; + } +#else + auto elem_func = [&term_set](MayConstRef x) { return term_set.find(x) != term_set.end(); }; +#endif return ExecRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); diff --git a/internal/core/src/simd/CMakeLists.txt b/internal/core/src/simd/CMakeLists.txt new file mode 100644 index 0000000000..64106eba5d --- /dev/null +++ b/internal/core/src/simd/CMakeLists.txt @@ -0,0 +1,36 @@ +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License + +set(MILVUS_SIMD_SRCS + ref.cpp + hook.cpp +) + +if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") + # x86 cpu simd + message ("simd using x86_64 mode") + list(APPEND MILVUS_SIMD_SRCS + sse2.cpp + sse4.cpp + avx2.cpp + avx512.cpp + ) + set_source_files_properties(sse4.cpp PROPERTIES COMPILE_FLAGS "-msse4.2") + set_source_files_properties(avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2") + set_source_files_properties(avx512.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx512bw") +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm*") + # TODO: add arm cpu simd +endif() + +add_library(milvus_simd ${MILVUS_SIMD_SRCS}) + +# Link the milvus_simd library with other libraries as needed +target_link_libraries(milvus_simd milvus_log) \ No newline at end of file diff --git a/internal/core/src/simd/avx2.cpp b/internal/core/src/simd/avx2.cpp new file mode 100644 index 0000000000..0faa120198 --- /dev/null +++ b/internal/core/src/simd/avx2.cpp @@ -0,0 +1,237 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#if defined(__x86_64__) + +#include "avx2.h" +#include "sse2.h" +#include "sse4.h" + +#include + +#include +#include + +namespace milvus { +namespace simd { + +BitsetBlockType +GetBitsetBlockAVX2(const bool* src) { + if constexpr (BITSET_BLOCK_SIZE == 8) { + // BitsetBlockType has 64 bits + __m256i highbit = _mm256_set1_epi8(0x7F); + uint32_t tmp[8]; + for (size_t i = 0; i < 2; i += 1) { + __m256i boolvec = _mm256_loadu_si256((__m256i*)&src[i * 32]); + __m256i highbits = _mm256_add_epi8(boolvec, highbit); + tmp[i] = _mm256_movemask_epi8(highbits); + } + + __m256i tmpvec = _mm256_loadu_si256((__m256i*)tmp); + BitsetBlockType res[4]; + _mm256_storeu_si256((__m256i*)res, tmpvec); + return res[0]; + // __m128i tmpvec = _mm_loadu_si64(tmp); + // BitsetBlockType res; + // _mm_storeu_si64(&res, tmpvec); + // return res; + } else { + // Others has 32 bits + __m256i highbit = _mm256_set1_epi8(0x7F); + uint32_t tmp[8]; + __m256i boolvec = _mm256_loadu_si256((__m256i*)&src[0]); + __m256i highbits = _mm256_add_epi8(boolvec, highbit); + tmp[0] = _mm256_movemask_epi8(highbits); + + __m256i tmpvec = _mm256_loadu_si256((__m256i*)tmp); + BitsetBlockType res[8]; + _mm256_storeu_si256((__m256i*)res, tmpvec); + return res[0]; + } +} + +template <> +bool +FindTermAVX2(const bool* src, size_t vec_size, bool val) { + __m256i ymm_target = _mm256_set1_epi8(val); + __m256i ymm_data; + size_t num_chunks = vec_size / 32; + + for (size_t i = 0; i < num_chunks; i++) { + ymm_data = + _mm256_loadu_si256(reinterpret_cast(src + 32 * i)); + __m256i ymm_match = _mm256_cmpeq_epi8(ymm_data, ymm_target); + int mask = _mm256_movemask_epi8(ymm_match); + if (mask != 0) { + return true; + } + } + + for (size_t i = 32 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermAVX2(const int8_t* src, size_t vec_size, int8_t val) { + __m256i ymm_target = _mm256_set1_epi8(val); + __m256i ymm_data; + size_t num_chunks = vec_size / 32; + + for (size_t i = 0; i < num_chunks; i++) { + ymm_data = + _mm256_loadu_si256(reinterpret_cast(src + 32 * i)); + __m256i ymm_match = _mm256_cmpeq_epi8(ymm_data, ymm_target); + int mask = _mm256_movemask_epi8(ymm_match); + if (mask != 0) { + return true; + } + } + + for (size_t i = 32 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermAVX2(const int16_t* src, size_t vec_size, int16_t val) { + __m256i ymm_target = _mm256_set1_epi16(val); + __m256i ymm_data; + size_t num_chunks = vec_size / 16; + size_t remaining_size = vec_size % 16; + for (size_t i = 0; i < num_chunks; i++) { + ymm_data = + _mm256_loadu_si256(reinterpret_cast(src + 16 * i)); + __m256i ymm_match = _mm256_cmpeq_epi16(ymm_data, ymm_target); + int mask = _mm256_movemask_epi8(ymm_match); + if (mask != 0) { + return true; + } + } + + for (size_t i = 16 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermAVX2(const int32_t* src, size_t vec_size, int32_t val) { + __m256i ymm_target = _mm256_set1_epi32(val); + __m256i ymm_data; + size_t num_chunks = vec_size / 8; + size_t remaining_size = vec_size % 8; + + for (size_t i = 0; i < num_chunks; i++) { + ymm_data = + _mm256_loadu_si256(reinterpret_cast(src + 8 * i)); + __m256i ymm_match = _mm256_cmpeq_epi32(ymm_data, ymm_target); + int mask = _mm256_movemask_epi8(ymm_match); + if (mask != 0) { + return true; + } + } + + if (remaining_size == 0) { + return false; + } + return FindTermSSE2(src + 8 * num_chunks, remaining_size, val); +} + +template <> +bool +FindTermAVX2(const int64_t* src, size_t vec_size, int64_t val) { + __m256i ymm_target = _mm256_set1_epi64x(val); + __m256i ymm_data; + size_t num_chunks = vec_size / 4; + size_t remaining_size = vec_size % 4; + + for (size_t i = 0; i < num_chunks; i++) { + ymm_data = + _mm256_loadu_si256(reinterpret_cast(src + 4 * i)); + __m256i ymm_match = _mm256_cmpeq_epi64(ymm_data, ymm_target); + int mask = _mm256_movemask_epi8(ymm_match); + if (mask != 0) { + return true; + } + } + + for (size_t i = 4 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermAVX2(const float* src, size_t vec_size, float val) { + __m256 ymm_target = _mm256_set1_ps(val); + __m256 ymm_data; + size_t num_chunks = vec_size / 8; + + for (size_t i = 0; i < num_chunks; i++) { + ymm_data = _mm256_loadu_ps(src + 8 * i); + __m256 ymm_match = _mm256_cmp_ps(ymm_data, ymm_target, _CMP_EQ_OQ); + int mask = _mm256_movemask_ps(ymm_match); + if (mask != 0) { + return true; + } + } + + for (size_t i = 8 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermAVX2(const double* src, size_t vec_size, double val) { + __m256d ymm_target = _mm256_set1_pd(val); + __m256d ymm_data; + size_t num_chunks = vec_size / 4; + + for (size_t i = 0; i < num_chunks; i++) { + ymm_data = _mm256_loadu_pd(src + 8 * i); + __m256d ymm_match = _mm256_cmp_pd(ymm_data, ymm_target, _CMP_EQ_OQ); + int mask = _mm256_movemask_pd(ymm_match); + if (mask != 0) { + return true; + } + } + + for (size_t i = 4 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +} // namespace simd +} // namespace milvus + +#endif diff --git a/internal/core/src/simd/avx2.h b/internal/core/src/simd/avx2.h new file mode 100644 index 0000000000..7e811aaa2b --- /dev/null +++ b/internal/core/src/simd/avx2.h @@ -0,0 +1,62 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "common.h" + +namespace milvus { +namespace simd { + +BitsetBlockType +GetBitsetBlockAVX2(const bool* src); + +template +bool +FindTermAVX2(const T* src, size_t vec_size, T va) { + CHECK_SUPPORTED_TYPE(T, "unsupported type for FindTermAVX2"); + return false; +} + +template <> +bool +FindTermAVX2(const bool* src, size_t vec_size, bool val); + +template <> +bool +FindTermAVX2(const int8_t* src, size_t vec_size, int8_t val); + +template <> +bool +FindTermAVX2(const int16_t* src, size_t vec_size, int16_t val); + +template <> +bool +FindTermAVX2(const int32_t* src, size_t vec_size, int32_t val); + +template <> +bool +FindTermAVX2(const int64_t* src, size_t vec_size, int64_t val); + +template <> +bool +FindTermAVX2(const float* src, size_t vec_size, float val); + +template <> +bool +FindTermAVX2(const double* src, size_t vec_size, double val); + +} // namespace simd +} // namespace milvus diff --git a/internal/core/src/simd/avx512.cpp b/internal/core/src/simd/avx512.cpp new file mode 100644 index 0000000000..42a7a08c77 --- /dev/null +++ b/internal/core/src/simd/avx512.cpp @@ -0,0 +1,188 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "avx512.h" +#include + +#if defined(__x86_64__) +#include + +namespace milvus { +namespace simd { + +template <> +bool +FindTermAVX512(const bool* src, size_t vec_size, bool val) { + __m512i zmm_target = _mm512_set1_epi8(val); + __m512i zmm_data; + size_t num_chunks = vec_size / 64; + + for (size_t i = 0; i < num_chunks; i++) { + zmm_data = + _mm512_loadu_si512(reinterpret_cast(src + 64 * i)); + __mmask64 mask = _mm512_cmpeq_epi8_mask(zmm_data, zmm_target); + if (mask != 0) { + return true; + } + } + + for (size_t i = 64 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermAVX512(const int8_t* src, size_t vec_size, int8_t val) { + __m512i zmm_target = _mm512_set1_epi8(val); + __m512i zmm_data; + size_t num_chunks = vec_size / 64; + + for (size_t i = 0; i < num_chunks; i++) { + zmm_data = + _mm512_loadu_si512(reinterpret_cast(src + 64 * i)); + __mmask64 mask = _mm512_cmpeq_epi8_mask(zmm_data, zmm_target); + if (mask != 0) { + return true; + } + } + + for (size_t i = 64 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermAVX512(const int16_t* src, size_t vec_size, int16_t val) { + __m512i zmm_target = _mm512_set1_epi16(val); + __m512i zmm_data; + size_t num_chunks = vec_size / 32; + + for (size_t i = 0; i < num_chunks; i++) { + zmm_data = + _mm512_loadu_si512(reinterpret_cast(src + 32 * i)); + __mmask32 mask = _mm512_cmpeq_epi16_mask(zmm_data, zmm_target); + if (mask != 0) { + return true; + } + } + + for (size_t i = 32 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermAVX512(const int32_t* src, size_t vec_size, int32_t val) { + __m512i zmm_target = _mm512_set1_epi32(val); + __m512i zmm_data; + size_t num_chunks = vec_size / 16; + + for (size_t i = 0; i < num_chunks; i++) { + zmm_data = + _mm512_loadu_si512(reinterpret_cast(src + 16 * i)); + __mmask16 mask = _mm512_cmpeq_epi32_mask(zmm_data, zmm_target); + if (mask != 0) { + return true; + } + } + + for (size_t i = 16 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermAVX512(const int64_t* src, size_t vec_size, int64_t val) { + __m512i zmm_target = _mm512_set1_epi64(val); + __m512i zmm_data; + size_t num_chunks = vec_size / 8; + + for (size_t i = 0; i < num_chunks; i++) { + zmm_data = + _mm512_loadu_si512(reinterpret_cast(src + 8 * i)); + __mmask8 mask = _mm512_cmpeq_epi64_mask(zmm_data, zmm_target); + if (mask != 0) { + return true; + } + } + + for (size_t i = 8 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermAVX512(const float* src, size_t vec_size, float val) { + __m512 zmm_target = _mm512_set1_ps(val); + __m512 zmm_data; + size_t num_chunks = vec_size / 16; + + for (size_t i = 0; i < num_chunks; i++) { + zmm_data = _mm512_loadu_ps(src + 16 * i); + __mmask16 mask = _mm512_cmp_ps_mask(zmm_data, zmm_target, _CMP_EQ_OQ); + if (mask != 0) { + return true; + } + } + + for (size_t i = 16 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermAVX512(const double* src, size_t vec_size, double val) { + __m512d zmm_target = _mm512_set1_pd(val); + __m512d zmm_data; + size_t num_chunks = vec_size / 8; + + for (size_t i = 0; i < num_chunks; i++) { + zmm_data = _mm512_loadu_pd(src + 8 * i); + __mmask8 mask = _mm512_cmp_pd_mask(zmm_data, zmm_target, _CMP_EQ_OQ); + if (mask != 0) { + return true; + } + } + + for (size_t i = 8 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} +} // namespace simd +} // namespace milvus +#endif diff --git a/internal/core/src/simd/avx512.h b/internal/core/src/simd/avx512.h new file mode 100644 index 0000000000..f09c2c2116 --- /dev/null +++ b/internal/core/src/simd/avx512.h @@ -0,0 +1,59 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "common.h" + +namespace milvus { +namespace simd { + +template +bool +FindTermAVX512(const T* src, size_t vec_size, T va) { + CHECK_SUPPORTED_TYPE(T, "unsupported type for FindTermAVX512"); + return false; +} + +template <> +bool +FindTermAVX512(const bool* src, size_t vec_size, bool val); + +template <> +bool +FindTermAVX512(const int8_t* src, size_t vec_size, int8_t val); + +template <> +bool +FindTermAVX512(const int16_t* src, size_t vec_size, int16_t val); + +template <> +bool +FindTermAVX512(const int32_t* src, size_t vec_size, int32_t val); + +template <> +bool +FindTermAVX512(const int64_t* src, size_t vec_size, int64_t val); + +template <> +bool +FindTermAVX512(const float* src, size_t vec_size, float val); + +template <> +bool +FindTermAVX512(const double* src, size_t vec_size, double val); + +} // namespace simd +} // namespace milvus diff --git a/internal/core/src/simd/common.h b/internal/core/src/simd/common.h new file mode 100644 index 0000000000..3cbe9c6e3e --- /dev/null +++ b/internal/core/src/simd/common.h @@ -0,0 +1,44 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +namespace milvus { +namespace simd { + +using BitsetBlockType = unsigned long; +constexpr size_t BITSET_BLOCK_SIZE = sizeof(unsigned long); + +/* +* For term size less than TERM_EXPR_IN_SIZE_THREAD, +* using simd search better for all numberic type. +* For term size bigger than TERM_EXPR_IN_SIZE_THREAD, +* using set search better for all numberic type. +* 50 is experimental value, using dynamic plan to support modify it +* in different situation. +*/ +const int TERM_EXPR_IN_SIZE_THREAD = 50; + +#define CHECK_SUPPORTED_TYPE(T, Message) \ + static_assert( \ + std::is_same::value || std::is_same::value || \ + std::is_same::value || \ + std::is_same::value || \ + std::is_same::value || \ + std::is_same::value || std::is_same::value, \ + Message); + +} // namespace simd +} // namespace milvus diff --git a/internal/core/src/simd/hook.cpp b/internal/core/src/simd/hook.cpp new file mode 100644 index 0000000000..0ae5f24266 --- /dev/null +++ b/internal/core/src/simd/hook.cpp @@ -0,0 +1,171 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +// -*- c++ -*- + +#include "hook.h" + +#include +#include +#include + +#include "ref.h" +#include "log/Log.h" +#if defined(__x86_64__) +#include "avx2.h" +#include "avx512.h" +#include "sse2.h" +#include "sse4.h" +#include "instruction_set.h" +#endif + +namespace milvus { +namespace simd { + +#if defined(__x86_64__) +bool use_avx512 = true; +bool use_avx2 = true; +bool use_sse4_2 = true; +bool use_sse2 = true; + +bool use_bitset_sse2; +bool use_find_term_sse2; +bool use_find_term_sse4_2; +bool use_find_term_avx2; +bool use_find_term_avx512; +#endif + +decltype(get_bitset_block) get_bitset_block = GetBitsetBlockRef; +FindTermPtr find_term_bool = FindTermRef; +FindTermPtr find_term_int8 = FindTermRef; +FindTermPtr find_term_int16 = FindTermRef; +FindTermPtr find_term_int32 = FindTermRef; +FindTermPtr find_term_int64 = FindTermRef; +FindTermPtr find_term_float = FindTermRef; +FindTermPtr find_term_double = FindTermRef; + +#if defined(__x86_64__) +bool +cpu_support_avx512() { + InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); + return (instruction_set_inst.AVX512F() && instruction_set_inst.AVX512DQ() && + instruction_set_inst.AVX512BW()); +} + +bool +cpu_support_avx2() { + InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); + return (instruction_set_inst.AVX2()); +} + +bool +cpu_support_sse4_2() { + InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); + return (instruction_set_inst.SSE42()); +} + +bool +cpu_support_sse2() { + InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); + return (instruction_set_inst.SSE2()); +} +#endif + +void +bitset_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + if (use_avx512 && cpu_support_avx512()) { + simd_type = "AVX512"; + // For now, sse2 has best performance + get_bitset_block = GetBitsetBlockSSE2; + use_bitset_sse2 = true; + } else if (use_avx2 && cpu_support_avx2()) { + simd_type = "AVX2"; + // For now, sse2 has best performance + get_bitset_block = GetBitsetBlockSSE2; + use_bitset_sse2 = true; + } else if (use_sse4_2 && cpu_support_sse4_2()) { + simd_type = "SSE4"; + get_bitset_block = GetBitsetBlockSSE2; + use_bitset_sse2 = true; + } else if (use_sse2 && cpu_support_sse2()) { + simd_type = "SSE2"; + get_bitset_block = GetBitsetBlockSSE2; + use_bitset_sse2 = true; + } +#endif + // TODO: support arm cpu + LOG_SEGCORE_INFO_ << "bitset hook simd type: " << simd_type; +} + +void +find_term_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + if (use_avx512 && cpu_support_avx512()) { + simd_type = "AVX512"; + find_term_bool = FindTermAVX512; + find_term_int8 = FindTermAVX512; + find_term_int16 = FindTermAVX512; + find_term_int32 = FindTermAVX512; + find_term_int64 = FindTermAVX512; + find_term_float = FindTermAVX512; + find_term_double = FindTermAVX512; + use_find_term_avx512 = true; + } else if (use_avx2 && cpu_support_avx2()) { + simd_type = "AVX2"; + find_term_bool = FindTermAVX2; + find_term_int8 = FindTermAVX2; + find_term_int16 = FindTermAVX2; + find_term_int32 = FindTermAVX2; + find_term_int64 = FindTermAVX2; + find_term_float = FindTermAVX2; + find_term_double = FindTermAVX2; + use_find_term_avx2 = true; + } else if (use_sse4_2 && cpu_support_sse4_2()) { + simd_type = "SSE4"; + find_term_bool = FindTermSSE4; + find_term_int8 = FindTermSSE4; + find_term_int16 = FindTermSSE4; + find_term_int32 = FindTermSSE4; + find_term_int64 = FindTermSSE4; + find_term_float = FindTermSSE4; + find_term_double = FindTermSSE4; + use_find_term_sse4_2 = true; + } else if (use_sse2 && cpu_support_sse2()) { + simd_type = "SSE2"; + find_term_bool = FindTermSSE2; + find_term_int8 = FindTermSSE2; + find_term_int16 = FindTermSSE2; + find_term_int32 = FindTermSSE2; + find_term_int64 = FindTermSSE2; + find_term_float = FindTermSSE2; + find_term_double = FindTermSSE2; + use_find_term_sse2 = true; + } +#endif + // TODO: support arm cpu + LOG_SEGCORE_INFO_ << "find term hook simd type: " << simd_type; +} + +static int init_hook_ = []() { + bitset_hook(); + find_term_hook(); + return 0; +}(); + +} // namespace simd +} // namespace milvus diff --git a/internal/core/src/simd/hook.h b/internal/core/src/simd/hook.h new file mode 100644 index 0000000000..050f660a10 --- /dev/null +++ b/internal/core/src/simd/hook.h @@ -0,0 +1,97 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include "common.h" +namespace milvus { +namespace simd { + +extern BitsetBlockType (*get_bitset_block)(const bool* src); + +template +using FindTermPtr = bool (*)(const T* src, size_t size, T val); + +extern FindTermPtr find_term_bool; +extern FindTermPtr find_term_int8; +extern FindTermPtr find_term_int16; +extern FindTermPtr find_term_int32; +extern FindTermPtr find_term_int64; +extern FindTermPtr find_term_float; +extern FindTermPtr find_term_double; + +#if defined(__x86_64__) +// Flags that indicate whether runtime can choose +// these simd type or not when hook starts. +extern bool use_avx512; +extern bool use_avx2; +extern bool use_sse4_2; +extern bool use_sse2; + +// Flags that indicate which kind of simd for +// different function when hook ends. +extern bool use_bitset_sse2; +extern bool use_find_term_sse2; +extern bool use_find_term_sse4_2; +extern bool use_find_term_avx2; +extern bool use_find_term_avx512; +#endif + +#if defined(__x86_64__) +bool +cpu_support_avx512(); +bool +cpu_support_avx2(); +bool +cpu_support_sse4_2(); +#endif + +void +bitset_hook(); + +void +find_term_hook(); + +template +bool +find_term_func(const T* data, size_t size, T val) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + if constexpr (std::is_same_v) { + return milvus::simd::find_term_bool(data, size, val); + } + if constexpr (std::is_same_v) { + return milvus::simd::find_term_int8(data, size, val); + } + if constexpr (std::is_same_v) { + return milvus::simd::find_term_int16(data, size, val); + } + if constexpr (std::is_same_v) { + return milvus::simd::find_term_int32(data, size, val); + } + if constexpr (std::is_same_v) { + return milvus::simd::find_term_int64(data, size, val); + } + if constexpr (std::is_same_v) { + return milvus::simd::find_term_float(data, size, val); + } + if constexpr (std::is_same_v) { + return milvus::simd::find_term_double(data, size, val); + } +} + +} // namespace simd +} // namespace milvus diff --git a/internal/core/src/simd/instruction_set.h b/internal/core/src/simd/instruction_set.h new file mode 100644 index 0000000000..a80686d160 --- /dev/null +++ b/internal/core/src/simd/instruction_set.h @@ -0,0 +1,368 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace milvus { +namespace simd { + +class InstructionSet { + public: + static InstructionSet& + GetInstance() { + static InstructionSet inst; + return inst; + } + + private: + InstructionSet() + : nIds_{0}, + nExIds_{0}, + isIntel_{false}, + isAMD_{false}, + f_1_ECX_{0}, + f_1_EDX_{0}, + f_7_EBX_{0}, + f_7_ECX_{0}, + f_81_ECX_{0}, + f_81_EDX_{0}, + data_{}, + extdata_{} { + std::array cpui; + + // Calling __cpuid with 0x0 as the function_id argument + // gets the number of the highest valid function ID. + __cpuid(0, cpui[0], cpui[1], cpui[2], cpui[3]); + nIds_ = cpui[0]; + + for (int i = 0; i <= nIds_; ++i) { + __cpuid_count(i, 0, cpui[0], cpui[1], cpui[2], cpui[3]); + data_.push_back(cpui); + } + + // Capture vendor string + char vendor[0x20]; + memset(vendor, 0, sizeof(vendor)); + *reinterpret_cast(vendor) = data_[0][1]; + *reinterpret_cast(vendor + 4) = data_[0][3]; + *reinterpret_cast(vendor + 8) = data_[0][2]; + vendor_ = vendor; + if (vendor_ == "GenuineIntel") { + isIntel_ = true; + } else if (vendor_ == "AuthenticAMD") { + isAMD_ = true; + } + + // load bitset with flags for function 0x00000001 + if (nIds_ >= 1) { + f_1_ECX_ = data_[1][2]; + f_1_EDX_ = data_[1][3]; + } + + // load bitset with flags for function 0x00000007 + if (nIds_ >= 7) { + f_7_EBX_ = data_[7][1]; + f_7_ECX_ = data_[7][2]; + } + + // Calling __cpuid with 0x80000000 as the function_id argument + // gets the number of the highest valid extended ID. + __cpuid(0x80000000, cpui[0], cpui[1], cpui[2], cpui[3]); + nExIds_ = cpui[0]; + + char brand[0x40]; + memset(brand, 0, sizeof(brand)); + + for (int i = 0x80000000; i <= nExIds_; ++i) { + __cpuid_count(i, 0, cpui[0], cpui[1], cpui[2], cpui[3]); + extdata_.push_back(cpui); + } + + // load bitset with flags for function 0x80000001 + if (nExIds_ >= (int)0x80000001) { + f_81_ECX_ = extdata_[1][2]; + f_81_EDX_ = extdata_[1][3]; + } + + // Interpret CPU brand string if reported + if (nExIds_ >= (int)0x80000004) { + memcpy(brand, extdata_[2].data(), sizeof(cpui)); + memcpy(brand + 16, extdata_[3].data(), sizeof(cpui)); + memcpy(brand + 32, extdata_[4].data(), sizeof(cpui)); + brand_ = brand; + } + }; + + public: + // getters + std::string + Vendor() { + return vendor_; + } + std::string + Brand() { + return brand_; + } + + bool + SSE3() { + return f_1_ECX_[0]; + } + bool + PCLMULQDQ() { + return f_1_ECX_[1]; + } + bool + MONITOR() { + return f_1_ECX_[3]; + } + bool + SSSE3() { + return f_1_ECX_[9]; + } + bool + FMA() { + return f_1_ECX_[12]; + } + bool + CMPXCHG16B() { + return f_1_ECX_[13]; + } + bool + SSE41() { + return f_1_ECX_[19]; + } + bool + SSE42() { + return f_1_ECX_[20]; + } + bool + MOVBE() { + return f_1_ECX_[22]; + } + bool + POPCNT() { + return f_1_ECX_[23]; + } + bool + AES() { + return f_1_ECX_[25]; + } + bool + XSAVE() { + return f_1_ECX_[26]; + } + bool + OSXSAVE() { + return f_1_ECX_[27]; + } + bool + AVX() { + return f_1_ECX_[28]; + } + bool + F16C() { + return f_1_ECX_[29]; + } + bool + RDRAND() { + return f_1_ECX_[30]; + } + + bool + MSR() { + return f_1_EDX_[5]; + } + bool + CX8() { + return f_1_EDX_[8]; + } + bool + SEP() { + return f_1_EDX_[11]; + } + bool + CMOV() { + return f_1_EDX_[15]; + } + bool + CLFSH() { + return f_1_EDX_[19]; + } + bool + MMX() { + return f_1_EDX_[23]; + } + bool + FXSR() { + return f_1_EDX_[24]; + } + bool + SSE() { + return f_1_EDX_[25]; + } + bool + SSE2() { + return f_1_EDX_[26]; + } + + bool + FSGSBASE() { + return f_7_EBX_[0]; + } + bool + BMI1() { + return f_7_EBX_[3]; + } + bool + HLE() { + return isIntel_ && f_7_EBX_[4]; + } + bool + AVX2() { + return f_7_EBX_[5]; + } + bool + BMI2() { + return f_7_EBX_[8]; + } + bool + ERMS() { + return f_7_EBX_[9]; + } + bool + INVPCID() { + return f_7_EBX_[10]; + } + bool + RTM() { + return isIntel_ && f_7_EBX_[11]; + } + bool + AVX512F() { + return f_7_EBX_[16]; + } + bool + AVX512DQ() { + return f_7_EBX_[17]; + } + bool + RDSEED() { + return f_7_EBX_[18]; + } + bool + ADX() { + return f_7_EBX_[19]; + } + bool + AVX512PF() { + return f_7_EBX_[26]; + } + bool + AVX512ER() { + return f_7_EBX_[27]; + } + bool + AVX512CD() { + return f_7_EBX_[28]; + } + bool + SHA() { + return f_7_EBX_[29]; + } + bool + AVX512BW() { + return f_7_EBX_[30]; + } + bool + AVX512VL() { + return f_7_EBX_[31]; + } + + bool + PREFETCHWT1() { + return f_7_ECX_[0]; + } + + bool + LAHF() { + return f_81_ECX_[0]; + } + bool + LZCNT() { + return isIntel_ && f_81_ECX_[5]; + } + bool + ABM() { + return isAMD_ && f_81_ECX_[5]; + } + bool + SSE4a() { + return isAMD_ && f_81_ECX_[6]; + } + bool + XOP() { + return isAMD_ && f_81_ECX_[11]; + } + bool + TBM() { + return isAMD_ && f_81_ECX_[21]; + } + + bool + SYSCALL() { + return isIntel_ && f_81_EDX_[11]; + } + bool + MMXEXT() { + return isAMD_ && f_81_EDX_[22]; + } + bool + RDTSCP() { + return isIntel_ && f_81_EDX_[27]; + } + bool + _3DNOWEXT() { + return isAMD_ && f_81_EDX_[30]; + } + bool + _3DNOW() { + return isAMD_ && f_81_EDX_[31]; + } + + private: + int nIds_; + int nExIds_; + std::string vendor_; + std::string brand_; + bool isIntel_; + bool isAMD_; + std::bitset<32> f_1_ECX_; + std::bitset<32> f_1_EDX_; + std::bitset<32> f_7_EBX_; + std::bitset<32> f_7_ECX_; + std::bitset<32> f_81_ECX_; + std::bitset<32> f_81_EDX_; + std::vector> data_; + std::vector> extdata_; +}; +} // namespace simd + +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/simd/ref.cpp b/internal/core/src/simd/ref.cpp new file mode 100644 index 0000000000..999bfa0458 --- /dev/null +++ b/internal/core/src/simd/ref.cpp @@ -0,0 +1,33 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "ref.h" + +namespace milvus { +namespace simd { + +BitsetBlockType +GetBitsetBlockRef(const bool* src) { + BitsetBlockType val = 0; + uint8_t vals[BITSET_BLOCK_SIZE] = {0}; + for (size_t j = 0; j < 8; ++j) { + for (size_t k = 0; k < BITSET_BLOCK_SIZE; ++k) { + vals[k] |= uint8_t(*(src + k * 8 + j)) << j; + } + } + for (size_t j = 0; j < BITSET_BLOCK_SIZE; ++j) { + val |= (BitsetBlockType)(vals[j]) << (8 * j); + } + return val; +} + +} // namespace simd +} // namespace milvus diff --git a/internal/core/src/simd/ref.h b/internal/core/src/simd/ref.h new file mode 100644 index 0000000000..604b0aa7c3 --- /dev/null +++ b/internal/core/src/simd/ref.h @@ -0,0 +1,34 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "common.h" + +namespace milvus { +namespace simd { + +BitsetBlockType +GetBitsetBlockRef(const bool* src); + +template +bool +FindTermRef(const T* src, size_t size, T val) { + for (size_t i = 0; i < size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +} // namespace simd +} // namespace milvus diff --git a/internal/core/src/simd/sse2.cpp b/internal/core/src/simd/sse2.cpp new file mode 100644 index 0000000000..e7cb207757 --- /dev/null +++ b/internal/core/src/simd/sse2.cpp @@ -0,0 +1,262 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#if defined(__x86_64__) + +#include "sse2.h" + +#include +#include + +namespace milvus { +namespace simd { + +#define ALIGNED(x) __attribute__((aligned(x))) + +BitsetBlockType +GetBitsetBlockSSE2(const bool* src) { + if constexpr (BITSET_BLOCK_SIZE == 8) { + // BitsetBlockType has 64 bits + __m128i highbit = _mm_set1_epi8(0x7F); + uint16_t tmp[4]; + for (size_t i = 0; i < 4; i += 1) { + // Outer function assert (src has 64 * n length) + __m128i boolvec = _mm_loadu_si128((__m128i*)&src[i * 16]); + __m128i highbits = _mm_add_epi8(boolvec, highbit); + tmp[i] = _mm_movemask_epi8(highbits); + } + + __m128i tmpvec = _mm_loadu_si64(tmp); + BitsetBlockType res; + _mm_storeu_si64(&res, tmpvec); + return res; + } else { + // Others has 32 bits + __m128i highbit = _mm_set1_epi8(0x7F); + uint16_t tmp[8]; + for (size_t i = 0; i < 2; i += 1) { + __m128i boolvec = _mm_loadu_si128((__m128i*)&src[i * 16]); + __m128i highbits = _mm_add_epi8(boolvec, highbit); + tmp[i] = _mm_movemask_epi8(highbits); + } + + __m128i tmpvec = _mm_loadu_si128((__m128i*)tmp); + BitsetBlockType res[4]; + _mm_storeu_si128((__m128i*)res, tmpvec); + return res[0]; + } +} + +template <> +bool +FindTermSSE2(const bool* src, size_t vec_size, bool val) { + __m128i xmm_target = _mm_set1_epi8(val); + __m128i xmm_data; + size_t num_chunks = vec_size / 16; + for (size_t i = 0; i < num_chunks; i++) { + xmm_data = + _mm_loadu_si128(reinterpret_cast(src + 16 * i)); + __m128i xmm_match = _mm_cmpeq_epi8(xmm_data, xmm_target); + int mask = _mm_movemask_epi8(xmm_match); + if (mask != 0) { + return true; + } + } + + for (size_t i = 16 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + + return false; +} + +template <> +bool +FindTermSSE2(const int8_t* src, size_t vec_size, int8_t val) { + __m128i xmm_target = _mm_set1_epi8(val); + __m128i xmm_data; + size_t num_chunks = vec_size / 16; + for (size_t i = 0; i < num_chunks; i++) { + xmm_data = + _mm_loadu_si128(reinterpret_cast(src + 16 * i)); + __m128i xmm_match = _mm_cmpeq_epi8(xmm_data, xmm_target); + int mask = _mm_movemask_epi8(xmm_match); + if (mask != 0) { + return true; + } + } + + for (size_t i = 16 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + + return false; +} + +template <> +bool +FindTermSSE2(const int16_t* src, size_t vec_size, int16_t val) { + __m128i xmm_target = _mm_set1_epi16(val); + __m128i xmm_data; + size_t num_chunks = vec_size / 8; + for (size_t i = 0; i < num_chunks; i++) { + xmm_data = + _mm_loadu_si128(reinterpret_cast(src + i * 8)); + __m128i xmm_match = _mm_cmpeq_epi16(xmm_data, xmm_target); + int mask = _mm_movemask_epi8(xmm_match); + if (mask != 0) { + return true; + } + } + + for (size_t i = 8 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermSSE2(const int32_t* src, size_t vec_size, int32_t val) { + size_t num_chunk = vec_size / 4; + size_t remaining_size = vec_size % 4; + + __m128i xmm_target = _mm_set1_epi32(val); + for (size_t i = 0; i < num_chunk; ++i) { + __m128i xmm_data = + _mm_loadu_si128(reinterpret_cast(src + i * 4)); + __m128i xmm_match = _mm_cmpeq_epi32(xmm_data, xmm_target); + int mask = _mm_movemask_epi8(xmm_match); + if (mask != 0) { + return true; + } + } + + const int32_t* remaining_ptr = src + num_chunk * 4; + if (remaining_size == 0) { + return false; + } else if (remaining_size == 1) { + return *remaining_ptr == val; + } else if (remaining_size == 2) { + __m128i xmm_data = + _mm_set_epi32(0, 0, *(remaining_ptr + 1), *(remaining_ptr)); + __m128i xmm_match = _mm_cmpeq_epi32(xmm_data, xmm_target); + int mask = _mm_movemask_epi8(xmm_match); + if ((mask & 0xFF) != 0) { + return true; + } + } else { + __m128i xmm_data = _mm_set_epi32( + 0, *(remaining_ptr + 2), *(remaining_ptr + 1), *(remaining_ptr)); + __m128i xmm_match = _mm_cmpeq_epi32(xmm_data, xmm_target); + int mask = _mm_movemask_epi8(xmm_match); + if ((mask & 0xFFF) != 0) { + return true; + } + } + return false; +} + +template <> +bool +FindTermSSE2(const int64_t* src, size_t vec_size, int64_t val) { + // _mm_cmpeq_epi64 is not implement in SSE2, compare two int32 instead. + int32_t low = static_cast(val); + int32_t high = static_cast(val >> 32); + size_t num_chunk = vec_size / 2; + size_t remaining_size = vec_size % 2; + + for (int64_t i = 0; i < num_chunk; i++) { + __m128i xmm_vec = + _mm_load_si128(reinterpret_cast(src + i * 2)); + + __m128i xmm_low = _mm_set1_epi32(low); + __m128i xmm_high = _mm_set1_epi32(high); + __m128i cmp_low = _mm_cmpeq_epi32(xmm_vec, xmm_low); + __m128i cmp_high = + _mm_cmpeq_epi32(_mm_srli_epi64(xmm_vec, 32), xmm_high); + __m128i cmp_result = _mm_and_si128(cmp_low, cmp_high); + + int mask = _mm_movemask_epi8(cmp_result); + if (mask != 0) { + return true; + } + } + + if (remaining_size == 1) { + if (src[2 * num_chunk] == val) { + return true; + } + } + return false; + + // for (size_t i = 0; i < vec_size; ++i) { + // if (src[i] == val) { + // return true; + // } + // } + // return false; +} + +template <> +bool +FindTermSSE2(const float* src, size_t vec_size, float val) { + size_t num_chunks = vec_size / 4; + __m128 xmm_target = _mm_set1_ps(val); + for (int i = 0; i < num_chunks; ++i) { + __m128 xmm_data = _mm_loadu_ps(src + 4 * i); + __m128 xmm_match = _mm_cmpeq_ps(xmm_data, xmm_target); + int mask = _mm_movemask_ps(xmm_match); + if (mask != 0) { + return true; + } + } + + for (size_t i = 4 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermSSE2(const double* src, size_t vec_size, double val) { + size_t num_chunks = vec_size / 2; + __m128d xmm_target = _mm_set1_pd(val); + for (int i = 0; i < num_chunks; ++i) { + __m128d xmm_data = _mm_loadu_pd(src + 2 * i); + __m128d xmm_match = _mm_cmpeq_pd(xmm_data, xmm_target); + int mask = _mm_movemask_pd(xmm_match); + if (mask != 0) { + return true; + } + } + + for (size_t i = 2 * num_chunks; i < vec_size; ++i) { + if (src[i] == val) { + return true; + } + } + return false; +} + +} // namespace simd +} // namespace milvus + +#endif diff --git a/internal/core/src/simd/sse2.h b/internal/core/src/simd/sse2.h new file mode 100644 index 0000000000..b7bbde86c0 --- /dev/null +++ b/internal/core/src/simd/sse2.h @@ -0,0 +1,63 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include +#include + +#include "common.h" +namespace milvus { +namespace simd { + +BitsetBlockType +GetBitsetBlockSSE2(const bool* src); + +template +bool +FindTermSSE2(const T* src, size_t vec_size, T va) { + CHECK_SUPPORTED_TYPE(T, "unsupported type for FindTermSSE2"); + return false; +} + +template <> +bool +FindTermSSE2(const bool* src, size_t vec_size, bool val); + +template <> +bool +FindTermSSE2(const int8_t* src, size_t vec_size, int8_t val); + +template <> +bool +FindTermSSE2(const int16_t* src, size_t vec_size, int16_t val); + +template <> +bool +FindTermSSE2(const int32_t* src, size_t vec_size, int32_t val); + +template <> +bool +FindTermSSE2(const int64_t* src, size_t vec_size, int64_t val); + +template <> +bool +FindTermSSE2(const float* src, size_t vec_size, float val); + +template <> +bool +FindTermSSE2(const double* src, size_t vec_size, double val); + +} // namespace simd +} // namespace milvus diff --git a/internal/core/src/simd/sse4.cpp b/internal/core/src/simd/sse4.cpp new file mode 100644 index 0000000000..8585f9c648 --- /dev/null +++ b/internal/core/src/simd/sse4.cpp @@ -0,0 +1,110 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#if defined(__x86_64__) + +#include "sse4.h" +#include "sse2.h" + +#include +#include +#include + +extern "C" { +extern int +sse2_strcmp(const char* s1, const char* s2); +} +namespace milvus { +namespace simd { + +template <> +bool +FindTermSSE4(const int64_t* src, size_t vec_size, int64_t val) { + size_t num_chunk = vec_size / 2; + size_t remaining_size = vec_size % 2; + + __m128i xmm_target = _mm_set1_epi64x(val); + for (size_t i = 0; i < num_chunk; ++i) { + __m128i xmm_data = + _mm_loadu_si128(reinterpret_cast(src + i * 2)); + __m128i xmm_match = _mm_cmpeq_epi64(xmm_data, xmm_target); + int mask = _mm_movemask_epi8(xmm_match); + if (mask != 0) { + return true; + } + } + if (remaining_size == 1) { + if (src[2 * num_chunk] == val) { + return true; + } + } + return false; +} + +template <> +bool +FindTermSSE4(const std::string* src, size_t vec_size, std::string val) { + for (size_t i = 0; i < vec_size; ++i) { + if (StrCmpSSE4(src[i].c_str(), val.c_str())) { + return true; + } + } + return false; +} + +template <> +bool +FindTermSSE4(const std::string_view* src, + size_t vec_size, + std::string_view val) { + for (size_t i = 0; i < vec_size; ++i) { + if (!StrCmpSSE4(src[i].data(), val.data())) { + return true; + } + } + return false; +} + +int +StrCmpSSE4(const char* s1, const char* s2) { + __m128i* ptr1 = reinterpret_cast<__m128i*>(const_cast(s1)); + __m128i* ptr2 = reinterpret_cast<__m128i*>(const_cast(s2)); + + for (;; ptr1++, ptr2++) { + const __m128i a = _mm_loadu_si128(ptr1); + const __m128i b = _mm_loadu_si128(ptr2); + + const uint8_t mode = _SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_EACH | + _SIDD_NEGATIVE_POLARITY | _SIDD_LEAST_SIGNIFICANT; + + if (_mm_cmpistrc(a, b, mode)) { + const auto idx = _mm_cmpistri(a, b, mode); + const uint8_t b1 = (reinterpret_cast(ptr1))[idx]; + const uint8_t b2 = (reinterpret_cast(ptr2))[idx]; + + if (b1 < b2) { + return -1; + } else if (b1 > b2) { + return +1; + } else { + return 0; + } + } else if (_mm_cmpistrz(a, b, mode)) { + break; + } + } + return 0; +} + +} // namespace simd +} // namespace milvus + +#endif diff --git a/internal/core/src/simd/sse4.h b/internal/core/src/simd/sse4.h new file mode 100644 index 0000000000..107ab519f7 --- /dev/null +++ b/internal/core/src/simd/sse4.h @@ -0,0 +1,41 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include +#include + +#include "common.h" +#include "sse2.h" +namespace milvus { +namespace simd { + +template +bool +FindTermSSE4(const T* src, size_t vec_size, T val) { + CHECK_SUPPORTED_TYPE(T, "unsupported type for FindTermSSE2"); + // SSE4 still hava 128bit, using same code with SSE2 + return FindTermSSE2(src, vec_size, val); +} + +template <> +bool +FindTermSSE4(const int64_t* src, size_t vec_size, int64_t val); + +int +StrCmpSSE4(const char* s1, const char* s2); + +} // namespace simd +} // namespace milvus diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index afcb35a051..56922f4de4 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -115,3 +115,17 @@ install(TARGETS all_tests DESTINATION unittest) if (LINUX) add_subdirectory(bench) endif () + +if (USE_DYNAMIC_SIMD) +add_executable(dynamic_simd_test + test_simd.cpp) + +target_link_libraries(dynamic_simd_test + milvus_simd + milvus_log + gtest + ${CONAN_LIBS}) + +install(TARGETS dynamic_simd_test DESTINATION unittest) +endif() + diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 6173762a97..459dca80a9 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -4519,6 +4519,7 @@ TEST(CApiTest, AssembeChunkTest) { ASSERT_EQ(result[index++], chunk[i]) << i; } + chunk.clear(); for (int i = 0; i < 934; ++i) { chunk.push_back(i % 2 == 0); } @@ -4526,6 +4527,8 @@ TEST(CApiTest, AssembeChunkTest) { for (size_t i = 0; i < 934; i++) { ASSERT_EQ(result[index++], chunk[i]) << i; } + + chunk.clear(); for (int i = 0; i < 62; ++i) { chunk.push_back(i % 2 == 0); } @@ -4533,6 +4536,8 @@ TEST(CApiTest, AssembeChunkTest) { for (size_t i = 0; i < 62; i++) { ASSERT_EQ(result[index++], chunk[i]) << i; } + + chunk.clear(); for (int i = 0; i < 105; ++i) { chunk.push_back(i % 2 == 0); } @@ -4621,3 +4626,28 @@ TEST(CApiTest, SearchIdTest) { test(nt); } } + +TEST(CApiTest, AssembeChunkPerfTest) { + FixedVector chunk; + for (size_t i = 0; i < 100000000; ++i) { + chunk.push_back(i % 2 == 0); + } + BitsetType result; + // while (true) { + std::cout << "start test" << std::endl; + auto start = std::chrono::steady_clock::now(); + milvus::query::AppendOneChunk(result, chunk); + std::cout << "cost: " + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << "us" << std::endl; + int index = 0; + for (size_t i = 0; i < 1000; i++) { + ASSERT_EQ(result[index++], chunk[i]) << i; + } + // } + // std::string s; + // boost::to_string(result, s); + // std::cout << s << std::endl; +} diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index e974031db1..0dd3ec46fd 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -1375,6 +1375,8 @@ TEST(Expr, TestExprs) { auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); schema->set_primary_field_id(str1_fid); auto seg = CreateSealedSegment(schema); @@ -1407,8 +1409,8 @@ TEST(Expr, TestExprs) { BinaryArithOpEvalRangeExpr = 6, }; - auto build_expr = - [&](enum ExprType test_type) -> std::shared_ptr { + auto build_expr = [&](enum ExprType test_type, + int n) -> std::shared_ptr { switch (test_type) { case UnaryRangeExpr: return std::make_shared>( @@ -1418,11 +1420,22 @@ TEST(Expr, TestExprs) { proto::plan::GenericValue::ValCase::kInt64Val); break; case TermExprImpl: { - std::vector retrieve_ints = {1, 4, 6}; - return std::make_shared>( - ColumnInfo(int64_fid, DataType::INT64), + std::vector retrieve_ints; + for (int i = 0; i < n; ++i) { + retrieve_ints.push_back("xxxxxx" + std::to_string(i % 10)); + } + return std::make_shared>( + ColumnInfo(str1_fid, DataType::VARCHAR), retrieve_ints, - proto::plan::GenericValue::ValCase::kInt64Val); + proto::plan::GenericValue::ValCase::kStringVal); + // std::vector retrieve_ints; + // for (int i = 0; i < n; ++i) { + // retrieve_ints.push_back(i); + // } + // return std::make_shared>( + // ColumnInfo(double_fid, DataType::DOUBLE), + // retrieve_ints, + // proto::plan::GenericValue::ValCase::kFloatVal); break; } case CompareExpr: { @@ -1499,15 +1512,25 @@ TEST(Expr, TestExprs) { break; } }; - auto expr = build_expr(UnaryRangeExpr); - std::cout << "start test" << std::endl; - auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*expr); - std::cout << "cost: " - << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << "us" << std::endl; + auto test_case = [&](int n) { + auto expr = build_expr(TermExprImpl, n); + std::cout << "start test" << std::endl; + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*expr); + std::cout << n << "cost: " + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << "us" << std::endl; + }; + test_case(3); + test_case(10); + test_case(20); + test_case(30); + test_case(50); + test_case(100); + test_case(200); + // test_case(500); } TEST(Expr, TestCompareWithScalarIndexMaris) { diff --git a/internal/core/unittest/test_simd.cpp b/internal/core/unittest/test_simd.cpp new file mode 100644 index 0000000000..b8a3606394 --- /dev/null +++ b/internal/core/unittest/test_simd.cpp @@ -0,0 +1,759 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__x86_64__) +#include "simd/hook.h" +#include "simd/sse2.h" +#include "simd/sse4.h" +#include "simd/avx2.h" +#include "simd/avx512.h" + +using namespace std; +using namespace milvus::simd; + +template +using FixedVector = boost::container::vector; + +#define PRINT_SKPI_TEST \ + std::cout \ + << "skip " \ + << ::testing::UnitTest::GetInstance()->current_test_info()->name() \ + << std::endl; + +TEST(GetBitSetBlock, base_test_sse) { + FixedVector src; + for (int i = 0; i < 64; ++i) { + src.push_back(false); + } + + auto res = GetBitsetBlockSSE2(src.data()); + std::cout << res << std::endl; + ASSERT_EQ(res, 0); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(true); + } + res = GetBitsetBlockSSE2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0xffffffffffffffff); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 2 == 0 ? true : false); + } + res = GetBitsetBlockSSE2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x5555555555555555); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 4 == 0 ? true : false); + } + res = GetBitsetBlockSSE2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x1111111111111111); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 8 == 0 ? true : false); + } + res = GetBitsetBlockSSE2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x0101010101010101); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 16 == 0 ? true : false); + } + res = GetBitsetBlockSSE2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x0001000100010001); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 32 == 0 ? true : false); + } + res = GetBitsetBlockSSE2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x0000000100000001); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 5 == 0 ? true : false); + } + res = GetBitsetBlockSSE2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x1084210842108421); +} + +TEST(GetBitSetBlock, base_test_avx2) { + FixedVector src; + for (int i = 0; i < 64; ++i) { + src.push_back(false); + } + + auto res = GetBitsetBlockAVX2(src.data()); + std::cout << res << std::endl; + ASSERT_EQ(res, 0); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(true); + } + res = GetBitsetBlockAVX2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0xffffffffffffffff); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 2 == 0 ? true : false); + } + res = GetBitsetBlockAVX2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x5555555555555555); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 4 == 0 ? true : false); + } + res = GetBitsetBlockAVX2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x1111111111111111); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 8 == 0 ? true : false); + } + res = GetBitsetBlockAVX2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x0101010101010101); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 16 == 0 ? true : false); + } + res = GetBitsetBlockAVX2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x0001000100010001); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 32 == 0 ? true : false); + } + res = GetBitsetBlockAVX2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x0000000100000001); + + src.clear(); + for (int i = 0; i < 64; ++i) { + src.push_back(i % 5 == 0 ? true : false); + } + res = GetBitsetBlockAVX2(src.data()); + std::cout << std::hex << res << std::endl; + ASSERT_EQ(res, 0x1084210842108421); +} + +TEST(FindTermSSE2, bool_type) { + FixedVector vecs; + vecs.push_back(false); + + auto res = FindTermSSE2(vecs.data(), vecs.size(), true); + ASSERT_EQ(res, false); + res = FindTermSSE2(vecs.data(), vecs.size(), false); + ASSERT_EQ(res, true); + + for (int i = 0; i < 16; i++) { + vecs.push_back(false); + } + + res = FindTermSSE2(vecs.data(), vecs.size(), true); + ASSERT_EQ(res, false); + res = FindTermSSE2(vecs.data(), vecs.size(), false); + ASSERT_EQ(res, true); + + vecs.push_back(true); + for (int i = 0; i < 16; i++) { + vecs.push_back(false); + } + res = FindTermSSE2(vecs.data(), vecs.size(), true); + ASSERT_EQ(res, true); +} + +TEST(FindTermSSE2, int8_type) { + std::vector vecs; + for (int i = 0; i < 100; i++) { + vecs.push_back(i); + } + + auto res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)0); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)10); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)99); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)100); + ASSERT_EQ(res, false); + res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)127); + ASSERT_EQ(res, false); + vecs.push_back(127); + res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)127); + ASSERT_EQ(res, true); +} + +TEST(FindTermSSE2, int16_type) { + std::vector vecs; + for (int i = 0; i < 1000; i++) { + vecs.push_back(i); + } + + auto res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)0); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)10); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)999); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)1000); + ASSERT_EQ(res, false); + res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)1270); + ASSERT_EQ(res, false); + vecs.push_back(1000); + res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)1000); + ASSERT_EQ(res, true); +} + +TEST(FindTermSSE2, int32_type) { + std::vector vecs; + for (int i = 0; i < 1000; i++) { + vecs.push_back(i); + } + + auto res = FindTermSSE2(vecs.data(), vecs.size(), 0); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), 10); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), 999); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), 1000); + ASSERT_EQ(res, false); + + vecs.push_back(1000); + res = FindTermSSE2(vecs.data(), vecs.size(), 1000); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), 1001); + ASSERT_EQ(res, false); + + vecs.push_back(1001); + res = FindTermSSE2(vecs.data(), vecs.size(), 1001); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), 1002); + ASSERT_EQ(res, false); + + vecs.push_back(1002); + res = FindTermSSE2(vecs.data(), vecs.size(), 1002); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), 1003); + ASSERT_EQ(res, false); + + res = FindTermSSE2(vecs.data(), vecs.size(), 1270); + ASSERT_EQ(res, false); +} + +TEST(FindTermSSE2, int64_type) { + std::vector vecs; + for (int i = 0; i < 1000; i++) { + vecs.push_back(i); + } + + auto res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)0); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)10); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)999); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)1000); + ASSERT_EQ(res, false); + res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)1270); + ASSERT_EQ(res, false); + vecs.push_back(1005); + res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)1005); + ASSERT_EQ(res, true); +} + +TEST(FindTermSSE2, float_type) { + std::vector vecs; + for (int i = 0; i < 10000; i++) { + vecs.push_back(i + 0.01); + } + + auto res = FindTermSSE2(vecs.data(), vecs.size(), (float)0.01); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), (float)10.01); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), (float)10000.01); + ASSERT_EQ(res, false); + res = FindTermSSE2(vecs.data(), vecs.size(), (float)12700.02); + ASSERT_EQ(res, false); + vecs.push_back(1.001); + res = FindTermSSE2(vecs.data(), vecs.size(), (float)1.001); + ASSERT_EQ(res, true); +} + +TEST(FindTermSSE2, double_type) { + std::vector vecs; + for (int i = 0; i < 10000; i++) { + vecs.push_back(i + 0.01); + } + + auto res = FindTermSSE2(vecs.data(), vecs.size(), 0.01); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), 10.01); + ASSERT_EQ(res, true); + res = FindTermSSE2(vecs.data(), vecs.size(), 10000.01); + ASSERT_EQ(res, false); + res = FindTermSSE2(vecs.data(), vecs.size(), 12700.01); + ASSERT_EQ(res, false); + vecs.push_back(1.001); + res = FindTermSSE2(vecs.data(), vecs.size(), 1.001); + ASSERT_EQ(res, true); +} + +TEST(FindTermSSE4, int64_type) { + if (!cpu_support_sse4_2()) { + PRINT_SKPI_TEST + return; + } + std::vector srcs; + for (size_t i = 0; i < 1000; i++) { + srcs.push_back(i); + } + + auto res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)0); + ASSERT_EQ(res, true); + res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)1); + ASSERT_EQ(res, true); + res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)999); + ASSERT_EQ(res, true); + res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)1000); + ASSERT_EQ(res, false); + res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)2000); + ASSERT_EQ(res, false); + srcs.push_back(1000); + res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)1000); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX2, bool_type) { + if (!cpu_support_avx2()) { + PRINT_SKPI_TEST + return; + } + std::vector srcs; + for (size_t i = 0; i < 1000; i++) { + srcs.push_back(i); + } + FixedVector vecs; + vecs.push_back(false); + + auto res = FindTermAVX2(vecs.data(), vecs.size(), true); + ASSERT_EQ(res, false); + res = FindTermAVX2(vecs.data(), vecs.size(), false); + ASSERT_EQ(res, true); + + for (int i = 0; i < 16; i++) { + vecs.push_back(false); + } + + res = FindTermAVX2(vecs.data(), vecs.size(), true); + ASSERT_EQ(res, false); + res = FindTermAVX2(vecs.data(), vecs.size(), false); + ASSERT_EQ(res, true); + + vecs.push_back(true); + for (int i = 0; i < 16; i++) { + vecs.push_back(false); + } + res = FindTermAVX2(vecs.data(), vecs.size(), true); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX2, int8_type) { + if (!cpu_support_avx2()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 100; i++) { + vecs.push_back(i); + } + + auto res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)0); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)10); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)99); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)100); + ASSERT_EQ(res, false); + res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)127); + ASSERT_EQ(res, false); + vecs.push_back(127); + res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)127); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX2, int16_type) { + if (!cpu_support_avx2()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 1000; i++) { + vecs.push_back(i); + } + + auto res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)0); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)10); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)999); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)1000); + ASSERT_EQ(res, false); + res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)1270); + ASSERT_EQ(res, false); + vecs.push_back(1270); + res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)1270); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX2, int32_type) { + if (!cpu_support_avx2()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 1000; i++) { + vecs.push_back(i); + } + + auto res = FindTermAVX2(vecs.data(), vecs.size(), 0); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), 10); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), 999); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), 1000); + ASSERT_EQ(res, false); + res = FindTermAVX2(vecs.data(), vecs.size(), 1270); + ASSERT_EQ(res, false); + vecs.push_back(1270); + res = FindTermAVX2(vecs.data(), vecs.size(), 1270); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX2, int64_type) { + if (!cpu_support_avx2()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 1000; i++) { + vecs.push_back(i); + } + + auto res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)0); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)10); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)999); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)1000); + ASSERT_EQ(res, false); + res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)1270); + ASSERT_EQ(res, false); + vecs.push_back(1270); + res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)1270); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX2, float_type) { + if (!cpu_support_avx2()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 10000; i++) { + vecs.push_back(i + 0.01); + } + + auto res = FindTermAVX2(vecs.data(), vecs.size(), (float)0.01); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), (float)10.01); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), (float)10000.01); + ASSERT_EQ(res, false); + res = FindTermAVX2(vecs.data(), vecs.size(), (float)12700.02); + ASSERT_EQ(res, false); + vecs.push_back(12700.02); + res = FindTermAVX2(vecs.data(), vecs.size(), (float)12700.02); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX2, double_type) { + if (!cpu_support_avx2()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 10000; i++) { + vecs.push_back(i + 0.01); + } + + auto res = FindTermAVX2(vecs.data(), vecs.size(), 0.01); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), 10.01); + ASSERT_EQ(res, true); + res = FindTermAVX2(vecs.data(), vecs.size(), 10000.01); + ASSERT_EQ(res, false); + res = FindTermAVX2(vecs.data(), vecs.size(), 12700.01); + ASSERT_EQ(res, false); + vecs.push_back(12700.01); + res = FindTermAVX2(vecs.data(), vecs.size(), 12700.01); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX512, bool_type) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector srcs; + for (size_t i = 0; i < 1000; i++) { + srcs.push_back(i); + } + FixedVector vecs; + vecs.push_back(false); + + auto res = FindTermAVX512(vecs.data(), vecs.size(), true); + ASSERT_EQ(res, false); + res = FindTermAVX512(vecs.data(), vecs.size(), false); + ASSERT_EQ(res, true); + + for (int i = 0; i < 16; i++) { + vecs.push_back(false); + } + + res = FindTermAVX512(vecs.data(), vecs.size(), true); + ASSERT_EQ(res, false); + res = FindTermAVX512(vecs.data(), vecs.size(), false); + ASSERT_EQ(res, true); + + vecs.push_back(true); + for (int i = 0; i < 16; i++) { + vecs.push_back(false); + } + res = FindTermAVX512(vecs.data(), vecs.size(), true); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX512, int8_type) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 100; i++) { + vecs.push_back(i); + } + + auto res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)0); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)10); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)99); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)100); + ASSERT_EQ(res, false); + res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)127); + ASSERT_EQ(res, false); + vecs.push_back(127); + res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)127); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX512, int16_type) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 1000; i++) { + vecs.push_back(i); + } + + auto res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)0); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)10); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)999); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)1000); + ASSERT_EQ(res, false); + res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)1270); + ASSERT_EQ(res, false); + vecs.push_back(1270); + res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)1270); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX512, int32_type) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 1000; i++) { + vecs.push_back(i); + } + + auto res = FindTermAVX512(vecs.data(), vecs.size(), 0); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), 10); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), 999); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), 1000); + ASSERT_EQ(res, false); + res = FindTermAVX512(vecs.data(), vecs.size(), 1270); + ASSERT_EQ(res, false); + vecs.push_back(1270); + res = FindTermAVX512(vecs.data(), vecs.size(), 1270); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX512, int64_type) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 1000; i++) { + vecs.push_back(i); + } + + auto res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)0); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)10); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)999); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)1000); + ASSERT_EQ(res, false); + res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)1270); + ASSERT_EQ(res, false); + vecs.push_back(1270); + res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)1270); + ASSERT_EQ(res, true); +} + +TEST(FindTermAVX512, float_type) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 10000; i++) { + vecs.push_back(i + 0.01); + } + + auto res = FindTermAVX512(vecs.data(), vecs.size(), (float)0.01); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), (float)10.01); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), (float)10000.01); + ASSERT_EQ(res, false); + res = FindTermAVX512(vecs.data(), vecs.size(), (float)12700.02); + ASSERT_EQ(res, false); + vecs.push_back(12700.02); + res = FindTermAVX512(vecs.data(), vecs.size(), (float)12700.02); + ASSERT_EQ(res, true); +} + +TEST(StrCmpSS4, string_type) { + if (!cpu_support_sse4_2()) { + PRINT_SKPI_TEST + return; + } + + std::vector s1; + for (int i = 0; i < 1000; ++i) { + s1.push_back("test" + std::to_string(i)); + } + + for (int i = 0; i < 1000; ++i) { + auto res = StrCmpSSE4(s1[i].c_str(), "test0"); + } + + string s2; + string s3; + for (int i = 0; i < 1000; ++i) { + s2.push_back('x'); + } + for (int i = 0; i < 1000; ++i) { + s3.push_back('x'); + } + + auto res = StrCmpSSE4(s2.c_str(), s3.c_str()); + std::cout << res << std::endl; +} + +TEST(FindTermAVX512, double_type) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < 10000; i++) { + vecs.push_back(i + 0.01); + } + + auto res = FindTermAVX512(vecs.data(), vecs.size(), 0.01); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), 10.01); + ASSERT_EQ(res, true); + res = FindTermAVX512(vecs.data(), vecs.size(), 10000.01); + ASSERT_EQ(res, false); + res = FindTermAVX512(vecs.data(), vecs.size(), 12700.01); + ASSERT_EQ(res, false); + vecs.push_back(12700.01); + res = FindTermAVX512(vecs.data(), vecs.size(), 12700.01); + ASSERT_EQ(res, true); +} + +#endif + +int +main(int argc, char* argv[]) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/scripts/core_build.sh b/scripts/core_build.sh index bc2f096479..a510e01cfd 100755 --- a/scripts/core_build.sh +++ b/scripts/core_build.sh @@ -104,8 +104,9 @@ EMBEDDED_MILVUS="OFF" BUILD_DISK_ANN="OFF" USE_ASAN="OFF" OPEN_SIMD="OFF" +USE_DYNAMIC_SIMD="OFF" -while getopts "p:d:t:s:f:n:i:a:ulrcghzmeb" arg; do +while getopts "p:d:t:s:f:n:i:y:a:ulrcghzmeb" arg; do case $arg in f) CUSTOM_THIRDPARTY_PATH=$OPTARG @@ -163,6 +164,9 @@ while getopts "p:d:t:s:f:n:i:a:ulrcghzmeb" arg; do i) OPEN_SIMD=$OPTARG ;; + y) + USE_DYNAMIC_SIMD=$OPTARG + ;; h) # help echo " @@ -260,6 +264,7 @@ ${CMAKE_EXTRA_ARGS} \ -DBUILD_DISK_ANN=${BUILD_DISK_ANN} \ -DUSE_ASAN=${USE_ASAN} \ -DOPEN_SIMD=${OPEN_SIMD} \ +-DUSE_DYNAMIC_SIMD=${USE_DYNAMIC_SIMD} -DCPU_ARCH=${CPU_ARCH} \ ${CPP_SRC_DIR}" diff --git a/scripts/run_cpp_unittest.sh b/scripts/run_cpp_unittest.sh index d0c577cb9f..1e94ea865b 100755 --- a/scripts/run_cpp_unittest.sh +++ b/scripts/run_cpp_unittest.sh @@ -48,6 +48,14 @@ for UNITTEST_DIR in "${UNITTEST_DIRS[@]}"; do echo ${UNITTEST_DIR}/all_tests "run failed" exit 1 fi + if [ -f "${UNITTEST_DIR}/dynamic_simd_test" ]; then + echo "Running dynamic simd test" + ${UNITTEST_DIR}/dynamic_simd_test + if [ $? -ne 0 ]; then + echo ${UNITTEST_DIR}/dynamic_simd_test "run failed" + exit 1 + fi + fi done # run cwrapper unittest