mirror of https://github.com/milvus-io/milvus.git
enhance: [bitset] extend op_find() to be able to search both 0 and 1 (#39176)
issue: #39124 `bitset::find_first()` and `bitset::find_next()` now accept one more parameter, which allows to search for `0` bit instead of `1` bit Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>pull/39229/head
parent
702347bbfd
commit
3447ff7310
|
@ -22,8 +22,8 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64")
|
|||
detail/platform/x86/instruction_set.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(detail/platform/x86/avx512-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq")
|
||||
set_source_files_properties(detail/platform/x86/avx2-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma")
|
||||
set_source_files_properties(detail/platform/x86/avx512-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq -mavx512cd -mbmi")
|
||||
set_source_files_properties(detail/platform/x86/avx2-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma -mbmi")
|
||||
|
||||
# set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq")
|
||||
# set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma")
|
||||
|
|
|
@ -546,23 +546,26 @@ class BitsetBase {
|
|||
return as_derived();
|
||||
}
|
||||
|
||||
// Find the index of the first bit set to true.
|
||||
// Find the index of the first bit set to either true (default), or false.
|
||||
inline std::optional<size_t>
|
||||
find_first() const {
|
||||
find_first(const bool is_set = true) const {
|
||||
return policy_type::op_find(
|
||||
this->data(), this->offset(), this->size(), 0);
|
||||
this->data(), this->offset(), this->size(), 0, is_set);
|
||||
}
|
||||
|
||||
// Find the index of the first bit set to true, starting from a given bit index.
|
||||
// Find the index of the first bit set to either true (default), or false, starting from a given bit index.
|
||||
inline std::optional<size_t>
|
||||
find_next(const size_t starting_bit_idx) const {
|
||||
find_next(const size_t starting_bit_idx, const bool is_set = true) const {
|
||||
const size_t size_v = this->size();
|
||||
if (starting_bit_idx + 1 >= size_v) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return policy_type::op_find(
|
||||
this->data(), this->offset(), this->size(), starting_bit_idx + 1);
|
||||
return policy_type::op_find(this->data(),
|
||||
this->offset(),
|
||||
this->size(),
|
||||
starting_bit_idx + 1,
|
||||
is_set);
|
||||
}
|
||||
|
||||
// Read multiple bits starting from a given bit index.
|
||||
|
|
|
@ -315,10 +315,11 @@ struct BitWiseBitsetPolicy {
|
|||
op_find(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx) {
|
||||
const size_t starting_idx,
|
||||
const bool is_set) {
|
||||
for (size_t i = starting_idx; i < size; i++) {
|
||||
const auto proxy = get_proxy(data, start + i);
|
||||
if (proxy) {
|
||||
if (proxy == is_set) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -220,9 +220,10 @@ struct VectorizedElementWiseBitsetPolicy {
|
|||
op_find(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx) {
|
||||
const size_t starting_idx,
|
||||
const bool is_set) {
|
||||
return ElementWiseBitsetPolicy<ElementT>::op_find(
|
||||
data, start, size, starting_idx);
|
||||
data, start, size, starting_idx, is_set);
|
||||
}
|
||||
|
||||
//
|
||||
|
|
|
@ -718,10 +718,10 @@ struct ElementWiseBitsetPolicy {
|
|||
|
||||
//
|
||||
static inline std::optional<size_t>
|
||||
op_find(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx) {
|
||||
op_find_1(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx) {
|
||||
if (size == 0) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -788,6 +788,91 @@ struct ElementWiseBitsetPolicy {
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
static inline std::optional<size_t>
|
||||
op_find_0(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx) {
|
||||
if (size == 0) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
//
|
||||
auto start_element = get_element(start + starting_idx);
|
||||
const auto end_element = get_element(start + size);
|
||||
|
||||
const auto start_shift = get_shift(start + starting_idx);
|
||||
const auto end_shift = get_shift(start + size);
|
||||
|
||||
// same element?
|
||||
if (start_element == end_element) {
|
||||
const data_type existing_v = ~data[start_element];
|
||||
|
||||
const data_type existing_mask = get_shift_mask_end(start_shift) &
|
||||
get_shift_mask_begin(end_shift);
|
||||
|
||||
const data_type value = existing_v & existing_mask;
|
||||
if (value != 0) {
|
||||
const auto ctz = CtzHelper<data_type>::ctz(value);
|
||||
return size_t(ctz) + start_element * data_bits - start;
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// process the first element
|
||||
if (start_shift != 0) {
|
||||
const data_type existing_v = ~data[start_element];
|
||||
const data_type existing_mask = get_shift_mask_end(start_shift);
|
||||
|
||||
const data_type value = existing_v & existing_mask;
|
||||
if (value != 0) {
|
||||
const auto ctz = CtzHelper<data_type>::ctz(value) +
|
||||
start_element * data_bits - start;
|
||||
return size_t(ctz);
|
||||
}
|
||||
|
||||
start_element += 1;
|
||||
}
|
||||
|
||||
// process the middle
|
||||
for (size_t i = start_element; i < end_element; i++) {
|
||||
const data_type value = ~data[i];
|
||||
if (value != 0) {
|
||||
const auto ctz = CtzHelper<data_type>::ctz(value);
|
||||
return size_t(ctz) + i * data_bits - start;
|
||||
}
|
||||
}
|
||||
|
||||
// process the last element
|
||||
if (end_shift != 0) {
|
||||
const data_type existing_v = ~data[end_element];
|
||||
const data_type existing_mask = get_shift_mask_begin(end_shift);
|
||||
|
||||
const data_type value = existing_v & existing_mask;
|
||||
if (value != 0) {
|
||||
const auto ctz = CtzHelper<data_type>::ctz(value);
|
||||
return size_t(ctz) + end_element * data_bits - start;
|
||||
}
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
//
|
||||
static inline std::optional<size_t>
|
||||
op_find(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx,
|
||||
const bool is_set) {
|
||||
if (is_set) {
|
||||
return op_find_1(data, start, size, starting_idx);
|
||||
} else {
|
||||
return op_find_0(data, start, size, starting_idx);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
template <typename T, typename U, CompareOpType Op>
|
||||
static inline void
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace milvus {
|
||||
namespace bitset {
|
||||
namespace detail {
|
||||
|
|
|
@ -346,7 +346,7 @@ from_i32(const int32_t i) {
|
|||
//
|
||||
template <typename BitsetT>
|
||||
void
|
||||
TestFindImpl(BitsetT& bitset, const size_t max_v) {
|
||||
TestFindImpl(BitsetT& bitset, const size_t max_v, const bool is_set) {
|
||||
const size_t n = bitset.size();
|
||||
|
||||
std::default_random_engine rng(123);
|
||||
|
@ -361,9 +361,13 @@ TestFindImpl(BitsetT& bitset, const size_t max_v) {
|
|||
}
|
||||
}
|
||||
|
||||
if (!is_set) {
|
||||
bitset.flip();
|
||||
}
|
||||
|
||||
StopWatch sw;
|
||||
|
||||
auto bit_idx = bitset.find_first();
|
||||
auto bit_idx = bitset.find_first(is_set);
|
||||
if (!bit_idx.has_value()) {
|
||||
ASSERT_EQ(one_pos.size(), 0);
|
||||
return;
|
||||
|
@ -372,7 +376,7 @@ TestFindImpl(BitsetT& bitset, const size_t max_v) {
|
|||
for (size_t i = 0; i < one_pos.size(); i++) {
|
||||
ASSERT_TRUE(bit_idx.has_value()) << n << ", " << max_v;
|
||||
ASSERT_EQ(bit_idx.value(), one_pos[i]) << n << ", " << max_v;
|
||||
bit_idx = bitset.find_next(bit_idx.value());
|
||||
bit_idx = bitset.find_next(bit_idx.value(), is_set);
|
||||
}
|
||||
|
||||
ASSERT_FALSE(bit_idx.has_value())
|
||||
|
@ -387,32 +391,40 @@ template <typename BitsetT>
|
|||
void
|
||||
TestFindImpl() {
|
||||
for (const size_t n : typical_sizes) {
|
||||
for (const size_t pr : {1, 100}) {
|
||||
BitsetT bitset(n);
|
||||
bitset.reset();
|
||||
|
||||
if (print_log) {
|
||||
printf("Testing bitset, n=%zd, pr=%zd\n", n, pr);
|
||||
}
|
||||
|
||||
TestFindImpl(bitset, pr);
|
||||
|
||||
for (const size_t offset : typical_offsets) {
|
||||
if (offset >= n) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const bool is_set : {true, false}) {
|
||||
for (const size_t pr : {1, 100}) {
|
||||
BitsetT bitset(n);
|
||||
bitset.reset();
|
||||
auto view = bitset.view(offset);
|
||||
|
||||
if (print_log) {
|
||||
printf("Testing bitset view, n=%zd, offset=%zd, pr=%zd\n",
|
||||
printf("Testing bitset, n=%zd, is_set=%d, pr=%zd\n",
|
||||
n,
|
||||
offset,
|
||||
(is_set) ? 1 : 0,
|
||||
pr);
|
||||
}
|
||||
|
||||
TestFindImpl(view, pr);
|
||||
TestFindImpl(bitset, pr, is_set);
|
||||
|
||||
for (const size_t offset : typical_offsets) {
|
||||
if (offset >= n) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bitset.reset();
|
||||
auto view = bitset.view(offset);
|
||||
|
||||
if (print_log) {
|
||||
printf(
|
||||
"Testing bitset view, n=%zd, offset=%zd, "
|
||||
"is_set=%d, pr=%zd\n",
|
||||
n,
|
||||
offset,
|
||||
(is_set) ? 1 : 0,
|
||||
pr);
|
||||
}
|
||||
|
||||
TestFindImpl(view, pr, is_set);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue