mirror of https://github.com/milvus-io/milvus.git
parent
fb2d7af3e0
commit
18aefb381c
|
@ -34,6 +34,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
func isCanceled(ctx context.Context) bool {
|
||||
|
@ -519,3 +520,22 @@ func getTypeName(dt schemapb.DataType) string {
|
|||
return "InvalidType"
|
||||
}
|
||||
}
|
||||
|
||||
func pkToShard(pk interface{}, shardNum uint32) (uint32, error) {
|
||||
var shard uint32
|
||||
strPK, ok := interface{}(pk).(string)
|
||||
if ok {
|
||||
hash := typeutil.HashString2Uint32(strPK)
|
||||
shard = hash % shardNum
|
||||
} else {
|
||||
intPK, ok := interface{}(pk).(int64)
|
||||
if !ok {
|
||||
log.Error("Numpy parser: primary key field must be int64 or varchar")
|
||||
return 0, fmt.Errorf("primary key field must be int64 or varchar")
|
||||
}
|
||||
hash, _ := typeutil.Hash32Int64(intPK)
|
||||
shard = hash % shardNum
|
||||
}
|
||||
|
||||
return shard, nil
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -611,3 +612,31 @@ func Test_GetTypeName(t *testing.T) {
|
|||
str = getTypeName(schemapb.DataType_None)
|
||||
assert.Equal(t, "InvalidType", str)
|
||||
}
|
||||
|
||||
func Test_PkToShard(t *testing.T) {
|
||||
a := int32(99)
|
||||
shard, err := pkToShard(a, 2)
|
||||
assert.Error(t, err)
|
||||
assert.Zero(t, shard)
|
||||
|
||||
s := "abcdef"
|
||||
shardNum := uint32(3)
|
||||
shard, err = pkToShard(s, shardNum)
|
||||
assert.NoError(t, err)
|
||||
hash := typeutil.HashString2Uint32(s)
|
||||
assert.Equal(t, hash%shardNum, shard)
|
||||
|
||||
pk := int64(100)
|
||||
shardNum = uint32(4)
|
||||
shard, err = pkToShard(pk, shardNum)
|
||||
assert.NoError(t, err)
|
||||
hash, _ = typeutil.Hash32Int64(pk)
|
||||
assert.Equal(t, hash%shardNum, shard)
|
||||
|
||||
pk = int64(99999)
|
||||
shardNum = uint32(5)
|
||||
shard, err = pkToShard(pk, shardNum)
|
||||
assert.NoError(t, err)
|
||||
hash, _ = typeutil.Hash32Int64(pk)
|
||||
assert.Equal(t, hash%shardNum, shard)
|
||||
}
|
||||
|
|
|
@ -34,7 +34,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -47,7 +46,7 @@ const (
|
|||
// this limitation is to avoid this OOM risk:
|
||||
// for column-based file, we read all its data into memory, if user input a large file, the read() method may
|
||||
// cost extra memory and lear to OOM.
|
||||
MaxFileSize = 1 * 1024 * 1024 * 1024 // 1GB
|
||||
MaxFileSize = 16 * 1024 * 1024 * 1024 // 16GB
|
||||
|
||||
// this limitation is to avoid this OOM risk:
|
||||
// simetimes system segment max size is a large number, a single segment fields data might cause OOM.
|
||||
|
@ -175,42 +174,6 @@ func (p *ImportWrapper) Cancel() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *ImportWrapper) validateColumnBasedFiles(filePaths []string, collectionSchema *schemapb.CollectionSchema) error {
|
||||
requiredFieldNames := make(map[string]interface{})
|
||||
for _, schema := range p.collectionSchema.Fields {
|
||||
if schema.GetIsPrimaryKey() {
|
||||
if !schema.GetAutoID() {
|
||||
requiredFieldNames[schema.GetName()] = nil
|
||||
}
|
||||
} else {
|
||||
requiredFieldNames[schema.GetName()] = nil
|
||||
}
|
||||
}
|
||||
|
||||
// check redundant file
|
||||
fileNames := make(map[string]interface{})
|
||||
for _, filePath := range filePaths {
|
||||
name, _ := GetFileNameAndExt(filePath)
|
||||
fileNames[name] = nil
|
||||
_, ok := requiredFieldNames[name]
|
||||
if !ok {
|
||||
log.Error("import wrapper: the file has no corresponding field in collection", zap.String("fieldName", name))
|
||||
return fmt.Errorf("the file '%s' has no corresponding field in collection", filePath)
|
||||
}
|
||||
}
|
||||
|
||||
// check missed file
|
||||
for name := range requiredFieldNames {
|
||||
_, ok := fileNames[name]
|
||||
if !ok {
|
||||
log.Error("import wrapper: there is no file corresponding to field", zap.String("fieldName", name))
|
||||
return fmt.Errorf("there is no file corresponding to field '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fileValidation verify the input paths
|
||||
// if all the files are json type, return true
|
||||
// if all the files are numpy type, return false, and not allow duplicate file name
|
||||
|
@ -278,22 +241,6 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) {
|
|||
totalSize += size
|
||||
}
|
||||
|
||||
// especially for column-base, total size of files cannot exceed MaxTotalSizeInMemory
|
||||
if totalSize > MaxTotalSizeInMemory {
|
||||
log.Error("import wrapper: total size of files exceeds the maximum size", zap.Int64("totalSize", totalSize), zap.Int64("MaxTotalSize", MaxTotalSizeInMemory))
|
||||
return rowBased, fmt.Errorf("total size(%d bytes) of all files exceeds the maximum size: %d bytes", totalSize, MaxTotalSizeInMemory)
|
||||
}
|
||||
|
||||
// check redundant files for column-based import
|
||||
// if the field is primary key and autoid is false, the file is required
|
||||
// any redundant file is not allowed
|
||||
if !rowBased {
|
||||
err := p.validateColumnBasedFiles(filePaths, p.collectionSchema)
|
||||
if err != nil {
|
||||
return rowBased, err
|
||||
}
|
||||
}
|
||||
|
||||
return rowBased, nil
|
||||
}
|
||||
|
||||
|
@ -337,84 +284,26 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error
|
|||
triggerGC()
|
||||
}
|
||||
} else {
|
||||
// parse and consume column-based files
|
||||
// for column-based files, the XXXColumnConsumer only output map[string]storage.FieldData
|
||||
// after all columns are parsed/consumed, we need to combine map[string]storage.FieldData into one
|
||||
// and use splitFieldsData() to split fields data into segments according to shard number
|
||||
fieldsData := initSegmentData(p.collectionSchema)
|
||||
if fieldsData == nil {
|
||||
log.Error("import wrapper: failed to initialize FieldData list")
|
||||
return fmt.Errorf("failed to initialize FieldData list")
|
||||
// parse and consume column-based files(currently support numpy)
|
||||
// for column-based files, the NumpyParser will generate autoid for primary key, and split rows into segments
|
||||
// according to shard number, so the flushFunc will be called in the NumpyParser
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths)
|
||||
return p.flushFunc(fields, shardID)
|
||||
}
|
||||
|
||||
rowCount := 0
|
||||
|
||||
// function to combine column data into fieldsData
|
||||
combineFunc := func(fields map[storage.FieldID]storage.FieldData) error {
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
printFieldsDataInfo(fields, "import wrapper: combine field data", nil)
|
||||
for k, v := range fields {
|
||||
// ignore 0 row field
|
||||
if v.RowNum() == 0 {
|
||||
log.Warn("import wrapper: empty FieldData ignored", zap.Int64("fieldID", k))
|
||||
continue
|
||||
}
|
||||
|
||||
// ignore internal fields: RowIDField and TimeStampField
|
||||
if k == common.RowIDField || k == common.TimeStampField {
|
||||
log.Warn("import wrapper: internal fields should not be provided", zap.Int64("fieldID", k))
|
||||
continue
|
||||
}
|
||||
|
||||
// each column should be only combined once
|
||||
data, ok := fieldsData[k]
|
||||
if ok && data.RowNum() > 0 {
|
||||
return fmt.Errorf("the field %d is duplicated", k)
|
||||
}
|
||||
|
||||
// check the row count. only count non-zero row fields
|
||||
if rowCount > 0 && rowCount != v.RowNum() {
|
||||
return fmt.Errorf("the field %d row count %d doesn't equal to others row count: %d", k, v.RowNum(), rowCount)
|
||||
}
|
||||
rowCount = v.RowNum()
|
||||
|
||||
// assign column data to fieldsData
|
||||
fieldsData[k] = v
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parse/validate/consume data
|
||||
for i := 0; i < len(filePaths); i++ {
|
||||
filePath := filePaths[i]
|
||||
_, fileType := GetFileNameAndExt(filePath)
|
||||
log.Info("import wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
|
||||
|
||||
if fileType == NumpyFileExt {
|
||||
err = p.parseColumnBasedNumpy(filePath, options.OnlyValidate, combineFunc)
|
||||
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to parse column-based numpy file", zap.Error(err), zap.String("filePath", filePath))
|
||||
return err
|
||||
}
|
||||
}
|
||||
// no need to check else, since the fileValidation() already do this
|
||||
}
|
||||
|
||||
// trigger after read finished
|
||||
triggerGC()
|
||||
|
||||
// split fields data into segments
|
||||
err := p.splitFieldsData(fieldsData, SingleBlockSize)
|
||||
parser, err := NewNumpyParser(p.ctx, p.collectionSchema, p.rowIDAllocator, p.shardNum, SingleBlockSize, p.chunkManager, flushFunc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// trigger after write finished
|
||||
err = parser.Parse(filePaths)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.importResult.AutoIds = append(p.importResult.AutoIds, parser.IDRange()...)
|
||||
|
||||
// trigger after parse finished
|
||||
triggerGC()
|
||||
}
|
||||
|
||||
|
@ -437,6 +326,7 @@ func (p *ImportWrapper) reportPersisted(reportAttempts uint, tr *timerecord.Time
|
|||
|
||||
// report file process state
|
||||
p.importResult.State = commonpb.ImportState_ImportPersisted
|
||||
log.Info("import wrapper: report import result", zap.Any("importResult", p.importResult))
|
||||
// persist state task is valuable, retry more times in case fail this task only because of network error
|
||||
reportErr := retry.Do(p.ctx, func() error {
|
||||
return p.reportFunc(p.importResult)
|
||||
|
@ -554,297 +444,12 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er
|
|||
}
|
||||
|
||||
// for row-based files, auto-id is generated within JSONRowConsumer
|
||||
if consumer != nil {
|
||||
p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...)
|
||||
}
|
||||
p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...)
|
||||
|
||||
tr.Elapse("parsed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseColumnBasedNumpy is the entry of column-based numpy import operation
|
||||
func (p *ImportWrapper) parseColumnBasedNumpy(filePath string, onlyValidate bool,
|
||||
combineFunc func(fields map[storage.FieldID]storage.FieldData) error) error {
|
||||
tr := timerecord.NewTimeRecorder("numpy parser: " + filePath)
|
||||
|
||||
fileName, _ := GetFileNameAndExt(filePath)
|
||||
|
||||
// for minio storage, chunkManager will download file into local memory
|
||||
// for local storage, chunkManager open the file directly
|
||||
file, err := p.chunkManager.Reader(p.ctx, filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var id storage.FieldID
|
||||
var found = false
|
||||
for _, field := range p.collectionSchema.Fields {
|
||||
if field.GetName() == fileName {
|
||||
id = field.GetFieldID()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// if the numpy file name is not mapping to a field name, ignore it
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
// the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
fields := make(map[storage.FieldID]storage.FieldData)
|
||||
fields[id] = field
|
||||
return combineFunc(fields)
|
||||
}
|
||||
|
||||
// for numpy file, we say the file name(without extension) is the filed name
|
||||
parser := NewNumpyParser(p.ctx, p.collectionSchema, flushFunc)
|
||||
err = parser.Parse(file, fileName, onlyValidate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tr.Elapse("parsed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// appendFunc defines the methods to append data to storage.FieldData
|
||||
func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
switch schema.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.BoolFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(bool))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.FloatFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(float32))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.DoubleFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(float64))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int8:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.Int8FieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(int8))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int16:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.Int16FieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(int16))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int32:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.Int32FieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(int32))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.Int64FieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(int64))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.BinaryVectorFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).([]byte)...)
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_FloatVector:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.FloatVectorFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).([]float32)...)
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.StringFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(string))
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// splitFieldsData is to split the in-memory data(parsed from column-based files) into blocks, each block save to a binlog file
|
||||
func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, blockSize int64) error {
|
||||
if len(fieldsData) == 0 {
|
||||
log.Error("import wrapper: fields data is empty")
|
||||
return fmt.Errorf("fields data is empty")
|
||||
}
|
||||
|
||||
tr := timerecord.NewTimeRecorder("import wrapper: split field data")
|
||||
defer tr.Elapse("finished")
|
||||
|
||||
// check existence of each field
|
||||
// check row count, all fields row count must be equal
|
||||
// firstly get the max row count
|
||||
rowCount := 0
|
||||
rowCounter := make(map[string]int)
|
||||
var primaryKey *schemapb.FieldSchema
|
||||
for i := 0; i < len(p.collectionSchema.Fields); i++ {
|
||||
schema := p.collectionSchema.Fields[i]
|
||||
if schema.GetIsPrimaryKey() {
|
||||
primaryKey = schema
|
||||
}
|
||||
|
||||
if !schema.GetAutoID() {
|
||||
v, ok := fieldsData[schema.GetFieldID()]
|
||||
if !ok {
|
||||
log.Error("import wrapper: field not provided", zap.String("fieldName", schema.GetName()))
|
||||
return fmt.Errorf("field '%s' not provided", schema.GetName())
|
||||
}
|
||||
rowCounter[schema.GetName()] = v.RowNum()
|
||||
if v.RowNum() > rowCount {
|
||||
rowCount = v.RowNum()
|
||||
}
|
||||
}
|
||||
}
|
||||
if primaryKey == nil {
|
||||
log.Error("import wrapper: primary key field is not found")
|
||||
return fmt.Errorf("primary key field is not found")
|
||||
}
|
||||
|
||||
for name, count := range rowCounter {
|
||||
if count != rowCount {
|
||||
log.Error("import wrapper: field row count is not equal to other fields row count", zap.String("fieldName", name),
|
||||
zap.Int("rowCount", count), zap.Int("otherRowCount", rowCount))
|
||||
return fmt.Errorf("field '%s' row count %d is not equal to other fields row count: %d", name, count, rowCount)
|
||||
}
|
||||
}
|
||||
log.Info("import wrapper: try to split a block with row count", zap.Int("rowCount", rowCount))
|
||||
|
||||
primaryData, ok := fieldsData[primaryKey.GetFieldID()]
|
||||
if !ok {
|
||||
log.Error("import wrapper: primary key field is not provided", zap.String("keyName", primaryKey.GetName()))
|
||||
return fmt.Errorf("primary key field is not provided")
|
||||
}
|
||||
|
||||
// generate auto id for primary key and rowid field
|
||||
rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount))
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to alloc row ID", zap.Error(err))
|
||||
return fmt.Errorf("failed to alloc row ID, error: %w", err)
|
||||
}
|
||||
|
||||
rowIDField := fieldsData[common.RowIDField]
|
||||
rowIDFieldArr := rowIDField.(*storage.Int64FieldData)
|
||||
for i := rowIDBegin; i < rowIDEnd; i++ {
|
||||
rowIDFieldArr.Data = append(rowIDFieldArr.Data, i)
|
||||
}
|
||||
|
||||
if primaryKey.GetAutoID() {
|
||||
log.Info("import wrapper: generating auto-id", zap.Int("rowCount", rowCount), zap.Int64("rowIDBegin", rowIDBegin))
|
||||
|
||||
// reset the primary keys, as we know, only int64 pk can be auto-generated
|
||||
primaryDataArr := &storage.Int64FieldData{
|
||||
NumRows: []int64{int64(rowCount)},
|
||||
Data: make([]int64, 0, rowCount),
|
||||
}
|
||||
for i := rowIDBegin; i < rowIDEnd; i++ {
|
||||
primaryDataArr.Data = append(primaryDataArr.Data, i)
|
||||
}
|
||||
|
||||
primaryData = primaryDataArr
|
||||
fieldsData[primaryKey.GetFieldID()] = primaryData
|
||||
p.importResult.AutoIds = append(p.importResult.AutoIds, rowIDBegin, rowIDEnd)
|
||||
}
|
||||
|
||||
if primaryData.RowNum() <= 0 {
|
||||
log.Error("import wrapper: primary key is not provided", zap.String("keyName", primaryKey.GetName()))
|
||||
return fmt.Errorf("the primary key '%s' is not provided", primaryKey.GetName())
|
||||
}
|
||||
|
||||
// prepare segemnts
|
||||
segmentsData := make([]map[storage.FieldID]storage.FieldData, 0, p.shardNum)
|
||||
for i := 0; i < int(p.shardNum); i++ {
|
||||
segmentData := initSegmentData(p.collectionSchema)
|
||||
if segmentData == nil {
|
||||
log.Error("import wrapper: failed to initialize FieldData list")
|
||||
return fmt.Errorf("failed to initialize FieldData list")
|
||||
}
|
||||
segmentsData = append(segmentsData, segmentData)
|
||||
}
|
||||
|
||||
// prepare append functions
|
||||
appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error)
|
||||
for i := 0; i < len(p.collectionSchema.Fields); i++ {
|
||||
schema := p.collectionSchema.Fields[i]
|
||||
appendFuncErr := p.appendFunc(schema)
|
||||
if appendFuncErr == nil {
|
||||
log.Error("import wrapper: unsupported field data type")
|
||||
return fmt.Errorf("unsupported field data type: %d", schema.GetDataType())
|
||||
}
|
||||
appendFunctions[schema.GetName()] = appendFuncErr
|
||||
}
|
||||
|
||||
// split data into shards
|
||||
for i := 0; i < rowCount; i++ {
|
||||
// hash to a shard number
|
||||
var shard uint32
|
||||
pk := primaryData.GetRow(i)
|
||||
strPK, ok := interface{}(pk).(string)
|
||||
if ok {
|
||||
hash := typeutil.HashString2Uint32(strPK)
|
||||
shard = hash % uint32(p.shardNum)
|
||||
} else {
|
||||
intPK, ok := interface{}(pk).(int64)
|
||||
if !ok {
|
||||
log.Error("import wrapper: primary key field must be int64 or varchar")
|
||||
return fmt.Errorf("primary key field must be int64 or varchar")
|
||||
}
|
||||
hash, _ := typeutil.Hash32Int64(intPK)
|
||||
shard = hash % uint32(p.shardNum)
|
||||
}
|
||||
|
||||
// set rowID field
|
||||
rowIDField := segmentsData[shard][common.RowIDField].(*storage.Int64FieldData)
|
||||
rowIDField.Data = append(rowIDField.Data, rowIDFieldArr.GetRow(i).(int64))
|
||||
|
||||
// append row to shard
|
||||
for k := 0; k < len(p.collectionSchema.Fields); k++ {
|
||||
schema := p.collectionSchema.Fields[k]
|
||||
srcData := fieldsData[schema.GetFieldID()]
|
||||
targetData := segmentsData[shard][schema.GetFieldID()]
|
||||
appendFunc := appendFunctions[schema.GetName()]
|
||||
err := appendFunc(srcData, i, targetData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// when the estimated size is close to blockSize, force flush
|
||||
err = tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// force flush at the end
|
||||
return tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, true)
|
||||
}
|
||||
|
||||
// flushFunc is the callback function for parsers generate segment and save binlog files
|
||||
func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
// if fields data is empty, do nothing
|
||||
|
|
|
@ -579,79 +579,6 @@ func Test_ImportWrapperRowBased_perf(t *testing.T) {
|
|||
tr.Record("parse large json file " + filePath)
|
||||
}
|
||||
|
||||
func Test_ImportWrapperValidateColumnBasedFiles(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
cm := &MockChunkManager{
|
||||
size: 1,
|
||||
}
|
||||
|
||||
idAllocator := newIDAllocator(ctx, t, nil)
|
||||
shardNum := 2
|
||||
segmentSize := 512 // unit: MB
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "ID",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "Age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "Vector",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "10"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
|
||||
|
||||
// file for PK is redundant
|
||||
files := []string{"ID.npy", "Age.npy", "Vector.npy"}
|
||||
err := wrapper.validateColumnBasedFiles(files, schema)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// file for PK is not redundant
|
||||
schema.Fields[0].AutoID = false
|
||||
err = wrapper.validateColumnBasedFiles(files, schema)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// file missed
|
||||
files = []string{"Age.npy", "Vector.npy"}
|
||||
err = wrapper.validateColumnBasedFiles(files, schema)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
files = []string{"ID.npy", "Vector.npy"}
|
||||
err = wrapper.validateColumnBasedFiles(files, schema)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// redundant file
|
||||
files = []string{"ID.npy", "Age.npy", "Vector.npy", "dummy.npy"}
|
||||
err = wrapper.validateColumnBasedFiles(files, schema)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// correct input
|
||||
files = []string{"ID.npy", "Age.npy", "Vector.npy"}
|
||||
err = wrapper.validateColumnBasedFiles(files, schema)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func Test_ImportWrapperFileValidation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -668,7 +595,7 @@ func Test_ImportWrapperFileValidation(t *testing.T) {
|
|||
FieldID: 101,
|
||||
Name: "uid",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
|
@ -684,84 +611,76 @@ func Test_ImportWrapperFileValidation(t *testing.T) {
|
|||
|
||||
wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
|
||||
|
||||
// unsupported file type
|
||||
files := []string{"uid.txt"}
|
||||
rowBased, err := wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.False(t, rowBased)
|
||||
t.Run("unsupported file type", func(t *testing.T) {
|
||||
files := []string{"uid.txt"}
|
||||
rowBased, err := wrapper.fileValidation(files)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, rowBased)
|
||||
})
|
||||
|
||||
// file missed
|
||||
files = []string{"uid.npy"}
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.False(t, rowBased)
|
||||
t.Run("duplicate files", func(t *testing.T) {
|
||||
files := []string{"a/1.json", "b/1.json"}
|
||||
rowBased, err := wrapper.fileValidation(files)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, rowBased)
|
||||
|
||||
// redundant file
|
||||
files = []string{"uid.npy", "b/bol.npy", "c/no.npy"}
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.False(t, rowBased)
|
||||
files = []string{"a/uid.npy", "uid.npy", "b/bol.npy"}
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, rowBased)
|
||||
})
|
||||
|
||||
// duplicate files
|
||||
files = []string{"a/1.json", "b/1.json"}
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, rowBased)
|
||||
t.Run("unsupported file for row-based", func(t *testing.T) {
|
||||
files := []string{"a/uid.json", "b/bol.npy"}
|
||||
rowBased, err := wrapper.fileValidation(files)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, rowBased)
|
||||
})
|
||||
|
||||
files = []string{"a/uid.npy", "uid.npy", "b/bol.npy"}
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.False(t, rowBased)
|
||||
t.Run("unsupported file for column-based", func(t *testing.T) {
|
||||
files := []string{"a/uid.npy", "b/bol.json"}
|
||||
rowBased, err := wrapper.fileValidation(files)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, rowBased)
|
||||
})
|
||||
|
||||
// unsupported file for row-based
|
||||
files = []string{"a/uid.json", "b/bol.npy"}
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, rowBased)
|
||||
t.Run("valid cases", func(t *testing.T) {
|
||||
files := []string{"a/1.json", "b/2.json"}
|
||||
rowBased, err := wrapper.fileValidation(files)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, rowBased)
|
||||
|
||||
// unsupported file for column-based
|
||||
files = []string{"a/uid.npy", "b/bol.json"}
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.False(t, rowBased)
|
||||
files = []string{"a/uid.npy", "b/bol.npy"}
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, rowBased)
|
||||
})
|
||||
|
||||
// valid cases
|
||||
files = []string{"a/1.json", "b/2.json"}
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, rowBased)
|
||||
t.Run("empty file", func(t *testing.T) {
|
||||
files := []string{}
|
||||
cm.size = 0
|
||||
wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
|
||||
rowBased, err := wrapper.fileValidation(files)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, rowBased)
|
||||
})
|
||||
|
||||
files = []string{"a/uid.npy", "b/bol.npy"}
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, rowBased)
|
||||
t.Run("file size exceed MaxFileSize limit", func(t *testing.T) {
|
||||
files := []string{"a/1.json"}
|
||||
cm.size = MaxFileSize + 1
|
||||
wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
|
||||
rowBased, err := wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, rowBased)
|
||||
})
|
||||
|
||||
// empty file
|
||||
cm.size = 0
|
||||
wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.False(t, rowBased)
|
||||
|
||||
// file size exceed MaxFileSize limit
|
||||
cm.size = MaxFileSize + 1
|
||||
wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil)
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.False(t, rowBased)
|
||||
|
||||
// total files size exceed MaxTotalSizeInMemory limit
|
||||
cm.size = MaxFileSize - 1
|
||||
files = append(files, "3.npy")
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.False(t, rowBased)
|
||||
|
||||
// failed to get file size
|
||||
cm.sizeErr = errors.New("error")
|
||||
rowBased, err = wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.False(t, rowBased)
|
||||
t.Run("failed to get file size", func(t *testing.T) {
|
||||
files := []string{"a/1.json"}
|
||||
cm.sizeErr = errors.New("error")
|
||||
rowBased, err := wrapper.fileValidation(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, rowBased)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_ImportWrapperReportFailRowBased(t *testing.T) {
|
||||
|
@ -1001,122 +920,6 @@ func Test_ImportWrapperDoBinlogImport(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func Test_ImportWrapperSplitFieldsData(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
cm := &MockChunkManager{}
|
||||
|
||||
idAllocator := newIDAllocator(ctx, t, nil)
|
||||
|
||||
rowCounter := &rowCounterTest{}
|
||||
assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter)
|
||||
|
||||
importResult := &rootcoordpb.ImportResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
TaskId: 1,
|
||||
DatanodeId: 1,
|
||||
State: commonpb.ImportState_ImportStarted,
|
||||
Segments: make([]int64, 0),
|
||||
AutoIds: make([]int64, 0),
|
||||
RowCount: 0,
|
||||
}
|
||||
reportFunc := func(res *rootcoordpb.ImportResult) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "uid",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "flag",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_Bool,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
wrapper := NewImportWrapper(ctx, schema, 2, 1024*1024, idAllocator, cm, importResult, reportFunc)
|
||||
wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc)
|
||||
|
||||
// nil input
|
||||
err := wrapper.splitFieldsData(nil, 0)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// split 100 rows to 4 blocks, success
|
||||
rowCount := 100
|
||||
input := initSegmentData(schema)
|
||||
for j := 0; j < rowCount; j++ {
|
||||
pkField := input[101].(*storage.Int64FieldData)
|
||||
pkField.Data = append(pkField.Data, int64(j))
|
||||
|
||||
flagField := input[102].(*storage.BoolFieldData)
|
||||
flagField.Data = append(flagField.Data, true)
|
||||
}
|
||||
|
||||
err = wrapper.splitFieldsData(input, 512)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(importResult.AutoIds))
|
||||
assert.Equal(t, 4, rowCounter.callTime)
|
||||
assert.Equal(t, rowCount, rowCounter.rowCount)
|
||||
|
||||
// alloc id failed
|
||||
wrapper.rowIDAllocator = newIDAllocator(ctx, t, errors.New("error"))
|
||||
err = wrapper.splitFieldsData(input, 512)
|
||||
assert.NotNil(t, err)
|
||||
wrapper.rowIDAllocator = newIDAllocator(ctx, t, nil)
|
||||
|
||||
// row count of fields are unequal
|
||||
schema.Fields[0].AutoID = false
|
||||
input = initSegmentData(schema)
|
||||
for j := 0; j < rowCount; j++ {
|
||||
pkField := input[101].(*storage.Int64FieldData)
|
||||
pkField.Data = append(pkField.Data, int64(j))
|
||||
if j%2 == 0 {
|
||||
continue
|
||||
}
|
||||
flagField := input[102].(*storage.BoolFieldData)
|
||||
flagField.Data = append(flagField.Data, true)
|
||||
}
|
||||
err = wrapper.splitFieldsData(input, 512)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// primary key not found
|
||||
wrapper.collectionSchema.Fields[0].IsPrimaryKey = false
|
||||
err = wrapper.splitFieldsData(input, 512)
|
||||
assert.NotNil(t, err)
|
||||
wrapper.collectionSchema.Fields[0].IsPrimaryKey = true
|
||||
|
||||
// primary key is varchar, success
|
||||
wrapper.collectionSchema.Fields[0].DataType = schemapb.DataType_VarChar
|
||||
input = initSegmentData(schema)
|
||||
for j := 0; j < rowCount; j++ {
|
||||
pkField := input[101].(*storage.StringFieldData)
|
||||
pkField.Data = append(pkField.Data, strconv.FormatInt(int64(j), 10))
|
||||
|
||||
flagField := input[102].(*storage.BoolFieldData)
|
||||
flagField.Data = append(flagField.Data, true)
|
||||
}
|
||||
rowCounter.callTime = 0
|
||||
rowCounter.rowCount = 0
|
||||
importResult.AutoIds = []int64{}
|
||||
err = wrapper.splitFieldsData(input, 1024)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(importResult.AutoIds))
|
||||
assert.Equal(t, 2, rowCounter.callTime)
|
||||
assert.Equal(t, rowCount, rowCounter.rowCount)
|
||||
}
|
||||
|
||||
func Test_ImportWrapperReportPersisted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tr := timerecord.NewTimeRecorder("test")
|
||||
|
|
|
@ -35,15 +35,11 @@ import (
|
|||
const (
|
||||
// root field of row-based json format
|
||||
RowRootNode = "rows"
|
||||
// minimal size of a buffer
|
||||
MinBufferSize = 1024
|
||||
// split file into batches no more than this count
|
||||
MaxBatchCount = 16
|
||||
)
|
||||
|
||||
type JSONParser struct {
|
||||
ctx context.Context // for canceling parse process
|
||||
bufSize int64 // max rows in a buffer
|
||||
bufRowCount int // max rows in a buffer
|
||||
fields map[string]int64 // fields need to be parsed
|
||||
name2FieldID map[string]storage.FieldID
|
||||
}
|
||||
|
@ -69,7 +65,7 @@ func NewJSONParser(ctx context.Context, collectionSchema *schemapb.CollectionSch
|
|||
|
||||
parser := &JSONParser{
|
||||
ctx: ctx,
|
||||
bufSize: MinBufferSize,
|
||||
bufRowCount: 1024,
|
||||
fields: fields,
|
||||
name2FieldID: name2FieldID,
|
||||
}
|
||||
|
@ -84,19 +80,24 @@ func adjustBufSize(parser *JSONParser, collectionSchema *schemapb.CollectionSche
|
|||
return
|
||||
}
|
||||
|
||||
// split the file into no more than MaxBatchCount batches to parse
|
||||
// for high dimensional vector, the bufSize is a small value, read few rows each time
|
||||
// for low dimensional vector, the bufSize is a large value, read more rows each time
|
||||
maxRows := MaxFileSize / sizePerRecord
|
||||
bufSize := maxRows / MaxBatchCount
|
||||
|
||||
// bufSize should not be less than MinBufferSize
|
||||
if bufSize < MinBufferSize {
|
||||
bufSize = MinBufferSize
|
||||
bufRowCount := parser.bufRowCount
|
||||
for {
|
||||
if bufRowCount*sizePerRecord > SingleBlockSize {
|
||||
bufRowCount--
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("JSON parser: reset bufSize", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufSize", bufSize))
|
||||
parser.bufSize = int64(bufSize)
|
||||
// at least one row per buffer
|
||||
if bufRowCount <= 0 {
|
||||
bufRowCount = 1
|
||||
}
|
||||
|
||||
log.Info("JSON parser: reset bufRowCount", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufRowCount", bufRowCount))
|
||||
parser.bufRowCount = bufRowCount
|
||||
}
|
||||
|
||||
func (p *JSONParser) verifyRow(raw interface{}) (map[storage.FieldID]interface{}, error) {
|
||||
|
@ -185,7 +186,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
}
|
||||
|
||||
// read buffer
|
||||
buf := make([]map[storage.FieldID]interface{}, 0, MinBufferSize)
|
||||
buf := make([]map[storage.FieldID]interface{}, 0, p.bufRowCount)
|
||||
for dec.More() {
|
||||
var value interface{}
|
||||
if err := dec.Decode(&value); err != nil {
|
||||
|
@ -199,7 +200,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
}
|
||||
|
||||
buf = append(buf, row)
|
||||
if len(buf) >= int(p.bufSize) {
|
||||
if len(buf) >= p.bufRowCount {
|
||||
isEmpty = false
|
||||
if err = handler.Handle(buf); err != nil {
|
||||
log.Error("JSON parser: failed to convert row value to entity", zap.Error(err))
|
||||
|
@ -207,7 +208,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
}
|
||||
|
||||
// clear the buffer
|
||||
buf = make([]map[storage.FieldID]interface{}, 0, MinBufferSize)
|
||||
buf = make([]map[storage.FieldID]interface{}, 0, p.bufRowCount)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -28,7 +28,6 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -58,11 +57,7 @@ func Test_AdjustBufSize(t *testing.T) {
|
|||
schema := sampleSchema()
|
||||
parser := NewJSONParser(ctx, schema)
|
||||
assert.NotNil(t, parser)
|
||||
|
||||
sizePerRecord, err := typeutil.EstimateSizePerRecord(schema)
|
||||
assert.Nil(t, err)
|
||||
assert.Greater(t, sizePerRecord, 0)
|
||||
assert.Equal(t, MaxBatchCount, MaxFileSize/(sizePerRecord*int(parser.bufSize)))
|
||||
assert.Greater(t, parser.bufRowCount, 0)
|
||||
|
||||
// huge row
|
||||
schema.Fields[9].TypeParams = []*commonpb.KeyValuePair{
|
||||
|
@ -70,9 +65,7 @@ func Test_AdjustBufSize(t *testing.T) {
|
|||
}
|
||||
parser = NewJSONParser(ctx, schema)
|
||||
assert.NotNil(t, parser)
|
||||
sizePerRecord, _ = typeutil.EstimateSizePerRecord(schema)
|
||||
|
||||
assert.Equal(t, 7, MaxFileSize/(sizePerRecord*int(parser.bufSize)))
|
||||
assert.Greater(t, parser.bufRowCount, 0)
|
||||
|
||||
// no change
|
||||
schema = &schemapb.CollectionSchema{
|
||||
|
@ -83,8 +76,7 @@ func Test_AdjustBufSize(t *testing.T) {
|
|||
}
|
||||
parser = NewJSONParser(ctx, schema)
|
||||
assert.NotNil(t, parser)
|
||||
|
||||
assert.Equal(t, int64(MinBufferSize), parser.bufSize)
|
||||
assert.Greater(t, parser.bufRowCount, 0)
|
||||
}
|
||||
|
||||
func Test_JSONParserParseRows_IntPK(t *testing.T) {
|
||||
|
@ -127,8 +119,8 @@ func Test_JSONParserParseRows_IntPK(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("parse success", func(t *testing.T) {
|
||||
// set bufSize = 4, means call handle() after reading 4 rows
|
||||
parser.bufSize = 4
|
||||
// set bufRowCount = 4, means call handle() after reading 4 rows
|
||||
parser.bufRowCount = 4
|
||||
err = parser.ParseRows(reader, consumer)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(content.Rows), len(consumer.rows))
|
||||
|
@ -285,12 +277,12 @@ func Test_JSONParserParseRows_IntPK(t *testing.T) {
|
|||
}`
|
||||
consumer.handleErr = errors.New("error")
|
||||
reader = strings.NewReader(content)
|
||||
parser.bufSize = 2
|
||||
parser.bufRowCount = 2
|
||||
err = parser.ParseRows(reader, consumer)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
reader = strings.NewReader(content)
|
||||
parser.bufSize = 5
|
||||
parser.bufRowCount = 5
|
||||
err = parser.ParseRows(reader, consumer)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
|
|
|
@ -86,11 +86,13 @@ type NumpyAdapter struct {
|
|||
func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) {
|
||||
r, err := npyio.NewReader(reader)
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read numpy header", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataType, err := convertNumpyType(r.Header.Descr.Type)
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to detect data type", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -109,12 +111,11 @@ func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) {
|
|||
zap.Uint8("minorVer", r.Header.Minor),
|
||||
zap.String("ByteOrder", adapter.order.String()))
|
||||
|
||||
return adapter, err
|
||||
return adapter, nil
|
||||
}
|
||||
|
||||
// convertNumpyType gets data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector)
|
||||
func convertNumpyType(typeStr string) (schemapb.DataType, error) {
|
||||
log.Info("Numpy adapter: parse numpy file dtype", zap.String("dtype", typeStr))
|
||||
switch typeStr {
|
||||
case "b1", "<b1", "|b1", "bool":
|
||||
return schemapb.DataType_Bool, nil
|
||||
|
@ -252,24 +253,29 @@ func (n *NumpyAdapter) checkCount(count int) int {
|
|||
|
||||
func (n *NumpyAdapter) ReadBool(count int) ([]bool, error) {
|
||||
if count <= 0 {
|
||||
log.Error("Numpy adapter: cannot read bool data with a zero or nagative count")
|
||||
return nil, errors.New("cannot read bool data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
if n.dataType != schemapb.DataType_Bool {
|
||||
log.Error("Numpy adapter: numpy data is not bool type")
|
||||
return nil, errors.New("numpy data is not bool type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("end of bool file, nothing to read")
|
||||
// end of file, nothing to read
|
||||
log.Info("Numpy adapter: read to end of file, type: bool")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]bool, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read bool data", zap.Int("count", count), zap.Error(err))
|
||||
return nil, fmt.Errorf(" failed to read bool data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
|
@ -281,6 +287,7 @@ func (n *NumpyAdapter) ReadBool(count int) ([]bool, error) {
|
|||
|
||||
func (n *NumpyAdapter) ReadUint8(count int) ([]uint8, error) {
|
||||
if count <= 0 {
|
||||
log.Error("Numpy adapter: cannot read uint8 data with a zero or nagative count")
|
||||
return nil, errors.New("cannot read uint8 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
|
@ -289,19 +296,23 @@ func (n *NumpyAdapter) ReadUint8(count int) ([]uint8, error) {
|
|||
switch n.npyReader.Header.Descr.Type {
|
||||
case "u1", "<u1", "|u1", "uint8":
|
||||
default:
|
||||
log.Error("Numpy adapter: numpy data is not uint8 type")
|
||||
return nil, errors.New("numpy data is not uint8 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("end of uint8 file, nothing to read")
|
||||
// end of file, nothing to read
|
||||
log.Info("Numpy adapter: read to end of file, type: uint8")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]uint8, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read uint8 data", zap.Int("count", count), zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read uint8 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
|
@ -313,24 +324,29 @@ func (n *NumpyAdapter) ReadUint8(count int) ([]uint8, error) {
|
|||
|
||||
func (n *NumpyAdapter) ReadInt8(count int) ([]int8, error) {
|
||||
if count <= 0 {
|
||||
log.Error("Numpy adapter: cannot read int8 data with a zero or nagative count")
|
||||
return nil, errors.New("cannot read int8 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
if n.dataType != schemapb.DataType_Int8 {
|
||||
log.Error("Numpy adapter: numpy data is not int8 type")
|
||||
return nil, errors.New("numpy data is not int8 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("end of int8 file, nothing to read")
|
||||
// end of file, nothing to read
|
||||
log.Info("Numpy adapter: read to end of file, type: int8")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]int8, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read int8 data", zap.Int("count", count), zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read int8 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
|
@ -342,24 +358,29 @@ func (n *NumpyAdapter) ReadInt8(count int) ([]int8, error) {
|
|||
|
||||
func (n *NumpyAdapter) ReadInt16(count int) ([]int16, error) {
|
||||
if count <= 0 {
|
||||
log.Error("Numpy adapter: cannot read int16 data with a zero or nagative count")
|
||||
return nil, errors.New("cannot read int16 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
if n.dataType != schemapb.DataType_Int16 {
|
||||
log.Error("Numpy adapter: numpy data is not int16 type")
|
||||
return nil, errors.New("numpy data is not int16 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("end of int16 file, nothing to read")
|
||||
// end of file, nothing to read
|
||||
log.Info("Numpy adapter: read to end of file, type: int16")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]int16, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read int16 data", zap.Int("count", count), zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read int16 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
|
@ -371,24 +392,29 @@ func (n *NumpyAdapter) ReadInt16(count int) ([]int16, error) {
|
|||
|
||||
func (n *NumpyAdapter) ReadInt32(count int) ([]int32, error) {
|
||||
if count <= 0 {
|
||||
log.Error("Numpy adapter: cannot read int32 data with a zero or nagative count")
|
||||
return nil, errors.New("cannot read int32 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
if n.dataType != schemapb.DataType_Int32 {
|
||||
log.Error("Numpy adapter: numpy data is not int32 type")
|
||||
return nil, errors.New("numpy data is not int32 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("end of int32 file, nothing to read")
|
||||
// end of file, nothing to read
|
||||
log.Info("Numpy adapter: read to end of file, type: int32")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]int32, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read int32 data", zap.Int("count", count), zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read int32 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
|
@ -400,24 +426,29 @@ func (n *NumpyAdapter) ReadInt32(count int) ([]int32, error) {
|
|||
|
||||
func (n *NumpyAdapter) ReadInt64(count int) ([]int64, error) {
|
||||
if count <= 0 {
|
||||
log.Error("Numpy adapter: cannot read int64 data with a zero or nagative count")
|
||||
return nil, errors.New("cannot read int64 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
if n.dataType != schemapb.DataType_Int64 {
|
||||
log.Error("Numpy adapter: numpy data is not int64 type")
|
||||
return nil, errors.New("numpy data is not int64 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("end of int64 file, nothing to read")
|
||||
// end of file, nothing to read
|
||||
log.Info("Numpy adapter: read to end of file, type: int64")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]int64, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read int64 data", zap.Int("count", count), zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read int64 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
|
@ -429,24 +460,29 @@ func (n *NumpyAdapter) ReadInt64(count int) ([]int64, error) {
|
|||
|
||||
func (n *NumpyAdapter) ReadFloat32(count int) ([]float32, error) {
|
||||
if count <= 0 {
|
||||
log.Error("Numpy adapter: cannot read float32 data with a zero or nagative count")
|
||||
return nil, errors.New("cannot read float32 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
if n.dataType != schemapb.DataType_Float {
|
||||
log.Error("Numpy adapter: numpy data is not float32 type")
|
||||
return nil, errors.New("numpy data is not float32 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("end of float32 file, nothing to read")
|
||||
// end of file, nothing to read
|
||||
log.Info("Numpy adapter: read to end of file, type: float32")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]float32, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read float32 data", zap.Int("count", count), zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read float32 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
|
@ -458,24 +494,29 @@ func (n *NumpyAdapter) ReadFloat32(count int) ([]float32, error) {
|
|||
|
||||
func (n *NumpyAdapter) ReadFloat64(count int) ([]float64, error) {
|
||||
if count <= 0 {
|
||||
log.Error("Numpy adapter: cannot read float64 data with a zero or nagative count")
|
||||
return nil, errors.New("cannot read float64 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
if n.dataType != schemapb.DataType_Double {
|
||||
log.Error("Numpy adapter: numpy data is not float64 type")
|
||||
return nil, errors.New("numpy data is not float64 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("end of float64 file, nothing to read")
|
||||
// end of file, nothing to read
|
||||
log.Info("Numpy adapter: read to end of file, type: float64")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]float64, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read float64 data", zap.Int("count", count), zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read float64 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
|
@ -487,12 +528,14 @@ func (n *NumpyAdapter) ReadFloat64(count int) ([]float64, error) {
|
|||
|
||||
func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
|
||||
if count <= 0 {
|
||||
return nil, errors.New("cannot read varhar data with a zero or nagative count")
|
||||
log.Error("Numpy adapter: cannot read varchar data with a zero or nagative count")
|
||||
return nil, errors.New("cannot read varchar data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
if n.dataType != schemapb.DataType_VarChar {
|
||||
return nil, errors.New("numpy data is not varhar type")
|
||||
log.Error("Numpy adapter: numpy data is not varchar type")
|
||||
return nil, errors.New("numpy data is not varchar type")
|
||||
}
|
||||
|
||||
// varchar length, this is the max length, some item is shorter than this length, but they also occupy bytes of max length
|
||||
|
@ -501,12 +544,19 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
|
|||
log.Error("Numpy adapter: failed to get max length of varchar from numpy file header", zap.Int("maxLen", maxLen), zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to get max length %d of varchar from numpy file header, error: %w", maxLen, err)
|
||||
}
|
||||
log.Info("Numpy adapter: get varchar max length from numpy file header", zap.Int("maxLen", maxLen), zap.Bool("utf", utf))
|
||||
// log.Info("Numpy adapter: get varchar max length from numpy file header", zap.Int("maxLen", maxLen), zap.Bool("utf", utf))
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("end of varhar file, nothing to read")
|
||||
// end of file, nothing to read
|
||||
log.Info("Numpy adapter: read to end of file, type: varchar")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n.reader == nil {
|
||||
log.Error("Numpy adapter: reader is nil")
|
||||
return nil, errors.New("numpy reader is nil")
|
||||
}
|
||||
|
||||
// read data
|
||||
|
@ -545,6 +595,7 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
|
|||
if n > 0 {
|
||||
buf = buf[:n]
|
||||
}
|
||||
|
||||
data = append(data, string(buf))
|
||||
}
|
||||
}
|
||||
|
@ -557,6 +608,7 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
|
|||
|
||||
func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) {
|
||||
if len(src)%4 != 0 {
|
||||
log.Error("Numpy adapter: invalid utf32 bytes length, the byte array length should be multiple of 4", zap.Int("byteLen", len(src)))
|
||||
return "", fmt.Errorf("invalid utf32 bytes length %d, the byte array length should be multiple of 4", len(src))
|
||||
}
|
||||
|
||||
|
@ -586,6 +638,7 @@ func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) {
|
|||
decoder := unicode.UTF16(uOrder, unicode.IgnoreBOM).NewDecoder()
|
||||
res, err := decoder.Bytes(src[lowbytesPosition : lowbytesPosition+2])
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to decode utf32 binary bytes", zap.Error(err))
|
||||
return "", fmt.Errorf("failed to decode utf32 binary bytes, error: %w", err)
|
||||
}
|
||||
str += string(res)
|
||||
|
@ -603,6 +656,7 @@ func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) {
|
|||
utf8Code := make([]byte, 4)
|
||||
utf8.EncodeRune(utf8Code, r)
|
||||
if r == utf8.RuneError {
|
||||
log.Error("Numpy adapter: failed to convert 4 bytes unicode to utf8 rune", zap.Uint32("code", x))
|
||||
return "", fmt.Errorf("failed to convert 4 bytes unicode %d to utf8 rune", x)
|
||||
}
|
||||
str += string(utf8Code)
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package importutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"os"
|
||||
|
@ -40,12 +41,12 @@ func Test_CreateNumpyFile(t *testing.T) {
|
|||
// directory doesn't exist
|
||||
data1 := []float32{1, 2, 3, 4, 5}
|
||||
err := CreateNumpyFile("/dummy_not_exist/dummy.npy", data1)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// invalid data type
|
||||
data2 := make(map[string]int)
|
||||
err = CreateNumpyFile("/tmp/dummy.npy", data2)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateNumpyData(t *testing.T) {
|
||||
|
@ -53,12 +54,12 @@ func Test_CreateNumpyData(t *testing.T) {
|
|||
data1 := []float32{1, 2, 3, 4, 5}
|
||||
buf, err := CreateNumpyData(data1)
|
||||
assert.NotNil(t, buf)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// invalid data type
|
||||
data2 := make(map[string]int)
|
||||
buf, err = CreateNumpyData(data2)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, buf)
|
||||
}
|
||||
|
||||
|
@ -66,7 +67,7 @@ func Test_ConvertNumpyType(t *testing.T) {
|
|||
checkFunc := func(inputs []string, output schemapb.DataType) {
|
||||
for i := 0; i < len(inputs); i++ {
|
||||
dt, err := convertNumpyType(inputs[i])
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, output, dt)
|
||||
}
|
||||
}
|
||||
|
@ -80,7 +81,7 @@ func Test_ConvertNumpyType(t *testing.T) {
|
|||
checkFunc([]string{"f8", "<f8", "|f8", ">f8", "float64"}, schemapb.DataType_Double)
|
||||
|
||||
dt, err := convertNumpyType("dummy")
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, schemapb.DataType_None, dt)
|
||||
}
|
||||
|
||||
|
@ -88,25 +89,25 @@ func Test_StringLen(t *testing.T) {
|
|||
len, utf, err := stringLen("S1")
|
||||
assert.Equal(t, 1, len)
|
||||
assert.False(t, utf)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
len, utf, err = stringLen("2S")
|
||||
assert.Equal(t, 2, len)
|
||||
assert.False(t, utf)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
len, utf, err = stringLen("<U3")
|
||||
assert.Equal(t, 3, len)
|
||||
assert.True(t, utf)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
len, utf, err = stringLen(">4U")
|
||||
assert.Equal(t, 4, len)
|
||||
assert.True(t, utf)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
len, utf, err = stringLen("dummy")
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 0, len)
|
||||
assert.False(t, utf)
|
||||
}
|
||||
|
@ -129,207 +130,337 @@ func Test_NumpyAdapterSetByteOrder(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_NumpyAdapterReadError(t *testing.T) {
|
||||
adapter := &NumpyAdapter{
|
||||
reader: nil,
|
||||
npyReader: nil,
|
||||
}
|
||||
|
||||
// reader size is zero
|
||||
t.Run("test size is zero", func(t *testing.T) {
|
||||
_, err := adapter.ReadBool(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadUint8(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt8(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt16(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt32(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt64(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadFloat32(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadFloat64(0)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
// t.Run("test size is zero", func(t *testing.T) {
|
||||
// adapter := &NumpyAdapter{
|
||||
// reader: nil,
|
||||
// npyReader: nil,
|
||||
// }
|
||||
// _, err := adapter.ReadBool(0)
|
||||
// assert.Error(t, err)
|
||||
// _, err = adapter.ReadUint8(0)
|
||||
// assert.Error(t, err)
|
||||
// _, err = adapter.ReadInt8(0)
|
||||
// assert.Error(t, err)
|
||||
// _, err = adapter.ReadInt16(0)
|
||||
// assert.Error(t, err)
|
||||
// _, err = adapter.ReadInt32(0)
|
||||
// assert.Error(t, err)
|
||||
// _, err = adapter.ReadInt64(0)
|
||||
// assert.Error(t, err)
|
||||
// _, err = adapter.ReadFloat32(0)
|
||||
// assert.Error(t, err)
|
||||
// _, err = adapter.ReadFloat64(0)
|
||||
// assert.Error(t, err)
|
||||
// })
|
||||
|
||||
adapter = &NumpyAdapter{
|
||||
reader: &MockReader{},
|
||||
npyReader: &npy.Reader{},
|
||||
createAdatper := func(dt schemapb.DataType) *NumpyAdapter {
|
||||
adapter := &NumpyAdapter{
|
||||
reader: &MockReader{},
|
||||
npyReader: &npy.Reader{
|
||||
Header: npy.Header{},
|
||||
},
|
||||
dataType: dt,
|
||||
order: binary.BigEndian,
|
||||
}
|
||||
adapter.npyReader.Header.Descr.Shape = []int{1}
|
||||
return adapter
|
||||
}
|
||||
|
||||
t.Run("test read bool", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "bool"
|
||||
// type mismatch
|
||||
adapter := createAdatper(schemapb.DataType_None)
|
||||
data, err := adapter.ReadBool(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
// reader is nil, cannot read
|
||||
adapter = createAdatper(schemapb.DataType_Bool)
|
||||
data, err = adapter.ReadBool(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// read one element from reader
|
||||
adapter.reader = bytes.NewReader([]byte{1})
|
||||
data, err = adapter.ReadBool(1)
|
||||
assert.NotEmpty(t, data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// nothing to read
|
||||
data, err = adapter.ReadBool(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test read uint8", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "u1"
|
||||
// type mismatch
|
||||
adapter := createAdatper(schemapb.DataType_None)
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
data, err := adapter.ReadUint8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
// reader is nil, cannot read
|
||||
adapter.npyReader.Header.Descr.Type = "u1"
|
||||
data, err = adapter.ReadUint8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// read one element from reader
|
||||
adapter.reader = bytes.NewReader([]byte{1})
|
||||
data, err = adapter.ReadUint8(1)
|
||||
assert.NotEmpty(t, data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// nothing to read
|
||||
data, err = adapter.ReadUint8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test read int8", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "i1"
|
||||
// type mismatch
|
||||
adapter := createAdatper(schemapb.DataType_None)
|
||||
data, err := adapter.ReadInt8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
// reader is nil, cannot read
|
||||
adapter = createAdatper(schemapb.DataType_Int8)
|
||||
data, err = adapter.ReadInt8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// read one element from reader
|
||||
adapter.reader = bytes.NewReader([]byte{1})
|
||||
data, err = adapter.ReadInt8(1)
|
||||
assert.NotEmpty(t, data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// nothing to read
|
||||
data, err = adapter.ReadInt8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test read int16", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "i2"
|
||||
// type mismatch
|
||||
adapter := createAdatper(schemapb.DataType_None)
|
||||
data, err := adapter.ReadInt16(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
// reader is nil, cannot read
|
||||
adapter = createAdatper(schemapb.DataType_Int16)
|
||||
data, err = adapter.ReadInt16(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// read one element from reader
|
||||
adapter.reader = bytes.NewReader([]byte{1, 2})
|
||||
data, err = adapter.ReadInt16(1)
|
||||
assert.NotEmpty(t, data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// nothing to read
|
||||
data, err = adapter.ReadInt16(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test read int32", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "i4"
|
||||
// type mismatch
|
||||
adapter := createAdatper(schemapb.DataType_None)
|
||||
data, err := adapter.ReadInt32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
// reader is nil, cannot read
|
||||
adapter = createAdatper(schemapb.DataType_Int32)
|
||||
data, err = adapter.ReadInt32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// read one element from reader
|
||||
adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4})
|
||||
data, err = adapter.ReadInt32(1)
|
||||
assert.NotEmpty(t, data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// nothing to read
|
||||
data, err = adapter.ReadInt32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test read int64", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "i8"
|
||||
// type mismatch
|
||||
adapter := createAdatper(schemapb.DataType_None)
|
||||
data, err := adapter.ReadInt64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
// reader is nil, cannot read
|
||||
adapter = createAdatper(schemapb.DataType_Int64)
|
||||
data, err = adapter.ReadInt64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// read one element from reader
|
||||
adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
data, err = adapter.ReadInt64(1)
|
||||
assert.NotEmpty(t, data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// nothing to read
|
||||
data, err = adapter.ReadInt64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test read float", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "f4"
|
||||
// type mismatch
|
||||
adapter := createAdatper(schemapb.DataType_None)
|
||||
data, err := adapter.ReadFloat32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
// reader is nil, cannot read
|
||||
adapter = createAdatper(schemapb.DataType_Float)
|
||||
data, err = adapter.ReadFloat32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// read one element from reader
|
||||
adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4})
|
||||
data, err = adapter.ReadFloat32(1)
|
||||
assert.NotEmpty(t, data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// nothing to read
|
||||
data, err = adapter.ReadFloat32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test read double", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "f8"
|
||||
// type mismatch
|
||||
adapter := createAdatper(schemapb.DataType_None)
|
||||
data, err := adapter.ReadFloat64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
// reader is nil, cannot read
|
||||
adapter = createAdatper(schemapb.DataType_Double)
|
||||
data, err = adapter.ReadFloat64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// read one element from reader
|
||||
adapter.reader = bytes.NewReader([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
data, err = adapter.ReadFloat64(1)
|
||||
assert.NotEmpty(t, data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// nothing to read
|
||||
data, err = adapter.ReadFloat64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test read varchar", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "U3"
|
||||
// type mismatch
|
||||
adapter := createAdatper(schemapb.DataType_None)
|
||||
data, err := adapter.ReadString(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
// reader is nil, cannot read
|
||||
adapter = createAdatper(schemapb.DataType_VarChar)
|
||||
adapter.reader = nil
|
||||
adapter.npyReader.Header.Descr.Type = "S3"
|
||||
data, err = adapter.ReadString(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// read one element from reader
|
||||
adapter.reader = strings.NewReader("abc")
|
||||
data, err = adapter.ReadString(1)
|
||||
assert.NotEmpty(t, data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// nothing to read
|
||||
data, err = adapter.ReadString(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_NumpyAdapterRead(t *testing.T) {
|
||||
err := os.MkdirAll(TempFilesPath, os.ModePerm)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer os.RemoveAll(TempFilesPath)
|
||||
|
||||
t.Run("test read bool", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "bool.npy"
|
||||
data := []bool{true, false, true, false}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// partly read
|
||||
res, err := adapter.ReadBool(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
// read the left data
|
||||
res, err = adapter.ReadBool(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
// nothing to read
|
||||
res, err = adapter.ReadBool(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res)
|
||||
|
||||
// incorrect type read
|
||||
resu1, err := adapter.ReadUint8(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resu1)
|
||||
|
||||
resi1, err := adapter.ReadInt8(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resi1)
|
||||
|
||||
resi2, err := adapter.ReadInt16(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resi2)
|
||||
|
||||
resi4, err := adapter.ReadInt32(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resi4)
|
||||
|
||||
resi8, err := adapter.ReadInt64(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resi8)
|
||||
|
||||
resf4, err := adapter.ReadFloat32(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resf4)
|
||||
|
||||
resf8, err := adapter.ReadFloat64(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resf8)
|
||||
})
|
||||
|
||||
|
@ -337,35 +468,38 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
filePath := TempFilesPath + "uint8.npy"
|
||||
data := []uint8{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// partly read
|
||||
res, err := adapter.ReadUint8(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
// read the left data
|
||||
res, err = adapter.ReadUint8(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
// nothing to read
|
||||
res, err = adapter.ReadUint8(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res)
|
||||
|
||||
// incorrect type read
|
||||
resb, err := adapter.ReadBool(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resb)
|
||||
})
|
||||
|
||||
|
@ -373,30 +507,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
filePath := TempFilesPath + "int8.npy"
|
||||
data := []int8{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// partly read
|
||||
res, err := adapter.ReadInt8(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
// read the left data
|
||||
res, err = adapter.ReadInt8(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
// nothing to read
|
||||
res, err = adapter.ReadInt8(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
|
@ -404,30 +541,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
filePath := TempFilesPath + "int16.npy"
|
||||
data := []int16{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// partly read
|
||||
res, err := adapter.ReadInt16(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
// read the left data
|
||||
res, err = adapter.ReadInt16(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
// nothing to read
|
||||
res, err = adapter.ReadInt16(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
|
@ -435,30 +575,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
filePath := TempFilesPath + "int32.npy"
|
||||
data := []int32{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// partly read
|
||||
res, err := adapter.ReadInt32(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
// read the left data
|
||||
res, err = adapter.ReadInt32(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
// nothing to read
|
||||
res, err = adapter.ReadInt32(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
|
@ -466,30 +609,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
filePath := TempFilesPath + "int64.npy"
|
||||
data := []int64{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// partly read
|
||||
res, err := adapter.ReadInt64(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
// read the left data
|
||||
res, err = adapter.ReadInt64(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
// nothing to read
|
||||
res, err = adapter.ReadInt64(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
|
@ -497,30 +643,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
filePath := TempFilesPath + "float.npy"
|
||||
data := []float32{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// partly read
|
||||
res, err := adapter.ReadFloat32(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
// read the left data
|
||||
res, err = adapter.ReadFloat32(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
// nothing to read
|
||||
res, err = adapter.ReadFloat32(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
|
@ -528,30 +677,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
filePath := TempFilesPath + "double.npy"
|
||||
data := []float64{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// partly read
|
||||
res, err := adapter.ReadFloat64(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
// read the left data
|
||||
res, err = adapter.ReadFloat64(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
// nothing to read
|
||||
res, err = adapter.ReadFloat64(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
|
@ -589,19 +741,19 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
|
||||
// count should greater than 0
|
||||
res, err := adapter.ReadString(0)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, res)
|
||||
|
||||
// maxLen is zero
|
||||
npyReader.Header.Descr.Type = "S0"
|
||||
res, err = adapter.ReadString(1)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, res)
|
||||
|
||||
npyReader.Header.Descr.Type = "S" + strconv.FormatInt(int64(maxLen), 10)
|
||||
|
||||
res, err = adapter.ReadString(len(values) + 1)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(values), len(res))
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, values[i], res[i])
|
||||
|
@ -612,29 +764,33 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
filePath := TempFilesPath + "varchar1.npy"
|
||||
data := []string{"a ", "bbb", " c", "dd", "eeee", "fff"}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// partly read
|
||||
res, err := adapter.ReadString(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
// read the left data
|
||||
res, err = adapter.ReadString(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
// nothing to read
|
||||
res, err = adapter.ReadString(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
|
@ -642,16 +798,16 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
filePath := TempFilesPath + "varchar2.npy"
|
||||
data := []string{"で と ど ", " 马克bbb", "$(한)삼각*"}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
res, err := adapter.ReadString(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data), len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
|
@ -663,7 +819,7 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
func Test_DecodeUtf32(t *testing.T) {
|
||||
// wrong input
|
||||
res, err := decodeUtf32([]byte{1, 2}, binary.LittleEndian)
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, res)
|
||||
|
||||
// this string contains ascii characters and unicode characters
|
||||
|
@ -672,12 +828,12 @@ func Test_DecodeUtf32(t *testing.T) {
|
|||
// utf32 littleEndian of str
|
||||
src := []byte{97, 0, 0, 0, 100, 0, 0, 0, 228, 37, 0, 0, 9, 78, 0, 0, 126, 118, 0, 0, 181, 243, 1, 0, 144, 48, 0, 0, 153, 33, 0, 0}
|
||||
res, err = decodeUtf32(src, binary.LittleEndian)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, str, res)
|
||||
|
||||
// utf32 bigEndian of str
|
||||
src = []byte{0, 0, 0, 97, 0, 0, 0, 100, 0, 0, 37, 228, 0, 0, 78, 9, 0, 0, 118, 126, 0, 1, 243, 181, 0, 0, 48, 144, 0, 0, 33, 153}
|
||||
res, err = decodeUtf32(src, binary.BigEndian)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, str, res)
|
||||
}
|
||||
|
|
|
@ -20,280 +20,527 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ColumnDesc struct {
|
||||
name string // name of the target column
|
||||
dt schemapb.DataType // data type of the target column
|
||||
elementCount int // how many elements need to be read
|
||||
dimension int // only for vector
|
||||
type NumpyColumnReader struct {
|
||||
fieldName string // name of the target column
|
||||
fieldID storage.FieldID // ID of the target column
|
||||
dataType schemapb.DataType // data type of the target column
|
||||
rowCount int // how many rows need to be read
|
||||
dimension int // only for vector
|
||||
file storage.FileReader // file to be read
|
||||
reader *NumpyAdapter // data reader
|
||||
}
|
||||
|
||||
func closeReaders(columnReaders []*NumpyColumnReader) {
|
||||
for _, reader := range columnReaders {
|
||||
if reader.file != nil {
|
||||
err := reader.file.Close()
|
||||
if err != nil {
|
||||
log.Error("Numper parser: failed to close numpy file", zap.String("fileName", reader.fieldName+NumpyFileExt))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type NumpyParser struct {
|
||||
ctx context.Context // for canceling parse process
|
||||
collectionSchema *schemapb.CollectionSchema // collection schema
|
||||
columnDesc *ColumnDesc // description for target column
|
||||
|
||||
columnData storage.FieldData // in-memory column data
|
||||
callFlushFunc func(field storage.FieldData) error // call back function to output column data
|
||||
rowIDAllocator *allocator.IDAllocator // autoid allocator
|
||||
shardNum int32 // sharding number of the collection
|
||||
blockSize int64 // maximum size of a read block(unit:byte)
|
||||
chunkManager storage.ChunkManager // storage interfaces to browse/read the files
|
||||
autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25
|
||||
callFlushFunc ImportFlushFunc // call back function to flush segment
|
||||
}
|
||||
|
||||
// NewNumpyParser is helper function to create a NumpyParser
|
||||
func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema,
|
||||
flushFunc func(field storage.FieldData) error) *NumpyParser {
|
||||
if collectionSchema == nil || flushFunc == nil {
|
||||
return nil
|
||||
func NewNumpyParser(ctx context.Context,
|
||||
collectionSchema *schemapb.CollectionSchema,
|
||||
idAlloc *allocator.IDAllocator,
|
||||
shardNum int32,
|
||||
blockSize int64,
|
||||
chunkManager storage.ChunkManager,
|
||||
flushFunc ImportFlushFunc) (*NumpyParser, error) {
|
||||
if collectionSchema == nil {
|
||||
log.Error("Numper parser: collection schema is nil")
|
||||
return nil, errors.New("collection schema is nil")
|
||||
}
|
||||
|
||||
if idAlloc == nil {
|
||||
log.Error("Numper parser: id allocator is nil")
|
||||
return nil, errors.New("id allocator is nil")
|
||||
}
|
||||
|
||||
if chunkManager == nil {
|
||||
log.Error("Numper parser: chunk manager pointer is nil")
|
||||
return nil, errors.New("chunk manager pointer is nil")
|
||||
}
|
||||
|
||||
if flushFunc == nil {
|
||||
log.Error("Numper parser: flush function is nil")
|
||||
return nil, errors.New("flush function is nil")
|
||||
}
|
||||
|
||||
parser := &NumpyParser{
|
||||
ctx: ctx,
|
||||
collectionSchema: collectionSchema,
|
||||
columnDesc: &ColumnDesc{},
|
||||
rowIDAllocator: idAlloc,
|
||||
shardNum: shardNum,
|
||||
blockSize: blockSize,
|
||||
chunkManager: chunkManager,
|
||||
autoIDRange: make([]int64, 0),
|
||||
callFlushFunc: flushFunc,
|
||||
}
|
||||
|
||||
return parser
|
||||
return parser, nil
|
||||
}
|
||||
|
||||
func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
|
||||
if adapter == nil {
|
||||
log.Error("Numpy parser: numpy adapter is nil")
|
||||
return errors.New("numpy adapter is nil")
|
||||
func (p *NumpyParser) IDRange() []int64 {
|
||||
return p.autoIDRange
|
||||
}
|
||||
|
||||
// Parse is the function entry
|
||||
func (p *NumpyParser) Parse(filePaths []string) error {
|
||||
// check redundant files for column-based import
|
||||
// if the field is primary key and autoid is false, the file is required
|
||||
// any redundant file is not allowed
|
||||
err := p.validateFileNames(filePaths)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check existence of the target field
|
||||
var schema *schemapb.FieldSchema
|
||||
for i := 0; i < len(p.collectionSchema.Fields); i++ {
|
||||
schema = p.collectionSchema.Fields[i]
|
||||
if schema.GetName() == fieldName {
|
||||
p.columnDesc.name = fieldName
|
||||
break
|
||||
}
|
||||
// open files and verify file header
|
||||
readers, err := p.createReaders(filePaths)
|
||||
// make sure all the files are closed finially, must call this method before the function return
|
||||
defer closeReaders(readers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.columnDesc.name == "" {
|
||||
log.Error("Numpy parser: Numpy parser: the field is not found in collection schema", zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("the field name '%s' is not found in collection schema", fieldName)
|
||||
}
|
||||
|
||||
p.columnDesc.dt = schema.DataType
|
||||
elementType := adapter.GetType()
|
||||
shape := adapter.GetShape()
|
||||
|
||||
var err error
|
||||
// 1. field data type should be consist to numpy data type
|
||||
// 2. vector field dimension should be consist to numpy shape
|
||||
if schemapb.DataType_FloatVector == schema.DataType {
|
||||
// float32/float64 numpy file can be used for float vector file, 2 reasons:
|
||||
// 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit
|
||||
// 2. for float64 numpy file, the performance is worse than float32 numpy file
|
||||
if elementType != schemapb.DataType_Float && elementType != schemapb.DataType_Double {
|
||||
log.Error("Numpy parser: illegal data type of numpy file for float vector field", zap.Any("dataType", elementType),
|
||||
zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("illegal data type %s of numpy file for float vector field '%s'", getTypeName(elementType), schema.GetName())
|
||||
}
|
||||
|
||||
// vector field, the shape should be 2
|
||||
if len(shape) != 2 {
|
||||
log.Error("Numpy parser: illegal shape of numpy file for float vector field, shape should be 2", zap.Int("shape", len(shape)),
|
||||
zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("illegal shape %d of numpy file for float vector field '%s', shape should be 2", shape, schema.GetName())
|
||||
}
|
||||
|
||||
// shape[0] is row count, shape[1] is element count per row
|
||||
p.columnDesc.elementCount = shape[0] * shape[1]
|
||||
|
||||
p.columnDesc.dimension, err = getFieldDimension(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if shape[1] != p.columnDesc.dimension {
|
||||
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName),
|
||||
zap.Int("numpyDimension", shape[1]), zap.Int("fieldDimension", p.columnDesc.dimension))
|
||||
return fmt.Errorf("illegal dimension %d of numpy file for float vector field '%s', dimension should be %d",
|
||||
shape[1], schema.GetName(), p.columnDesc.dimension)
|
||||
}
|
||||
} else if schemapb.DataType_BinaryVector == schema.DataType {
|
||||
if elementType != schemapb.DataType_BinaryVector {
|
||||
log.Error("Numpy parser: illegal data type of numpy file for binary vector field", zap.Any("dataType", elementType),
|
||||
zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("illegal data type %s of numpy file for binary vector field '%s'", getTypeName(elementType), schema.GetName())
|
||||
}
|
||||
|
||||
// vector field, the shape should be 2
|
||||
if len(shape) != 2 {
|
||||
log.Error("Numpy parser: illegal shape of numpy file for binary vector field, shape should be 2", zap.Int("shape", len(shape)),
|
||||
zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("illegal shape %d of numpy file for binary vector field '%s', shape should be 2", shape, schema.GetName())
|
||||
}
|
||||
|
||||
// shape[0] is row count, shape[1] is element count per row
|
||||
p.columnDesc.elementCount = shape[0] * shape[1]
|
||||
|
||||
p.columnDesc.dimension, err = getFieldDimension(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if shape[1] != p.columnDesc.dimension/8 {
|
||||
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName),
|
||||
zap.Int("numpyDimension", shape[1]*8), zap.Int("fieldDimension", p.columnDesc.dimension))
|
||||
return fmt.Errorf("illegal dimension %d of numpy file for binary vector field '%s', dimension should be %d",
|
||||
shape[1]*8, schema.GetName(), p.columnDesc.dimension)
|
||||
}
|
||||
} else {
|
||||
if elementType != schema.DataType {
|
||||
log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType),
|
||||
zap.String("fieldName", fieldName), zap.Any("fieldDataType", schema.DataType))
|
||||
return fmt.Errorf("illegal data type %s of numpy file for scalar field '%s' with type %s",
|
||||
getTypeName(elementType), schema.GetName(), getTypeName(schema.DataType))
|
||||
}
|
||||
|
||||
// scalar field, the shape should be 1
|
||||
if len(shape) != 1 {
|
||||
log.Error("Numpy parser: illegal shape of numpy file for scalar field, shape should be 1", zap.Int("shape", len(shape)),
|
||||
zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("illegal shape %d of numpy file for scalar field '%s', shape should be 1", shape, schema.GetName())
|
||||
}
|
||||
|
||||
p.columnDesc.elementCount = shape[0]
|
||||
// read all data from the numpy files
|
||||
err = p.consume(readers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateFileNames is to check redundant file and missed file
|
||||
func (p *NumpyParser) validateFileNames(filePaths []string) error {
|
||||
requiredFieldNames := make(map[string]interface{})
|
||||
for _, schema := range p.collectionSchema.Fields {
|
||||
if schema.GetIsPrimaryKey() {
|
||||
if !schema.GetAutoID() {
|
||||
requiredFieldNames[schema.GetName()] = nil
|
||||
}
|
||||
} else {
|
||||
requiredFieldNames[schema.GetName()] = nil
|
||||
}
|
||||
}
|
||||
|
||||
// check redundant file
|
||||
fileNames := make(map[string]interface{})
|
||||
for _, filePath := range filePaths {
|
||||
name, _ := GetFileNameAndExt(filePath)
|
||||
fileNames[name] = nil
|
||||
_, ok := requiredFieldNames[name]
|
||||
if !ok {
|
||||
log.Error("Numpy parser: the file has no corresponding field in collection", zap.String("fieldName", name))
|
||||
return fmt.Errorf("the file '%s' has no corresponding field in collection", filePath)
|
||||
}
|
||||
}
|
||||
|
||||
// check missed file
|
||||
for name := range requiredFieldNames {
|
||||
_, ok := fileNames[name]
|
||||
if !ok {
|
||||
log.Error("Numpy parser: there is no file corresponding to field", zap.String("fieldName", name))
|
||||
return fmt.Errorf("there is no file corresponding to field '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createReaders open the files and verify file header
|
||||
func (p *NumpyParser) createReaders(filePaths []string) ([]*NumpyColumnReader, error) {
|
||||
readers := make([]*NumpyColumnReader, 0)
|
||||
|
||||
for _, filePath := range filePaths {
|
||||
fileName, _ := GetFileNameAndExt(filePath)
|
||||
|
||||
// check existence of the target field
|
||||
var schema *schemapb.FieldSchema
|
||||
for i := 0; i < len(p.collectionSchema.Fields); i++ {
|
||||
tmpSchema := p.collectionSchema.Fields[i]
|
||||
if tmpSchema.GetName() == fileName {
|
||||
schema = tmpSchema
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if schema == nil {
|
||||
log.Error("Numpy parser: the field is not found in collection schema", zap.String("fileName", fileName))
|
||||
return nil, fmt.Errorf("the field name '%s' is not found in collection schema", fileName)
|
||||
}
|
||||
|
||||
file, err := p.chunkManager.Reader(p.ctx, filePath)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read the file", zap.String("filePath", filePath), zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read the file '%s', error: %s", filePath, err.Error())
|
||||
}
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read the file header", zap.String("filePath", filePath), zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read the file header '%s', error: %s", filePath, err.Error())
|
||||
}
|
||||
|
||||
if file == nil || adapter == nil {
|
||||
log.Error("Numpy parser: failed to open file", zap.String("filePath", filePath))
|
||||
return nil, fmt.Errorf("failed to open file '%s'", filePath)
|
||||
}
|
||||
|
||||
dim, _ := getFieldDimension(schema)
|
||||
columnReader := &NumpyColumnReader{
|
||||
fieldName: schema.GetName(),
|
||||
fieldID: schema.GetFieldID(),
|
||||
dataType: schema.GetDataType(),
|
||||
dimension: dim,
|
||||
file: file,
|
||||
reader: adapter,
|
||||
}
|
||||
|
||||
// the validation method only check the file header information
|
||||
err = p.validateHeader(columnReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
readers = append(readers, columnReader)
|
||||
}
|
||||
|
||||
// row count of each file should be equal
|
||||
if len(readers) > 0 {
|
||||
firstReader := readers[0]
|
||||
rowCount := firstReader.rowCount
|
||||
for i := 1; i < len(readers); i++ {
|
||||
compareReader := readers[i]
|
||||
if rowCount != compareReader.rowCount {
|
||||
log.Error("Numpy parser: the row count of files are not equal",
|
||||
zap.String("firstFile", firstReader.fieldName), zap.Int("firstRowCount", firstReader.rowCount),
|
||||
zap.String("compareFile", compareReader.fieldName), zap.Int("compareRowCount", compareReader.rowCount))
|
||||
return nil, fmt.Errorf("the row count(%d) of file '%s.npy' is not equal to row count(%d) of file '%s.npy'",
|
||||
firstReader.rowCount, firstReader.fieldName, compareReader.rowCount, compareReader.fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return readers, nil
|
||||
}
|
||||
|
||||
// validateHeader is to verify numpy file header, file header information should match field's schema
|
||||
func (p *NumpyParser) validateHeader(columnReader *NumpyColumnReader) error {
|
||||
if columnReader == nil || columnReader.reader == nil {
|
||||
log.Error("Numpy parser: numpy reader is nil")
|
||||
return errors.New("numpy adapter is nil")
|
||||
}
|
||||
|
||||
elementType := columnReader.reader.GetType()
|
||||
shape := columnReader.reader.GetShape()
|
||||
columnReader.rowCount = shape[0]
|
||||
|
||||
// 1. field data type should be consist to numpy data type
|
||||
// 2. vector field dimension should be consist to numpy shape
|
||||
if schemapb.DataType_FloatVector == columnReader.dataType {
|
||||
// float32/float64 numpy file can be used for float vector file, 2 reasons:
|
||||
// 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit
|
||||
// 2. for float64 numpy file, the performance is worse than float32 numpy file
|
||||
if elementType != schemapb.DataType_Float && elementType != schemapb.DataType_Double {
|
||||
log.Error("Numpy parser: illegal data type of numpy file for float vector field", zap.Any("dataType", elementType),
|
||||
zap.String("fieldName", columnReader.fieldName))
|
||||
return fmt.Errorf("illegal data type %s of numpy file for float vector field '%s'", getTypeName(elementType),
|
||||
columnReader.fieldName)
|
||||
}
|
||||
|
||||
// vector field, the shape should be 2
|
||||
if len(shape) != 2 {
|
||||
log.Error("Numpy parser: illegal shape of numpy file for float vector field, shape should be 2", zap.Int("shape", len(shape)),
|
||||
zap.String("fieldName", columnReader.fieldName))
|
||||
return fmt.Errorf("illegal shape %d of numpy file for float vector field '%s', shape should be 2", shape,
|
||||
columnReader.fieldName)
|
||||
}
|
||||
|
||||
if shape[1] != columnReader.dimension {
|
||||
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", columnReader.fieldName),
|
||||
zap.Int("numpyDimension", shape[1]), zap.Int("fieldDimension", columnReader.dimension))
|
||||
return fmt.Errorf("illegal dimension %d of numpy file for float vector field '%s', dimension should be %d",
|
||||
shape[1], columnReader.fieldName, columnReader.dimension)
|
||||
}
|
||||
} else if schemapb.DataType_BinaryVector == columnReader.dataType {
|
||||
if elementType != schemapb.DataType_BinaryVector {
|
||||
log.Error("Numpy parser: illegal data type of numpy file for binary vector field", zap.Any("dataType", elementType),
|
||||
zap.String("fieldName", columnReader.fieldName))
|
||||
return fmt.Errorf("illegal data type %s of numpy file for binary vector field '%s'", getTypeName(elementType),
|
||||
columnReader.fieldName)
|
||||
}
|
||||
|
||||
// vector field, the shape should be 2
|
||||
if len(shape) != 2 {
|
||||
log.Error("Numpy parser: illegal shape of numpy file for binary vector field, shape should be 2", zap.Int("shape", len(shape)),
|
||||
zap.String("fieldName", columnReader.fieldName))
|
||||
return fmt.Errorf("illegal shape %d of numpy file for binary vector field '%s', shape should be 2", shape,
|
||||
columnReader.fieldName)
|
||||
}
|
||||
|
||||
if shape[1] != columnReader.dimension/8 {
|
||||
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", columnReader.fieldName),
|
||||
zap.Int("numpyDimension", shape[1]*8), zap.Int("fieldDimension", columnReader.dimension))
|
||||
return fmt.Errorf("illegal dimension %d of numpy file for binary vector field '%s', dimension should be %d",
|
||||
shape[1]*8, columnReader.fieldName, columnReader.dimension)
|
||||
}
|
||||
} else {
|
||||
if elementType != columnReader.dataType {
|
||||
log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType),
|
||||
zap.String("fieldName", columnReader.fieldName), zap.Any("fieldDataType", columnReader.dataType))
|
||||
return fmt.Errorf("illegal data type %s of numpy file for scalar field '%s' with type %s",
|
||||
getTypeName(elementType), columnReader.fieldName, getTypeName(columnReader.dataType))
|
||||
}
|
||||
|
||||
// scalar field, the shape should be 1
|
||||
if len(shape) != 1 {
|
||||
log.Error("Numpy parser: illegal shape of numpy file for scalar field, shape should be 1", zap.Int("shape", len(shape)),
|
||||
zap.String("fieldName", columnReader.fieldName))
|
||||
return fmt.Errorf("illegal shape %d of numpy file for scalar field '%s', shape should be 1", shape, columnReader.fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calcRowCountPerBlock calculates a proper value for a batch row count to read file
|
||||
func (p *NumpyParser) calcRowCountPerBlock() (int64, error) {
|
||||
sizePerRecord, err := typeutil.EstimateSizePerRecord(p.collectionSchema)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to estimate size of each row", zap.Error(err))
|
||||
return 0, fmt.Errorf("failed to estimate size of each row: %s", err.Error())
|
||||
}
|
||||
|
||||
if sizePerRecord <= 0 {
|
||||
log.Error("Numpy parser: failed to estimate size of each row, the collection schema might be empty")
|
||||
return 0, fmt.Errorf("failed to estimate size of each row: the collection schema might be empty")
|
||||
}
|
||||
|
||||
// the sizePerRecord is estimate value, if the schema contains varchar field, the value is not accurate
|
||||
// we will read data block by block, by default, each block size is 16MB
|
||||
// rowCountPerBlock is the estimated row count for a block
|
||||
rowCountPerBlock := p.blockSize / int64(sizePerRecord)
|
||||
if rowCountPerBlock <= 0 {
|
||||
rowCountPerBlock = 1 // make sure the value is positive
|
||||
}
|
||||
|
||||
log.Info("Numper parser: calculate row count per block to read file", zap.Int64("rowCountPerBlock", rowCountPerBlock),
|
||||
zap.Int64("blockSize", p.blockSize), zap.Int("sizePerRecord", sizePerRecord))
|
||||
return rowCountPerBlock, nil
|
||||
}
|
||||
|
||||
// consume method reads numpy data section into a storage.FieldData
|
||||
// please note it will require a large memory block(the memory size is almost equal to numpy file size)
|
||||
func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
||||
switch p.columnDesc.dt {
|
||||
func (p *NumpyParser) consume(columnReaders []*NumpyColumnReader) error {
|
||||
rowCountPerBlock, err := p.calcRowCountPerBlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// prepare shards
|
||||
shards := make([]map[storage.FieldID]storage.FieldData, 0, p.shardNum)
|
||||
for i := 0; i < int(p.shardNum); i++ {
|
||||
segmentData := initSegmentData(p.collectionSchema)
|
||||
if segmentData == nil {
|
||||
log.Error("import wrapper: failed to initialize FieldData list")
|
||||
return fmt.Errorf("failed to initialize FieldData list")
|
||||
}
|
||||
shards = append(shards, segmentData)
|
||||
}
|
||||
tr := timerecord.NewTimeRecorder("consume performance")
|
||||
defer tr.Elapse("end")
|
||||
// read data from files, batch by batch
|
||||
for {
|
||||
readRowCount := 0
|
||||
segmentData := make(map[storage.FieldID]storage.FieldData)
|
||||
for _, reader := range columnReaders {
|
||||
fieldData, err := p.readData(reader, int(rowCountPerBlock))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if readRowCount == 0 {
|
||||
readRowCount = fieldData.RowNum()
|
||||
} else if readRowCount != fieldData.RowNum() {
|
||||
log.Error("Numpy parser: data block's row count mismatch", zap.Int("firstBlockRowCount", readRowCount),
|
||||
zap.Int("thisBlockRowCount", fieldData.RowNum()), zap.Int64("rowCountPerBlock", rowCountPerBlock))
|
||||
return fmt.Errorf("data block's row count mismatch: %d vs %d", readRowCount, fieldData.RowNum())
|
||||
}
|
||||
|
||||
segmentData[reader.fieldID] = fieldData
|
||||
}
|
||||
|
||||
// nothing to read
|
||||
if readRowCount == 0 {
|
||||
break
|
||||
}
|
||||
tr.Record("readData")
|
||||
// split data to shards
|
||||
err = p.splitFieldsData(segmentData, shards)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tr.Record("splitFieldsData")
|
||||
// when the estimated size is close to blockSize, save to binlog
|
||||
err = tryFlushBlocks(p.ctx, shards, p.collectionSchema, p.callFlushFunc, p.blockSize, MaxTotalSizeInMemory, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tr.Record("tryFlushBlocks")
|
||||
}
|
||||
|
||||
// force flush at the end
|
||||
return tryFlushBlocks(p.ctx, shards, p.collectionSchema, p.callFlushFunc, p.blockSize, MaxTotalSizeInMemory, true)
|
||||
}
|
||||
|
||||
// readData method reads numpy data section into a storage.FieldData
|
||||
func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (storage.FieldData, error) {
|
||||
switch columnReader.dataType {
|
||||
case schemapb.DataType_Bool:
|
||||
data, err := adapter.ReadBool(p.columnDesc.elementCount)
|
||||
data, err := columnReader.reader.ReadBool(rowCount)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read bool array", zap.Error(err))
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to read bool array: %s", err.Error())
|
||||
}
|
||||
|
||||
p.columnData = &storage.BoolFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
return &storage.BoolFieldData{
|
||||
NumRows: []int64{int64(len(data))},
|
||||
Data: data,
|
||||
}
|
||||
|
||||
}, nil
|
||||
case schemapb.DataType_Int8:
|
||||
data, err := adapter.ReadInt8(p.columnDesc.elementCount)
|
||||
data, err := columnReader.reader.ReadInt8(rowCount)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read int8 array", zap.Error(err))
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to read int8 array: %s", err.Error())
|
||||
}
|
||||
|
||||
p.columnData = &storage.Int8FieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
return &storage.Int8FieldData{
|
||||
NumRows: []int64{int64(len(data))},
|
||||
Data: data,
|
||||
}
|
||||
}, nil
|
||||
case schemapb.DataType_Int16:
|
||||
data, err := adapter.ReadInt16(p.columnDesc.elementCount)
|
||||
data, err := columnReader.reader.ReadInt16(rowCount)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to int16 bool array", zap.Error(err))
|
||||
return err
|
||||
log.Error("Numpy parser: failed to int16 array", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read int16 array: %s", err.Error())
|
||||
}
|
||||
|
||||
p.columnData = &storage.Int16FieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
return &storage.Int16FieldData{
|
||||
NumRows: []int64{int64(len(data))},
|
||||
Data: data,
|
||||
}
|
||||
}, nil
|
||||
case schemapb.DataType_Int32:
|
||||
data, err := adapter.ReadInt32(p.columnDesc.elementCount)
|
||||
data, err := columnReader.reader.ReadInt32(rowCount)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read int32 array", zap.Error(err))
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to read int32 array: %s", err.Error())
|
||||
}
|
||||
|
||||
p.columnData = &storage.Int32FieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
return &storage.Int32FieldData{
|
||||
NumRows: []int64{int64(len(data))},
|
||||
Data: data,
|
||||
}
|
||||
}, nil
|
||||
case schemapb.DataType_Int64:
|
||||
data, err := adapter.ReadInt64(p.columnDesc.elementCount)
|
||||
data, err := columnReader.reader.ReadInt64(rowCount)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read int64 array", zap.Error(err))
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to read int64 array: %s", err.Error())
|
||||
}
|
||||
|
||||
p.columnData = &storage.Int64FieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
return &storage.Int64FieldData{
|
||||
NumRows: []int64{int64(len(data))},
|
||||
Data: data,
|
||||
}
|
||||
}, nil
|
||||
case schemapb.DataType_Float:
|
||||
data, err := adapter.ReadFloat32(p.columnDesc.elementCount)
|
||||
data, err := columnReader.reader.ReadFloat32(rowCount)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read float array", zap.Error(err))
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to read float array: %s", err.Error())
|
||||
}
|
||||
|
||||
p.columnData = &storage.FloatFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
return &storage.FloatFieldData{
|
||||
NumRows: []int64{int64(len(data))},
|
||||
Data: data,
|
||||
}
|
||||
}, nil
|
||||
case schemapb.DataType_Double:
|
||||
data, err := adapter.ReadFloat64(p.columnDesc.elementCount)
|
||||
data, err := columnReader.reader.ReadFloat64(rowCount)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read double array", zap.Error(err))
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to read double array: %s", err.Error())
|
||||
}
|
||||
|
||||
p.columnData = &storage.DoubleFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
return &storage.DoubleFieldData{
|
||||
NumRows: []int64{int64(len(data))},
|
||||
Data: data,
|
||||
}
|
||||
}, nil
|
||||
case schemapb.DataType_VarChar:
|
||||
data, err := adapter.ReadString(p.columnDesc.elementCount)
|
||||
data, err := columnReader.reader.ReadString(rowCount)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read varchar array", zap.Error(err))
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to read varchar array: %s", err.Error())
|
||||
}
|
||||
|
||||
p.columnData = &storage.StringFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
return &storage.StringFieldData{
|
||||
NumRows: []int64{int64(len(data))},
|
||||
Data: data,
|
||||
}
|
||||
}, nil
|
||||
case schemapb.DataType_BinaryVector:
|
||||
data, err := adapter.ReadUint8(p.columnDesc.elementCount)
|
||||
data, err := columnReader.reader.ReadUint8(rowCount * (columnReader.dimension / 8))
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read binary vector array", zap.Error(err))
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to read binary vector array: %s", err.Error())
|
||||
}
|
||||
|
||||
p.columnData = &storage.BinaryVectorFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
return &storage.BinaryVectorFieldData{
|
||||
NumRows: []int64{int64(len(data) * 8 / columnReader.dimension)},
|
||||
Data: data,
|
||||
Dim: p.columnDesc.dimension,
|
||||
}
|
||||
Dim: columnReader.dimension,
|
||||
}, nil
|
||||
case schemapb.DataType_FloatVector:
|
||||
// float32/float64 numpy file can be used for float vector file, 2 reasons:
|
||||
// 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit
|
||||
// 2. for float64 numpy file, the performance is worse than float32 numpy file
|
||||
elementType := adapter.GetType()
|
||||
elementType := columnReader.reader.GetType()
|
||||
|
||||
var data []float32
|
||||
var err error
|
||||
if elementType == schemapb.DataType_Float {
|
||||
data, err = adapter.ReadFloat32(p.columnDesc.elementCount)
|
||||
data, err = columnReader.reader.ReadFloat32(rowCount * columnReader.dimension)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read float vector array", zap.Error(err))
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to read float vector array: %s", err.Error())
|
||||
}
|
||||
} else if elementType == schemapb.DataType_Double {
|
||||
data = make([]float32, 0, p.columnDesc.elementCount)
|
||||
data64, err := adapter.ReadFloat64(p.columnDesc.elementCount)
|
||||
data = make([]float32, 0, columnReader.rowCount)
|
||||
data64, err := columnReader.reader.ReadFloat64(rowCount * columnReader.dimension)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to read float vector array", zap.Error(err))
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to read float vector array: %s", err.Error())
|
||||
}
|
||||
|
||||
for _, f64 := range data64 {
|
||||
|
@ -301,40 +548,255 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
|||
}
|
||||
}
|
||||
|
||||
p.columnData = &storage.FloatVectorFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
return &storage.FloatVectorFieldData{
|
||||
NumRows: []int64{int64(len(data) / columnReader.dimension)},
|
||||
Data: data,
|
||||
Dim: p.columnDesc.dimension,
|
||||
Dim: columnReader.dimension,
|
||||
}, nil
|
||||
default:
|
||||
log.Error("Numpy parser: unsupported data type of field", zap.Any("dataType", columnReader.dataType),
|
||||
zap.String("fieldName", columnReader.fieldName))
|
||||
return nil, fmt.Errorf("unsupported data type %s of field '%s'", getTypeName(columnReader.dataType),
|
||||
columnReader.fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// appendFunc defines the methods to append data to storage.FieldData
|
||||
func (p *NumpyParser) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
switch schema.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.BoolFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(bool))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.FloatFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(float32))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.DoubleFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(float64))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int8:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.Int8FieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(int8))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int16:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.Int16FieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(int16))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int32:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.Int32FieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(int32))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.Int64FieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(int64))
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.BinaryVectorFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).([]byte)...)
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_FloatVector:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.FloatVectorFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).([]float32)...)
|
||||
arr.NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
return func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
arr := target.(*storage.StringFieldData)
|
||||
arr.Data = append(arr.Data, src.GetRow(n).(string))
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
log.Error("Numpy parser: unsupported data type of field", zap.Any("dataType", p.columnDesc.dt), zap.String("fieldName", p.columnDesc.name))
|
||||
return fmt.Errorf("unsupported data type %s of field '%s'", getTypeName(p.columnDesc.dt), p.columnDesc.name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *NumpyParser) prepareAppendFunctions() (map[string]func(src storage.FieldData, n int, target storage.FieldData) error, error) {
|
||||
appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error)
|
||||
for i := 0; i < len(p.collectionSchema.Fields); i++ {
|
||||
schema := p.collectionSchema.Fields[i]
|
||||
appendFuncErr := p.appendFunc(schema)
|
||||
if appendFuncErr == nil {
|
||||
log.Error("Numpy parser: unsupported field data type")
|
||||
return nil, fmt.Errorf("unsupported field data type: %d", schema.GetDataType())
|
||||
}
|
||||
appendFunctions[schema.GetName()] = appendFuncErr
|
||||
}
|
||||
return appendFunctions, nil
|
||||
}
|
||||
|
||||
// checkRowCount checks existence of each field, and returns the primary key schema
|
||||
// check row count, all fields row count must be equal
|
||||
func (p *NumpyParser) checkRowCount(fieldsData map[storage.FieldID]storage.FieldData) (int, *schemapb.FieldSchema, error) {
|
||||
rowCount := 0
|
||||
rowCounter := make(map[string]int)
|
||||
var primaryKey *schemapb.FieldSchema
|
||||
for i := 0; i < len(p.collectionSchema.Fields); i++ {
|
||||
schema := p.collectionSchema.Fields[i]
|
||||
if schema.GetIsPrimaryKey() {
|
||||
primaryKey = schema
|
||||
}
|
||||
|
||||
if !schema.GetAutoID() {
|
||||
v, ok := fieldsData[schema.GetFieldID()]
|
||||
if !ok {
|
||||
log.Error("Numpy parser: field not provided", zap.String("fieldName", schema.GetName()))
|
||||
return 0, nil, fmt.Errorf("field '%s' not provided", schema.GetName())
|
||||
}
|
||||
rowCounter[schema.GetName()] = v.RowNum()
|
||||
if v.RowNum() > rowCount {
|
||||
rowCount = v.RowNum()
|
||||
}
|
||||
}
|
||||
}
|
||||
if primaryKey == nil {
|
||||
log.Error("Numpy parser: primary key field is not found")
|
||||
return 0, nil, fmt.Errorf("primary key field is not found")
|
||||
}
|
||||
|
||||
for name, count := range rowCounter {
|
||||
if count != rowCount {
|
||||
log.Error("Numpy parser: field row count is not equal to other fields row count", zap.String("fieldName", name),
|
||||
zap.Int("rowCount", count), zap.Int("otherRowCount", rowCount))
|
||||
return 0, nil, fmt.Errorf("field '%s' row count %d is not equal to other fields row count: %d", name, count, rowCount)
|
||||
}
|
||||
}
|
||||
// log.Info("Numpy parser: try to split a block with row count", zap.Int("rowCount", rowCount))
|
||||
|
||||
return rowCount, primaryKey, nil
|
||||
}
|
||||
|
||||
// splitFieldsData is to split the in-memory data(parsed from column-based files) into shards
|
||||
func (p *NumpyParser) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, shards []map[storage.FieldID]storage.FieldData) error {
|
||||
if len(fieldsData) == 0 {
|
||||
log.Error("Numpy parser: fields data to split is empty")
|
||||
return fmt.Errorf("fields data to split is empty")
|
||||
}
|
||||
|
||||
if len(shards) != int(p.shardNum) {
|
||||
log.Error("Numpy parser: block count is not equal to collection shard number", zap.Int("shardsLen", len(shards)),
|
||||
zap.Int32("shardNum", p.shardNum))
|
||||
return fmt.Errorf("block count %d is not equal to collection shard number %d", len(shards), p.shardNum)
|
||||
}
|
||||
|
||||
rowCount, primaryKey, err := p.checkRowCount(fieldsData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// generate auto id for primary key and rowid field
|
||||
rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount))
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: failed to alloc row ID", zap.Int("rowCount", rowCount), zap.Error(err))
|
||||
return fmt.Errorf("failed to alloc %d rows ID, error: %w", rowCount, err)
|
||||
}
|
||||
|
||||
rowIDField, ok := fieldsData[common.RowIDField]
|
||||
if !ok {
|
||||
rowIDField = &storage.Int64FieldData{
|
||||
Data: make([]int64, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
fieldsData[common.RowIDField] = rowIDField
|
||||
}
|
||||
rowIDFieldArr := rowIDField.(*storage.Int64FieldData)
|
||||
for i := rowIDBegin; i < rowIDEnd; i++ {
|
||||
rowIDFieldArr.Data = append(rowIDFieldArr.Data, i)
|
||||
}
|
||||
|
||||
// reset the primary keys, as we know, only int64 pk can be auto-generated
|
||||
if primaryKey.GetAutoID() {
|
||||
log.Info("Numpy parser: generating auto-id", zap.Int("rowCount", rowCount), zap.Int64("rowIDBegin", rowIDBegin))
|
||||
if primaryKey.GetDataType() != schemapb.DataType_Int64 {
|
||||
log.Error("Numpy parser: primary key field is auto-generated but the field type is not int64")
|
||||
return fmt.Errorf("primary key field is auto-generated but the field type is not int64")
|
||||
}
|
||||
|
||||
primaryDataArr := &storage.Int64FieldData{
|
||||
NumRows: []int64{int64(rowCount)},
|
||||
Data: make([]int64, 0, rowCount),
|
||||
}
|
||||
for i := rowIDBegin; i < rowIDEnd; i++ {
|
||||
primaryDataArr.Data = append(primaryDataArr.Data, i)
|
||||
}
|
||||
|
||||
fieldsData[primaryKey.GetFieldID()] = primaryDataArr
|
||||
p.autoIDRange = append(p.autoIDRange, rowIDBegin, rowIDEnd)
|
||||
}
|
||||
|
||||
// if the primary key is not auto-gernerate and user doesn't provide, return error
|
||||
primaryData, ok := fieldsData[primaryKey.GetFieldID()]
|
||||
if !ok || primaryData.RowNum() <= 0 {
|
||||
log.Error("Numpy parser: primary key field is not provided", zap.String("keyName", primaryKey.GetName()))
|
||||
return fmt.Errorf("primary key '%s' field data is not provided", primaryKey.GetName())
|
||||
}
|
||||
|
||||
// prepare append functions
|
||||
appendFunctions, err := p.prepareAppendFunctions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// split data into shards
|
||||
for i := 0; i < rowCount; i++ {
|
||||
// hash to a shard number
|
||||
|
||||
pk := primaryData.GetRow(i)
|
||||
shard, err := pkToShard(pk, uint32(p.shardNum))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set rowID field
|
||||
rowIDField := shards[shard][common.RowIDField].(*storage.Int64FieldData)
|
||||
rowIDField.Data = append(rowIDField.Data, rowIDFieldArr.GetRow(i).(int64))
|
||||
|
||||
// append row to shard
|
||||
for k := 0; k < len(p.collectionSchema.Fields); k++ {
|
||||
schema := p.collectionSchema.Fields[k]
|
||||
srcData := fieldsData[schema.GetFieldID()]
|
||||
targetData := shards[shard][schema.GetFieldID()]
|
||||
if srcData == nil || targetData == nil {
|
||||
log.Error("Numpy parser: cannot append data since source or target field data is nil",
|
||||
zap.String("FieldName", schema.GetName()),
|
||||
zap.Bool("sourceNil", srcData == nil), zap.Bool("targetNil", targetData == nil))
|
||||
return fmt.Errorf("cannot append data for field '%s' since source or target field data is nil",
|
||||
primaryKey.GetName())
|
||||
}
|
||||
appendFunc := appendFunctions[schema.GetName()]
|
||||
err := appendFunc(srcData, i, targetData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate bool) error {
|
||||
adapter, err := NewNumpyAdapter(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// the validation method only check the file header information
|
||||
err = p.validate(adapter, fieldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if onlyValidate {
|
||||
return nil
|
||||
}
|
||||
|
||||
// read all data from the numpy file
|
||||
err = p.consume(adapter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.callFlushFunc(p.columnData)
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue