package importutil import ( "bufio" "context" "errors" "os" "path" "strconv" "strings" "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/typeutil" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) const ( JSONFileExt = ".json" NumpyFileExt = ".npy" ) type ImportWrapper struct { ctx context.Context // for canceling parse process cancel context.CancelFunc // for canceling parse process collectionSchema *schemapb.CollectionSchema // collection schema shardNum int32 // sharding number of the collection segmentSize int32 // maximum size of a segment in MB rowIDAllocator *allocator.IDAllocator // autoid allocator callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush a segment } func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int32, idAlloc *allocator.IDAllocator, flushFunc func(fields map[string]storage.FieldData) error) *ImportWrapper { if collectionSchema == nil { log.Error("import error: collection schema is nil") return nil } // ignore the RowID field and Timestamp field realSchema := &schemapb.CollectionSchema{ Name: collectionSchema.GetName(), Description: collectionSchema.GetDescription(), AutoID: collectionSchema.GetAutoID(), Fields: make([]*schemapb.FieldSchema, 0), } for i := 0; i < len(collectionSchema.Fields); i++ { schema := collectionSchema.Fields[i] if schema.GetName() == common.RowIDFieldName || schema.GetName() == common.TimeStampFieldName { continue } realSchema.Fields = append(realSchema.Fields, schema) } ctx, cancel := context.WithCancel(ctx) wrapper := &ImportWrapper{ ctx: ctx, cancel: cancel, collectionSchema: realSchema, shardNum: shardNum, segmentSize: segmentSize, rowIDAllocator: idAlloc, callFlushFunc: flushFunc, } return wrapper } // this method can be used to cancel parse process func (p *ImportWrapper) Cancel() error { p.cancel() return nil } func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[string]storage.FieldData, msg string, files []string) { stats := make([]zapcore.Field, 0) for k, v := range fieldsData { stats = append(stats, zap.Int(k, v.RowNum())) } if len(files) > 0 { stats = append(stats, zap.Any("files", files)) } log.Debug(msg, stats...) } func getFileNameAndExt(filePath string) (string, string) { fileName := path.Base(filePath) fileType := path.Ext(fileName) fileNameWithoutExt := strings.TrimSuffix(fileName, fileType) return fileNameWithoutExt, fileType } // import process entry // filePath and rowBased are from ImportTask // if onlyValidate is true, this process only do validation, no data generated, callFlushFunc will not be called func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate bool) error { if rowBased { // parse and consume row-based files // for row-based files, the JSONRowConsumer will generate autoid for primary key, and split rows into segments // according to shard number, so the callFlushFunc will be called in the JSONRowConsumer for i := 0; i < len(filePaths); i++ { filePath := filePaths[i] _, fileType := getFileNameAndExt(filePath) log.Debug("imprort wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType)) if fileType == JSONFileExt { err := func() error { file, err := os.Open(filePath) if err != nil { return err } defer file.Close() reader := bufio.NewReader(file) parser := NewJSONParser(p.ctx, p.collectionSchema) var consumer *JSONRowConsumer if !onlyValidate { flushFunc := func(fields map[string]storage.FieldData) error { p.printFieldsDataInfo(fields, "import wrapper: prepare to flush segment", filePaths) return p.callFlushFunc(fields) } consumer = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, p.segmentSize, flushFunc) } validator := NewJSONRowValidator(p.collectionSchema, consumer) err = parser.ParseRows(reader, validator) if err != nil { return err } return nil }() if err != nil { log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) return err } } } } else { // parse and consume row-based files // for column-based files, the XXXColumnConsumer only output map[string]storage.FieldData // after all columns are parsed/consumed, we need to combine map[string]storage.FieldData into one // and use splitFieldsData() to split fields data into segments according to shard number fieldsData := initSegmentData(p.collectionSchema) rowCount := 0 // function to combine column data into fieldsData combineFunc := func(fields map[string]storage.FieldData) error { if len(fields) == 0 { return nil } p.printFieldsDataInfo(fields, "imprort wrapper: combine field data", nil) fieldNames := make([]string, 0) for k, v := range fields { // ignore 0 row field if v.RowNum() == 0 { continue } // each column should be only combined once data, ok := fieldsData[k] if ok && data.RowNum() > 0 { return errors.New("the field " + k + " is duplicated") } // check the row count. only count non-zero row fields if rowCount > 0 && rowCount != v.RowNum() { return errors.New("the field " + k + " row count " + strconv.Itoa(v.RowNum()) + " doesn't equal " + strconv.Itoa(rowCount)) } rowCount = v.RowNum() // assign column data to fieldsData fieldsData[k] = v fieldNames = append(fieldNames, k) } return nil } // parse/validate/consume data for i := 0; i < len(filePaths); i++ { filePath := filePaths[i] fileName, fileType := getFileNameAndExt(filePath) log.Debug("imprort wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType)) if fileType == JSONFileExt { err := func() error { file, err := os.Open(filePath) if err != nil { log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) return err } defer file.Close() reader := bufio.NewReader(file) parser := NewJSONParser(p.ctx, p.collectionSchema) var consumer *JSONColumnConsumer if !onlyValidate { consumer = NewJSONColumnConsumer(p.collectionSchema, combineFunc) } validator := NewJSONColumnValidator(p.collectionSchema, consumer) err = parser.ParseColumns(reader, validator) if err != nil { log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) return err } return nil }() if err != nil { log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) return err } } else if fileType == NumpyFileExt { file, err := os.Open(filePath) if err != nil { log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) return err } defer file.Close() // the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine flushFunc := func(field storage.FieldData) error { fields := make(map[string]storage.FieldData) fields[fileName] = field combineFunc(fields) return nil } // for numpy file, we say the file name(without extension) is the filed name parser := NewNumpyParser(p.ctx, p.collectionSchema, flushFunc) err = parser.Parse(file, fileName, onlyValidate) if err != nil { log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) return err } } } // split fields data into segments err := p.splitFieldsData(fieldsData, filePaths) if err != nil { log.Error("imprort error: " + err.Error()) return err } } return nil } func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error { switch schema.DataType { case schemapb.DataType_Bool: return func(src storage.FieldData, n int, target storage.FieldData) error { arr := target.(*storage.BoolFieldData) arr.Data = append(arr.Data, src.GetRow(n).(bool)) arr.NumRows[0]++ return nil } case schemapb.DataType_Float: return func(src storage.FieldData, n int, target storage.FieldData) error { arr := target.(*storage.FloatFieldData) arr.Data = append(arr.Data, src.GetRow(n).(float32)) arr.NumRows[0]++ return nil } case schemapb.DataType_Double: return func(src storage.FieldData, n int, target storage.FieldData) error { arr := target.(*storage.DoubleFieldData) arr.Data = append(arr.Data, src.GetRow(n).(float64)) arr.NumRows[0]++ return nil } case schemapb.DataType_Int8: return func(src storage.FieldData, n int, target storage.FieldData) error { arr := target.(*storage.Int8FieldData) arr.Data = append(arr.Data, src.GetRow(n).(int8)) arr.NumRows[0]++ return nil } case schemapb.DataType_Int16: return func(src storage.FieldData, n int, target storage.FieldData) error { arr := target.(*storage.Int16FieldData) arr.Data = append(arr.Data, src.GetRow(n).(int16)) arr.NumRows[0]++ return nil } case schemapb.DataType_Int32: return func(src storage.FieldData, n int, target storage.FieldData) error { arr := target.(*storage.Int32FieldData) arr.Data = append(arr.Data, src.GetRow(n).(int32)) arr.NumRows[0]++ return nil } case schemapb.DataType_Int64: return func(src storage.FieldData, n int, target storage.FieldData) error { arr := target.(*storage.Int64FieldData) arr.Data = append(arr.Data, src.GetRow(n).(int64)) arr.NumRows[0]++ return nil } case schemapb.DataType_BinaryVector: return func(src storage.FieldData, n int, target storage.FieldData) error { arr := target.(*storage.BinaryVectorFieldData) arr.Data = append(arr.Data, src.GetRow(n).([]byte)...) arr.NumRows[0]++ return nil } case schemapb.DataType_FloatVector: return func(src storage.FieldData, n int, target storage.FieldData) error { arr := target.(*storage.FloatVectorFieldData) arr.Data = append(arr.Data, src.GetRow(n).([]float32)...) arr.NumRows[0]++ return nil } case schemapb.DataType_String: return func(src storage.FieldData, n int, target storage.FieldData) error { arr := target.(*storage.StringFieldData) arr.Data = append(arr.Data, src.GetRow(n).(string)) return nil } default: return nil } } func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData, files []string) error { if len(fieldsData) == 0 { return errors.New("imprort error: fields data is empty") } var primaryKey *schemapb.FieldSchema for i := 0; i < len(p.collectionSchema.Fields); i++ { schema := p.collectionSchema.Fields[i] if schema.GetIsPrimaryKey() { primaryKey = schema } else { _, ok := fieldsData[schema.GetName()] if !ok { return errors.New("imprort error: field " + schema.GetName() + " not provided") } } } if primaryKey == nil { return errors.New("imprort error: primary key field is not found") } rowCount := 0 for _, v := range fieldsData { rowCount = v.RowNum() break } primaryData, ok := fieldsData[primaryKey.GetName()] if !ok { // generate auto id for primary key if primaryKey.GetAutoID() { var rowIDBegin typeutil.UniqueID var rowIDEnd typeutil.UniqueID rowIDBegin, rowIDEnd, _ = p.rowIDAllocator.Alloc(uint32(rowCount)) primaryDataArr := primaryData.(*storage.Int64FieldData) for i := rowIDBegin; i < rowIDEnd; i++ { primaryDataArr.Data = append(primaryDataArr.Data, rowIDBegin+i) } } } if primaryData.RowNum() <= 0 { return errors.New("imprort error: primary key " + primaryKey.GetName() + " not provided") } // prepare segemnts segmentsData := make([]map[string]storage.FieldData, 0, p.shardNum) for i := 0; i < int(p.shardNum); i++ { segmentData := initSegmentData(p.collectionSchema) if segmentData == nil { return nil } segmentsData = append(segmentsData, segmentData) } // prepare append functions appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error) for i := 0; i < len(p.collectionSchema.Fields); i++ { schema := p.collectionSchema.Fields[i] appendFunc := p.appendFunc(schema) if appendFunc == nil { return errors.New("imprort error: unsupported field data type") } appendFunctions[schema.GetName()] = appendFunc } // split data into segments for i := 0; i < rowCount; i++ { id := primaryData.GetRow(i).(int64) // hash to a shard number hash, _ := typeutil.Hash32Int64(id) shard := hash % uint32(p.shardNum) for k := 0; k < len(p.collectionSchema.Fields); k++ { schema := p.collectionSchema.Fields[k] srcData := fieldsData[schema.GetName()] targetData := segmentsData[shard][schema.GetName()] appendFunc := appendFunctions[schema.GetName()] err := appendFunc(srcData, i, targetData) if err != nil { return err } } } // call flush function for i := 0; i < int(p.shardNum); i++ { segmentData := segmentsData[i] p.printFieldsDataInfo(segmentData, "import wrapper: prepare to flush segment", files) err := p.callFlushFunc(segmentData) if err != nil { return err } } return nil }