[Cherry-Pick] Remove arrow uasge in FieldData (#22726)

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/22845/head
xige-16 2023-03-20 10:41:56 +08:00 committed by GitHub
parent 2a0ad67021
commit 9aa99aedbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1226 additions and 775 deletions

View File

@ -0,0 +1,46 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "storage/BinlogReader.h"
namespace milvus::storage {
Status
BinlogReader::Read(int64_t nbytes, void* out) {
auto remain = size_ - tell_;
if (nbytes > remain) {
return Status(SERVER_UNEXPECTED_ERROR, "out range of binlog data");
}
std::memcpy(out, data_.get() + tell_, nbytes);
tell_ += nbytes;
return Status(SERVER_SUCCESS, "");
}
std::pair<Status, std::shared_ptr<uint8_t[]>>
BinlogReader::Read(int64_t nbytes) {
auto remain = size_ - tell_;
if (nbytes > remain) {
return std::make_pair(
Status(SERVER_UNEXPECTED_ERROR, "out range of binlog data"),
nullptr);
}
auto res = std::shared_ptr<uint8_t[]>(new uint8_t[nbytes]);
std::memcpy(res.get(), data_.get() + tell_, nbytes);
tell_ += nbytes;
return std::make_pair(Status(SERVER_SUCCESS, ""), res);
}
} // namespace milvus::storage

View File

@ -0,0 +1,59 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <utility>
#include "utils/Status.h"
#include "exceptions/EasyAssert.h"
namespace milvus::storage {
class BinlogReader {
public:
explicit BinlogReader(const std::shared_ptr<uint8_t[]> binlog_data,
int64_t length)
: data_(binlog_data), size_(length), tell_(0) {
}
explicit BinlogReader(const uint8_t* binlog_data, int64_t length)
: size_(length), tell_(0) {
data_ = std::shared_ptr<uint8_t[]>(new uint8_t[length]);
std::memcpy(data_.get(), binlog_data, length);
}
Status
Read(int64_t nbytes, void* out);
std::pair<Status, std::shared_ptr<uint8_t[]>>
Read(int64_t nbytes);
int64_t
Tell() const {
return tell_;
}
private:
int64_t size_;
int64_t tell_;
std::shared_ptr<uint8_t[]> data_;
};
using BinlogReaderPtr = std::shared_ptr<BinlogReader>;
} // namespace milvus::storage

View File

@ -27,9 +27,11 @@ set(STORAGE_FILES
PayloadStream.cpp
DataCodec.cpp
Util.cpp
FieldData.cpp
PayloadReader.cpp
PayloadWriter.cpp
FieldData.cpp
BinlogReader.cpp
FieldDataFactory.cpp
IndexData.cpp
InsertData.cpp
Event.cpp

View File

@ -19,6 +19,7 @@
#include "storage/Util.h"
#include "storage/InsertData.h"
#include "storage/IndexData.h"
#include "storage/BinlogReader.h"
#include "exceptions/EasyAssert.h"
#include "common/Consts.h"
@ -26,8 +27,8 @@ namespace milvus::storage {
// deserialize remote insert and index file
std::unique_ptr<DataCodec>
DeserializeRemoteFileData(PayloadInputStream* input_stream) {
DescriptorEvent descriptor_event(input_stream);
DeserializeRemoteFileData(BinlogReaderPtr reader) {
DescriptorEvent descriptor_event(reader);
DataType data_type =
DataType(descriptor_event.event_data.fix_part.data_type);
auto descriptor_fix_part = descriptor_event.event_data.fix_part;
@ -35,13 +36,13 @@ DeserializeRemoteFileData(PayloadInputStream* input_stream) {
descriptor_fix_part.partition_id,
descriptor_fix_part.segment_id,
descriptor_fix_part.field_id};
EventHeader header(input_stream);
EventHeader header(reader);
switch (header.event_type_) {
case EventType::InsertEvent: {
auto event_data_length =
header.event_length_ - header.next_position_;
auto insert_event_data =
InsertEventData(input_stream, event_data_length, data_type);
InsertEventData(reader, event_data_length, data_type);
auto insert_data =
std::make_unique<InsertData>(insert_event_data.field_data);
insert_data->SetFieldDataMeta(data_meta);
@ -53,7 +54,7 @@ DeserializeRemoteFileData(PayloadInputStream* input_stream) {
auto event_data_length =
header.event_length_ - header.next_position_;
auto index_event_data =
IndexEventData(input_stream, event_data_length, data_type);
IndexEventData(reader, event_data_length, data_type);
auto index_data =
std::make_unique<IndexData>(index_event_data.field_data);
index_data->SetFieldDataMeta(data_meta);
@ -76,53 +77,25 @@ DeserializeRemoteFileData(PayloadInputStream* input_stream) {
// For now, no file header in file data
std::unique_ptr<DataCodec>
DeserializeLocalFileData(PayloadInputStream* input_stream) {
DeserializeLocalFileData(BinlogReaderPtr reader) {
PanicInfo("not supported");
}
std::unique_ptr<DataCodec>
DeserializeFileData(const uint8_t* input_data, int64_t length) {
auto input_stream =
std::make_shared<PayloadInputStream>(input_data, length);
auto medium_type = ReadMediumType(input_stream.get());
DeserializeFileData(const std::shared_ptr<uint8_t[]> input_data,
int64_t length) {
auto binlog_reader = std::make_shared<BinlogReader>(input_data, length);
auto medium_type = ReadMediumType(binlog_reader);
switch (medium_type) {
case StorageType::Remote: {
return DeserializeRemoteFileData(input_stream.get());
return DeserializeRemoteFileData(binlog_reader);
}
case StorageType::LocalDisk: {
auto ret = input_stream->Seek(0);
AssertInfo(ret.ok(), "seek input stream failed");
return DeserializeLocalFileData(input_stream.get());
return DeserializeLocalFileData(binlog_reader);
}
default:
PanicInfo("unsupported medium type");
}
}
// local insert file format
// -------------------------------------
// | Rows(int) | Dim(int) | InsertData |
// -------------------------------------
std::unique_ptr<DataCodec>
DeserializeLocalInsertFileData(const uint8_t* input_data,
int64_t length,
DataType data_type) {
auto input_stream =
std::make_shared<PayloadInputStream>(input_data, length);
LocalInsertEvent event(input_stream.get(), data_type);
return std::make_unique<InsertData>(event.field_data);
}
// local index file format: which indexSize = sizeOf(IndexData)
// --------------------------------------------------
// | IndexSize(uint64) | degree(uint32) | IndexData |
// --------------------------------------------------
std::unique_ptr<DataCodec>
DeserializeLocalIndexFileData(const uint8_t* input_data, int64_t length) {
auto input_stream =
std::make_shared<PayloadInputStream>(input_data, length);
LocalIndexEvent event(input_stream.get());
return std::make_unique<IndexData>(event.field_data);
}
} // namespace milvus::storage

View File

@ -23,12 +23,13 @@
#include "storage/Types.h"
#include "storage/FieldData.h"
#include "storage/PayloadStream.h"
#include "storage/BinlogReader.h"
namespace milvus::storage {
class DataCodec {
public:
explicit DataCodec(std::shared_ptr<FieldData> data, CodecType type)
explicit DataCodec(FieldDataPtr data, CodecType type)
: field_data_(data), codec_type_(type) {
}
@ -62,33 +63,25 @@ class DataCodec {
return field_data_->get_data_type();
}
std::unique_ptr<Payload>
GetPayload() const {
return field_data_->get_payload();
FieldDataPtr
GetFieldData() const {
return field_data_;
}
protected:
CodecType codec_type_;
std::pair<Timestamp, Timestamp> time_range_;
std::shared_ptr<FieldData> field_data_;
FieldDataPtr field_data_;
};
// Deserialize the data stream of the file obtained from remote or local
std::unique_ptr<DataCodec>
DeserializeFileData(const uint8_t* input, int64_t length);
DeserializeFileData(const std::shared_ptr<uint8_t[]> input, int64_t length);
std::unique_ptr<DataCodec>
DeserializeLocalInsertFileData(const uint8_t* input_data,
int64_t length,
DataType data_type);
DeserializeRemoteFileData(BinlogReaderPtr reader);
std::unique_ptr<DataCodec>
DeserializeLocalIndexFileData(const uint8_t* input_data, int64_t length);
std::unique_ptr<DataCodec>
DeserializeRemoteFileData(PayloadInputStream* input_stream);
std::unique_ptr<DataCodec>
DeserializeLocalFileData(PayloadInputStream* input_stream);
DeserializeLocalFileData(BinlogReaderPtr reader);
} // namespace milvus::storage

View File

@ -30,6 +30,7 @@
#include "storage/IndexData.h"
#include "storage/ThreadPool.h"
#include "storage/Util.h"
#include "storage/FieldDataFactory.h"
#define FILEMANAGER_TRY try {
#define FILEMANAGER_CATCH \
@ -91,9 +92,12 @@ EncodeAndUploadIndexSlice(RemoteChunkManager* remote_chunk_manager,
auto& local_chunk_manager = LocalChunkManager::GetInstance();
auto buf = std::unique_ptr<uint8_t[]>(new uint8_t[batch_size]);
local_chunk_manager.Read(file, offset, buf.get(), batch_size);
auto fieldData = std::make_shared<FieldData>(buf.get(), batch_size);
auto indexData = std::make_shared<IndexData>(fieldData);
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
DataType::INT8);
field_data->FillFieldData(buf.get(), batch_size);
auto indexData = std::make_shared<IndexData>(field_data);
indexData->set_index_meta(index_meta);
indexData->SetFieldDataMeta(field_meta);
auto serialized_index_data = indexData->serialize_to_remote_file();
@ -217,7 +221,7 @@ DownloadAndDecodeRemoteIndexfile(RemoteChunkManager* remote_chunk_manager,
auto buf = std::shared_ptr<uint8_t[]>(new uint8_t[fileSize]);
remote_chunk_manager->Read(file, buf.get(), fileSize);
return DeserializeFileData(buf.get(), fileSize);
return DeserializeFileData(buf, fileSize);
}
uint64_t
@ -238,12 +242,13 @@ DiskFileManagerImpl::CacheBatchIndexFilesToDisk(
uint64_t offset = local_file_init_offfset;
for (int i = 0; i < batch_size; ++i) {
auto res = futures[i].get();
auto index_payload = res->GetPayload();
auto index_size = index_payload->rows * sizeof(uint8_t);
local_chunk_manager.Write(local_file_name,
offset,
const_cast<uint8_t*>(index_payload->raw_data),
index_size);
auto index_data = res->GetFieldData();
auto index_size = index_data->Size();
local_chunk_manager.Write(
local_file_name,
offset,
reinterpret_cast<uint8_t*>(const_cast<void*>(index_data->Data())),
index_size);
offset += index_size;
}

View File

@ -18,6 +18,7 @@
#include "storage/Util.h"
#include "storage/PayloadReader.h"
#include "storage/PayloadWriter.h"
#include "storage/FieldDataFactory.h"
#include "exceptions/EasyAssert.h"
#include "utils/Json.h"
#include "common/Consts.h"
@ -67,14 +68,14 @@ GetEventFixPartSize(EventType EventTypeCode) {
}
}
EventHeader::EventHeader(PayloadInputStream* input) {
auto ast = input->Read(sizeof(timestamp_), &timestamp_);
EventHeader::EventHeader(BinlogReaderPtr reader) {
auto ast = reader->Read(sizeof(timestamp_), &timestamp_);
assert(ast.ok());
ast = input->Read(sizeof(event_type_), &event_type_);
ast = reader->Read(sizeof(event_type_), &event_type_);
assert(ast.ok());
ast = input->Read(sizeof(event_length_), &event_length_);
ast = reader->Read(sizeof(event_length_), &event_length_);
assert(ast.ok());
ast = input->Read(sizeof(next_position_), &next_position_);
ast = reader->Read(sizeof(next_position_), &next_position_);
assert(ast.ok());
}
@ -95,21 +96,20 @@ EventHeader::Serialize() {
return res;
}
DescriptorEventDataFixPart::DescriptorEventDataFixPart(
PayloadInputStream* input) {
auto ast = input->Read(sizeof(collection_id), &collection_id);
DescriptorEventDataFixPart::DescriptorEventDataFixPart(BinlogReaderPtr reader) {
auto ast = reader->Read(sizeof(collection_id), &collection_id);
assert(ast.ok());
ast = input->Read(sizeof(partition_id), &partition_id);
ast = reader->Read(sizeof(partition_id), &partition_id);
assert(ast.ok());
ast = input->Read(sizeof(segment_id), &segment_id);
ast = reader->Read(sizeof(segment_id), &segment_id);
assert(ast.ok());
ast = input->Read(sizeof(field_id), &field_id);
ast = reader->Read(sizeof(field_id), &field_id);
assert(ast.ok());
ast = input->Read(sizeof(start_timestamp), &start_timestamp);
ast = reader->Read(sizeof(start_timestamp), &start_timestamp);
assert(ast.ok());
ast = input->Read(sizeof(end_timestamp), &end_timestamp);
ast = reader->Read(sizeof(end_timestamp), &end_timestamp);
assert(ast.ok());
ast = input->Read(sizeof(data_type), &data_type);
ast = reader->Read(sizeof(data_type), &data_type);
assert(ast.ok());
}
@ -138,20 +138,20 @@ DescriptorEventDataFixPart::Serialize() {
return res;
}
DescriptorEventData::DescriptorEventData(PayloadInputStream* input) {
fix_part = DescriptorEventDataFixPart(input);
DescriptorEventData::DescriptorEventData(BinlogReaderPtr reader) {
fix_part = DescriptorEventDataFixPart(reader);
for (auto i = int8_t(EventType::DescriptorEvent);
i < int8_t(EventType::EventTypeEnd);
i++) {
post_header_lengths.push_back(GetEventFixPartSize(EventType(i)));
}
auto ast =
input->Read(post_header_lengths.size(), post_header_lengths.data());
reader->Read(post_header_lengths.size(), post_header_lengths.data());
assert(ast.ok());
ast = input->Read(sizeof(extra_length), &extra_length);
ast = reader->Read(sizeof(extra_length), &extra_length);
assert(ast.ok());
extra_bytes = std::vector<uint8_t>(extra_length);
ast = input->Read(extra_length, extra_bytes.data());
ast = reader->Read(extra_length, extra_bytes.data());
assert(ast.ok());
milvus::json json =
@ -192,35 +192,46 @@ DescriptorEventData::Serialize() {
return res;
}
BaseEventData::BaseEventData(PayloadInputStream* input,
BaseEventData::BaseEventData(BinlogReaderPtr reader,
int event_length,
DataType data_type) {
auto ast = input->Read(sizeof(start_timestamp), &start_timestamp);
auto ast = reader->Read(sizeof(start_timestamp), &start_timestamp);
AssertInfo(ast.ok(), "read start timestamp failed");
ast = input->Read(sizeof(end_timestamp), &end_timestamp);
ast = reader->Read(sizeof(end_timestamp), &end_timestamp);
AssertInfo(ast.ok(), "read end timestamp failed");
int payload_length =
event_length - sizeof(start_timestamp) - sizeof(end_timestamp);
auto res = input->Read(payload_length);
auto res = reader->Read(payload_length);
AssertInfo(res.first.ok(), "read payload failed");
auto payload_reader = std::make_shared<PayloadReader>(
res.ValueOrDie()->data(), payload_length, data_type);
res.second.get(), payload_length, data_type);
field_data = payload_reader->get_field_data();
}
// TODO :: handle string and bool type
std::vector<uint8_t>
BaseEventData::Serialize() {
auto payload = field_data->get_payload();
auto data_type = field_data->get_data_type();
std::shared_ptr<PayloadWriter> payload_writer;
if (milvus::datatype_is_vector(payload->data_type)) {
AssertInfo(payload->dimension.has_value(), "empty dimension");
payload_writer = std::make_unique<PayloadWriter>(
payload->data_type, payload->dimension.value());
if (milvus::datatype_is_vector(data_type)) {
payload_writer =
std::make_unique<PayloadWriter>(data_type, field_data->get_dim());
} else {
payload_writer = std::make_unique<PayloadWriter>(payload->data_type);
payload_writer = std::make_unique<PayloadWriter>(data_type);
}
if (datatype_is_string(data_type)) {
for (size_t offset = 0; offset < field_data->get_num_rows(); ++offset) {
payload_writer->add_one_string_payload(
reinterpret_cast<const char*>(field_data->RawValue(offset)),
field_data->get_element_size(offset));
}
} else {
auto payload = Payload{data_type,
static_cast<const uint8_t*>(field_data->Data()),
field_data->get_num_rows(),
field_data->get_dim()};
payload_writer->add_payload(payload);
}
payload_writer->add_payload(*payload.get());
payload_writer->finish();
auto payload_buffer = payload_writer->get_payload_buffer();
auto len =
@ -236,11 +247,11 @@ BaseEventData::Serialize() {
return res;
}
BaseEvent::BaseEvent(PayloadInputStream* input, DataType data_type) {
event_header = EventHeader(input);
BaseEvent::BaseEvent(BinlogReaderPtr reader, DataType data_type) {
event_header = EventHeader(reader);
auto event_data_length =
event_header.event_length_ - event_header.next_position_;
event_data = BaseEventData(input, event_data_length, data_type);
event_data = BaseEventData(reader, event_data_length, data_type);
}
std::vector<uint8_t>
@ -263,9 +274,9 @@ BaseEvent::Serialize() {
return res;
}
DescriptorEvent::DescriptorEvent(PayloadInputStream* input) {
event_header = EventHeader(input);
event_data = DescriptorEventData(input);
DescriptorEvent::DescriptorEvent(BinlogReaderPtr reader) {
event_header = EventHeader(reader);
event_data = DescriptorEventData(reader);
}
std::vector<uint8_t>
@ -291,42 +302,11 @@ DescriptorEvent::Serialize() {
return res;
}
LocalInsertEvent::LocalInsertEvent(PayloadInputStream* input,
DataType data_type) {
auto ret = input->Read(sizeof(row_num), &row_num);
AssertInfo(ret.ok(), "read input stream failed");
ret = input->Read(sizeof(dimension), &dimension);
AssertInfo(ret.ok(), "read input stream failed");
int data_size = milvus::datatype_sizeof(data_type) * row_num;
auto insert_data_bytes = input->Read(data_size);
auto insert_data = reinterpret_cast<const uint8_t*>(
insert_data_bytes.ValueOrDie()->data());
std::shared_ptr<arrow::ArrayBuilder> builder = nullptr;
if (milvus::datatype_is_vector(data_type)) {
builder = CreateArrowBuilder(data_type, dimension);
} else {
builder = CreateArrowBuilder(data_type);
}
// TODO :: handle string type
Payload payload{data_type, insert_data, row_num, dimension};
AddPayloadToArrowBuilder(builder, payload);
std::shared_ptr<arrow::Array> array;
auto finish_ret = builder->Finish(&array);
AssertInfo(finish_ret.ok(), "arrow builder finish failed");
field_data = std::make_shared<FieldData>(array, data_type);
}
std::vector<uint8_t>
LocalInsertEvent::Serialize() {
auto payload = field_data->get_payload();
row_num = payload->rows;
dimension = 1;
if (milvus::datatype_is_vector(payload->data_type)) {
assert(payload->dimension.has_value());
dimension = payload->dimension.value();
}
int payload_size = GetPayloadSize(payload.get());
int row_num = field_data->get_num_rows();
int dimension = field_data->get_dim();
int payload_size = field_data->Size();
int len = sizeof(row_num) + sizeof(dimension) + payload_size;
std::vector<uint8_t> res(len);
@ -335,36 +315,27 @@ LocalInsertEvent::Serialize() {
offset += sizeof(row_num);
memcpy(res.data() + offset, &dimension, sizeof(dimension));
offset += sizeof(dimension);
memcpy(res.data() + offset, payload->raw_data, payload_size);
memcpy(res.data() + offset, field_data->Data(), payload_size);
return res;
}
LocalIndexEvent::LocalIndexEvent(PayloadInputStream* input) {
auto ret = input->Read(sizeof(index_size), &index_size);
AssertInfo(ret.ok(), "read input stream failed");
ret = input->Read(sizeof(degree), &degree);
AssertInfo(ret.ok(), "read input stream failed");
auto binary_index = input->Read(index_size);
LocalIndexEvent::LocalIndexEvent(BinlogReaderPtr reader) {
auto ret = reader->Read(sizeof(index_size), &index_size);
AssertInfo(ret.ok(), "read binlog failed");
ret = reader->Read(sizeof(degree), &degree);
AssertInfo(ret.ok(), "read binlog failed");
auto binary_index_data =
reinterpret_cast<const int8_t*>(binary_index.ValueOrDie()->data());
auto builder = std::make_shared<arrow::Int8Builder>();
auto append_ret = builder->AppendValues(binary_index_data,
binary_index_data + index_size);
AssertInfo(append_ret.ok(), "append data to arrow builder failed");
std::shared_ptr<arrow::Array> array;
auto finish_ret = builder->Finish(&array);
AssertInfo(finish_ret.ok(), "arrow builder finish failed");
field_data = std::make_shared<FieldData>(array, DataType::INT8);
auto res = reader->Read(index_size);
AssertInfo(res.first.ok(), "read payload failed");
auto payload_reader = std::make_shared<PayloadReader>(
res.second.get(), index_size, DataType::INT8);
field_data = payload_reader->get_field_data();
}
std::vector<uint8_t>
LocalIndexEvent::Serialize() {
auto payload = field_data->get_payload();
index_size = payload->rows;
index_size = field_data->Size();
int len = sizeof(index_size) + sizeof(degree) + index_size;
std::vector<uint8_t> res(len);
@ -373,7 +344,7 @@ LocalIndexEvent::Serialize() {
offset += sizeof(index_size);
memcpy(res.data() + offset, &degree, sizeof(degree));
offset += sizeof(degree);
memcpy(res.data() + offset, payload->raw_data, index_size);
memcpy(res.data() + offset, field_data->Data(), index_size);
return res;
}

View File

@ -23,8 +23,8 @@
#include "common/Types.h"
#include "storage/Types.h"
#include "storage/PayloadStream.h"
#include "storage/FieldData.h"
#include "storage/BinlogReader.h"
namespace milvus::storage {
@ -36,7 +36,7 @@ struct EventHeader {
EventHeader() {
}
explicit EventHeader(PayloadInputStream* input);
explicit EventHeader(BinlogReaderPtr reader);
std::vector<uint8_t>
Serialize();
@ -53,7 +53,7 @@ struct DescriptorEventDataFixPart {
DescriptorEventDataFixPart() {
}
explicit DescriptorEventDataFixPart(PayloadInputStream* input);
explicit DescriptorEventDataFixPart(BinlogReaderPtr reader);
std::vector<uint8_t>
Serialize();
@ -68,7 +68,7 @@ struct DescriptorEventData {
DescriptorEventData() {
}
explicit DescriptorEventData(PayloadInputStream* input);
explicit DescriptorEventData(BinlogReaderPtr reader);
std::vector<uint8_t>
Serialize();
@ -77,11 +77,11 @@ struct DescriptorEventData {
struct BaseEventData {
Timestamp start_timestamp;
Timestamp end_timestamp;
std::shared_ptr<FieldData> field_data;
FieldDataPtr field_data;
BaseEventData() {
}
explicit BaseEventData(PayloadInputStream* input,
explicit BaseEventData(BinlogReaderPtr reader,
int event_length,
DataType data_type);
@ -95,7 +95,7 @@ struct DescriptorEvent {
DescriptorEvent() {
}
explicit DescriptorEvent(PayloadInputStream* input);
explicit DescriptorEvent(BinlogReaderPtr reader);
std::vector<uint8_t>
Serialize();
@ -107,7 +107,7 @@ struct BaseEvent {
BaseEvent() {
}
explicit BaseEvent(PayloadInputStream* input, DataType data_type);
explicit BaseEvent(BinlogReaderPtr reader, DataType data_type);
std::vector<uint8_t>
Serialize();
@ -138,13 +138,7 @@ int
GetEventFixPartSize(EventType EventTypeCode);
struct LocalInsertEvent {
int row_num;
int dimension;
std::shared_ptr<FieldData> field_data;
LocalInsertEvent() {
}
explicit LocalInsertEvent(PayloadInputStream* input, DataType data_type);
FieldDataPtr field_data;
std::vector<uint8_t>
Serialize();
@ -153,11 +147,11 @@ struct LocalInsertEvent {
struct LocalIndexEvent {
uint64_t index_size;
uint32_t degree;
std::shared_ptr<FieldData> field_data;
FieldDataPtr field_data;
LocalIndexEvent() {
}
explicit LocalIndexEvent(PayloadInputStream* input);
explicit LocalIndexEvent(BinlogReaderPtr reader);
std::vector<uint8_t>
Serialize();

View File

@ -38,6 +38,22 @@ class NotImplementedException : public std::exception {
std::string exception_message_;
};
class NotSupportedDataTypeException : public std::exception {
public:
explicit NotSupportedDataTypeException(const std::string& msg)
: std::exception(), exception_message_(msg) {
}
const char*
what() const noexcept {
return exception_message_.c_str();
}
virtual ~NotSupportedDataTypeException() {
}
private:
std::string exception_message_;
};
class LocalChunkManagerException : public std::runtime_error {
public:
explicit LocalChunkManagerException(const std::string& msg)

View File

@ -15,84 +15,135 @@
// limitations under the License.
#include "storage/FieldData.h"
#include "exceptions/EasyAssert.h"
#include "storage/Util.h"
#include "common/FieldMeta.h"
namespace milvus::storage {
FieldData::FieldData(const Payload& payload) {
std::shared_ptr<arrow::ArrayBuilder> builder;
data_type_ = payload.data_type;
if (milvus::datatype_is_vector(data_type_)) {
AssertInfo(payload.dimension.has_value(), "empty dimension");
builder = CreateArrowBuilder(data_type_, payload.dimension.value());
} else {
builder = CreateArrowBuilder(data_type_);
}
AddPayloadToArrowBuilder(builder, payload);
auto ast = builder->Finish(&array_);
AssertInfo(ast.ok(), "builder failed to finish");
}
// TODO ::Check arrow type with data_type
FieldData::FieldData(std::shared_ptr<arrow::Array> array, DataType data_type)
: array_(array), data_type_(data_type) {
}
FieldData::FieldData(const uint8_t* data, int length)
: data_type_(DataType::INT8) {
auto builder = std::make_shared<arrow::Int8Builder>();
auto ret = builder->AppendValues(data, data + length);
AssertInfo(ret.ok(), "append value to builder failed");
ret = builder->Finish(&array_);
AssertInfo(ret.ok(), "builder failed to finish");
}
bool
FieldData::get_bool_payload(int idx) const {
AssertInfo(array_ != nullptr, "null arrow array");
AssertInfo(array_->type()->id() == arrow::Type::type::BOOL,
"inconsistent data type");
auto array = std::dynamic_pointer_cast<arrow::BooleanArray>(array_);
AssertInfo(idx < array_->length(), "out range of bool array");
return array->Value(idx);
}
template <typename Type, bool is_scalar>
void
FieldData::get_one_string_payload(int idx, char** cstr, int* str_size) const {
AssertInfo(array_ != nullptr, "null arrow array");
AssertInfo(array_->type()->id() == arrow::Type::type::STRING,
"inconsistent data type");
auto array = std::dynamic_pointer_cast<arrow::StringArray>(array_);
AssertInfo(idx < array->length(), "index out of range array.length");
arrow::StringArray::offset_type length;
*cstr = (char*)array->GetValue(idx, &length);
*str_size = length;
}
std::unique_ptr<Payload>
FieldData::get_payload() const {
AssertInfo(array_ != nullptr, "null arrow array");
auto raw_data_info = std::make_unique<Payload>();
raw_data_info->rows = array_->length();
raw_data_info->data_type = data_type_;
raw_data_info->raw_data = GetRawValuesFromArrowArray(array_, data_type_);
if (milvus::datatype_is_vector(data_type_)) {
raw_data_info->dimension =
GetDimensionFromArrowArray(array_, data_type_);
FieldDataImpl<Type, is_scalar>::FillFieldData(const void* source,
ssize_t element_count) {
AssertInfo(element_count % dim_ == 0, "invalid element count");
if (element_count == 0) {
return;
}
return raw_data_info;
AssertInfo(field_data_.size() == 0, "no empty field vector");
field_data_.resize(element_count);
std::copy_n(
static_cast<const Type*>(source), element_count, field_data_.data());
}
// TODO :: handle string type
int
FieldData::get_data_size() const {
auto payload = get_payload();
return GetPayloadSize(payload.get());
template <typename Type, bool is_scalar>
void
FieldDataImpl<Type, is_scalar>::FillFieldData(
const std::shared_ptr<arrow::Array> array) {
AssertInfo(array != nullptr, "null arrow array");
auto element_count = array->length() * dim_;
if (element_count == 0) {
return;
}
switch (data_type_) {
case DataType::BOOL: {
AssertInfo(array->type()->id() == arrow::Type::type::BOOL,
"inconsistent data type");
auto bool_array =
std::dynamic_pointer_cast<arrow::BooleanArray>(array);
FixedVector<bool> values(element_count);
for (size_t index = 0; index < element_count; ++index) {
values[index] = bool_array->Value(index);
}
return FillFieldData(values.data(), element_count);
}
case DataType::INT8: {
AssertInfo(array->type()->id() == arrow::Type::type::INT8,
"inconsistent data type");
auto int8_array =
std::dynamic_pointer_cast<arrow::Int8Array>(array);
return FillFieldData(int8_array->raw_values(), element_count);
}
case DataType::INT16: {
AssertInfo(array->type()->id() == arrow::Type::type::INT16,
"inconsistent data type");
auto int16_array =
std::dynamic_pointer_cast<arrow::Int16Array>(array);
return FillFieldData(int16_array->raw_values(), element_count);
}
case DataType::INT32: {
AssertInfo(array->type()->id() == arrow::Type::type::INT32,
"inconsistent data type");
auto int32_array =
std::dynamic_pointer_cast<arrow::Int32Array>(array);
return FillFieldData(int32_array->raw_values(), element_count);
}
case DataType::INT64: {
AssertInfo(array->type()->id() == arrow::Type::type::INT64,
"inconsistent data type");
auto int64_array =
std::dynamic_pointer_cast<arrow::Int64Array>(array);
return FillFieldData(int64_array->raw_values(), element_count);
}
case DataType::FLOAT: {
AssertInfo(array->type()->id() == arrow::Type::type::FLOAT,
"inconsistent data type");
auto float_array =
std::dynamic_pointer_cast<arrow::FloatArray>(array);
return FillFieldData(float_array->raw_values(), element_count);
}
case DataType::DOUBLE: {
AssertInfo(array->type()->id() == arrow::Type::type::DOUBLE,
"inconsistent data type");
auto double_array =
std::dynamic_pointer_cast<arrow::DoubleArray>(array);
return FillFieldData(double_array->raw_values(), element_count);
}
case DataType::STRING:
case DataType::VARCHAR: {
AssertInfo(array->type()->id() == arrow::Type::type::STRING,
"inconsistent data type");
auto string_array =
std::dynamic_pointer_cast<arrow::StringArray>(array);
std::vector<std::string> values(element_count);
for (size_t index = 0; index < element_count; ++index) {
values[index] = string_array->GetString(index);
}
return FillFieldData(values.data(), element_count);
}
case DataType::VECTOR_FLOAT: {
AssertInfo(
array->type()->id() == arrow::Type::type::FIXED_SIZE_BINARY,
"inconsistent data type");
auto vector_array =
std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(array);
return FillFieldData(vector_array->raw_values(), element_count);
}
case DataType::VECTOR_BINARY: {
AssertInfo(
array->type()->id() == arrow::Type::type::FIXED_SIZE_BINARY,
"inconsistent data type");
auto vector_array =
std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(array);
return FillFieldData(vector_array->raw_values(), element_count);
}
default: {
throw NotSupportedDataTypeException(GetName() + "::FillFieldData" +
" not support data type " +
datatype_name(data_type_));
}
}
}
// scalar data
template class FieldDataImpl<bool, true>;
template class FieldDataImpl<unsigned char, false>;
template class FieldDataImpl<int8_t, true>;
template class FieldDataImpl<int16_t, true>;
template class FieldDataImpl<int32_t, true>;
template class FieldDataImpl<int64_t, true>;
template class FieldDataImpl<float, true>;
template class FieldDataImpl<double, true>;
template class FieldDataImpl<std::string, true>;
// vector data
template class FieldDataImpl<int8_t, false>;
template class FieldDataImpl<float, false>;
} // namespace milvus::storage

View File

@ -16,59 +16,54 @@
#pragma once
#include <iostream>
#include <string>
#include <memory>
#include "arrow/api.h"
#include "storage/Types.h"
#include "storage/PayloadStream.h"
#include "storage/FieldDataInterface.h"
namespace milvus::storage {
using DataType = milvus::DataType;
class FieldData {
template <typename Type>
class FieldData : public FieldDataImpl<Type, true> {
public:
explicit FieldData(const Payload& payload);
explicit FieldData(std::shared_ptr<arrow::Array> raw_data,
DataType data_type);
explicit FieldData(const uint8_t* data, int length);
// explicit FieldData(std::unique_ptr<uint8_t[]> data, int length, DataType data_type): data_(std::move(data)),
// data_len_(length), data_type_(data_type) {}
~FieldData() = default;
DataType
get_data_type() const {
return data_type_;
static_assert(IsScalar<Type> || std::is_same_v<Type, PkType>);
explicit FieldData(DataType data_type)
: FieldDataImpl<Type, true>::FieldDataImpl(1, data_type) {
}
bool
get_bool_payload(int idx) const;
void
get_one_string_payload(int idx, char** cstr, int* str_size) const;
// get the bytes stream of the arrow array data
std::unique_ptr<Payload>
get_payload() const;
int
get_payload_length() const {
return array_->length();
}
int
get_data_size() const;
private:
std::shared_ptr<arrow::Array> array_;
// std::unique_ptr<uint8_t[]> data_;
// int64_t data_len_;
DataType data_type_;
};
template <>
class FieldData<std::string> : public FieldDataStringImpl {
public:
static_assert(IsScalar<std::string> || std::is_same_v<std::string, PkType>);
explicit FieldData(DataType data_type) : FieldDataStringImpl(data_type) {
}
};
template <>
class FieldData<FloatVector> : public FieldDataImpl<float, false> {
public:
explicit FieldData(int64_t dim, DataType data_type)
: FieldDataImpl<float, false>::FieldDataImpl(dim, data_type) {
}
};
template <>
class FieldData<BinaryVector> : public FieldDataImpl<uint8_t, false> {
public:
explicit FieldData(int64_t dim, DataType data_type)
: binary_dim_(dim), FieldDataImpl(dim / 8, data_type) {
Assert(dim % 8 == 0);
}
int64_t
get_dim() const {
return binary_dim_;
}
private:
int64_t binary_dim_;
};
using FieldDataPtr = std::shared_ptr<FieldDataBase>;
} // namespace milvus::storage

View File

@ -0,0 +1,53 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "storage/FieldDataFactory.h"
#include "storage/Exception.h"
namespace milvus::storage {
FieldDataPtr
FieldDataFactory::CreateFieldData(const DataType& type, const int64_t dim) {
switch (type) {
case DataType::BOOL:
return std::make_shared<FieldData<bool>>(type);
case DataType::INT8:
return std::make_shared<FieldData<int8_t>>(type);
case DataType::INT16:
return std::make_shared<FieldData<int16_t>>(type);
case DataType::INT32:
return std::make_shared<FieldData<int32_t>>(type);
case DataType::INT64:
return std::make_shared<FieldData<int64_t>>(type);
case DataType::FLOAT:
return std::make_shared<FieldData<float>>(type);
case DataType::DOUBLE:
return std::make_shared<FieldData<double>>(type);
case DataType::STRING:
case DataType::VARCHAR:
return std::make_shared<FieldData<std::string>>(type);
case DataType::VECTOR_FLOAT:
return std::make_shared<FieldData<FloatVector>>(dim, type);
case DataType::VECTOR_BINARY:
return std::make_shared<FieldData<BinaryVector>>(dim, type);
default:
throw NotSupportedDataTypeException(
GetName() + "::CreateFieldData" + " not support data type " +
datatype_name(type));
}
}
} // namespace milvus::storage

View File

@ -0,0 +1,48 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "storage/FieldData.h"
namespace milvus::storage {
class FieldDataFactory {
private:
FieldDataFactory() = default;
FieldDataFactory(const FieldDataFactory&) = delete;
FieldDataFactory
operator=(const FieldDataFactory&) = delete;
public:
static FieldDataFactory&
GetInstance() {
static FieldDataFactory inst;
return inst;
}
std::string
GetName() const {
return "FieldDataFactory";
}
FieldDataPtr
CreateFieldData(const DataType& type, const int64_t dim = 1);
};
} // namespace milvus::storage

View File

@ -0,0 +1,172 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <memory>
#include <vector>
#include <string>
#include "arrow/api.h"
#include "common/FieldMeta.h"
#include "common/Utils.h"
#include "common/VectorTrait.h"
#include "exceptions/EasyAssert.h"
#include "storage/Exception.h"
namespace milvus::storage {
using DataType = milvus::DataType;
class FieldDataBase {
public:
explicit FieldDataBase(DataType data_type) : data_type_(data_type) {
}
virtual ~FieldDataBase() = default;
virtual void
FillFieldData(const void* source, ssize_t element_count) = 0;
virtual void
FillFieldData(const std::shared_ptr<arrow::Array> array) = 0;
virtual const void*
Data() const = 0;
virtual const void*
RawValue(ssize_t offset) const = 0;
virtual int64_t
Size() const = 0;
public:
virtual int
get_num_rows() const = 0;
virtual int64_t
get_dim() const = 0;
virtual int64_t
get_element_size(ssize_t offset) const = 0;
DataType
get_data_type() const {
return data_type_;
}
protected:
const DataType data_type_;
};
template <typename Type, bool is_scalar = false>
class FieldDataImpl : public FieldDataBase {
public:
// constants
using Chunk = FixedVector<Type>;
FieldDataImpl(FieldDataImpl&&) = delete;
FieldDataImpl(const FieldDataImpl&) = delete;
FieldDataImpl&
operator=(FieldDataImpl&&) = delete;
FieldDataImpl&
operator=(const FieldDataImpl&) = delete;
public:
explicit FieldDataImpl(ssize_t dim, DataType data_type)
: FieldDataBase(data_type), dim_(is_scalar ? 1 : dim) {
}
void
FillFieldData(const void* source, ssize_t element_count) override;
void
FillFieldData(const std::shared_ptr<arrow::Array> array) override;
std::string
GetName() const {
return "FieldDataImpl";
}
const void*
Data() const override {
return field_data_.data();
}
const void*
RawValue(ssize_t offset) const override {
return &field_data_[offset];
}
int64_t
Size() const override {
return sizeof(Type) * field_data_.size();
}
public:
int
get_num_rows() const override {
auto len = field_data_.size();
AssertInfo(len % dim_ == 0, "field data size not aligned");
return len / dim_;
}
int64_t
get_dim() const override {
return dim_;
}
int64_t
get_element_size(ssize_t offset) const override {
return sizeof(Type) * dim_;
}
protected:
Chunk field_data_;
private:
const ssize_t dim_;
};
class FieldDataStringImpl : public FieldDataImpl<std::string, true> {
public:
explicit FieldDataStringImpl(DataType data_type)
: FieldDataImpl<std::string, true>(1, data_type) {
}
const void*
RawValue(ssize_t offset) const {
return field_data_[offset].c_str();
}
int64_t
Size() const {
int64_t data_size = 0;
for (size_t offset = 0; offset < field_data_.size(); ++offset) {
data_size += get_element_size(offset);
}
return data_size;
}
public:
int64_t
get_element_size(ssize_t offset) const {
return field_data_[offset].size();
}
};
} // namespace milvus::storage

View File

@ -85,7 +85,7 @@ IndexData::serialize_to_remote_file() {
GetEventFixPartSize(EventType(i)));
}
des_event_data.extras[ORIGIN_SIZE_KEY] =
std::to_string(field_data_->get_data_size());
std::to_string(field_data_->Size());
des_event_data.extras[INDEX_BUILD_ID_KEY] =
std::to_string(index_meta_->build_id);

View File

@ -27,7 +27,7 @@ namespace milvus::storage {
// TODO :: indexParams storage in a single file
class IndexData : public DataCodec {
public:
explicit IndexData(std::shared_ptr<FieldData> data)
explicit IndexData(FieldDataPtr data)
: DataCodec(data, CodecType::IndexDataType) {
}

View File

@ -81,7 +81,7 @@ InsertData::serialize_to_remote_file() {
GetEventFixPartSize(EventType(i)));
}
des_event_data.extras[ORIGIN_SIZE_KEY] =
std::to_string(field_data_->get_data_size());
std::to_string(field_data_->Size());
auto& des_event_header = descriptor_event.event_header;
// TODO :: set timestamp

View File

@ -25,7 +25,7 @@ namespace milvus::storage {
class InsertData : public DataCodec {
public:
explicit InsertData(std::shared_ptr<FieldData> data)
explicit InsertData(FieldDataPtr data)
: DataCodec(data, CodecType::InsertDataType) {
}

View File

@ -352,21 +352,18 @@ MinioChunkManager::GetObjectBuffer(const std::string& bucket_name,
request.SetBucket(bucket_name.c_str());
request.SetKey(object_name.c_str());
request.SetResponseStreamFactory([buf, size]() {
std::unique_ptr<Aws::StringStream> stream(
Aws::New<Aws::StringStream>(""));
stream->rdbuf()->pubsetbuf(static_cast<char*>(buf), size);
return stream.release();
});
auto outcome = client_->GetObject(request);
if (!outcome.IsSuccess()) {
THROWS3ERROR(GetObjectBuffer);
}
std::stringstream ss;
ss << outcome.GetResultWithOwnership().GetBody().rdbuf();
uint64_t realSize = size;
if (ss.str().size() <= size) {
memcpy(buf, ss.str().data(), ss.str().size());
realSize = ss.str().size();
} else {
memcpy(buf, ss.str().data(), size);
}
return realSize;
return size;
}
std::vector<std::string>

View File

@ -16,6 +16,8 @@
#include "storage/PayloadReader.h"
#include "exceptions/EasyAssert.h"
#include "storage/FieldDataFactory.h"
#include "storage/Util.h"
namespace milvus::storage {
PayloadReader::PayloadReader(std::shared_ptr<PayloadInputStream> input,
@ -48,33 +50,12 @@ PayloadReader::init(std::shared_ptr<PayloadInputStream> input) {
"arrow chunk size in arrow column should be 1");
auto array = column->chunk(0);
AssertInfo(array != nullptr, "empty arrow array of PayloadReader");
field_data_ = std::make_shared<FieldData>(array, column_type_);
}
bool
PayloadReader::get_bool_payload(int idx) const {
AssertInfo(field_data_ != nullptr, "empty payload");
return field_data_->get_bool_payload(idx);
}
void
PayloadReader::get_one_string_Payload(int idx,
char** cstr,
int* str_size) const {
AssertInfo(field_data_ != nullptr, "empty payload");
return field_data_->get_one_string_payload(idx, cstr, str_size);
}
std::unique_ptr<Payload>
PayloadReader::get_payload() const {
AssertInfo(field_data_ != nullptr, "empty payload");
return field_data_->get_payload();
}
int
PayloadReader::get_payload_length() const {
AssertInfo(field_data_ != nullptr, "empty payload");
return field_data_->get_payload_length();
dim_ = datatype_is_vector(column_type_)
? GetDimensionFromArrowArray(array, column_type_)
: 1;
field_data_ =
FieldDataFactory::GetInstance().CreateFieldData(column_type_, dim_);
field_data_->FillFieldData(array);
}
} // namespace milvus::storage

View File

@ -36,26 +36,15 @@ class PayloadReader {
void
init(std::shared_ptr<PayloadInputStream> input);
bool
get_bool_payload(int idx) const;
void
get_one_string_Payload(int idx, char** cstr, int* str_size) const;
std::unique_ptr<Payload>
get_payload() const;
int
get_payload_length() const;
std::shared_ptr<FieldData>
const FieldDataPtr
get_field_data() const {
return field_data_;
}
private:
DataType column_type_;
std::shared_ptr<FieldData> field_data_;
int dim_;
FieldDataPtr field_data_;
};
} // namespace milvus::storage

View File

@ -26,12 +26,12 @@
namespace milvus::storage {
StorageType
ReadMediumType(PayloadInputStream* input_stream) {
AssertInfo(input_stream->Tell().Equals(arrow::Result<int64_t>(0)),
ReadMediumType(BinlogReaderPtr reader) {
AssertInfo(reader->Tell() == 0,
"medium type must be parsed from stream header");
int32_t magic_num;
auto ret = input_stream->Read(sizeof(magic_num), &magic_num);
AssertInfo(ret.ok(), "read input stream failed");
auto ret = reader->Read(sizeof(magic_num), &magic_num);
AssertInfo(ret.ok(), "read binlog failed");
if (magic_num == MAGIC_NUM) {
return StorageType::Remote;
}
@ -246,98 +246,6 @@ CreateArrowSchema(DataType data_type, int dim) {
}
}
// TODO ::handle string type
int64_t
GetPayloadSize(const Payload* payload) {
switch (payload->data_type) {
case DataType::BOOL:
return payload->rows * sizeof(bool);
case DataType::INT8:
return payload->rows * sizeof(int8_t);
case DataType::INT16:
return payload->rows * sizeof(int16_t);
case DataType::INT32:
return payload->rows * sizeof(int32_t);
case DataType::INT64:
return payload->rows * sizeof(int64_t);
case DataType::FLOAT:
return payload->rows * sizeof(float);
case DataType::DOUBLE:
return payload->rows * sizeof(double);
case DataType::VECTOR_FLOAT: {
Assert(payload->dimension.has_value());
return payload->rows * payload->dimension.value() * sizeof(float);
}
case DataType::VECTOR_BINARY: {
Assert(payload->dimension.has_value());
return payload->rows * payload->dimension.value();
}
default:
PanicInfo("unsupported data type");
}
}
const uint8_t*
GetRawValuesFromArrowArray(std::shared_ptr<arrow::Array> data,
DataType data_type) {
switch (data_type) {
case DataType::INT8: {
AssertInfo(data->type()->id() == arrow::Type::type::INT8,
"inconsistent data type");
auto array = std::dynamic_pointer_cast<arrow::Int8Array>(data);
return reinterpret_cast<const uint8_t*>(array->raw_values());
}
case DataType::INT16: {
AssertInfo(data->type()->id() == arrow::Type::type::INT16,
"inconsistent data type");
auto array = std::dynamic_pointer_cast<arrow::Int16Array>(data);
return reinterpret_cast<const uint8_t*>(array->raw_values());
}
case DataType::INT32: {
AssertInfo(data->type()->id() == arrow::Type::type::INT32,
"inconsistent data type");
auto array = std::dynamic_pointer_cast<arrow::Int32Array>(data);
return reinterpret_cast<const uint8_t*>(array->raw_values());
}
case DataType::INT64: {
AssertInfo(data->type()->id() == arrow::Type::type::INT64,
"inconsistent data type");
auto array = std::dynamic_pointer_cast<arrow::Int64Array>(data);
return reinterpret_cast<const uint8_t*>(array->raw_values());
}
case DataType::FLOAT: {
AssertInfo(data->type()->id() == arrow::Type::type::FLOAT,
"inconsistent data type");
auto array = std::dynamic_pointer_cast<arrow::FloatArray>(data);
return reinterpret_cast<const uint8_t*>(array->raw_values());
}
case DataType::DOUBLE: {
AssertInfo(data->type()->id() == arrow::Type::type::DOUBLE,
"inconsistent data type");
auto array = std::dynamic_pointer_cast<arrow::DoubleArray>(data);
return reinterpret_cast<const uint8_t*>(array->raw_values());
}
case DataType::VECTOR_FLOAT: {
AssertInfo(
data->type()->id() == arrow::Type::type::FIXED_SIZE_BINARY,
"inconsistent data type");
auto array =
std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(data);
return reinterpret_cast<const uint8_t*>(array->raw_values());
}
case DataType::VECTOR_BINARY: {
AssertInfo(
data->type()->id() == arrow::Type::type::FIXED_SIZE_BINARY,
"inconsistent data type");
auto array =
std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(data);
return reinterpret_cast<const uint8_t*>(array->raw_values());
}
default:
PanicInfo("unsupported data type");
}
}
int
GetDimensionFromArrowArray(std::shared_ptr<arrow::Array> data,
DataType data_type) {

View File

@ -22,12 +22,13 @@
#include "storage/PayloadStream.h"
#include "storage/FileManager.h"
#include "storage/BinlogReader.h"
#include "knowhere/comp/index_param.h"
namespace milvus::storage {
StorageType
ReadMediumType(PayloadInputStream* input_stream);
ReadMediumType(BinlogReaderPtr reader);
void
AddPayloadToArrowBuilder(std::shared_ptr<arrow::ArrayBuilder> builder,
@ -50,13 +51,6 @@ CreateArrowSchema(DataType data_type);
std::shared_ptr<arrow::Schema>
CreateArrowSchema(DataType data_type, int dim);
int64_t
GetPayloadSize(const Payload* payload);
const uint8_t*
GetRawValuesFromArrowArray(std::shared_ptr<arrow::Array> array,
DataType data_type);
int
GetDimensionFromArrowArray(std::shared_ptr<arrow::Array> array,
DataType data_type);

View File

@ -19,6 +19,7 @@
#include "storage/parquet_c.h"
#include "storage/PayloadReader.h"
#include "storage/PayloadWriter.h"
#include "storage/FieldData.h"
#include "common/CGoHelper.h"
using Payload = milvus::storage::Payload;
@ -218,8 +219,11 @@ ReleasePayloadWriter(CPayloadWriter handler) {
}
}
extern "C" CPayloadReader
NewPayloadReader(int columnType, uint8_t* buffer, int64_t buf_size) {
extern "C" CStatus
NewPayloadReader(int columnType,
uint8_t* buffer,
int64_t buf_size,
CPayloadReader* c_reader) {
auto column_type = static_cast<milvus::DataType>(columnType);
switch (column_type) {
case milvus::DataType::BOOL:
@ -236,19 +240,26 @@ NewPayloadReader(int columnType, uint8_t* buffer, int64_t buf_size) {
break;
}
default: {
return nullptr;
return milvus::FailureCStatus(UnexpectedError,
"unsupported data type");
}
}
auto p = std::make_unique<PayloadReader>(buffer, buf_size, column_type);
return reinterpret_cast<CPayloadReader>(p.release());
try {
auto p = std::make_unique<PayloadReader>(buffer, buf_size, column_type);
*c_reader = (CPayloadReader)(p.release());
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
}
}
extern "C" CStatus
GetBoolFromPayload(CPayloadReader payloadReader, int idx, bool* value) {
try {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
*value = p->get_bool_payload(idx);
auto field_data = p->get_field_data();
*value = *reinterpret_cast<const bool*>(field_data->RawValue(idx));
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
@ -259,10 +270,10 @@ extern "C" CStatus
GetInt8FromPayload(CPayloadReader payloadReader, int8_t** values, int* length) {
try {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
auto ret = p->get_payload();
auto raw_data = const_cast<uint8_t*>(ret->raw_data);
*values = reinterpret_cast<int8_t*>(raw_data);
*length = ret->rows;
auto field_data = p->get_field_data();
*length = field_data->get_num_rows();
*values =
reinterpret_cast<int8_t*>(const_cast<void*>(field_data->Data()));
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
@ -275,10 +286,10 @@ GetInt16FromPayload(CPayloadReader payloadReader,
int* length) {
try {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
auto ret = p->get_payload();
auto raw_data = const_cast<uint8_t*>(ret->raw_data);
*values = reinterpret_cast<int16_t*>(raw_data);
*length = ret->rows;
auto field_data = p->get_field_data();
*length = field_data->get_num_rows();
*values =
reinterpret_cast<int16_t*>(const_cast<void*>(field_data->Data()));
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
@ -291,10 +302,10 @@ GetInt32FromPayload(CPayloadReader payloadReader,
int* length) {
try {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
auto ret = p->get_payload();
auto raw_data = const_cast<uint8_t*>(ret->raw_data);
*values = reinterpret_cast<int32_t*>(raw_data);
*length = ret->rows;
auto field_data = p->get_field_data();
*length = field_data->get_num_rows();
*values =
reinterpret_cast<int32_t*>(const_cast<void*>(field_data->Data()));
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
@ -307,10 +318,10 @@ GetInt64FromPayload(CPayloadReader payloadReader,
int* length) {
try {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
auto ret = p->get_payload();
auto raw_data = const_cast<uint8_t*>(ret->raw_data);
*values = reinterpret_cast<int64_t*>(raw_data);
*length = ret->rows;
auto field_data = p->get_field_data();
*length = field_data->get_num_rows();
*values =
reinterpret_cast<int64_t*>(const_cast<void*>(field_data->Data()));
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
@ -321,10 +332,10 @@ extern "C" CStatus
GetFloatFromPayload(CPayloadReader payloadReader, float** values, int* length) {
try {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
auto ret = p->get_payload();
auto raw_data = const_cast<uint8_t*>(ret->raw_data);
*values = reinterpret_cast<float*>(raw_data);
*length = ret->rows;
auto field_data = p->get_field_data();
*length = field_data->get_num_rows();
*values =
reinterpret_cast<float*>(const_cast<void*>(field_data->Data()));
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
@ -337,10 +348,10 @@ GetDoubleFromPayload(CPayloadReader payloadReader,
int* length) {
try {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
auto ret = p->get_payload();
auto raw_data = const_cast<uint8_t*>(ret->raw_data);
*values = reinterpret_cast<double*>(raw_data);
*length = ret->rows;
auto field_data = p->get_field_data();
*length = field_data->get_num_rows();
*values =
reinterpret_cast<double*>(const_cast<void*>(field_data->Data()));
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
@ -354,7 +365,9 @@ GetOneStringFromPayload(CPayloadReader payloadReader,
int* str_size) {
try {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
p->get_one_string_Payload(idx, cstr, str_size);
auto field_data = p->get_field_data();
*cstr = (char*)(const_cast<void*>(field_data->RawValue(idx)));
*str_size = field_data->get_element_size(idx);
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
@ -368,10 +381,10 @@ GetBinaryVectorFromPayload(CPayloadReader payloadReader,
int* length) {
try {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
auto ret = p->get_payload();
*values = const_cast<uint8_t*>(ret->raw_data);
*length = ret->rows;
*dimension = ret->dimension.value();
auto field_data = p->get_field_data();
*values = (uint8_t*)field_data->Data();
*dimension = field_data->get_dim();
*length = field_data->get_num_rows();
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
@ -385,11 +398,10 @@ GetFloatVectorFromPayload(CPayloadReader payloadReader,
int* length) {
try {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
auto ret = p->get_payload();
auto raw_data = const_cast<uint8_t*>(ret->raw_data);
*values = reinterpret_cast<float*>(raw_data);
*length = ret->rows;
*dimension = ret->dimension.value();
auto field_data = p->get_field_data();
*values = (float*)field_data->Data();
*dimension = field_data->get_dim();
*length = field_data->get_num_rows();
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
@ -399,12 +411,20 @@ GetFloatVectorFromPayload(CPayloadReader payloadReader,
extern "C" int
GetPayloadLengthFromReader(CPayloadReader payloadReader) {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
return p->get_payload_length();
auto field_data = p->get_field_data();
return field_data->get_num_rows();
}
extern "C" void
extern "C" CStatus
ReleasePayloadReader(CPayloadReader payloadReader) {
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
delete (p);
ReleaseArrowUnused();
try {
AssertInfo(payloadReader != nullptr,
"released payloadReader should not be null pointer");
auto p = reinterpret_cast<PayloadReader*>(payloadReader);
delete (p);
ReleaseArrowUnused();
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
}
}

View File

@ -74,8 +74,11 @@ ReleasePayloadWriter(CPayloadWriter handler);
//============= payload reader ======================
typedef void* CPayloadReader;
CPayloadReader
NewPayloadReader(int columnType, uint8_t* buffer, int64_t buf_size);
CStatus
NewPayloadReader(int columnType,
uint8_t* buffer,
int64_t buf_size,
CPayloadReader* c_reader);
CStatus
GetBoolFromPayload(CPayloadReader payloadReader, int idx, bool* value);
CStatus
@ -116,7 +119,8 @@ GetFloatVectorFromPayload(CPayloadReader payloadReader,
int
GetPayloadLengthFromReader(CPayloadReader payloadReader);
void
CStatus
ReleasePayloadReader(CPayloadReader payloadReader);
#ifdef __cplusplus

View File

@ -19,17 +19,190 @@
#include "storage/DataCodec.h"
#include "storage/InsertData.h"
#include "storage/IndexData.h"
#include "storage/FieldDataFactory.h"
#include "common/Consts.h"
#include "utils/Json.h"
using namespace milvus;
TEST(storage, InsertDataBool) {
FixedVector<bool> data = {true, false, true, false, true};
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::BOOL);
field_data->FillFieldData(data.data(), data.size());
storage::InsertData insert_data(field_data);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
ASSERT_EQ(new_insert_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(), storage::DataType::BOOL);
ASSERT_EQ(new_payload->get_num_rows(), data.size());
FixedVector<bool> new_data(data.size());
memcpy(new_data.data(), new_payload->Data(), new_payload->Size());
ASSERT_EQ(data, new_data);
}
TEST(storage, InsertDataInt8) {
FixedVector<int8_t> data = {1, 2, 3, 4, 5};
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::INT8);
field_data->FillFieldData(data.data(), data.size());
storage::InsertData insert_data(field_data);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
ASSERT_EQ(new_insert_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT8);
ASSERT_EQ(new_payload->get_num_rows(), data.size());
FixedVector<int8_t> new_data(data.size());
memcpy(new_data.data(), new_payload->Data(), new_payload->Size());
ASSERT_EQ(data, new_data);
}
TEST(storage, InsertDataInt16) {
FixedVector<int16_t> data = {1, 2, 3, 4, 5};
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::INT16);
field_data->FillFieldData(data.data(), data.size());
storage::InsertData insert_data(field_data);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
ASSERT_EQ(new_insert_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT16);
ASSERT_EQ(new_payload->get_num_rows(), data.size());
FixedVector<int16_t> new_data(data.size());
memcpy(new_data.data(), new_payload->Data(), new_payload->Size());
ASSERT_EQ(data, new_data);
}
TEST(storage, InsertDataInt32) {
FixedVector<int32_t> data = {true, false, true, false, true};
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::INT32);
field_data->FillFieldData(data.data(), data.size());
storage::InsertData insert_data(field_data);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
ASSERT_EQ(new_insert_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT32);
ASSERT_EQ(new_payload->get_num_rows(), data.size());
FixedVector<int32_t> new_data(data.size());
memcpy(new_data.data(), new_payload->Data(), new_payload->Size());
ASSERT_EQ(data, new_data);
}
TEST(storage, InsertDataInt64) {
FixedVector<int64_t> data = {1, 2, 3, 4, 5};
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::INT64);
field_data->FillFieldData(data.data(), data.size());
storage::InsertData insert_data(field_data);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
ASSERT_EQ(new_insert_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT64);
ASSERT_EQ(new_payload->get_num_rows(), data.size());
FixedVector<int64_t> new_data(data.size());
memcpy(new_data.data(), new_payload->Data(), new_payload->Size());
ASSERT_EQ(data, new_data);
}
TEST(storage, InsertDataString) {
FixedVector<std::string> data = {
"test1", "test2", "test3", "test4", "test5"};
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::VARCHAR);
field_data->FillFieldData(data.data(), data.size());
storage::InsertData insert_data(field_data);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
ASSERT_EQ(new_insert_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VARCHAR);
ASSERT_EQ(new_payload->get_num_rows(), data.size());
FixedVector<std::string> new_data(data.size());
for (int i = 0; i < data.size(); ++i) {
new_data[i] = reinterpret_cast<const char*>(new_payload->RawValue(i));
ASSERT_EQ(new_payload->get_element_size(i), data[i].size());
}
ASSERT_EQ(data, new_data);
}
TEST(storage, InsertDataFloat) {
std::vector<float> data = {1, 2, 3, 4, 5};
storage::Payload payload{storage::DataType::FLOAT,
reinterpret_cast<const uint8_t*>(data.data()),
int(data.size())};
auto field_data = std::make_shared<storage::FieldData>(payload);
FixedVector<float> data = {1, 2, 3, 4, 5};
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::FLOAT);
field_data->FillFieldData(data.data(), data.size());
storage::InsertData insert_data(field_data);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
@ -37,30 +210,27 @@ TEST(storage, InsertDataFloat) {
insert_data.SetTimestamps(0, 100);
auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
reinterpret_cast<const uint8_t*>(serialized_bytes.data()),
serialized_bytes.size());
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
ASSERT_EQ(new_insert_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_payload = new_insert_data->GetPayload();
ASSERT_EQ(new_payload->data_type, storage::DataType::FLOAT);
ASSERT_EQ(new_payload->rows, data.size());
std::vector<float> new_data(data.size());
memcpy(new_data.data(),
new_payload->raw_data,
new_payload->rows * sizeof(float));
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(), storage::DataType::FLOAT);
ASSERT_EQ(new_payload->get_num_rows(), data.size());
FixedVector<float> new_data(data.size());
memcpy(new_data.data(), new_payload->Data(), new_payload->Size());
ASSERT_EQ(data, new_data);
}
TEST(storage, InsertDataVectorFloat) {
std::vector<float> data = {1, 2, 3, 4, 5, 6, 7, 8};
int DIM = 2;
storage::Payload payload{storage::DataType::VECTOR_FLOAT,
reinterpret_cast<const uint8_t*>(data.data()),
int(data.size()) / DIM,
DIM};
auto field_data = std::make_shared<storage::FieldData>(payload);
TEST(storage, InsertDataDouble) {
FixedVector<double> data = {1.0, 2.0, 3.0, 4.2, 5.3};
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::DOUBLE);
field_data->FillFieldData(data.data(), data.size());
storage::InsertData insert_data(field_data);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
@ -68,72 +238,107 @@ TEST(storage, InsertDataVectorFloat) {
insert_data.SetTimestamps(0, 100);
auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
reinterpret_cast<const uint8_t*>(serialized_bytes.data()),
serialized_bytes.size());
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
ASSERT_EQ(new_insert_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_payload = new_insert_data->GetPayload();
ASSERT_EQ(new_payload->data_type, storage::DataType::VECTOR_FLOAT);
ASSERT_EQ(new_payload->rows, data.size() / DIM);
std::vector<float> new_data(data.size());
memcpy(new_data.data(),
new_payload->raw_data,
new_payload->rows * sizeof(float) * DIM);
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(), storage::DataType::DOUBLE);
ASSERT_EQ(new_payload->get_num_rows(), data.size());
FixedVector<double> new_data(data.size());
memcpy(new_data.data(), new_payload->Data(), new_payload->Size());
ASSERT_EQ(data, new_data);
}
TEST(storage, LocalInsertDataVectorFloat) {
TEST(storage, InsertDataFloatVector) {
std::vector<float> data = {1, 2, 3, 4, 5, 6, 7, 8};
int DIM = 2;
storage::Payload payload{storage::DataType::VECTOR_FLOAT,
reinterpret_cast<const uint8_t*>(data.data()),
int(data.size()) / DIM,
DIM};
auto field_data = std::make_shared<storage::FieldData>(payload);
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::VECTOR_FLOAT, DIM);
field_data->FillFieldData(data.data(), data.size());
storage::InsertData insert_data(field_data);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes =
insert_data.Serialize(storage::StorageType::LocalDisk);
auto new_insert_data = storage::DeserializeLocalInsertFileData(
reinterpret_cast<const uint8_t*>(serialized_bytes.data()),
serialized_bytes.size(),
storage::DataType::VECTOR_FLOAT);
auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
auto new_payload = new_insert_data->GetPayload();
ASSERT_EQ(new_payload->data_type, storage::DataType::VECTOR_FLOAT);
ASSERT_EQ(new_payload->rows, data.size() / DIM);
ASSERT_EQ(new_insert_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VECTOR_FLOAT);
ASSERT_EQ(new_payload->get_num_rows(), data.size() / DIM);
std::vector<float> new_data(data.size());
memcpy(new_data.data(),
new_payload->raw_data,
new_payload->rows * sizeof(float) * DIM);
new_payload->Data(),
new_payload->get_num_rows() * sizeof(float) * DIM);
ASSERT_EQ(data, new_data);
}
TEST(storage, LocalIndexData) {
TEST(storage, InsertDataBinaryVector) {
std::vector<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8};
storage::Payload payload{storage::DataType::INT8,
reinterpret_cast<const uint8_t*>(data.data()),
int(data.size())};
auto field_data = std::make_shared<storage::FieldData>(payload);
storage::IndexData indexData_data(field_data);
auto serialized_bytes =
indexData_data.Serialize(storage::StorageType::LocalDisk);
int DIM = 16;
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::VECTOR_BINARY, DIM);
field_data->FillFieldData(data.data(), data.size());
auto new_index_data = storage::DeserializeLocalIndexFileData(
reinterpret_cast<const uint8_t*>(serialized_bytes.data()),
serialized_bytes.size());
ASSERT_EQ(new_index_data->GetCodecType(), storage::IndexDataType);
auto new_payload = new_index_data->GetPayload();
ASSERT_EQ(new_payload->data_type, storage::DataType::INT8);
ASSERT_EQ(new_payload->rows, data.size());
storage::InsertData insert_data(field_data);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
ASSERT_EQ(new_insert_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VECTOR_BINARY);
ASSERT_EQ(new_payload->get_num_rows(), data.size() * 8 / DIM);
std::vector<uint8_t> new_data(data.size());
memcpy(new_data.data(),
new_payload->raw_data,
new_payload->rows * sizeof(uint8_t));
memcpy(new_data.data(), new_payload->Data(), new_payload->Size());
ASSERT_EQ(data, new_data);
}
TEST(storage, IndexData) {
std::vector<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8};
auto field_data =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::INT8);
field_data->FillFieldData(data.data(), data.size());
storage::IndexData index_data(field_data);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
index_data.SetFieldDataMeta(field_data_meta);
index_data.SetTimestamps(0, 100);
storage::IndexMeta index_meta{102, 103, 104, 1};
index_data.set_index_meta(index_meta);
auto serialized_bytes = index_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_index_data = storage::DeserializeFileData(serialized_data_ptr,
serialized_bytes.size());
ASSERT_EQ(new_index_data->GetCodecType(), storage::IndexDataType);
ASSERT_EQ(new_index_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_field_data = new_index_data->GetFieldData();
ASSERT_EQ(new_field_data->get_data_type(), storage::DataType::INT8);
ASSERT_EQ(new_field_data->Size(), data.size());
std::vector<uint8_t> new_data(data.size());
memcpy(new_data.data(), new_field_data->Data(), new_field_data->Size());
ASSERT_EQ(data, new_data);
}

View File

@ -21,8 +21,8 @@
#include "storage/MinioChunkManager.h"
#include "storage/DiskFileManagerImpl.h"
#include "storage/ThreadPool.h"
#include "storage/FieldDataFactory.h"
#include "config/ConfigChunkManager.h"
#include "config/ConfigKnowhere.h"
#include "test_utils/indexbuilder_test_utils.h"
using namespace std;
@ -50,9 +50,9 @@ class DiskAnnFileManagerTest : public testing::Test {
TEST_F(DiskAnnFileManagerTest, AddFilePositive) {
auto& lcm = LocalChunkManager::GetInstance();
auto rcm = std::make_unique<MinioChunkManager>(storage_config_);
string testBucketName = "test-diskann";
storage_config_.bucket_name = testBucketName;
auto rcm = std::make_unique<MinioChunkManager>(storage_config_);
if (!rcm->BucketExists(testBucketName)) {
rcm->CreateBucket(testBucketName);
}
@ -83,7 +83,6 @@ TEST_F(DiskAnnFileManagerTest, AddFilePositive) {
std::vector<std::string> remote_files;
for (auto& file2size : remote_files_to_size) {
std::cout << file2size.first << std::endl;
remote_files.emplace_back(file2size.first);
}
diskAnnFileManager->CacheIndexToDisk(remote_files);
@ -93,22 +92,32 @@ TEST_F(DiskAnnFileManagerTest, AddFilePositive) {
auto buf = std::unique_ptr<uint8_t[]>(new uint8_t[file_size]);
lcm.Read(file, buf.get(), file_size);
auto index = FieldData(buf.get(), file_size);
auto payload = index.get_payload();
auto rows = payload->rows;
auto rawData = payload->raw_data;
auto index =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::INT8);
index->FillFieldData(buf.get(), file_size);
auto rows = index->get_num_rows();
auto rawData = (uint8_t*)(index->Data());
EXPECT_EQ(rows, index_size);
EXPECT_EQ(rawData[0], data[0]);
EXPECT_EQ(rawData[4], data[4]);
}
auto objects =
rcm->ListWithPrefix(diskAnnFileManager->GetRemoteIndexObjectPrefix());
for (auto obj : objects) {
rcm->Remove(obj);
}
ok = rcm->DeleteBucket(testBucketName);
EXPECT_EQ(ok, true);
}
TEST_F(DiskAnnFileManagerTest, AddFilePositiveParallel) {
auto& lcm = LocalChunkManager::GetInstance();
auto rcm = std::make_unique<MinioChunkManager>(storage_config_);
string testBucketName = "test-diskann";
storage_config_.bucket_name = testBucketName;
auto rcm = std::make_unique<MinioChunkManager>(storage_config_);
if (!rcm->BucketExists(testBucketName)) {
rcm->CreateBucket(testBucketName);
}
@ -149,15 +158,25 @@ TEST_F(DiskAnnFileManagerTest, AddFilePositiveParallel) {
auto buf = std::unique_ptr<uint8_t[]>(new uint8_t[file_size]);
lcm.Read(file, buf.get(), file_size);
auto index = FieldData(buf.get(), file_size);
auto payload = index.get_payload();
auto rows = payload->rows;
auto rawData = payload->raw_data;
auto index =
milvus::storage::FieldDataFactory::GetInstance().CreateFieldData(
storage::DataType::INT8);
index->FillFieldData(buf.get(), file_size);
auto rows = index->get_num_rows();
auto rawData = (uint8_t*)(index->Data());
EXPECT_EQ(rows, index_size);
EXPECT_EQ(rawData[0], data[0]);
EXPECT_EQ(rawData[4], data[4]);
}
auto objects =
rcm->ListWithPrefix(diskAnnFileManager->GetRemoteIndexObjectPrefix());
for (auto obj : objects) {
rcm->Remove(obj);
}
ok = rcm->DeleteBucket(testBucketName);
EXPECT_EQ(ok, true);
}
int

View File

@ -291,6 +291,10 @@ class IndexTest : public ::testing::TestWithParam<Param> {
void
SetUp() override {
storage_config_ = get_default_storage_config();
// auto rcm = std::make_shared<storage::MinioChunkManager>(storage_config_);
// if (!rcm->BucketExists(storage_config_.bucket_name)) {
// rcm->CreateBucket(storage_config_.bucket_name);
// }
auto param = GetParam();
index_type = param.first;

View File

@ -22,39 +22,27 @@ using namespace std;
using namespace milvus;
using namespace milvus::storage;
class LocalChunkManagerTest : public testing::Test {
public:
LocalChunkManagerTest() {
}
~LocalChunkManagerTest() {
}
virtual void
SetUp() {
std::string local_path_prefix = "/tmp/local-test-dir";
ChunkMangerConfig::SetLocalRootPath(local_path_prefix);
}
};
class LocalChunkManagerTest : public testing::Test {};
TEST_F(LocalChunkManagerTest, DirPositive) {
auto& lcm = LocalChunkManager::GetInstance();
string path_prefix = lcm.GetPathPrefix();
lcm.RemoveDir(path_prefix);
lcm.CreateDir(path_prefix);
string test_dir = lcm.GetPathPrefix() + "/local-test-dir/";
lcm.RemoveDir(test_dir);
lcm.CreateDir(test_dir);
bool exist = lcm.DirExist(path_prefix);
bool exist = lcm.DirExist(test_dir);
EXPECT_EQ(exist, true);
lcm.RemoveDir(path_prefix);
exist = lcm.DirExist(path_prefix);
lcm.RemoveDir(test_dir);
exist = lcm.DirExist(test_dir);
EXPECT_EQ(exist, false);
}
TEST_F(LocalChunkManagerTest, FilePositive) {
auto& lcm = LocalChunkManager::GetInstance();
string path_prefix = lcm.GetPathPrefix();
string test_dir = lcm.GetPathPrefix() + "/local-test-dir";
string file = "/tmp/local-test-dir/test-file";
string file = test_dir + "/test-file";
auto exist = lcm.Exist(file);
EXPECT_EQ(exist, false);
lcm.CreateFile(file);
@ -65,16 +53,16 @@ TEST_F(LocalChunkManagerTest, FilePositive) {
exist = lcm.Exist(file);
EXPECT_EQ(exist, false);
lcm.RemoveDir(path_prefix);
exist = lcm.DirExist(path_prefix);
lcm.RemoveDir(test_dir);
exist = lcm.DirExist(test_dir);
EXPECT_EQ(exist, false);
}
TEST_F(LocalChunkManagerTest, WritePositive) {
auto& lcm = LocalChunkManager::GetInstance();
string path_prefix = lcm.GetPathPrefix();
string test_dir = lcm.GetPathPrefix() + "/local-test-dir";
string file = "/tmp/local-test-dir/test-write-positive";
string file = test_dir + "/test-write-positive";
auto exist = lcm.Exist(file);
EXPECT_EQ(exist, false);
lcm.CreateFile(file);
@ -98,17 +86,17 @@ TEST_F(LocalChunkManagerTest, WritePositive) {
EXPECT_EQ(size, datasize);
delete[] bigdata;
lcm.RemoveDir(path_prefix);
exist = lcm.DirExist(path_prefix);
lcm.RemoveDir(test_dir);
exist = lcm.DirExist(test_dir);
EXPECT_EQ(exist, false);
}
TEST_F(LocalChunkManagerTest, ReadPositive) {
auto& lcm = LocalChunkManager::GetInstance();
string path_prefix = lcm.GetPathPrefix();
string test_dir = lcm.GetPathPrefix() + "/local-test-dir";
uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23};
string path = "/tmp/local-test-dir/test-read-positive";
string path = test_dir + "/test-read-positive";
lcm.CreateFile(path);
lcm.Write(path, data, sizeof(data));
bool exist = lcm.Exist(path);
@ -145,16 +133,16 @@ TEST_F(LocalChunkManagerTest, ReadPositive) {
EXPECT_EQ(readdata[3], 0x34);
EXPECT_EQ(readdata[4], 0x23);
lcm.RemoveDir(path_prefix);
exist = lcm.DirExist(path_prefix);
lcm.RemoveDir(test_dir);
exist = lcm.DirExist(test_dir);
EXPECT_EQ(exist, false);
}
TEST_F(LocalChunkManagerTest, WriteOffset) {
auto& lcm = LocalChunkManager::GetInstance();
string path_prefix = lcm.GetPathPrefix();
string test_dir = lcm.GetPathPrefix() + "/local-test-dir";
string file = "/tmp/local-test-dir/test-write-offset";
string file = test_dir + "/test-write-offset";
auto exist = lcm.Exist(file);
EXPECT_EQ(exist, false);
lcm.CreateFile(file);
@ -189,16 +177,16 @@ TEST_F(LocalChunkManagerTest, WriteOffset) {
EXPECT_EQ(read_data[8], 0x34);
EXPECT_EQ(read_data[9], 0x23);
lcm.RemoveDir(path_prefix);
exist = lcm.DirExist(path_prefix);
lcm.RemoveDir(test_dir);
exist = lcm.DirExist(test_dir);
EXPECT_EQ(exist, false);
}
TEST_F(LocalChunkManagerTest, ReadOffset) {
auto& lcm = LocalChunkManager::GetInstance();
string path_prefix = lcm.GetPathPrefix();
string test_dir = lcm.GetPathPrefix() + "/local-test-dir";
string file = "/tmp/local-test-dir/test-read-offset";
string file = test_dir + "/test-read-offset";
lcm.CreateFile(file);
auto exist = lcm.Exist(file);
EXPECT_EQ(exist, true);
@ -225,15 +213,14 @@ TEST_F(LocalChunkManagerTest, ReadOffset) {
EXPECT_EQ(size, 1);
EXPECT_EQ(read_data[0], 0x98);
lcm.RemoveDir(path_prefix);
exist = lcm.DirExist(path_prefix);
lcm.RemoveDir(test_dir);
exist = lcm.DirExist(test_dir);
EXPECT_EQ(exist, false);
}
TEST_F(LocalChunkManagerTest, GetSizeOfDir) {
auto& lcm = LocalChunkManager::GetInstance();
auto path_prefix = lcm.GetPathPrefix();
auto test_dir = path_prefix + "/" + "test_dir/";
auto test_dir = lcm.GetPathPrefix() + "/local-test-dir";
EXPECT_EQ(lcm.DirExist(test_dir), false);
lcm.CreateDir(test_dir);
EXPECT_EQ(lcm.DirExist(test_dir), true);
@ -241,7 +228,7 @@ TEST_F(LocalChunkManagerTest, GetSizeOfDir) {
uint8_t data[] = {0x17, 0x32, 0x00, 0x34, 0x23, 0x23, 0x87, 0x98};
// test get size of file in test_dir
auto file1 = test_dir + "file";
auto file1 = test_dir + "/file";
auto res = lcm.CreateFile(file1);
EXPECT_EQ(res, true);
lcm.Write(file1, data, sizeof(data));
@ -251,15 +238,15 @@ TEST_F(LocalChunkManagerTest, GetSizeOfDir) {
EXPECT_EQ(exist, false);
// test get dir size with nested dirs
auto nest_dir = test_dir + "nest_dir/";
auto file2 = nest_dir + "file";
auto nest_dir = test_dir + "/nest_dir";
auto file2 = nest_dir + "/file";
res = lcm.CreateFile(file2);
EXPECT_EQ(res, true);
lcm.Write(file2, data, sizeof(data));
EXPECT_EQ(lcm.GetSizeOfDir(test_dir), sizeof(data));
lcm.RemoveDir(test_dir);
lcm.RemoveDir(path_prefix);
exist = lcm.DirExist(path_prefix);
lcm.RemoveDir(test_dir);
exist = lcm.DirExist(test_dir);
EXPECT_EQ(exist, false);
}

View File

@ -125,11 +125,11 @@ TEST_F(MinioChunkManagerTest, ReadPositive) {
bool exist = chunk_manager_->Exist(path);
EXPECT_EQ(exist, true);
auto size = chunk_manager_->Size(path);
EXPECT_EQ(size, 5);
EXPECT_EQ(size, sizeof(data));
uint8_t readdata[20] = {0};
size = chunk_manager_->Read(path, readdata, 20);
EXPECT_EQ(size, 5);
size = chunk_manager_->Read(path, readdata, sizeof(data));
EXPECT_EQ(size, sizeof(data));
EXPECT_EQ(readdata[0], 0x17);
EXPECT_EQ(readdata[1], 0x32);
EXPECT_EQ(readdata[2], 0x45);
@ -147,9 +147,9 @@ TEST_F(MinioChunkManagerTest, ReadPositive) {
exist = chunk_manager_->Exist(path);
EXPECT_EQ(exist, true);
size = chunk_manager_->Size(path);
EXPECT_EQ(size, 5);
size = chunk_manager_->Read(path, readdata, 20);
EXPECT_EQ(size, 5);
EXPECT_EQ(size, sizeof(dataWithNULL));
size = chunk_manager_->Read(path, readdata, sizeof(dataWithNULL));
EXPECT_EQ(size, sizeof(dataWithNULL));
EXPECT_EQ(readdata[0], 0x17);
EXPECT_EQ(readdata[1], 0x32);
EXPECT_EQ(readdata[2], 0x00);

View File

@ -108,8 +108,10 @@ TEST(storage, boolean) {
auto nums = GetPayloadLengthFromWriter(payload);
ASSERT_EQ(nums, 4);
auto reader = NewPayloadReader(
int(milvus::DataType::BOOL), (uint8_t*)cb.data, cb.length);
CPayloadReader reader;
st = NewPayloadReader(
int(milvus::DataType::BOOL), (uint8_t*)cb.data, cb.length, &reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
bool* values;
int length = GetPayloadLengthFromReader(reader);
ASSERT_EQ(length, 4);
@ -121,42 +123,46 @@ TEST(storage, boolean) {
}
ReleasePayloadWriter(payload);
ReleasePayloadReader(reader);
st = ReleasePayloadReader(reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
}
#define NUMERIC_TEST( \
TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) \
TEST(wrapper, TEST_NAME) { \
auto payload = NewPayloadWriter(COLUMN_TYPE); \
DATA_TYPE data[] = {-1, 1, -100, 100}; \
\
auto st = ADD_FUNC(payload, data, 4); \
ASSERT_EQ(st.error_code, ErrorCode::Success); \
st = FinishPayloadWriter(payload); \
ASSERT_EQ(st.error_code, ErrorCode::Success); \
auto cb = GetPayloadBufferFromWriter(payload); \
ASSERT_GT(cb.length, 0); \
ASSERT_NE(cb.data, nullptr); \
auto nums = GetPayloadLengthFromWriter(payload); \
ASSERT_EQ(nums, 4); \
\
auto reader = \
NewPayloadReader(COLUMN_TYPE, (uint8_t*)cb.data, cb.length); \
DATA_TYPE* values; \
int length; \
st = GET_FUNC(reader, &values, &length); \
ASSERT_EQ(st.error_code, ErrorCode::Success); \
ASSERT_NE(values, nullptr); \
ASSERT_EQ(length, 4); \
length = GetPayloadLengthFromReader(reader); \
ASSERT_EQ(length, 4); \
\
for (int i = 0; i < length; i++) { \
ASSERT_EQ(data[i], values[i]); \
} \
\
ReleasePayloadWriter(payload); \
ReleasePayloadReader(reader); \
#define NUMERIC_TEST( \
TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) \
TEST(wrapper, TEST_NAME) { \
auto payload = NewPayloadWriter(COLUMN_TYPE); \
DATA_TYPE data[] = {-1, 1, -100, 100}; \
\
auto st = ADD_FUNC(payload, data, 4); \
ASSERT_EQ(st.error_code, ErrorCode::Success); \
st = FinishPayloadWriter(payload); \
ASSERT_EQ(st.error_code, ErrorCode::Success); \
auto cb = GetPayloadBufferFromWriter(payload); \
ASSERT_GT(cb.length, 0); \
ASSERT_NE(cb.data, nullptr); \
auto nums = GetPayloadLengthFromWriter(payload); \
ASSERT_EQ(nums, 4); \
\
CPayloadReader reader; \
st = NewPayloadReader( \
COLUMN_TYPE, (uint8_t*)cb.data, cb.length, &reader); \
ASSERT_EQ(st.error_code, ErrorCode::Success); \
DATA_TYPE* values; \
int length; \
st = GET_FUNC(reader, &values, &length); \
ASSERT_EQ(st.error_code, ErrorCode::Success); \
ASSERT_NE(values, nullptr); \
ASSERT_EQ(length, 4); \
length = GetPayloadLengthFromReader(reader); \
ASSERT_EQ(length, 4); \
\
for (int i = 0; i < length; i++) { \
ASSERT_EQ(data[i], values[i]); \
} \
\
ReleasePayloadWriter(payload); \
st = ReleasePayloadReader(reader); \
ASSERT_EQ(st.error_code, ErrorCode::Success); \
}
NUMERIC_TEST(int8,
@ -215,8 +221,10 @@ TEST(storage, stringarray) {
auto nums = GetPayloadLengthFromWriter(payload);
ASSERT_EQ(nums, 3);
auto reader = NewPayloadReader(
int(milvus::DataType::VARCHAR), (uint8_t*)cb.data, cb.length);
CPayloadReader reader;
st = NewPayloadReader(
int(milvus::DataType::VARCHAR), (uint8_t*)cb.data, cb.length, &reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
int length = GetPayloadLengthFromReader(reader);
ASSERT_EQ(length, 3);
char *v0, *v1, *v2;
@ -246,7 +254,8 @@ TEST(storage, stringarray) {
ASSERT_EQ(v2[2], 0);
ReleasePayloadWriter(payload);
ReleasePayloadReader(reader);
st = ReleasePayloadReader(reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
}
TEST(storage, binary_vector) {
@ -265,8 +274,12 @@ TEST(storage, binary_vector) {
auto nums = GetPayloadLengthFromWriter(payload);
ASSERT_EQ(nums, 4);
auto reader = NewPayloadReader(
int(milvus::DataType::VECTOR_BINARY), (uint8_t*)cb.data, cb.length);
CPayloadReader reader;
st = NewPayloadReader(int(milvus::DataType::VECTOR_BINARY),
(uint8_t*)cb.data,
cb.length,
&reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
uint8_t* values;
int length;
int dim;
@ -283,7 +296,8 @@ TEST(storage, binary_vector) {
}
ReleasePayloadWriter(payload);
ReleasePayloadReader(reader);
st = ReleasePayloadReader(reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
}
TEST(storage, binary_vector_empty) {
@ -297,12 +311,17 @@ TEST(storage, binary_vector_empty) {
// ASSERT_EQ(cb.data, nullptr);
auto nums = GetPayloadLengthFromWriter(payload);
ASSERT_EQ(nums, 0);
auto reader = NewPayloadReader(
int(milvus::DataType::VECTOR_BINARY), (uint8_t*)cb.data, cb.length);
CPayloadReader reader;
st = NewPayloadReader(int(milvus::DataType::VECTOR_BINARY),
(uint8_t*)cb.data,
cb.length,
&reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
ASSERT_EQ(0, GetPayloadLengthFromReader(reader));
// ASSERT_EQ(reader, nullptr);
ReleasePayloadWriter(payload);
ReleasePayloadReader(reader);
st = ReleasePayloadReader(reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
}
TEST(storage, float_vector) {
@ -321,8 +340,12 @@ TEST(storage, float_vector) {
auto nums = GetPayloadLengthFromWriter(payload);
ASSERT_EQ(nums, 4);
auto reader = NewPayloadReader(
int(milvus::DataType::VECTOR_FLOAT), (uint8_t*)cb.data, cb.length);
CPayloadReader reader;
st = NewPayloadReader(int(milvus::DataType::VECTOR_FLOAT),
(uint8_t*)cb.data,
cb.length,
&reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
float* values;
int length;
int dim;
@ -339,7 +362,8 @@ TEST(storage, float_vector) {
}
ReleasePayloadWriter(payload);
ReleasePayloadReader(reader);
st = ReleasePayloadReader(reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
}
TEST(storage, float_vector_empty) {
@ -353,12 +377,17 @@ TEST(storage, float_vector_empty) {
// ASSERT_EQ(cb.data, nullptr);
auto nums = GetPayloadLengthFromWriter(payload);
ASSERT_EQ(nums, 0);
auto reader = NewPayloadReader(
int(milvus::DataType::VECTOR_FLOAT), (uint8_t*)cb.data, cb.length);
CPayloadReader reader;
st = NewPayloadReader(int(milvus::DataType::VECTOR_FLOAT),
(uint8_t*)cb.data,
cb.length,
&reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
ASSERT_EQ(0, GetPayloadLengthFromReader(reader));
// ASSERT_EQ(reader, nullptr);
ReleasePayloadWriter(payload);
ReleasePayloadReader(reader);
st = ReleasePayloadReader(reader);
ASSERT_EQ(st.error_code, ErrorCode::Success);
}
TEST(storage, int8_2) {

View File

@ -70,8 +70,8 @@ type PayloadReaderInterface interface {
GetBinaryVectorFromPayload() ([]byte, int, error)
GetFloatVectorFromPayload() ([]float32, int, error)
GetPayloadLengthFromReader() (int, error)
ReleasePayloadReader()
Close()
ReleasePayloadReader() error
Close() error
}
// PayloadWriter writes data into payload

View File

@ -645,14 +645,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) {
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReaderCgo(schemapb.DataType_Bool, buffer)
assert.Nil(t, err)
_, err = r.GetBoolFromPayload()
assert.NotNil(t, err)
r.colType = 999
_, err = r.GetBoolFromPayload()
_, err = NewPayloadReaderCgo(schemapb.DataType_Bool, buffer)
assert.NotNil(t, err)
})
t.Run("TestGetInt8Error", func(t *testing.T) {
@ -669,14 +662,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) {
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReaderCgo(schemapb.DataType_Int8, buffer)
assert.Nil(t, err)
_, err = r.GetInt8FromPayload()
assert.NotNil(t, err)
r.colType = 999
_, err = r.GetInt8FromPayload()
_, err = NewPayloadReaderCgo(schemapb.DataType_Int8, buffer)
assert.NotNil(t, err)
})
t.Run("TestGetInt16Error", func(t *testing.T) {
@ -693,14 +679,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) {
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReaderCgo(schemapb.DataType_Int16, buffer)
assert.Nil(t, err)
_, err = r.GetInt16FromPayload()
assert.NotNil(t, err)
r.colType = 999
_, err = r.GetInt16FromPayload()
_, err = NewPayloadReaderCgo(schemapb.DataType_Int16, buffer)
assert.NotNil(t, err)
})
t.Run("TestGetInt32Error", func(t *testing.T) {
@ -717,14 +696,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) {
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReaderCgo(schemapb.DataType_Int32, buffer)
assert.Nil(t, err)
_, err = r.GetInt32FromPayload()
assert.NotNil(t, err)
r.colType = 999
_, err = r.GetInt32FromPayload()
_, err = NewPayloadReaderCgo(schemapb.DataType_Int32, buffer)
assert.NotNil(t, err)
})
t.Run("TestGetInt64Error", func(t *testing.T) {
@ -741,14 +713,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) {
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReaderCgo(schemapb.DataType_Int64, buffer)
assert.Nil(t, err)
_, err = r.GetInt64FromPayload()
assert.NotNil(t, err)
r.colType = 999
_, err = r.GetInt64FromPayload()
_, err = NewPayloadReaderCgo(schemapb.DataType_Int64, buffer)
assert.NotNil(t, err)
})
t.Run("TestGetFloatError", func(t *testing.T) {
@ -765,14 +730,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) {
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReaderCgo(schemapb.DataType_Float, buffer)
assert.Nil(t, err)
_, err = r.GetFloatFromPayload()
assert.NotNil(t, err)
r.colType = 999
_, err = r.GetFloatFromPayload()
_, err = NewPayloadReaderCgo(schemapb.DataType_Float, buffer)
assert.NotNil(t, err)
})
t.Run("TestGetDoubleError", func(t *testing.T) {
@ -789,14 +747,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) {
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReaderCgo(schemapb.DataType_Double, buffer)
assert.Nil(t, err)
_, err = r.GetDoubleFromPayload()
assert.NotNil(t, err)
r.colType = 999
_, err = r.GetDoubleFromPayload()
_, err = NewPayloadReaderCgo(schemapb.DataType_Double, buffer)
assert.NotNil(t, err)
})
t.Run("TestGetStringError", func(t *testing.T) {
@ -813,14 +764,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) {
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReaderCgo(schemapb.DataType_String, buffer)
assert.Nil(t, err)
_, err = r.GetStringFromPayload()
assert.NotNil(t, err)
r.colType = 999
_, err = r.GetStringFromPayload()
_, err = NewPayloadReaderCgo(schemapb.DataType_String, buffer)
assert.NotNil(t, err)
})
t.Run("TestGetBinaryVectorError", func(t *testing.T) {
@ -837,14 +781,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) {
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReaderCgo(schemapb.DataType_BinaryVector, buffer)
assert.Nil(t, err)
_, _, err = r.GetBinaryVectorFromPayload()
assert.NotNil(t, err)
r.colType = 999
_, _, err = r.GetBinaryVectorFromPayload()
_, err = NewPayloadReaderCgo(schemapb.DataType_BinaryVector, buffer)
assert.NotNil(t, err)
})
t.Run("TestGetFloatVectorError", func(t *testing.T) {
@ -861,14 +798,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) {
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReaderCgo(schemapb.DataType_FloatVector, buffer)
assert.Nil(t, err)
_, _, err = r.GetFloatVectorFromPayload()
assert.NotNil(t, err)
r.colType = 999
_, _, err = r.GetFloatVectorFromPayload()
_, err = NewPayloadReaderCgo(schemapb.DataType_FloatVector, buffer)
assert.NotNil(t, err)
})

View File

@ -73,8 +73,8 @@ func (r *PayloadReader) GetDataFromPayload() (interface{}, int, error) {
}
// ReleasePayloadReader release payload reader.
func (r *PayloadReader) ReleasePayloadReader() {
r.Close()
func (r *PayloadReader) ReleasePayloadReader() error {
return r.Close()
}
// GetBoolFromPayload returns bool slice from payload.
@ -308,8 +308,8 @@ func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) {
}
// Close closes the payload reader
func (r *PayloadReader) Close() {
r.reader.Close()
func (r *PayloadReader) Close() error {
return r.reader.Close()
}
// ReadDataFromAllRowGroups iterates all row groups of file.Reader, and convert column to E.

View File

@ -28,9 +28,10 @@ func NewPayloadReaderCgo(colType schemapb.DataType, buf []byte) (*PayloadReaderC
if len(buf) == 0 {
return nil, errors.New("create Payload reader failed, buffer is empty")
}
r := C.NewPayloadReader(C.int(colType), (*C.uint8_t)(unsafe.Pointer(&buf[0])), C.int64_t(len(buf)))
if r == nil {
return nil, errors.New("failed to read parquet from buffer")
var r C.CPayloadReader
status := C.NewPayloadReader(C.int(colType), (*C.uint8_t)(unsafe.Pointer(&buf[0])), C.int64_t(len(buf)), &r)
if err := HandleCStatus(&status, "NewPayloadReader failed"); err != nil {
return nil, err
}
return &PayloadReaderCgo{payloadReaderPtr: r, colType: colType}, nil
}
@ -81,8 +82,13 @@ func (r *PayloadReaderCgo) GetDataFromPayload() (interface{}, int, error) {
}
// ReleasePayloadReader release payload reader.
func (r *PayloadReaderCgo) ReleasePayloadReader() {
C.ReleasePayloadReader(r.payloadReaderPtr)
func (r *PayloadReaderCgo) ReleasePayloadReader() error {
status := C.ReleasePayloadReader(r.payloadReaderPtr)
if err := HandleCStatus(&status, "ReleasePayloadReader failed"); err != nil {
return err
}
return nil
}
// GetBoolFromPayload returns bool slice from payload.
@ -303,8 +309,8 @@ func (r *PayloadReaderCgo) GetPayloadLengthFromReader() (int, error) {
}
// Close closes the payload reader
func (r *PayloadReaderCgo) Close() {
r.ReleasePayloadReader()
func (r *PayloadReaderCgo) Close() error {
return r.ReleasePayloadReader()
}
// HandleCStatus deal with the error returned from CGO