diff --git a/internal/util/importutil/import_util.go b/internal/util/importutil/import_util.go index 47b90889a0..a4aaf34710 100644 --- a/internal/util/importutil/import_util.go +++ b/internal/util/importutil/import_util.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/typeutil" ) func isCanceled(ctx context.Context) bool { @@ -519,3 +520,22 @@ func getTypeName(dt schemapb.DataType) string { return "InvalidType" } } + +func pkToShard(pk interface{}, shardNum uint32) (uint32, error) { + var shard uint32 + strPK, ok := interface{}(pk).(string) + if ok { + hash := typeutil.HashString2Uint32(strPK) + shard = hash % shardNum + } else { + intPK, ok := interface{}(pk).(int64) + if !ok { + log.Error("Numpy parser: primary key field must be int64 or varchar") + return 0, fmt.Errorf("primary key field must be int64 or varchar") + } + hash, _ := typeutil.Hash32Int64(intPK) + shard = hash % shardNum + } + + return shard, nil +} diff --git a/internal/util/importutil/import_util_test.go b/internal/util/importutil/import_util_test.go index 0e87651960..c20cc51b36 100644 --- a/internal/util/importutil/import_util_test.go +++ b/internal/util/importutil/import_util_test.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/stretchr/testify/assert" ) @@ -611,3 +612,31 @@ func Test_GetTypeName(t *testing.T) { str = getTypeName(schemapb.DataType_None) assert.Equal(t, "InvalidType", str) } + +func Test_PkToShard(t *testing.T) { + a := int32(99) + shard, err := pkToShard(a, 2) + assert.Error(t, err) + assert.Zero(t, shard) + + s := "abcdef" + shardNum := uint32(3) + shard, err = pkToShard(s, shardNum) + assert.NoError(t, err) + hash := typeutil.HashString2Uint32(s) + assert.Equal(t, hash%shardNum, shard) + + pk := int64(100) + shardNum = uint32(4) + shard, err = pkToShard(pk, shardNum) + assert.NoError(t, err) + hash, _ = typeutil.Hash32Int64(pk) + assert.Equal(t, hash%shardNum, shard) + + pk = int64(99999) + shardNum = uint32(5) + shard, err = pkToShard(pk, shardNum) + assert.NoError(t, err) + hash, _ = typeutil.Hash32Int64(pk) + assert.Equal(t, hash%shardNum, shard) +} diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go index b9bf63d9fd..79d4667eb2 100644 --- a/internal/util/importutil/import_wrapper.go +++ b/internal/util/importutil/import_wrapper.go @@ -34,7 +34,6 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/timerecord" - "github.com/milvus-io/milvus/internal/util/typeutil" ) const ( @@ -47,7 +46,7 @@ const ( // this limitation is to avoid this OOM risk: // for column-based file, we read all its data into memory, if user input a large file, the read() method may // cost extra memory and lear to OOM. - MaxFileSize = 1 * 1024 * 1024 * 1024 // 1GB + MaxFileSize = 16 * 1024 * 1024 * 1024 // 16GB // this limitation is to avoid this OOM risk: // simetimes system segment max size is a large number, a single segment fields data might cause OOM. @@ -175,42 +174,6 @@ func (p *ImportWrapper) Cancel() error { return nil } -func (p *ImportWrapper) validateColumnBasedFiles(filePaths []string, collectionSchema *schemapb.CollectionSchema) error { - requiredFieldNames := make(map[string]interface{}) - for _, schema := range p.collectionSchema.Fields { - if schema.GetIsPrimaryKey() { - if !schema.GetAutoID() { - requiredFieldNames[schema.GetName()] = nil - } - } else { - requiredFieldNames[schema.GetName()] = nil - } - } - - // check redundant file - fileNames := make(map[string]interface{}) - for _, filePath := range filePaths { - name, _ := GetFileNameAndExt(filePath) - fileNames[name] = nil - _, ok := requiredFieldNames[name] - if !ok { - log.Error("import wrapper: the file has no corresponding field in collection", zap.String("fieldName", name)) - return fmt.Errorf("the file '%s' has no corresponding field in collection", filePath) - } - } - - // check missed file - for name := range requiredFieldNames { - _, ok := fileNames[name] - if !ok { - log.Error("import wrapper: there is no file corresponding to field", zap.String("fieldName", name)) - return fmt.Errorf("there is no file corresponding to field '%s'", name) - } - } - - return nil -} - // fileValidation verify the input paths // if all the files are json type, return true // if all the files are numpy type, return false, and not allow duplicate file name @@ -278,22 +241,6 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) { totalSize += size } - // especially for column-base, total size of files cannot exceed MaxTotalSizeInMemory - if totalSize > MaxTotalSizeInMemory { - log.Error("import wrapper: total size of files exceeds the maximum size", zap.Int64("totalSize", totalSize), zap.Int64("MaxTotalSize", MaxTotalSizeInMemory)) - return rowBased, fmt.Errorf("total size(%d bytes) of all files exceeds the maximum size: %d bytes", totalSize, MaxTotalSizeInMemory) - } - - // check redundant files for column-based import - // if the field is primary key and autoid is false, the file is required - // any redundant file is not allowed - if !rowBased { - err := p.validateColumnBasedFiles(filePaths, p.collectionSchema) - if err != nil { - return rowBased, err - } - } - return rowBased, nil } @@ -337,84 +284,26 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error triggerGC() } } else { - // parse and consume column-based files - // for column-based files, the XXXColumnConsumer only output map[string]storage.FieldData - // after all columns are parsed/consumed, we need to combine map[string]storage.FieldData into one - // and use splitFieldsData() to split fields data into segments according to shard number - fieldsData := initSegmentData(p.collectionSchema) - if fieldsData == nil { - log.Error("import wrapper: failed to initialize FieldData list") - return fmt.Errorf("failed to initialize FieldData list") + // parse and consume column-based files(currently support numpy) + // for column-based files, the NumpyParser will generate autoid for primary key, and split rows into segments + // according to shard number, so the flushFunc will be called in the NumpyParser + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { + printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths) + return p.flushFunc(fields, shardID) } - - rowCount := 0 - - // function to combine column data into fieldsData - combineFunc := func(fields map[storage.FieldID]storage.FieldData) error { - if len(fields) == 0 { - return nil - } - - printFieldsDataInfo(fields, "import wrapper: combine field data", nil) - for k, v := range fields { - // ignore 0 row field - if v.RowNum() == 0 { - log.Warn("import wrapper: empty FieldData ignored", zap.Int64("fieldID", k)) - continue - } - - // ignore internal fields: RowIDField and TimeStampField - if k == common.RowIDField || k == common.TimeStampField { - log.Warn("import wrapper: internal fields should not be provided", zap.Int64("fieldID", k)) - continue - } - - // each column should be only combined once - data, ok := fieldsData[k] - if ok && data.RowNum() > 0 { - return fmt.Errorf("the field %d is duplicated", k) - } - - // check the row count. only count non-zero row fields - if rowCount > 0 && rowCount != v.RowNum() { - return fmt.Errorf("the field %d row count %d doesn't equal to others row count: %d", k, v.RowNum(), rowCount) - } - rowCount = v.RowNum() - - // assign column data to fieldsData - fieldsData[k] = v - } - - return nil - } - - // parse/validate/consume data - for i := 0; i < len(filePaths); i++ { - filePath := filePaths[i] - _, fileType := GetFileNameAndExt(filePath) - log.Info("import wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType)) - - if fileType == NumpyFileExt { - err = p.parseColumnBasedNumpy(filePath, options.OnlyValidate, combineFunc) - - if err != nil { - log.Error("import wrapper: failed to parse column-based numpy file", zap.Error(err), zap.String("filePath", filePath)) - return err - } - } - // no need to check else, since the fileValidation() already do this - } - - // trigger after read finished - triggerGC() - - // split fields data into segments - err := p.splitFieldsData(fieldsData, SingleBlockSize) + parser, err := NewNumpyParser(p.ctx, p.collectionSchema, p.rowIDAllocator, p.shardNum, SingleBlockSize, p.chunkManager, flushFunc) if err != nil { return err } - // trigger after write finished + err = parser.Parse(filePaths) + if err != nil { + return err + } + + p.importResult.AutoIds = append(p.importResult.AutoIds, parser.IDRange()...) + + // trigger after parse finished triggerGC() } @@ -437,6 +326,7 @@ func (p *ImportWrapper) reportPersisted(reportAttempts uint, tr *timerecord.Time // report file process state p.importResult.State = commonpb.ImportState_ImportPersisted + log.Info("import wrapper: report import result", zap.Any("importResult", p.importResult)) // persist state task is valuable, retry more times in case fail this task only because of network error reportErr := retry.Do(p.ctx, func() error { return p.reportFunc(p.importResult) @@ -554,297 +444,12 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er } // for row-based files, auto-id is generated within JSONRowConsumer - if consumer != nil { - p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...) - } + p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...) tr.Elapse("parsed") return nil } -// parseColumnBasedNumpy is the entry of column-based numpy import operation -func (p *ImportWrapper) parseColumnBasedNumpy(filePath string, onlyValidate bool, - combineFunc func(fields map[storage.FieldID]storage.FieldData) error) error { - tr := timerecord.NewTimeRecorder("numpy parser: " + filePath) - - fileName, _ := GetFileNameAndExt(filePath) - - // for minio storage, chunkManager will download file into local memory - // for local storage, chunkManager open the file directly - file, err := p.chunkManager.Reader(p.ctx, filePath) - if err != nil { - return err - } - defer file.Close() - - var id storage.FieldID - var found = false - for _, field := range p.collectionSchema.Fields { - if field.GetName() == fileName { - id = field.GetFieldID() - found = true - break - } - } - - // if the numpy file name is not mapping to a field name, ignore it - if !found { - return nil - } - - // the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine - flushFunc := func(field storage.FieldData) error { - fields := make(map[storage.FieldID]storage.FieldData) - fields[id] = field - return combineFunc(fields) - } - - // for numpy file, we say the file name(without extension) is the filed name - parser := NewNumpyParser(p.ctx, p.collectionSchema, flushFunc) - err = parser.Parse(file, fileName, onlyValidate) - if err != nil { - return err - } - - tr.Elapse("parsed") - return nil -} - -// appendFunc defines the methods to append data to storage.FieldData -func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error { - switch schema.DataType { - case schemapb.DataType_Bool: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.BoolFieldData) - arr.Data = append(arr.Data, src.GetRow(n).(bool)) - arr.NumRows[0]++ - return nil - } - case schemapb.DataType_Float: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.FloatFieldData) - arr.Data = append(arr.Data, src.GetRow(n).(float32)) - arr.NumRows[0]++ - return nil - } - case schemapb.DataType_Double: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.DoubleFieldData) - arr.Data = append(arr.Data, src.GetRow(n).(float64)) - arr.NumRows[0]++ - return nil - } - case schemapb.DataType_Int8: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.Int8FieldData) - arr.Data = append(arr.Data, src.GetRow(n).(int8)) - arr.NumRows[0]++ - return nil - } - case schemapb.DataType_Int16: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.Int16FieldData) - arr.Data = append(arr.Data, src.GetRow(n).(int16)) - arr.NumRows[0]++ - return nil - } - case schemapb.DataType_Int32: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.Int32FieldData) - arr.Data = append(arr.Data, src.GetRow(n).(int32)) - arr.NumRows[0]++ - return nil - } - case schemapb.DataType_Int64: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.Int64FieldData) - arr.Data = append(arr.Data, src.GetRow(n).(int64)) - arr.NumRows[0]++ - return nil - } - case schemapb.DataType_BinaryVector: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.BinaryVectorFieldData) - arr.Data = append(arr.Data, src.GetRow(n).([]byte)...) - arr.NumRows[0]++ - return nil - } - case schemapb.DataType_FloatVector: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.FloatVectorFieldData) - arr.Data = append(arr.Data, src.GetRow(n).([]float32)...) - arr.NumRows[0]++ - return nil - } - case schemapb.DataType_String, schemapb.DataType_VarChar: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.StringFieldData) - arr.Data = append(arr.Data, src.GetRow(n).(string)) - return nil - } - default: - return nil - } -} - -// splitFieldsData is to split the in-memory data(parsed from column-based files) into blocks, each block save to a binlog file -func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, blockSize int64) error { - if len(fieldsData) == 0 { - log.Error("import wrapper: fields data is empty") - return fmt.Errorf("fields data is empty") - } - - tr := timerecord.NewTimeRecorder("import wrapper: split field data") - defer tr.Elapse("finished") - - // check existence of each field - // check row count, all fields row count must be equal - // firstly get the max row count - rowCount := 0 - rowCounter := make(map[string]int) - var primaryKey *schemapb.FieldSchema - for i := 0; i < len(p.collectionSchema.Fields); i++ { - schema := p.collectionSchema.Fields[i] - if schema.GetIsPrimaryKey() { - primaryKey = schema - } - - if !schema.GetAutoID() { - v, ok := fieldsData[schema.GetFieldID()] - if !ok { - log.Error("import wrapper: field not provided", zap.String("fieldName", schema.GetName())) - return fmt.Errorf("field '%s' not provided", schema.GetName()) - } - rowCounter[schema.GetName()] = v.RowNum() - if v.RowNum() > rowCount { - rowCount = v.RowNum() - } - } - } - if primaryKey == nil { - log.Error("import wrapper: primary key field is not found") - return fmt.Errorf("primary key field is not found") - } - - for name, count := range rowCounter { - if count != rowCount { - log.Error("import wrapper: field row count is not equal to other fields row count", zap.String("fieldName", name), - zap.Int("rowCount", count), zap.Int("otherRowCount", rowCount)) - return fmt.Errorf("field '%s' row count %d is not equal to other fields row count: %d", name, count, rowCount) - } - } - log.Info("import wrapper: try to split a block with row count", zap.Int("rowCount", rowCount)) - - primaryData, ok := fieldsData[primaryKey.GetFieldID()] - if !ok { - log.Error("import wrapper: primary key field is not provided", zap.String("keyName", primaryKey.GetName())) - return fmt.Errorf("primary key field is not provided") - } - - // generate auto id for primary key and rowid field - rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount)) - if err != nil { - log.Error("import wrapper: failed to alloc row ID", zap.Error(err)) - return fmt.Errorf("failed to alloc row ID, error: %w", err) - } - - rowIDField := fieldsData[common.RowIDField] - rowIDFieldArr := rowIDField.(*storage.Int64FieldData) - for i := rowIDBegin; i < rowIDEnd; i++ { - rowIDFieldArr.Data = append(rowIDFieldArr.Data, i) - } - - if primaryKey.GetAutoID() { - log.Info("import wrapper: generating auto-id", zap.Int("rowCount", rowCount), zap.Int64("rowIDBegin", rowIDBegin)) - - // reset the primary keys, as we know, only int64 pk can be auto-generated - primaryDataArr := &storage.Int64FieldData{ - NumRows: []int64{int64(rowCount)}, - Data: make([]int64, 0, rowCount), - } - for i := rowIDBegin; i < rowIDEnd; i++ { - primaryDataArr.Data = append(primaryDataArr.Data, i) - } - - primaryData = primaryDataArr - fieldsData[primaryKey.GetFieldID()] = primaryData - p.importResult.AutoIds = append(p.importResult.AutoIds, rowIDBegin, rowIDEnd) - } - - if primaryData.RowNum() <= 0 { - log.Error("import wrapper: primary key is not provided", zap.String("keyName", primaryKey.GetName())) - return fmt.Errorf("the primary key '%s' is not provided", primaryKey.GetName()) - } - - // prepare segemnts - segmentsData := make([]map[storage.FieldID]storage.FieldData, 0, p.shardNum) - for i := 0; i < int(p.shardNum); i++ { - segmentData := initSegmentData(p.collectionSchema) - if segmentData == nil { - log.Error("import wrapper: failed to initialize FieldData list") - return fmt.Errorf("failed to initialize FieldData list") - } - segmentsData = append(segmentsData, segmentData) - } - - // prepare append functions - appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error) - for i := 0; i < len(p.collectionSchema.Fields); i++ { - schema := p.collectionSchema.Fields[i] - appendFuncErr := p.appendFunc(schema) - if appendFuncErr == nil { - log.Error("import wrapper: unsupported field data type") - return fmt.Errorf("unsupported field data type: %d", schema.GetDataType()) - } - appendFunctions[schema.GetName()] = appendFuncErr - } - - // split data into shards - for i := 0; i < rowCount; i++ { - // hash to a shard number - var shard uint32 - pk := primaryData.GetRow(i) - strPK, ok := interface{}(pk).(string) - if ok { - hash := typeutil.HashString2Uint32(strPK) - shard = hash % uint32(p.shardNum) - } else { - intPK, ok := interface{}(pk).(int64) - if !ok { - log.Error("import wrapper: primary key field must be int64 or varchar") - return fmt.Errorf("primary key field must be int64 or varchar") - } - hash, _ := typeutil.Hash32Int64(intPK) - shard = hash % uint32(p.shardNum) - } - - // set rowID field - rowIDField := segmentsData[shard][common.RowIDField].(*storage.Int64FieldData) - rowIDField.Data = append(rowIDField.Data, rowIDFieldArr.GetRow(i).(int64)) - - // append row to shard - for k := 0; k < len(p.collectionSchema.Fields); k++ { - schema := p.collectionSchema.Fields[k] - srcData := fieldsData[schema.GetFieldID()] - targetData := segmentsData[shard][schema.GetFieldID()] - appendFunc := appendFunctions[schema.GetName()] - err := appendFunc(srcData, i, targetData) - if err != nil { - return err - } - } - - // when the estimated size is close to blockSize, force flush - err = tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, false) - if err != nil { - return err - } - } - - // force flush at the end - return tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, true) -} - // flushFunc is the callback function for parsers generate segment and save binlog files func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData, shardID int) error { // if fields data is empty, do nothing diff --git a/internal/util/importutil/import_wrapper_test.go b/internal/util/importutil/import_wrapper_test.go index 798840e2f8..a3c6217200 100644 --- a/internal/util/importutil/import_wrapper_test.go +++ b/internal/util/importutil/import_wrapper_test.go @@ -579,79 +579,6 @@ func Test_ImportWrapperRowBased_perf(t *testing.T) { tr.Record("parse large json file " + filePath) } -func Test_ImportWrapperValidateColumnBasedFiles(t *testing.T) { - ctx := context.Background() - - cm := &MockChunkManager{ - size: 1, - } - - idAllocator := newIDAllocator(ctx, t, nil) - shardNum := 2 - segmentSize := 512 // unit: MB - - schema := &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 101, - Name: "ID", - IsPrimaryKey: true, - AutoID: true, - DataType: schemapb.DataType_Int64, - }, - { - FieldID: 102, - Name: "Age", - IsPrimaryKey: false, - DataType: schemapb.DataType_Int64, - }, - { - FieldID: 103, - Name: "Vector", - IsPrimaryKey: false, - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "dim", Value: "10"}, - }, - }, - }, - } - - wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) - - // file for PK is redundant - files := []string{"ID.npy", "Age.npy", "Vector.npy"} - err := wrapper.validateColumnBasedFiles(files, schema) - assert.NotNil(t, err) - - // file for PK is not redundant - schema.Fields[0].AutoID = false - err = wrapper.validateColumnBasedFiles(files, schema) - assert.Nil(t, err) - - // file missed - files = []string{"Age.npy", "Vector.npy"} - err = wrapper.validateColumnBasedFiles(files, schema) - assert.NotNil(t, err) - - files = []string{"ID.npy", "Vector.npy"} - err = wrapper.validateColumnBasedFiles(files, schema) - assert.NotNil(t, err) - - // redundant file - files = []string{"ID.npy", "Age.npy", "Vector.npy", "dummy.npy"} - err = wrapper.validateColumnBasedFiles(files, schema) - assert.NotNil(t, err) - - // correct input - files = []string{"ID.npy", "Age.npy", "Vector.npy"} - err = wrapper.validateColumnBasedFiles(files, schema) - assert.Nil(t, err) -} - func Test_ImportWrapperFileValidation(t *testing.T) { ctx := context.Background() @@ -668,7 +595,7 @@ func Test_ImportWrapperFileValidation(t *testing.T) { FieldID: 101, Name: "uid", IsPrimaryKey: true, - AutoID: false, + AutoID: true, DataType: schemapb.DataType_Int64, }, { @@ -684,84 +611,76 @@ func Test_ImportWrapperFileValidation(t *testing.T) { wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) - // unsupported file type - files := []string{"uid.txt"} - rowBased, err := wrapper.fileValidation(files) - assert.NotNil(t, err) - assert.False(t, rowBased) + t.Run("unsupported file type", func(t *testing.T) { + files := []string{"uid.txt"} + rowBased, err := wrapper.fileValidation(files) + assert.Error(t, err) + assert.False(t, rowBased) + }) - // file missed - files = []string{"uid.npy"} - rowBased, err = wrapper.fileValidation(files) - assert.NotNil(t, err) - assert.False(t, rowBased) + t.Run("duplicate files", func(t *testing.T) { + files := []string{"a/1.json", "b/1.json"} + rowBased, err := wrapper.fileValidation(files) + assert.Error(t, err) + assert.True(t, rowBased) - // redundant file - files = []string{"uid.npy", "b/bol.npy", "c/no.npy"} - rowBased, err = wrapper.fileValidation(files) - assert.NotNil(t, err) - assert.False(t, rowBased) + files = []string{"a/uid.npy", "uid.npy", "b/bol.npy"} + rowBased, err = wrapper.fileValidation(files) + assert.Error(t, err) + assert.False(t, rowBased) + }) - // duplicate files - files = []string{"a/1.json", "b/1.json"} - rowBased, err = wrapper.fileValidation(files) - assert.NotNil(t, err) - assert.True(t, rowBased) + t.Run("unsupported file for row-based", func(t *testing.T) { + files := []string{"a/uid.json", "b/bol.npy"} + rowBased, err := wrapper.fileValidation(files) + assert.Error(t, err) + assert.True(t, rowBased) + }) - files = []string{"a/uid.npy", "uid.npy", "b/bol.npy"} - rowBased, err = wrapper.fileValidation(files) - assert.NotNil(t, err) - assert.False(t, rowBased) + t.Run("unsupported file for column-based", func(t *testing.T) { + files := []string{"a/uid.npy", "b/bol.json"} + rowBased, err := wrapper.fileValidation(files) + assert.Error(t, err) + assert.False(t, rowBased) + }) - // unsupported file for row-based - files = []string{"a/uid.json", "b/bol.npy"} - rowBased, err = wrapper.fileValidation(files) - assert.NotNil(t, err) - assert.True(t, rowBased) + t.Run("valid cases", func(t *testing.T) { + files := []string{"a/1.json", "b/2.json"} + rowBased, err := wrapper.fileValidation(files) + assert.NoError(t, err) + assert.True(t, rowBased) - // unsupported file for column-based - files = []string{"a/uid.npy", "b/bol.json"} - rowBased, err = wrapper.fileValidation(files) - assert.NotNil(t, err) - assert.False(t, rowBased) + files = []string{"a/uid.npy", "b/bol.npy"} + rowBased, err = wrapper.fileValidation(files) + assert.NoError(t, err) + assert.False(t, rowBased) + }) - // valid cases - files = []string{"a/1.json", "b/2.json"} - rowBased, err = wrapper.fileValidation(files) - assert.Nil(t, err) - assert.True(t, rowBased) + t.Run("empty file", func(t *testing.T) { + files := []string{} + cm.size = 0 + wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) + rowBased, err := wrapper.fileValidation(files) + assert.NoError(t, err) + assert.False(t, rowBased) + }) - files = []string{"a/uid.npy", "b/bol.npy"} - rowBased, err = wrapper.fileValidation(files) - assert.Nil(t, err) - assert.False(t, rowBased) + t.Run("file size exceed MaxFileSize limit", func(t *testing.T) { + files := []string{"a/1.json"} + cm.size = MaxFileSize + 1 + wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) + rowBased, err := wrapper.fileValidation(files) + assert.NotNil(t, err) + assert.True(t, rowBased) + }) - // empty file - cm.size = 0 - wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) - rowBased, err = wrapper.fileValidation(files) - assert.NotNil(t, err) - assert.False(t, rowBased) - - // file size exceed MaxFileSize limit - cm.size = MaxFileSize + 1 - wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) - rowBased, err = wrapper.fileValidation(files) - assert.NotNil(t, err) - assert.False(t, rowBased) - - // total files size exceed MaxTotalSizeInMemory limit - cm.size = MaxFileSize - 1 - files = append(files, "3.npy") - rowBased, err = wrapper.fileValidation(files) - assert.NotNil(t, err) - assert.False(t, rowBased) - - // failed to get file size - cm.sizeErr = errors.New("error") - rowBased, err = wrapper.fileValidation(files) - assert.NotNil(t, err) - assert.False(t, rowBased) + t.Run("failed to get file size", func(t *testing.T) { + files := []string{"a/1.json"} + cm.sizeErr = errors.New("error") + rowBased, err := wrapper.fileValidation(files) + assert.NotNil(t, err) + assert.True(t, rowBased) + }) } func Test_ImportWrapperReportFailRowBased(t *testing.T) { @@ -1001,122 +920,6 @@ func Test_ImportWrapperDoBinlogImport(t *testing.T) { assert.Nil(t, err) } -func Test_ImportWrapperSplitFieldsData(t *testing.T) { - ctx := context.Background() - - cm := &MockChunkManager{} - - idAllocator := newIDAllocator(ctx, t, nil) - - rowCounter := &rowCounterTest{} - assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) - - importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - TaskId: 1, - DatanodeId: 1, - State: commonpb.ImportState_ImportStarted, - Segments: make([]int64, 0), - AutoIds: make([]int64, 0), - RowCount: 0, - } - reportFunc := func(res *rootcoordpb.ImportResult) error { - return nil - } - - schema := &schemapb.CollectionSchema{ - Name: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 101, - Name: "uid", - IsPrimaryKey: true, - AutoID: true, - DataType: schemapb.DataType_Int64, - }, - { - FieldID: 102, - Name: "flag", - IsPrimaryKey: false, - DataType: schemapb.DataType_Bool, - }, - }, - } - - wrapper := NewImportWrapper(ctx, schema, 2, 1024*1024, idAllocator, cm, importResult, reportFunc) - wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) - - // nil input - err := wrapper.splitFieldsData(nil, 0) - assert.NotNil(t, err) - - // split 100 rows to 4 blocks, success - rowCount := 100 - input := initSegmentData(schema) - for j := 0; j < rowCount; j++ { - pkField := input[101].(*storage.Int64FieldData) - pkField.Data = append(pkField.Data, int64(j)) - - flagField := input[102].(*storage.BoolFieldData) - flagField.Data = append(flagField.Data, true) - } - - err = wrapper.splitFieldsData(input, 512) - assert.Nil(t, err) - assert.Equal(t, 2, len(importResult.AutoIds)) - assert.Equal(t, 4, rowCounter.callTime) - assert.Equal(t, rowCount, rowCounter.rowCount) - - // alloc id failed - wrapper.rowIDAllocator = newIDAllocator(ctx, t, errors.New("error")) - err = wrapper.splitFieldsData(input, 512) - assert.NotNil(t, err) - wrapper.rowIDAllocator = newIDAllocator(ctx, t, nil) - - // row count of fields are unequal - schema.Fields[0].AutoID = false - input = initSegmentData(schema) - for j := 0; j < rowCount; j++ { - pkField := input[101].(*storage.Int64FieldData) - pkField.Data = append(pkField.Data, int64(j)) - if j%2 == 0 { - continue - } - flagField := input[102].(*storage.BoolFieldData) - flagField.Data = append(flagField.Data, true) - } - err = wrapper.splitFieldsData(input, 512) - assert.NotNil(t, err) - - // primary key not found - wrapper.collectionSchema.Fields[0].IsPrimaryKey = false - err = wrapper.splitFieldsData(input, 512) - assert.NotNil(t, err) - wrapper.collectionSchema.Fields[0].IsPrimaryKey = true - - // primary key is varchar, success - wrapper.collectionSchema.Fields[0].DataType = schemapb.DataType_VarChar - input = initSegmentData(schema) - for j := 0; j < rowCount; j++ { - pkField := input[101].(*storage.StringFieldData) - pkField.Data = append(pkField.Data, strconv.FormatInt(int64(j), 10)) - - flagField := input[102].(*storage.BoolFieldData) - flagField.Data = append(flagField.Data, true) - } - rowCounter.callTime = 0 - rowCounter.rowCount = 0 - importResult.AutoIds = []int64{} - err = wrapper.splitFieldsData(input, 1024) - assert.Nil(t, err) - assert.Equal(t, 0, len(importResult.AutoIds)) - assert.Equal(t, 2, rowCounter.callTime) - assert.Equal(t, rowCount, rowCounter.rowCount) -} - func Test_ImportWrapperReportPersisted(t *testing.T) { ctx := context.Background() tr := timerecord.NewTimeRecorder("test") diff --git a/internal/util/importutil/json_parser.go b/internal/util/importutil/json_parser.go index ffe48cc04d..1e76865317 100644 --- a/internal/util/importutil/json_parser.go +++ b/internal/util/importutil/json_parser.go @@ -35,15 +35,11 @@ import ( const ( // root field of row-based json format RowRootNode = "rows" - // minimal size of a buffer - MinBufferSize = 1024 - // split file into batches no more than this count - MaxBatchCount = 16 ) type JSONParser struct { ctx context.Context // for canceling parse process - bufSize int64 // max rows in a buffer + bufRowCount int // max rows in a buffer fields map[string]int64 // fields need to be parsed name2FieldID map[string]storage.FieldID } @@ -69,7 +65,7 @@ func NewJSONParser(ctx context.Context, collectionSchema *schemapb.CollectionSch parser := &JSONParser{ ctx: ctx, - bufSize: MinBufferSize, + bufRowCount: 1024, fields: fields, name2FieldID: name2FieldID, } @@ -84,19 +80,24 @@ func adjustBufSize(parser *JSONParser, collectionSchema *schemapb.CollectionSche return } - // split the file into no more than MaxBatchCount batches to parse // for high dimensional vector, the bufSize is a small value, read few rows each time // for low dimensional vector, the bufSize is a large value, read more rows each time - maxRows := MaxFileSize / sizePerRecord - bufSize := maxRows / MaxBatchCount - - // bufSize should not be less than MinBufferSize - if bufSize < MinBufferSize { - bufSize = MinBufferSize + bufRowCount := parser.bufRowCount + for { + if bufRowCount*sizePerRecord > SingleBlockSize { + bufRowCount-- + } else { + break + } } - log.Info("JSON parser: reset bufSize", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufSize", bufSize)) - parser.bufSize = int64(bufSize) + // at least one row per buffer + if bufRowCount <= 0 { + bufRowCount = 1 + } + + log.Info("JSON parser: reset bufRowCount", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufRowCount", bufRowCount)) + parser.bufRowCount = bufRowCount } func (p *JSONParser) verifyRow(raw interface{}) (map[storage.FieldID]interface{}, error) { @@ -185,7 +186,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { } // read buffer - buf := make([]map[storage.FieldID]interface{}, 0, MinBufferSize) + buf := make([]map[storage.FieldID]interface{}, 0, p.bufRowCount) for dec.More() { var value interface{} if err := dec.Decode(&value); err != nil { @@ -199,7 +200,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { } buf = append(buf, row) - if len(buf) >= int(p.bufSize) { + if len(buf) >= p.bufRowCount { isEmpty = false if err = handler.Handle(buf); err != nil { log.Error("JSON parser: failed to convert row value to entity", zap.Error(err)) @@ -207,7 +208,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { } // clear the buffer - buf = make([]map[storage.FieldID]interface{}, 0, MinBufferSize) + buf = make([]map[storage.FieldID]interface{}, 0, p.bufRowCount) } } diff --git a/internal/util/importutil/json_parser_test.go b/internal/util/importutil/json_parser_test.go index 045b4f86fc..1c6f91caeb 100644 --- a/internal/util/importutil/json_parser_test.go +++ b/internal/util/importutil/json_parser_test.go @@ -28,7 +28,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/stretchr/testify/assert" ) @@ -58,11 +57,7 @@ func Test_AdjustBufSize(t *testing.T) { schema := sampleSchema() parser := NewJSONParser(ctx, schema) assert.NotNil(t, parser) - - sizePerRecord, err := typeutil.EstimateSizePerRecord(schema) - assert.Nil(t, err) - assert.Greater(t, sizePerRecord, 0) - assert.Equal(t, MaxBatchCount, MaxFileSize/(sizePerRecord*int(parser.bufSize))) + assert.Greater(t, parser.bufRowCount, 0) // huge row schema.Fields[9].TypeParams = []*commonpb.KeyValuePair{ @@ -70,9 +65,7 @@ func Test_AdjustBufSize(t *testing.T) { } parser = NewJSONParser(ctx, schema) assert.NotNil(t, parser) - sizePerRecord, _ = typeutil.EstimateSizePerRecord(schema) - - assert.Equal(t, 7, MaxFileSize/(sizePerRecord*int(parser.bufSize))) + assert.Greater(t, parser.bufRowCount, 0) // no change schema = &schemapb.CollectionSchema{ @@ -83,8 +76,7 @@ func Test_AdjustBufSize(t *testing.T) { } parser = NewJSONParser(ctx, schema) assert.NotNil(t, parser) - - assert.Equal(t, int64(MinBufferSize), parser.bufSize) + assert.Greater(t, parser.bufRowCount, 0) } func Test_JSONParserParseRows_IntPK(t *testing.T) { @@ -127,8 +119,8 @@ func Test_JSONParserParseRows_IntPK(t *testing.T) { } t.Run("parse success", func(t *testing.T) { - // set bufSize = 4, means call handle() after reading 4 rows - parser.bufSize = 4 + // set bufRowCount = 4, means call handle() after reading 4 rows + parser.bufRowCount = 4 err = parser.ParseRows(reader, consumer) assert.Nil(t, err) assert.Equal(t, len(content.Rows), len(consumer.rows)) @@ -285,12 +277,12 @@ func Test_JSONParserParseRows_IntPK(t *testing.T) { }` consumer.handleErr = errors.New("error") reader = strings.NewReader(content) - parser.bufSize = 2 + parser.bufRowCount = 2 err = parser.ParseRows(reader, consumer) assert.NotNil(t, err) reader = strings.NewReader(content) - parser.bufSize = 5 + parser.bufRowCount = 5 err = parser.ParseRows(reader, consumer) assert.NotNil(t, err) diff --git a/internal/util/importutil/numpy_adapter.go b/internal/util/importutil/numpy_adapter.go index 4efb183a16..19ca7374a3 100644 --- a/internal/util/importutil/numpy_adapter.go +++ b/internal/util/importutil/numpy_adapter.go @@ -86,11 +86,13 @@ type NumpyAdapter struct { func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) { r, err := npyio.NewReader(reader) if err != nil { + log.Error("Numpy adapter: failed to read numpy header", zap.Error(err)) return nil, err } dataType, err := convertNumpyType(r.Header.Descr.Type) if err != nil { + log.Error("Numpy adapter: failed to detect data type", zap.Error(err)) return nil, err } @@ -109,12 +111,11 @@ func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) { zap.Uint8("minorVer", r.Header.Minor), zap.String("ByteOrder", adapter.order.String())) - return adapter, err + return adapter, nil } // convertNumpyType gets data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector) func convertNumpyType(typeStr string) (schemapb.DataType, error) { - log.Info("Numpy adapter: parse numpy file dtype", zap.String("dtype", typeStr)) switch typeStr { case "b1", " 0 { buf = buf[:n] } + data = append(data, string(buf)) } } @@ -557,6 +608,7 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) { func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) { if len(src)%4 != 0 { + log.Error("Numpy adapter: invalid utf32 bytes length, the byte array length should be multiple of 4", zap.Int("byteLen", len(src))) return "", fmt.Errorf("invalid utf32 bytes length %d, the byte array length should be multiple of 4", len(src)) } @@ -586,6 +638,7 @@ func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) { decoder := unicode.UTF16(uOrder, unicode.IgnoreBOM).NewDecoder() res, err := decoder.Bytes(src[lowbytesPosition : lowbytesPosition+2]) if err != nil { + log.Error("Numpy adapter: failed to decode utf32 binary bytes", zap.Error(err)) return "", fmt.Errorf("failed to decode utf32 binary bytes, error: %w", err) } str += string(res) @@ -603,6 +656,7 @@ func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) { utf8Code := make([]byte, 4) utf8.EncodeRune(utf8Code, r) if r == utf8.RuneError { + log.Error("Numpy adapter: failed to convert 4 bytes unicode to utf8 rune", zap.Uint32("code", x)) return "", fmt.Errorf("failed to convert 4 bytes unicode %d to utf8 rune", x) } str += string(utf8Code) diff --git a/internal/util/importutil/numpy_adapter_test.go b/internal/util/importutil/numpy_adapter_test.go index f352acdd92..d47b32f408 100644 --- a/internal/util/importutil/numpy_adapter_test.go +++ b/internal/util/importutil/numpy_adapter_test.go @@ -17,6 +17,7 @@ package importutil import ( + "bytes" "encoding/binary" "io" "os" @@ -40,12 +41,12 @@ func Test_CreateNumpyFile(t *testing.T) { // directory doesn't exist data1 := []float32{1, 2, 3, 4, 5} err := CreateNumpyFile("/dummy_not_exist/dummy.npy", data1) - assert.NotNil(t, err) + assert.Error(t, err) // invalid data type data2 := make(map[string]int) err = CreateNumpyFile("/tmp/dummy.npy", data2) - assert.NotNil(t, err) + assert.Error(t, err) } func Test_CreateNumpyData(t *testing.T) { @@ -53,12 +54,12 @@ func Test_CreateNumpyData(t *testing.T) { data1 := []float32{1, 2, 3, 4, 5} buf, err := CreateNumpyData(data1) assert.NotNil(t, buf) - assert.Nil(t, err) + assert.NoError(t, err) // invalid data type data2 := make(map[string]int) buf, err = CreateNumpyData(data2) - assert.NotNil(t, err) + assert.Error(t, err) assert.Nil(t, buf) } @@ -66,7 +67,7 @@ func Test_ConvertNumpyType(t *testing.T) { checkFunc := func(inputs []string, output schemapb.DataType) { for i := 0; i < len(inputs); i++ { dt, err := convertNumpyType(inputs[i]) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, output, dt) } } @@ -80,7 +81,7 @@ func Test_ConvertNumpyType(t *testing.T) { checkFunc([]string{"f8", "f8", "float64"}, schemapb.DataType_Double) dt, err := convertNumpyType("dummy") - assert.NotNil(t, err) + assert.Error(t, err) assert.Equal(t, schemapb.DataType_None, dt) } @@ -88,25 +89,25 @@ func Test_StringLen(t *testing.T) { len, utf, err := stringLen("S1") assert.Equal(t, 1, len) assert.False(t, utf) - assert.Nil(t, err) + assert.NoError(t, err) len, utf, err = stringLen("2S") assert.Equal(t, 2, len) assert.False(t, utf) - assert.Nil(t, err) + assert.NoError(t, err) len, utf, err = stringLen("4U") assert.Equal(t, 4, len) assert.True(t, utf) - assert.Nil(t, err) + assert.NoError(t, err) len, utf, err = stringLen("dummy") - assert.NotNil(t, err) + assert.Error(t, err) assert.Equal(t, 0, len) assert.False(t, utf) } @@ -129,207 +130,337 @@ func Test_NumpyAdapterSetByteOrder(t *testing.T) { } func Test_NumpyAdapterReadError(t *testing.T) { - adapter := &NumpyAdapter{ - reader: nil, - npyReader: nil, - } - // reader size is zero - t.Run("test size is zero", func(t *testing.T) { - _, err := adapter.ReadBool(0) - assert.NotNil(t, err) - _, err = adapter.ReadUint8(0) - assert.NotNil(t, err) - _, err = adapter.ReadInt8(0) - assert.NotNil(t, err) - _, err = adapter.ReadInt16(0) - assert.NotNil(t, err) - _, err = adapter.ReadInt32(0) - assert.NotNil(t, err) - _, err = adapter.ReadInt64(0) - assert.NotNil(t, err) - _, err = adapter.ReadFloat32(0) - assert.NotNil(t, err) - _, err = adapter.ReadFloat64(0) - assert.NotNil(t, err) - }) + // t.Run("test size is zero", func(t *testing.T) { + // adapter := &NumpyAdapter{ + // reader: nil, + // npyReader: nil, + // } + // _, err := adapter.ReadBool(0) + // assert.Error(t, err) + // _, err = adapter.ReadUint8(0) + // assert.Error(t, err) + // _, err = adapter.ReadInt8(0) + // assert.Error(t, err) + // _, err = adapter.ReadInt16(0) + // assert.Error(t, err) + // _, err = adapter.ReadInt32(0) + // assert.Error(t, err) + // _, err = adapter.ReadInt64(0) + // assert.Error(t, err) + // _, err = adapter.ReadFloat32(0) + // assert.Error(t, err) + // _, err = adapter.ReadFloat64(0) + // assert.Error(t, err) + // }) - adapter = &NumpyAdapter{ - reader: &MockReader{}, - npyReader: &npy.Reader{}, + createAdatper := func(dt schemapb.DataType) *NumpyAdapter { + adapter := &NumpyAdapter{ + reader: &MockReader{}, + npyReader: &npy.Reader{ + Header: npy.Header{}, + }, + dataType: dt, + order: binary.BigEndian, + } + adapter.npyReader.Header.Descr.Shape = []int{1} + return adapter } t.Run("test read bool", func(t *testing.T) { - adapter.npyReader.Header.Descr.Type = "bool" + // type mismatch + adapter := createAdatper(schemapb.DataType_None) data, err := adapter.ReadBool(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) - adapter.npyReader.Header.Descr.Type = "dummy" + // reader is nil, cannot read + adapter = createAdatper(schemapb.DataType_Bool) data, err = adapter.ReadBool(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) + + // read one element from reader + adapter.reader = bytes.NewReader([]byte{1}) + data, err = adapter.ReadBool(1) + assert.NotEmpty(t, data) + assert.NoError(t, err) + + // nothing to read + data, err = adapter.ReadBool(1) + assert.Nil(t, data) + assert.NoError(t, err) }) t.Run("test read uint8", func(t *testing.T) { - adapter.npyReader.Header.Descr.Type = "u1" + // type mismatch + adapter := createAdatper(schemapb.DataType_None) + adapter.npyReader.Header.Descr.Type = "dummy" data, err := adapter.ReadUint8(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) - adapter.npyReader.Header.Descr.Type = "dummy" + // reader is nil, cannot read + adapter.npyReader.Header.Descr.Type = "u1" data, err = adapter.ReadUint8(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) + + // read one element from reader + adapter.reader = bytes.NewReader([]byte{1}) + data, err = adapter.ReadUint8(1) + assert.NotEmpty(t, data) + assert.NoError(t, err) + + // nothing to read + data, err = adapter.ReadUint8(1) + assert.Nil(t, data) + assert.NoError(t, err) }) t.Run("test read int8", func(t *testing.T) { - adapter.npyReader.Header.Descr.Type = "i1" + // type mismatch + adapter := createAdatper(schemapb.DataType_None) data, err := adapter.ReadInt8(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) - adapter.npyReader.Header.Descr.Type = "dummy" + // reader is nil, cannot read + adapter = createAdatper(schemapb.DataType_Int8) data, err = adapter.ReadInt8(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) + + // read one element from reader + adapter.reader = bytes.NewReader([]byte{1}) + data, err = adapter.ReadInt8(1) + assert.NotEmpty(t, data) + assert.NoError(t, err) + + // nothing to read + data, err = adapter.ReadInt8(1) + assert.Nil(t, data) + assert.NoError(t, err) }) t.Run("test read int16", func(t *testing.T) { - adapter.npyReader.Header.Descr.Type = "i2" + // type mismatch + adapter := createAdatper(schemapb.DataType_None) data, err := adapter.ReadInt16(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) - adapter.npyReader.Header.Descr.Type = "dummy" + // reader is nil, cannot read + adapter = createAdatper(schemapb.DataType_Int16) data, err = adapter.ReadInt16(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) + + // read one element from reader + adapter.reader = bytes.NewReader([]byte{1, 2}) + data, err = adapter.ReadInt16(1) + assert.NotEmpty(t, data) + assert.NoError(t, err) + + // nothing to read + data, err = adapter.ReadInt16(1) + assert.Nil(t, data) + assert.NoError(t, err) }) t.Run("test read int32", func(t *testing.T) { - adapter.npyReader.Header.Descr.Type = "i4" + // type mismatch + adapter := createAdatper(schemapb.DataType_None) data, err := adapter.ReadInt32(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) - adapter.npyReader.Header.Descr.Type = "dummy" + // reader is nil, cannot read + adapter = createAdatper(schemapb.DataType_Int32) data, err = adapter.ReadInt32(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) + + // read one element from reader + adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4}) + data, err = adapter.ReadInt32(1) + assert.NotEmpty(t, data) + assert.NoError(t, err) + + // nothing to read + data, err = adapter.ReadInt32(1) + assert.Nil(t, data) + assert.NoError(t, err) }) t.Run("test read int64", func(t *testing.T) { - adapter.npyReader.Header.Descr.Type = "i8" + // type mismatch + adapter := createAdatper(schemapb.DataType_None) data, err := adapter.ReadInt64(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) - adapter.npyReader.Header.Descr.Type = "dummy" + // reader is nil, cannot read + adapter = createAdatper(schemapb.DataType_Int64) data, err = adapter.ReadInt64(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) + + // read one element from reader + adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + data, err = adapter.ReadInt64(1) + assert.NotEmpty(t, data) + assert.NoError(t, err) + + // nothing to read + data, err = adapter.ReadInt64(1) + assert.Nil(t, data) + assert.NoError(t, err) }) t.Run("test read float", func(t *testing.T) { - adapter.npyReader.Header.Descr.Type = "f4" + // type mismatch + adapter := createAdatper(schemapb.DataType_None) data, err := adapter.ReadFloat32(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) - adapter.npyReader.Header.Descr.Type = "dummy" + // reader is nil, cannot read + adapter = createAdatper(schemapb.DataType_Float) data, err = adapter.ReadFloat32(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) + + // read one element from reader + adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4}) + data, err = adapter.ReadFloat32(1) + assert.NotEmpty(t, data) + assert.NoError(t, err) + + // nothing to read + data, err = adapter.ReadFloat32(1) + assert.Nil(t, data) + assert.NoError(t, err) }) t.Run("test read double", func(t *testing.T) { - adapter.npyReader.Header.Descr.Type = "f8" + // type mismatch + adapter := createAdatper(schemapb.DataType_None) data, err := adapter.ReadFloat64(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) - adapter.npyReader.Header.Descr.Type = "dummy" + // reader is nil, cannot read + adapter = createAdatper(schemapb.DataType_Double) data, err = adapter.ReadFloat64(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) + + // read one element from reader + adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + data, err = adapter.ReadFloat64(1) + assert.NotEmpty(t, data) + assert.NoError(t, err) + + // nothing to read + data, err = adapter.ReadFloat64(1) + assert.Nil(t, data) + assert.NoError(t, err) }) t.Run("test read varchar", func(t *testing.T) { - adapter.npyReader.Header.Descr.Type = "U3" + // type mismatch + adapter := createAdatper(schemapb.DataType_None) data, err := adapter.ReadString(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) - adapter.npyReader.Header.Descr.Type = "dummy" + // reader is nil, cannot read + adapter = createAdatper(schemapb.DataType_VarChar) + adapter.reader = nil + adapter.npyReader.Header.Descr.Type = "S3" data, err = adapter.ReadString(1) assert.Nil(t, data) - assert.NotNil(t, err) + assert.Error(t, err) + + // read one element from reader + adapter.reader = strings.NewReader("abc") + data, err = adapter.ReadString(1) + assert.NotEmpty(t, data) + assert.NoError(t, err) + + // nothing to read + data, err = adapter.ReadString(1) + assert.Nil(t, data) + assert.NoError(t, err) }) } func Test_NumpyAdapterRead(t *testing.T) { err := os.MkdirAll(TempFilesPath, os.ModePerm) - assert.Nil(t, err) + assert.NoError(t, err) defer os.RemoveAll(TempFilesPath) t.Run("test read bool", func(t *testing.T) { filePath := TempFilesPath + "bool.npy" data := []bool{true, false, true, false} err := CreateNumpyFile(filePath, data) - assert.Nil(t, err) + assert.NoError(t, err) file, err := os.Open(filePath) - assert.Nil(t, err) + assert.NoError(t, err) defer file.Close() adapter, err := NewNumpyAdapter(file) - assert.Nil(t, err) + assert.NoError(t, err) + // partly read res, err := adapter.ReadBool(len(data) - 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, len(data)-1, len(res)) for i := 0; i < len(res); i++ { assert.Equal(t, data[i], res[i]) } + // read the left data res, err = adapter.ReadBool(len(data)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, data[len(data)-1], res[0]) + // nothing to read res, err = adapter.ReadBool(len(data)) - assert.NotNil(t, err) + assert.NoError(t, err) assert.Nil(t, res) // incorrect type read resu1, err := adapter.ReadUint8(len(data)) - assert.NotNil(t, err) + assert.Error(t, err) assert.Nil(t, resu1) resi1, err := adapter.ReadInt8(len(data)) - assert.NotNil(t, err) + assert.Error(t, err) assert.Nil(t, resi1) resi2, err := adapter.ReadInt16(len(data)) - assert.NotNil(t, err) + assert.Error(t, err) assert.Nil(t, resi2) resi4, err := adapter.ReadInt32(len(data)) - assert.NotNil(t, err) + assert.Error(t, err) assert.Nil(t, resi4) resi8, err := adapter.ReadInt64(len(data)) - assert.NotNil(t, err) + assert.Error(t, err) assert.Nil(t, resi8) resf4, err := adapter.ReadFloat32(len(data)) - assert.NotNil(t, err) + assert.Error(t, err) assert.Nil(t, resf4) resf8, err := adapter.ReadFloat64(len(data)) - assert.NotNil(t, err) + assert.Error(t, err) assert.Nil(t, resf8) }) @@ -337,35 +468,38 @@ func Test_NumpyAdapterRead(t *testing.T) { filePath := TempFilesPath + "uint8.npy" data := []uint8{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) - assert.Nil(t, err) + assert.NoError(t, err) file, err := os.Open(filePath) - assert.Nil(t, err) + assert.NoError(t, err) defer file.Close() adapter, err := NewNumpyAdapter(file) - assert.Nil(t, err) + assert.NoError(t, err) + // partly read res, err := adapter.ReadUint8(len(data) - 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, len(data)-1, len(res)) for i := 0; i < len(res); i++ { assert.Equal(t, data[i], res[i]) } + // read the left data res, err = adapter.ReadUint8(len(data)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, data[len(data)-1], res[0]) + // nothing to read res, err = adapter.ReadUint8(len(data)) - assert.NotNil(t, err) + assert.NoError(t, err) assert.Nil(t, res) // incorrect type read resb, err := adapter.ReadBool(len(data)) - assert.NotNil(t, err) + assert.Error(t, err) assert.Nil(t, resb) }) @@ -373,30 +507,33 @@ func Test_NumpyAdapterRead(t *testing.T) { filePath := TempFilesPath + "int8.npy" data := []int8{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) - assert.Nil(t, err) + assert.NoError(t, err) file, err := os.Open(filePath) - assert.Nil(t, err) + assert.NoError(t, err) defer file.Close() adapter, err := NewNumpyAdapter(file) - assert.Nil(t, err) + assert.NoError(t, err) + // partly read res, err := adapter.ReadInt8(len(data) - 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, len(data)-1, len(res)) for i := 0; i < len(res); i++ { assert.Equal(t, data[i], res[i]) } + // read the left data res, err = adapter.ReadInt8(len(data)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, data[len(data)-1], res[0]) + // nothing to read res, err = adapter.ReadInt8(len(data)) - assert.NotNil(t, err) + assert.NoError(t, err) assert.Nil(t, res) }) @@ -404,30 +541,33 @@ func Test_NumpyAdapterRead(t *testing.T) { filePath := TempFilesPath + "int16.npy" data := []int16{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) - assert.Nil(t, err) + assert.NoError(t, err) file, err := os.Open(filePath) - assert.Nil(t, err) + assert.NoError(t, err) defer file.Close() adapter, err := NewNumpyAdapter(file) - assert.Nil(t, err) + assert.NoError(t, err) + // partly read res, err := adapter.ReadInt16(len(data) - 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, len(data)-1, len(res)) for i := 0; i < len(res); i++ { assert.Equal(t, data[i], res[i]) } + // read the left data res, err = adapter.ReadInt16(len(data)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, data[len(data)-1], res[0]) + // nothing to read res, err = adapter.ReadInt16(len(data)) - assert.NotNil(t, err) + assert.NoError(t, err) assert.Nil(t, res) }) @@ -435,30 +575,33 @@ func Test_NumpyAdapterRead(t *testing.T) { filePath := TempFilesPath + "int32.npy" data := []int32{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) - assert.Nil(t, err) + assert.NoError(t, err) file, err := os.Open(filePath) - assert.Nil(t, err) + assert.NoError(t, err) defer file.Close() adapter, err := NewNumpyAdapter(file) - assert.Nil(t, err) + assert.NoError(t, err) + // partly read res, err := adapter.ReadInt32(len(data) - 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, len(data)-1, len(res)) for i := 0; i < len(res); i++ { assert.Equal(t, data[i], res[i]) } + // read the left data res, err = adapter.ReadInt32(len(data)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, data[len(data)-1], res[0]) + // nothing to read res, err = adapter.ReadInt32(len(data)) - assert.NotNil(t, err) + assert.NoError(t, err) assert.Nil(t, res) }) @@ -466,30 +609,33 @@ func Test_NumpyAdapterRead(t *testing.T) { filePath := TempFilesPath + "int64.npy" data := []int64{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) - assert.Nil(t, err) + assert.NoError(t, err) file, err := os.Open(filePath) - assert.Nil(t, err) + assert.NoError(t, err) defer file.Close() adapter, err := NewNumpyAdapter(file) - assert.Nil(t, err) + assert.NoError(t, err) + // partly read res, err := adapter.ReadInt64(len(data) - 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, len(data)-1, len(res)) for i := 0; i < len(res); i++ { assert.Equal(t, data[i], res[i]) } + // read the left data res, err = adapter.ReadInt64(len(data)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, data[len(data)-1], res[0]) + // nothing to read res, err = adapter.ReadInt64(len(data)) - assert.NotNil(t, err) + assert.NoError(t, err) assert.Nil(t, res) }) @@ -497,30 +643,33 @@ func Test_NumpyAdapterRead(t *testing.T) { filePath := TempFilesPath + "float.npy" data := []float32{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) - assert.Nil(t, err) + assert.NoError(t, err) file, err := os.Open(filePath) - assert.Nil(t, err) + assert.NoError(t, err) defer file.Close() adapter, err := NewNumpyAdapter(file) - assert.Nil(t, err) + assert.NoError(t, err) + // partly read res, err := adapter.ReadFloat32(len(data) - 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, len(data)-1, len(res)) for i := 0; i < len(res); i++ { assert.Equal(t, data[i], res[i]) } + // read the left data res, err = adapter.ReadFloat32(len(data)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, data[len(data)-1], res[0]) + // nothing to read res, err = adapter.ReadFloat32(len(data)) - assert.NotNil(t, err) + assert.NoError(t, err) assert.Nil(t, res) }) @@ -528,30 +677,33 @@ func Test_NumpyAdapterRead(t *testing.T) { filePath := TempFilesPath + "double.npy" data := []float64{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) - assert.Nil(t, err) + assert.NoError(t, err) file, err := os.Open(filePath) - assert.Nil(t, err) + assert.NoError(t, err) defer file.Close() adapter, err := NewNumpyAdapter(file) - assert.Nil(t, err) + assert.NoError(t, err) + // partly read res, err := adapter.ReadFloat64(len(data) - 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, len(data)-1, len(res)) for i := 0; i < len(res); i++ { assert.Equal(t, data[i], res[i]) } + // read the left data res, err = adapter.ReadFloat64(len(data)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, data[len(data)-1], res[0]) + // nothing to read res, err = adapter.ReadFloat64(len(data)) - assert.NotNil(t, err) + assert.NoError(t, err) assert.Nil(t, res) }) @@ -589,19 +741,19 @@ func Test_NumpyAdapterRead(t *testing.T) { // count should greater than 0 res, err := adapter.ReadString(0) - assert.NotNil(t, err) + assert.Error(t, err) assert.Nil(t, res) // maxLen is zero npyReader.Header.Descr.Type = "S0" res, err = adapter.ReadString(1) - assert.NotNil(t, err) + assert.Error(t, err) assert.Nil(t, res) npyReader.Header.Descr.Type = "S" + strconv.FormatInt(int64(maxLen), 10) res, err = adapter.ReadString(len(values) + 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, len(values), len(res)) for i := 0; i < len(res); i++ { assert.Equal(t, values[i], res[i]) @@ -612,29 +764,33 @@ func Test_NumpyAdapterRead(t *testing.T) { filePath := TempFilesPath + "varchar1.npy" data := []string{"a ", "bbb", " c", "dd", "eeee", "fff"} err := CreateNumpyFile(filePath, data) - assert.Nil(t, err) + assert.NoError(t, err) file, err := os.Open(filePath) - assert.Nil(t, err) + assert.NoError(t, err) defer file.Close() adapter, err := NewNumpyAdapter(file) - assert.Nil(t, err) + assert.NoError(t, err) + + // partly read res, err := adapter.ReadString(len(data) - 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, len(data)-1, len(res)) for i := 0; i < len(res); i++ { assert.Equal(t, data[i], res[i]) } + // read the left data res, err = adapter.ReadString(len(data)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, data[len(data)-1], res[0]) + // nothing to read res, err = adapter.ReadString(len(data)) - assert.NotNil(t, err) + assert.NoError(t, err) assert.Nil(t, res) }) @@ -642,16 +798,16 @@ func Test_NumpyAdapterRead(t *testing.T) { filePath := TempFilesPath + "varchar2.npy" data := []string{"で と ど ", " 马克bbb", "$(한)삼각*"} err := CreateNumpyFile(filePath, data) - assert.Nil(t, err) + assert.NoError(t, err) file, err := os.Open(filePath) - assert.Nil(t, err) + assert.NoError(t, err) defer file.Close() adapter, err := NewNumpyAdapter(file) - assert.Nil(t, err) + assert.NoError(t, err) res, err := adapter.ReadString(len(data)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, len(data), len(res)) for i := 0; i < len(res); i++ { @@ -663,7 +819,7 @@ func Test_NumpyAdapterRead(t *testing.T) { func Test_DecodeUtf32(t *testing.T) { // wrong input res, err := decodeUtf32([]byte{1, 2}, binary.LittleEndian) - assert.NotNil(t, err) + assert.Error(t, err) assert.Empty(t, res) // this string contains ascii characters and unicode characters @@ -672,12 +828,12 @@ func Test_DecodeUtf32(t *testing.T) { // utf32 littleEndian of str src := []byte{97, 0, 0, 0, 100, 0, 0, 0, 228, 37, 0, 0, 9, 78, 0, 0, 126, 118, 0, 0, 181, 243, 1, 0, 144, 48, 0, 0, 153, 33, 0, 0} res, err = decodeUtf32(src, binary.LittleEndian) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, str, res) // utf32 bigEndian of str src = []byte{0, 0, 0, 97, 0, 0, 0, 100, 0, 0, 37, 228, 0, 0, 78, 9, 0, 0, 118, 126, 0, 1, 243, 181, 0, 0, 48, 144, 0, 0, 33, 153} res, err = decodeUtf32(src, binary.BigEndian) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, str, res) } diff --git a/internal/util/importutil/numpy_parser.go b/internal/util/importutil/numpy_parser.go index abd59d908b..cc8ed552bc 100644 --- a/internal/util/importutil/numpy_parser.go +++ b/internal/util/importutil/numpy_parser.go @@ -20,280 +20,527 @@ import ( "context" "errors" "fmt" - "io" "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/timerecord" + "github.com/milvus-io/milvus/internal/util/typeutil" "go.uber.org/zap" ) -type ColumnDesc struct { - name string // name of the target column - dt schemapb.DataType // data type of the target column - elementCount int // how many elements need to be read - dimension int // only for vector +type NumpyColumnReader struct { + fieldName string // name of the target column + fieldID storage.FieldID // ID of the target column + dataType schemapb.DataType // data type of the target column + rowCount int // how many rows need to be read + dimension int // only for vector + file storage.FileReader // file to be read + reader *NumpyAdapter // data reader +} + +func closeReaders(columnReaders []*NumpyColumnReader) { + for _, reader := range columnReaders { + if reader.file != nil { + err := reader.file.Close() + if err != nil { + log.Error("Numper parser: failed to close numpy file", zap.String("fileName", reader.fieldName+NumpyFileExt)) + } + } + } } type NumpyParser struct { ctx context.Context // for canceling parse process collectionSchema *schemapb.CollectionSchema // collection schema - columnDesc *ColumnDesc // description for target column - - columnData storage.FieldData // in-memory column data - callFlushFunc func(field storage.FieldData) error // call back function to output column data + rowIDAllocator *allocator.IDAllocator // autoid allocator + shardNum int32 // sharding number of the collection + blockSize int64 // maximum size of a read block(unit:byte) + chunkManager storage.ChunkManager // storage interfaces to browse/read the files + autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25 + callFlushFunc ImportFlushFunc // call back function to flush segment } // NewNumpyParser is helper function to create a NumpyParser -func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema, - flushFunc func(field storage.FieldData) error) *NumpyParser { - if collectionSchema == nil || flushFunc == nil { - return nil +func NewNumpyParser(ctx context.Context, + collectionSchema *schemapb.CollectionSchema, + idAlloc *allocator.IDAllocator, + shardNum int32, + blockSize int64, + chunkManager storage.ChunkManager, + flushFunc ImportFlushFunc) (*NumpyParser, error) { + if collectionSchema == nil { + log.Error("Numper parser: collection schema is nil") + return nil, errors.New("collection schema is nil") + } + + if idAlloc == nil { + log.Error("Numper parser: id allocator is nil") + return nil, errors.New("id allocator is nil") + } + + if chunkManager == nil { + log.Error("Numper parser: chunk manager pointer is nil") + return nil, errors.New("chunk manager pointer is nil") + } + + if flushFunc == nil { + log.Error("Numper parser: flush function is nil") + return nil, errors.New("flush function is nil") } parser := &NumpyParser{ ctx: ctx, collectionSchema: collectionSchema, - columnDesc: &ColumnDesc{}, + rowIDAllocator: idAlloc, + shardNum: shardNum, + blockSize: blockSize, + chunkManager: chunkManager, + autoIDRange: make([]int64, 0), callFlushFunc: flushFunc, } - return parser + return parser, nil } -func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error { - if adapter == nil { - log.Error("Numpy parser: numpy adapter is nil") - return errors.New("numpy adapter is nil") +func (p *NumpyParser) IDRange() []int64 { + return p.autoIDRange +} + +// Parse is the function entry +func (p *NumpyParser) Parse(filePaths []string) error { + // check redundant files for column-based import + // if the field is primary key and autoid is false, the file is required + // any redundant file is not allowed + err := p.validateFileNames(filePaths) + if err != nil { + return err } - // check existence of the target field - var schema *schemapb.FieldSchema - for i := 0; i < len(p.collectionSchema.Fields); i++ { - schema = p.collectionSchema.Fields[i] - if schema.GetName() == fieldName { - p.columnDesc.name = fieldName - break - } + // open files and verify file header + readers, err := p.createReaders(filePaths) + // make sure all the files are closed finially, must call this method before the function return + defer closeReaders(readers) + if err != nil { + return err } - if p.columnDesc.name == "" { - log.Error("Numpy parser: Numpy parser: the field is not found in collection schema", zap.String("fieldName", fieldName)) - return fmt.Errorf("the field name '%s' is not found in collection schema", fieldName) - } - - p.columnDesc.dt = schema.DataType - elementType := adapter.GetType() - shape := adapter.GetShape() - - var err error - // 1. field data type should be consist to numpy data type - // 2. vector field dimension should be consist to numpy shape - if schemapb.DataType_FloatVector == schema.DataType { - // float32/float64 numpy file can be used for float vector file, 2 reasons: - // 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit - // 2. for float64 numpy file, the performance is worse than float32 numpy file - if elementType != schemapb.DataType_Float && elementType != schemapb.DataType_Double { - log.Error("Numpy parser: illegal data type of numpy file for float vector field", zap.Any("dataType", elementType), - zap.String("fieldName", fieldName)) - return fmt.Errorf("illegal data type %s of numpy file for float vector field '%s'", getTypeName(elementType), schema.GetName()) - } - - // vector field, the shape should be 2 - if len(shape) != 2 { - log.Error("Numpy parser: illegal shape of numpy file for float vector field, shape should be 2", zap.Int("shape", len(shape)), - zap.String("fieldName", fieldName)) - return fmt.Errorf("illegal shape %d of numpy file for float vector field '%s', shape should be 2", shape, schema.GetName()) - } - - // shape[0] is row count, shape[1] is element count per row - p.columnDesc.elementCount = shape[0] * shape[1] - - p.columnDesc.dimension, err = getFieldDimension(schema) - if err != nil { - return err - } - - if shape[1] != p.columnDesc.dimension { - log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName), - zap.Int("numpyDimension", shape[1]), zap.Int("fieldDimension", p.columnDesc.dimension)) - return fmt.Errorf("illegal dimension %d of numpy file for float vector field '%s', dimension should be %d", - shape[1], schema.GetName(), p.columnDesc.dimension) - } - } else if schemapb.DataType_BinaryVector == schema.DataType { - if elementType != schemapb.DataType_BinaryVector { - log.Error("Numpy parser: illegal data type of numpy file for binary vector field", zap.Any("dataType", elementType), - zap.String("fieldName", fieldName)) - return fmt.Errorf("illegal data type %s of numpy file for binary vector field '%s'", getTypeName(elementType), schema.GetName()) - } - - // vector field, the shape should be 2 - if len(shape) != 2 { - log.Error("Numpy parser: illegal shape of numpy file for binary vector field, shape should be 2", zap.Int("shape", len(shape)), - zap.String("fieldName", fieldName)) - return fmt.Errorf("illegal shape %d of numpy file for binary vector field '%s', shape should be 2", shape, schema.GetName()) - } - - // shape[0] is row count, shape[1] is element count per row - p.columnDesc.elementCount = shape[0] * shape[1] - - p.columnDesc.dimension, err = getFieldDimension(schema) - if err != nil { - return err - } - - if shape[1] != p.columnDesc.dimension/8 { - log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName), - zap.Int("numpyDimension", shape[1]*8), zap.Int("fieldDimension", p.columnDesc.dimension)) - return fmt.Errorf("illegal dimension %d of numpy file for binary vector field '%s', dimension should be %d", - shape[1]*8, schema.GetName(), p.columnDesc.dimension) - } - } else { - if elementType != schema.DataType { - log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType), - zap.String("fieldName", fieldName), zap.Any("fieldDataType", schema.DataType)) - return fmt.Errorf("illegal data type %s of numpy file for scalar field '%s' with type %s", - getTypeName(elementType), schema.GetName(), getTypeName(schema.DataType)) - } - - // scalar field, the shape should be 1 - if len(shape) != 1 { - log.Error("Numpy parser: illegal shape of numpy file for scalar field, shape should be 1", zap.Int("shape", len(shape)), - zap.String("fieldName", fieldName)) - return fmt.Errorf("illegal shape %d of numpy file for scalar field '%s', shape should be 1", shape, schema.GetName()) - } - - p.columnDesc.elementCount = shape[0] + // read all data from the numpy files + err = p.consume(readers) + if err != nil { + return err } return nil } +// validateFileNames is to check redundant file and missed file +func (p *NumpyParser) validateFileNames(filePaths []string) error { + requiredFieldNames := make(map[string]interface{}) + for _, schema := range p.collectionSchema.Fields { + if schema.GetIsPrimaryKey() { + if !schema.GetAutoID() { + requiredFieldNames[schema.GetName()] = nil + } + } else { + requiredFieldNames[schema.GetName()] = nil + } + } + + // check redundant file + fileNames := make(map[string]interface{}) + for _, filePath := range filePaths { + name, _ := GetFileNameAndExt(filePath) + fileNames[name] = nil + _, ok := requiredFieldNames[name] + if !ok { + log.Error("Numpy parser: the file has no corresponding field in collection", zap.String("fieldName", name)) + return fmt.Errorf("the file '%s' has no corresponding field in collection", filePath) + } + } + + // check missed file + for name := range requiredFieldNames { + _, ok := fileNames[name] + if !ok { + log.Error("Numpy parser: there is no file corresponding to field", zap.String("fieldName", name)) + return fmt.Errorf("there is no file corresponding to field '%s'", name) + } + } + + return nil +} + +// createReaders open the files and verify file header +func (p *NumpyParser) createReaders(filePaths []string) ([]*NumpyColumnReader, error) { + readers := make([]*NumpyColumnReader, 0) + + for _, filePath := range filePaths { + fileName, _ := GetFileNameAndExt(filePath) + + // check existence of the target field + var schema *schemapb.FieldSchema + for i := 0; i < len(p.collectionSchema.Fields); i++ { + tmpSchema := p.collectionSchema.Fields[i] + if tmpSchema.GetName() == fileName { + schema = tmpSchema + break + } + } + + if schema == nil { + log.Error("Numpy parser: the field is not found in collection schema", zap.String("fileName", fileName)) + return nil, fmt.Errorf("the field name '%s' is not found in collection schema", fileName) + } + + file, err := p.chunkManager.Reader(p.ctx, filePath) + if err != nil { + log.Error("Numpy parser: failed to read the file", zap.String("filePath", filePath), zap.Error(err)) + return nil, fmt.Errorf("failed to read the file '%s', error: %s", filePath, err.Error()) + } + + adapter, err := NewNumpyAdapter(file) + if err != nil { + log.Error("Numpy parser: failed to read the file header", zap.String("filePath", filePath), zap.Error(err)) + return nil, fmt.Errorf("failed to read the file header '%s', error: %s", filePath, err.Error()) + } + + if file == nil || adapter == nil { + log.Error("Numpy parser: failed to open file", zap.String("filePath", filePath)) + return nil, fmt.Errorf("failed to open file '%s'", filePath) + } + + dim, _ := getFieldDimension(schema) + columnReader := &NumpyColumnReader{ + fieldName: schema.GetName(), + fieldID: schema.GetFieldID(), + dataType: schema.GetDataType(), + dimension: dim, + file: file, + reader: adapter, + } + + // the validation method only check the file header information + err = p.validateHeader(columnReader) + if err != nil { + return nil, err + } + readers = append(readers, columnReader) + } + + // row count of each file should be equal + if len(readers) > 0 { + firstReader := readers[0] + rowCount := firstReader.rowCount + for i := 1; i < len(readers); i++ { + compareReader := readers[i] + if rowCount != compareReader.rowCount { + log.Error("Numpy parser: the row count of files are not equal", + zap.String("firstFile", firstReader.fieldName), zap.Int("firstRowCount", firstReader.rowCount), + zap.String("compareFile", compareReader.fieldName), zap.Int("compareRowCount", compareReader.rowCount)) + return nil, fmt.Errorf("the row count(%d) of file '%s.npy' is not equal to row count(%d) of file '%s.npy'", + firstReader.rowCount, firstReader.fieldName, compareReader.rowCount, compareReader.fieldName) + } + } + } + + return readers, nil +} + +// validateHeader is to verify numpy file header, file header information should match field's schema +func (p *NumpyParser) validateHeader(columnReader *NumpyColumnReader) error { + if columnReader == nil || columnReader.reader == nil { + log.Error("Numpy parser: numpy reader is nil") + return errors.New("numpy adapter is nil") + } + + elementType := columnReader.reader.GetType() + shape := columnReader.reader.GetShape() + columnReader.rowCount = shape[0] + + // 1. field data type should be consist to numpy data type + // 2. vector field dimension should be consist to numpy shape + if schemapb.DataType_FloatVector == columnReader.dataType { + // float32/float64 numpy file can be used for float vector file, 2 reasons: + // 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit + // 2. for float64 numpy file, the performance is worse than float32 numpy file + if elementType != schemapb.DataType_Float && elementType != schemapb.DataType_Double { + log.Error("Numpy parser: illegal data type of numpy file for float vector field", zap.Any("dataType", elementType), + zap.String("fieldName", columnReader.fieldName)) + return fmt.Errorf("illegal data type %s of numpy file for float vector field '%s'", getTypeName(elementType), + columnReader.fieldName) + } + + // vector field, the shape should be 2 + if len(shape) != 2 { + log.Error("Numpy parser: illegal shape of numpy file for float vector field, shape should be 2", zap.Int("shape", len(shape)), + zap.String("fieldName", columnReader.fieldName)) + return fmt.Errorf("illegal shape %d of numpy file for float vector field '%s', shape should be 2", shape, + columnReader.fieldName) + } + + if shape[1] != columnReader.dimension { + log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", columnReader.fieldName), + zap.Int("numpyDimension", shape[1]), zap.Int("fieldDimension", columnReader.dimension)) + return fmt.Errorf("illegal dimension %d of numpy file for float vector field '%s', dimension should be %d", + shape[1], columnReader.fieldName, columnReader.dimension) + } + } else if schemapb.DataType_BinaryVector == columnReader.dataType { + if elementType != schemapb.DataType_BinaryVector { + log.Error("Numpy parser: illegal data type of numpy file for binary vector field", zap.Any("dataType", elementType), + zap.String("fieldName", columnReader.fieldName)) + return fmt.Errorf("illegal data type %s of numpy file for binary vector field '%s'", getTypeName(elementType), + columnReader.fieldName) + } + + // vector field, the shape should be 2 + if len(shape) != 2 { + log.Error("Numpy parser: illegal shape of numpy file for binary vector field, shape should be 2", zap.Int("shape", len(shape)), + zap.String("fieldName", columnReader.fieldName)) + return fmt.Errorf("illegal shape %d of numpy file for binary vector field '%s', shape should be 2", shape, + columnReader.fieldName) + } + + if shape[1] != columnReader.dimension/8 { + log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", columnReader.fieldName), + zap.Int("numpyDimension", shape[1]*8), zap.Int("fieldDimension", columnReader.dimension)) + return fmt.Errorf("illegal dimension %d of numpy file for binary vector field '%s', dimension should be %d", + shape[1]*8, columnReader.fieldName, columnReader.dimension) + } + } else { + if elementType != columnReader.dataType { + log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType), + zap.String("fieldName", columnReader.fieldName), zap.Any("fieldDataType", columnReader.dataType)) + return fmt.Errorf("illegal data type %s of numpy file for scalar field '%s' with type %s", + getTypeName(elementType), columnReader.fieldName, getTypeName(columnReader.dataType)) + } + + // scalar field, the shape should be 1 + if len(shape) != 1 { + log.Error("Numpy parser: illegal shape of numpy file for scalar field, shape should be 1", zap.Int("shape", len(shape)), + zap.String("fieldName", columnReader.fieldName)) + return fmt.Errorf("illegal shape %d of numpy file for scalar field '%s', shape should be 1", shape, columnReader.fieldName) + } + } + + return nil +} + +// calcRowCountPerBlock calculates a proper value for a batch row count to read file +func (p *NumpyParser) calcRowCountPerBlock() (int64, error) { + sizePerRecord, err := typeutil.EstimateSizePerRecord(p.collectionSchema) + if err != nil { + log.Error("Numpy parser: failed to estimate size of each row", zap.Error(err)) + return 0, fmt.Errorf("failed to estimate size of each row: %s", err.Error()) + } + + if sizePerRecord <= 0 { + log.Error("Numpy parser: failed to estimate size of each row, the collection schema might be empty") + return 0, fmt.Errorf("failed to estimate size of each row: the collection schema might be empty") + } + + // the sizePerRecord is estimate value, if the schema contains varchar field, the value is not accurate + // we will read data block by block, by default, each block size is 16MB + // rowCountPerBlock is the estimated row count for a block + rowCountPerBlock := p.blockSize / int64(sizePerRecord) + if rowCountPerBlock <= 0 { + rowCountPerBlock = 1 // make sure the value is positive + } + + log.Info("Numper parser: calculate row count per block to read file", zap.Int64("rowCountPerBlock", rowCountPerBlock), + zap.Int64("blockSize", p.blockSize), zap.Int("sizePerRecord", sizePerRecord)) + return rowCountPerBlock, nil +} + // consume method reads numpy data section into a storage.FieldData // please note it will require a large memory block(the memory size is almost equal to numpy file size) -func (p *NumpyParser) consume(adapter *NumpyAdapter) error { - switch p.columnDesc.dt { +func (p *NumpyParser) consume(columnReaders []*NumpyColumnReader) error { + rowCountPerBlock, err := p.calcRowCountPerBlock() + if err != nil { + return err + } + + // prepare shards + shards := make([]map[storage.FieldID]storage.FieldData, 0, p.shardNum) + for i := 0; i < int(p.shardNum); i++ { + segmentData := initSegmentData(p.collectionSchema) + if segmentData == nil { + log.Error("import wrapper: failed to initialize FieldData list") + return fmt.Errorf("failed to initialize FieldData list") + } + shards = append(shards, segmentData) + } + tr := timerecord.NewTimeRecorder("consume performance") + defer tr.Elapse("end") + // read data from files, batch by batch + for { + readRowCount := 0 + segmentData := make(map[storage.FieldID]storage.FieldData) + for _, reader := range columnReaders { + fieldData, err := p.readData(reader, int(rowCountPerBlock)) + if err != nil { + return err + } + + if readRowCount == 0 { + readRowCount = fieldData.RowNum() + } else if readRowCount != fieldData.RowNum() { + log.Error("Numpy parser: data block's row count mismatch", zap.Int("firstBlockRowCount", readRowCount), + zap.Int("thisBlockRowCount", fieldData.RowNum()), zap.Int64("rowCountPerBlock", rowCountPerBlock)) + return fmt.Errorf("data block's row count mismatch: %d vs %d", readRowCount, fieldData.RowNum()) + } + + segmentData[reader.fieldID] = fieldData + } + + // nothing to read + if readRowCount == 0 { + break + } + tr.Record("readData") + // split data to shards + err = p.splitFieldsData(segmentData, shards) + if err != nil { + return err + } + tr.Record("splitFieldsData") + // when the estimated size is close to blockSize, save to binlog + err = tryFlushBlocks(p.ctx, shards, p.collectionSchema, p.callFlushFunc, p.blockSize, MaxTotalSizeInMemory, false) + if err != nil { + return err + } + tr.Record("tryFlushBlocks") + } + + // force flush at the end + return tryFlushBlocks(p.ctx, shards, p.collectionSchema, p.callFlushFunc, p.blockSize, MaxTotalSizeInMemory, true) +} + +// readData method reads numpy data section into a storage.FieldData +func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (storage.FieldData, error) { + switch columnReader.dataType { case schemapb.DataType_Bool: - data, err := adapter.ReadBool(p.columnDesc.elementCount) + data, err := columnReader.reader.ReadBool(rowCount) if err != nil { log.Error("Numpy parser: failed to read bool array", zap.Error(err)) - return err + return nil, fmt.Errorf("failed to read bool array: %s", err.Error()) } - p.columnData = &storage.BoolFieldData{ - NumRows: []int64{int64(p.columnDesc.elementCount)}, + return &storage.BoolFieldData{ + NumRows: []int64{int64(len(data))}, Data: data, - } - + }, nil case schemapb.DataType_Int8: - data, err := adapter.ReadInt8(p.columnDesc.elementCount) + data, err := columnReader.reader.ReadInt8(rowCount) if err != nil { log.Error("Numpy parser: failed to read int8 array", zap.Error(err)) - return err + return nil, fmt.Errorf("failed to read int8 array: %s", err.Error()) } - p.columnData = &storage.Int8FieldData{ - NumRows: []int64{int64(p.columnDesc.elementCount)}, + return &storage.Int8FieldData{ + NumRows: []int64{int64(len(data))}, Data: data, - } + }, nil case schemapb.DataType_Int16: - data, err := adapter.ReadInt16(p.columnDesc.elementCount) + data, err := columnReader.reader.ReadInt16(rowCount) if err != nil { - log.Error("Numpy parser: failed to int16 bool array", zap.Error(err)) - return err + log.Error("Numpy parser: failed to int16 array", zap.Error(err)) + return nil, fmt.Errorf("failed to read int16 array: %s", err.Error()) } - p.columnData = &storage.Int16FieldData{ - NumRows: []int64{int64(p.columnDesc.elementCount)}, + return &storage.Int16FieldData{ + NumRows: []int64{int64(len(data))}, Data: data, - } + }, nil case schemapb.DataType_Int32: - data, err := adapter.ReadInt32(p.columnDesc.elementCount) + data, err := columnReader.reader.ReadInt32(rowCount) if err != nil { log.Error("Numpy parser: failed to read int32 array", zap.Error(err)) - return err + return nil, fmt.Errorf("failed to read int32 array: %s", err.Error()) } - p.columnData = &storage.Int32FieldData{ - NumRows: []int64{int64(p.columnDesc.elementCount)}, + return &storage.Int32FieldData{ + NumRows: []int64{int64(len(data))}, Data: data, - } + }, nil case schemapb.DataType_Int64: - data, err := adapter.ReadInt64(p.columnDesc.elementCount) + data, err := columnReader.reader.ReadInt64(rowCount) if err != nil { log.Error("Numpy parser: failed to read int64 array", zap.Error(err)) - return err + return nil, fmt.Errorf("failed to read int64 array: %s", err.Error()) } - p.columnData = &storage.Int64FieldData{ - NumRows: []int64{int64(p.columnDesc.elementCount)}, + return &storage.Int64FieldData{ + NumRows: []int64{int64(len(data))}, Data: data, - } + }, nil case schemapb.DataType_Float: - data, err := adapter.ReadFloat32(p.columnDesc.elementCount) + data, err := columnReader.reader.ReadFloat32(rowCount) if err != nil { log.Error("Numpy parser: failed to read float array", zap.Error(err)) - return err + return nil, fmt.Errorf("failed to read float array: %s", err.Error()) } - p.columnData = &storage.FloatFieldData{ - NumRows: []int64{int64(p.columnDesc.elementCount)}, + return &storage.FloatFieldData{ + NumRows: []int64{int64(len(data))}, Data: data, - } + }, nil case schemapb.DataType_Double: - data, err := adapter.ReadFloat64(p.columnDesc.elementCount) + data, err := columnReader.reader.ReadFloat64(rowCount) if err != nil { log.Error("Numpy parser: failed to read double array", zap.Error(err)) - return err + return nil, fmt.Errorf("failed to read double array: %s", err.Error()) } - p.columnData = &storage.DoubleFieldData{ - NumRows: []int64{int64(p.columnDesc.elementCount)}, + return &storage.DoubleFieldData{ + NumRows: []int64{int64(len(data))}, Data: data, - } + }, nil case schemapb.DataType_VarChar: - data, err := adapter.ReadString(p.columnDesc.elementCount) + data, err := columnReader.reader.ReadString(rowCount) if err != nil { log.Error("Numpy parser: failed to read varchar array", zap.Error(err)) - return err + return nil, fmt.Errorf("failed to read varchar array: %s", err.Error()) } - p.columnData = &storage.StringFieldData{ - NumRows: []int64{int64(p.columnDesc.elementCount)}, + return &storage.StringFieldData{ + NumRows: []int64{int64(len(data))}, Data: data, - } + }, nil case schemapb.DataType_BinaryVector: - data, err := adapter.ReadUint8(p.columnDesc.elementCount) + data, err := columnReader.reader.ReadUint8(rowCount * (columnReader.dimension / 8)) if err != nil { log.Error("Numpy parser: failed to read binary vector array", zap.Error(err)) - return err + return nil, fmt.Errorf("failed to read binary vector array: %s", err.Error()) } - p.columnData = &storage.BinaryVectorFieldData{ - NumRows: []int64{int64(p.columnDesc.elementCount)}, + return &storage.BinaryVectorFieldData{ + NumRows: []int64{int64(len(data) * 8 / columnReader.dimension)}, Data: data, - Dim: p.columnDesc.dimension, - } + Dim: columnReader.dimension, + }, nil case schemapb.DataType_FloatVector: // float32/float64 numpy file can be used for float vector file, 2 reasons: // 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit // 2. for float64 numpy file, the performance is worse than float32 numpy file - elementType := adapter.GetType() + elementType := columnReader.reader.GetType() var data []float32 var err error if elementType == schemapb.DataType_Float { - data, err = adapter.ReadFloat32(p.columnDesc.elementCount) + data, err = columnReader.reader.ReadFloat32(rowCount * columnReader.dimension) if err != nil { log.Error("Numpy parser: failed to read float vector array", zap.Error(err)) - return err + return nil, fmt.Errorf("failed to read float vector array: %s", err.Error()) } } else if elementType == schemapb.DataType_Double { - data = make([]float32, 0, p.columnDesc.elementCount) - data64, err := adapter.ReadFloat64(p.columnDesc.elementCount) + data = make([]float32, 0, columnReader.rowCount) + data64, err := columnReader.reader.ReadFloat64(rowCount * columnReader.dimension) if err != nil { log.Error("Numpy parser: failed to read float vector array", zap.Error(err)) - return err + return nil, fmt.Errorf("failed to read float vector array: %s", err.Error()) } for _, f64 := range data64 { @@ -301,40 +548,255 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error { } } - p.columnData = &storage.FloatVectorFieldData{ - NumRows: []int64{int64(p.columnDesc.elementCount)}, + return &storage.FloatVectorFieldData{ + NumRows: []int64{int64(len(data) / columnReader.dimension)}, Data: data, - Dim: p.columnDesc.dimension, + Dim: columnReader.dimension, + }, nil + default: + log.Error("Numpy parser: unsupported data type of field", zap.Any("dataType", columnReader.dataType), + zap.String("fieldName", columnReader.fieldName)) + return nil, fmt.Errorf("unsupported data type %s of field '%s'", getTypeName(columnReader.dataType), + columnReader.fieldName) + } +} + +// appendFunc defines the methods to append data to storage.FieldData +func (p *NumpyParser) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error { + switch schema.DataType { + case schemapb.DataType_Bool: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.BoolFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(bool)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Float: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.FloatFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(float32)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Double: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.DoubleFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(float64)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Int8: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int8FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int8)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Int16: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int16FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int16)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Int32: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int32FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int32)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_Int64: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int64FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int64)) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_BinaryVector: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.BinaryVectorFieldData) + arr.Data = append(arr.Data, src.GetRow(n).([]byte)...) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_FloatVector: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.FloatVectorFieldData) + arr.Data = append(arr.Data, src.GetRow(n).([]float32)...) + arr.NumRows[0]++ + return nil + } + case schemapb.DataType_String, schemapb.DataType_VarChar: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.StringFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(string)) + return nil } default: - log.Error("Numpy parser: unsupported data type of field", zap.Any("dataType", p.columnDesc.dt), zap.String("fieldName", p.columnDesc.name)) - return fmt.Errorf("unsupported data type %s of field '%s'", getTypeName(p.columnDesc.dt), p.columnDesc.name) + return nil + } +} + +func (p *NumpyParser) prepareAppendFunctions() (map[string]func(src storage.FieldData, n int, target storage.FieldData) error, error) { + appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error) + for i := 0; i < len(p.collectionSchema.Fields); i++ { + schema := p.collectionSchema.Fields[i] + appendFuncErr := p.appendFunc(schema) + if appendFuncErr == nil { + log.Error("Numpy parser: unsupported field data type") + return nil, fmt.Errorf("unsupported field data type: %d", schema.GetDataType()) + } + appendFunctions[schema.GetName()] = appendFuncErr + } + return appendFunctions, nil +} + +// checkRowCount checks existence of each field, and returns the primary key schema +// check row count, all fields row count must be equal +func (p *NumpyParser) checkRowCount(fieldsData map[storage.FieldID]storage.FieldData) (int, *schemapb.FieldSchema, error) { + rowCount := 0 + rowCounter := make(map[string]int) + var primaryKey *schemapb.FieldSchema + for i := 0; i < len(p.collectionSchema.Fields); i++ { + schema := p.collectionSchema.Fields[i] + if schema.GetIsPrimaryKey() { + primaryKey = schema + } + + if !schema.GetAutoID() { + v, ok := fieldsData[schema.GetFieldID()] + if !ok { + log.Error("Numpy parser: field not provided", zap.String("fieldName", schema.GetName())) + return 0, nil, fmt.Errorf("field '%s' not provided", schema.GetName()) + } + rowCounter[schema.GetName()] = v.RowNum() + if v.RowNum() > rowCount { + rowCount = v.RowNum() + } + } + } + if primaryKey == nil { + log.Error("Numpy parser: primary key field is not found") + return 0, nil, fmt.Errorf("primary key field is not found") + } + + for name, count := range rowCounter { + if count != rowCount { + log.Error("Numpy parser: field row count is not equal to other fields row count", zap.String("fieldName", name), + zap.Int("rowCount", count), zap.Int("otherRowCount", rowCount)) + return 0, nil, fmt.Errorf("field '%s' row count %d is not equal to other fields row count: %d", name, count, rowCount) + } + } + // log.Info("Numpy parser: try to split a block with row count", zap.Int("rowCount", rowCount)) + + return rowCount, primaryKey, nil +} + +// splitFieldsData is to split the in-memory data(parsed from column-based files) into shards +func (p *NumpyParser) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, shards []map[storage.FieldID]storage.FieldData) error { + if len(fieldsData) == 0 { + log.Error("Numpy parser: fields data to split is empty") + return fmt.Errorf("fields data to split is empty") + } + + if len(shards) != int(p.shardNum) { + log.Error("Numpy parser: block count is not equal to collection shard number", zap.Int("shardsLen", len(shards)), + zap.Int32("shardNum", p.shardNum)) + return fmt.Errorf("block count %d is not equal to collection shard number %d", len(shards), p.shardNum) + } + + rowCount, primaryKey, err := p.checkRowCount(fieldsData) + if err != nil { + return err + } + + // generate auto id for primary key and rowid field + rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount)) + if err != nil { + log.Error("Numpy parser: failed to alloc row ID", zap.Int("rowCount", rowCount), zap.Error(err)) + return fmt.Errorf("failed to alloc %d rows ID, error: %w", rowCount, err) + } + + rowIDField, ok := fieldsData[common.RowIDField] + if !ok { + rowIDField = &storage.Int64FieldData{ + Data: make([]int64, 0), + NumRows: []int64{0}, + } + fieldsData[common.RowIDField] = rowIDField + } + rowIDFieldArr := rowIDField.(*storage.Int64FieldData) + for i := rowIDBegin; i < rowIDEnd; i++ { + rowIDFieldArr.Data = append(rowIDFieldArr.Data, i) + } + + // reset the primary keys, as we know, only int64 pk can be auto-generated + if primaryKey.GetAutoID() { + log.Info("Numpy parser: generating auto-id", zap.Int("rowCount", rowCount), zap.Int64("rowIDBegin", rowIDBegin)) + if primaryKey.GetDataType() != schemapb.DataType_Int64 { + log.Error("Numpy parser: primary key field is auto-generated but the field type is not int64") + return fmt.Errorf("primary key field is auto-generated but the field type is not int64") + } + + primaryDataArr := &storage.Int64FieldData{ + NumRows: []int64{int64(rowCount)}, + Data: make([]int64, 0, rowCount), + } + for i := rowIDBegin; i < rowIDEnd; i++ { + primaryDataArr.Data = append(primaryDataArr.Data, i) + } + + fieldsData[primaryKey.GetFieldID()] = primaryDataArr + p.autoIDRange = append(p.autoIDRange, rowIDBegin, rowIDEnd) + } + + // if the primary key is not auto-gernerate and user doesn't provide, return error + primaryData, ok := fieldsData[primaryKey.GetFieldID()] + if !ok || primaryData.RowNum() <= 0 { + log.Error("Numpy parser: primary key field is not provided", zap.String("keyName", primaryKey.GetName())) + return fmt.Errorf("primary key '%s' field data is not provided", primaryKey.GetName()) + } + + // prepare append functions + appendFunctions, err := p.prepareAppendFunctions() + if err != nil { + return err + } + + // split data into shards + for i := 0; i < rowCount; i++ { + // hash to a shard number + + pk := primaryData.GetRow(i) + shard, err := pkToShard(pk, uint32(p.shardNum)) + if err != nil { + return err + } + + // set rowID field + rowIDField := shards[shard][common.RowIDField].(*storage.Int64FieldData) + rowIDField.Data = append(rowIDField.Data, rowIDFieldArr.GetRow(i).(int64)) + + // append row to shard + for k := 0; k < len(p.collectionSchema.Fields); k++ { + schema := p.collectionSchema.Fields[k] + srcData := fieldsData[schema.GetFieldID()] + targetData := shards[shard][schema.GetFieldID()] + if srcData == nil || targetData == nil { + log.Error("Numpy parser: cannot append data since source or target field data is nil", + zap.String("FieldName", schema.GetName()), + zap.Bool("sourceNil", srcData == nil), zap.Bool("targetNil", targetData == nil)) + return fmt.Errorf("cannot append data for field '%s' since source or target field data is nil", + primaryKey.GetName()) + } + appendFunc := appendFunctions[schema.GetName()] + err := appendFunc(srcData, i, targetData) + if err != nil { + return err + } + } } return nil } - -func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate bool) error { - adapter, err := NewNumpyAdapter(reader) - if err != nil { - return err - } - - // the validation method only check the file header information - err = p.validate(adapter, fieldName) - if err != nil { - return err - } - - if onlyValidate { - return nil - } - - // read all data from the numpy file - err = p.consume(adapter) - if err != nil { - return err - } - - return p.callFlushFunc(p.columnData) -} diff --git a/internal/util/importutil/numpy_parser_test.go b/internal/util/importutil/numpy_parser_test.go index 258c27f694..6e81a4871f 100644 --- a/internal/util/importutil/numpy_parser_test.go +++ b/internal/util/importutil/numpy_parser_test.go @@ -18,498 +18,802 @@ package importutil import ( "context" + "errors" "os" "testing" - "github.com/sbinet/npyio/npy" "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/timerecord" ) -func Test_NewNumpyParser(t *testing.T) { +func createLocalChunkManager(t *testing.T) storage.ChunkManager { ctx := context.Background() + // NewDefaultFactory() use "/tmp/milvus" as default root path, and cannot specify root path + // NewChunkManagerFactory() can specify the root path + f := storage.NewChunkManagerFactory("local", storage.RootPath(TempFilesPath)) + cm, err := f.NewPersistentStorageChunkManager(ctx) + assert.NoError(t, err) - parser := NewNumpyParser(ctx, nil, nil) - assert.Nil(t, parser) + return cm } -func Test_NumpyParserValidate(t *testing.T) { +func createNumpyParser(t *testing.T) *NumpyParser { ctx := context.Background() - err := os.MkdirAll(TempFilesPath, os.ModePerm) - assert.Nil(t, err) - defer os.RemoveAll(TempFilesPath) - schema := sampleSchema() - flushFunc := func(field storage.FieldData) error { + idAllocator := newIDAllocator(ctx, t, nil) + + cm := createLocalChunkManager(t) + + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } - adapter := &NumpyAdapter{npyReader: &npy.Reader{}} - - t.Run("not support DataType_String", func(t *testing.T) { - // string type is not supported - p := NewNumpyParser(ctx, &schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ - { - FieldID: 109, - Name: "FieldString", - IsPrimaryKey: false, - Description: "string", - DataType: schemapb.DataType_String, - }, - }, - }, flushFunc) - err = p.validate(adapter, "dummy") - assert.NotNil(t, err) - err = p.validate(adapter, "FieldString") - assert.NotNil(t, err) - }) - - // reader is nil - parser := NewNumpyParser(ctx, schema, flushFunc) - err = parser.validate(nil, "") - assert.NotNil(t, err) - - t.Run("validate scalar", func(t *testing.T) { - filePath := TempFilesPath + "scalar_1.npy" - data1 := []float64{0, 1, 2, 3, 4, 5} - err := CreateNumpyFile(filePath, data1) - assert.Nil(t, err) - - file1, err := os.Open(filePath) - assert.Nil(t, err) - defer file1.Close() - - adapter, err := NewNumpyAdapter(file1) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - err = parser.validate(adapter, "FieldDouble") - assert.Nil(t, err) - assert.Equal(t, len(data1), parser.columnDesc.elementCount) - - err = parser.validate(adapter, "") - assert.NotNil(t, err) - - // data type mismatch - filePath = TempFilesPath + "scalar_2.npy" - data2 := []int64{0, 1, 2, 3, 4, 5} - err = CreateNumpyFile(filePath, data2) - assert.Nil(t, err) - - file2, err := os.Open(filePath) - assert.Nil(t, err) - defer file2.Close() - - adapter, err = NewNumpyAdapter(file2) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - err = parser.validate(adapter, "FieldDouble") - assert.NotNil(t, err) - - // shape mismatch - filePath = TempFilesPath + "scalar_2.npy" - data3 := [][2]float64{{1, 1}} - err = CreateNumpyFile(filePath, data3) - assert.Nil(t, err) - - file3, err := os.Open(filePath) - assert.Nil(t, err) - defer file2.Close() - - adapter, err = NewNumpyAdapter(file3) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - err = parser.validate(adapter, "FieldDouble") - assert.NotNil(t, err) - }) - - t.Run("validate binary vector", func(t *testing.T) { - filePath := TempFilesPath + "binary_vector_1.npy" - data1 := [][2]uint8{{0, 1}, {2, 3}, {4, 5}} - err := CreateNumpyFile(filePath, data1) - assert.Nil(t, err) - - file1, err := os.Open(filePath) - assert.Nil(t, err) - defer file1.Close() - - adapter, err := NewNumpyAdapter(file1) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - err = parser.validate(adapter, "FieldBinaryVector") - assert.Nil(t, err) - assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount) - - // data type mismatch - filePath = TempFilesPath + "binary_vector_2.npy" - data2 := [][2]uint16{{0, 1}, {2, 3}, {4, 5}} - err = CreateNumpyFile(filePath, data2) - assert.Nil(t, err) - - file2, err := os.Open(filePath) - assert.Nil(t, err) - defer file2.Close() - - adapter, err = NewNumpyAdapter(file2) - assert.NotNil(t, err) - assert.Nil(t, adapter) - - // shape mismatch - filePath = TempFilesPath + "binary_vector_3.npy" - data3 := []uint8{1, 2, 3} - err = CreateNumpyFile(filePath, data3) - assert.Nil(t, err) - - file3, err := os.Open(filePath) - assert.Nil(t, err) - defer file3.Close() - - adapter, err = NewNumpyAdapter(file3) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - err = parser.validate(adapter, "FieldBinaryVector") - assert.NotNil(t, err) - - // shape[1] mismatch - filePath = TempFilesPath + "binary_vector_4.npy" - data4 := [][3]uint8{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}} - err = CreateNumpyFile(filePath, data4) - assert.Nil(t, err) - - file4, err := os.Open(filePath) - assert.Nil(t, err) - defer file4.Close() - - adapter, err = NewNumpyAdapter(file4) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - err = parser.validate(adapter, "FieldBinaryVector") - assert.NotNil(t, err) - - // dimension mismatch - p := NewNumpyParser(ctx, &schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ - { - FieldID: 109, - Name: "FieldBinaryVector", - DataType: schemapb.DataType_BinaryVector, - }, - }, - }, flushFunc) - - err = p.validate(adapter, "FieldBinaryVector") - assert.NotNil(t, err) - }) - - t.Run("validate float vector", func(t *testing.T) { - filePath := TempFilesPath + "Float_vector.npy" - data1 := [][4]float32{{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}, {3, 3, 3, 3}} - err := CreateNumpyFile(filePath, data1) - assert.Nil(t, err) - - file1, err := os.Open(filePath) - assert.Nil(t, err) - defer file1.Close() - - adapter, err := NewNumpyAdapter(file1) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - err = parser.validate(adapter, "FieldFloatVector") - assert.Nil(t, err) - assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount) - - // data type mismatch - filePath = TempFilesPath + "float_vector_2.npy" - data2 := [][4]int32{{0, 1, 2, 3}} - err = CreateNumpyFile(filePath, data2) - assert.Nil(t, err) - - file2, err := os.Open(filePath) - assert.Nil(t, err) - defer file2.Close() - - adapter, err = NewNumpyAdapter(file2) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - err = parser.validate(adapter, "FieldFloatVector") - assert.NotNil(t, err) - - // shape mismatch - filePath = TempFilesPath + "float_vector_3.npy" - data3 := []float32{1, 2, 3} - err = CreateNumpyFile(filePath, data3) - assert.Nil(t, err) - - file3, err := os.Open(filePath) - assert.Nil(t, err) - defer file3.Close() - - adapter, err = NewNumpyAdapter(file3) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - err = parser.validate(adapter, "FieldFloatVector") - assert.NotNil(t, err) - - // shape[1] mismatch - filePath = TempFilesPath + "float_vector_4.npy" - data4 := [][3]float32{{0, 0, 0}, {1, 1, 1}} - err = CreateNumpyFile(filePath, data4) - assert.Nil(t, err) - - file4, err := os.Open(filePath) - assert.Nil(t, err) - defer file4.Close() - - adapter, err = NewNumpyAdapter(file4) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - err = parser.validate(adapter, "FieldFloatVector") - assert.NotNil(t, err) - - // dimension mismatch - p := NewNumpyParser(ctx, &schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ - { - FieldID: 109, - Name: "FieldFloatVector", - DataType: schemapb.DataType_FloatVector, - }, - }, - }, flushFunc) - - err = p.validate(adapter, "FieldFloatVector") - assert.NotNil(t, err) - }) + parser, err := NewNumpyParser(ctx, schema, idAllocator, 2, 100, cm, flushFunc) + assert.NoError(t, err) + assert.NotNil(t, parser) + return parser } -func Test_NumpyParserParse(t *testing.T) { +func findSchema(schema *schemapb.CollectionSchema, dt schemapb.DataType) *schemapb.FieldSchema { + fields := schema.Fields + for _, field := range fields { + if field.GetDataType() == dt { + return field + } + } + return nil +} + +func Test_NewNumpyParser(t *testing.T) { ctx := context.Background() + + parser, err := NewNumpyParser(ctx, nil, nil, 2, 100, nil, nil) + assert.Error(t, err) + assert.Nil(t, parser) + + schema := sampleSchema() + parser, err = NewNumpyParser(ctx, schema, nil, 2, 100, nil, nil) + assert.Error(t, err) + assert.Nil(t, parser) + + idAllocator := newIDAllocator(ctx, t, nil) + parser, err = NewNumpyParser(ctx, schema, idAllocator, 2, 100, nil, nil) + assert.Error(t, err) + assert.Nil(t, parser) + + cm := createLocalChunkManager(t) + + parser, err = NewNumpyParser(ctx, schema, idAllocator, 2, 100, cm, nil) + assert.Error(t, err) + assert.Nil(t, parser) + + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { + return nil + } + parser, err = NewNumpyParser(ctx, schema, idAllocator, 2, 100, cm, flushFunc) + assert.NoError(t, err) + assert.NotNil(t, parser) +} + +func Test_NumpyParserValidateFileNames(t *testing.T) { + parser := createNumpyParser(t) + + // file has no corresponding field in collection + err := parser.validateFileNames([]string{"dummy.npy"}) + assert.Error(t, err) + + // there is no file corresponding to field + fileNames := []string{ + "FieldBool.npy", + "FieldInt8.npy", + "FieldInt16.npy", + "FieldInt32.npy", + "FieldInt64.npy", + "FieldFloat.npy", + "FieldDouble.npy", + "FieldString.npy", + "FieldBinaryVector.npy", + } + err = parser.validateFileNames(fileNames) + assert.Error(t, err) + + //valid + fileNames = append(fileNames, "FieldFloatVector.npy") + err = parser.validateFileNames(fileNames) + assert.NoError(t, err) +} + +func Test_NumpyParserValidateHeader(t *testing.T) { err := os.MkdirAll(TempFilesPath, os.ModePerm) assert.Nil(t, err) defer os.RemoveAll(TempFilesPath) - schema := sampleSchema() + parser := createNumpyParser(t) - checkFunc := func(data interface{}, fieldName string, callback func(field storage.FieldData) error) { + // nil input error + err = parser.validateHeader(nil) + assert.Error(t, err) - filePath := TempFilesPath + fieldName + ".npy" - err := CreateNumpyFile(filePath, data) + validateHeader := func(data interface{}, fieldSchema *schemapb.FieldSchema) error { + filePath := TempFilesPath + fieldSchema.GetName() + ".npy" + + err = CreateNumpyFile(filePath, data) assert.Nil(t, err) - func() { - file, err := os.Open(filePath) - assert.Nil(t, err) - defer file.Close() + file, err := os.Open(filePath) + assert.Nil(t, err) + defer file.Close() - parser := NewNumpyParser(ctx, schema, callback) - err = parser.Parse(file, fieldName, false) - assert.Nil(t, err) - }() + adapter, err := NewNumpyAdapter(file) + assert.Nil(t, err) - // validation failed - func() { - file, err := os.Open(filePath) - assert.Nil(t, err) - defer file.Close() - - parser := NewNumpyParser(ctx, schema, callback) - err = parser.Parse(file, "dummy", false) - assert.NotNil(t, err) - }() - - // read data error - func() { - parser := NewNumpyParser(ctx, schema, callback) - err = parser.Parse(&MockReader{}, fieldName, false) - assert.NotNil(t, err) - }() + dim, _ := getFieldDimension(fieldSchema) + columnReader := &NumpyColumnReader{ + fieldName: fieldSchema.GetName(), + fieldID: fieldSchema.GetFieldID(), + dataType: fieldSchema.GetDataType(), + dimension: dim, + file: file, + reader: adapter, + } + err = parser.validateHeader(columnReader) + return err } - t.Run("parse scalar bool", func(t *testing.T) { - data := []bool{true, false, true, false, true} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data), field.RowNum()) + t.Run("veridate float vector numpy", func(t *testing.T) { + // numpy file is not vectors + data1 := []int32{1, 2, 3, 4} + schema := findSchema(sampleSchema(), schemapb.DataType_FloatVector) + err = validateHeader(data1, schema) + assert.Error(t, err) - for i := 0; i < len(data); i++ { - assert.Equal(t, data[i], field.GetRow(i)) - } + // field data type is not float vector type + data2 := []float32{1.1, 2.1, 3.1, 4.1} + err = validateHeader(data2, schema) + assert.Error(t, err) - return nil + // dimension mismatch + data3 := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}} + schema = &schemapb.FieldSchema{ + FieldID: 111, + Name: "FieldFloatVector", + IsPrimaryKey: false, + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "99"}, + }, } - checkFunc(data, "FieldBool", flushFunc) + err = validateHeader(data3, schema) + assert.Error(t, err) }) - t.Run("parse scalar int8", func(t *testing.T) { - data := []int8{1, 2, 3, 4, 5} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data), field.RowNum()) + t.Run("veridate binary vector numpy", func(t *testing.T) { + // numpy file is not vectors + data1 := []int32{1, 2, 3, 4} + schema := findSchema(sampleSchema(), schemapb.DataType_BinaryVector) + err = validateHeader(data1, schema) + assert.Error(t, err) - for i := 0; i < len(data); i++ { - assert.Equal(t, data[i], field.GetRow(i)) - } + // field data type is not binary vector type + data2 := []uint8{1, 2, 3, 4, 5, 6} + err = validateHeader(data2, schema) + assert.Error(t, err) - return nil + // dimension mismatch + data3 := [][2]uint8{{1, 2}, {3, 4}, {5, 6}} + schema = &schemapb.FieldSchema{ + FieldID: 110, + Name: "FieldBinaryVector", + IsPrimaryKey: false, + Description: "binary_vector", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "99"}, + }, } - checkFunc(data, "FieldInt8", flushFunc) + err = validateHeader(data3, schema) + assert.Error(t, err) }) - t.Run("parse scalar int16", func(t *testing.T) { - data := []int16{1, 2, 3, 4, 5} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data), field.RowNum()) + t.Run("veridate scalar numpy", func(t *testing.T) { + // data type mismatch + data1 := []int32{1, 2, 3, 4} + schema := findSchema(sampleSchema(), schemapb.DataType_Int8) + err = validateHeader(data1, schema) + assert.Error(t, err) - for i := 0; i < len(data); i++ { - assert.Equal(t, data[i], field.GetRow(i)) - } + // illegal shape + data2 := [][2]int8{{1, 2}, {3, 4}, {5, 6}} + err = validateHeader(data2, schema) + assert.Error(t, err) + }) +} - return nil - } - checkFunc(data, "FieldInt16", flushFunc) +func Test_NumpyParserCreateReaders(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + + cm := createLocalChunkManager(t) + parser := createNumpyParser(t) + + // no field match the filename + t.Run("no field match the filename", func(t *testing.T) { + filePath := TempFilesPath + "dummy.npy" + files := []string{filePath} + readers, err := parser.createReaders(files) + assert.Error(t, err) + assert.Empty(t, readers) + defer closeReaders(readers) }) - t.Run("parse scalar int32", func(t *testing.T) { - data := []int32{1, 2, 3, 4, 5} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data), field.RowNum()) - - for i := 0; i < len(data); i++ { - assert.Equal(t, data[i], field.GetRow(i)) - } - - return nil - } - checkFunc(data, "FieldInt32", flushFunc) + // file doesn't exist + t.Run("file doesnt exist", func(t *testing.T) { + filePath := TempFilesPath + "FieldBool.npy" + files := []string{filePath} + readers, err := parser.createReaders(files) + assert.Error(t, err) + assert.Empty(t, readers) + defer closeReaders(readers) }) - t.Run("parse scalar int64", func(t *testing.T) { - data := []int64{1, 2, 3, 4, 5} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data), field.RowNum()) - - for i := 0; i < len(data); i++ { - assert.Equal(t, data[i], field.GetRow(i)) - } - - return nil - } - checkFunc(data, "FieldInt64", flushFunc) + // not a numpy file + t.Run("not a numpy file", func(t *testing.T) { + ctx := context.Background() + filePath := TempFilesPath + "FieldBool.npy" + files := []string{filePath} + err = cm.Write(ctx, filePath, []byte{1, 2, 3}) + readers, err := parser.createReaders(files) + assert.Error(t, err) + assert.Empty(t, readers) + defer closeReaders(readers) }) - t.Run("parse scalar float", func(t *testing.T) { - data := []float32{1, 2, 3, 4, 5} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data), field.RowNum()) - - for i := 0; i < len(data); i++ { - assert.Equal(t, data[i], field.GetRow(i)) - } - - return nil + t.Run("succeed", func(t *testing.T) { + files := createSampleNumpyFiles(t, cm) + readers, err := parser.createReaders(files) + assert.NoError(t, err) + assert.Equal(t, len(files), len(readers)) + for i := 0; i < len(readers); i++ { + reader := readers[i] + schema := findSchema(sampleSchema(), reader.dataType) + assert.NotNil(t, schema) + assert.Equal(t, schema.GetName(), reader.fieldName) + assert.Equal(t, schema.GetFieldID(), reader.fieldID) + dim, _ := getFieldDimension(schema) + assert.Equal(t, dim, reader.dimension) } - checkFunc(data, "FieldFloat", flushFunc) + defer closeReaders(readers) }) - t.Run("parse scalar double", func(t *testing.T) { - data := []float64{1, 2, 3, 4, 5} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data), field.RowNum()) + t.Run("row count doesnt equal", func(t *testing.T) { + files := createSampleNumpyFiles(t, cm) + filePath := TempFilesPath + "FieldBool.npy" + err = CreateNumpyFile(filePath, []bool{true}) + assert.Nil(t, err) - for i := 0; i < len(data); i++ { - assert.Equal(t, data[i], field.GetRow(i)) - } - - return nil - } - checkFunc(data, "FieldDouble", flushFunc) + readers, err := parser.createReaders(files) + assert.Error(t, err) + assert.Empty(t, readers) + defer closeReaders(readers) }) - t.Run("parse scalar varchar", func(t *testing.T) { - data := []string{"abcd", "sdb", "ok", "milvus"} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data), field.RowNum()) + t.Run("velidate header failed", func(t *testing.T) { + filePath := TempFilesPath + "FieldBool.npy" + err = CreateNumpyFile(filePath, []int32{1, 2, 3, 4, 5}) + assert.Nil(t, err) + files := []string{filePath} + readers, err := parser.createReaders(files) + assert.Error(t, err) + assert.Empty(t, readers) + closeReaders(readers) + }) +} - for i := 0; i < len(data); i++ { - assert.Equal(t, data[i], field.GetRow(i)) - } +func Test_NumpyParserReadData(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) - return nil + cm := createLocalChunkManager(t) + parser := createNumpyParser(t) + + t.Run("general cases", func(t *testing.T) { + files := createSampleNumpyFiles(t, cm) + readers, err := parser.createReaders(files) + assert.NoError(t, err) + assert.Equal(t, len(files), len(readers)) + defer closeReaders(readers) + + // each sample file has 5 rows, read the first 2 rows + for _, reader := range readers { + fieldData, err := parser.readData(reader, 2) + assert.NoError(t, err) + assert.Equal(t, 2, fieldData.RowNum()) } - checkFunc(data, "FieldString", flushFunc) + + // read the left rows + for _, reader := range readers { + fieldData, err := parser.readData(reader, 100) + assert.NoError(t, err) + assert.Equal(t, 3, fieldData.RowNum()) + } + + // unsupport data type + columnReader := &NumpyColumnReader{ + fieldName: "dummy", + dataType: schemapb.DataType_None, + } + fieldData, err := parser.readData(columnReader, 2) + assert.Error(t, err) + assert.Nil(t, fieldData) }) - t.Run("parse binary vector", func(t *testing.T) { - data := [][2]uint8{{1, 2}, {3, 4}, {5, 6}} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data), field.RowNum()) + readEmptyFunc := func(filedName string, data interface{}) { + filePath := TempFilesPath + filedName + ".npy" + err = CreateNumpyFile(filePath, data) + assert.Nil(t, err) - for i := 0; i < len(data); i++ { - row := field.GetRow(i).([]uint8) - for k := 0; k < len(row); k++ { - assert.Equal(t, data[i][k], row[k]) + readers, err := parser.createReaders([]string{filePath}) + assert.NoError(t, err) + assert.Equal(t, 1, len(readers)) + defer closeReaders(readers) + + // row count 0 is not allowed + fieldData, err := parser.readData(readers[0], 0) + assert.Error(t, err) + assert.Nil(t, fieldData) + + // nothint to read + _, err = parser.readData(readers[0], 2) + assert.NoError(t, err) + } + + readBatchFunc := func(filedName string, data interface{}, dataLen int, getValue func(k int) interface{}) { + filePath := TempFilesPath + filedName + ".npy" + err = CreateNumpyFile(filePath, data) + assert.Nil(t, err) + + readers, err := parser.createReaders([]string{filePath}) + assert.NoError(t, err) + assert.Equal(t, 1, len(readers)) + defer closeReaders(readers) + + readPosition := 2 + fieldData, err := parser.readData(readers[0], readPosition) + assert.NoError(t, err) + assert.Equal(t, readPosition, fieldData.RowNum()) + for i := 0; i < readPosition; i++ { + assert.Equal(t, getValue(i), fieldData.GetRow(i)) + } + + if dataLen > readPosition { + fieldData, err = parser.readData(readers[0], dataLen+1) + assert.NoError(t, err) + assert.Equal(t, dataLen-readPosition, fieldData.RowNum()) + for i := readPosition; i < dataLen; i++ { + assert.Equal(t, getValue(i), fieldData.GetRow(i-readPosition)) + } + } + } + + t.Run("read bool", func(t *testing.T) { + readEmptyFunc("FieldBool", []bool{}) + + data := []bool{true, false, true, false, false, true} + readBatchFunc("FieldBool", data, len(data), func(k int) interface{} { return data[k] }) + }) + + t.Run("read int8", func(t *testing.T) { + readEmptyFunc("FieldInt8", []int8{}) + + data := []int8{1, 3, 5, 7, 9, 4, 2, 6, 8} + readBatchFunc("FieldInt8", data, len(data), func(k int) interface{} { return data[k] }) + }) + + t.Run("read int16", func(t *testing.T) { + readEmptyFunc("FieldInt16", []int16{}) + + data := []int16{21, 13, 35, 47, 59, 34, 12} + readBatchFunc("FieldInt16", data, len(data), func(k int) interface{} { return data[k] }) + }) + + t.Run("read int32", func(t *testing.T) { + readEmptyFunc("FieldInt32", []int32{}) + + data := []int32{1, 3, 5, 7, 9, 4, 2, 6, 8} + readBatchFunc("FieldInt32", data, len(data), func(k int) interface{} { return data[k] }) + }) + + t.Run("read int64", func(t *testing.T) { + readEmptyFunc("FieldInt64", []int64{}) + + data := []int64{100, 200} + readBatchFunc("FieldInt64", data, len(data), func(k int) interface{} { return data[k] }) + }) + + t.Run("read float", func(t *testing.T) { + readEmptyFunc("FieldFloat", []float32{}) + + data := []float32{2.5, 32.2, 53.254, 3.45, 65.23421, 54.8978} + readBatchFunc("FieldFloat", data, len(data), func(k int) interface{} { return data[k] }) + }) + + t.Run("read double", func(t *testing.T) { + readEmptyFunc("FieldDouble", []float64{}) + + data := []float64{65.24454, 343.4365, 432.6556} + readBatchFunc("FieldDouble", data, len(data), func(k int) interface{} { return data[k] }) + }) + + specialReadEmptyFunc := func(filedName string, data interface{}) { + ctx := context.Background() + filePath := TempFilesPath + filedName + ".npy" + content, err := CreateNumpyData(data) + assert.NoError(t, err) + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + + readers, err := parser.createReaders([]string{filePath}) + assert.NoError(t, err) + assert.Equal(t, 1, len(readers)) + defer closeReaders(readers) + + // row count 0 is not allowed + fieldData, err := parser.readData(readers[0], 0) + assert.Error(t, err) + assert.Nil(t, fieldData) + } + + t.Run("read varchar", func(t *testing.T) { + specialReadEmptyFunc("FieldString", []string{"aaa"}) + }) + + t.Run("read binary vector", func(t *testing.T) { + specialReadEmptyFunc("FieldBinaryVector", [][2]uint8{{1, 2}, {3, 4}}) + }) + + t.Run("read float vector", func(t *testing.T) { + specialReadEmptyFunc("FieldFloatVector", [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}}) + specialReadEmptyFunc("FieldFloatVector", [][4]float64{{1, 2, 3, 4}, {3, 4, 5, 6}}) + }) +} + +func Test_NumpyParserPrepareAppendFunctions(t *testing.T) { + parser := createNumpyParser(t) + + // succeed + appendFuncs, err := parser.prepareAppendFunctions() + assert.NoError(t, err) + assert.Equal(t, len(sampleSchema().Fields), len(appendFuncs)) + + // schema has unsupported data type + parser.collectionSchema = &schemapb.CollectionSchema{ + Name: "schema", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 102, + Name: "flag", + IsPrimaryKey: false, + DataType: schemapb.DataType_None, + }, + }, + } + appendFuncs, err = parser.prepareAppendFunctions() + assert.Error(t, err) + assert.Nil(t, appendFuncs) +} + +func Test_NumpyParserCheckRowCount(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + + cm := createLocalChunkManager(t) + parser := createNumpyParser(t) + + files := createSampleNumpyFiles(t, cm) + readers, err := parser.createReaders(files) + assert.NoError(t, err) + defer closeReaders(readers) + + // succeed + segmentData := make(map[storage.FieldID]storage.FieldData) + for _, reader := range readers { + fieldData, err := parser.readData(reader, 100) + assert.NoError(t, err) + segmentData[reader.fieldID] = fieldData + } + + rowCount, primaryKey, err := parser.checkRowCount(segmentData) + assert.NoError(t, err) + assert.Equal(t, 5, rowCount) + assert.NotNil(t, primaryKey) + assert.Equal(t, "FieldInt64", primaryKey.GetName()) + + // field data missed + delete(segmentData, 102) + rowCount, primaryKey, err = parser.checkRowCount(segmentData) + assert.Error(t, err) + assert.Zero(t, rowCount) + assert.Nil(t, primaryKey) + + // primarykey missed + parser.collectionSchema = &schemapb.CollectionSchema{ + Name: "schema", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 105, + Name: "FieldInt32", + IsPrimaryKey: false, + AutoID: false, + DataType: schemapb.DataType_Int32, + }, + }, + } + + segmentData[105] = &storage.Int32FieldData{ + Data: []int32{1, 2, 3, 4}, + } + + rowCount, primaryKey, err = parser.checkRowCount(segmentData) + assert.Error(t, err) + assert.Zero(t, rowCount) + assert.Nil(t, primaryKey) + + // row count mismatch + parser.collectionSchema.Fields = append(parser.collectionSchema.Fields, &schemapb.FieldSchema{ + FieldID: 106, + Name: "FieldInt64", + IsPrimaryKey: true, + AutoID: false, + DataType: schemapb.DataType_Int64, + }) + + segmentData[106] = &storage.Int64FieldData{ + Data: []int64{1, 2, 4}, + } + + rowCount, primaryKey, err = parser.checkRowCount(segmentData) + assert.Error(t, err) + assert.Zero(t, rowCount) + assert.Nil(t, primaryKey) +} + +func Test_NumpyParserSplitFieldsData(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + + cm := createLocalChunkManager(t) + parser := createNumpyParser(t) + + segmentData := make(map[storage.FieldID]storage.FieldData) + t.Run("segemnt data is empty", func(t *testing.T) { + err = parser.splitFieldsData(segmentData, nil) + assert.Error(t, err) + }) + + files := createSampleNumpyFiles(t, cm) + readers, err := parser.createReaders(files) + assert.NoError(t, err) + defer closeReaders(readers) + + for _, reader := range readers { + fieldData, err := parser.readData(reader, 100) + assert.NoError(t, err) + segmentData[reader.fieldID] = fieldData + } + + shards := make([]map[storage.FieldID]storage.FieldData, 0, parser.shardNum) + t.Run("shards number mismatch", func(t *testing.T) { + err = parser.splitFieldsData(segmentData, shards) + assert.Error(t, err) + }) + + t.Run("checkRowCount returns error", func(t *testing.T) { + parser.collectionSchema = &schemapb.CollectionSchema{ + Name: "schema", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 105, + Name: "FieldInt32", + IsPrimaryKey: false, + AutoID: false, + DataType: schemapb.DataType_Int32, + }, + }, + } + for i := 0; i < int(parser.shardNum); i++ { + shards = append(shards, initSegmentData(parser.collectionSchema)) + } + err = parser.splitFieldsData(segmentData, shards) + assert.Error(t, err) + parser.collectionSchema = sampleSchema() + }) + + t.Run("failed to alloc id", func(t *testing.T) { + ctx := context.Background() + parser.rowIDAllocator = newIDAllocator(ctx, t, errors.New("dummy error")) + err = parser.splitFieldsData(segmentData, shards) + assert.Error(t, err) + parser.rowIDAllocator = newIDAllocator(ctx, t, nil) + }) + + t.Run("primary key auto-generated", func(t *testing.T) { + schema := findSchema(parser.collectionSchema, schemapb.DataType_Int64) + schema.AutoID = true + + shards = make([]map[storage.FieldID]storage.FieldData, 0, parser.shardNum) + for i := 0; i < int(parser.shardNum); i++ { + segmentData := initSegmentData(parser.collectionSchema) + shards = append(shards, segmentData) + } + err = parser.splitFieldsData(segmentData, shards) + assert.NoError(t, err) + assert.NotEmpty(t, parser.autoIDRange) + + totalNum := 0 + for i := 0; i < int(parser.shardNum); i++ { + totalNum += shards[i][106].RowNum() + } + assert.Equal(t, segmentData[106].RowNum(), totalNum) + + // target field data is nil + shards[0][105] = nil + err = parser.splitFieldsData(segmentData, shards) + assert.Error(t, err) + + schema.AutoID = false + }) +} + +func Test_NumpyParserCalcRowCountPerBlock(t *testing.T) { + parser := createNumpyParser(t) + + // succeed + rowCount, err := parser.calcRowCountPerBlock() + assert.NoError(t, err) + assert.Greater(t, rowCount, int64(0)) + + // failed to estimate row size + parser.collectionSchema = &schemapb.CollectionSchema{ + Name: "schema", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 109, + Name: "FieldString", + IsPrimaryKey: false, + Description: "string", + DataType: schemapb.DataType_VarChar, + }, + }, + } + rowCount, err = parser.calcRowCountPerBlock() + assert.Error(t, err) + assert.Zero(t, rowCount) + + // no field + parser.collectionSchema = &schemapb.CollectionSchema{ + Name: "schema", + } + rowCount, err = parser.calcRowCountPerBlock() + assert.Error(t, err) + assert.Zero(t, rowCount) +} + +func Test_NumpyParserConsume(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + + cm := createLocalChunkManager(t) + parser := createNumpyParser(t) + + files := createSampleNumpyFiles(t, cm) + readers, err := parser.createReaders(files) + assert.NoError(t, err) + assert.Equal(t, len(sampleSchema().Fields), len(readers)) + + // succeed + err = parser.consume(readers) + assert.NoError(t, err) + closeReaders(readers) + + // row count mismatch + parser.blockSize = 1000 + readers, err = parser.createReaders(files) + assert.NoError(t, err) + parser.readData(readers[0], 1) + err = parser.consume(readers) + assert.Error(t, err) + + // invalid schema + parser.collectionSchema = &schemapb.CollectionSchema{ + Name: "schema", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 109, + Name: "dummy", + IsPrimaryKey: false, + DataType: schemapb.DataType_None, + }, + }, + } + err = parser.consume(readers) + assert.Error(t, err) + closeReaders(readers) +} + +func Test_NumpyParserParse(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + + parser := createNumpyParser(t) + parser.blockSize = 400 + + t.Run("validate file name failed", func(t *testing.T) { + files := []string{"dummy.npy"} + err = parser.Parse(files) + assert.Error(t, err) + }) + + t.Run("file doesnt exist", func(t *testing.T) { + parser.collectionSchema = perfSchema(4) + files := []string{"ID.npy", "Vector.npy"} + err = parser.Parse(files) + assert.Error(t, err) + parser.collectionSchema = sampleSchema() + }) + + t.Run("succeed", func(t *testing.T) { + cm := createLocalChunkManager(t) + files := createSampleNumpyFiles(t, cm) + + totalRowCount := 0 + parser.callFlushFunc = func(fields map[storage.FieldID]storage.FieldData, shardID int) error { + assert.LessOrEqual(t, int32(shardID), parser.shardNum) + rowCount := 0 + for _, fieldData := range fields { + if rowCount == 0 { + rowCount = fieldData.RowNum() + } else { + assert.Equal(t, rowCount, fieldData.RowNum()) } } - + totalRowCount += rowCount return nil } - checkFunc(data, "FieldBinaryVector", flushFunc) - }) - - t.Run("parse binary vector with float32", func(t *testing.T) { - data := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data), field.RowNum()) - - for i := 0; i < len(data); i++ { - row := field.GetRow(i).([]float32) - for k := 0; k < len(row); k++ { - assert.Equal(t, data[i][k], row[k]) - } - } - - return nil - } - checkFunc(data, "FieldFloatVector", flushFunc) - }) - - t.Run("parse binary vector with float64", func(t *testing.T) { - data := [][4]float64{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data), field.RowNum()) - - for i := 0; i < len(data); i++ { - row := field.GetRow(i).([]float32) - for k := 0; k < len(row); k++ { - assert.Equal(t, float32(data[i][k]), row[k]) - } - } - - return nil - } - checkFunc(data, "FieldFloatVector", flushFunc) + err = parser.Parse(files) + assert.NoError(t, err) + assert.Equal(t, 5, totalRowCount) }) } @@ -519,6 +823,8 @@ func Test_NumpyParserParse_perf(t *testing.T) { assert.Nil(t, err) defer os.RemoveAll(TempFilesPath) + cm := createLocalChunkManager(t) + tr := timerecord.NewTimeRecorder("numpy parse performance") // change the parameter to test performance @@ -528,37 +834,58 @@ func Test_NumpyParserParse_perf(t *testing.T) { dim = 128 ) - schema := perfSchema(dim) - - data := make([][dim]float32, 0) + idData := make([]int64, 0) + vecData := make([][dim]float32, 0) for i := 0; i < rowCount; i++ { var row [dim]float32 for k := 0; k < dim; k++ { row[k] = float32(i) + dotValue } - data = append(data, row) + vecData = append(vecData, row) + idData = append(idData, int64(i)) } tr.Record("generate large data") - flushFunc := func(field storage.FieldData) error { - assert.Equal(t, len(data), field.RowNum()) + createNpyFile := func(t *testing.T, fielName string, data interface{}) string { + filePath := TempFilesPath + fielName + ".npy" + content, err := CreateNumpyData(data) + assert.NoError(t, err) + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + return filePath + } + + idFilePath := createNpyFile(t, "ID", idData) + vecFilePath := createNpyFile(t, "Vector", vecData) + + tr.Record("generate large numpy files") + + shardNum := int32(3) + totalRowCount := 0 + callFlushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { + assert.LessOrEqual(t, int32(shardID), shardNum) + rowCount := 0 + for _, fieldData := range fields { + if rowCount == 0 { + rowCount = fieldData.RowNum() + } else { + assert.Equal(t, rowCount, fieldData.RowNum()) + } + } + totalRowCount += rowCount return nil } - filePath := TempFilesPath + "perf.npy" - err = CreateNumpyFile(filePath, data) + idAllocator := newIDAllocator(ctx, t, nil) + parser, err := NewNumpyParser(ctx, perfSchema(dim), idAllocator, shardNum, 16*1024*1024, cm, callFlushFunc) + assert.NoError(t, err) + assert.NotNil(t, parser) + parser.collectionSchema = perfSchema(dim) + + err = parser.Parse([]string{idFilePath, vecFilePath}) assert.Nil(t, err) + assert.Equal(t, rowCount, totalRowCount) - tr.Record("generate large numpy file " + filePath) - - file, err := os.Open(filePath) - assert.Nil(t, err) - defer file.Close() - - parser := NewNumpyParser(ctx, schema, flushFunc) - err = parser.Parse(file, "Vector", false) - assert.Nil(t, err) - - tr.Record("parse large numpy files: " + filePath) + tr.Record("parse large numpy files") }