Streaming read numpy (#21540)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
pull/21505/head
groot 2023-01-09 10:01:37 +08:00 committed by GitHub
parent fb2d7af3e0
commit 18aefb381c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1958 additions and 1509 deletions

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func isCanceled(ctx context.Context) bool {
@ -519,3 +520,22 @@ func getTypeName(dt schemapb.DataType) string {
return "InvalidType"
}
}
func pkToShard(pk interface{}, shardNum uint32) (uint32, error) {
var shard uint32
strPK, ok := interface{}(pk).(string)
if ok {
hash := typeutil.HashString2Uint32(strPK)
shard = hash % shardNum
} else {
intPK, ok := interface{}(pk).(int64)
if !ok {
log.Error("Numpy parser: primary key field must be int64 or varchar")
return 0, fmt.Errorf("primary key field must be int64 or varchar")
}
hash, _ := typeutil.Hash32Int64(intPK)
shard = hash % shardNum
}
return shard, nil
}

View File

@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert"
)
@ -611,3 +612,31 @@ func Test_GetTypeName(t *testing.T) {
str = getTypeName(schemapb.DataType_None)
assert.Equal(t, "InvalidType", str)
}
func Test_PkToShard(t *testing.T) {
a := int32(99)
shard, err := pkToShard(a, 2)
assert.Error(t, err)
assert.Zero(t, shard)
s := "abcdef"
shardNum := uint32(3)
shard, err = pkToShard(s, shardNum)
assert.NoError(t, err)
hash := typeutil.HashString2Uint32(s)
assert.Equal(t, hash%shardNum, shard)
pk := int64(100)
shardNum = uint32(4)
shard, err = pkToShard(pk, shardNum)
assert.NoError(t, err)
hash, _ = typeutil.Hash32Int64(pk)
assert.Equal(t, hash%shardNum, shard)
pk = int64(99999)
shardNum = uint32(5)
shard, err = pkToShard(pk, shardNum)
assert.NoError(t, err)
hash, _ = typeutil.Hash32Int64(pk)
assert.Equal(t, hash%shardNum, shard)
}

View File

