fix:fix get array error for int type (#35154)

#35055

Signed-off-by: luzhang <luzhang@zilliz.com>
Co-authored-by: luzhang <luzhang@zilliz.com>
pull/35022/head
zhagnlu 2024-08-01 14:30:12 +08:00 committed by GitHub
parent b4d0f4df0a
commit f8c1b138a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 10 deletions

View File

@ -142,13 +142,19 @@ template <typename T>
void void
BitmapIndex<T>::BuildArrayField(const std::vector<FieldDataPtr>& field_datas) { BitmapIndex<T>::BuildArrayField(const std::vector<FieldDataPtr>& field_datas) {
int64_t offset = 0; int64_t offset = 0;
using GetType = std::conditional_t<std::is_same_v<T, int8_t> ||
std::is_same_v<T, int16_t> ||
std::is_same_v<T, int32_t>,
int32_t,
T>;
for (const auto& data : field_datas) { for (const auto& data : field_datas) {
auto slice_row_num = data->get_num_rows(); auto slice_row_num = data->get_num_rows();
for (size_t i = 0; i < slice_row_num; ++i) { for (size_t i = 0; i < slice_row_num; ++i) {
auto array = auto array =
reinterpret_cast<const milvus::Array*>(data->RawValue(i)); reinterpret_cast<const milvus::Array*>(data->RawValue(i));
for (size_t j = 0; j < array->length(); ++j) { for (size_t j = 0; j < array->length(); ++j) {
auto val = array->template get_data<T>(j); auto val = static_cast<T>(array->template get_data<GetType>(j));
data_[val].add(offset); data_[val].add(offset);
} }
offset++; offset++;
@ -294,10 +300,12 @@ BitmapIndex<T>::DeserializeIndexMeta(const uint8_t* data_ptr,
template <typename T> template <typename T>
void void
BitmapIndex<T>::ChooseIndexBuildMode() { BitmapIndex<T>::ChooseIndexLoadMode(int64_t index_length) {
if (data_.size() <= DEFAULT_BITMAP_INDEX_CARDINALITY_BOUND) { if (index_length <= DEFAULT_BITMAP_INDEX_CARDINALITY_BOUND) {
LOG_DEBUG("load bitmap index with bitset mode");
build_mode_ = BitmapIndexBuildMode::BITSET; build_mode_ = BitmapIndexBuildMode::BITSET;
} else { } else {
LOG_DEBUG("load bitmap index with raw roaring mode");
build_mode_ = BitmapIndexBuildMode::ROARING; build_mode_ = BitmapIndexBuildMode::ROARING;
} }
} }
@ -306,6 +314,7 @@ template <typename T>
void void
BitmapIndex<T>::DeserializeIndexData(const uint8_t* data_ptr, BitmapIndex<T>::DeserializeIndexData(const uint8_t* data_ptr,
size_t index_length) { size_t index_length) {
ChooseIndexLoadMode(index_length);
for (size_t i = 0; i < index_length; ++i) { for (size_t i = 0; i < index_length; ++i) {
T key; T key;
memcpy(&key, data_ptr, sizeof(T)); memcpy(&key, data_ptr, sizeof(T));
@ -315,11 +324,10 @@ BitmapIndex<T>::DeserializeIndexData(const uint8_t* data_ptr,
value = roaring::Roaring::read(reinterpret_cast<const char*>(data_ptr)); value = roaring::Roaring::read(reinterpret_cast<const char*>(data_ptr));
data_ptr += value.getSizeInBytes(); data_ptr += value.getSizeInBytes();
ChooseIndexBuildMode();
if (build_mode_ == BitmapIndexBuildMode::BITSET) { if (build_mode_ == BitmapIndexBuildMode::BITSET) {
bitsets_[key] = ConvertRoaringToBitset(value); bitsets_[key] = ConvertRoaringToBitset(value);
data_.erase(key); } else {
data_[key] = value;
} }
} }
} }
@ -328,6 +336,7 @@ template <>
void void
BitmapIndex<std::string>::DeserializeIndexData(const uint8_t* data_ptr, BitmapIndex<std::string>::DeserializeIndexData(const uint8_t* data_ptr,
size_t index_length) { size_t index_length) {
ChooseIndexLoadMode(index_length);
for (size_t i = 0; i < index_length; ++i) { for (size_t i = 0; i < index_length; ++i) {
size_t key_size; size_t key_size;
memcpy(&key_size, data_ptr, sizeof(size_t)); memcpy(&key_size, data_ptr, sizeof(size_t));
@ -340,7 +349,11 @@ BitmapIndex<std::string>::DeserializeIndexData(const uint8_t* data_ptr,
value = roaring::Roaring::read(reinterpret_cast<const char*>(data_ptr)); value = roaring::Roaring::read(reinterpret_cast<const char*>(data_ptr));
data_ptr += value.getSizeInBytes(); data_ptr += value.getSizeInBytes();
bitsets_[key] = ConvertRoaringToBitset(value); if (build_mode_ == BitmapIndexBuildMode::BITSET) {
bitsets_[key] = ConvertRoaringToBitset(value);
} else {
data_[key] = value;
}
} }
} }

View File

@ -146,7 +146,7 @@ class BitmapIndex : public ScalarIndex<T> {
DeserializeIndexData(const uint8_t* data_ptr, size_t index_length); DeserializeIndexData(const uint8_t* data_ptr, size_t index_length);
void void
ChooseIndexBuildMode(); ChooseIndexLoadMode(int64_t index_length);
bool bool
ShouldSkip(const T lower_value, const T upper_value, const OpType op); ShouldSkip(const T lower_value, const T upper_value, const OpType op);

View File

@ -194,8 +194,8 @@ class ArrayBitmapIndexTest : public testing::Test {
auto serialized_bytes = insert_data.Serialize(storage::Remote); auto serialized_bytes = insert_data.Serialize(storage::Remote);
auto log_path = fmt::format("{}/{}/{}/{}/{}/{}", auto log_path = fmt::format("/{}/{}/{}/{}/{}/{}",
"test_array_bitmap", "/tmp/test_array_bitmap",
collection_id, collection_id,
partition_id, partition_id,
segment_id, segment_id,