mirror of https://github.com/milvus-io/milvus.git
Parse utf32 string of numpy file (#20176)
Signed-off-by: groot <yihua.mo@zilliz.com> Signed-off-by: groot <yihua.mo@zilliz.com>pull/20211/head
parent
e9cd2cb42a
commit
c6151ad351
|
@ -267,7 +267,7 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) {
|
||||||
// TODO add context
|
// TODO add context
|
||||||
size, err := p.chunkManager.Size(context.TODO(), filePath)
|
size, err := p.chunkManager.Size(context.TODO(), filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("import wrapper: failed to get file size", zap.String("filePath", filePath), zap.Any("err", err))
|
log.Error("import wrapper: failed to get file size", zap.String("filePath", filePath), zap.Error(err))
|
||||||
return rowBased, fmt.Errorf("import wrapper: failed to get file size of '%s'", filePath)
|
return rowBased, fmt.Errorf("import wrapper: failed to get file size of '%s'", filePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,7 +334,7 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error
|
||||||
if fileType == JSONFileExt {
|
if fileType == JSONFileExt {
|
||||||
err = p.parseRowBasedJSON(filePath, options.OnlyValidate)
|
err = p.parseRowBasedJSON(filePath, options.OnlyValidate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("import wrapper: failed to parse row-based json file", zap.Any("err", err), zap.String("filePath", filePath))
|
log.Error("import wrapper: failed to parse row-based json file", zap.Error(err), zap.String("filePath", filePath))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} // no need to check else, since the fileValidation() already do this
|
} // no need to check else, since the fileValidation() already do this
|
||||||
|
@ -407,7 +407,7 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error
|
||||||
err = p.parseColumnBasedNumpy(filePath, options.OnlyValidate, combineFunc)
|
err = p.parseColumnBasedNumpy(filePath, options.OnlyValidate, combineFunc)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("import wrapper: failed to parse column-based numpy file", zap.Any("err", err), zap.String("filePath", filePath))
|
log.Error("import wrapper: failed to parse column-based numpy file", zap.Error(err), zap.String("filePath", filePath))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -732,7 +732,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
|
||||||
// generate auto id for primary key and rowid field
|
// generate auto id for primary key and rowid field
|
||||||
rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount))
|
rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("import wrapper: failed to alloc row ID", zap.Any("err", err))
|
log.Error("import wrapper: failed to alloc row ID", zap.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -869,7 +869,7 @@ func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData,
|
||||||
// create a new segment
|
// create a new segment
|
||||||
segID, channelName, err := p.assignSegmentFunc(shardID)
|
segID, channelName, err := p.assignSegmentFunc(shardID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("import wrapper: failed to assign a new segment", zap.Any("error", err), zap.Int("shardID", shardID))
|
log.Error("import wrapper: failed to assign a new segment", zap.Error(err), zap.Int("shardID", shardID))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -888,7 +888,7 @@ func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData,
|
||||||
// save binlogs
|
// save binlogs
|
||||||
fieldsInsert, fieldsStats, err := p.createBinlogsFunc(fields, segment.segmentID)
|
fieldsInsert, fieldsStats, err := p.createBinlogsFunc(fields, segment.segmentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("import wrapper: failed to save binlogs", zap.Any("error", err), zap.Int("shardID", shardID),
|
log.Error("import wrapper: failed to save binlogs", zap.Error(err), zap.Int("shardID", shardID),
|
||||||
zap.Int64("segmentID", segment.segmentID), zap.String("targetChannel", segment.targetChName))
|
zap.Int64("segmentID", segment.segmentID), zap.String("targetChannel", segment.targetChName))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -914,7 +914,7 @@ func (p *ImportWrapper) closeWorkingSegment(segment *WorkingSegment) error {
|
||||||
err := p.saveSegmentFunc(segment.fieldsInsert, segment.fieldsStats, segment.segmentID, segment.targetChName, segment.rowCount)
|
err := p.saveSegmentFunc(segment.fieldsInsert, segment.fieldsStats, segment.segmentID, segment.targetChName, segment.rowCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("import wrapper: failed to save segment",
|
log.Error("import wrapper: failed to save segment",
|
||||||
zap.Any("error", err),
|
zap.Error(err),
|
||||||
zap.Int("shardID", segment.shardID),
|
zap.Int("shardID", segment.shardID),
|
||||||
zap.Int64("segmentID", segment.segmentID),
|
zap.Int64("segmentID", segment.segmentID),
|
||||||
zap.String("targetChannel", segment.targetChName))
|
zap.String("targetChannel", segment.targetChName))
|
||||||
|
|
|
@ -113,7 +113,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
||||||
// read the key
|
// read the key
|
||||||
t, err := dec.Token()
|
t, err := dec.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("JSON parser: read json token error", zap.Any("err", err))
|
log.Error("JSON parser: read json token error", zap.Error(err))
|
||||||
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
||||||
}
|
}
|
||||||
key := t.(string)
|
key := t.(string)
|
||||||
|
@ -128,7 +128,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
||||||
// started by '['
|
// started by '['
|
||||||
t, err = dec.Token()
|
t, err = dec.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("JSON parser: read json token error", zap.Any("err", err))
|
log.Error("JSON parser: read json token error", zap.Error(err))
|
||||||
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
||||||
for dec.More() {
|
for dec.More() {
|
||||||
var value interface{}
|
var value interface{}
|
||||||
if err := dec.Decode(&value); err != nil {
|
if err := dec.Decode(&value); err != nil {
|
||||||
log.Error("JSON parser: decode json value error", zap.Any("err", err))
|
log.Error("JSON parser: decode json value error", zap.Error(err))
|
||||||
return fmt.Errorf("JSON parser: decode json value error: %v", err)
|
return fmt.Errorf("JSON parser: decode json value error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,7 +170,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
||||||
if len(buf) >= int(p.bufSize) {
|
if len(buf) >= int(p.bufSize) {
|
||||||
isEmpty = false
|
isEmpty = false
|
||||||
if err = handler.Handle(buf); err != nil {
|
if err = handler.Handle(buf); err != nil {
|
||||||
log.Error("JSON parser: parse values error", zap.Any("err", err))
|
log.Error("JSON parser: parse values error", zap.Error(err))
|
||||||
return fmt.Errorf("JSON parser: parse values error: %v", err)
|
return fmt.Errorf("JSON parser: parse values error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
||||||
if len(buf) > 0 {
|
if len(buf) > 0 {
|
||||||
isEmpty = false
|
isEmpty = false
|
||||||
if err = handler.Handle(buf); err != nil {
|
if err = handler.Handle(buf); err != nil {
|
||||||
log.Error("JSON parser: parse values error", zap.Any("err", err))
|
log.Error("JSON parser: parse values error", zap.Error(err))
|
||||||
return fmt.Errorf("JSON parser: parse values error: %v", err)
|
return fmt.Errorf("JSON parser: parse values error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -191,7 +191,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
||||||
// end by ']'
|
// end by ']'
|
||||||
t, err = dec.Token()
|
t, err = dec.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("JSON parser: read json token error", zap.Any("err", err))
|
log.Error("JSON parser: read json token error", zap.Error(err))
|
||||||
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ import (
|
||||||
"github.com/sbinet/npyio"
|
"github.com/sbinet/npyio"
|
||||||
"github.com/sbinet/npyio/npy"
|
"github.com/sbinet/npyio/npy"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
"golang.org/x/text/encoding/unicode"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -135,7 +136,7 @@ func convertNumpyType(typeStr string) (schemapb.DataType, error) {
|
||||||
if isStringType(typeStr) {
|
if isStringType(typeStr) {
|
||||||
return schemapb.DataType_VarChar, nil
|
return schemapb.DataType_VarChar, nil
|
||||||
}
|
}
|
||||||
log.Error("Numpy adapter: the numpy file data type not supported", zap.String("dataType", typeStr))
|
log.Error("Numpy adapter: the numpy file data type is not supported", zap.String("dataType", typeStr))
|
||||||
return schemapb.DataType_None, fmt.Errorf("Numpy adapter: the numpy file dtype '%s' is not supported", typeStr)
|
return schemapb.DataType_None, fmt.Errorf("Numpy adapter: the numpy file dtype '%s' is not supported", typeStr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -496,7 +497,7 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
|
||||||
// varchar length, this is the max length, some item is shorter than this length, but they also occupy bytes of max length
|
// varchar length, this is the max length, some item is shorter than this length, but they also occupy bytes of max length
|
||||||
maxLen, utf, err := stringLen(n.npyReader.Header.Descr.Type)
|
maxLen, utf, err := stringLen(n.npyReader.Header.Descr.Type)
|
||||||
if err != nil || maxLen <= 0 {
|
if err != nil || maxLen <= 0 {
|
||||||
log.Error("Numpy adapter: failed to get max length of varchar from numpy file header", zap.Int("maxLen", maxLen), zap.Any("err", err))
|
log.Error("Numpy adapter: failed to get max length of varchar from numpy file header", zap.Int("maxLen", maxLen), zap.Error(err))
|
||||||
return nil, fmt.Errorf("Numpy adapter: failed to get max length %d of varchar from numpy file header, error: %w", maxLen, err)
|
return nil, fmt.Errorf("Numpy adapter: failed to get max length %d of varchar from numpy file header, error: %w", maxLen, err)
|
||||||
}
|
}
|
||||||
log.Info("Numpy adapter: get varchar max length from numpy file header", zap.Int("maxLen", maxLen), zap.Bool("utf", utf))
|
log.Info("Numpy adapter: get varchar max length from numpy file header", zap.Int("maxLen", maxLen), zap.Bool("utf", utf))
|
||||||
|
@ -511,45 +512,33 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
|
||||||
data := make([]string, 0)
|
data := make([]string, 0)
|
||||||
for i := 0; i < readSize; i++ {
|
for i := 0; i < readSize; i++ {
|
||||||
if utf {
|
if utf {
|
||||||
// in the numpy file, each utf8 character occupy utf8.UTFMax bytes, each string occupys utf8.UTFMax*maxLen bytes
|
// in the numpy file with utf32 encoding, the dType could be like "<U2",
|
||||||
// for example, an ANSI character "a" only uses one byte, but it still occupy utf8.UTFMax bytes
|
// "<" is byteorder(LittleEndian), "U" means it is utf32 encoding, "2" means the max length of strings is 2(characters)
|
||||||
// a chinese character uses three bytes, it also occupy utf8.UTFMax bytes
|
// each character occupy 4 bytes, each string occupys 4*maxLen bytes
|
||||||
|
// for example, a numpy file has two strings: "a" and "bb", the maxLen is 2, byte order is LittleEndian
|
||||||
|
// the character "a" occupys 2*4=8 bytes(0x97,0x00,0x00,0x00,0x00,0x00,0x00,0x00),
|
||||||
|
// the "bb" occupys 8 bytes(0x97,0x00,0x00,0x00,0x98,0x00,0x00,0x00)
|
||||||
|
// for non-ascii characters, the unicode could be 1 ~ 4 bytes, each character occupys 4 bytes, too
|
||||||
raw, err := ioutil.ReadAll(io.LimitReader(n.reader, utf8.UTFMax*int64(maxLen)))
|
raw, err := ioutil.ReadAll(io.LimitReader(n.reader, utf8.UTFMax*int64(maxLen)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Numpy adapter: failed to read utf8 string from numpy file", zap.Int("i", i), zap.Any("err", err))
|
log.Error("Numpy adapter: failed to read utf32 bytes from numpy file", zap.Int("i", i), zap.Error(err))
|
||||||
return nil, fmt.Errorf("Numpy adapter: failed to read utf8 string from numpy file, error: %w", err)
|
return nil, fmt.Errorf("Numpy adapter: failed to read utf32 bytes from numpy file, error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var str string
|
str, err := decodeUtf32(raw, n.order)
|
||||||
for len(raw) > 0 {
|
if err != nil {
|
||||||
r, _ := utf8.DecodeRune(raw)
|
log.Error("Numpy adapter: failed todecode utf32 bytes", zap.Int("i", i), zap.Error(err))
|
||||||
if r == utf8.RuneError {
|
return nil, fmt.Errorf("Numpy adapter: failed to decode utf32 bytes, error: %w", err)
|
||||||
log.Error("Numpy adapter: failed to decode utf8 string from numpy file", zap.Any("raw", raw[:utf8.UTFMax]))
|
|
||||||
return nil, fmt.Errorf("Numpy adapter: failed to decode utf8 string from numpy file, error: illegal utf-8 encoding")
|
|
||||||
}
|
|
||||||
|
|
||||||
// only support ascii characters, because the numpy lib encode the utf8 bytes by its internal method,
|
|
||||||
// the encode/decode logic is not clear now, return error
|
|
||||||
n := n.order.Uint32(raw)
|
|
||||||
if n > 127 {
|
|
||||||
log.Error("Numpy adapter: a string contains non-ascii characters, not support yet", zap.Int32("utf8Code", r))
|
|
||||||
return nil, fmt.Errorf("Numpy adapter: a string contains non-ascii characters, not support yet")
|
|
||||||
}
|
|
||||||
|
|
||||||
// if a string is shorter than maxLen, the tail characters will be filled with "\u0000"(in utf spec this is Null)
|
|
||||||
if r > 0 {
|
|
||||||
str += string(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
raw = raw[utf8.UTFMax:]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data = append(data, str)
|
data = append(data, str)
|
||||||
} else {
|
} else {
|
||||||
|
// in the numpy file with ansi encoding, the dType could be like "S2", maxLen is 2, each string occupys 2 bytes
|
||||||
|
// bytes.Index(buf, []byte{0}) tell us which position is the end of the string
|
||||||
buf, err := ioutil.ReadAll(io.LimitReader(n.reader, int64(maxLen)))
|
buf, err := ioutil.ReadAll(io.LimitReader(n.reader, int64(maxLen)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Numpy adapter: failed to read string from numpy file", zap.Int("i", i), zap.Any("err", err))
|
log.Error("Numpy adapter: failed to read ascii bytes from numpy file", zap.Int("i", i), zap.Error(err))
|
||||||
return nil, fmt.Errorf("Numpy adapter: failed to read string from numpy file, error: %w", err)
|
return nil, fmt.Errorf("Numpy adapter: failed to read ascii bytes from numpy file, error: %w", err)
|
||||||
}
|
}
|
||||||
n := bytes.Index(buf, []byte{0})
|
n := bytes.Index(buf, []byte{0})
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
|
@ -564,3 +553,61 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
|
||||||
|
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) {
|
||||||
|
if len(src)%4 != 0 {
|
||||||
|
return "", fmt.Errorf("invalid utf32 bytes length %d, the byte array length should be multiple of 4", len(src))
|
||||||
|
}
|
||||||
|
|
||||||
|
var str string
|
||||||
|
for len(src) > 0 {
|
||||||
|
// check the high bytes, if high bytes are 0, the UNICODE is less than U+FFFF, we can use unicode.UTF16 to decode
|
||||||
|
isUtf16 := false
|
||||||
|
var lowbytesPosition int
|
||||||
|
uOrder := unicode.LittleEndian
|
||||||
|
if order == binary.LittleEndian {
|
||||||
|
if src[2] == 0 && src[3] == 0 {
|
||||||
|
isUtf16 = true
|
||||||
|
}
|
||||||
|
lowbytesPosition = 0
|
||||||
|
} else {
|
||||||
|
if src[0] == 0 && src[1] == 0 {
|
||||||
|
isUtf16 = true
|
||||||
|
}
|
||||||
|
lowbytesPosition = 2
|
||||||
|
uOrder = unicode.BigEndian
|
||||||
|
}
|
||||||
|
|
||||||
|
if isUtf16 {
|
||||||
|
// use unicode.UTF16 to decode the low bytes to utf8
|
||||||
|
// utf32 and utf16 is same if the unicode code is less than 65535
|
||||||
|
if src[lowbytesPosition] != 0 || src[lowbytesPosition+1] != 0 {
|
||||||
|
decoder := unicode.UTF16(uOrder, unicode.IgnoreBOM).NewDecoder()
|
||||||
|
res, err := decoder.Bytes(src[lowbytesPosition : lowbytesPosition+2])
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to decode utf32 binary bytes, error: %w", err)
|
||||||
|
}
|
||||||
|
str += string(res)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// convert the 4 bytes to a unicode and encode to utf8
|
||||||
|
// Golang strongly opposes utf32 coding, this kind of encoding has been excluded from standard lib
|
||||||
|
var x uint32
|
||||||
|
if order == binary.LittleEndian {
|
||||||
|
x = uint32(src[3])<<24 | uint32(src[2])<<16 | uint32(src[1])<<8 | uint32(src[0])
|
||||||
|
} else {
|
||||||
|
x = uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3])
|
||||||
|
}
|
||||||
|
r := rune(x)
|
||||||
|
utf8Code := make([]byte, 4)
|
||||||
|
utf8.EncodeRune(utf8Code, r)
|
||||||
|
if r == utf8.RuneError {
|
||||||
|
return "", fmt.Errorf("failed to convert 4 bytes unicode %d to utf8 rune", x)
|
||||||
|
}
|
||||||
|
str += string(utf8Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
src = src[4:]
|
||||||
|
}
|
||||||
|
return str, nil
|
||||||
|
}
|
||||||
|
|
|
@ -20,6 +20,8 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||||
|
@ -553,9 +555,62 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
||||||
assert.Nil(t, res)
|
assert.Nil(t, res)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("test read ascii characters", func(t *testing.T) {
|
t.Run("test read ascii characters with ansi", func(t *testing.T) {
|
||||||
|
npyReader := &npy.Reader{
|
||||||
|
Header: npy.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]byte, 0)
|
||||||
|
values := []string{"ab", "ccc", "d"}
|
||||||
|
maxLen := 0
|
||||||
|
for _, str := range values {
|
||||||
|
if len(str) > maxLen {
|
||||||
|
maxLen = len(str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, str := range values {
|
||||||
|
for i := 0; i < maxLen; i++ {
|
||||||
|
if i < len(str) {
|
||||||
|
data = append(data, str[i])
|
||||||
|
} else {
|
||||||
|
data = append(data, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
npyReader.Header.Descr.Shape = append(npyReader.Header.Descr.Shape, len(values))
|
||||||
|
|
||||||
|
adapter := &NumpyAdapter{
|
||||||
|
reader: strings.NewReader(string(data)),
|
||||||
|
npyReader: npyReader,
|
||||||
|
readPosition: 0,
|
||||||
|
dataType: schemapb.DataType_VarChar,
|
||||||
|
}
|
||||||
|
|
||||||
|
// count should greater than 0
|
||||||
|
res, err := adapter.ReadString(0)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Nil(t, res)
|
||||||
|
|
||||||
|
// maxLen is zero
|
||||||
|
npyReader.Header.Descr.Type = "S0"
|
||||||
|
res, err = adapter.ReadString(1)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Nil(t, res)
|
||||||
|
|
||||||
|
npyReader.Header.Descr.Type = "S" + strconv.FormatInt(int64(maxLen), 10)
|
||||||
|
|
||||||
|
res, err = adapter.ReadString(len(values) + 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, len(values), len(res))
|
||||||
|
for i := 0; i < len(res); i++ {
|
||||||
|
assert.Equal(t, values[i], res[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test read ascii characters with utf32", func(t *testing.T) {
|
||||||
filePath := TempFilesPath + "varchar1.npy"
|
filePath := TempFilesPath + "varchar1.npy"
|
||||||
data := []string{"a", "bbb", "c", "dd", "eeee", "fff"}
|
data := []string{"a ", "bbb", " c", "dd", "eeee", "fff"}
|
||||||
err := CreateNumpyFile(filePath, data)
|
err := CreateNumpyFile(filePath, data)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
@ -583,9 +638,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
||||||
assert.Nil(t, res)
|
assert.Nil(t, res)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("test read non-ascii", func(t *testing.T) {
|
t.Run("test read non-ascii characters with utf32", func(t *testing.T) {
|
||||||
filePath := TempFilesPath + "varchar2.npy"
|
filePath := TempFilesPath + "varchar2.npy"
|
||||||
data := []string{"a三百", "马克bbb"}
|
data := []string{"で と ど ", " 马克bbb", "$(한)삼각*"}
|
||||||
err := CreateNumpyFile(filePath, data)
|
err := CreateNumpyFile(filePath, data)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
@ -596,7 +651,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
||||||
adapter, err := NewNumpyAdapter(file)
|
adapter, err := NewNumpyAdapter(file)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
res, err := adapter.ReadString(len(data))
|
res, err := adapter.ReadString(len(data))
|
||||||
assert.NotNil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Nil(t, res)
|
assert.Equal(t, len(data), len(res))
|
||||||
|
|
||||||
|
for i := 0; i < len(res); i++ {
|
||||||
|
assert.Equal(t, data[i], res[i])
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_DecodeUtf32(t *testing.T) {
|
||||||
|
// wrong input
|
||||||
|
res, err := decodeUtf32([]byte{1, 2}, binary.LittleEndian)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Empty(t, res)
|
||||||
|
|
||||||
|
// this string contains ascii characters and unicode characters
|
||||||
|
str := "ad◤三百🎵ゐ↙"
|
||||||
|
|
||||||
|
// utf32 littleEndian of str
|
||||||
|
src := []byte{97, 0, 0, 0, 100, 0, 0, 0, 228, 37, 0, 0, 9, 78, 0, 0, 126, 118, 0, 0, 181, 243, 1, 0, 144, 48, 0, 0, 153, 33, 0, 0}
|
||||||
|
res, err = decodeUtf32(src, binary.LittleEndian)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, str, res)
|
||||||
|
|
||||||
|
// utf32 bigEndian of str
|
||||||
|
src = []byte{0, 0, 0, 97, 0, 0, 0, 100, 0, 0, 37, 228, 0, 0, 78, 9, 0, 0, 118, 126, 0, 1, 243, 181, 0, 0, 48, 144, 0, 0, 33, 153}
|
||||||
|
res, err = decodeUtf32(src, binary.BigEndian)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, str, res)
|
||||||
|
}
|
||||||
|
|
|
@ -152,8 +152,8 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
|
||||||
if elementType != schema.DataType {
|
if elementType != schema.DataType {
|
||||||
log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType),
|
log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType),
|
||||||
zap.String("fieldName", fieldName), zap.Any("fieldDataType", schema.DataType))
|
zap.String("fieldName", fieldName), zap.Any("fieldDataType", schema.DataType))
|
||||||
return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for scalar field '%s' with type %d",
|
return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for scalar field '%s' with type %s",
|
||||||
getTypeName(elementType), schema.GetName(), schema.DataType)
|
getTypeName(elementType), schema.GetName(), getTypeName(schema.DataType))
|
||||||
}
|
}
|
||||||
|
|
||||||
// scalar field, the shape should be 1
|
// scalar field, the shape should be 1
|
||||||
|
|
Loading…
Reference in New Issue