mirror of https://github.com/milvus-io/milvus.git
Add string type payload cgo wrapper
Signed-off-by: neza2017 <yefu.chen@zilliz.com>pull/4973/head^2
parent
aeca8f85a5
commit
0d8273c7cc
|
@ -19,7 +19,7 @@ typedef struct CStatus {
|
|||
} CStatus;
|
||||
|
||||
CPayloadWriter NewPayloadWriter(int columnType);
|
||||
CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
|
||||
//CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
|
||||
CStatus AddInt8ToPayload(CPayloadWriter payloadWriter, int8_t *values, int length);
|
||||
CStatus AddInt16ToPayload(CPayloadWriter payloadWriter, int16_t *values, int length);
|
||||
CStatus AddInt32ToPayload(CPayloadWriter payloadWriter, int32_t *values, int length);
|
||||
|
@ -39,7 +39,7 @@ CStatus ReleasePayloadWriter(CPayloadWriter handler);
|
|||
|
||||
typedef void *CPayloadReader;
|
||||
CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_size);
|
||||
CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
|
||||
//CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
|
||||
CStatus GetInt8FromPayload(CPayloadReader payloadReader, int8_t **values, int *length);
|
||||
CStatus GetInt16FromPayload(CPayloadReader payloadReader, int16_t **values, int *length);
|
||||
CStatus GetInt32FromPayload(CPayloadReader payloadReader, int32_t **values, int *length);
|
||||
|
|
|
@ -70,38 +70,38 @@ TEST(wrapper, inoutstream) {
|
|||
ASSERT_EQ(inarray->Value(4), 5);
|
||||
}
|
||||
|
||||
TEST(wrapper, boolean) {
|
||||
auto payload = NewPayloadWriter(ColumnType::BOOL);
|
||||
bool data[] = {true, false, true, false};
|
||||
|
||||
auto st = AddBooleanToPayload(payload, data, 4);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
st = FinishPayloadWriter(payload);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
auto cb = GetPayloadBufferFromWriter(payload);
|
||||
ASSERT_GT(cb.length, 0);
|
||||
ASSERT_NE(cb.data, nullptr);
|
||||
auto nums = GetPayloadLengthFromWriter(payload);
|
||||
ASSERT_EQ(nums, 4);
|
||||
|
||||
auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
|
||||
bool *values;
|
||||
int length;
|
||||
st = GetBoolFromPayload(reader, &values, &length);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
ASSERT_NE(values, nullptr);
|
||||
ASSERT_EQ(length, 4);
|
||||
length = GetPayloadLengthFromReader(reader);
|
||||
ASSERT_EQ(length, 4);
|
||||
for (int i = 0; i < length; i++) {
|
||||
ASSERT_EQ(data[i], values[i]);
|
||||
}
|
||||
|
||||
st = ReleasePayloadWriter(payload);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
st = ReleasePayloadReader(reader);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
}
|
||||
//TEST(wrapper, boolean) {
|
||||
// auto payload = NewPayloadWriter(ColumnType::BOOL);
|
||||
// bool data[] = {true, false, true, false};
|
||||
//
|
||||
// auto st = AddBooleanToPayload(payload, data, 4);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// st = FinishPayloadWriter(payload);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// auto cb = GetPayloadBufferFromWriter(payload);
|
||||
// ASSERT_GT(cb.length, 0);
|
||||
// ASSERT_NE(cb.data, nullptr);
|
||||
// auto nums = GetPayloadLengthFromWriter(payload);
|
||||
// ASSERT_EQ(nums, 4);
|
||||
//
|
||||
// auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
|
||||
// bool *values;
|
||||
// int length;
|
||||
// st = GetBoolFromPayload(reader, &values, &length);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// ASSERT_NE(values, nullptr);
|
||||
// ASSERT_EQ(length, 4);
|
||||
// length = GetPayloadLengthFromReader(reader);
|
||||
// ASSERT_EQ(length, 4);
|
||||
// for (int i = 0; i < length; i++) {
|
||||
// ASSERT_EQ(data[i], values[i]);
|
||||
// }
|
||||
//
|
||||
// st = ReleasePayloadWriter(payload);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// st = ReleasePayloadReader(reader);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
//}
|
||||
|
||||
#define NUMERIC_TEST(TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) TEST(wrapper, TEST_NAME) { \
|
||||
auto payload = NewPayloadWriter(COLUMN_TYPE); \
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
package storage
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I${SRCDIR}/cwrapper
|
||||
|
||||
#cgo LDFLAGS: -L${SRCDIR}/cwrapper/output -l:libwrapper.a -l:libparquet.a -l:libarrow.a -l:libthrift.a -l:libutf8proc.a -lstdc++ -lm
|
||||
#include <stdlib.h>
|
||||
#include "ParquetWrapper.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
type PayloadWriter struct {
|
||||
payloadWriterPtr C.CPayloadWriter
|
||||
}
|
||||
|
||||
func NewPayloadWriter(colType schemapb.DataType) (*PayloadWriter, error) {
|
||||
w := C.NewPayloadWriter(C.int(colType))
|
||||
if w == nil {
|
||||
return nil, errors.New("create Payload writer failed")
|
||||
}
|
||||
return &PayloadWriter{payloadWriterPtr: w}, nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddOneStringToPayload(msg string) error {
|
||||
if len(msg) == 0 {
|
||||
return errors.New("can't add empty string into payload")
|
||||
}
|
||||
cstr := C.CString(msg)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
st := C.AddOneStringToPayload(w.payloadWriterPtr, cstr, C.int(len(msg)))
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) FinishPayloadWriter() error {
|
||||
st := C.FinishPayloadWriter(w.payloadWriterPtr)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) GetPayloadBufferFromWriter() ([]byte, error) {
|
||||
//See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
||||
cb := C.GetPayloadBufferFromWriter(w.payloadWriterPtr)
|
||||
pointer := unsafe.Pointer(cb.data)
|
||||
length := int(cb.length)
|
||||
if length <= 0 {
|
||||
return nil, errors.New("empty buffer")
|
||||
}
|
||||
slice := (*[1 << 28]byte)(pointer)[:length:length]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) GetPayloadLengthFromWriter() (int, error) {
|
||||
length := C.GetPayloadLengthFromWriter(w.payloadWriterPtr)
|
||||
return int(length), nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) ReleasePayloadWriter() error {
|
||||
st := C.ReleasePayloadWriter(w.payloadWriterPtr)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) Close() error {
|
||||
return w.ReleasePayloadWriter()
|
||||
}
|
||||
|
||||
type PayloadReader struct {
|
||||
payloadReaderPtr C.CPayloadReader
|
||||
}
|
||||
|
||||
func NewPayloadReader(colType schemapb.DataType, buf []byte) (*PayloadReader, error) {
|
||||
if len(buf) == 0 {
|
||||
return nil, errors.New("create Payload reader failed, buffer is empty")
|
||||
}
|
||||
r := C.NewPayloadReader(C.int(colType), (*C.uchar)(unsafe.Pointer(&buf[0])), C.long(len(buf)))
|
||||
return &PayloadReader{payloadReaderPtr: r}, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) ReleasePayloadReader() error {
|
||||
st := C.ReleasePayloadReader(r.payloadReaderPtr)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetOneStringFromPayload(idx int) (string, error) {
|
||||
var cStr *C.char
|
||||
var strSize C.int
|
||||
|
||||
st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &strSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
return C.GoStringN(cStr, strSize), nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) {
|
||||
length := C.GetPayloadLengthFromReader(r.payloadReaderPtr)
|
||||
return int(length), nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) Close() error {
|
||||
return r.ReleasePayloadReader()
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func TestNewPayloadWriter(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_STRING)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, w)
|
||||
err = w.Close()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestPayLoadString(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_STRING)
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello0")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello1")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello2")
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 3)
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_STRING, buffer)
|
||||
assert.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 3)
|
||||
str0, err := r.GetOneStringFromPayload(0)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str0, "hello0")
|
||||
str1, err := r.GetOneStringFromPayload(1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str1, "hello1")
|
||||
str2, err := r.GetOneStringFromPayload(2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str2, "hello2")
|
||||
|
||||
err = r.ReleasePayloadReader()
|
||||
assert.Nil(t, err)
|
||||
err = w.ReleasePayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
}
|
Loading…
Reference in New Issue