Fix bug for constructing ArrayView with fixed-length type (#28186)

Signed-off-by: Cai Zhang <cai.zhang@zilliz.com>
pull/28267/head
cai.zhang 2023-11-07 23:36:21 +08:00 committed by GitHub
parent 304f232a02
commit 19230db7f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 124 additions and 4 deletions

View File

@ -120,13 +120,23 @@ class Array {
size_t size,
DataType element_type,
std::vector<uint64_t>&& element_offsets)
: length_(element_offsets.size()),
size_(size),
: size_(size),
offsets_(std::move(element_offsets)),
element_type_(element_type) {
delete[] data_;
data_ = new char[size];
std::copy(data, data + size, data_);
if (datatype_is_variable(element_type_)) {
length_ = offsets_.size();
} else {
// int8, int16, int32 are all promoted to int32
if (element_type_ == DataType::INT8 ||
element_type_ == DataType::INT16) {
length_ = size / sizeof(int32_t);
} else {
length_ = size / datatype_sizeof(element_type_);
}
}
}
Array(const Array& array) noexcept
@ -433,9 +443,19 @@ class ArrayView {
std::vector<uint64_t>&& element_offsets)
: size_(size),
element_type_(element_type),
offsets_(std::move(element_offsets)),
length_(element_offsets.size()) {
offsets_(std::move(element_offsets)) {
data_ = data;
if (datatype_is_variable(element_type_)) {
length_ = offsets_.size();
} else {
// int8, int16, int32 are all promoted to int32
if (element_type_ == DataType::INT8 ||
element_type_ == DataType::INT16) {
length_ = size / sizeof(int32_t);
} else {
length_ = size / datatype_sizeof(element_type_);
}
}
}
template <typename T>

View File

@ -32,6 +32,30 @@ TEST(Array, TestConstructArray) {
ASSERT_EQ(int_array.get_data<int>(i), i);
}
ASSERT_TRUE(int_array.is_same_array(field_int_array));
auto int_array_tmp = Array(
const_cast<char*>(int_array.data()),
int_array.byte_size(),
int_array.get_element_type(),
{});
ASSERT_TRUE(int_array_tmp == int_array);
auto int_8_array = Array(const_cast<char*>(int_array.data()),
int_array.byte_size(),
DataType::INT8,
{});
ASSERT_EQ(int_array.length(), int_8_array.length());
auto int_16_array = Array(const_cast<char*>(int_array.data()),
int_array.byte_size(),
DataType::INT16,
{});
ASSERT_EQ(int_array.length(), int_16_array.length());
auto int_array_view = ArrayView(
const_cast<char*>(int_array.data()),
int_array.byte_size(),
int_array.get_element_type(),
{});
ASSERT_EQ(int_array.length(), int_array_view.length());
ASSERT_EQ(int_array.byte_size(), int_array_view.byte_size());
ASSERT_EQ(int_array.get_element_type(), int_array_view.get_element_type());
milvus::proto::schema::ScalarField field_long_data;
milvus::proto::plan::Array field_long_array;
@ -47,6 +71,20 @@ TEST(Array, TestConstructArray) {
ASSERT_EQ(long_array.get_data<int64_t>(i), i);
}
ASSERT_TRUE(long_array.is_same_array(field_int_array));
auto long_array_tmp = Array(const_cast<char*>(long_array.data()),
long_array.byte_size(),
long_array.get_element_type(),
{});
ASSERT_TRUE(long_array_tmp == long_array);
auto long_array_view = ArrayView(
const_cast<char*>(long_array.data()),
long_array.byte_size(),
long_array.get_element_type(),
{});
ASSERT_EQ(long_array.length(), long_array_view.length());
ASSERT_EQ(long_array.byte_size(), long_array_view.byte_size());
ASSERT_EQ(long_array.get_element_type(),
long_array_view.get_element_type());
milvus::proto::schema::ScalarField field_string_data;
milvus::proto::plan::Array field_string_array;
@ -65,6 +103,26 @@ TEST(Array, TestConstructArray) {
std::to_string(i));
}
ASSERT_TRUE(string_array.is_same_array(field_string_array));
std::vector<uint64_t> string_element_offsets;
std::vector<uint64_t> string_view_element_offsets;
for (auto& offset : string_array.get_offsets()) {
string_element_offsets.emplace_back(offset);
string_view_element_offsets.emplace_back(offset);
}
auto string_array_tmp = Array(const_cast<char*>(string_array.data()),
string_array.byte_size(),
string_array.get_element_type(),
std::move(string_element_offsets));
ASSERT_TRUE(string_array_tmp == string_array);
auto string_array_view = ArrayView(
const_cast<char*>(string_array.data()),
string_array.byte_size(),
string_array.get_element_type(),
std::move(string_view_element_offsets));
ASSERT_EQ(string_array.length(), string_array_view.length());
ASSERT_EQ(string_array.byte_size(), string_array_view.byte_size());
ASSERT_EQ(string_array.get_element_type(),
string_array_view.get_element_type());
milvus::proto::schema::ScalarField field_bool_data;
milvus::proto::plan::Array field_bool_array;
@ -80,6 +138,20 @@ TEST(Array, TestConstructArray) {
ASSERT_EQ(bool_array.get_data<bool>(i), bool(i));
}
ASSERT_TRUE(bool_array.is_same_array(field_bool_array));
auto bool_array_tmp = Array(const_cast<char*>(bool_array.data()),
bool_array.byte_size(),
bool_array.get_element_type(),
{});
ASSERT_TRUE(bool_array_tmp == bool_array);
auto bool_array_view = ArrayView(
const_cast<char*>(bool_array.data()),
bool_array.byte_size(),
bool_array.get_element_type(),
{});
ASSERT_EQ(bool_array.length(), bool_array_view.length());
ASSERT_EQ(bool_array.byte_size(), bool_array_view.byte_size());
ASSERT_EQ(bool_array.get_element_type(),
bool_array_view.get_element_type());
milvus::proto::schema::ScalarField field_float_data;
milvus::proto::plan::Array field_float_array;
@ -95,6 +167,20 @@ TEST(Array, TestConstructArray) {
ASSERT_DOUBLE_EQ(float_array.get_data<float>(i), float(i * 0.1));
}
ASSERT_TRUE(float_array.is_same_array(field_float_array));
auto float_array_tmp = Array(const_cast<char*>(float_array.data()),
float_array.byte_size(),
float_array.get_element_type(),
{});
ASSERT_TRUE(float_array_tmp == float_array);
auto float_array_view = ArrayView(
const_cast<char*>(float_array.data()),
float_array.byte_size(),
float_array.get_element_type(),
{});
ASSERT_EQ(float_array.length(), float_array_view.length());
ASSERT_EQ(float_array.byte_size(), float_array_view.byte_size());
ASSERT_EQ(float_array.get_element_type(),
float_array_view.get_element_type());
milvus::proto::schema::ScalarField field_double_data;
milvus::proto::plan::Array field_double_array;
@ -111,6 +197,20 @@ TEST(Array, TestConstructArray) {
ASSERT_DOUBLE_EQ(double_array.get_data<double>(i), double(i * 0.1));
}
ASSERT_TRUE(double_array.is_same_array(field_double_array));
auto double_array_tmp = Array(const_cast<char*>(double_array.data()),
double_array.byte_size(),
double_array.get_element_type(),
{});
ASSERT_TRUE(double_array_tmp == double_array);
auto double_array_view = ArrayView(
const_cast<char*>(double_array.data()),
double_array.byte_size(),
double_array.get_element_type(),
{});
ASSERT_EQ(double_array.length(), double_array_view.length());
ASSERT_EQ(double_array.byte_size(), double_array_view.byte_size());
ASSERT_EQ(double_array.get_element_type(),
double_array_view.get_element_type());
milvus::proto::schema::ScalarField field_empty_data;
milvus::proto::plan::Array field_empty_array;