@ -34,7 +34,6 @@ import (
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const (
@ -47,7 +46,7 @@ const (
// this limitation is to avoid this OOM risk:
// for column-based file, we read all its data into memory, if user input a large file, the read() method may
// cost extra memory and lear to OOM.
MaxFileSize = 1 * 1024 * 1024 * 1024 // 1GB
MaxFileSize = 16 * 1024 * 1024 * 1024 // 16GB
// this limitation is to avoid this OOM risk:
// simetimes system segment max size is a large number, a single segment fields data might cause OOM.
@ -175,42 +174,6 @@ func (p *ImportWrapper) Cancel() error {
return nil
}
func (p *ImportWrapper) validateColumnBasedFiles(filePaths []string, collectionSchema *schemapb.CollectionSchema) error {
requiredFieldNames := make(map[string]interface{})
for _, schema := range p.collectionSchema.Fields {
if schema.GetIsPrimaryKey() {
if !schema.GetAutoID() {
requiredFieldNames[schema.GetName()] = nil
}
} else {
requiredFieldNames[schema.GetName()] = nil
}
}
// check redundant file
fileNames := make(map[string]interface{})
for _, filePath := range filePaths {
name, _ := GetFileNameAndExt(filePath)
fileNames[name] = nil
_, ok := requiredFieldNames[name]
if !ok {
log.Error("import wrapper: the file has no corresponding field in collection", zap.String("fieldName", name))
return fmt.Errorf("the file '%s' has no corresponding field in collection", filePath)
}
}
// check missed file
for name := range requiredFieldNames {
_, ok := fileNames[name]
if !ok {
log.Error("import wrapper: there is no file corresponding to field", zap.String("fieldName", name))
return fmt.Errorf("there is no file corresponding to field '%s'", name)
}
}
return nil
}
// fileValidation verify the input paths
// if all the files are json type, return true
// if all the files are numpy type, return false, and not allow duplicate file name
@ -278,22 +241,6 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) {
totalSize += size
}
// especially for column-base, total size of files cannot exceed MaxTotalSizeInMemory
if totalSize > MaxTotalSizeInMemory {
log.Error("import wrapper: total size of files exceeds the maximum size", zap.Int64("totalSize", totalSize), zap.Int64("MaxTotalSize", MaxTotalSizeInMemory))
return rowBased, fmt.Errorf("total size(%d bytes) of all files exceeds the maximum size: %d bytes", totalSize, MaxTotalSizeInMemory)
}
// check redundant files for column-based import
// if the field is primary key and autoid is false, the file is required
// any redundant file is not allowed
if !rowBased {
err := p.validateColumnBasedFiles(filePaths, p.collectionSchema)
if err != nil {
return rowBased, err
}
}
return rowBased, nil
}
@ -337,84 +284,26 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error
triggerGC()
}
} else {
// parse and consume column-based files
// for column-based files, the XXXColumnConsumer only output map[string]storage.FieldData
// after all columns are parsed/consumed, we need to combine map[string]storage.FieldData into one
// and use splitFieldsData() to split fields data into segments according to shard number
fieldsData := initSegmentData(p.collectionSchema)
if fieldsData == nil {
log.Error("import wrapper: failed to initialize FieldData list")
return fmt.Errorf("failed to initialize FieldData list")
// parse and consume column-based files(currently support numpy)
// for column-based files, the NumpyParser will generate autoid for primary key, and split rows into segments
// according to shard number, so the flushFunc will be called in the NumpyParser
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths)
return p.flushFunc(fields, shardID)
}
rowCount := 0
// function to combine column data into fieldsData
combineFunc := func(fields map[storage.FieldID]storage.FieldData) error {
if len(fields) == 0 {
return nil
}
printFieldsDataInfo(fields, "import wrapper: combine field data", nil)
for k, v := range fields {
// ignore 0 row field
if v.RowNum() == 0 {
log.Warn("import wrapper: empty FieldData ignored", zap.Int64("fieldID", k))
continue
}
// ignore internal fields: RowIDField and TimeStampField
if k == common.RowIDField || k == common.TimeStampField {
log.Warn("import wrapper: internal fields should not be provided", zap.Int64("fieldID", k))
continue
}
// each column should be only combined once
data, ok := fieldsData[k]
if ok && data.RowNum() > 0 {
return fmt.Errorf("the field %d is duplicated", k)
}
// check the row count. only count non-zero row fields
if rowCount > 0 && rowCount != v.RowNum() {
return fmt.Errorf("the field %d row count %d doesn't equal to others row count: %d", k, v.RowNum(), rowCount)
}
rowCount = v.RowNum()
// assign column data to fieldsData
fieldsData[k] = v
}
return nil
}
// parse/validate/consume data
for i := 0; i < len(filePaths); i++ {
filePath := filePaths[i]
_, fileType := GetFileNameAndExt(filePath)
log.Info("import wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
if fileType == NumpyFileExt {
err = p.parseColumnBasedNumpy(filePath, options.OnlyValidate, combineFunc)
if err != nil {
log.Error("import wrapper: failed to parse column-based numpy file", zap.Error(err), zap.String("filePath", filePath))
return err
}
}
// no need to check else, since the fileValidation() already do this
}
// trigger after read finished
triggerGC()
// split fields data into segments
err := p.splitFieldsData(fieldsData, SingleBlockSize)
parser, err := NewNumpyParser(p.ctx, p.collectionSchema, p.rowIDAllocator, p.shardNum, SingleBlockSize, p.chunkManager, flushFunc)
if err != nil {
return err
}
// trigger after write finished
err = parser.Parse(filePaths)
if err != nil {
return err
}
p.importResult.AutoIds = append(p.importResult.AutoIds, parser.IDRange()...)
// trigger after parse finished
triggerGC()
}
@ -437,6 +326,7 @@ func (p *ImportWrapper) reportPersisted(reportAttempts uint, tr *timerecord.Time
// report file process state
p.importResult.State = commonpb.ImportState_ImportPersisted
log.Info("import wrapper: report import result", zap.Any("importResult", p.importResult))
// persist state task is valuable, retry more times in case fail this task only because of network error
reportErr := retry.Do(p.ctx, func() error {
return p.reportFunc(p.importResult)
@ -554,297 +444,12 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er
}
// for row-based files, auto-id is generated within JSONRowConsumer
if consumer != nil {
p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...)
}
p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...)
tr.Elapse("parsed")
return nil
}
// parseColumnBasedNumpy is the entry of column-based numpy import operation
func (p *ImportWrapper) parseColumnBasedNumpy(filePath string, onlyValidate bool,
combineFunc func(fields map[storage.FieldID]storage.FieldData) error) error {
tr := timerecord.NewTimeRecorder("numpy parser: " + filePath)
fileName, _ := GetFileNameAndExt(filePath)
// for minio storage, chunkManager will download file into local memory
// for local storage, chunkManager open the file directly
file, err := p.chunkManager.Reader(p.ctx, filePath)
if err != nil {
return err
}
defer file.Close()
var id storage.FieldID
var found = false
for _, field := range p.collectionSchema.Fields {
if field.GetName() == fileName {
id = field.GetFieldID()
found = true
break
}
}
// if the numpy file name is not mapping to a field name, ignore it
if !found {
return nil
}
// the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine
flushFunc := func(field storage.FieldData) error {
fields := make(map[storage.FieldID]storage.FieldData)
fields[id] = field
return combineFunc(fields)
}
// for numpy file, we say the file name(without extension) is the filed name
parser := NewNumpyParser(p.ctx, p.collectionSchema, flushFunc)
err = parser.Parse(file, fileName, onlyValidate)
if err != nil {
return err
}
tr.Elapse("parsed")
return nil
}
// appendFunc defines the methods to append data to storage.FieldData
func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error {
switch schema.DataType {
case schemapb.DataType_Bool:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.BoolFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(bool))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Float:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.FloatFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(float32))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Double:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.DoubleFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(float64))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int8:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int8FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int8))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int16:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int16FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int16))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int32:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int32FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int32))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int64:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int64FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int64))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_BinaryVector:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.BinaryVectorFieldData)
arr.Data = append(arr.Data, src.GetRow(n).([]byte)...)
arr.NumRows[0]++
return nil
}
case schemapb.DataType_FloatVector:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.FloatVectorFieldData)
arr.Data = append(arr.Data, src.GetRow(n).([]float32)...)
arr.NumRows[0]++
return nil
}
case schemapb.DataType_String, schemapb.DataType_VarChar:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.StringFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(string))
return nil
}
default:
return nil
}
}
// splitFieldsData is to split the in-memory data(parsed from column-based files) into blocks, each block save to a binlog file
func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, blockSize int64) error {
if len(fieldsData) == 0 {
log.Error("import wrapper: fields data is empty")
return fmt.Errorf("fields data is empty")
}
tr := timerecord.NewTimeRecorder("import wrapper: split field data")
defer tr.Elapse("finished")
// check existence of each field
// check row count, all fields row count must be equal
// firstly get the max row count
rowCount := 0
rowCounter := make(map[string]int)
var primaryKey *schemapb.FieldSchema
for i := 0; i < len(p.collectionSchema.Fields); i++ {
schema := p.collectionSchema.Fields[i]
if schema.GetIsPrimaryKey() {
primaryKey = schema
}
if !schema.GetAutoID() {
v, ok := fieldsData[schema.GetFieldID()]
if !ok {
log.Error("import wrapper: field not provided", zap.String("fieldName", schema.GetName()))
return fmt.Errorf("field '%s' not provided", schema.GetName())
}
rowCounter[schema.GetName()] = v.RowNum()
if v.RowNum() > rowCount {
rowCount = v.RowNum()
}
}
}
if primaryKey == nil {
log.Error("import wrapper: primary key field is not found")
return fmt.Errorf("primary key field is not found")
}
for name, count := range rowCounter {
if count != rowCount {
log.Error("import wrapper: field row count is not equal to other fields row count", zap.String("fieldName", name),
zap.Int("rowCount", count), zap.Int("otherRowCount", rowCount))
return fmt.Errorf("field '%s' row count %d is not equal to other fields row count: %d", name, count, rowCount)
}
}
log.Info("import wrapper: try to split a block with row count", zap.Int("rowCount", rowCount))
primaryData, ok := fieldsData[primaryKey.GetFieldID()]
if !ok {
log.Error("import wrapper: primary key field is not provided", zap.String("keyName", primaryKey.GetName()))
return fmt.Errorf("primary key field is not provided")
}
// generate auto id for primary key and rowid field
rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount))
if err != nil {
log.Error("import wrapper: failed to alloc row ID", zap.Error(err))
return fmt.Errorf("failed to alloc row ID, error: %w", err)
}
rowIDField := fieldsData[common.RowIDField]
rowIDFieldArr := rowIDField.(*storage.Int64FieldData)
for i := rowIDBegin; i < rowIDEnd; i++ {
rowIDFieldArr.Data = append(rowIDFieldArr.Data, i)
}
if primaryKey.GetAutoID() {
log.Info("import wrapper: generating auto-id", zap.Int("rowCount", rowCount), zap.Int64("rowIDBegin", rowIDBegin))
// reset the primary keys, as we know, only int64 pk can be auto-generated
primaryDataArr := &storage.Int64FieldData{
NumRows: []int64{int64(rowCount)},
Data: make([]int64, 0, rowCount),
}
for i := rowIDBegin; i < rowIDEnd; i++ {
primaryDataArr.Data = append(primaryDataArr.Data, i)
}
primaryData = primaryDataArr
fieldsData[primaryKey.GetFieldID()] = primaryData
p.importResult.AutoIds = append(p.importResult.AutoIds, rowIDBegin, rowIDEnd)
}
if primaryData.RowNum() <= 0 {
log.Error("import wrapper: primary key is not provided", zap.String("keyName", primaryKey.GetName()))
return fmt.Errorf("the primary key '%s' is not provided", primaryKey.GetName())
}
// prepare segemnts
segmentsData := make([]map[storage.FieldID]storage.FieldData, 0, p.shardNum)
for i := 0; i < int(p.shardNum); i++ {
segmentData := initSegmentData(p.collectionSchema)
if segmentData == nil {
log.Error("import wrapper: failed to initialize FieldData list")
return fmt.Errorf("failed to initialize FieldData list")
}
segmentsData = append(segmentsData, segmentData)
}
// prepare append functions
appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error)
for i := 0; i < len(p.collectionSchema.Fields); i++ {
schema := p.collectionSchema.Fields[i]
appendFuncErr := p.appendFunc(schema)
if appendFuncErr == nil {
log.Error("import wrapper: unsupported field data type")
return fmt.Errorf("unsupported field data type: %d", schema.GetDataType())
}
appendFunctions[schema.GetName()] = appendFuncErr
}
// split data into shards
for i := 0; i < rowCount; i++ {
// hash to a shard number
var shard uint32
pk := primaryData.GetRow(i)
strPK, ok := interface{}(pk).(string)
if ok {
hash := typeutil.HashString2Uint32(strPK)
shard = hash % uint32(p.shardNum)
} else {
intPK, ok := interface{}(pk).(int64)
if !ok {
log.Error("import wrapper: primary key field must be int64 or varchar")
return fmt.Errorf("primary key field must be int64 or varchar")
}
hash, _ := typeutil.Hash32Int64(intPK)
shard = hash % uint32(p.shardNum)
}
// set rowID field
rowIDField := segmentsData[shard][common.RowIDField].(*storage.Int64FieldData)
rowIDField.Data = append(rowIDField.Data, rowIDFieldArr.GetRow(i).(int64))
// append row to shard
for k := 0; k < len(p.collectionSchema.Fields); k++ {
schema := p.collectionSchema.Fields[k]
srcData := fieldsData[schema.GetFieldID()]
targetData := segmentsData[shard][schema.GetFieldID()]
appendFunc := appendFunctions[schema.GetName()]
err := appendFunc(srcData, i, targetData)
if err != nil {
return err
}
}
// when the estimated size is close to blockSize, force flush
err = tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, false)
if err != nil {
return err
}
}
// force flush at the end
return tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, true)
}
// flushFunc is the callback function for parsers generate segment and save binlog files
func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData, shardID int) error {
// if fields data is empty, do nothing

View File

@ -579,79 +579,6 @@ func Test_ImportWrapperRowBased_perf(t *testing.T) {
tr.Record("parse large json file " + filePath)
}
func Test_ImportWrapperValidateColumnBasedFiles(t *testing.T) {
ctx := context.Background()
cm := &MockChunkManager{
size: 1,
}
idAllocator := newIDAllocator(ctx, t, nil)
shardNum := 2
segmentSize := 512 // unit: MB
schema := &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "ID",
IsPrimaryKey: true,
AutoID: true,
DataType: schemapb.DataType_Int64,
},
{
FieldID: 102,
Name: "Age",
IsPrimaryKey: false,
DataType: schemapb.DataType_Int64,
},
{
FieldID: 103,
Name: "Vector",
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "10"},
},
},
},
}
wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
// file for PK is redundant
files := []string{"ID.npy", "Age.npy", "Vector.npy"}
err := wrapper.validateColumnBasedFiles(files, schema)
assert.NotNil(t, err)
// file for PK is not redundant
schema.Fields[0].AutoID = false
err = wrapper.validateColumnBasedFiles(files, schema)
assert.Nil(t, err)
// file missed
files = []string{"Age.npy", "Vector.npy"}
err = wrapper.validateColumnBasedFiles(files, schema)
assert.NotNil(t, err)
files = []string{"ID.npy", "Vector.npy"}
err = wrapper.validateColumnBasedFiles(files, schema)
assert.NotNil(t, err)
// redundant file
files = []string{"ID.npy", "Age.npy", "Vector.npy", "dummy.npy"}
err = wrapper.validateColumnBasedFiles(files, schema)
assert.NotNil(t, err)
// correct input
files = []string{"ID.npy", "Age.npy", "Vector.npy"}
err = wrapper.validateColumnBasedFiles(files, schema)
assert.Nil(t, err)
}
func Test_ImportWrapperFileValidation(t *testing.T) {
ctx := context.Background()
@ -668,7 +595,7 @@ func Test_ImportWrapperFileValidation(t *testing.T) {
FieldID: 101,
Name: "uid",
IsPrimaryKey: true,
AutoID: false,
AutoID: true,
DataType: schemapb.DataType_Int64,
},
{
@ -684,84 +611,76 @@ func Test_ImportWrapperFileValidation(t *testing.T) {
wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
// unsupported file type
files := []string{"uid.txt"}
rowBased, err := wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.False(t, rowBased)
t.Run("unsupported file type", func(t *testing.T) {
files := []string{"uid.txt"}
rowBased, err := wrapper.fileValidation(files)
assert.Error(t, err)
assert.False(t, rowBased)
})
// file missed
files = []string{"uid.npy"}
rowBased, err = wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.False(t, rowBased)
t.Run("duplicate files", func(t *testing.T) {
files := []string{"a/1.json", "b/1.json"}
rowBased, err := wrapper.fileValidation(files)
assert.Error(t, err)
assert.True(t, rowBased)
// redundant file
files = []string{"uid.npy", "b/bol.npy", "c/no.npy"}
rowBased, err = wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.False(t, rowBased)
files = []string{"a/uid.npy", "uid.npy", "b/bol.npy"}
rowBased, err = wrapper.fileValidation(files)
assert.Error(t, err)
assert.False(t, rowBased)
})
// duplicate files
files = []string{"a/1.json", "b/1.json"}
rowBased, err = wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.True(t, rowBased)
t.Run("unsupported file for row-based", func(t *testing.T) {
files := []string{"a/uid.json", "b/bol.npy"}
rowBased, err := wrapper.fileValidation(files)
assert.Error(t, err)
assert.True(t, rowBased)
})
files = []string{"a/uid.npy", "uid.npy", "b/bol.npy"}
rowBased, err = wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.False(t, rowBased)
t.Run("unsupported file for column-based", func(t *testing.T) {
files := []string{"a/uid.npy", "b/bol.json"}
rowBased, err := wrapper.fileValidation(files)
assert.Error(t, err)
assert.False(t, rowBased)
})
// unsupported file for row-based
files = []string{"a/uid.json", "b/bol.npy"}
rowBased, err = wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.True(t, rowBased)
t.Run("valid cases", func(t *testing.T) {
files := []string{"a/1.json", "b/2.json"}
rowBased, err := wrapper.fileValidation(files)
assert.NoError(t, err)
assert.True(t, rowBased)
// unsupported file for column-based
files = []string{"a/uid.npy", "b/bol.json"}
rowBased, err = wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.False(t, rowBased)
files = []string{"a/uid.npy", "b/bol.npy"}
rowBased, err = wrapper.fileValidation(files)
assert.NoError(t, err)
assert.False(t, rowBased)
})
// valid cases
files = []string{"a/1.json", "b/2.json"}
rowBased, err = wrapper.fileValidation(files)
assert.Nil(t, err)
assert.True(t, rowBased)
t.Run("empty file", func(t *testing.T) {
files := []string{}
cm.size = 0
wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
rowBased, err := wrapper.fileValidation(files)
assert.NoError(t, err)
assert.False(t, rowBased)
})
files = []string{"a/uid.npy", "b/bol.npy"}
rowBased, err = wrapper.fileValidation(files)
assert.Nil(t, err)
assert.False(t, rowBased)
t.Run("file size exceed MaxFileSize limit", func(t *testing.T) {
files := []string{"a/1.json"}
cm.size = MaxFileSize + 1
wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
rowBased, err := wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.True(t, rowBased)
})
// empty file
cm.size = 0
wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
rowBased, err = wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.False(t, rowBased)
// file size exceed MaxFileSize limit
cm.size = MaxFileSize + 1
wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
rowBased, err = wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.False(t, rowBased)
// total files size exceed MaxTotalSizeInMemory limit
cm.size = MaxFileSize - 1
files = append(files, "3.npy")
rowBased, err = wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.False(t, rowBased)
// failed to get file size
cm.sizeErr = errors.New("error")
rowBased, err = wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.False(t, rowBased)
t.Run("failed to get file size", func(t *testing.T) {
files := []string{"a/1.json"}
cm.sizeErr = errors.New("error")
rowBased, err := wrapper.fileValidation(files)
assert.NotNil(t, err)
assert.True(t, rowBased)
})
}
func Test_ImportWrapperReportFailRowBased(t *testing.T) {
@ -1001,122 +920,6 @@ func Test_ImportWrapperDoBinlogImport(t *testing.T) {
assert.Nil(t, err)
}
func Test_ImportWrapperSplitFieldsData(t *testing.T) {
ctx := context.Background()
cm := &MockChunkManager{}
idAllocator := newIDAllocator(ctx, t, nil)
rowCounter := &rowCounterTest{}
assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter)
importResult := &rootcoordpb.ImportResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
TaskId: 1,
DatanodeId: 1,
State: commonpb.ImportState_ImportStarted,
Segments: make([]int64, 0),
AutoIds: make([]int64, 0),
RowCount: 0,
}
reportFunc := func(res *rootcoordpb.ImportResult) error {
return nil
}
schema := &schemapb.CollectionSchema{
Name: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "uid",
IsPrimaryKey: true,
AutoID: true,
DataType: schemapb.DataType_Int64,
},
{
FieldID: 102,
Name: "flag",
IsPrimaryKey: false,
DataType: schemapb.DataType_Bool,
},
},
}
wrapper := NewImportWrapper(ctx, schema, 2, 1024*1024, idAllocator, cm, importResult, reportFunc)
wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc)
// nil input
err := wrapper.splitFieldsData(nil, 0)
assert.NotNil(t, err)
// split 100 rows to 4 blocks, success
rowCount := 100
input := initSegmentData(schema)
for j := 0; j < rowCount; j++ {
pkField := input[101].(*storage.Int64FieldData)
pkField.Data = append(pkField.Data, int64(j))
flagField := input[102].(*storage.BoolFieldData)
flagField.Data = append(flagField.Data, true)
}
err = wrapper.splitFieldsData(input, 512)
assert.Nil(t, err)
assert.Equal(t, 2, len(importResult.AutoIds))
assert.Equal(t, 4, rowCounter.callTime)
assert.Equal(t, rowCount, rowCounter.rowCount)
// alloc id failed
wrapper.rowIDAllocator = newIDAllocator(ctx, t, errors.New("error"))
err = wrapper.splitFieldsData(input, 512)
assert.NotNil(t, err)
wrapper.rowIDAllocator = newIDAllocator(ctx, t, nil)
// row count of fields are unequal
schema.Fields[0].AutoID = false
input = initSegmentData(schema)
for j := 0; j < rowCount; j++ {
pkField := input[101].(*storage.Int64FieldData)
pkField.Data = append(pkField.Data, int64(j))
if j%2 == 0 {
continue
}
flagField := input[102].(*storage.BoolFieldData)
flagField.Data = append(flagField.Data, true)
}
err = wrapper.splitFieldsData(input, 512)
assert.NotNil(t, err)
// primary key not found
wrapper.collectionSchema.Fields[0].IsPrimaryKey = false
err = wrapper.splitFieldsData(input, 512)
assert.NotNil(t, err)
wrapper.collectionSchema.Fields[0].IsPrimaryKey = true
// primary key is varchar, success
wrapper.collectionSchema.Fields[0].DataType = schemapb.DataType_VarChar
input = initSegmentData(schema)
for j := 0; j < rowCount; j++ {
pkField := input[101].(*storage.StringFieldData)
pkField.Data = append(pkField.Data, strconv.FormatInt(int64(j), 10))
flagField := input[102].(*storage.BoolFieldData)
flagField.Data = append(flagField.Data, true)
}
rowCounter.callTime = 0
rowCounter.rowCount = 0
importResult.AutoIds = []int64{}
err = wrapper.splitFieldsData(input, 1024)
assert.Nil(t, err)
assert.Equal(t, 0, len(importResult.AutoIds))
assert.Equal(t, 2, rowCounter.callTime)
assert.Equal(t, rowCount, rowCounter.rowCount)
}
func Test_ImportWrapperReportPersisted(t *testing.T) {
ctx := context.Background()
tr := timerecord.NewTimeRecorder("test")

View File

@ -35,15 +35,11 @@ import (
const (
// root field of row-based json format
RowRootNode = "rows"
// minimal size of a buffer
MinBufferSize = 1024
// split file into batches no more than this count
MaxBatchCount = 16
)
type JSONParser struct {
ctx context.Context // for canceling parse process
bufSize int64 // max rows in a buffer
bufRowCount int // max rows in a buffer
fields map[string]int64 // fields need to be parsed
name2FieldID map[string]storage.FieldID
}
@ -69,7 +65,7 @@ func NewJSONParser(ctx context.Context, collectionSchema *schemapb.CollectionSch
parser := &JSONParser{
ctx: ctx,
bufSize: MinBufferSize,
bufRowCount: 1024,
fields: fields,
name2FieldID: name2FieldID,
}
@ -84,19 +80,24 @@ func adjustBufSize(parser *JSONParser, collectionSchema *schemapb.CollectionSche
return
}
// split the file into no more than MaxBatchCount batches to parse
// for high dimensional vector, the bufSize is a small value, read few rows each time
// for low dimensional vector, the bufSize is a large value, read more rows each time
maxRows := MaxFileSize / sizePerRecord
bufSize := maxRows / MaxBatchCount
// bufSize should not be less than MinBufferSize
if bufSize < MinBufferSize {
bufSize = MinBufferSize
bufRowCount := parser.bufRowCount
for {
if bufRowCount*sizePerRecord > SingleBlockSize {
bufRowCount--
} else {
break
}
}
log.Info("JSON parser: reset bufSize", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufSize", bufSize))
parser.bufSize = int64(bufSize)
// at least one row per buffer
if bufRowCount <= 0 {
bufRowCount = 1
}
log.Info("JSON parser: reset bufRowCount", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufRowCount", bufRowCount))
parser.bufRowCount = bufRowCount
}
func (p *JSONParser) verifyRow(raw interface{}) (map[storage.FieldID]interface{}, error) {
@ -185,7 +186,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
}
// read buffer
buf := make([]map[storage.FieldID]interface{}, 0, MinBufferSize)
buf := make([]map[storage.FieldID]interface{}, 0, p.bufRowCount)
for dec.More() {
var value interface{}
if err := dec.Decode(&value); err != nil {
@ -199,7 +200,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
}
buf = append(buf, row)
if len(buf) >= int(p.bufSize) {
if len(buf) >= p.bufRowCount {
isEmpty = false
if err = handler.Handle(buf); err != nil {
log.Error("JSON parser: failed to convert row value to entity", zap.Error(err))
@ -207,7 +208,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
}
// clear the buffer
buf = make([]map[storage.FieldID]interface{}, 0, MinBufferSize)
buf = make([]map[storage.FieldID]interface{}, 0, p.bufRowCount)
}
}

View File

@ -28,7 +28,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert"
)
@ -58,11 +57,7 @@ func Test_AdjustBufSize(t *testing.T) {
schema := sampleSchema()
parser := NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
sizePerRecord, err := typeutil.EstimateSizePerRecord(schema)
assert.Nil(t, err)
assert.Greater(t, sizePerRecord, 0)
assert.Equal(t, MaxBatchCount, MaxFileSize/(sizePerRecord*int(parser.bufSize)))
assert.Greater(t, parser.bufRowCount, 0)
// huge row
schema.Fields[9].TypeParams = []*commonpb.KeyValuePair{
@ -70,9 +65,7 @@ func Test_AdjustBufSize(t *testing.T) {
}
parser = NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
sizePerRecord, _ = typeutil.EstimateSizePerRecord(schema)
assert.Equal(t, 7, MaxFileSize/(sizePerRecord*int(parser.bufSize)))
assert.Greater(t, parser.bufRowCount, 0)
// no change
schema = &schemapb.CollectionSchema{
@ -83,8 +76,7 @@ func Test_AdjustBufSize(t *testing.T) {
}
parser = NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
assert.Equal(t, int64(MinBufferSize), parser.bufSize)
assert.Greater(t, parser.bufRowCount, 0)
}
func Test_JSONParserParseRows_IntPK(t *testing.T) {
@ -127,8 +119,8 @@ func Test_JSONParserParseRows_IntPK(t *testing.T) {
}
t.Run("parse success", func(t *testing.T) {
// set bufSize = 4, means call handle() after reading 4 rows
parser.bufSize = 4
// set bufRowCount = 4, means call handle() after reading 4 rows
parser.bufRowCount = 4
err = parser.ParseRows(reader, consumer)
assert.Nil(t, err)
assert.Equal(t, len(content.Rows), len(consumer.rows))
@ -285,12 +277,12 @@ func Test_JSONParserParseRows_IntPK(t *testing.T) {
}`
consumer.handleErr = errors.New("error")
reader = strings.NewReader(content)
parser.bufSize = 2
parser.bufRowCount = 2
err = parser.ParseRows(reader, consumer)
assert.NotNil(t, err)
reader = strings.NewReader(content)
parser.bufSize = 5
parser.bufRowCount = 5
err = parser.ParseRows(reader, consumer)
assert.NotNil(t, err)

View File

@ -86,11 +86,13 @@ type NumpyAdapter struct {
func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) {
r, err := npyio.NewReader(reader)
if err != nil {
log.Error("Numpy adapter: failed to read numpy header", zap.Error(err))
return nil, err
}
dataType, err := convertNumpyType(r.Header.Descr.Type)
if err != nil {
log.Error("Numpy adapter: failed to detect data type", zap.Error(err))
return nil, err
}
@ -109,12 +111,11 @@ func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) {
zap.Uint8("minorVer", r.Header.Minor),
zap.String("ByteOrder", adapter.order.String()))
return adapter, err
return adapter, nil
}
// convertNumpyType gets data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector)
func convertNumpyType(typeStr string) (schemapb.DataType, error) {
log.Info("Numpy adapter: parse numpy file dtype", zap.String("dtype", typeStr))
switch typeStr {
case "b1", "<b1", "|b1", "bool":
return schemapb.DataType_Bool, nil
@ -252,24 +253,29 @@ func (n *NumpyAdapter) checkCount(count int) int {
func (n *NumpyAdapter) ReadBool(count int) ([]bool, error) {
if count <= 0 {
log.Error("Numpy adapter: cannot read bool data with a zero or nagative count")
return nil, errors.New("cannot read bool data with a zero or nagative count")
}
// incorrect type
if n.dataType != schemapb.DataType_Bool {
log.Error("Numpy adapter: numpy data is not bool type")
return nil, errors.New("numpy data is not bool type")
}
// avoid read overflow
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("end of bool file, nothing to read")
// end of file, nothing to read
log.Info("Numpy adapter: read to end of file, type: bool")
return nil, nil
}
// read data
data := make([]bool, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
log.Error("Numpy adapter: failed to read bool data", zap.Int("count", count), zap.Error(err))
return nil, fmt.Errorf(" failed to read bool data with count %d, error: %w", readSize, err)
}
@ -281,6 +287,7 @@ func (n *NumpyAdapter) ReadBool(count int) ([]bool, error) {
func (n *NumpyAdapter) ReadUint8(count int) ([]uint8, error) {
if count <= 0 {
log.Error("Numpy adapter: cannot read uint8 data with a zero or nagative count")
return nil, errors.New("cannot read uint8 data with a zero or nagative count")
}
@ -289,19 +296,23 @@ func (n *NumpyAdapter) ReadUint8(count int) ([]uint8, error) {
switch n.npyReader.Header.Descr.Type {
case "u1", "<u1", "|u1", "uint8":
default:
log.Error("Numpy adapter: numpy data is not uint8 type")
return nil, errors.New("numpy data is not uint8 type")
}
// avoid read overflow
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("end of uint8 file, nothing to read")
// end of file, nothing to read
log.Info("Numpy adapter: read to end of file, type: uint8")
return nil, nil
}
// read data
data := make([]uint8, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
log.Error("Numpy adapter: failed to read uint8 data", zap.Int("count", count), zap.Error(err))
return nil, fmt.Errorf("failed to read uint8 data with count %d, error: %w", readSize, err)
}
@ -313,24 +324,29 @@ func (n *NumpyAdapter) ReadUint8(count int) ([]uint8, error) {
func (n *NumpyAdapter) ReadInt8(count int) ([]int8, error) {
if count <= 0 {
log.Error("Numpy adapter: cannot read int8 data with a zero or nagative count")
return nil, errors.New("cannot read int8 data with a zero or nagative count")
}
// incorrect type
if n.dataType != schemapb.DataType_Int8 {
log.Error("Numpy adapter: numpy data is not int8 type")
return nil, errors.New("numpy data is not int8 type")
}
// avoid read overflow
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("end of int8 file, nothing to read")
// end of file, nothing to read
log.Info("Numpy adapter: read to end of file, type: int8")
return nil, nil
}
// read data
data := make([]int8, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
log.Error("Numpy adapter: failed to read int8 data", zap.Int("count", count), zap.Error(err))
return nil, fmt.Errorf("failed to read int8 data with count %d, error: %w", readSize, err)
}
@ -342,24 +358,29 @@ func (n *NumpyAdapter) ReadInt8(count int) ([]int8, error) {
func (n *NumpyAdapter) ReadInt16(count int) ([]int16, error) {
if count <= 0 {
log.Error("Numpy adapter: cannot read int16 data with a zero or nagative count")
return nil, errors.New("cannot read int16 data with a zero or nagative count")
}
// incorrect type
if n.dataType != schemapb.DataType_Int16 {
log.Error("Numpy adapter: numpy data is not int16 type")
return nil, errors.New("numpy data is not int16 type")
}
// avoid read overflow
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("end of int16 file, nothing to read")
// end of file, nothing to read
log.Info("Numpy adapter: read to end of file, type: int16")
return nil, nil
}
// read data
data := make([]int16, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
log.Error("Numpy adapter: failed to read int16 data", zap.Int("count", count), zap.Error(err))
return nil, fmt.Errorf("failed to read int16 data with count %d, error: %w", readSize, err)
}
@ -371,24 +392,29 @@ func (n *NumpyAdapter) ReadInt16(count int) ([]int16, error) {
func (n *NumpyAdapter) ReadInt32(count int) ([]int32, error) {
if count <= 0 {
log.Error("Numpy adapter: cannot read int32 data with a zero or nagative count")
return nil, errors.New("cannot read int32 data with a zero or nagative count")
}
// incorrect type
if n.dataType != schemapb.DataType_Int32 {
log.Error("Numpy adapter: numpy data is not int32 type")
return nil, errors.New("numpy data is not int32 type")
}
// avoid read overflow
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("end of int32 file, nothing to read")
// end of file, nothing to read
log.Info("Numpy adapter: read to end of file, type: int32")
return nil, nil
}
// read data
data := make([]int32, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
log.Error("Numpy adapter: failed to read int32 data", zap.Int("count", count), zap.Error(err))
return nil, fmt.Errorf("failed to read int32 data with count %d, error: %w", readSize, err)
}
@ -400,24 +426,29 @@ func (n *NumpyAdapter) ReadInt32(count int) ([]int32, error) {
func (n *NumpyAdapter) ReadInt64(count int) ([]int64, error) {
if count <= 0 {
log.Error("Numpy adapter: cannot read int64 data with a zero or nagative count")
return nil, errors.New("cannot read int64 data with a zero or nagative count")
}
// incorrect type
if n.dataType != schemapb.DataType_Int64 {
log.Error("Numpy adapter: numpy data is not int64 type")
return nil, errors.New("numpy data is not int64 type")
}
// avoid read overflow
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("end of int64 file, nothing to read")
// end of file, nothing to read
log.Info("Numpy adapter: read to end of file, type: int64")
return nil, nil
}
// read data
data := make([]int64, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
log.Error("Numpy adapter: failed to read int64 data", zap.Int("count", count), zap.Error(err))
return nil, fmt.Errorf("failed to read int64 data with count %d, error: %w", readSize, err)
}
@ -429,24 +460,29 @@ func (n *NumpyAdapter) ReadInt64(count int) ([]int64, error) {
func (n *NumpyAdapter) ReadFloat32(count int) ([]float32, error) {
if count <= 0 {
log.Error("Numpy adapter: cannot read float32 data with a zero or nagative count")
return nil, errors.New("cannot read float32 data with a zero or nagative count")
}
// incorrect type
if n.dataType != schemapb.DataType_Float {
log.Error("Numpy adapter: numpy data is not float32 type")
return nil, errors.New("numpy data is not float32 type")
}
// avoid read overflow
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("end of float32 file, nothing to read")
// end of file, nothing to read
log.Info("Numpy adapter: read to end of file, type: float32")
return nil, nil
}
// read data
data := make([]float32, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
log.Error("Numpy adapter: failed to read float32 data", zap.Int("count", count), zap.Error(err))
return nil, fmt.Errorf("failed to read float32 data with count %d, error: %w", readSize, err)
}
@ -458,24 +494,29 @@ func (n *NumpyAdapter) ReadFloat32(count int) ([]float32, error) {
func (n *NumpyAdapter) ReadFloat64(count int) ([]float64, error) {
if count <= 0 {
log.Error("Numpy adapter: cannot read float64 data with a zero or nagative count")
return nil, errors.New("cannot read float64 data with a zero or nagative count")
}
// incorrect type
if n.dataType != schemapb.DataType_Double {
log.Error("Numpy adapter: numpy data is not float64 type")
return nil, errors.New("numpy data is not float64 type")
}
// avoid read overflow
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("end of float64 file, nothing to read")
// end of file, nothing to read
log.Info("Numpy adapter: read to end of file, type: float64")
return nil, nil
}
// read data
data := make([]float64, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
log.Error("Numpy adapter: failed to read float64 data", zap.Int("count", count), zap.Error(err))
return nil, fmt.Errorf("failed to read float64 data with count %d, error: %w", readSize, err)
}
@ -487,12 +528,14 @@ func (n *NumpyAdapter) ReadFloat64(count int) ([]float64, error) {
func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
if count <= 0 {
return nil, errors.New("cannot read varhar data with a zero or nagative count")
log.Error("Numpy adapter: cannot read varchar data with a zero or nagative count")
return nil, errors.New("cannot read varchar data with a zero or nagative count")
}
// incorrect type
if n.dataType != schemapb.DataType_VarChar {
return nil, errors.New("numpy data is not varhar type")
log.Error("Numpy adapter: numpy data is not varchar type")
return nil, errors.New("numpy data is not varchar type")
}
// varchar length, this is the max length, some item is shorter than this length, but they also occupy bytes of max length
@ -501,12 +544,19 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
log.Error("Numpy adapter: failed to get max length of varchar from numpy file header", zap.Int("maxLen", maxLen), zap.Error(err))
return nil, fmt.Errorf("failed to get max length %d of varchar from numpy file header, error: %w", maxLen, err)
}
log.Info("Numpy adapter: get varchar max length from numpy file header", zap.Int("maxLen", maxLen), zap.Bool("utf", utf))
// log.Info("Numpy adapter: get varchar max length from numpy file header", zap.Int("maxLen", maxLen), zap.Bool("utf", utf))
// avoid read overflow
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("end of varhar file, nothing to read")
// end of file, nothing to read
log.Info("Numpy adapter: read to end of file, type: varchar")
return nil, nil
}
if n.reader == nil {
log.Error("Numpy adapter: reader is nil")
return nil, errors.New("numpy reader is nil")
}
// read data
@ -545,6 +595,7 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
if n > 0 {
buf = buf[:n]
}
data = append(data, string(buf))
}
}
@ -557,6 +608,7 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) {
if len(src)%4 != 0 {
log.Error("Numpy adapter: invalid utf32 bytes length, the byte array length should be multiple of 4", zap.Int("byteLen", len(src)))
return "", fmt.Errorf("invalid utf32 bytes length %d, the byte array length should be multiple of 4", len(src))
}
@ -586,6 +638,7 @@ func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) {
decoder := unicode.UTF16(uOrder, unicode.IgnoreBOM).NewDecoder()
res, err := decoder.Bytes(src[lowbytesPosition : lowbytesPosition+2])
if err != nil {
log.Error("Numpy adapter: failed to decode utf32 binary bytes", zap.Error(err))
return "", fmt.Errorf("failed to decode utf32 binary bytes, error: %w", err)
}
str += string(res)
@ -603,6 +656,7 @@ func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) {
utf8Code := make([]byte, 4)
utf8.EncodeRune(utf8Code, r)
if r == utf8.RuneError {
log.Error("Numpy adapter: failed to convert 4 bytes unicode to utf8 rune", zap.Uint32("code", x))
return "", fmt.Errorf("failed to convert 4 bytes unicode %d to utf8 rune", x)
}
str += string(utf8Code)

View File

@ -17,6 +17,7 @@
package importutil
import (
"bytes"
"encoding/binary"
"io"
"os"
@ -40,12 +41,12 @@ func Test_CreateNumpyFile(t *testing.T) {
// directory doesn't exist
data1 := []float32{1, 2, 3, 4, 5}
err := CreateNumpyFile("/dummy_not_exist/dummy.npy", data1)
assert.NotNil(t, err)
assert.Error(t, err)
// invalid data type
data2 := make(map[string]int)
err = CreateNumpyFile("/tmp/dummy.npy", data2)
assert.NotNil(t, err)
assert.Error(t, err)
}
func Test_CreateNumpyData(t *testing.T) {
@ -53,12 +54,12 @@ func Test_CreateNumpyData(t *testing.T) {
data1 := []float32{1, 2, 3, 4, 5}
buf, err := CreateNumpyData(data1)
assert.NotNil(t, buf)
assert.Nil(t, err)
assert.NoError(t, err)
// invalid data type
data2 := make(map[string]int)
buf, err = CreateNumpyData(data2)
assert.NotNil(t, err)
assert.Error(t, err)
assert.Nil(t, buf)
}
@ -66,7 +67,7 @@ func Test_ConvertNumpyType(t *testing.T) {
checkFunc := func(inputs []string, output schemapb.DataType) {
for i := 0; i < len(inputs); i++ {
dt, err := convertNumpyType(inputs[i])
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, output, dt)
}
}
@ -80,7 +81,7 @@ func Test_ConvertNumpyType(t *testing.T) {
checkFunc([]string{"f8", "<f8", "|f8", ">f8", "float64"}, schemapb.DataType_Double)
dt, err := convertNumpyType("dummy")
assert.NotNil(t, err)
assert.Error(t, err)
assert.Equal(t, schemapb.DataType_None, dt)
}
@ -88,25 +89,25 @@ func Test_StringLen(t *testing.T) {
len, utf, err := stringLen("S1")
assert.Equal(t, 1, len)
assert.False(t, utf)
assert.Nil(t, err)
assert.NoError(t, err)
len, utf, err = stringLen("2S")
assert.Equal(t, 2, len)
assert.False(t, utf)
assert.Nil(t, err)
assert.NoError(t, err)
len, utf, err = stringLen("<U3")
assert.Equal(t, 3, len)
assert.True(t, utf)
assert.Nil(t, err)
assert.NoError(t, err)
len, utf, err = stringLen(">4U")
assert.Equal(t, 4, len)
assert.True(t, utf)
assert.Nil(t, err)
assert.NoError(t, err)
len, utf, err = stringLen("dummy")
assert.NotNil(t, err)
assert.Error(t, err)
assert.Equal(t, 0, len)
assert.False(t, utf)
}
@ -129,207 +130,337 @@ func Test_NumpyAdapterSetByteOrder(t *testing.T) {
}
func Test_NumpyAdapterReadError(t *testing.T) {
adapter := &NumpyAdapter{
reader: nil,
npyReader: nil,
}
// reader size is zero
t.Run("test size is zero", func(t *testing.T) {
_, err := adapter.ReadBool(0)
assert.NotNil(t, err)
_, err = adapter.ReadUint8(0)
assert.NotNil(t, err)
_, err = adapter.ReadInt8(0)
assert.NotNil(t, err)
_, err = adapter.ReadInt16(0)
assert.NotNil(t, err)
_, err = adapter.ReadInt32(0)
assert.NotNil(t, err)
_, err = adapter.ReadInt64(0)
assert.NotNil(t, err)
_, err = adapter.ReadFloat32(0)
assert.NotNil(t, err)
_, err = adapter.ReadFloat64(0)
assert.NotNil(t, err)
})
// t.Run("test size is zero", func(t *testing.T) {
// adapter := &NumpyAdapter{
// reader: nil,
// npyReader: nil,
// }
// _, err := adapter.ReadBool(0)
// assert.Error(t, err)
// _, err = adapter.ReadUint8(0)
// assert.Error(t, err)
// _, err = adapter.ReadInt8(0)
// assert.Error(t, err)
// _, err = adapter.ReadInt16(0)
// assert.Error(t, err)
// _, err = adapter.ReadInt32(0)
// assert.Error(t, err)
// _, err = adapter.ReadInt64(0)
// assert.Error(t, err)
// _, err = adapter.ReadFloat32(0)
// assert.Error(t, err)
// _, err = adapter.ReadFloat64(0)
// assert.Error(t, err)
// })
adapter = &NumpyAdapter{
reader: &MockReader{},
npyReader: &npy.Reader{},
createAdatper := func(dt schemapb.DataType) *NumpyAdapter {
adapter := &NumpyAdapter{
reader: &MockReader{},
npyReader: &npy.Reader{
Header: npy.Header{},
},
dataType: dt,
order: binary.BigEndian,
}
adapter.npyReader.Header.Descr.Shape = []int{1}
return adapter
}
t.Run("test read bool", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "bool"
// type mismatch
adapter := createAdatper(schemapb.DataType_None)
data, err := adapter.ReadBool(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
// reader is nil, cannot read
adapter = createAdatper(schemapb.DataType_Bool)
data, err = adapter.ReadBool(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
// read one element from reader
adapter.reader = bytes.NewReader([]byte{1})
data, err = adapter.ReadBool(1)
assert.NotEmpty(t, data)
assert.NoError(t, err)
// nothing to read
data, err = adapter.ReadBool(1)
assert.Nil(t, data)
assert.NoError(t, err)
})
t.Run("test read uint8", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "u1"
// type mismatch
adapter := createAdatper(schemapb.DataType_None)
adapter.npyReader.Header.Descr.Type = "dummy"
data, err := adapter.ReadUint8(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
// reader is nil, cannot read
adapter.npyReader.Header.Descr.Type = "u1"
data, err = adapter.ReadUint8(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
// read one element from reader
adapter.reader = bytes.NewReader([]byte{1})
data, err = adapter.ReadUint8(1)
assert.NotEmpty(t, data)
assert.NoError(t, err)
// nothing to read
data, err = adapter.ReadUint8(1)
assert.Nil(t, data)
assert.NoError(t, err)
})
t.Run("test read int8", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "i1"
// type mismatch
adapter := createAdatper(schemapb.DataType_None)
data, err := adapter.ReadInt8(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
// reader is nil, cannot read
adapter = createAdatper(schemapb.DataType_Int8)
data, err = adapter.ReadInt8(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
// read one element from reader
adapter.reader = bytes.NewReader([]byte{1})
data, err = adapter.ReadInt8(1)
assert.NotEmpty(t, data)
assert.NoError(t, err)
// nothing to read
data, err = adapter.ReadInt8(1)
assert.Nil(t, data)
assert.NoError(t, err)
})
t.Run("test read int16", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "i2"
// type mismatch
adapter := createAdatper(schemapb.DataType_None)
data, err := adapter.ReadInt16(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
// reader is nil, cannot read
adapter = createAdatper(schemapb.DataType_Int16)
data, err = adapter.ReadInt16(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
// read one element from reader
adapter.reader = bytes.NewReader([]byte{1, 2})
data, err = adapter.ReadInt16(1)
assert.NotEmpty(t, data)
assert.NoError(t, err)
// nothing to read
data, err = adapter.ReadInt16(1)
assert.Nil(t, data)
assert.NoError(t, err)
})
t.Run("test read int32", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "i4"
// type mismatch
adapter := createAdatper(schemapb.DataType_None)
data, err := adapter.ReadInt32(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
// reader is nil, cannot read
adapter = createAdatper(schemapb.DataType_Int32)
data, err = adapter.ReadInt32(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
// read one element from reader
adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4})
data, err = adapter.ReadInt32(1)
assert.NotEmpty(t, data)
assert.NoError(t, err)
// nothing to read
data, err = adapter.ReadInt32(1)
assert.Nil(t, data)
assert.NoError(t, err)
})
t.Run("test read int64", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "i8"
// type mismatch
adapter := createAdatper(schemapb.DataType_None)
data, err := adapter.ReadInt64(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
// reader is nil, cannot read
adapter = createAdatper(schemapb.DataType_Int64)
data, err = adapter.ReadInt64(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
// read one element from reader
adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4, 5, 6, 7, 8})
data, err = adapter.ReadInt64(1)
assert.NotEmpty(t, data)
assert.NoError(t, err)
// nothing to read
data, err = adapter.ReadInt64(1)
assert.Nil(t, data)
assert.NoError(t, err)
})
t.Run("test read float", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "f4"
// type mismatch
adapter := createAdatper(schemapb.DataType_None)
data, err := adapter.ReadFloat32(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
// reader is nil, cannot read
adapter = createAdatper(schemapb.DataType_Float)
data, err = adapter.ReadFloat32(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
// read one element from reader
adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4})
data, err = adapter.ReadFloat32(1)
assert.NotEmpty(t, data)
assert.NoError(t, err)
// nothing to read
data, err = adapter.ReadFloat32(1)
assert.Nil(t, data)
assert.NoError(t, err)
})
t.Run("test read double", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "f8"
// type mismatch
adapter := createAdatper(schemapb.DataType_None)
data, err := adapter.ReadFloat64(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
// reader is nil, cannot read
adapter = createAdatper(schemapb.DataType_Double)
data, err = adapter.ReadFloat64(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
// read one element from reader
adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4, 5, 6, 7, 8})
data, err = adapter.ReadFloat64(1)
assert.NotEmpty(t, data)
assert.NoError(t, err)
// nothing to read
data, err = adapter.ReadFloat64(1)
assert.Nil(t, data)
assert.NoError(t, err)
})
t.Run("test read varchar", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "U3"
// type mismatch
adapter := createAdatper(schemapb.DataType_None)
data, err := adapter.ReadString(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
// reader is nil, cannot read
adapter = createAdatper(schemapb.DataType_VarChar)
adapter.reader = nil
adapter.npyReader.Header.Descr.Type = "S3"
data, err = adapter.ReadString(1)
assert.Nil(t, data)
assert.NotNil(t, err)
assert.Error(t, err)
// read one element from reader
adapter.reader = strings.NewReader("abc")
data, err = adapter.ReadString(1)
assert.NotEmpty(t, data)
assert.NoError(t, err)
// nothing to read
data, err = adapter.ReadString(1)
assert.Nil(t, data)
assert.NoError(t, err)
})
}
func Test_NumpyAdapterRead(t *testing.T) {
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
assert.NoError(t, err)
defer os.RemoveAll(TempFilesPath)
t.Run("test read bool", func(t *testing.T) {
filePath := TempFilesPath + "bool.npy"
data := []bool{true, false, true, false}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
assert.NoError(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
assert.NoError(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
assert.NoError(t, err)
// partly read
res, err := adapter.ReadBool(len(data) - 1)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
// read the left data
res, err = adapter.ReadBool(len(data))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
// nothing to read
res, err = adapter.ReadBool(len(data))
assert.NotNil(t, err)
assert.NoError(t, err)
assert.Nil(t, res)
// incorrect type read
resu1, err := adapter.ReadUint8(len(data))
assert.NotNil(t, err)
assert.Error(t, err)
assert.Nil(t, resu1)
resi1, err := adapter.ReadInt8(len(data))
assert.NotNil(t, err)
assert.Error(t, err)
assert.Nil(t, resi1)
resi2, err := adapter.ReadInt16(len(data))
assert.NotNil(t, err)
assert.Error(t, err)
assert.Nil(t, resi2)
resi4, err := adapter.ReadInt32(len(data))
assert.NotNil(t, err)
assert.Error(t, err)
assert.Nil(t, resi4)
resi8, err := adapter.ReadInt64(len(data))
assert.NotNil(t, err)
assert.Error(t, err)
assert.Nil(t, resi8)
resf4, err := adapter.ReadFloat32(len(data))
assert.NotNil(t, err)
assert.Error(t, err)
assert.Nil(t, resf4)
resf8, err := adapter.ReadFloat64(len(data))
assert.NotNil(t, err)
assert.Error(t, err)
assert.Nil(t, resf8)
})
@ -337,35 +468,38 @@ func Test_NumpyAdapterRead(t *testing.T) {
filePath := TempFilesPath + "uint8.npy"
data := []uint8{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
assert.NoError(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
assert.NoError(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
assert.NoError(t, err)
// partly read
res, err := adapter.ReadUint8(len(data) - 1)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
// read the left data
res, err = adapter.ReadUint8(len(data))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
// nothing to read
res, err = adapter.ReadUint8(len(data))
assert.NotNil(t, err)
assert.NoError(t, err)
assert.Nil(t, res)
// incorrect type read
resb, err := adapter.ReadBool(len(data))
assert.NotNil(t, err)
assert.Error(t, err)
assert.Nil(t, resb)
})
@ -373,30 +507,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
filePath := TempFilesPath + "int8.npy"
data := []int8{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
assert.NoError(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
assert.NoError(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
assert.NoError(t, err)
// partly read
res, err := adapter.ReadInt8(len(data) - 1)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
// read the left data
res, err = adapter.ReadInt8(len(data))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
// nothing to read
res, err = adapter.ReadInt8(len(data))
assert.NotNil(t, err)
assert.NoError(t, err)
assert.Nil(t, res)
})
@ -404,30 +541,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
filePath := TempFilesPath + "int16.npy"
data := []int16{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
assert.NoError(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
assert.NoError(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
assert.NoError(t, err)
// partly read
res, err := adapter.ReadInt16(len(data) - 1)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
// read the left data
res, err = adapter.ReadInt16(len(data))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
// nothing to read
res, err = adapter.ReadInt16(len(data))
assert.NotNil(t, err)
assert.NoError(t, err)
assert.Nil(t, res)
})
@ -435,30 +575,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
filePath := TempFilesPath + "int32.npy"
data := []int32{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
assert.NoError(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
assert.NoError(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
assert.NoError(t, err)
// partly read
res, err := adapter.ReadInt32(len(data) - 1)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
// read the left data
res, err = adapter.ReadInt32(len(data))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
// nothing to read
res, err = adapter.ReadInt32(len(data))
assert.NotNil(t, err)
assert.NoError(t, err)
assert.Nil(t, res)
})
@ -466,30 +609,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
filePath := TempFilesPath + "int64.npy"
data := []int64{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
assert.NoError(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
assert.NoError(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
assert.NoError(t, err)
// partly read
res, err := adapter.ReadInt64(len(data) - 1)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
// read the left data
res, err = adapter.ReadInt64(len(data))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
// nothing to read
res, err = adapter.ReadInt64(len(data))
assert.NotNil(t, err)
assert.NoError(t, err)
assert.Nil(t, res)
})
@ -497,30 +643,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
filePath := TempFilesPath + "float.npy"
data := []float32{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
assert.NoError(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
assert.NoError(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
assert.NoError(t, err)
// partly read
res, err := adapter.ReadFloat32(len(data) - 1)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
// read the left data
res, err = adapter.ReadFloat32(len(data))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
// nothing to read
res, err = adapter.ReadFloat32(len(data))
assert.NotNil(t, err)
assert.NoError(t, err)
assert.Nil(t, res)
})
@ -528,30 +677,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
filePath := TempFilesPath + "double.npy"
data := []float64{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
assert.NoError(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
assert.NoError(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
assert.NoError(t, err)
// partly read
res, err := adapter.ReadFloat64(len(data) - 1)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
// read the left data
res, err = adapter.ReadFloat64(len(data))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
// nothing to read
res, err = adapter.ReadFloat64(len(data))
assert.NotNil(t, err)
assert.NoError(t, err)
assert.Nil(t, res)
})
@ -589,19 +741,19 @@ func Test_NumpyAdapterRead(t *testing.T) {
// count should greater than 0
res, err := adapter.ReadString(0)
assert.NotNil(t, err)
assert.Error(t, err)
assert.Nil(t, res)
// maxLen is zero
npyReader.Header.Descr.Type = "S0"
res, err = adapter.ReadString(1)
assert.NotNil(t, err)
assert.Error(t, err)
assert.Nil(t, res)
npyReader.Header.Descr.Type = "S" + strconv.FormatInt(int64(maxLen), 10)
res, err = adapter.ReadString(len(values) + 1)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(values), len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, values[i], res[i])
@ -612,29 +764,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
filePath := TempFilesPath + "varchar1.npy"
data := []string{"a ", "bbb", " c", "dd", "eeee", "fff"}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
assert.NoError(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
assert.NoError(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
assert.NoError(t, err)
// partly read
res, err := adapter.ReadString(len(data) - 1)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
// read the left data
res, err = adapter.ReadString(len(data))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
// nothing to read
res, err = adapter.ReadString(len(data))
assert.NotNil(t, err)
assert.NoError(t, err)
assert.Nil(t, res)
})
@ -642,16 +798,16 @@ func Test_NumpyAdapterRead(t *testing.T) {
filePath := TempFilesPath + "varchar2.npy"
data := []string{"で と ど ", " 马克bbb", "$(한)삼각*"}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
assert.NoError(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
assert.NoError(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
assert.NoError(t, err)
res, err := adapter.ReadString(len(data))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(data), len(res))
for i := 0; i < len(res); i++ {
@ -663,7 +819,7 @@ func Test_NumpyAdapterRead(t *testing.T) {
func Test_DecodeUtf32(t *testing.T) {
// wrong input
res, err := decodeUtf32([]byte{1, 2}, binary.LittleEndian)
assert.NotNil(t, err)
assert.Error(t, err)
assert.Empty(t, res)
// this string contains ascii characters and unicode characters
@ -672,12 +828,12 @@ func Test_DecodeUtf32(t *testing.T) {
// utf32 littleEndian of str
src := []byte{97, 0, 0, 0, 100, 0, 0, 0, 228, 37, 0, 0, 9, 78, 0, 0, 126, 118, 0, 0, 181, 243, 1, 0, 144, 48, 0, 0, 153, 33, 0, 0}
res, err = decodeUtf32(src, binary.LittleEndian)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, str, res)
// utf32 bigEndian of str
src = []byte{0, 0, 0, 97, 0, 0, 0, 100, 0, 0, 37, 228, 0, 0, 78, 9, 0, 0, 118, 126, 0, 1, 243, 181, 0, 0, 48, 144, 0, 0, 33, 153}
res, err = decodeUtf32(src, binary.BigEndian)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, str, res)
}

View File

@ -20,280 +20,527 @@ import (
"context"
"errors"
"fmt"
"io"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
type ColumnDesc struct {
name string // name of the target column
dt schemapb.DataType // data type of the target column
elementCount int // how many elements need to be read
dimension int // only for vector
type NumpyColumnReader struct {
fieldName string // name of the target column
fieldID storage.FieldID // ID of the target column
dataType schemapb.DataType // data type of the target column
rowCount int // how many rows need to be read
dimension int // only for vector
file storage.FileReader // file to be read
reader *NumpyAdapter // data reader
}
func closeReaders(columnReaders []*NumpyColumnReader) {
for _, reader := range columnReaders {
if reader.file != nil {
err := reader.file.Close()
if err != nil {
log.Error("Numper parser: failed to close numpy file", zap.String("fileName", reader.fieldName+NumpyFileExt))
}
}
}
}
type NumpyParser struct {
ctx context.Context // for canceling parse process
collectionSchema *schemapb.CollectionSchema // collection schema
columnDesc *ColumnDesc // description for target column
columnData storage.FieldData // in-memory column data
callFlushFunc func(field storage.FieldData) error // call back function to output column data
rowIDAllocator *allocator.IDAllocator // autoid allocator
shardNum int32 // sharding number of the collection
blockSize int64 // maximum size of a read block(unit:byte)
chunkManager storage.ChunkManager // storage interfaces to browse/read the files
autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25
callFlushFunc ImportFlushFunc // call back function to flush segment
}
// NewNumpyParser is helper function to create a NumpyParser
func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema,
flushFunc func(field storage.FieldData) error) *NumpyParser {
if collectionSchema == nil || flushFunc == nil {
return nil
func NewNumpyParser(ctx context.Context,
collectionSchema *schemapb.CollectionSchema,
idAlloc *allocator.IDAllocator,
shardNum int32,
blockSize int64,
chunkManager storage.ChunkManager,
flushFunc ImportFlushFunc) (*NumpyParser, error) {
if collectionSchema == nil {
log.Error("Numper parser: collection schema is nil")
return nil, errors.New("collection schema is nil")
}
if idAlloc == nil {
log.Error("Numper parser: id allocator is nil")
return nil, errors.New("id allocator is nil")
}
if chunkManager == nil {
log.Error("Numper parser: chunk manager pointer is nil")
return nil, errors.New("chunk manager pointer is nil")
}
if flushFunc == nil {
log.Error("Numper parser: flush function is nil")
return nil, errors.New("flush function is nil")
}
parser := &NumpyParser{
ctx: ctx,
collectionSchema: collectionSchema,
columnDesc: &ColumnDesc{},
rowIDAllocator: idAlloc,
shardNum: shardNum,
blockSize: blockSize,
chunkManager: chunkManager,
autoIDRange: make([]int64, 0),
callFlushFunc: flushFunc,
}
return parser
return parser, nil
}
func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
if adapter == nil {
log.Error("Numpy parser: numpy adapter is nil")
return errors.New("numpy adapter is nil")
func (p *NumpyParser) IDRange() []int64 {
return p.autoIDRange
}
// Parse is the function entry
func (p *NumpyParser) Parse(filePaths []string) error {
// check redundant files for column-based import
// if the field is primary key and autoid is false, the file is required
// any redundant file is not allowed
err := p.validateFileNames(filePaths)
if err != nil {
return err
}
// check existence of the target field
var schema *schemapb.FieldSchema
for i := 0; i < len(p.collectionSchema.Fields); i++ {
schema = p.collectionSchema.Fields[i]
if schema.GetName() == fieldName {
p.columnDesc.name = fieldName
break
}
// open files and verify file header
readers, err := p.createReaders(filePaths)
// make sure all the files are closed finially, must call this method before the function return
defer closeReaders(readers)
if err != nil {
return err
}
if p.columnDesc.name == "" {
log.Error("Numpy parser: Numpy parser: the field is not found in collection schema", zap.String("fieldName", fieldName))
return fmt.Errorf("the field name '%s' is not found in collection schema", fieldName)
}
p.columnDesc.dt = schema.DataType
elementType := adapter.GetType()
shape := adapter.GetShape()
var err error
// 1. field data type should be consist to numpy data type
// 2. vector field dimension should be consist to numpy shape
if schemapb.DataType_FloatVector == schema.DataType {
// float32/float64 numpy file can be used for float vector file, 2 reasons:
// 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit
// 2. for float64 numpy file, the performance is worse than float32 numpy file
if elementType != schemapb.DataType_Float && elementType != schemapb.DataType_Double {
log.Error("Numpy parser: illegal data type of numpy file for float vector field", zap.Any("dataType", elementType),
zap.String("fieldName", fieldName))
return fmt.Errorf("illegal data type %s of numpy file for float vector field '%s'", getTypeName(elementType), schema.GetName())
}
// vector field, the shape should be 2
if len(shape) != 2 {
log.Error("Numpy parser: illegal shape of numpy file for float vector field, shape should be 2", zap.Int("shape", len(shape)),
zap.String("fieldName", fieldName))
return fmt.Errorf("illegal shape %d of numpy file for float vector field '%s', shape should be 2", shape, schema.GetName())
}
// shape[0] is row count, shape[1] is element count per row
p.columnDesc.elementCount = shape[0] * shape[1]
p.columnDesc.dimension, err = getFieldDimension(schema)
if err != nil {
return err
}
if shape[1] != p.columnDesc.dimension {
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName),
zap.Int("numpyDimension", shape[1]), zap.Int("fieldDimension", p.columnDesc.dimension))
return fmt.Errorf("illegal dimension %d of numpy file for float vector field '%s', dimension should be %d",
shape[1], schema.GetName(), p.columnDesc.dimension)
}
} else if schemapb.DataType_BinaryVector == schema.DataType {
if elementType != schemapb.DataType_BinaryVector {
log.Error("Numpy parser: illegal data type of numpy file for binary vector field", zap.Any("dataType", elementType),
zap.String("fieldName", fieldName))
return fmt.Errorf("illegal data type %s of numpy file for binary vector field '%s'", getTypeName(elementType), schema.GetName())
}
// vector field, the shape should be 2
if len(shape) != 2 {
log.Error("Numpy parser: illegal shape of numpy file for binary vector field, shape should be 2", zap.Int("shape", len(shape)),
zap.String("fieldName", fieldName))
return fmt.Errorf("illegal shape %d of numpy file for binary vector field '%s', shape should be 2", shape, schema.GetName())
}
// shape[0] is row count, shape[1] is element count per row
p.columnDesc.elementCount = shape[0] * shape[1]
p.columnDesc.dimension, err = getFieldDimension(schema)
if err != nil {
return err
}
if shape[1] != p.columnDesc.dimension/8 {
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName),
zap.Int("numpyDimension", shape[1]*8), zap.Int("fieldDimension", p.columnDesc.dimension))
return fmt.Errorf("illegal dimension %d of numpy file for binary vector field '%s', dimension should be %d",
shape[1]*8, schema.GetName(), p.columnDesc.dimension)
}
} else {
if elementType != schema.DataType {
log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType),
zap.String("fieldName", fieldName), zap.Any("fieldDataType", schema.DataType))
return fmt.Errorf("illegal data type %s of numpy file for scalar field '%s' with type %s",
getTypeName(elementType), schema.GetName(), getTypeName(schema.DataType))
}
// scalar field, the shape should be 1
if len(shape) != 1 {
log.Error("Numpy parser: illegal shape of numpy file for scalar field, shape should be 1", zap.Int("shape", len(shape)),
zap.String("fieldName", fieldName))
return fmt.Errorf("illegal shape %d of numpy file for scalar field '%s', shape should be 1", shape, schema.GetName())
}
p.columnDesc.elementCount = shape[0]
// read all data from the numpy files
err = p.consume(readers)
if err != nil {
return err
}
return nil
}
// validateFileNames is to check redundant file and missed file
func (p *NumpyParser) validateFileNames(filePaths []string) error {
requiredFieldNames := make(map[string]interface{})
for _, schema := range p.collectionSchema.Fields {
if schema.GetIsPrimaryKey() {
if !schema.GetAutoID() {
requiredFieldNames[schema.GetName()] = nil
}
} else {
requiredFieldNames[schema.GetName()] = nil
}
}
// check redundant file
fileNames := make(map[string]interface{})
for _, filePath := range filePaths {
name, _ := GetFileNameAndExt(filePath)
fileNames[name] = nil
_, ok := requiredFieldNames[name]
if !ok {
log.Error("Numpy parser: the file has no corresponding field in collection", zap.String("fieldName", name))
return fmt.Errorf("the file '%s' has no corresponding field in collection", filePath)
}
}
// check missed file
for name := range requiredFieldNames {
_, ok := fileNames[name]
if !ok {
log.Error("Numpy parser: there is no file corresponding to field", zap.String("fieldName", name))
return fmt.Errorf("there is no file corresponding to field '%s'", name)
}
}
return nil
}
// createReaders open the files and verify file header
func (p *NumpyParser) createReaders(filePaths []string) ([]*NumpyColumnReader, error) {
readers := make([]*NumpyColumnReader, 0)
for _, filePath := range filePaths {
fileName, _ := GetFileNameAndExt(filePath)
// check existence of the target field
var schema *schemapb.FieldSchema
for i := 0; i < len(p.collectionSchema.Fields); i++ {
tmpSchema := p.collectionSchema.Fields[i]
if tmpSchema.GetName() == fileName {
schema = tmpSchema
break
}
}
if schema == nil {
log.Error("Numpy parser: the field is not found in collection schema", zap.String("fileName", fileName))
return nil, fmt.Errorf("the field name '%s' is not found in collection schema", fileName)
}
file, err := p.chunkManager.Reader(p.ctx, filePath)
if err != nil {
log.Error("Numpy parser: failed to read the file", zap.String("filePath", filePath), zap.Error(err))
return nil, fmt.Errorf("failed to read the file '%s', error: %s", filePath, err.Error())
}
adapter, err := NewNumpyAdapter(file)
if err != nil {
log.Error("Numpy parser: failed to read the file header", zap.String("filePath", filePath), zap.Error(err))
return nil, fmt.Errorf("failed to read the file header '%s', error: %s", filePath, err.Error())
}
if file == nil || adapter == nil {
log.Error("Numpy parser: failed to open file", zap.String("filePath", filePath))
return nil, fmt.Errorf("failed to open file '%s'", filePath)
}
dim, _ := getFieldDimension(schema)
columnReader := &NumpyColumnReader{
fieldName: schema.GetName(),
fieldID: schema.GetFieldID(),
dataType: schema.GetDataType(),
dimension: dim,
file: file,
reader: adapter,
}
// the validation method only check the file header information
err = p.validateHeader(columnReader)
if err != nil {
return nil, err
}
readers = append(readers, columnReader)
}
// row count of each file should be equal
if len(readers) > 0 {
firstReader := readers[0]
rowCount := firstReader.rowCount
for i := 1; i < len(readers); i++ {
compareReader := readers[i]
if rowCount != compareReader.rowCount {
log.Error("Numpy parser: the row count of files are not equal",
zap.String("firstFile", firstReader.fieldName), zap.Int("firstRowCount", firstReader.rowCount),
zap.String("compareFile", compareReader.fieldName), zap.Int("compareRowCount", compareReader.rowCount))
return nil, fmt.Errorf("the row count(%d) of file '%s.npy' is not equal to row count(%d) of file '%s.npy'",
firstReader.rowCount, firstReader.fieldName, compareReader.rowCount, compareReader.fieldName)
}
}
}
return readers, nil
}
// validateHeader is to verify numpy file header, file header information should match field's schema
func (p *NumpyParser) validateHeader(columnReader *NumpyColumnReader) error {
if columnReader == nil || columnReader.reader == nil {
log.Error("Numpy parser: numpy reader is nil")
return errors.New("numpy adapter is nil")
}
elementType := columnReader.reader.GetType()
shape := columnReader.reader.GetShape()
columnReader.rowCount = shape[0]
// 1. field data type should be consist to numpy data type
// 2. vector field dimension should be consist to numpy shape
if schemapb.DataType_FloatVector == columnReader.dataType {
// float32/float64 numpy file can be used for float vector file, 2 reasons:
// 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit
// 2. for float64 numpy file, the performance is worse than float32 numpy file
if elementType != schemapb.DataType_Float && elementType != schemapb.DataType_Double {
log.Error("Numpy parser: illegal data type of numpy file for float vector field", zap.Any("dataType", elementType),
zap.String("fieldName", columnReader.fieldName))
return fmt.Errorf("illegal data type %s of numpy file for float vector field '%s'", getTypeName(elementType),
columnReader.fieldName)
}
// vector field, the shape should be 2
if len(shape) != 2 {
log.Error("Numpy parser: illegal shape of numpy file for float vector field, shape should be 2", zap.Int("shape", len(shape)),
zap.String("fieldName", columnReader.fieldName))
return fmt.Errorf("illegal shape %d of numpy file for float vector field '%s', shape should be 2", shape,
columnReader.fieldName)
}
if shape[1] != columnReader.dimension {
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", columnReader.fieldName),
zap.Int("numpyDimension", shape[1]), zap.Int("fieldDimension", columnReader.dimension))
return fmt.Errorf("illegal dimension %d of numpy file for float vector field '%s', dimension should be %d",
shape[1], columnReader.fieldName, columnReader.dimension)
}
} else if schemapb.DataType_BinaryVector == columnReader.dataType {
if elementType != schemapb.DataType_BinaryVector {
log.Error("Numpy parser: illegal data type of numpy file for binary vector field", zap.Any("dataType", elementType),
zap.String("fieldName", columnReader.fieldName))
return fmt.Errorf("illegal data type %s of numpy file for binary vector field '%s'", getTypeName(elementType),
columnReader.fieldName)
}
// vector field, the shape should be 2
if len(shape) != 2 {
log.Error("Numpy parser: illegal shape of numpy file for binary vector field, shape should be 2", zap.Int("shape", len(shape)),
zap.String("fieldName", columnReader.fieldName))
return fmt.Errorf("illegal shape %d of numpy file for binary vector field '%s', shape should be 2", shape,
columnReader.fieldName)
}
if shape[1] != columnReader.dimension/8 {
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", columnReader.fieldName),
zap.Int("numpyDimension", shape[1]*8), zap.Int("fieldDimension", columnReader.dimension))
return fmt.Errorf("illegal dimension %d of numpy file for binary vector field '%s', dimension should be %d",
shape[1]*8, columnReader.fieldName, columnReader.dimension)
}
} else {
if elementType != columnReader.dataType {
log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType),
zap.String("fieldName", columnReader.fieldName), zap.Any("fieldDataType", columnReader.dataType))
return fmt.Errorf("illegal data type %s of numpy file for scalar field '%s' with type %s",
getTypeName(elementType), columnReader.fieldName, getTypeName(columnReader.dataType))
}
// scalar field, the shape should be 1
if len(shape) != 1 {
log.Error("Numpy parser: illegal shape of numpy file for scalar field, shape should be 1", zap.Int("shape", len(shape)),
zap.String("fieldName", columnReader.fieldName))
return fmt.Errorf("illegal shape %d of numpy file for scalar field '%s', shape should be 1", shape, columnReader.fieldName)
}
}
return nil
}
// calcRowCountPerBlock calculates a proper value for a batch row count to read file
func (p *NumpyParser) calcRowCountPerBlock() (int64, error) {
sizePerRecord, err := typeutil.EstimateSizePerRecord(p.collectionSchema)
if err != nil {
log.Error("Numpy parser: failed to estimate size of each row", zap.Error(err))
return 0, fmt.Errorf("failed to estimate size of each row: %s", err.Error())
}
if sizePerRecord <= 0 {
log.Error("Numpy parser: failed to estimate size of each row, the collection schema might be empty")
return 0, fmt.Errorf("failed to estimate size of each row: the collection schema might be empty")
}
// the sizePerRecord is estimate value, if the schema contains varchar field, the value is not accurate
// we will read data block by block, by default, each block size is 16MB
// rowCountPerBlock is the estimated row count for a block
rowCountPerBlock := p.blockSize / int64(sizePerRecord)
if rowCountPerBlock <= 0 {
rowCountPerBlock = 1 // make sure the value is positive
}
log.Info("Numper parser: calculate row count per block to read file", zap.Int64("rowCountPerBlock", rowCountPerBlock),
zap.Int64("blockSize", p.blockSize), zap.Int("sizePerRecord", sizePerRecord))
return rowCountPerBlock, nil
}
// consume method reads numpy data section into a storage.FieldData
// please note it will require a large memory block(the memory size is almost equal to numpy file size)
func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
switch p.columnDesc.dt {
func (p *NumpyParser) consume(columnReaders []*NumpyColumnReader) error {
rowCountPerBlock, err := p.calcRowCountPerBlock()
if err != nil {
return err
}
// prepare shards
shards := make([]map[storage.FieldID]storage.FieldData, 0, p.shardNum)
for i := 0; i < int(p.shardNum); i++ {
segmentData := initSegmentData(p.collectionSchema)
if segmentData == nil {
log.Error("import wrapper: failed to initialize FieldData list")
return fmt.Errorf("failed to initialize FieldData list")
}
shards = append(shards, segmentData)
}
tr := timerecord.NewTimeRecorder("consume performance")
defer tr.Elapse("end")
// read data from files, batch by batch
for {
readRowCount := 0
segmentData := make(map[storage.FieldID]storage.FieldData)
for _, reader := range columnReaders {
fieldData, err := p.readData(reader, int(rowCountPerBlock))
if err != nil {
return err
}
if readRowCount == 0 {
readRowCount = fieldData.RowNum()
} else if readRowCount != fieldData.RowNum() {
log.Error("Numpy parser: data block's row count mismatch", zap.Int("firstBlockRowCount", readRowCount),
zap.Int("thisBlockRowCount", fieldData.RowNum()), zap.Int64("rowCountPerBlock", rowCountPerBlock))
return fmt.Errorf("data block's row count mismatch: %d vs %d", readRowCount, fieldData.RowNum())
}
segmentData[reader.fieldID] = fieldData
}
// nothing to read
if readRowCount == 0 {
break
}
tr.Record("readData")
// split data to shards
err = p.splitFieldsData(segmentData, shards)
if err != nil {
return err
}
tr.Record("splitFieldsData")
// when the estimated size is close to blockSize, save to binlog
err = tryFlushBlocks(p.ctx, shards, p.collectionSchema, p.callFlushFunc, p.blockSize, MaxTotalSizeInMemory, false)
if err != nil {
return err
}
tr.Record("tryFlushBlocks")
}
// force flush at the end
return tryFlushBlocks(p.ctx, shards, p.collectionSchema, p.callFlushFunc, p.blockSize, MaxTotalSizeInMemory, true)
}
// readData method reads numpy data section into a storage.FieldData
func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (storage.FieldData, error) {
switch columnReader.dataType {
case schemapb.DataType_Bool:
data, err := adapter.ReadBool(p.columnDesc.elementCount)
data, err := columnReader.reader.ReadBool(rowCount)
if err != nil {
log.Error("Numpy parser: failed to read bool array", zap.Error(err))
return err
return nil, fmt.Errorf("failed to read bool array: %s", err.Error())
}
p.columnData = &storage.BoolFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
return &storage.BoolFieldData{
NumRows: []int64{int64(len(data))},
Data: data,
}
}, nil
case schemapb.DataType_Int8:
data, err := adapter.ReadInt8(p.columnDesc.elementCount)
data, err := columnReader.reader.ReadInt8(rowCount)
if err != nil {
log.Error("Numpy parser: failed to read int8 array", zap.Error(err))
return err
return nil, fmt.Errorf("failed to read int8 array: %s", err.Error())
}
p.columnData = &storage.Int8FieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
return &storage.Int8FieldData{
NumRows: []int64{int64(len(data))},
Data: data,
}
}, nil
case schemapb.DataType_Int16:
data, err := adapter.ReadInt16(p.columnDesc.elementCount)
data, err := columnReader.reader.ReadInt16(rowCount)
if err != nil {
log.Error("Numpy parser: failed to int16 bool array", zap.Error(err))
return err
log.Error("Numpy parser: failed to int16 array", zap.Error(err))
return nil, fmt.Errorf("failed to read int16 array: %s", err.Error())
}
p.columnData = &storage.Int16FieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
return &storage.Int16FieldData{
NumRows: []int64{int64(len(data))},
Data: data,
}
}, nil
case schemapb.DataType_Int32:
data, err := adapter.ReadInt32(p.columnDesc.elementCount)
data, err := columnReader.reader.ReadInt32(rowCount)
if err != nil {
log.Error("Numpy parser: failed to read int32 array", zap.Error(err))
return err
return nil, fmt.Errorf("failed to read int32 array: %s", err.Error())
}
p.columnData = &storage.Int32FieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
return &storage.Int32FieldData{
NumRows: []int64{int64(len(data))},
Data: data,
}
}, nil
case schemapb.DataType_Int64:
data, err := adapter.ReadInt64(p.columnDesc.elementCount)
data, err := columnReader.reader.ReadInt64(rowCount)
if err != nil {
log.Error("Numpy parser: failed to read int64 array", zap.Error(err))
return err
return nil, fmt.Errorf("failed to read int64 array: %s", err.Error())
}
p.columnData = &storage.Int64FieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
return &storage.Int64FieldData{
NumRows: []int64{int64(len(data))},
Data: data,
}
}, nil
case schemapb.DataType_Float:
data, err := adapter.ReadFloat32(p.columnDesc.elementCount)
data, err := columnReader.reader.ReadFloat32(rowCount)
if err != nil {
log.Error("Numpy parser: failed to read float array", zap.Error(err))
return err
return nil, fmt.Errorf("failed to read float array: %s", err.Error())
}
p.columnData = &storage.FloatFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
return &storage.FloatFieldData{
NumRows: []int64{int64(len(data))},
Data: data,
}
}, nil
case schemapb.DataType_Double:
data, err := adapter.ReadFloat64(p.columnDesc.elementCount)
data, err := columnReader.reader.ReadFloat64(rowCount)
if err != nil {
log.Error("Numpy parser: failed to read double array", zap.Error(err))
return err
return nil, fmt.Errorf("failed to read double array: %s", err.Error())
}
p.columnData = &storage.DoubleFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
return &storage.DoubleFieldData{
NumRows: []int64{int64(len(data))},
Data: data,
}
}, nil
case schemapb.DataType_VarChar:
data, err := adapter.ReadString(p.columnDesc.elementCount)
data, err := columnReader.reader.ReadString(rowCount)
if err != nil {
log.Error("Numpy parser: failed to read varchar array", zap.Error(err))
return err
return nil, fmt.Errorf("failed to read varchar array: %s", err.Error())
}
p.columnData = &storage.StringFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
return &storage.StringFieldData{
NumRows: []int64{int64(len(data))},
Data: data,
}
}, nil
case schemapb.DataType_BinaryVector:
data, err := adapter.ReadUint8(p.columnDesc.elementCount)
data, err := columnReader.reader.ReadUint8(rowCount * (columnReader.dimension / 8))
if err != nil {
log.Error("Numpy parser: failed to read binary vector array", zap.Error(err))
return err
return nil, fmt.Errorf("failed to read binary vector array: %s", err.Error())
}
p.columnData = &storage.BinaryVectorFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
return &storage.BinaryVectorFieldData{
NumRows: []int64{int64(len(data) * 8 / columnReader.dimension)},
Data: data,
Dim: p.columnDesc.dimension,
}
Dim: columnReader.dimension,
}, nil
case schemapb.DataType_FloatVector:
// float32/float64 numpy file can be used for float vector file, 2 reasons:
// 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit
// 2. for float64 numpy file, the performance is worse than float32 numpy file
elementType := adapter.GetType()
elementType := columnReader.reader.GetType()
var data []float32
var err error
if elementType == schemapb.DataType_Float {
data, err = adapter.ReadFloat32(p.columnDesc.elementCount)
data, err = columnReader.reader.ReadFloat32(rowCount * columnReader.dimension)
if err != nil {
log.Error("Numpy parser: failed to read float vector array", zap.Error(err))
return err
return nil, fmt.Errorf("failed to read float vector array: %s", err.Error())
}
} else if elementType == schemapb.DataType_Double {
data = make([]float32, 0, p.columnDesc.elementCount)
data64, err := adapter.ReadFloat64(p.columnDesc.elementCount)
data = make([]float32, 0, columnReader.rowCount)
data64, err := columnReader.reader.ReadFloat64(rowCount * columnReader.dimension)
if err != nil {
log.Error("Numpy parser: failed to read float vector array", zap.Error(err))
return err
return nil, fmt.Errorf("failed to read float vector array: %s", err.Error())
}
for _, f64 := range data64 {
@ -301,40 +548,255 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
}
}
p.columnData = &storage.FloatVectorFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
return &storage.FloatVectorFieldData{
NumRows: []int64{int64(len(data) / columnReader.dimension)},
Data: data,
Dim: p.columnDesc.dimension,
Dim: columnReader.dimension,
}, nil
default:
log.Error("Numpy parser: unsupported data type of field", zap.Any("dataType", columnReader.dataType),
zap.String("fieldName", columnReader.fieldName))
return nil, fmt.Errorf("unsupported data type %s of field '%s'", getTypeName(columnReader.dataType),
columnReader.fieldName)
}
}
// appendFunc defines the methods to append data to storage.FieldData
func (p *NumpyParser) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error {
switch schema.DataType {
case schemapb.DataType_Bool:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.BoolFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(bool))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Float:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.FloatFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(float32))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Double:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.DoubleFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(float64))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int8:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int8FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int8))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int16:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int16FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int16))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int32:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int32FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int32))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int64:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int64FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int64))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_BinaryVector:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.BinaryVectorFieldData)
arr.Data = append(arr.Data, src.GetRow(n).([]byte)...)
arr.NumRows[0]++
return nil
}
case schemapb.DataType_FloatVector:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.FloatVectorFieldData)
arr.Data = append(arr.Data, src.GetRow(n).([]float32)...)
arr.NumRows[0]++
return nil
}
case schemapb.DataType_String, schemapb.DataType_VarChar:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.StringFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(string))
return nil
}
default:
log.Error("Numpy parser: unsupported data type of field", zap.Any("dataType", p.columnDesc.dt), zap.String("fieldName", p.columnDesc.name))
return fmt.Errorf("unsupported data type %s of field '%s'", getTypeName(p.columnDesc.dt), p.columnDesc.name)
return nil
}
}
func (p *NumpyParser) prepareAppendFunctions() (map[string]func(src storage.FieldData, n int, target storage.FieldData) error, error) {
appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error)
for i := 0; i < len(p.collectionSchema.Fields); i++ {
schema := p.collectionSchema.Fields[i]
appendFuncErr := p.appendFunc(schema)
if appendFuncErr == nil {
log.Error("Numpy parser: unsupported field data type")
return nil, fmt.Errorf("unsupported field data type: %d", schema.GetDataType())
}
appendFunctions[schema.GetName()] = appendFuncErr
}
return appendFunctions, nil
}
// checkRowCount checks existence of each field, and returns the primary key schema
// check row count, all fields row count must be equal
func (p *NumpyParser) checkRowCount(fieldsData map[storage.FieldID]storage.FieldData) (int, *schemapb.FieldSchema, error) {
rowCount := 0
rowCounter := make(map[string]int)
var primaryKey *schemapb.FieldSchema
for i := 0; i < len(p.collectionSchema.Fields); i++ {
schema := p.collectionSchema.Fields[i]
if schema.GetIsPrimaryKey() {
primaryKey = schema
}
if !schema.GetAutoID() {
v, ok := fieldsData[schema.GetFieldID()]
if !ok {
log.Error("Numpy parser: field not provided", zap.String("fieldName", schema.GetName()))
return 0, nil, fmt.Errorf("field '%s' not provided", schema.GetName())
}
rowCounter[schema.GetName()] = v.RowNum()
if v.RowNum() > rowCount {
rowCount = v.RowNum()
}
}
}
if primaryKey == nil {
log.Error("Numpy parser: primary key field is not found")
return 0, nil, fmt.Errorf("primary key field is not found")
}
for name, count := range rowCounter {
if count != rowCount {
log.Error("Numpy parser: field row count is not equal to other fields row count", zap.String("fieldName", name),
zap.Int("rowCount", count), zap.Int("otherRowCount", rowCount))
return 0, nil, fmt.Errorf("field '%s' row count %d is not equal to other fields row count: %d", name, count, rowCount)
}
}
// log.Info("Numpy parser: try to split a block with row count", zap.Int("rowCount", rowCount))
return rowCount, primaryKey, nil
}
// splitFieldsData is to split the in-memory data(parsed from column-based files) into shards
func (p *NumpyParser) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, shards []map[storage.FieldID]storage.FieldData) error {
if len(fieldsData) == 0 {
log.Error("Numpy parser: fields data to split is empty")
return fmt.Errorf("fields data to split is empty")
}
if len(shards) != int(p.shardNum) {
log.Error("Numpy parser: block count is not equal to collection shard number", zap.Int("shardsLen", len(shards)),
zap.Int32("shardNum", p.shardNum))
return fmt.Errorf("block count %d is not equal to collection shard number %d", len(shards), p.shardNum)
}
rowCount, primaryKey, err := p.checkRowCount(fieldsData)
if err != nil {
return err
}
// generate auto id for primary key and rowid field
rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount))
if err != nil {
log.Error("Numpy parser: failed to alloc row ID", zap.Int("rowCount", rowCount), zap.Error(err))
return fmt.Errorf("failed to alloc %d rows ID, error: %w", rowCount, err)
}
rowIDField, ok := fieldsData[common.RowIDField]
if !ok {
rowIDField = &storage.Int64FieldData{
Data: make([]int64, 0),
NumRows: []int64{0},
}
fieldsData[common.RowIDField] = rowIDField
}
rowIDFieldArr := rowIDField.(*storage.Int64FieldData)
for i := rowIDBegin; i < rowIDEnd; i++ {
rowIDFieldArr.Data = append(rowIDFieldArr.Data, i)
}
// reset the primary keys, as we know, only int64 pk can be auto-generated
if primaryKey.GetAutoID() {
log.Info("Numpy parser: generating auto-id", zap.Int("rowCount", rowCount), zap.Int64("rowIDBegin", rowIDBegin))
if primaryKey.GetDataType() != schemapb.DataType_Int64 {
log.Error("Numpy parser: primary key field is auto-generated but the field type is not int64")
return fmt.Errorf("primary key field is auto-generated but the field type is not int64")
}
primaryDataArr := &storage.Int64FieldData{
NumRows: []int64{int64(rowCount)},
Data: make([]int64, 0, rowCount),
}
for i := rowIDBegin; i < rowIDEnd; i++ {
primaryDataArr.Data = append(primaryDataArr.Data, i)
}
fieldsData[primaryKey.GetFieldID()] = primaryDataArr
p.autoIDRange = append(p.autoIDRange, rowIDBegin, rowIDEnd)
}
// if the primary key is not auto-gernerate and user doesn't provide, return error
primaryData, ok := fieldsData[primaryKey.GetFieldID()]
if !ok || primaryData.RowNum() <= 0 {
log.Error("Numpy parser: primary key field is not provided", zap.String("keyName", primaryKey.GetName()))
return fmt.Errorf("primary key '%s' field data is not provided", primaryKey.GetName())
}
// prepare append functions
appendFunctions, err := p.prepareAppendFunctions()
if err != nil {
return err
}
// split data into shards
for i := 0; i < rowCount; i++ {
// hash to a shard number
pk := primaryData.GetRow(i)
shard, err := pkToShard(pk, uint32(p.shardNum))
if err != nil {
return err
}
// set rowID field
rowIDField := shards[shard][common.RowIDField].(*storage.Int64FieldData)
rowIDField.Data = append(rowIDField.Data, rowIDFieldArr.GetRow(i).(int64))
// append row to shard
for k := 0; k < len(p.collectionSchema.Fields); k++ {
schema := p.collectionSchema.Fields[k]
srcData := fieldsData[schema.GetFieldID()]
targetData := shards[shard][schema.GetFieldID()]
if srcData == nil || targetData == nil {
log.Error("Numpy parser: cannot append data since source or target field data is nil",
zap.String("FieldName", schema.GetName()),
zap.Bool("sourceNil", srcData == nil), zap.Bool("targetNil", targetData == nil))
return fmt.Errorf("cannot append data for field '%s' since source or target field data is nil",
primaryKey.GetName())
}
appendFunc := appendFunctions[schema.GetName()]
err := appendFunc(srcData, i, targetData)
if err != nil {
return err
}
}
}
return nil
}
func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate bool) error {
adapter, err := NewNumpyAdapter(reader)
if err != nil {
return err
}
// the validation method only check the file header information
err = p.validate(adapter, fieldName)
if err != nil {
return err
}
if onlyValidate {
return nil
}
// read all data from the numpy file
err = p.consume(adapter)
if err != nil {
return err
}
return p.callFlushFunc(p.columnData)
}

File diff suppressed because it is too large Load Diff