enhance: [bitset] extend op_find() to be able to search both 0 and 1 (#39176)

issue: #39124 

`bitset::find_first()` and `bitset::find_next()` now accept one more
parameter, which allows to search for `0` bit instead of `1` bit

Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
pull/39229/head
Alexander Guzhva 2025-01-14 01:50:58 +00:00 committed by GitHub
parent 702347bbfd
commit 3447ff7310
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 143 additions and 39 deletions

View File

@ -22,8 +22,8 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64")
detail/platform/x86/instruction_set.cpp
)
set_source_files_properties(detail/platform/x86/avx512-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq")
set_source_files_properties(detail/platform/x86/avx2-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma")
set_source_files_properties(detail/platform/x86/avx512-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq -mavx512cd -mbmi")
set_source_files_properties(detail/platform/x86/avx2-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma -mbmi")
# set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq")
# set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma")

View File

@ -546,23 +546,26 @@ class BitsetBase {
return as_derived();
}
// Find the index of the first bit set to true.
// Find the index of the first bit set to either true (default), or false.
inline std::optional<size_t>
find_first() const {
find_first(const bool is_set = true) const {
return policy_type::op_find(
this->data(), this->offset(), this->size(), 0);
this->data(), this->offset(), this->size(), 0, is_set);
}
// Find the index of the first bit set to true, starting from a given bit index.
// Find the index of the first bit set to either true (default), or false, starting from a given bit index.
inline std::optional<size_t>
find_next(const size_t starting_bit_idx) const {
find_next(const size_t starting_bit_idx, const bool is_set = true) const {
const size_t size_v = this->size();
if (starting_bit_idx + 1 >= size_v) {
return std::nullopt;
}
return policy_type::op_find(
this->data(), this->offset(), this->size(), starting_bit_idx + 1);
return policy_type::op_find(this->data(),
this->offset(),
this->size(),
starting_bit_idx + 1,
is_set);
}
// Read multiple bits starting from a given bit index.

View File

@ -315,10 +315,11 @@ struct BitWiseBitsetPolicy {
op_find(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx) {
const size_t starting_idx,
const bool is_set) {
for (size_t i = starting_idx; i < size; i++) {
const auto proxy = get_proxy(data, start + i);
if (proxy) {
if (proxy == is_set) {
return i;
}
}

View File

@ -220,9 +220,10 @@ struct VectorizedElementWiseBitsetPolicy {
op_find(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx) {
const size_t starting_idx,
const bool is_set) {
return ElementWiseBitsetPolicy<ElementT>::op_find(
data, start, size, starting_idx);
data, start, size, starting_idx, is_set);
}
//

View File

@ -718,10 +718,10 @@ struct ElementWiseBitsetPolicy {
//
static inline std::optional<size_t>
op_find(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx) {
op_find_1(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx) {
if (size == 0) {
return std::nullopt;
}
@ -788,6 +788,91 @@ struct ElementWiseBitsetPolicy {
return std::nullopt;
}
static inline std::optional<size_t>
op_find_0(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx) {
if (size == 0) {
return std::nullopt;
}
//
auto start_element = get_element(start + starting_idx);
const auto end_element = get_element(start + size);
const auto start_shift = get_shift(start + starting_idx);
const auto end_shift = get_shift(start + size);
// same element?
if (start_element == end_element) {
const data_type existing_v = ~data[start_element];
const data_type existing_mask = get_shift_mask_end(start_shift) &
get_shift_mask_begin(end_shift);
const data_type value = existing_v & existing_mask;
if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value);
return size_t(ctz) + start_element * data_bits - start;
} else {
return std::nullopt;
}
}
// process the first element
if (start_shift != 0) {
const data_type existing_v = ~data[start_element];
const data_type existing_mask = get_shift_mask_end(start_shift);
const data_type value = existing_v & existing_mask;
if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value) +
start_element * data_bits - start;
return size_t(ctz);
}
start_element += 1;
}
// process the middle
for (size_t i = start_element; i < end_element; i++) {
const data_type value = ~data[i];
if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value);
return size_t(ctz) + i * data_bits - start;
}
}
// process the last element
if (end_shift != 0) {
const data_type existing_v = ~data[end_element];
const data_type existing_mask = get_shift_mask_begin(end_shift);
const data_type value = existing_v & existing_mask;
if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value);
return size_t(ctz) + end_element * data_bits - start;
}
}
return std::nullopt;
}
//
static inline std::optional<size_t>
op_find(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx,
const bool is_set) {
if (is_set) {
return op_find_1(data, start, size, starting_idx);
} else {
return op_find_0(data, start, size, starting_idx);
}
}
//
template <typename T, typename U, CompareOpType Op>
static inline void

View File

@ -16,6 +16,8 @@
#pragma once
#include <cstddef>
namespace milvus {
namespace bitset {
namespace detail {

View File

@ -346,7 +346,7 @@ from_i32(const int32_t i) {
//
template <typename BitsetT>
void
TestFindImpl(BitsetT& bitset, const size_t max_v) {
TestFindImpl(BitsetT& bitset, const size_t max_v, const bool is_set) {
const size_t n = bitset.size();
std::default_random_engine rng(123);
@ -361,9 +361,13 @@ TestFindImpl(BitsetT& bitset, const size_t max_v) {
}
}
if (!is_set) {
bitset.flip();
}
StopWatch sw;
auto bit_idx = bitset.find_first();
auto bit_idx = bitset.find_first(is_set);
if (!bit_idx.has_value()) {
ASSERT_EQ(one_pos.size(), 0);
return;
@ -372,7 +376,7 @@ TestFindImpl(BitsetT& bitset, const size_t max_v) {
for (size_t i = 0; i < one_pos.size(); i++) {
ASSERT_TRUE(bit_idx.has_value()) << n << ", " << max_v;
ASSERT_EQ(bit_idx.value(), one_pos[i]) << n << ", " << max_v;
bit_idx = bitset.find_next(bit_idx.value());
bit_idx = bitset.find_next(bit_idx.value(), is_set);
}
ASSERT_FALSE(bit_idx.has_value())
@ -387,32 +391,40 @@ template <typename BitsetT>
void
TestFindImpl() {
for (const size_t n : typical_sizes) {
for (const size_t pr : {1, 100}) {
BitsetT bitset(n);
bitset.reset();
if (print_log) {
printf("Testing bitset, n=%zd, pr=%zd\n", n, pr);
}
TestFindImpl(bitset, pr);
for (const size_t offset : typical_offsets) {
if (offset >= n) {
continue;
}
for (const bool is_set : {true, false}) {
for (const size_t pr : {1, 100}) {
BitsetT bitset(n);
bitset.reset();
auto view = bitset.view(offset);
if (print_log) {
printf("Testing bitset view, n=%zd, offset=%zd, pr=%zd\n",
printf("Testing bitset, n=%zd, is_set=%d, pr=%zd\n",
n,
offset,
(is_set) ? 1 : 0,
pr);
}
TestFindImpl(view, pr);
TestFindImpl(bitset, pr, is_set);
for (const size_t offset : typical_offsets) {
if (offset >= n) {
continue;
}
bitset.reset();
auto view = bitset.view(offset);
if (print_log) {
printf(
"Testing bitset view, n=%zd, offset=%zd, "
"is_set=%d, pr=%zd\n",
n,
offset,
(is_set) ? 1 : 0,
pr);
}
TestFindImpl(view, pr, is_set);
}
}
}
}