From c6151ad3515c75619f1c691af950420182bdc4d9 Mon Sep 17 00:00:00 2001 From: groot Date: Mon, 31 Oct 2022 17:39:34 +0800 Subject: [PATCH] Parse utf32 string of numpy file (#20176) Signed-off-by: groot Signed-off-by: groot --- internal/util/importutil/import_wrapper.go | 14 +-- internal/util/importutil/json_parser.go | 12 +- internal/util/importutil/numpy_adapter.go | 109 +++++++++++++----- .../util/importutil/numpy_adapter_test.go | 93 ++++++++++++++- internal/util/importutil/numpy_parser.go | 4 +- 5 files changed, 180 insertions(+), 52 deletions(-) diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go index e98d2a860e..01c94c30fa 100644 --- a/internal/util/importutil/import_wrapper.go +++ b/internal/util/importutil/import_wrapper.go @@ -267,7 +267,7 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) { // TODO add context size, err := p.chunkManager.Size(context.TODO(), filePath) if err != nil { - log.Error("import wrapper: failed to get file size", zap.String("filePath", filePath), zap.Any("err", err)) + log.Error("import wrapper: failed to get file size", zap.String("filePath", filePath), zap.Error(err)) return rowBased, fmt.Errorf("import wrapper: failed to get file size of '%s'", filePath) } @@ -334,7 +334,7 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error if fileType == JSONFileExt { err = p.parseRowBasedJSON(filePath, options.OnlyValidate) if err != nil { - log.Error("import wrapper: failed to parse row-based json file", zap.Any("err", err), zap.String("filePath", filePath)) + log.Error("import wrapper: failed to parse row-based json file", zap.Error(err), zap.String("filePath", filePath)) return err } } // no need to check else, since the fileValidation() already do this @@ -407,7 +407,7 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error err = p.parseColumnBasedNumpy(filePath, options.OnlyValidate, combineFunc) if err != nil { - log.Error("import wrapper: failed to parse column-based numpy file", zap.Any("err", err), zap.String("filePath", filePath)) + log.Error("import wrapper: failed to parse column-based numpy file", zap.Error(err), zap.String("filePath", filePath)) return err } } @@ -732,7 +732,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F // generate auto id for primary key and rowid field rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount)) if err != nil { - log.Error("import wrapper: failed to alloc row ID", zap.Any("err", err)) + log.Error("import wrapper: failed to alloc row ID", zap.Error(err)) return err } @@ -869,7 +869,7 @@ func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData, // create a new segment segID, channelName, err := p.assignSegmentFunc(shardID) if err != nil { - log.Error("import wrapper: failed to assign a new segment", zap.Any("error", err), zap.Int("shardID", shardID)) + log.Error("import wrapper: failed to assign a new segment", zap.Error(err), zap.Int("shardID", shardID)) return err } @@ -888,7 +888,7 @@ func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData, // save binlogs fieldsInsert, fieldsStats, err := p.createBinlogsFunc(fields, segment.segmentID) if err != nil { - log.Error("import wrapper: failed to save binlogs", zap.Any("error", err), zap.Int("shardID", shardID), + log.Error("import wrapper: failed to save binlogs", zap.Error(err), zap.Int("shardID", shardID), zap.Int64("segmentID", segment.segmentID), zap.String("targetChannel", segment.targetChName)) return err } @@ -914,7 +914,7 @@ func (p *ImportWrapper) closeWorkingSegment(segment *WorkingSegment) error { err := p.saveSegmentFunc(segment.fieldsInsert, segment.fieldsStats, segment.segmentID, segment.targetChName, segment.rowCount) if err != nil { log.Error("import wrapper: failed to save segment", - zap.Any("error", err), + zap.Error(err), zap.Int("shardID", segment.shardID), zap.Int64("segmentID", segment.segmentID), zap.String("targetChannel", segment.targetChName)) diff --git a/internal/util/importutil/json_parser.go b/internal/util/importutil/json_parser.go index f9a60022bc..390a2d3e90 100644 --- a/internal/util/importutil/json_parser.go +++ b/internal/util/importutil/json_parser.go @@ -113,7 +113,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { // read the key t, err := dec.Token() if err != nil { - log.Error("JSON parser: read json token error", zap.Any("err", err)) + log.Error("JSON parser: read json token error", zap.Error(err)) return fmt.Errorf("JSON parser: read json token error: %v", err) } key := t.(string) @@ -128,7 +128,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { // started by '[' t, err = dec.Token() if err != nil { - log.Error("JSON parser: read json token error", zap.Any("err", err)) + log.Error("JSON parser: read json token error", zap.Error(err)) return fmt.Errorf("JSON parser: read json token error: %v", err) } @@ -142,7 +142,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { for dec.More() { var value interface{} if err := dec.Decode(&value); err != nil { - log.Error("JSON parser: decode json value error", zap.Any("err", err)) + log.Error("JSON parser: decode json value error", zap.Error(err)) return fmt.Errorf("JSON parser: decode json value error: %v", err) } @@ -170,7 +170,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { if len(buf) >= int(p.bufSize) { isEmpty = false if err = handler.Handle(buf); err != nil { - log.Error("JSON parser: parse values error", zap.Any("err", err)) + log.Error("JSON parser: parse values error", zap.Error(err)) return fmt.Errorf("JSON parser: parse values error: %v", err) } @@ -183,7 +183,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { if len(buf) > 0 { isEmpty = false if err = handler.Handle(buf); err != nil { - log.Error("JSON parser: parse values error", zap.Any("err", err)) + log.Error("JSON parser: parse values error", zap.Error(err)) return fmt.Errorf("JSON parser: parse values error: %v", err) } } @@ -191,7 +191,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { // end by ']' t, err = dec.Token() if err != nil { - log.Error("JSON parser: read json token error", zap.Any("err", err)) + log.Error("JSON parser: read json token error", zap.Error(err)) return fmt.Errorf("JSON parser: read json token error: %v", err) } diff --git a/internal/util/importutil/numpy_adapter.go b/internal/util/importutil/numpy_adapter.go index 27bd4dee9e..5d7740153e 100644 --- a/internal/util/importutil/numpy_adapter.go +++ b/internal/util/importutil/numpy_adapter.go @@ -34,6 +34,7 @@ import ( "github.com/sbinet/npyio" "github.com/sbinet/npyio/npy" "go.uber.org/zap" + "golang.org/x/text/encoding/unicode" ) var ( @@ -135,7 +136,7 @@ func convertNumpyType(typeStr string) (schemapb.DataType, error) { if isStringType(typeStr) { return schemapb.DataType_VarChar, nil } - log.Error("Numpy adapter: the numpy file data type not supported", zap.String("dataType", typeStr)) + log.Error("Numpy adapter: the numpy file data type is not supported", zap.String("dataType", typeStr)) return schemapb.DataType_None, fmt.Errorf("Numpy adapter: the numpy file dtype '%s' is not supported", typeStr) } } @@ -496,7 +497,7 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) { // varchar length, this is the max length, some item is shorter than this length, but they also occupy bytes of max length maxLen, utf, err := stringLen(n.npyReader.Header.Descr.Type) if err != nil || maxLen <= 0 { - log.Error("Numpy adapter: failed to get max length of varchar from numpy file header", zap.Int("maxLen", maxLen), zap.Any("err", err)) + log.Error("Numpy adapter: failed to get max length of varchar from numpy file header", zap.Int("maxLen", maxLen), zap.Error(err)) return nil, fmt.Errorf("Numpy adapter: failed to get max length %d of varchar from numpy file header, error: %w", maxLen, err) } log.Info("Numpy adapter: get varchar max length from numpy file header", zap.Int("maxLen", maxLen), zap.Bool("utf", utf)) @@ -511,45 +512,33 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) { data := make([]string, 0) for i := 0; i < readSize; i++ { if utf { - // in the numpy file, each utf8 character occupy utf8.UTFMax bytes, each string occupys utf8.UTFMax*maxLen bytes - // for example, an ANSI character "a" only uses one byte, but it still occupy utf8.UTFMax bytes - // a chinese character uses three bytes, it also occupy utf8.UTFMax bytes + // in the numpy file with utf32 encoding, the dType could be like " 0 { - r, _ := utf8.DecodeRune(raw) - if r == utf8.RuneError { - log.Error("Numpy adapter: failed to decode utf8 string from numpy file", zap.Any("raw", raw[:utf8.UTFMax])) - return nil, fmt.Errorf("Numpy adapter: failed to decode utf8 string from numpy file, error: illegal utf-8 encoding") - } - - // only support ascii characters, because the numpy lib encode the utf8 bytes by its internal method, - // the encode/decode logic is not clear now, return error - n := n.order.Uint32(raw) - if n > 127 { - log.Error("Numpy adapter: a string contains non-ascii characters, not support yet", zap.Int32("utf8Code", r)) - return nil, fmt.Errorf("Numpy adapter: a string contains non-ascii characters, not support yet") - } - - // if a string is shorter than maxLen, the tail characters will be filled with "\u0000"(in utf spec this is Null) - if r > 0 { - str += string(r) - } - - raw = raw[utf8.UTFMax:] + str, err := decodeUtf32(raw, n.order) + if err != nil { + log.Error("Numpy adapter: failed todecode utf32 bytes", zap.Int("i", i), zap.Error(err)) + return nil, fmt.Errorf("Numpy adapter: failed to decode utf32 bytes, error: %w", err) } data = append(data, str) } else { + // in the numpy file with ansi encoding, the dType could be like "S2", maxLen is 2, each string occupys 2 bytes + // bytes.Index(buf, []byte{0}) tell us which position is the end of the string buf, err := ioutil.ReadAll(io.LimitReader(n.reader, int64(maxLen))) if err != nil { - log.Error("Numpy adapter: failed to read string from numpy file", zap.Int("i", i), zap.Any("err", err)) - return nil, fmt.Errorf("Numpy adapter: failed to read string from numpy file, error: %w", err) + log.Error("Numpy adapter: failed to read ascii bytes from numpy file", zap.Int("i", i), zap.Error(err)) + return nil, fmt.Errorf("Numpy adapter: failed to read ascii bytes from numpy file, error: %w", err) } n := bytes.Index(buf, []byte{0}) if n > 0 { @@ -564,3 +553,61 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) { return data, nil } + +func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) { + if len(src)%4 != 0 { + return "", fmt.Errorf("invalid utf32 bytes length %d, the byte array length should be multiple of 4", len(src)) + } + + var str string + for len(src) > 0 { + // check the high bytes, if high bytes are 0, the UNICODE is less than U+FFFF, we can use unicode.UTF16 to decode + isUtf16 := false + var lowbytesPosition int + uOrder := unicode.LittleEndian + if order == binary.LittleEndian { + if src[2] == 0 && src[3] == 0 { + isUtf16 = true + } + lowbytesPosition = 0 + } else { + if src[0] == 0 && src[1] == 0 { + isUtf16 = true + } + lowbytesPosition = 2 + uOrder = unicode.BigEndian + } + + if isUtf16 { + // use unicode.UTF16 to decode the low bytes to utf8 + // utf32 and utf16 is same if the unicode code is less than 65535 + if src[lowbytesPosition] != 0 || src[lowbytesPosition+1] != 0 { + decoder := unicode.UTF16(uOrder, unicode.IgnoreBOM).NewDecoder() + res, err := decoder.Bytes(src[lowbytesPosition : lowbytesPosition+2]) + if err != nil { + return "", fmt.Errorf("failed to decode utf32 binary bytes, error: %w", err) + } + str += string(res) + } + } else { + // convert the 4 bytes to a unicode and encode to utf8 + // Golang strongly opposes utf32 coding, this kind of encoding has been excluded from standard lib + var x uint32 + if order == binary.LittleEndian { + x = uint32(src[3])<<24 | uint32(src[2])<<16 | uint32(src[1])<<8 | uint32(src[0]) + } else { + x = uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3]) + } + r := rune(x) + utf8Code := make([]byte, 4) + utf8.EncodeRune(utf8Code, r) + if r == utf8.RuneError { + return "", fmt.Errorf("failed to convert 4 bytes unicode %d to utf8 rune", x) + } + str += string(utf8Code) + } + + src = src[4:] + } + return str, nil +} diff --git a/internal/util/importutil/numpy_adapter_test.go b/internal/util/importutil/numpy_adapter_test.go index 5b3804711a..f352acdd92 100644 --- a/internal/util/importutil/numpy_adapter_test.go +++ b/internal/util/importutil/numpy_adapter_test.go @@ -20,6 +20,8 @@ import ( "encoding/binary" "io" "os" + "strconv" + "strings" "testing" "github.com/milvus-io/milvus-proto/go-api/schemapb" @@ -553,9 +555,62 @@ func Test_NumpyAdapterRead(t *testing.T) { assert.Nil(t, res) }) - t.Run("test read ascii characters", func(t *testing.T) { + t.Run("test read ascii characters with ansi", func(t *testing.T) { + npyReader := &npy.Reader{ + Header: npy.Header{}, + } + + data := make([]byte, 0) + values := []string{"ab", "ccc", "d"} + maxLen := 0 + for _, str := range values { + if len(str) > maxLen { + maxLen = len(str) + } + } + for _, str := range values { + for i := 0; i < maxLen; i++ { + if i < len(str) { + data = append(data, str[i]) + } else { + data = append(data, 0) + } + } + } + + npyReader.Header.Descr.Shape = append(npyReader.Header.Descr.Shape, len(values)) + + adapter := &NumpyAdapter{ + reader: strings.NewReader(string(data)), + npyReader: npyReader, + readPosition: 0, + dataType: schemapb.DataType_VarChar, + } + + // count should greater than 0 + res, err := adapter.ReadString(0) + assert.NotNil(t, err) + assert.Nil(t, res) + + // maxLen is zero + npyReader.Header.Descr.Type = "S0" + res, err = adapter.ReadString(1) + assert.NotNil(t, err) + assert.Nil(t, res) + + npyReader.Header.Descr.Type = "S" + strconv.FormatInt(int64(maxLen), 10) + + res, err = adapter.ReadString(len(values) + 1) + assert.Nil(t, err) + assert.Equal(t, len(values), len(res)) + for i := 0; i < len(res); i++ { + assert.Equal(t, values[i], res[i]) + } + }) + + t.Run("test read ascii characters with utf32", func(t *testing.T) { filePath := TempFilesPath + "varchar1.npy" - data := []string{"a", "bbb", "c", "dd", "eeee", "fff"} + data := []string{"a ", "bbb", " c", "dd", "eeee", "fff"} err := CreateNumpyFile(filePath, data) assert.Nil(t, err) @@ -583,9 +638,9 @@ func Test_NumpyAdapterRead(t *testing.T) { assert.Nil(t, res) }) - t.Run("test read non-ascii", func(t *testing.T) { + t.Run("test read non-ascii characters with utf32", func(t *testing.T) { filePath := TempFilesPath + "varchar2.npy" - data := []string{"a三百", "马克bbb"} + data := []string{"で と ど ", " 马克bbb", "$(한)삼각*"} err := CreateNumpyFile(filePath, data) assert.Nil(t, err) @@ -596,7 +651,33 @@ func Test_NumpyAdapterRead(t *testing.T) { adapter, err := NewNumpyAdapter(file) assert.Nil(t, err) res, err := adapter.ReadString(len(data)) - assert.NotNil(t, err) - assert.Nil(t, res) + assert.Nil(t, err) + assert.Equal(t, len(data), len(res)) + + for i := 0; i < len(res); i++ { + assert.Equal(t, data[i], res[i]) + } }) } + +func Test_DecodeUtf32(t *testing.T) { + // wrong input + res, err := decodeUtf32([]byte{1, 2}, binary.LittleEndian) + assert.NotNil(t, err) + assert.Empty(t, res) + + // this string contains ascii characters and unicode characters + str := "ad◤三百🎵ゐ↙" + + // utf32 littleEndian of str + src := []byte{97, 0, 0, 0, 100, 0, 0, 0, 228, 37, 0, 0, 9, 78, 0, 0, 126, 118, 0, 0, 181, 243, 1, 0, 144, 48, 0, 0, 153, 33, 0, 0} + res, err = decodeUtf32(src, binary.LittleEndian) + assert.Nil(t, err) + assert.Equal(t, str, res) + + // utf32 bigEndian of str + src = []byte{0, 0, 0, 97, 0, 0, 0, 100, 0, 0, 37, 228, 0, 0, 78, 9, 0, 0, 118, 126, 0, 1, 243, 181, 0, 0, 48, 144, 0, 0, 33, 153} + res, err = decodeUtf32(src, binary.BigEndian) + assert.Nil(t, err) + assert.Equal(t, str, res) +} diff --git a/internal/util/importutil/numpy_parser.go b/internal/util/importutil/numpy_parser.go index 55cd334ad0..90ce1c4793 100644 --- a/internal/util/importutil/numpy_parser.go +++ b/internal/util/importutil/numpy_parser.go @@ -152,8 +152,8 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error { if elementType != schema.DataType { log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType), zap.String("fieldName", fieldName), zap.Any("fieldDataType", schema.DataType)) - return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for scalar field '%s' with type %d", - getTypeName(elementType), schema.GetName(), schema.DataType) + return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for scalar field '%s' with type %s", + getTypeName(elementType), schema.GetName(), getTypeName(schema.DataType)) } // scalar field, the shape should be 1