mirror of https://github.com/milvus-io/milvus.git
Parse utf32 string of numpy file (#20176)
Signed-off-by: groot <yihua.mo@zilliz.com> Signed-off-by: groot <yihua.mo@zilliz.com>pull/20211/head
parent
e9cd2cb42a
commit
c6151ad351
|
@ -267,7 +267,7 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) {
|
|||
// TODO add context
|
||||
size, err := p.chunkManager.Size(context.TODO(), filePath)
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to get file size", zap.String("filePath", filePath), zap.Any("err", err))
|
||||
log.Error("import wrapper: failed to get file size", zap.String("filePath", filePath), zap.Error(err))
|
||||
return rowBased, fmt.Errorf("import wrapper: failed to get file size of '%s'", filePath)
|
||||
}
|
||||
|
||||
|
@ -334,7 +334,7 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error
|
|||
if fileType == JSONFileExt {
|
||||
err = p.parseRowBasedJSON(filePath, options.OnlyValidate)
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to parse row-based json file", zap.Any("err", err), zap.String("filePath", filePath))
|
||||
log.Error("import wrapper: failed to parse row-based json file", zap.Error(err), zap.String("filePath", filePath))
|
||||
return err
|
||||
}
|
||||
} // no need to check else, since the fileValidation() already do this
|
||||
|
@ -407,7 +407,7 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error
|
|||
err = p.parseColumnBasedNumpy(filePath, options.OnlyValidate, combineFunc)
|
||||
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to parse column-based numpy file", zap.Any("err", err), zap.String("filePath", filePath))
|
||||
log.Error("import wrapper: failed to parse column-based numpy file", zap.Error(err), zap.String("filePath", filePath))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -732,7 +732,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
|
|||
// generate auto id for primary key and rowid field
|
||||
rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount))
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to alloc row ID", zap.Any("err", err))
|
||||
log.Error("import wrapper: failed to alloc row ID", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -869,7 +869,7 @@ func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData,
|
|||
// create a new segment
|
||||
segID, channelName, err := p.assignSegmentFunc(shardID)
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to assign a new segment", zap.Any("error", err), zap.Int("shardID", shardID))
|
||||
log.Error("import wrapper: failed to assign a new segment", zap.Error(err), zap.Int("shardID", shardID))
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -888,7 +888,7 @@ func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData,
|
|||
// save binlogs
|
||||
fieldsInsert, fieldsStats, err := p.createBinlogsFunc(fields, segment.segmentID)
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to save binlogs", zap.Any("error", err), zap.Int("shardID", shardID),
|
||||
log.Error("import wrapper: failed to save binlogs", zap.Error(err), zap.Int("shardID", shardID),
|
||||
zap.Int64("segmentID", segment.segmentID), zap.String("targetChannel", segment.targetChName))
|
||||
return err
|
||||
}
|
||||
|
@ -914,7 +914,7 @@ func (p *ImportWrapper) closeWorkingSegment(segment *WorkingSegment) error {
|
|||
err := p.saveSegmentFunc(segment.fieldsInsert, segment.fieldsStats, segment.segmentID, segment.targetChName, segment.rowCount)
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to save segment",
|
||||
zap.Any("error", err),
|
||||
zap.Error(err),
|
||||
zap.Int("shardID", segment.shardID),
|
||||
zap.Int64("segmentID", segment.segmentID),
|
||||
zap.String("targetChannel", segment.targetChName))
|
||||
|
|
|
@ -113,7 +113,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
// read the key
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
log.Error("JSON parser: read json token error", zap.Any("err", err))
|
||||
log.Error("JSON parser: read json token error", zap.Error(err))
|
||||
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
||||
}
|
||||
key := t.(string)
|
||||
|
@ -128,7 +128,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
// started by '['
|
||||
t, err = dec.Token()
|
||||
if err != nil {
|
||||
log.Error("JSON parser: read json token error", zap.Any("err", err))
|
||||
log.Error("JSON parser: read json token error", zap.Error(err))
|
||||
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
||||
}
|
||||
|
||||
|
@ -142,7 +142,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
for dec.More() {
|
||||
var value interface{}
|
||||
if err := dec.Decode(&value); err != nil {
|
||||
log.Error("JSON parser: decode json value error", zap.Any("err", err))
|
||||
log.Error("JSON parser: decode json value error", zap.Error(err))
|
||||
return fmt.Errorf("JSON parser: decode json value error: %v", err)
|
||||
}
|
||||
|
||||
|
@ -170,7 +170,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
if len(buf) >= int(p.bufSize) {
|
||||
isEmpty = false
|
||||
if err = handler.Handle(buf); err != nil {
|
||||
log.Error("JSON parser: parse values error", zap.Any("err", err))
|
||||
log.Error("JSON parser: parse values error", zap.Error(err))
|
||||
return fmt.Errorf("JSON parser: parse values error: %v", err)
|
||||
}
|
||||
|
||||
|
@ -183,7 +183,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
if len(buf) > 0 {
|
||||
isEmpty = false
|
||||
if err = handler.Handle(buf); err != nil {
|
||||
log.Error("JSON parser: parse values error", zap.Any("err", err))
|
||||
log.Error("JSON parser: parse values error", zap.Error(err))
|
||||
return fmt.Errorf("JSON parser: parse values error: %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -191,7 +191,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
// end by ']'
|
||||
t, err = dec.Token()
|
||||
if err != nil {
|
||||
log.Error("JSON parser: read json token error", zap.Any("err", err))
|
||||
log.Error("JSON parser: read json token error", zap.Error(err))
|
||||
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ import (
|
|||
"github.com/sbinet/npyio"
|
||||
"github.com/sbinet/npyio/npy"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -135,7 +136,7 @@ func convertNumpyType(typeStr string) (schemapb.DataType, error) {
|
|||
if isStringType(typeStr) {
|
||||
return schemapb.DataType_VarChar, nil
|
||||
}
|
||||
log.Error("Numpy adapter: the numpy file data type not supported", zap.String("dataType", typeStr))
|
||||
log.Error("Numpy adapter: the numpy file data type is not supported", zap.String("dataType", typeStr))
|
||||
return schemapb.DataType_None, fmt.Errorf("Numpy adapter: the numpy file dtype '%s' is not supported", typeStr)
|
||||
}
|
||||
}
|
||||
|
@ -496,7 +497,7 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
|
|||
// varchar length, this is the max length, some item is shorter than this length, but they also occupy bytes of max length
|
||||
maxLen, utf, err := stringLen(n.npyReader.Header.Descr.Type)
|
||||
if err != nil || maxLen <= 0 {
|
||||
log.Error("Numpy adapter: failed to get max length of varchar from numpy file header", zap.Int("maxLen", maxLen), zap.Any("err", err))
|
||||
log.Error("Numpy adapter: failed to get max length of varchar from numpy file header", zap.Int("maxLen", maxLen), zap.Error(err))
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to get max length %d of varchar from numpy file header, error: %w", maxLen, err)
|
||||
}
|
||||
log.Info("Numpy adapter: get varchar max length from numpy file header", zap.Int("maxLen", maxLen), zap.Bool("utf", utf))
|
||||
|
@ -511,45 +512,33 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
|
|||
data := make([]string, 0)
|
||||
for i := 0; i < readSize; i++ {
|
||||
if utf {
|
||||
// in the numpy file, each utf8 character occupy utf8.UTFMax bytes, each string occupys utf8.UTFMax*maxLen bytes
|
||||
// for example, an ANSI character "a" only uses one byte, but it still occupy utf8.UTFMax bytes
|
||||
// a chinese character uses three bytes, it also occupy utf8.UTFMax bytes
|
||||
// in the numpy file with utf32 encoding, the dType could be like "<U2",
|
||||
// "<" is byteorder(LittleEndian), "U" means it is utf32 encoding, "2" means the max length of strings is 2(characters)
|
||||
// each character occupy 4 bytes, each string occupys 4*maxLen bytes
|
||||
// for example, a numpy file has two strings: "a" and "bb", the maxLen is 2, byte order is LittleEndian
|
||||
// the character "a" occupys 2*4=8 bytes(0x97,0x00,0x00,0x00,0x00,0x00,0x00,0x00),
|
||||
// the "bb" occupys 8 bytes(0x97,0x00,0x00,0x00,0x98,0x00,0x00,0x00)
|
||||
// for non-ascii characters, the unicode could be 1 ~ 4 bytes, each character occupys 4 bytes, too
|
||||
raw, err := ioutil.ReadAll(io.LimitReader(n.reader, utf8.UTFMax*int64(maxLen)))
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read utf8 string from numpy file", zap.Int("i", i), zap.Any("err", err))
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read utf8 string from numpy file, error: %w", err)
|
||||
log.Error("Numpy adapter: failed to read utf32 bytes from numpy file", zap.Int("i", i), zap.Error(err))
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read utf32 bytes from numpy file, error: %w", err)
|
||||
}
|
||||
|
||||
var str string
|
||||
for len(raw) > 0 {
|
||||
r, _ := utf8.DecodeRune(raw)
|
||||
if r == utf8.RuneError {
|
||||
log.Error("Numpy adapter: failed to decode utf8 string from numpy file", zap.Any("raw", raw[:utf8.UTFMax]))
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to decode utf8 string from numpy file, error: illegal utf-8 encoding")
|
||||
}
|
||||
|
||||
// only support ascii characters, because the numpy lib encode the utf8 bytes by its internal method,
|
||||
// the encode/decode logic is not clear now, return error
|
||||
n := n.order.Uint32(raw)
|
||||
if n > 127 {
|
||||
log.Error("Numpy adapter: a string contains non-ascii characters, not support yet", zap.Int32("utf8Code", r))
|
||||
return nil, fmt.Errorf("Numpy adapter: a string contains non-ascii characters, not support yet")
|
||||
}
|
||||
|
||||
// if a string is shorter than maxLen, the tail characters will be filled with "\u0000"(in utf spec this is Null)
|
||||
if r > 0 {
|
||||
str += string(r)
|
||||
}
|
||||
|
||||
raw = raw[utf8.UTFMax:]
|
||||
str, err := decodeUtf32(raw, n.order)
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed todecode utf32 bytes", zap.Int("i", i), zap.Error(err))
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to decode utf32 bytes, error: %w", err)
|
||||
}
|
||||
|
||||
data = append(data, str)
|
||||
} else {
|
||||
// in the numpy file with ansi encoding, the dType could be like "S2", maxLen is 2, each string occupys 2 bytes
|
||||
// bytes.Index(buf, []byte{0}) tell us which position is the end of the string
|
||||
buf, err := ioutil.ReadAll(io.LimitReader(n.reader, int64(maxLen)))
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read string from numpy file", zap.Int("i", i), zap.Any("err", err))
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read string from numpy file, error: %w", err)
|
||||
log.Error("Numpy adapter: failed to read ascii bytes from numpy file", zap.Int("i", i), zap.Error(err))
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read ascii bytes from numpy file, error: %w", err)
|
||||
}
|
||||
n := bytes.Index(buf, []byte{0})
|
||||
if n > 0 {
|
||||
|
@ -564,3 +553,61 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
|
|||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) {
|
||||
if len(src)%4 != 0 {
|
||||
return "", fmt.Errorf("invalid utf32 bytes length %d, the byte array length should be multiple of 4", len(src))
|
||||
}
|
||||
|
||||
var str string
|
||||
for len(src) > 0 {
|
||||
// check the high bytes, if high bytes are 0, the UNICODE is less than U+FFFF, we can use unicode.UTF16 to decode
|
||||
isUtf16 := false
|
||||
var lowbytesPosition int
|
||||
uOrder := unicode.LittleEndian
|
||||
if order == binary.LittleEndian {
|
||||
if src[2] == 0 && src[3] == 0 {
|
||||
isUtf16 = true
|
||||
}
|
||||
lowbytesPosition = 0
|
||||
} else {
|
||||
if src[0] == 0 && src[1] == 0 {
|
||||
isUtf16 = true
|
||||
}
|
||||
lowbytesPosition = 2
|
||||
uOrder = unicode.BigEndian
|
||||
}
|
||||
|
||||
if isUtf16 {
|
||||
// use unicode.UTF16 to decode the low bytes to utf8
|
||||
// utf32 and utf16 is same if the unicode code is less than 65535
|
||||
if src[lowbytesPosition] != 0 || src[lowbytesPosition+1] != 0 {
|
||||
decoder := unicode.UTF16(uOrder, unicode.IgnoreBOM).NewDecoder()
|
||||
res, err := decoder.Bytes(src[lowbytesPosition : lowbytesPosition+2])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode utf32 binary bytes, error: %w", err)
|
||||
}
|
||||
str += string(res)
|
||||
}
|
||||
} else {
|
||||
// convert the 4 bytes to a unicode and encode to utf8
|
||||
// Golang strongly opposes utf32 coding, this kind of encoding has been excluded from standard lib
|
||||
var x uint32
|
||||
if order == binary.LittleEndian {
|
||||
x = uint32(src[3])<<24 | uint32(src[2])<<16 | uint32(src[1])<<8 | uint32(src[0])
|
||||
} else {
|
||||
x = uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3])
|
||||
}
|
||||
r := rune(x)
|
||||
utf8Code := make([]byte, 4)
|
||||
utf8.EncodeRune(utf8Code, r)
|
||||
if r == utf8.RuneError {
|
||||
return "", fmt.Errorf("failed to convert 4 bytes unicode %d to utf8 rune", x)
|
||||
}
|
||||
str += string(utf8Code)
|
||||
}
|
||||
|
||||
src = src[4:]
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
"encoding/binary"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
@ -553,9 +555,62 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
t.Run("test read ascii characters", func(t *testing.T) {
|
||||
t.Run("test read ascii characters with ansi", func(t *testing.T) {
|
||||
npyReader := &npy.Reader{
|
||||
Header: npy.Header{},
|
||||
}
|
||||
|
||||
data := make([]byte, 0)
|
||||
values := []string{"ab", "ccc", "d"}
|
||||
maxLen := 0
|
||||
for _, str := range values {
|
||||
if len(str) > maxLen {
|
||||
maxLen = len(str)
|
||||
}
|
||||
}
|
||||
for _, str := range values {
|
||||
for i := 0; i < maxLen; i++ {
|
||||
if i < len(str) {
|
||||
data = append(data, str[i])
|
||||
} else {
|
||||
data = append(data, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
npyReader.Header.Descr.Shape = append(npyReader.Header.Descr.Shape, len(values))
|
||||
|
||||
adapter := &NumpyAdapter{
|
||||
reader: strings.NewReader(string(data)),
|
||||
npyReader: npyReader,
|
||||
readPosition: 0,
|
||||
dataType: schemapb.DataType_VarChar,
|
||||
}
|
||||
|
||||
// count should greater than 0
|
||||
res, err := adapter.ReadString(0)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
|
||||
// maxLen is zero
|
||||
npyReader.Header.Descr.Type = "S0"
|
||||
res, err = adapter.ReadString(1)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
|
||||
npyReader.Header.Descr.Type = "S" + strconv.FormatInt(int64(maxLen), 10)
|
||||
|
||||
res, err = adapter.ReadString(len(values) + 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(values), len(res))
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, values[i], res[i])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test read ascii characters with utf32", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "varchar1.npy"
|
||||
data := []string{"a", "bbb", "c", "dd", "eeee", "fff"}
|
||||
data := []string{"a ", "bbb", " c", "dd", "eeee", "fff"}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -583,9 +638,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
t.Run("test read non-ascii", func(t *testing.T) {
|
||||
t.Run("test read non-ascii characters with utf32", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "varchar2.npy"
|
||||
data := []string{"a三百", "马克bbb"}
|
||||
data := []string{"で と ど ", " 马克bbb", "$(한)삼각*"}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -596,7 +651,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
res, err := adapter.ReadString(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data), len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_DecodeUtf32(t *testing.T) {
|
||||
// wrong input
|
||||
res, err := decodeUtf32([]byte{1, 2}, binary.LittleEndian)
|
||||
assert.NotNil(t, err)
|
||||
assert.Empty(t, res)
|
||||
|
||||
// this string contains ascii characters and unicode characters
|
||||
str := "ad◤三百🎵ゐ↙"
|
||||
|
||||
// utf32 littleEndian of str
|
||||
src := []byte{97, 0, 0, 0, 100, 0, 0, 0, 228, 37, 0, 0, 9, 78, 0, 0, 126, 118, 0, 0, 181, 243, 1, 0, 144, 48, 0, 0, 153, 33, 0, 0}
|
||||
res, err = decodeUtf32(src, binary.LittleEndian)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str, res)
|
||||
|
||||
// utf32 bigEndian of str
|
||||
src = []byte{0, 0, 0, 97, 0, 0, 0, 100, 0, 0, 37, 228, 0, 0, 78, 9, 0, 0, 118, 126, 0, 1, 243, 181, 0, 0, 48, 144, 0, 0, 33, 153}
|
||||
res, err = decodeUtf32(src, binary.BigEndian)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str, res)
|
||||
}
|
||||
|
|
|
@ -152,8 +152,8 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
|
|||
if elementType != schema.DataType {
|
||||
log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType),
|
||||
zap.String("fieldName", fieldName), zap.Any("fieldDataType", schema.DataType))
|
||||
return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for scalar field '%s' with type %d",
|
||||
getTypeName(elementType), schema.GetName(), schema.DataType)
|
||||
return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for scalar field '%s' with type %s",
|
||||
getTypeName(elementType), schema.GetName(), getTypeName(schema.DataType))
|
||||
}
|
||||
|
||||
// scalar field, the shape should be 1
|
||||
|
|
Loading…
Reference in New Issue