mirror of https://github.com/milvus-io/milvus.git
Enable term parser and executor
Signed-off-by: FluorineDog <guilin.gou@zilliz.com>pull/4973/head^2
parent
6412ebc0d4
commit
63c8f60c6e
|
@ -56,3 +56,6 @@ cmake_build/
|
|||
.DS_Store
|
||||
*.swp
|
||||
cwrapper_build
|
||||
**/.clangd/*
|
||||
**/compile_commands.json
|
||||
**/.lint
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
## Binlog
|
||||
|
||||
InsertBinlog、DeleteBinlog、DDLBinlog
|
||||
|
||||
Binlog is stored in a columnar storage format, every column in schema should be stored in a individual file. Timestamp, schema, row id and primary key allocated by system are four special columns. Schema column records the DDL of the collection.
|
||||
|
||||
|
||||
|
||||
## Event format
|
||||
|
||||
Binlog file consists of 4 bytes magic number and a series of events. The first event must be descriptor event.
|
||||
|
||||
### Event format
|
||||
|
||||
```
|
||||
+=====================================+
|
||||
| event | timestamp 0 : 8 | create timestamp
|
||||
| header +----------------------------+
|
||||
| | type_code 8 : 1 | event type code
|
||||
| +----------------------------+
|
||||
| | server_id 9 : 4 | write node id
|
||||
| +----------------------------+
|
||||
| | event_length 13 : 4 | length of event, including header and data
|
||||
| +----------------------------+
|
||||
| | next_position 17 : 4 | offset of next event from the start of file
|
||||
| +----------------------------+
|
||||
| | extra_headers 21 : x-21 | reserved part
|
||||
+=====================================+
|
||||
| event | fixed part x : y |
|
||||
| data +----------------------------+
|
||||
| | variable part |
|
||||
+=====================================+
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Descriptor Event format
|
||||
|
||||
```
|
||||
+=====================================+
|
||||
| event | timestamp 0 : 8 | create timestamp
|
||||
| header +----------------------------+
|
||||
| | type_code 8 : 1 | event type code
|
||||
| +----------------------------+
|
||||
| | server_id 9 : 4 | write node id
|
||||
| +----------------------------+
|
||||
| | event_length 13 : 4 | length of event, including header and data
|
||||
| +----------------------------+
|
||||
| | next_position 17 : 4 | offset of next event from the start of file
|
||||
+=====================================+
|
||||
| event | binlog_version 21 : 2 | binlog version
|
||||
| data +----------------------------+
|
||||
| | server_version 23 : 8 | write node version
|
||||
| +----------------------------+
|
||||
| | commit_id 31 : 8 | commit id of the programe in git
|
||||
| +----------------------------+
|
||||
| | header_length 39 : 1 | header length of other event
|
||||
| +----------------------------+
|
||||
| | collection_id 40 : 8 | collection id
|
||||
| +----------------------------+
|
||||
| | partition_id 48 : 8 | partition id (schema column does not need)
|
||||
| +----------------------------+
|
||||
| | segment_id 56 : 8 | segment id (schema column does not need)
|
||||
| +----------------------------+
|
||||
| | start_timestamp 64 : 1 | minimum timestamp allocated by master of all events in this file
|
||||
| +----------------------------+
|
||||
| | end_timestamp 65 : 1 | maximum timestamp allocated by master of all events in this file
|
||||
| +----------------------------+
|
||||
| | post-header 66 : n | array of n bytes, one byte per event type that the server knows about
|
||||
| | lengths for all |
|
||||
| | event types |
|
||||
+=====================================+
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Type code
|
||||
|
||||
```
|
||||
DESCRIPTOR_EVENT
|
||||
INSERT_EVENT
|
||||
DELETE_EVENT
|
||||
CREATE_COLLECTION_EVENT
|
||||
DROP_COLLECTION_EVENT
|
||||
CREATE_PARTITION_EVENT
|
||||
DROP_PARTITION_EVENT
|
||||
```
|
||||
|
||||
DESCRIPTOR_EVENT must appear in all column files and always be the first event.
|
||||
|
||||
INSERT_EVENT 可以出现在除DDL binlog文件外的其他列的binlog
|
||||
|
||||
DELETE_EVENT 只能用于primary key 的binlog文件(目前只有按照primary key删除)
|
||||
|
||||
CREATE_COLLECTION_EVENT、DROP_COLLECTION_EVENT、CREATE_PARTITION_EVENT、DROP_PARTITION_EVENT 只出现在DDL binlog文件
|
||||
|
||||
|
||||
|
||||
### Event data part
|
||||
|
||||
```
|
||||
event data part
|
||||
|
||||
INSERT_EVENT:
|
||||
+================================================+
|
||||
| event | fixed | start_timestamp x : 8 | min timestamp in this event
|
||||
| data | part +------------------------------+
|
||||
| | | end_timestamp x+8 : 8 | max timestamp in this event
|
||||
| | +------------------------------+
|
||||
| | | reserved x+16 : y-x-16 | reserved part
|
||||
| +--------+------------------------------+
|
||||
| |variable| parquet payloI ad | payload in parquet format
|
||||
| |part | |
|
||||
+================================================+
|
||||
|
||||
other events is similar with INSERT_EVENT
|
||||
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
### Example
|
||||
|
||||
Schema
|
||||
|
||||
string | int | float(optional) | vector(512)
|
||||
|
||||
|
||||
|
||||
Request:
|
||||
|
||||
InsertRequest rows(1W)
|
||||
|
||||
DeleteRequest pk=1
|
||||
|
||||
DropPartition partitionTag="abc"
|
||||
|
||||
|
||||
|
||||
insert binlogs:
|
||||
|
||||
rowid, pk, ts, string, int, float, vector 6 files
|
||||
|
||||
all events are INSERT_EVENT
|
||||
float column file contains some NULL value
|
||||
|
||||
delete binlogs:
|
||||
|
||||
pk, ts 2 files
|
||||
|
||||
pk's events are DELETE_EVENT, ts's events are INSERT_EVENT
|
||||
|
||||
DDL binlogs:
|
||||
|
||||
ddl, ts
|
||||
|
||||
ddl's event is DROP_PARTITION_EVENT, ts's event is INSERT_EVENT
|
||||
|
||||
|
||||
|
||||
C++ interface
|
||||
|
||||
```c++
|
||||
typedef void* CPayloadWriter
|
||||
typedef struct CBuffer {
|
||||
char* data;
|
||||
int length;
|
||||
} CBuffer
|
||||
|
||||
typedef struct CStatus {
|
||||
int error_code;
|
||||
const char* error_msg;
|
||||
} CStatus
|
||||
|
||||
|
||||
// C++ interface
|
||||
// writer
|
||||
CPayloadWriter NewPayloadWriter(int columnType);
|
||||
CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
|
||||
CStatus AddInt8ToPayload(CPayloadWriter payloadWriter, int8_t *values, int length);
|
||||
CStatus AddInt16ToPayload(CPayloadWriter payloadWriter, int16_t *values, int length);
|
||||
CStatus AddInt32ToPayload(CPayloadWriter payloadWriter, int32_t *values, int length);
|
||||
CStatus AddInt64ToPayload(CPayloadWriter payloadWriter, int64_t *values, int length);
|
||||
CStatus AddFloatToPayload(CPayloadWriter payloadWriter, float *values, int length);
|
||||
CStatus AddDoubleToPayload(CPayloadWriter payloadWriter, double *values, int length);
|
||||
CStatus AddOneStringToPayload(CPayloadWriter payloadWriter, char *cstr, int str_size);
|
||||
CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_t *values, int dimension, int length);
|
||||
CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float *values, int dimension, int length);
|
||||
|
||||
CStatus FinishPayloadWriter(CPayloadWriter payloadWriter);
|
||||
CBuffer GetPayloadBufferFromWriter(CPayloadWriter payloadWriter);
|
||||
int GetPayloadLengthFromWriter(CPayloadWriter payloadWriter);
|
||||
CStatus ReleasePayloadWriter(CPayloadWriter handler);
|
||||
|
||||
// reader
|
||||
CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_size);
|
||||
CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
|
||||
CStatus GetInt8FromPayload(CPayloadReader payloadReader, int8_t **values, int *length);
|
||||
CStatus GetInt16FromPayload(CPayloadReader payloadReader, int16_t **values, int *length);
|
||||
CStatus GetInt32FromPayload(CPayloadReader payloadReader, int32_t **values, int *length);
|
||||
CStatus GetInt64FromPayload(CPayloadReader payloadReader, int64_t **values, int *length);
|
||||
CStatus GetFloatFromPayload(CPayloadReader payloadReader, float **values, int *length);
|
||||
CStatus GetDoubleFromPayload(CPayloadReader payloadReader, double **values, int *length);
|
||||
CStatus GetOneStringFromPayload(CPayloadReader payloadReader, int idx, char **cstr, int *str_size);
|
||||
CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader, uint8_t **values, int *dimension, int *length);
|
||||
CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader, float **values, int *dimension, int *length);
|
||||
|
||||
int GetPayloadLengthFromReader(CPayloadReader payloadReader);
|
||||
CStatus ReleasePayloadReader(CPayloadReader payloadReader);
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ static auto map = [] {
|
|||
MetricType
|
||||
GetMetricType(const std::string& type_name) {
|
||||
auto real_name = to_lower_copy(type_name);
|
||||
AssertInfo(map.left.count(real_name), "metric type not found: " + type_name);
|
||||
AssertInfo(map.left.count(real_name), "metric type not found: (" + type_name + ")");
|
||||
return map.left.at(real_name);
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#include "utils/Types.h"
|
||||
#include <faiss/MetricType.h>
|
||||
#include <string>
|
||||
#include <boost/align/aligned_allocator.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace milvus {
|
||||
using Timestamp = uint64_t; // TODO: use TiKV-like timestamp
|
||||
|
@ -24,4 +26,15 @@ using MetricType = faiss::MetricType;
|
|||
faiss::MetricType
|
||||
GetMetricType(const std::string& type);
|
||||
|
||||
// NOTE: dependent type
|
||||
// used at meta-template programming
|
||||
template <class...>
|
||||
constexpr std::true_type always_true{};
|
||||
|
||||
template <class...>
|
||||
constexpr std::false_type always_false{};
|
||||
|
||||
template <typename T>
|
||||
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 512>>;
|
||||
|
||||
} // namespace milvus
|
||||
|
|
|
@ -70,8 +70,6 @@ to_lower(const std::string& raw) {
|
|||
return data;
|
||||
}
|
||||
|
||||
template <class...>
|
||||
constexpr std::false_type always_false{};
|
||||
template <typename T>
|
||||
std::unique_ptr<Expr>
|
||||
ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
|
||||
|
@ -85,31 +83,62 @@ ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Js
|
|||
|
||||
AssertInfo(RangeExpr::mapping_.count(op_name), "op(" + op_name + ") not found");
|
||||
auto op = RangeExpr::mapping_.at(op_name);
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
Assert(item.value().is_boolean());
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
Assert(item.value().is_number_integer());
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
Assert(item.value().is_number());
|
||||
} else {
|
||||
static_assert(always_false<T>, "unsupported type");
|
||||
__builtin_unreachable();
|
||||
}
|
||||
T value = item.value();
|
||||
expr->conditions_.emplace_back(op, value);
|
||||
}
|
||||
std::sort(expr->conditions_.begin(), expr->conditions_.end());
|
||||
return expr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<Expr>
|
||||
ParseTermNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
|
||||
auto expr = std::make_unique<TermExprImpl<T>>();
|
||||
auto data_type = schema[field_name].get_data_type();
|
||||
Assert(body.is_array());
|
||||
expr->field_id_ = field_name;
|
||||
expr->data_type_ = data_type;
|
||||
for (auto& value : body) {
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
Assert(value.is_boolean());
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
Assert(value.is_number_integer());
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
Assert(value.is_number());
|
||||
} else {
|
||||
static_assert(always_false<T>, "unsupported type");
|
||||
__builtin_unreachable();
|
||||
}
|
||||
T real_value = value;
|
||||
expr->terms_.push_back(real_value);
|
||||
}
|
||||
std::sort(expr->terms_.begin(), expr->terms_.end());
|
||||
return expr;
|
||||
}
|
||||
|
||||
std::unique_ptr<Expr>
|
||||
ParseRangeNode(const Schema& schema, const Json& out_body) {
|
||||
Assert(out_body.is_object());
|
||||
Assert(out_body.size() == 1);
|
||||
auto out_iter = out_body.begin();
|
||||
auto field_name = out_iter.key();
|
||||
auto body = out_iter.value();
|
||||
auto data_type = schema[field_name].get_data_type();
|
||||
Assert(!field_is_vector(data_type));
|
||||
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
PanicInfo("bool is not supported in Range node");
|
||||
// return ParseRangeNodeImpl<bool>(schema, field_name, body);
|
||||
return ParseRangeNodeImpl<bool>(schema, field_name, body);
|
||||
}
|
||||
case DataType::INT8:
|
||||
return ParseRangeNodeImpl<int8_t>(schema, field_name, body);
|
||||
|
@ -128,6 +157,42 @@ ParseRangeNode(const Schema& schema, const Json& out_body) {
|
|||
}
|
||||
}
|
||||
|
||||
static std::unique_ptr<Expr>
|
||||
ParseTermNode(const Schema& schema, const Json& out_body) {
|
||||
Assert(out_body.size() == 1);
|
||||
auto out_iter = out_body.begin();
|
||||
auto field_name = out_iter.key();
|
||||
auto body = out_iter.value();
|
||||
auto data_type = schema[field_name].get_data_type();
|
||||
Assert(!field_is_vector(data_type));
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
return ParseTermNodeImpl<bool>(schema, field_name, body);
|
||||
}
|
||||
case DataType::INT8: {
|
||||
return ParseTermNodeImpl<int8_t>(schema, field_name, body);
|
||||
}
|
||||
case DataType::INT16: {
|
||||
return ParseTermNodeImpl<int16_t>(schema, field_name, body);
|
||||
}
|
||||
case DataType::INT32: {
|
||||
return ParseTermNodeImpl<int32_t>(schema, field_name, body);
|
||||
}
|
||||
case DataType::INT64: {
|
||||
return ParseTermNodeImpl<int64_t>(schema, field_name, body);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
return ParseTermNodeImpl<float>(schema, field_name, body);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
return ParseTermNodeImpl<double>(schema, field_name, body);
|
||||
}
|
||||
default: {
|
||||
PanicInfo("unsupported data_type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::unique_ptr<Plan>
|
||||
CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
|
||||
auto plan = std::make_unique<Plan>(schema);
|
||||
|
@ -143,6 +208,10 @@ CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
|
|||
if (pack.contains("vector")) {
|
||||
auto& out_body = pack.at("vector");
|
||||
plan->plan_node_ = ParseVecNode(plan.get(), out_body);
|
||||
} else if (pack.contains("term")) {
|
||||
AssertInfo(!predicate, "unsupported complex DSL");
|
||||
auto& out_body = pack.at("term");
|
||||
predicate = ParseTermNode(schema, out_body);
|
||||
} else if (pack.contains("range")) {
|
||||
AssertInfo(!predicate, "unsupported complex DSL");
|
||||
auto& out_body = pack.at("range");
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <boost/align/aligned_allocator.hpp>
|
||||
|
||||
namespace milvus::query {
|
||||
using Json = nlohmann::json;
|
||||
|
@ -39,9 +38,6 @@ struct Plan {
|
|||
// TODO: add move extra info
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 512>>;
|
||||
|
||||
struct Placeholder {
|
||||
// milvus::proto::service::PlaceholderGroup group_;
|
||||
std::string tag_;
|
||||
|
|
|
@ -27,7 +27,7 @@ create_bitmap_view(std::optional<const BitmapSimple*> bitmaps_opt, int64_t chunk
|
|||
return nullptr;
|
||||
}
|
||||
auto& bitmaps = *bitmaps_opt.value();
|
||||
auto& src_vec = bitmaps.at(chunk_id);
|
||||
auto src_vec = ~bitmaps.at(chunk_id);
|
||||
auto dst = std::make_shared<faiss::ConcurrentBitset>(src_vec.size());
|
||||
auto iter = reinterpret_cast<BitmapChunk::block_type*>(dst->mutable_data());
|
||||
|
||||
|
|
|
@ -58,6 +58,10 @@ class ExecExprVisitor : ExprVisitor {
|
|||
auto
|
||||
ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType;
|
||||
|
||||
template <typename T>
|
||||
auto
|
||||
ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType;
|
||||
|
||||
private:
|
||||
segcore::SegmentSmallIndex& segment_;
|
||||
std::optional<RetType> ret_;
|
||||
|
|
|
@ -46,6 +46,10 @@ class ExecExprVisitor : ExprVisitor {
|
|||
auto
|
||||
ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType;
|
||||
|
||||
template <typename T>
|
||||
auto
|
||||
ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType;
|
||||
|
||||
private:
|
||||
segcore::SegmentSmallIndex& segment_;
|
||||
std::optional<RetType> ret_;
|
||||
|
@ -63,11 +67,6 @@ ExecExprVisitor::visit(BoolBinaryExpr& expr) {
|
|||
PanicInfo("unimplemented");
|
||||
}
|
||||
|
||||
void
|
||||
ExecExprVisitor::visit(TermExpr& expr) {
|
||||
PanicInfo("unimplemented");
|
||||
}
|
||||
|
||||
template <typename T, typename IndexFunc, typename ElementFunc>
|
||||
auto
|
||||
ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl<T>& expr, IndexFunc index_func, ElementFunc element_func)
|
||||
|
@ -84,17 +83,17 @@ ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl<T>& expr, IndexFunc index_fu
|
|||
auto& indexing_record = segment_.get_indexing_record();
|
||||
const segcore::ScalarIndexingEntry<T>& entry = indexing_record.get_scalar_entry<T>(field_offset);
|
||||
|
||||
RetType results(vec.chunk_size());
|
||||
RetType results(vec.num_chunk());
|
||||
auto indexing_barrier = indexing_record.get_finished_ack();
|
||||
for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
|
||||
auto& result = results[chunk_id];
|
||||
auto indexing = entry.get_indexing(chunk_id);
|
||||
auto data = index_func(indexing);
|
||||
result = ~std::move(*data);
|
||||
result = std::move(*data);
|
||||
Assert(result.size() == segcore::DefaultElementPerChunk);
|
||||
}
|
||||
|
||||
for (auto chunk_id = indexing_barrier; chunk_id < vec.chunk_size(); ++chunk_id) {
|
||||
for (auto chunk_id = indexing_barrier; chunk_id < vec.num_chunk(); ++chunk_id) {
|
||||
auto& result = results[chunk_id];
|
||||
result.resize(segcore::DefaultElementPerChunk);
|
||||
auto chunk = vec.get_chunk(chunk_id);
|
||||
|
@ -126,32 +125,32 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType {
|
|||
switch (op) {
|
||||
case OpType::Equal: {
|
||||
auto index_func = [val](Index* index) { return index->In(1, &val); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x == val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x == val); });
|
||||
}
|
||||
|
||||
case OpType::NotEqual: {
|
||||
auto index_func = [val](Index* index) { return index->NotIn(1, &val); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x != val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x != val); });
|
||||
}
|
||||
|
||||
case OpType::GreaterEqual: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, Operator::GE); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x >= val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x >= val); });
|
||||
}
|
||||
|
||||
case OpType::GreaterThan: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, Operator::GT); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x > val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x > val); });
|
||||
}
|
||||
|
||||
case OpType::LessEqual: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, Operator::LE); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x <= val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x <= val); });
|
||||
}
|
||||
|
||||
case OpType::LessThan: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, Operator::LT); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x < val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x < val); });
|
||||
}
|
||||
default: {
|
||||
PanicInfo("unsupported range node");
|
||||
|
@ -167,16 +166,16 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType {
|
|||
if (false) {
|
||||
} else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessThan)) {
|
||||
auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, false); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x < val2); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x < val2); });
|
||||
} else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessEqual)) {
|
||||
auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, true); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x <= val2); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x <= val2); });
|
||||
} else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessThan)) {
|
||||
auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, false); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x < val2); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x < val2); });
|
||||
} else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessEqual)) {
|
||||
auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, true); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x <= val2); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x <= val2); });
|
||||
} else {
|
||||
PanicInfo("unsupported range node");
|
||||
}
|
||||
|
@ -226,4 +225,79 @@ ExecExprVisitor::visit(RangeExpr& expr) {
|
|||
ret_ = std::move(ret);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto
|
||||
ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType {
|
||||
auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
|
||||
auto& records = segment_.get_insert_record();
|
||||
auto data_type = expr.data_type_;
|
||||
auto& schema = segment_.get_schema();
|
||||
auto field_offset_opt = schema.get_offset(expr.field_id_);
|
||||
Assert(field_offset_opt);
|
||||
auto field_offset = field_offset_opt.value();
|
||||
auto& field_meta = schema[field_offset];
|
||||
auto vec_ptr = records.get_entity<T>(field_offset);
|
||||
auto& vec = *vec_ptr;
|
||||
auto num_chunk = vec.num_chunk();
|
||||
RetType bitsets;
|
||||
|
||||
auto N = records.ack_responder_.GetAck();
|
||||
|
||||
// small batch
|
||||
for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
|
||||
auto& chunk = vec.get_chunk(chunk_id);
|
||||
|
||||
auto size = chunk_id == num_chunk - 1 ? N - chunk_id * segcore::DefaultElementPerChunk
|
||||
: segcore::DefaultElementPerChunk;
|
||||
|
||||
boost::dynamic_bitset<> bitset(segcore::DefaultElementPerChunk);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
auto value = chunk[i];
|
||||
bool is_in = std::binary_search(expr.terms_.begin(), expr.terms_.end(), value);
|
||||
bitset[i] = is_in;
|
||||
}
|
||||
bitsets.emplace_back(std::move(bitset));
|
||||
}
|
||||
return bitsets;
|
||||
}
|
||||
|
||||
void
|
||||
ExecExprVisitor::visit(TermExpr& expr) {
|
||||
auto& field_meta = segment_.get_schema()[expr.field_id_];
|
||||
Assert(expr.data_type_ == field_meta.get_data_type());
|
||||
RetType ret;
|
||||
switch (expr.data_type_) {
|
||||
case DataType::BOOL: {
|
||||
ret = ExecTermVisitorImpl<bool>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::INT8: {
|
||||
ret = ExecTermVisitorImpl<int8_t>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
ret = ExecTermVisitorImpl<int16_t>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
ret = ExecTermVisitorImpl<int32_t>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
ret = ExecTermVisitorImpl<int64_t>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
ret = ExecTermVisitorImpl<float>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
ret = ExecTermVisitorImpl<double>(expr);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo("unsupported");
|
||||
}
|
||||
ret_ = std::move(ret);
|
||||
}
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -196,7 +196,7 @@ class ConcurrentVectorImpl : public VectorBase {
|
|||
}
|
||||
|
||||
ssize_t
|
||||
chunk_size() const {
|
||||
num_chunk() const {
|
||||
return chunks_.size();
|
||||
}
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ VecIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const Vector
|
|||
|
||||
auto source = dynamic_cast<const ConcurrentVector<FloatVector>*>(vec_base);
|
||||
Assert(source);
|
||||
auto chunk_size = source->chunk_size();
|
||||
auto chunk_size = source->num_chunk();
|
||||
assert(ack_end <= chunk_size);
|
||||
auto conf = get_build_conf();
|
||||
data_.grow_to_at_least(ack_end);
|
||||
|
@ -87,7 +87,7 @@ void
|
|||
ScalarIndexingEntry<T>::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
|
||||
auto source = dynamic_cast<const ConcurrentVector<T>*>(vec_base);
|
||||
Assert(source);
|
||||
auto chunk_size = source->chunk_size();
|
||||
auto chunk_size = source->num_chunk();
|
||||
assert(ack_end <= chunk_size);
|
||||
data_.grow_to_at_least(ack_end);
|
||||
for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) {
|
||||
|
|
|
@ -467,16 +467,16 @@ SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
|
|||
auto dim = field.get_dim();
|
||||
|
||||
auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode);
|
||||
auto chunk_size = record_.uids_.chunk_size();
|
||||
auto chunk_size = record_.uids_.num_chunk();
|
||||
|
||||
auto& uids = record_.uids_;
|
||||
auto entities = record_.get_entity<FloatVector>(offset);
|
||||
|
||||
std::vector<knowhere::DatasetPtr> datasets;
|
||||
for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) {
|
||||
for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) {
|
||||
auto entities_chunk = entities->get_chunk(chunk_id).data();
|
||||
int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
|
||||
: DefaultElementPerChunk;
|
||||
int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
|
||||
: DefaultElementPerChunk;
|
||||
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
|
||||
}
|
||||
for (auto& ds : datasets) {
|
||||
|
|
|
@ -241,10 +241,10 @@ SegmentSmallIndex::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
|
|||
auto entities = record_.get_entity<FloatVector>(offset);
|
||||
|
||||
std::vector<knowhere::DatasetPtr> datasets;
|
||||
for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) {
|
||||
for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) {
|
||||
auto entities_chunk = entities->get_chunk(chunk_id).data();
|
||||
int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
|
||||
: DefaultElementPerChunk;
|
||||
int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
|
||||
: DefaultElementPerChunk;
|
||||
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
|
||||
}
|
||||
for (auto& ds : datasets) {
|
||||
|
|
|
@ -26,4 +26,5 @@ target_link_libraries(all_tests
|
|||
pthread
|
||||
milvus_utils
|
||||
)
|
||||
|
||||
install (TARGETS all_tests DESTINATION unittest)
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
|
@ -52,7 +52,7 @@ TEST(ConcurrentVector, TestSingle) {
|
|||
c_vec.set_data(total_count, vec.data(), insert_size);
|
||||
total_count += insert_size;
|
||||
}
|
||||
ASSERT_EQ(c_vec.chunk_size(), (total_count + 31) / 32);
|
||||
ASSERT_EQ(c_vec.num_chunk(), (total_count + 31) / 32);
|
||||
for (int i = 0; i < total_count; ++i) {
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
auto std_data = d + i * dim;
|
||||
|
|
|
@ -321,7 +321,88 @@ TEST(Expr, TestRange) {
|
|||
auto ans = final[vec_id][offset];
|
||||
|
||||
auto val = age_col[i];
|
||||
auto ref = !ref_func(val);
|
||||
auto ref = ref_func(val);
|
||||
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Expr, TestTerm) {
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
auto vec_2k_3k = [] {
|
||||
std::string buf = "[";
|
||||
for (int i = 2000; i < 3000 - 1; ++i) {
|
||||
buf += std::to_string(i) + ", ";
|
||||
}
|
||||
buf += std::to_string(2999) + "]";
|
||||
return buf;
|
||||
}();
|
||||
|
||||
std::vector<std::tuple<std::string, std::function<bool(int)>>> testcases = {
|
||||
{R"([2000, 3000])", [](int v) { return v == 2000 || v == 3000; }},
|
||||
{R"([2000])", [](int v) { return v == 2000; }},
|
||||
{R"([3000])", [](int v) { return v == 3000; }},
|
||||
{vec_2k_3k, [](int v) { return 2000 <= v && v < 3000; }},
|
||||
};
|
||||
|
||||
std::string dsl_string_tmp = R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"term": {
|
||||
"age": @@@@
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
|
||||
auto seg = CreateSegment(schema);
|
||||
int N = 10000;
|
||||
std::vector<int> age_col;
|
||||
int num_iters = 100;
|
||||
for (int iter = 0; iter < num_iters; ++iter) {
|
||||
auto raw_data = DataGen(schema, N, iter);
|
||||
auto new_age_col = raw_data.get_col<int>(1);
|
||||
age_col.insert(age_col.end(), new_age_col.begin(), new_age_col.end());
|
||||
seg->PreInsert(N);
|
||||
seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_);
|
||||
}
|
||||
|
||||
auto seg_promote = dynamic_cast<SegmentSmallIndex*>(seg.get());
|
||||
ExecExprVisitor visitor(*seg_promote);
|
||||
for (auto [clause, ref_func] : testcases) {
|
||||
auto loc = dsl_string_tmp.find("@@@@");
|
||||
auto dsl_string = dsl_string_tmp;
|
||||
dsl_string.replace(loc, 4, clause);
|
||||
auto plan = CreatePlan(*schema, dsl_string);
|
||||
auto final = visitor.call_child(*plan->plan_node_->predicate_.value());
|
||||
EXPECT_EQ(final.size(), upper_div(N * num_iters, DefaultElementPerChunk));
|
||||
|
||||
for (int i = 0; i < N * num_iters; ++i) {
|
||||
auto vec_id = i / DefaultElementPerChunk;
|
||||
auto offset = i % DefaultElementPerChunk;
|
||||
auto ans = final[vec_id][offset];
|
||||
|
||||
auto val = age_col[i];
|
||||
auto ref = ref_func(val);
|
||||
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,6 +31,14 @@ struct GeneratedData {
|
|||
memcpy(ret.data(), target.data(), target.size());
|
||||
return ret;
|
||||
}
|
||||
template <typename T>
|
||||
auto
|
||||
get_mutable_col(int index) {
|
||||
auto& target = cols_.at(index);
|
||||
assert(target.size() == row_ids_.size() * sizeof(T));
|
||||
auto ptr = reinterpret_cast<T*>(target.data());
|
||||
return ptr;
|
||||
}
|
||||
|
||||
private:
|
||||
GeneratedData() = default;
|
||||
|
@ -58,6 +66,9 @@ GeneratedData::generate_rows(int N, SchemaPtr schema) {
|
|||
}
|
||||
}
|
||||
rows_ = std::move(result);
|
||||
raw_.raw_data = rows_.data();
|
||||
raw_.sizeof_per_row = schema->get_total_sizeof();
|
||||
raw_.count = N;
|
||||
}
|
||||
|
||||
inline GeneratedData
|
||||
|
@ -129,14 +140,12 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) {
|
|||
}
|
||||
GeneratedData res;
|
||||
res.cols_ = std::move(cols);
|
||||
res.generate_rows(N, schema);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
res.row_ids_.push_back(i);
|
||||
res.timestamps_.push_back(i);
|
||||
}
|
||||
res.raw_.raw_data = res.rows_.data();
|
||||
res.raw_.sizeof_per_row = schema->get_total_sizeof();
|
||||
res.raw_.count = N;
|
||||
|
||||
res.generate_rows(N, schema);
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
|
|
|
@ -206,7 +206,7 @@ extern "C" CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_
|
|||
st.error_msg = ErrorMsg("payload has finished");
|
||||
return st;
|
||||
}
|
||||
auto ast = builder->AppendValues(values, (dimension / 8) * length);
|
||||
auto ast = builder->AppendValues(values, length);
|
||||
if (!ast.ok()) {
|
||||
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
|
||||
st.error_msg = ErrorMsg(ast.message());
|
||||
|
@ -249,7 +249,7 @@ extern "C" CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float *
|
|||
st.error_msg = ErrorMsg("payload has finished");
|
||||
return st;
|
||||
}
|
||||
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), dimension * length * sizeof(float));
|
||||
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), length);
|
||||
if (!ast.ok()) {
|
||||
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
|
||||
st.error_msg = ErrorMsg(ast.message());
|
||||
|
@ -451,7 +451,7 @@ extern "C" CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader,
|
|||
return st;
|
||||
}
|
||||
*dimension = array->byte_width() * 8;
|
||||
*length = array->length() / array->byte_width();
|
||||
*length = array->length();
|
||||
*values = (uint8_t *) array->raw_values();
|
||||
return st;
|
||||
}
|
||||
|
@ -470,7 +470,7 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
|
|||
return st;
|
||||
}
|
||||
*dimension = array->byte_width() / sizeof(float);
|
||||
*length = array->length() / array->byte_width();
|
||||
*length = array->length();
|
||||
*values = (float *) array->raw_values();
|
||||
return st;
|
||||
}
|
||||
|
@ -478,12 +478,7 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
|
|||
extern "C" int GetPayloadLengthFromReader(CPayloadReader payloadReader) {
|
||||
auto p = reinterpret_cast<wrapper::PayloadReader *>(payloadReader);
|
||||
if (p->array == nullptr) return 0;
|
||||
auto ba = std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(p->array);
|
||||
if (ba == nullptr) {
|
||||
return p->array->length();
|
||||
} else {
|
||||
return ba->length() / ba->byte_width();
|
||||
}
|
||||
return p->array->length();
|
||||
}
|
||||
|
||||
extern "C" CStatus ReleasePayloadReader(CPayloadReader payloadReader) {
|
||||
|
|
|
@ -5,6 +5,7 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
typedef void *CPayloadWriter;
|
||||
|
||||
|
@ -19,7 +20,7 @@ typedef struct CStatus {
|
|||
} CStatus;
|
||||
|
||||
CPayloadWriter NewPayloadWriter(int columnType);
|
||||
//CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
|
||||
CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
|
||||
CStatus AddInt8ToPayload(CPayloadWriter payloadWriter, int8_t *values, int length);
|
||||
CStatus AddInt16ToPayload(CPayloadWriter payloadWriter, int16_t *values, int length);
|
||||
CStatus AddInt32ToPayload(CPayloadWriter payloadWriter, int32_t *values, int length);
|
||||
|
@ -39,7 +40,7 @@ CStatus ReleasePayloadWriter(CPayloadWriter handler);
|
|||
|
||||
typedef void *CPayloadReader;
|
||||
CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_size);
|
||||
//CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
|
||||
CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
|
||||
CStatus GetInt8FromPayload(CPayloadReader payloadReader, int8_t **values, int *length);
|
||||
CStatus GetInt16FromPayload(CPayloadReader payloadReader, int16_t **values, int *length);
|
||||
CStatus GetInt32FromPayload(CPayloadReader payloadReader, int32_t **values, int *length);
|
||||
|
@ -55,4 +56,4 @@ CStatus ReleasePayloadReader(CPayloadReader payloadReader);
|
|||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
@ -70,38 +70,38 @@ TEST(wrapper, inoutstream) {
|
|||
ASSERT_EQ(inarray->Value(4), 5);
|
||||
}
|
||||
|
||||
//TEST(wrapper, boolean) {
|
||||
// auto payload = NewPayloadWriter(ColumnType::BOOL);
|
||||
// bool data[] = {true, false, true, false};
|
||||
//
|
||||
// auto st = AddBooleanToPayload(payload, data, 4);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// st = FinishPayloadWriter(payload);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// auto cb = GetPayloadBufferFromWriter(payload);
|
||||
// ASSERT_GT(cb.length, 0);
|
||||
// ASSERT_NE(cb.data, nullptr);
|
||||
// auto nums = GetPayloadLengthFromWriter(payload);
|
||||
// ASSERT_EQ(nums, 4);
|
||||
//
|
||||
// auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
|
||||
// bool *values;
|
||||
// int length;
|
||||
// st = GetBoolFromPayload(reader, &values, &length);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// ASSERT_NE(values, nullptr);
|
||||
// ASSERT_EQ(length, 4);
|
||||
// length = GetPayloadLengthFromReader(reader);
|
||||
// ASSERT_EQ(length, 4);
|
||||
// for (int i = 0; i < length; i++) {
|
||||
// ASSERT_EQ(data[i], values[i]);
|
||||
// }
|
||||
//
|
||||
// st = ReleasePayloadWriter(payload);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// st = ReleasePayloadReader(reader);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
//}
|
||||
TEST(wrapper, boolean) {
|
||||
auto payload = NewPayloadWriter(ColumnType::BOOL);
|
||||
bool data[] = {true, false, true, false};
|
||||
|
||||
auto st = AddBooleanToPayload(payload, data, 4);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
st = FinishPayloadWriter(payload);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
auto cb = GetPayloadBufferFromWriter(payload);
|
||||
ASSERT_GT(cb.length, 0);
|
||||
ASSERT_NE(cb.data, nullptr);
|
||||
auto nums = GetPayloadLengthFromWriter(payload);
|
||||
ASSERT_EQ(nums, 4);
|
||||
|
||||
auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
|
||||
bool *values;
|
||||
int length;
|
||||
st = GetBoolFromPayload(reader, &values, &length);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
ASSERT_NE(values, nullptr);
|
||||
ASSERT_EQ(length, 4);
|
||||
length = GetPayloadLengthFromReader(reader);
|
||||
ASSERT_EQ(length, 4);
|
||||
for (int i = 0; i < length; i++) {
|
||||
ASSERT_EQ(data[i], values[i]);
|
||||
}
|
||||
|
||||
st = ReleasePayloadWriter(payload);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
st = ReleasePayloadReader(reader);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
}
|
||||
|
||||
#define NUMERIC_TEST(TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) TEST(wrapper, TEST_NAME) { \
|
||||
auto payload = NewPayloadWriter(COLUMN_TYPE); \
|
||||
|
|
|
@ -16,25 +16,311 @@ import (
|
|||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
type PayloadWriter struct {
|
||||
payloadWriterPtr C.CPayloadWriter
|
||||
}
|
||||
type (
|
||||
PayloadWriter struct {
|
||||
payloadWriterPtr C.CPayloadWriter
|
||||
colType schemapb.DataType
|
||||
}
|
||||
|
||||
PayloadReader struct {
|
||||
payloadReaderPtr C.CPayloadReader
|
||||
colType schemapb.DataType
|
||||
}
|
||||
)
|
||||
|
||||
func NewPayloadWriter(colType schemapb.DataType) (*PayloadWriter, error) {
|
||||
w := C.NewPayloadWriter(C.int(colType))
|
||||
if w == nil {
|
||||
return nil, errors.New("create Payload writer failed")
|
||||
}
|
||||
return &PayloadWriter{payloadWriterPtr: w}, nil
|
||||
return &PayloadWriter{payloadWriterPtr: w, colType: colType}, nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddDataToPayload(msgs interface{}, dim ...int) error {
|
||||
switch len(dim) {
|
||||
case 0:
|
||||
switch w.colType {
|
||||
case schemapb.DataType_BOOL:
|
||||
val, ok := msgs.([]bool)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddBoolToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT8:
|
||||
val, ok := msgs.([]int8)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt8ToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT16:
|
||||
val, ok := msgs.([]int16)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt16ToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT32:
|
||||
val, ok := msgs.([]int32)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt32ToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT64:
|
||||
val, ok := msgs.([]int64)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt64ToPayload(val)
|
||||
|
||||
case schemapb.DataType_FLOAT:
|
||||
val, ok := msgs.([]float32)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddFloatToPayload(val)
|
||||
|
||||
case schemapb.DataType_DOUBLE:
|
||||
val, ok := msgs.([]float64)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddDoubleToPayload(val)
|
||||
|
||||
case schemapb.DataType_STRING:
|
||||
val, ok := msgs.(string)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddOneStringToPayload(val)
|
||||
}
|
||||
case 1:
|
||||
switch w.colType {
|
||||
case schemapb.DataType_VECTOR_BINARY:
|
||||
val, ok := msgs.([]byte)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddBinaryVectorToPayload(val, dim[0])
|
||||
|
||||
case schemapb.DataType_VECTOR_FLOAT:
|
||||
val, ok := msgs.([]float32)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddFloatVectorToPayload(val, dim[0])
|
||||
}
|
||||
|
||||
default:
|
||||
return errors.New("incorrect input numbers")
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddBoolToPayload(msgs []bool) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.bool)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddBooleanToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt8ToPayload(msgs []int8) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
cMsgs := (*C.int8_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt8ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt16ToPayload(msgs []int16) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.int16_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt16ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt32ToPayload(msgs []int32) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.int32_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt32ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt64ToPayload(msgs []int64) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.int64_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt64ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddFloatToPayload(msgs []float32) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.float)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddFloatToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddDoubleToPayload(msgs []float64) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.double)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddDoubleToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddOneStringToPayload(msg string) error {
|
||||
if len(msg) == 0 {
|
||||
length := len(msg)
|
||||
if length == 0 {
|
||||
return errors.New("can't add empty string into payload")
|
||||
}
|
||||
cstr := C.CString(msg)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
st := C.AddOneStringToPayload(w.payloadWriterPtr, cstr, C.int(len(msg)))
|
||||
|
||||
cmsg := C.CString(msg)
|
||||
clength := C.int(length)
|
||||
defer C.free(unsafe.Pointer(cmsg))
|
||||
|
||||
st := C.AddOneStringToPayload(w.payloadWriterPtr, cmsg, clength)
|
||||
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dimension > 0 && (%8 == 0)
|
||||
func (w *PayloadWriter) AddBinaryVectorToPayload(binVec []byte, dim int) error {
|
||||
length := len(binVec)
|
||||
if length <= 0 {
|
||||
return errors.New("can't add empty binVec into payload")
|
||||
}
|
||||
|
||||
if dim <= 0 {
|
||||
return errors.New("dimension should be greater than 0")
|
||||
}
|
||||
|
||||
cBinVec := (*C.uint8_t)(&binVec[0])
|
||||
cDim := C.int(dim)
|
||||
cLength := C.int(length / (dim / 8))
|
||||
|
||||
st := C.AddBinaryVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dimension > 0 && (%8 == 0)
|
||||
func (w *PayloadWriter) AddFloatVectorToPayload(floatVec []float32, dim int) error {
|
||||
length := len(floatVec)
|
||||
if length <= 0 {
|
||||
return errors.New("can't add empty floatVec into payload")
|
||||
}
|
||||
|
||||
if dim <= 0 {
|
||||
return errors.New("dimension should be greater than 0")
|
||||
}
|
||||
|
||||
cBinVec := (*C.float)(&floatVec[0])
|
||||
cDim := C.int(dim)
|
||||
cLength := C.int(length / dim)
|
||||
|
||||
st := C.AddFloatVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
|
@ -56,13 +342,13 @@ func (w *PayloadWriter) FinishPayloadWriter() error {
|
|||
}
|
||||
|
||||
func (w *PayloadWriter) GetPayloadBufferFromWriter() ([]byte, error) {
|
||||
//See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
||||
cb := C.GetPayloadBufferFromWriter(w.payloadWriterPtr)
|
||||
pointer := unsafe.Pointer(cb.data)
|
||||
length := int(cb.length)
|
||||
if length <= 0 {
|
||||
return nil, errors.New("empty buffer")
|
||||
}
|
||||
// refer to: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
||||
slice := (*[1 << 28]byte)(pointer)[:length:length]
|
||||
return slice, nil
|
||||
}
|
||||
|
@ -87,16 +373,71 @@ func (w *PayloadWriter) Close() error {
|
|||
return w.ReleasePayloadWriter()
|
||||
}
|
||||
|
||||
type PayloadReader struct {
|
||||
payloadReaderPtr C.CPayloadReader
|
||||
}
|
||||
|
||||
func NewPayloadReader(colType schemapb.DataType, buf []byte) (*PayloadReader, error) {
|
||||
if len(buf) == 0 {
|
||||
return nil, errors.New("create Payload reader failed, buffer is empty")
|
||||
}
|
||||
r := C.NewPayloadReader(C.int(colType), (*C.uchar)(unsafe.Pointer(&buf[0])), C.long(len(buf)))
|
||||
return &PayloadReader{payloadReaderPtr: r}, nil
|
||||
return &PayloadReader{payloadReaderPtr: r, colType: colType}, nil
|
||||
}
|
||||
|
||||
// Params:
|
||||
// `idx`: String index
|
||||
// Return:
|
||||
// `interface{}`: all types.
|
||||
// `int`: length, only meaningful to FLOAT/BINARY VECTOR type.
|
||||
// `error`: error.
|
||||
func (r *PayloadReader) GetDataFromPayload(idx ...int) (interface{}, int, error) {
|
||||
switch len(idx) {
|
||||
case 1:
|
||||
switch r.colType {
|
||||
case schemapb.DataType_STRING:
|
||||
val, err := r.GetOneStringFromPayload(idx[0])
|
||||
return val, 0, err
|
||||
}
|
||||
case 0:
|
||||
switch r.colType {
|
||||
case schemapb.DataType_BOOL:
|
||||
val, err := r.GetBoolFromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT8:
|
||||
val, err := r.GetInt8FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT16:
|
||||
val, err := r.GetInt16FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT32:
|
||||
val, err := r.GetInt32FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT64:
|
||||
val, err := r.GetInt64FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_FLOAT:
|
||||
val, err := r.GetFloatFromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_DOUBLE:
|
||||
val, err := r.GetDoubleFromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_VECTOR_BINARY:
|
||||
return r.GetBinaryVectorFromPayload()
|
||||
|
||||
case schemapb.DataType_VECTOR_FLOAT:
|
||||
return r.GetFloatVectorFromPayload()
|
||||
default:
|
||||
return nil, 0, errors.New("Unknown type")
|
||||
}
|
||||
default:
|
||||
return nil, 0, errors.New("incorrect number of index")
|
||||
}
|
||||
|
||||
return nil, 0, errors.New("unknown error")
|
||||
}
|
||||
|
||||
func (r *PayloadReader) ReleasePayloadReader() error {
|
||||
|
@ -110,18 +451,169 @@ func (r *PayloadReader) ReleasePayloadReader() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetBoolFromPayload() ([]bool, error) {
|
||||
var cMsg *C.bool
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetBoolFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]bool)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt8FromPayload() ([]int8, error) {
|
||||
var cMsg *C.int8_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt8FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int8)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt16FromPayload() ([]int16, error) {
|
||||
var cMsg *C.int16_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt16FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int16)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt32FromPayload() ([]int32, error) {
|
||||
var cMsg *C.int32_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt32FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int32)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt64FromPayload() ([]int64, error) {
|
||||
var cMsg *C.int64_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt64FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int64)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetFloatFromPayload() ([]float32, error) {
|
||||
var cMsg *C.float
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetFloatFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) {
|
||||
var cMsg *C.double
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetDoubleFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]float64)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetOneStringFromPayload(idx int) (string, error) {
|
||||
var cStr *C.char
|
||||
var strSize C.int
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &cSize)
|
||||
|
||||
st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &strSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
return C.GoStringN(cStr, strSize), nil
|
||||
return C.GoStringN(cStr, cSize), nil
|
||||
}
|
||||
|
||||
// ,dimension, error
|
||||
func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) {
|
||||
var cMsg *C.uint8_t
|
||||
var cDim C.int
|
||||
var cLen C.int
|
||||
|
||||
st := C.GetBinaryVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, 0, errors.New(msg)
|
||||
}
|
||||
length := (cDim / 8) * cLen
|
||||
|
||||
slice := (*[1 << 28]byte)(unsafe.Pointer(cMsg))[:length:length]
|
||||
return slice, int(cDim), nil
|
||||
}
|
||||
|
||||
// ,dimension, error
|
||||
func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
|
||||
var cMsg *C.float
|
||||
var cDim C.int
|
||||
var cLen C.int
|
||||
|
||||
st := C.GetFloatVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, 0, errors.New(msg)
|
||||
}
|
||||
length := cDim * cLen
|
||||
|
||||
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:length:length]
|
||||
return slice, int(cDim), nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) {
|
||||
|
|
|
@ -1,54 +1,426 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func TestNewPayloadWriter(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_STRING)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, w)
|
||||
err = w.Close()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestPayLoadString(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_STRING)
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello0")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello1")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello2")
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 3)
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_STRING, buffer)
|
||||
assert.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 3)
|
||||
str0, err := r.GetOneStringFromPayload(0)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str0, "hello0")
|
||||
str1, err := r.GetOneStringFromPayload(1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str1, "hello1")
|
||||
str2, err := r.GetOneStringFromPayload(2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str2, "hello2")
|
||||
|
||||
err = r.ReleasePayloadReader()
|
||||
assert.Nil(t, err)
|
||||
err = w.ReleasePayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
func TestPayload_ReaderandWriter(t *testing.T) {
|
||||
|
||||
t.Run("TestBool", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_BOOL)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddBoolToPayload([]bool{false, false, false, false})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]bool{false, false, false, false})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 8, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_BOOL, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 8)
|
||||
bools, err := r.GetBoolFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
|
||||
ibools, _, err := r.GetDataFromPayload()
|
||||
bools = ibools.([]bool)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
|
||||
defer r.ReleasePayloadReader()
|
||||
|
||||
})
|
||||
|
||||
t.Run("TestInt8", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT8)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt8ToPayload([]int8{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int8{4, 5, 6})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT8, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
int8s, err := r.GetInt8FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
|
||||
|
||||
iint8s, _, err := r.GetDataFromPayload()
|
||||
int8s = iint8s.([]int8)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestInt16", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT16)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt16ToPayload([]int16{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int16{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT16, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
int16s, err := r.GetInt16FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
|
||||
|
||||
iint16s, _, err := r.GetDataFromPayload()
|
||||
int16s = iint16s.([]int16)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestInt32", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT32)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt32ToPayload([]int32{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int32{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT32, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
int32s, err := r.GetInt32FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
|
||||
|
||||
iint32s, _, err := r.GetDataFromPayload()
|
||||
int32s = iint32s.([]int32)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestInt64", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT64)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt64ToPayload([]int64{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int64{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT64, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
int64s, err := r.GetInt64FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
|
||||
|
||||
iint64s, _, err := r.GetDataFromPayload()
|
||||
int64s = iint64s.([]int64)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestFloat32", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_FLOAT)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddFloatToPayload([]float32{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]float32{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_FLOAT, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
float32s, err := r.GetFloatFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
|
||||
|
||||
ifloat32s, _, err := r.GetDataFromPayload()
|
||||
float32s = ifloat32s.([]float32)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestDouble", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_DOUBLE)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddDoubleToPayload([]float64{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]float64{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_DOUBLE, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
float64s, err := r.GetDoubleFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
|
||||
|
||||
ifloat64s, _, err := r.GetDataFromPayload()
|
||||
float64s = ifloat64s.([]float64)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestAddOneString", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_STRING)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddOneStringToPayload("hello0")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello1")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello2")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload("hello3")
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 4)
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_STRING, buffer)
|
||||
assert.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 4)
|
||||
str0, err := r.GetOneStringFromPayload(0)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str0, "hello0")
|
||||
str1, err := r.GetOneStringFromPayload(1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str1, "hello1")
|
||||
str2, err := r.GetOneStringFromPayload(2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str2, "hello2")
|
||||
str3, err := r.GetOneStringFromPayload(3)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str3, "hello3")
|
||||
|
||||
istr0, _, err := r.GetDataFromPayload(0)
|
||||
str0 = istr0.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str0, "hello0")
|
||||
|
||||
istr1, _, err := r.GetDataFromPayload(1)
|
||||
str1 = istr1.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str1, "hello1")
|
||||
|
||||
istr2, _, err := r.GetDataFromPayload(2)
|
||||
str2 = istr2.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str2, "hello2")
|
||||
|
||||
istr3, _, err := r.GetDataFromPayload(3)
|
||||
str3 = istr3.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str3, "hello3")
|
||||
|
||||
err = r.ReleasePayloadReader()
|
||||
assert.Nil(t, err)
|
||||
err = w.ReleasePayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("TestBinaryVector", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_BINARY)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
in := make([]byte, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
in[i] = 1
|
||||
}
|
||||
in2 := make([]byte, 8)
|
||||
for i := 0; i < 8; i++ {
|
||||
in2[i] = 1
|
||||
}
|
||||
|
||||
err = w.AddBinaryVectorToPayload(in, 8)
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload(in2, 8)
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 24, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_VECTOR_BINARY, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 24)
|
||||
|
||||
binVecs, dim, err := r.GetBinaryVectorFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 8, dim)
|
||||
assert.Equal(t, 24, len(binVecs))
|
||||
fmt.Println(binVecs)
|
||||
|
||||
ibinVecs, dim, err := r.GetDataFromPayload()
|
||||
assert.Nil(t, err)
|
||||
binVecs = ibinVecs.([]byte)
|
||||
assert.Equal(t, 8, dim)
|
||||
assert.Equal(t, 24, len(binVecs))
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestFloatVector", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_FLOAT)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1)
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]float32{3.0, 4.0}, 1)
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_VECTOR_FLOAT, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 4)
|
||||
|
||||
floatVecs, dim, err := r.GetFloatVectorFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, dim)
|
||||
assert.Equal(t, 4, len(floatVecs))
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
|
||||
|
||||
ifloatVecs, dim, err := r.GetDataFromPayload()
|
||||
assert.Nil(t, err)
|
||||
floatVecs = ifloatVecs.([]float32)
|
||||
assert.Equal(t, 1, dim)
|
||||
assert.Equal(t, 4, len(floatVecs))
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue