From ffa06c77b6250e893808d3d9e6a6569703868a86 Mon Sep 17 00:00:00 2001 From: groot Date: Wed, 30 Mar 2022 16:25:30 +0800 Subject: [PATCH] Import util functions (#16237) Signed-off-by: groot --- internal/util/importutil/import_wrapper.go | 401 ++++++++++ .../util/importutil/import_wrapper_test.go | 205 +++++ internal/util/importutil/json_handler.go | 714 ++++++++++++++++++ internal/util/importutil/json_handler_test.go | 397 ++++++++++ internal/util/importutil/json_parser.go | 238 ++++++ internal/util/importutil/json_parser_test.go | 256 +++++++ 6 files changed, 2211 insertions(+) create mode 100644 internal/util/importutil/import_wrapper.go create mode 100644 internal/util/importutil/import_wrapper_test.go create mode 100644 internal/util/importutil/json_handler.go create mode 100644 internal/util/importutil/json_handler_test.go create mode 100644 internal/util/importutil/json_parser.go create mode 100644 internal/util/importutil/json_parser_test.go diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go new file mode 100644 index 0000000000..55985c5492 --- /dev/null +++ b/internal/util/importutil/import_wrapper.go @@ -0,0 +1,401 @@ +package importutil + +import ( + "bufio" + "context" + "errors" + "os" + "path" + "strconv" + + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/common" + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/typeutil" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +const ( + JSONFileExt = ".json" + NumpyFileExt = ".npy" +) + +type ImportWrapper struct { + ctx context.Context // for canceling parse process + cancel context.CancelFunc // for canceling parse process + collectionSchema *schemapb.CollectionSchema // collection schema + shardNum int32 // sharding number of the collection + segmentSize int32 // maximum size of a segment in MB + rowIDAllocator *allocator.IDAllocator // autoid allocator + + callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush a segment +} + +func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int32, + idAlloc *allocator.IDAllocator, flushFunc func(fields map[string]storage.FieldData) error) *ImportWrapper { + if collectionSchema == nil { + log.Error("import error: collection schema is nil") + return nil + } + + // ignore the RowID field and Timestamp field + realSchema := &schemapb.CollectionSchema{ + Name: collectionSchema.GetName(), + Description: collectionSchema.GetDescription(), + AutoID: collectionSchema.GetAutoID(), + Fields: make([]*schemapb.FieldSchema, 0), + } + for i := 0; i < len(collectionSchema.Fields); i++ { + schema := collectionSchema.Fields[i] + if schema.GetName() == common.RowIDFieldName || schema.GetName() == common.TimeStampFieldName { + continue + } + realSchema.Fields = append(realSchema.Fields, schema) + } + + ctx, cancel := context.WithCancel(ctx) + + wrapper := &ImportWrapper{ + ctx: ctx, + cancel: cancel, + collectionSchema: realSchema, + shardNum: shardNum, + segmentSize: segmentSize, + rowIDAllocator: idAlloc, + callFlushFunc: flushFunc, + } + + return wrapper +} + +// this method can be used to cancel parse process +func (p *ImportWrapper) Cancel() error { + p.cancel() + return nil +} + +func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[string]storage.FieldData, msg string, files []string) { + stats := make([]zapcore.Field, 0) + for k, v := range fieldsData { + stats = append(stats, zap.Int(k, v.RowNum())) + } + for i := 0; i < len(files); i++ { + stats = append(stats, zap.String("file", files[i])) + } + log.Debug(msg, stats...) +} + +// import process entry +// filePath and rowBased are from ImportTask +// if onlyValidate is true, this process only do validation, no data generated, callFlushFunc will not be called +func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate bool) error { + if rowBased { + // parse and consume row-based files + // for row-based files, the JSONRowConsumer will generate autoid for primary key, and split rows into segments + // according to shard number, so the callFlushFunc will be called in the JSONRowConsumer + for i := 0; i < len(filePaths); i++ { + filePath := filePaths[i] + fileName := path.Base(filePath) + fileType := path.Ext(fileName) + log.Debug("imprort wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType)) + + if fileType == JSONFileExt { + err := func() error { + file, err := os.Open(filePath) + if err != nil { + return err + } + defer file.Close() + + reader := bufio.NewReader(file) + parser := NewJSONParser(p.ctx, p.collectionSchema) + var consumer *JSONRowConsumer + if !onlyValidate { + flushFunc := func(fields map[string]storage.FieldData) error { + p.printFieldsDataInfo(fields, "import wrapper: prepare to flush segment", filePaths) + return p.callFlushFunc(fields) + } + consumer = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, p.segmentSize, flushFunc) + } + validator := NewJSONRowValidator(p.collectionSchema, consumer) + + err = parser.ParseRows(reader, validator) + if err != nil { + return err + } + + return nil + }() + + if err != nil { + log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) + return err + } + } + } + } else { + // parse and consume row-based files + // for column-based files, the XXXColumnConsumer only output map[string]storage.FieldData + // after all columns are parsed/consumed, we need to combine map[string]storage.FieldData into one + // and use splitFieldsData() to split fields data into segments according to shard number + fieldsData := initSegmentData(p.collectionSchema) + rowCount := 0 + + // function to combine column data into fieldsData + combineFunc := func(fields map[string]storage.FieldData) error { + if len(fields) == 0 { + return nil + } + + fieldNames := make([]string, 0) + for k, v := range fields { + data, ok := fieldsData[k] + if ok && data.RowNum() > 0 { + return errors.New("imprort error: the field " + k + " is duplicated") + } + + fieldsData[k] = v + fieldNames = append(fieldNames, k) + + if rowCount == 0 { + rowCount = v.RowNum() + } else if rowCount != v.RowNum() { + return errors.New("imprort error: the field " + k + " row count " + strconv.Itoa(v.RowNum()) + " doesn't equal " + strconv.Itoa(rowCount)) + } + } + + log.Debug("imprort wrapper: ", zap.Any("fieldNames", fieldNames), zap.Int("rowCount", rowCount)) + + return nil + } + + // parse/validate/consume data + for i := 0; i < len(filePaths); i++ { + filePath := filePaths[i] + fileName := path.Base(filePath) + fileType := path.Ext(fileName) + log.Debug("imprort wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType)) + + if fileType == JSONFileExt { + err := func() error { + file, err := os.Open(filePath) + if err != nil { + log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) + return err + } + defer file.Close() + + reader := bufio.NewReader(file) + parser := NewJSONParser(p.ctx, p.collectionSchema) + var consumer *JSONColumnConsumer + if !onlyValidate { + consumer = NewJSONColumnConsumer(p.collectionSchema, combineFunc) + } + validator := NewJSONColumnValidator(p.collectionSchema, consumer) + + err = parser.ParseColumns(reader, validator) + if err != nil { + log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) + return err + } + + return nil + }() + + if err != nil { + log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) + return err + } + } else if fileType == NumpyFileExt { + + } + } + + // split fields data into segments + err := p.splitFieldsData(fieldsData, filePaths) + if err != nil { + log.Error("imprort error: " + err.Error()) + return err + } + } + + return nil +} + +func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error { + switch schema.DataType { + case schemapb.DataType_Bool: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.BoolFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(bool)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Float: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.FloatFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(float32)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Double: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.DoubleFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(float64)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Int8: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int8FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int8)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Int16: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int16FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int16)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Int32: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int32FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int32)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Int64: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int64FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int64)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_BinaryVector: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.BinaryVectorFieldData) + arr.Data = append(arr.Data, src.GetRow(n).([]byte)...) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_FloatVector: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.FloatVectorFieldData) + arr.Data = append(arr.Data, src.GetRow(n).([]float32)...) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_String: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.StringFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(string)) + return nil + } + default: + return nil + } +} + +func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData, files []string) error { + if len(fieldsData) == 0 { + return errors.New("imprort error: fields data is empty") + } + + var primaryKey *schemapb.FieldSchema + for i := 0; i < len(p.collectionSchema.Fields); i++ { + schema := p.collectionSchema.Fields[i] + if schema.GetIsPrimaryKey() { + primaryKey = schema + } else { + _, ok := fieldsData[schema.GetName()] + if !ok { + return errors.New("imprort error: field " + schema.GetName() + " not provided") + } + } + } + if primaryKey == nil { + return errors.New("imprort error: primary key field is not found") + } + + rowCount := 0 + for _, v := range fieldsData { + rowCount = v.RowNum() + break + } + + primaryData, ok := fieldsData[primaryKey.GetName()] + if !ok { + // generate auto id for primary key + if primaryKey.GetAutoID() { + var rowIDBegin typeutil.UniqueID + var rowIDEnd typeutil.UniqueID + rowIDBegin, rowIDEnd, _ = p.rowIDAllocator.Alloc(uint32(rowCount)) + + primaryDataArr := primaryData.(*storage.Int64FieldData) + for i := rowIDBegin; i < rowIDEnd; i++ { + primaryDataArr.Data = append(primaryDataArr.Data, rowIDBegin+i) + } + } + } + + if primaryData.RowNum() <= 0 { + return errors.New("imprort error: primary key " + primaryKey.GetName() + " not provided") + } + + // prepare segemnts + segmentsData := make([]map[string]storage.FieldData, 0, p.shardNum) + for i := 0; i < int(p.shardNum); i++ { + segmentData := initSegmentData(p.collectionSchema) + if segmentData == nil { + return nil + } + segmentsData = append(segmentsData, segmentData) + } + + // prepare append functions + appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error) + for i := 0; i < len(p.collectionSchema.Fields); i++ { + schema := p.collectionSchema.Fields[i] + appendFunc := p.appendFunc(schema) + if appendFunc == nil { + return errors.New("imprort error: unsupported field data type") + } + appendFunctions[schema.GetName()] = appendFunc + } + + // split data into segments + for i := 0; i < rowCount; i++ { + id := primaryData.GetRow(i).(int64) + // hash to a shard number + hash, _ := typeutil.Hash32Int64(id) + shard := hash % uint32(p.shardNum) + + for k := 0; k < len(p.collectionSchema.Fields); k++ { + schema := p.collectionSchema.Fields[k] + srcData := fieldsData[schema.GetName()] + targetData := segmentsData[shard][schema.GetName()] + appendFunc := appendFunctions[schema.GetName()] + err := appendFunc(srcData, i, targetData) + if err != nil { + return err + } + } + } + + // call flush function + for i := 0; i < int(p.shardNum); i++ { + segmentData := segmentsData[i] + p.printFieldsDataInfo(segmentData, "import wrapper: prepare to flush segment", files) + err := p.callFlushFunc(segmentData) + if err != nil { + return err + } + } + + return nil +} diff --git a/internal/util/importutil/import_wrapper_test.go b/internal/util/importutil/import_wrapper_test.go new file mode 100644 index 0000000000..133b073c52 --- /dev/null +++ b/internal/util/importutil/import_wrapper_test.go @@ -0,0 +1,205 @@ +package importutil + +import ( + "context" + "os" + "testing" + + "github.com/milvus-io/milvus/internal/common" + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/stretchr/testify/assert" +) + +const ( + TempFilesPath = "/tmp/milvus_test/import/" +) + +func Test_NewImportWrapper(t *testing.T) { + ctx := context.Background() + wrapper := NewImportWrapper(ctx, nil, 2, 1, nil, nil) + assert.Nil(t, wrapper) + + schema := &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: make([]*schemapb.FieldSchema, 0), + } + schema.Fields = append(schema.Fields, sampleSchema().Fields...) + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 106, + Name: common.RowIDFieldName, + IsPrimaryKey: true, + AutoID: false, + Description: "int64", + DataType: schemapb.DataType_Int64, + }) + wrapper = NewImportWrapper(ctx, schema, 2, 1, nil, nil) + assert.NotNil(t, wrapper) + + err := wrapper.Cancel() + assert.Nil(t, err) +} + +func saveFile(t *testing.T, filePath string, content []byte) *os.File { + fp, err := os.Create(filePath) + assert.Nil(t, err) + + _, err = fp.Write(content) + assert.Nil(t, err) + + return fp +} + +func Test_ImportRowBased(t *testing.T) { + ctx := context.Background() + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + + idAllocator := newIDAllocator(ctx, t) + + content := []byte(`{ + "rows":[ + {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}, + {"field_bool": false, "field_int8": 11, "field_int16": 102, "field_int32": 1002, "field_int64": 10002, "field_float": 3.15, "field_double": 2.56, "field_string": "hello world", "field_binary_vector": [253, 0], "field_float_vector": [2.1, 2.2, 2.3, 2.4]}, + {"field_bool": true, "field_int8": 12, "field_int16": 103, "field_int32": 1003, "field_int64": 10003, "field_float": 3.16, "field_double": 3.56, "field_string": "hello world", "field_binary_vector": [252, 0], "field_float_vector": [3.1, 3.2, 3.3, 3.4]}, + {"field_bool": false, "field_int8": 13, "field_int16": 104, "field_int32": 1004, "field_int64": 10004, "field_float": 3.17, "field_double": 4.56, "field_string": "hello world", "field_binary_vector": [251, 0], "field_float_vector": [4.1, 4.2, 4.3, 4.4]}, + {"field_bool": true, "field_int8": 14, "field_int16": 105, "field_int32": 1005, "field_int64": 10005, "field_float": 3.18, "field_double": 5.56, "field_string": "hello world", "field_binary_vector": [250, 0], "field_float_vector": [5.1, 5.2, 5.3, 5.4]} + ] + }`) + + filePath := TempFilesPath + "rows_1.json" + fp1 := saveFile(t, filePath, content) + defer fp1.Close() + + rowCount := 0 + flushFunc := func(fields map[string]storage.FieldData) error { + count := 0 + for _, data := range fields { + assert.Less(t, 0, data.RowNum()) + if count == 0 { + count = data.RowNum() + } else { + assert.Equal(t, count, data.RowNum()) + } + } + rowCount += count + return nil + } + + // success case + wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + files := make([]string, 0) + files = append(files, filePath) + err = wrapper.Import(files, true, false) + assert.Nil(t, err) + assert.Equal(t, 5, rowCount) + + // parse error + content = []byte(`{ + "rows":[ + {"field_bool": true, "field_int8": false, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}, + ] + }`) + + filePath = TempFilesPath + "rows_2.json" + fp2 := saveFile(t, filePath, content) + defer fp2.Close() + + wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + files = make([]string, 0) + files = append(files, filePath) + err = wrapper.Import(files, true, false) + assert.NotNil(t, err) + + // file doesn't exist + files = make([]string, 0) + files = append(files, "/dummy/dummy.json") + err = wrapper.Import(files, true, false) + assert.NotNil(t, err) + +} + +func Test_ImportColumnBased(t *testing.T) { + ctx := context.Background() + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + + idAllocator := newIDAllocator(ctx, t) + + content := []byte(`{ + "field_bool": [true, false, true, true, true], + "field_int8": [10, 11, 12, 13, 14], + "field_int16": [100, 101, 102, 103, 104], + "field_int32": [1000, 1001, 1002, 1003, 1004], + "field_int64": [10000, 10001, 10002, 10003, 10004], + "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], + "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], + "field_string": ["a", "b", "c", "d", "e"], + "field_binary_vector": [ + [254, 1], + [253, 2], + [252, 3], + [251, 4], + [250, 5] + ], + "field_float_vector": [ + [1.1, 1.2, 1.3, 1.4], + [2.1, 2.2, 2.3, 2.4], + [3.1, 3.2, 3.3, 3.4], + [4.1, 4.2, 4.3, 4.4], + [5.1, 5.2, 5.3, 5.4] + ] + }`) + + filePath := TempFilesPath + "columns_1.json" + fp1 := saveFile(t, filePath, content) + defer fp1.Close() + + rowCount := 0 + flushFunc := func(fields map[string]storage.FieldData) error { + count := 0 + for _, data := range fields { + assert.Less(t, 0, data.RowNum()) + if count == 0 { + count = data.RowNum() + } else { + assert.Equal(t, count, data.RowNum()) + } + } + rowCount += count + return nil + } + + // success case + wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + files := make([]string, 0) + files = append(files, filePath) + err = wrapper.Import(files, false, false) + assert.Nil(t, err) + assert.Equal(t, 5, rowCount) + + // parse error + content = []byte(`{ + "field_bool": [true, false, true, true, true] + }`) + + filePath = TempFilesPath + "rows_2.json" + fp2 := saveFile(t, filePath, content) + defer fp2.Close() + + wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + files = make([]string, 0) + files = append(files, filePath) + err = wrapper.Import(files, false, false) + assert.NotNil(t, err) + + // file doesn't exist + files = make([]string, 0) + files = append(files, "/dummy/dummy.json") + err = wrapper.Import(files, false, false) + assert.NotNil(t, err) +} diff --git a/internal/util/importutil/json_handler.go b/internal/util/importutil/json_handler.go new file mode 100644 index 0000000000..a71e071313 --- /dev/null +++ b/internal/util/importutil/json_handler.go @@ -0,0 +1,714 @@ +package importutil + +import ( + "errors" + "fmt" + "strconv" + + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/typeutil" + "go.uber.org/zap" +) + +// interface to process rows data +type JSONRowHandler interface { + Handle(rows []map[string]interface{}) error +} + +// interface to process column data +type JSONColumnHandler interface { + Handle(columns map[string][]interface{}) error +} + +// method to get dimension of vecotor field +func getFieldDimension(schema *schemapb.FieldSchema) (int, error) { + for _, kvPair := range schema.GetTypeParams() { + key, value := kvPair.GetKey(), kvPair.GetValue() + if key == "dim" { + dim, err := strconv.Atoi(value) + if err != nil { + return 0, errors.New("vector dimension is invalid") + } + return dim, nil + } + } + + return 0, errors.New("vector dimension is not defined") +} + +// field value validator +type Validator struct { + validateFunc func(obj interface{}) error // validate data type function + convertFunc func(obj interface{}, field storage.FieldData) error // convert data function + primaryKey bool // true for primary key + autoID bool // only for primary key field + dimension int // only for vector field +} + +// method to construct valiator functions +func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[string]*Validator) error { + if collectionSchema == nil { + return errors.New("collection schema is nil") + } + + // json decoder parse all the numeric value into float64 + numericValidator := func(obj interface{}) error { + switch obj.(type) { + case float64: + return nil + default: + s := fmt.Sprintf("%v", obj) + msg := "illegal numeric value " + s + return errors.New(msg) + } + + } + + for i := 0; i < len(collectionSchema.Fields); i++ { + schema := collectionSchema.Fields[i] + + validators[schema.GetName()] = &Validator{} + validators[schema.GetName()].primaryKey = schema.GetIsPrimaryKey() + validators[schema.GetName()].autoID = schema.GetAutoID() + + switch schema.DataType { + case schemapb.DataType_Bool: + validators[schema.GetName()].validateFunc = func(obj interface{}) error { + switch obj.(type) { + case bool: + return nil + default: + s := fmt.Sprintf("%v", obj) + msg := "illegal value " + s + " for bool type field " + schema.GetName() + return errors.New(msg) + } + + } + validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := obj.(bool) + field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value) + field.(*storage.BoolFieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Float: + validators[schema.GetName()].validateFunc = numericValidator + validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := float32(obj.(float64)) + field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value) + field.(*storage.FloatFieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Double: + validators[schema.GetName()].validateFunc = numericValidator + validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := obj.(float64) + field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value) + field.(*storage.DoubleFieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Int8: + validators[schema.GetName()].validateFunc = numericValidator + validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := int8(obj.(float64)) + field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value) + field.(*storage.Int8FieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Int16: + validators[schema.GetName()].validateFunc = numericValidator + validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := int16(obj.(float64)) + field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value) + field.(*storage.Int16FieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Int32: + validators[schema.GetName()].validateFunc = numericValidator + validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := int32(obj.(float64)) + field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value) + field.(*storage.Int32FieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Int64: + validators[schema.GetName()].validateFunc = numericValidator + validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := int64(obj.(float64)) + field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value) + field.(*storage.Int64FieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_BinaryVector: + dim, err := getFieldDimension(schema) + if err != nil { + return err + } + validators[schema.GetName()].dimension = dim + + validators[schema.GetName()].validateFunc = func(obj interface{}) error { + switch vt := obj.(type) { + case []interface{}: + if len(vt)*8 != dim { + msg := "bit size " + strconv.Itoa(len(vt)*8) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + return errors.New(msg) + } + for i := 0; i < len(vt); i++ { + if e := numericValidator(vt[i]); e != nil { + msg := e.Error() + " for binary vector field " + schema.GetName() + return errors.New(msg) + } + + t := int(vt[i].(float64)) + if t >= 255 || t < 0 { + msg := "illegal value " + strconv.Itoa(t) + " for binary vector field " + schema.GetName() + return errors.New(msg) + } + } + return nil + default: + s := fmt.Sprintf("%v", obj) + msg := s + " is not an array for binary vector field " + schema.GetName() + return errors.New(msg) + } + } + + validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + arr := obj.([]interface{}) + for i := 0; i < len(arr); i++ { + value := byte(arr[i].(float64)) + field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, value) + } + + field.(*storage.BinaryVectorFieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_FloatVector: + dim, err := getFieldDimension(schema) + if err != nil { + return err + } + validators[schema.GetName()].dimension = dim + + validators[schema.GetName()].validateFunc = func(obj interface{}) error { + switch vt := obj.(type) { + case []interface{}: + if len(vt) != dim { + msg := "array size " + strconv.Itoa(len(vt)) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + return errors.New(msg) + } + for i := 0; i < len(vt); i++ { + if e := numericValidator(vt[i]); e != nil { + msg := e.Error() + " for float vector field " + schema.GetName() + return errors.New(msg) + } + } + return nil + default: + s := fmt.Sprintf("%v", obj) + msg := s + " is not an array for float vector field " + schema.GetName() + return errors.New(msg) + } + } + + validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + arr := obj.([]interface{}) + for i := 0; i < len(arr); i++ { + value := float32(arr[i].(float64)) + field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, value) + } + field.(*storage.FloatVectorFieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_String: + validators[schema.GetName()].validateFunc = func(obj interface{}) error { + switch obj.(type) { + case string: + return nil + default: + s := fmt.Sprintf("%v", obj) + msg := s + " is not a string for string type field " + schema.GetName() + return errors.New(msg) + } + } + + validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := obj.(string) + field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, value) + field.(*storage.StringFieldData).NumRows[0]++ + return nil + } + default: + return errors.New("unsupport data type: " + strconv.Itoa(int(collectionSchema.Fields[i].DataType))) + } + } + + return nil +} + +// row-based json format validator class +type JSONRowValidator struct { + downstream JSONRowHandler // downstream processor, typically is a JSONRowComsumer + validators map[string]*Validator // validators for each field + rowCounter int64 // how many rows have been validated +} + +func NewJSONRowValidator(collectionSchema *schemapb.CollectionSchema, downstream JSONRowHandler) *JSONRowValidator { + v := &JSONRowValidator{ + validators: make(map[string]*Validator), + downstream: downstream, + rowCounter: 0, + } + initValidators(collectionSchema, v.validators) + + return v +} + +func (v *JSONRowValidator) ValidateCount() int64 { + return v.rowCounter +} + +func (v *JSONRowValidator) Handle(rows []map[string]interface{}) error { + if v.validators == nil || len(v.validators) == 0 { + return errors.New("JSON row validator is not initialized") + } + + // parse completed + if rows == nil { + log.Debug("JSON row validation finished") + if v.downstream != nil { + return v.downstream.Handle(rows) + } + return nil + } + + for i := 0; i < len(rows); i++ { + row := rows[i] + for name, validator := range v.validators { + if validator.primaryKey && validator.autoID { + // auto-generated primary key, ignore + continue + } + value, ok := row[name] + if !ok { + return errors.New("JSON row validator: field " + name + " missed at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10)) + } + + if err := validator.validateFunc(value); err != nil { + return errors.New("JSON row validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter, 10)) + } + } + } + + v.rowCounter += int64(len(rows)) + + if v.downstream != nil { + return v.downstream.Handle(rows) + } + + return nil +} + +// column-based json format validator class +type JSONColumnValidator struct { + downstream JSONColumnHandler // downstream processor, typically is a JSONColumnComsumer + validators map[string]*Validator // validators for each field + rowCounter map[string]int64 // row count of each field +} + +func NewJSONColumnValidator(schema *schemapb.CollectionSchema, downstream JSONColumnHandler) *JSONColumnValidator { + v := &JSONColumnValidator{ + validators: make(map[string]*Validator), + downstream: downstream, + rowCounter: make(map[string]int64), + } + initValidators(schema, v.validators) + + for k := range v.validators { + v.rowCounter[k] = 0 + } + + return v +} + +func (v *JSONColumnValidator) ValidateCount() map[string]int64 { + return v.rowCounter +} + +func (v *JSONColumnValidator) Handle(columns map[string][]interface{}) error { + if v.validators == nil || len(v.validators) == 0 { + return errors.New("JSON column validator is not initialized") + } + + // parse completed + if columns == nil { + // all columns are parsed? + maxCount := int64(0) + for _, counter := range v.rowCounter { + if counter > maxCount { + maxCount = counter + } + } + for k := range v.validators { + counter, ok := v.rowCounter[k] + if !ok || counter != maxCount { + return errors.New("JSON column validator: the field " + k + " row count is not equal to other fields") + } + } + + log.Debug("JSON column validation finished") + if v.downstream != nil { + return v.downstream.Handle(nil) + } + return nil + } + + for name, values := range columns { + validator, ok := v.validators[name] + if !ok { + // not a valid field name + break + } + + for i := 0; i < len(values); i++ { + if err := validator.validateFunc(values[i]); err != nil { + return errors.New("JSON column validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter[name]+int64(i), 10)) + } + } + + v.rowCounter[name] += int64(len(values)) + } + + if v.downstream != nil { + return v.downstream.Handle(columns) + } + + return nil +} + +// row-based json format consumer class +type JSONRowConsumer struct { + collectionSchema *schemapb.CollectionSchema // collection schema + rowIDAllocator *allocator.IDAllocator // autoid allocator + validators map[string]*Validator // validators for each field + rowCounter int64 // how many rows have been consumed + shardNum int32 // sharding number of the collection + segmentsData []map[string]storage.FieldData // in-memory segments data + segmentSize int32 // maximum size of a segment in MB + primaryKey string // name of primary key + + callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush segment +} + +func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[string]storage.FieldData { + segmentData := make(map[string]storage.FieldData) + for i := 0; i < len(collectionSchema.Fields); i++ { + schema := collectionSchema.Fields[i] + switch schema.DataType { + case schemapb.DataType_Bool: + segmentData[schema.GetName()] = &storage.BoolFieldData{ + Data: make([]bool, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Float: + segmentData[schema.GetName()] = &storage.FloatFieldData{ + Data: make([]float32, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Double: + segmentData[schema.GetName()] = &storage.DoubleFieldData{ + Data: make([]float64, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Int8: + segmentData[schema.GetName()] = &storage.Int8FieldData{ + Data: make([]int8, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Int16: + segmentData[schema.GetName()] = &storage.Int16FieldData{ + Data: make([]int16, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Int32: + segmentData[schema.GetName()] = &storage.Int32FieldData{ + Data: make([]int32, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Int64: + segmentData[schema.GetName()] = &storage.Int64FieldData{ + Data: make([]int64, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_BinaryVector: + dim, _ := getFieldDimension(schema) + segmentData[schema.GetName()] = &storage.BinaryVectorFieldData{ + Data: make([]byte, 0), + NumRows: []int64{0}, + Dim: dim, + } + case schemapb.DataType_FloatVector: + dim, _ := getFieldDimension(schema) + segmentData[schema.GetName()] = &storage.FloatVectorFieldData{ + Data: make([]float32, 0), + NumRows: []int64{0}, + Dim: dim, + } + case schemapb.DataType_String: + segmentData[schema.GetName()] = &storage.StringFieldData{ + Data: make([]string, 0), + NumRows: []int64{0}, + } + default: + log.Error("JSON row consumer error: unsupported data type", zap.Int("DataType", int(schema.DataType))) + return nil + } + } + + return segmentData +} + +func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, segmentSize int32, + flushFunc func(fields map[string]storage.FieldData) error) *JSONRowConsumer { + if collectionSchema == nil { + log.Error("JSON row consumer: collection schema is nil") + return nil + } + + v := &JSONRowConsumer{ + collectionSchema: collectionSchema, + rowIDAllocator: idAlloc, + validators: make(map[string]*Validator), + shardNum: shardNum, + segmentSize: segmentSize, + rowCounter: 0, + callFlushFunc: flushFunc, + } + + initValidators(collectionSchema, v.validators) + + v.segmentsData = make([]map[string]storage.FieldData, 0, shardNum) + for i := 0; i < int(shardNum); i++ { + segmentData := initSegmentData(collectionSchema) + if segmentData == nil { + return nil + } + v.segmentsData = append(v.segmentsData, segmentData) + } + + for i := 0; i < len(collectionSchema.Fields); i++ { + schema := collectionSchema.Fields[i] + if schema.GetIsPrimaryKey() { + v.primaryKey = schema.GetName() + break + } + } + // primary key not found + if v.primaryKey == "" { + log.Error("JSON row consumer: collection schema has no primary key") + return nil + } + // primary key is autoid, id generator is required + if v.validators[v.primaryKey].autoID && idAlloc == nil { + log.Error("JSON row consumer: ID allocator is nil") + return nil + } + + return v +} + +func (v *JSONRowConsumer) flush(force bool) error { + // force flush all data + if force { + for i := 0; i < len(v.segmentsData); i++ { + segmentData := v.segmentsData[i] + rowNum := segmentData[v.primaryKey].RowNum() + if rowNum > 0 { + log.Debug("JSON row consumer: force flush segment", zap.Int("rows", rowNum)) + v.callFlushFunc(segmentData) + } + } + + return nil + } + + // segment size can be flushed + for i := 0; i < len(v.segmentsData); i++ { + segmentData := v.segmentsData[i] + memSize := 0 + for _, field := range segmentData { + memSize += field.GetMemorySize() + } + if memSize >= int(v.segmentSize)*1024*1024 { + log.Debug("JSON row consumer: flush fulled segment", zap.Int("bytes", memSize)) + v.callFlushFunc(segmentData) + v.segmentsData[i] = initSegmentData(v.collectionSchema) + } + } + + return nil +} + +func (v *JSONRowConsumer) Handle(rows []map[string]interface{}) error { + if v.validators == nil || len(v.validators) == 0 { + return errors.New("JSON row consumer is not initialized") + } + + // flush in necessery + if rows == nil { + err := v.flush(true) + log.Debug("JSON row consumer finished") + return err + } + + err := v.flush(false) + if err != nil { + return err + } + + // prepare autoid + primaryValidator := v.validators[v.primaryKey] + var rowIDBegin typeutil.UniqueID + var rowIDEnd typeutil.UniqueID + if primaryValidator.autoID { + var err error + rowIDBegin, rowIDEnd, err = v.rowIDAllocator.Alloc(uint32(len(rows))) + if err != nil { + return errors.New("JSON row consumer: " + err.Error()) + } + if rowIDEnd-rowIDBegin != int64(len(rows)) { + return errors.New("JSON row consumer: failed to allocate ID for " + strconv.Itoa(len(rows)) + " rows") + } + } + + // consume rows + for i := 0; i < len(rows); i++ { + row := rows[i] + + // firstly get/generate the row id + var id int64 + if primaryValidator.autoID { + id = rowIDBegin + int64(i) + } else { + value := row[v.primaryKey] + id = int64(value.(float64)) + } + + // hash to a shard number + hash, _ := typeutil.Hash32Int64(id) + shard := hash % uint32(v.shardNum) + pkArray := v.segmentsData[shard][v.primaryKey].(*storage.Int64FieldData) + pkArray.Data = append(pkArray.Data, id) + + // convert value and consume + for name, validator := range v.validators { + if validator.primaryKey { + continue + } + value := row[name] + if err := validator.convertFunc(value, v.segmentsData[shard][name]); err != nil { + return errors.New("JSON row consumer: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter, 10)) + } + } + } + + v.rowCounter += int64(len(rows)) + + return nil +} + +// column-based json format consumer class +type JSONColumnConsumer struct { + collectionSchema *schemapb.CollectionSchema // collection schema + validators map[string]*Validator // validators for each field + fieldsData map[string]storage.FieldData // in-memory fields data + primaryKey string // name of primary key + + callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush segment +} + +func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema, + flushFunc func(fields map[string]storage.FieldData) error) *JSONColumnConsumer { + if collectionSchema == nil { + return nil + } + + v := &JSONColumnConsumer{ + collectionSchema: collectionSchema, + validators: make(map[string]*Validator), + callFlushFunc: flushFunc, + } + initValidators(collectionSchema, v.validators) + v.fieldsData = initSegmentData(collectionSchema) + + for i := 0; i < len(collectionSchema.Fields); i++ { + schema := collectionSchema.Fields[i] + if schema.GetIsPrimaryKey() { + v.primaryKey = schema.GetName() + break + } + } + + return v +} + +func (v *JSONColumnConsumer) flush() error { + // check row count, should be equal + rowCount := 0 + for name, field := range v.fieldsData { + if name == v.primaryKey && v.validators[v.primaryKey].autoID { + continue + } + cnt := field.RowNum() + if rowCount == 0 { + rowCount = cnt + } else if rowCount != cnt { + return errors.New("JSON column consumer: " + name + " row count " + strconv.Itoa(cnt) + " doesn't equal " + strconv.Itoa(rowCount)) + } + } + + if rowCount == 0 { + return errors.New("JSON column consumer: row count is 0") + } + log.Debug("JSON column consumer: rows parsed", zap.Int("rowCount", rowCount)) + + // output the fileds data, let outside split them into segments + return v.callFlushFunc(v.fieldsData) +} + +func (v *JSONColumnConsumer) Handle(columns map[string][]interface{}) error { + if v.validators == nil || len(v.validators) == 0 { + return errors.New("JSON column consumer is not initialized") + } + + // flush at the end + if columns == nil { + err := v.flush() + log.Debug("JSON column consumer finished") + return err + } + + for name, values := range columns { + validator, ok := v.validators[name] + if !ok { + // not a valid field name + break + } + + if validator.primaryKey && validator.autoID { + // autoid is no need to provide + break + } + + // convert and consume data + for i := 0; i < len(values); i++ { + if err := validator.convertFunc(values[i], v.fieldsData[name]); err != nil { + return errors.New("JSON column consumer: " + err.Error() + " of field " + name) + } + } + } + + return nil +} diff --git a/internal/util/importutil/json_handler_test.go b/internal/util/importutil/json_handler_test.go new file mode 100644 index 0000000000..1d607477e7 --- /dev/null +++ b/internal/util/importutil/json_handler_test.go @@ -0,0 +1,397 @@ +package importutil + +import ( + "context" + "strings" + "testing" + + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/stretchr/testify/assert" +) + +type mockIDAllocator struct { +} + +func (tso *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { + return &rootcoordpb.AllocIDResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + ID: int64(1), + Count: req.Count, + }, nil +} + +func newIDAllocator(ctx context.Context, t *testing.T) *allocator.IDAllocator { + mockIDAllocator := &mockIDAllocator{} + + idAllocator, err := allocator.NewIDAllocator(ctx, mockIDAllocator, int64(1)) + assert.Nil(t, err) + err = idAllocator.Start() + assert.Nil(t, err) + + return idAllocator +} + +func Test_GetFieldDimension(t *testing.T) { + schema := &schemapb.FieldSchema{ + FieldID: 111, + Name: "field_float_vector", + IsPrimaryKey: false, + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + } + + dim, err := getFieldDimension(schema) + assert.Nil(t, err) + assert.Equal(t, 4, dim) + + schema.TypeParams = []*commonpb.KeyValuePair{ + {Key: "dim", Value: "abc"}, + } + dim, err = getFieldDimension(schema) + assert.NotNil(t, err) + assert.Equal(t, 0, dim) + + schema.TypeParams = []*commonpb.KeyValuePair{} + dim, err = getFieldDimension(schema) + assert.NotNil(t, err) + assert.Equal(t, 0, dim) +} + +func Test_InitValidators(t *testing.T) { + validators := make(map[string]*Validator) + err := initValidators(nil, validators) + assert.NotNil(t, err) + + // success case + err = initValidators(sampleSchema(), validators) + assert.Nil(t, err) + assert.Equal(t, len(sampleSchema().Fields), len(validators)) + + checkFunc := func(funcName string, validVal interface{}, invalidVal interface{}) { + v, ok := validators[funcName] + assert.True(t, ok) + err = v.validateFunc(validVal) + assert.Nil(t, err) + err = v.validateFunc(invalidVal) + assert.NotNil(t, err) + } + + // validate functions + var validVal interface{} = true + var invalidVal interface{} = "aa" + checkFunc("field_bool", validVal, invalidVal) + + validVal = float64(100) + invalidVal = "aa" + checkFunc("field_int8", validVal, invalidVal) + checkFunc("field_int16", validVal, invalidVal) + checkFunc("field_int32", validVal, invalidVal) + checkFunc("field_int64", validVal, invalidVal) + checkFunc("field_float", validVal, invalidVal) + checkFunc("field_double", validVal, invalidVal) + + validVal = "aa" + invalidVal = 100 + checkFunc("field_string", validVal, invalidVal) + + validVal = []interface{}{float64(100), float64(101)} + invalidVal = "aa" + checkFunc("field_binary_vector", validVal, invalidVal) + invalidVal = []interface{}{float64(100)} + checkFunc("field_binary_vector", validVal, invalidVal) + invalidVal = []interface{}{float64(100), float64(101), float64(102)} + checkFunc("field_binary_vector", validVal, invalidVal) + invalidVal = []interface{}{true, true} + checkFunc("field_binary_vector", validVal, invalidVal) + invalidVal = []interface{}{float64(255), float64(-1)} + checkFunc("field_binary_vector", validVal, invalidVal) + + validVal = []interface{}{float64(1), float64(2), float64(3), float64(4)} + invalidVal = true + checkFunc("field_float_vector", validVal, invalidVal) + invalidVal = []interface{}{float64(1), float64(2), float64(3)} + checkFunc("field_float_vector", validVal, invalidVal) + invalidVal = []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5)} + checkFunc("field_float_vector", validVal, invalidVal) + invalidVal = []interface{}{"a", "b", "c", "d"} + checkFunc("field_float_vector", validVal, invalidVal) + + // error cases + schema := &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: make([]*schemapb.FieldSchema, 0), + } + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 111, + Name: "field_float_vector", + IsPrimaryKey: false, + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "aa"}, + }, + }) + + validators = make(map[string]*Validator) + err = initValidators(schema, validators) + assert.NotNil(t, err) + + schema.Fields = make([]*schemapb.FieldSchema, 0) + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 110, + Name: "field_binary_vector", + IsPrimaryKey: false, + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "aa"}, + }, + }) + + err = initValidators(schema, validators) + assert.NotNil(t, err) +} + +func Test_JSONRowValidator(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := sampleSchema() + parser := NewJSONParser(ctx, schema) + assert.NotNil(t, parser) + + // 0 row case + reader := strings.NewReader(`{ + "rows":[] + }`) + + validator := NewJSONRowValidator(schema, nil) + err := parser.ParseRows(reader, validator) + assert.Nil(t, err) + assert.Equal(t, int64(0), validator.ValidateCount()) + + // // missed some fields + // reader = strings.NewReader(`{ + // "rows":[ + // {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}, + // {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]} + // ] + // }`) + // err = parser.ParseRows(reader, validator) + // assert.NotNil(t, err) + + // invalid dimension + reader = strings.NewReader(`{ + "rows":[ + {"field_bool": true, "field_int8": true, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0, 1, 66, 128, 0, 1, 66], "field_float_vector": [1.1, 1.2, 1.3, 1.4]} + ] + }`) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) + + // invalid value type + reader = strings.NewReader(`{ + "rows":[ + {"field_bool": true, "field_int8": true, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]} + ] + }`) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) + + // init failed + validator.validators = nil + err = validator.Handle(nil) + assert.NotNil(t, err) +} + +func Test_JSONColumnValidator(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := sampleSchema() + parser := NewJSONParser(ctx, schema) + assert.NotNil(t, parser) + + // 0 row case + reader := strings.NewReader(`{ + "field_bool": [], + "field_int8": [], + "field_int16": [], + "field_int32": [], + "field_int64": [], + "field_float": [], + "field_double": [], + "field_string": [], + "field_binary_vector": [], + "field_float_vector": [] + }`) + + validator := NewJSONColumnValidator(schema, nil) + err := parser.ParseColumns(reader, validator) + assert.Nil(t, err) + for _, count := range validator.rowCounter { + assert.Equal(t, int64(0), count) + } + + // different row count + reader = strings.NewReader(`{ + "field_bool": [true], + "field_int8": [], + "field_int16": [], + "field_int32": [1, 2, 3], + "field_int64": [], + "field_float": [], + "field_double": [], + "field_string": [], + "field_binary_vector": [], + "field_float_vector": [] + }`) + + validator = NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.NotNil(t, err) + + // invalid value type + reader = strings.NewReader(`{ + "dummy": [], + "field_bool": [true], + "field_int8": [1], + "field_int16": [2], + "field_int32": [3], + "field_int64": [4], + "field_float": [1], + "field_double": [1], + "field_string": [9], + "field_binary_vector": [[254, 1]], + "field_float_vector": [[1.1, 1.2, 1.3, 1.4]] + }`) + + validator = NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.NotNil(t, err) + + // init failed + validator.validators = nil + err = validator.Handle(nil) + assert.NotNil(t, err) +} + +func Test_JSONRowConsumer(t *testing.T) { + ctx := context.Background() + idAllocator := newIDAllocator(ctx, t) + + schema := sampleSchema() + parser := NewJSONParser(ctx, schema) + assert.NotNil(t, parser) + + reader := strings.NewReader(`{ + "rows":[ + {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}, + {"field_bool": false, "field_int8": 11, "field_int16": 102, "field_int32": 1002, "field_int64": 10002, "field_float": 3.15, "field_double": 2.56, "field_string": "hello world", "field_binary_vector": [253, 0], "field_float_vector": [2.1, 2.2, 2.3, 2.4]}, + {"field_bool": true, "field_int8": 12, "field_int16": 103, "field_int32": 1003, "field_int64": 10003, "field_float": 3.16, "field_double": 3.56, "field_string": "hello world", "field_binary_vector": [252, 0], "field_float_vector": [3.1, 3.2, 3.3, 3.4]}, + {"field_bool": false, "field_int8": 13, "field_int16": 104, "field_int32": 1004, "field_int64": 10004, "field_float": 3.17, "field_double": 4.56, "field_string": "hello world", "field_binary_vector": [251, 0], "field_float_vector": [4.1, 4.2, 4.3, 4.4]}, + {"field_bool": true, "field_int8": 14, "field_int16": 105, "field_int32": 1005, "field_int64": 10005, "field_float": 3.18, "field_double": 5.56, "field_string": "hello world", "field_binary_vector": [250, 0], "field_float_vector": [5.1, 5.2, 5.3, 5.4]} + ] + }`) + + var callTime int32 + var totalCount int + consumeFunc := func(fields map[string]storage.FieldData) error { + callTime++ + rowCount := 0 + for _, data := range fields { + if rowCount == 0 { + rowCount = data.RowNum() + } else { + assert.Equal(t, rowCount, data.RowNum()) + } + } + totalCount += rowCount + return nil + } + + var shardNum int32 = 2 + consumer := NewJSONRowConsumer(schema, idAllocator, shardNum, 1, consumeFunc) + assert.NotNil(t, consumer) + + validator := NewJSONRowValidator(schema, consumer) + err := parser.ParseRows(reader, validator) + assert.Nil(t, err) + assert.Equal(t, int64(5), validator.ValidateCount()) + + assert.Equal(t, shardNum, callTime) + assert.Equal(t, 5, totalCount) +} + +func Test_JSONColumnConsumer(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := sampleSchema() + parser := NewJSONParser(ctx, schema) + assert.NotNil(t, parser) + + reader := strings.NewReader(`{ + "field_bool": [true, false, true, true, true], + "field_int8": [10, 11, 12, 13, 14], + "field_int16": [100, 101, 102, 103, 104], + "field_int32": [1000, 1001, 1002, 1003, 1004], + "field_int64": [10000, 10001, 10002, 10003, 10004], + "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], + "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], + "field_string": ["a", "b", "c", "d", "e"], + "field_binary_vector": [ + [254, 1], + [253, 2], + [252, 3], + [251, 4], + [250, 5] + ], + "field_float_vector": [ + [1.1, 1.2, 1.3, 1.4], + [2.1, 2.2, 2.3, 2.4], + [3.1, 3.2, 3.3, 3.4], + [4.1, 4.2, 4.3, 4.4], + [5.1, 5.2, 5.3, 5.4] + ] + }`) + + callTime := 0 + rowCount := 0 + consumeFunc := func(fields map[string]storage.FieldData) error { + callTime++ + for _, data := range fields { + if rowCount == 0 { + rowCount = data.RowNum() + } else { + assert.Equal(t, rowCount, data.RowNum()) + } + } + return nil + } + + consumer := NewJSONColumnConsumer(schema, consumeFunc) + assert.NotNil(t, consumer) + + validator := NewJSONColumnValidator(schema, consumer) + err := parser.ParseColumns(reader, validator) + assert.Nil(t, err) + for _, count := range validator.ValidateCount() { + assert.Equal(t, int64(5), count) + } + + assert.Equal(t, 1, callTime) + assert.Equal(t, 5, rowCount) +} diff --git a/internal/util/importutil/json_parser.go b/internal/util/importutil/json_parser.go new file mode 100644 index 0000000000..94227f550d --- /dev/null +++ b/internal/util/importutil/json_parser.go @@ -0,0 +1,238 @@ +package importutil + +import ( + "context" + "encoding/json" + "errors" + "io" + + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/schemapb" +) + +const ( + // root field of row-based json format + RowRootNode = "rows" + // initial size of a buffer + BufferSize = 1024 +) + +type JSONParser struct { + ctx context.Context // for canceling parse process + bufSize int64 // max rows in a buffer + fields map[string]int64 // fields need to be parsed +} + +// newImportManager helper function to create a importManager +func NewJSONParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema) *JSONParser { + fields := make(map[string]int64) + for i := 0; i < len(collectionSchema.Fields); i++ { + schema := collectionSchema.Fields[i] + fields[schema.GetName()] = 0 + } + + parser := &JSONParser{ + ctx: ctx, + bufSize: 4096, + fields: fields, + } + + return parser +} + +func (p *JSONParser) logError(msg string) error { + log.Error(msg) + return errors.New(msg) +} + +func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { + if handler == nil { + return p.logError("JSON parse handler is nil") + } + + dec := json.NewDecoder(r) + + t, err := dec.Token() + if err != nil { + return p.logError("JSON parse: " + err.Error()) + } + if t != json.Delim('{') { + return p.logError("JSON parse: invalid JSON format, the content should be started with'{'") + } + + // read the first level + for dec.More() { + // read the key + t, err := dec.Token() + if err != nil { + return p.logError("JSON parse: " + err.Error()) + } + key := t.(string) + + // the root key should be RowRootNode + if key != RowRootNode { + return p.logError("JSON parse: invalid row-based JSON format, the key " + RowRootNode + " is not found") + } + + // started by '[' + t, err = dec.Token() + if err != nil { + return p.logError("JSON parse: " + err.Error()) + } + + if t != json.Delim('[') { + return p.logError("JSON parse: invalid row-based JSON format, rows list should begin with '['") + } + + // read buffer + buf := make([]map[string]interface{}, 0, BufferSize) + for dec.More() { + var value interface{} + if err := dec.Decode(&value); err != nil { + return p.logError("JSON parse: " + err.Error()) + } + + switch value.(type) { + case map[string]interface{}: + break + default: + return p.logError("JSON parse: invalid JSON format, each row should be a key-value map") + } + + row := value.(map[string]interface{}) + + buf = append(buf, row) + if len(buf) >= int(p.bufSize) { + if err = handler.Handle(buf); err != nil { + return p.logError(err.Error()) + } + + // clear the buffer + buf = make([]map[string]interface{}, 0, BufferSize) + } + } + + // some rows in buffer not parsed, parse them + if len(buf) > 0 { + if err = handler.Handle(buf); err != nil { + return p.logError(err.Error()) + } + } + + // end by ']' + t, err = dec.Token() + if err != nil { + return p.logError("JSON parse: " + err.Error()) + } + + if t != json.Delim(']') { + return p.logError("JSON parse: invalid column-based JSON format, rows list should end with a ']'") + } + + // canceled? + select { + case <-p.ctx.Done(): + return p.logError("import task was canceled") + default: + break + } + + // this break means we require the first node must be RowRootNode + // once the RowRootNode is parsed, just finish + break + } + + // send nil to notify the handler all have done + return handler.Handle(nil) +} + +func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error { + if handler == nil { + return p.logError("JSON parse handler is nil") + } + + dec := json.NewDecoder(r) + + t, err := dec.Token() + if err != nil { + return p.logError("JSON parse: " + err.Error()) + } + if t != json.Delim('{') { + return p.logError("JSON parse: invalid JSON format, the content should be started with'{'") + } + + // read the first level + for dec.More() { + // read the key + t, err := dec.Token() + if err != nil { + return p.logError("JSON parse: " + err.Error()) + } + key := t.(string) + + // not a valid column name, skip + _, isValidField := p.fields[key] + + // started by '[' + t, err = dec.Token() + if err != nil { + return p.logError("JSON parse: " + err.Error()) + } + + if t != json.Delim('[') { + return p.logError("JSON parse: invalid column-based JSON format, each field should begin with '['") + } + + // read buffer + buf := make(map[string][]interface{}) + buf[key] = make([]interface{}, 0, BufferSize) + for dec.More() { + var value interface{} + if err := dec.Decode(&value); err != nil { + return p.logError("JSON parse: " + err.Error()) + } + + if !isValidField { + continue + } + + buf[key] = append(buf[key], value) + if len(buf[key]) >= int(p.bufSize) { + if err = handler.Handle(buf); err != nil { + return p.logError(err.Error()) + } + + // clear the buffer + buf[key] = make([]interface{}, 0, BufferSize) + } + } + + // some values in buffer not parsed, parse them + if len(buf[key]) > 0 { + if err = handler.Handle(buf); err != nil { + return p.logError(err.Error()) + } + } + + // end by ']' + t, err = dec.Token() + if err != nil { + return p.logError("JSON parse: " + err.Error()) + } + + if t != json.Delim(']') { + return p.logError("JSON parse: invalid column-based JSON format, each field should end with a ']'") + } + + // canceled? + select { + case <-p.ctx.Done(): + return p.logError("import task was canceled") + default: + break + } + } + + // send nil to notify the handler all have done + return handler.Handle(nil) +} diff --git a/internal/util/importutil/json_parser_test.go b/internal/util/importutil/json_parser_test.go new file mode 100644 index 0000000000..2de3721550 --- /dev/null +++ b/internal/util/importutil/json_parser_test.go @@ -0,0 +1,256 @@ +package importutil + +import ( + "context" + "strings" + "testing" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/stretchr/testify/assert" +) + +func sampleSchema() *schemapb.CollectionSchema { + schema := &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 102, + Name: "field_bool", + IsPrimaryKey: false, + Description: "bool", + DataType: schemapb.DataType_Bool, + }, + { + FieldID: 103, + Name: "field_int8", + IsPrimaryKey: false, + Description: "int8", + DataType: schemapb.DataType_Int8, + }, + { + FieldID: 104, + Name: "field_int16", + IsPrimaryKey: false, + Description: "int16", + DataType: schemapb.DataType_Int16, + }, + { + FieldID: 105, + Name: "field_int32", + IsPrimaryKey: false, + Description: "int32", + DataType: schemapb.DataType_Int32, + }, + { + FieldID: 106, + Name: "field_int64", + IsPrimaryKey: true, + AutoID: false, + Description: "int64", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 107, + Name: "field_float", + IsPrimaryKey: false, + Description: "float", + DataType: schemapb.DataType_Float, + }, + { + FieldID: 108, + Name: "field_double", + IsPrimaryKey: false, + Description: "double", + DataType: schemapb.DataType_Double, + }, + { + FieldID: 109, + Name: "field_string", + IsPrimaryKey: false, + Description: "string", + DataType: schemapb.DataType_String, + }, + { + FieldID: 110, + Name: "field_binary_vector", + IsPrimaryKey: false, + Description: "binary_vector", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "16"}, + }, + }, + { + FieldID: 111, + Name: "field_float_vector", + IsPrimaryKey: false, + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + return schema +} + +func Test_ParserRows(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := sampleSchema() + parser := NewJSONParser(ctx, schema) + assert.NotNil(t, parser) + parser.bufSize = 1 + + reader := strings.NewReader(`{ + "rows":[ + {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}, + {"field_bool": false, "field_int8": 11, "field_int16": 102, "field_int32": 1002, "field_int64": 10002, "field_float": 3.15, "field_double": 2.56, "field_string": "hello world", "field_binary_vector": [253, 0], "field_float_vector": [2.1, 2.2, 2.3, 2.4]}, + {"field_bool": true, "field_int8": 12, "field_int16": 103, "field_int32": 1003, "field_int64": 10003, "field_float": 3.16, "field_double": 3.56, "field_string": "hello world", "field_binary_vector": [252, 0], "field_float_vector": [3.1, 3.2, 3.3, 3.4]}, + {"field_bool": false, "field_int8": 13, "field_int16": 104, "field_int32": 1004, "field_int64": 10004, "field_float": 3.17, "field_double": 4.56, "field_string": "hello world", "field_binary_vector": [251, 0], "field_float_vector": [4.1, 4.2, 4.3, 4.4]}, + {"field_bool": true, "field_int8": 14, "field_int16": 105, "field_int32": 1005, "field_int64": 10005, "field_float": 3.18, "field_double": 5.56, "field_string": "hello world", "field_binary_vector": [250, 0], "field_float_vector": [5.1, 5.2, 5.3, 5.4]} + ] + }`) + + err := parser.ParseRows(reader, nil) + assert.NotNil(t, err) + + validator := NewJSONRowValidator(schema, nil) + err = parser.ParseRows(reader, validator) + assert.Nil(t, err) + assert.Equal(t, int64(5), validator.ValidateCount()) + + reader = strings.NewReader(`{ + "dummy":[] + }`) + validator = NewJSONRowValidator(schema, nil) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) + + reader = strings.NewReader(`{ + "rows": + }`) + validator = NewJSONRowValidator(schema, nil) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) + + reader = strings.NewReader(`{ + "rows": [} + }`) + validator = NewJSONRowValidator(schema, nil) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) + + reader = strings.NewReader(`{ + "rows": {} + }`) + validator = NewJSONRowValidator(schema, nil) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) + + reader = strings.NewReader(`{ + "rows": [[]] + }`) + validator = NewJSONRowValidator(schema, nil) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) + + reader = strings.NewReader(`[]`) + validator = NewJSONRowValidator(schema, nil) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) + + reader = strings.NewReader(``) + validator = NewJSONRowValidator(schema, nil) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) +} + +func Test_ParserColumns(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := sampleSchema() + parser := NewJSONParser(ctx, schema) + assert.NotNil(t, parser) + parser.bufSize = 1 + + reader := strings.NewReader(`{ + "field_bool": [true, false, true, true, true], + "field_int8": [10, 11, 12, 13, 14], + "field_int16": [100, 101, 102, 103, 104], + "field_int32": [1000, 1001, 1002, 1003, 1004], + "field_int64": [10000, 10001, 10002, 10003, 10004], + "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], + "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], + "field_string": ["a", "b", "c", "d", "e"], + "field_binary_vector": [ + [254, 1], + [253, 2], + [252, 3], + [251, 4], + [250, 5] + ], + "field_float_vector": [ + [1.1, 1.2, 1.3, 1.4], + [2.1, 2.2, 2.3, 2.4], + [3.1, 3.2, 3.3, 3.4], + [4.1, 4.2, 4.3, 4.4], + [5.1, 5.2, 5.3, 5.4] + ] + }`) + + err := parser.ParseColumns(reader, nil) + assert.NotNil(t, err) + + validator := NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.Nil(t, err) + counter := validator.ValidateCount() + for _, v := range counter { + assert.Equal(t, int64(5), v) + } + + reader = strings.NewReader(`{ + "dummy":[1, 2, 3] + }`) + validator = NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.Nil(t, err) + + reader = strings.NewReader(`{ + "field_bool": + }`) + validator = NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.NotNil(t, err) + + reader = strings.NewReader(`{ + "field_bool":{} + }`) + validator = NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.NotNil(t, err) + + reader = strings.NewReader(`{ + "field_bool":[} + }`) + validator = NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.NotNil(t, err) + + reader = strings.NewReader(`[]`) + validator = NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.NotNil(t, err) + + reader = strings.NewReader(``) + validator = NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.NotNil(t, err) +}