Add CSV file import function (#27149)

Signed-off-by: kuma <675613722@qq.com>
Co-authored-by: kuma <675613722@qq.com>
pull/28051/head
KumaJie 2023-10-31 22:47:23 +08:00 committed by GitHub
parent 0677d2623d
commit e88212ba4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 2115 additions and 13 deletions

View File

@ -399,18 +399,18 @@ func (m *importManager) isRowbased(files []string) (bool, error) {
isRowBased := false
for _, filePath := range files {
_, fileType := importutil.GetFileNameAndExt(filePath)
if fileType == importutil.JSONFileExt {
if fileType == importutil.JSONFileExt || fileType == importutil.CSVFileExt {
isRowBased = true
} else if isRowBased {
log.Error("row-based data file type must be JSON, mixed file types is not allowed", zap.Strings("files", files))
return isRowBased, fmt.Errorf("row-based data file type must be JSON, file type '%s' is not allowed", fileType)
log.Error("row-based data file type must be JSON or CSV, mixed file types is not allowed", zap.Strings("files", files))
return isRowBased, fmt.Errorf("row-based data file type must be JSON or CSV, file type '%s' is not allowed", fileType)
}
}
// for row_based, we only allow one file so that each invocation only generate a task
if isRowBased && len(files) > 1 {
log.Error("row-based import, only allow one JSON file each time", zap.Strings("files", files))
return isRowBased, fmt.Errorf("row-based import, only allow one JSON file each time")
log.Error("row-based import, only allow one JSON or CSV file each time", zap.Strings("files", files))
return isRowBased, fmt.Errorf("row-based import, only allow one JSON or CSV file each time")
}
return isRowBased, nil

View File

@ -1101,6 +1101,26 @@ func TestImportManager_isRowbased(t *testing.T) {
rb, err = mgr.isRowbased(files)
assert.NoError(t, err)
assert.False(t, rb)
files = []string{"1.csv"}
rb, err = mgr.isRowbased(files)
assert.NoError(t, err)
assert.True(t, rb)
files = []string{"1.csv", "2.csv"}
rb, err = mgr.isRowbased(files)
assert.Error(t, err)
assert.True(t, rb)
files = []string{"1.csv", "2.json"}
rb, err = mgr.isRowbased(files)
assert.Error(t, err)
assert.True(t, rb)
files = []string{"1.csv", "2.npy"}
rb, err = mgr.isRowbased(files)
assert.Error(t, err)
assert.True(t, rb)
}
func TestImportManager_mergeArray(t *testing.T) {

View File

@ -0,0 +1,446 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importutil
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type CSVRowHandler interface {
Handle(row []map[storage.FieldID]string) error
}
// CSVRowConsumer is row-based csv format consumer class
type CSVRowConsumer struct {
ctx context.Context // for canceling parse process
collectionInfo *CollectionInfo // collection details including schema
rowIDAllocator *allocator.IDAllocator // autoid allocator
validators map[storage.FieldID]*CSVValidator // validators for each field
rowCounter int64 // how many rows have been consumed
shardsData []ShardData // in-memory shards data
blockSize int64 // maximum size of a read block(unit:byte)
autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25
callFlushFunc ImportFlushFunc // call back function to flush segment
}
func NewCSVRowConsumer(ctx context.Context,
collectionInfo *CollectionInfo,
idAlloc *allocator.IDAllocator,
blockSize int64,
flushFunc ImportFlushFunc,
) (*CSVRowConsumer, error) {
if collectionInfo == nil {
log.Warn("CSV row consumer: collection schema is nil")
return nil, errors.New("collection schema is nil")
}
v := &CSVRowConsumer{
ctx: ctx,
collectionInfo: collectionInfo,
rowIDAllocator: idAlloc,
validators: make(map[storage.FieldID]*CSVValidator, 0),
rowCounter: 0,
shardsData: make([]ShardData, 0, collectionInfo.ShardNum),
blockSize: blockSize,
autoIDRange: make([]int64, 0),
callFlushFunc: flushFunc,
}
if err := v.initValidators(collectionInfo.Schema); err != nil {
log.Warn("CSV row consumer: fail to initialize csv row-based consumer", zap.Error(err))
return nil, fmt.Errorf("fail to initialize csv row-based consumer, error: %w", err)
}
for i := 0; i < int(collectionInfo.ShardNum); i++ {
shardData := initShardData(collectionInfo.Schema, collectionInfo.PartitionIDs)
if shardData == nil {
log.Warn("CSV row consumer: fail to initialize in-memory segment data", zap.Int("shardID", i))
return nil, fmt.Errorf("fail to initialize in-memory segment data for shard id %d", i)
}
v.shardsData = append(v.shardsData, shardData)
}
// primary key is autoid, id generator is required
if v.collectionInfo.PrimaryKey.GetAutoID() && idAlloc == nil {
log.Warn("CSV row consumer: ID allocator is nil")
return nil, errors.New("ID allocator is nil")
}
return v, nil
}
type CSVValidator struct {
convertFunc func(val string, field storage.FieldData) error // convert data function
isString bool // for string field
fieldName string // field name
}
func (v *CSVRowConsumer) initValidators(collectionSchema *schemapb.CollectionSchema) error {
if collectionSchema == nil {
return errors.New("collection schema is nil")
}
validators := v.validators
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
validators[schema.GetFieldID()] = &CSVValidator{}
validators[schema.GetFieldID()].fieldName = schema.GetName()
validators[schema.GetFieldID()].isString = false
switch schema.DataType {
// all obj is string type
case schemapb.DataType_Bool:
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
var value bool
if err := json.Unmarshal([]byte(str), &value); err != nil {
return fmt.Errorf("illegal value '%v' for bool type field '%s'", str, schema.GetName())
}
field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value)
return nil
}
case schemapb.DataType_Float:
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
value, err := parseFloat(str, 32, schema.GetName())
if err != nil {
return err
}
field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, float32(value))
return nil
}
case schemapb.DataType_Double:
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
value, err := parseFloat(str, 64, schema.GetName())
if err != nil {
return err
}
field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value)
return nil
}
case schemapb.DataType_Int8:
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
value, err := strconv.ParseInt(str, 0, 8)
if err != nil {
return fmt.Errorf("failed to parse value '%v' for int8 field '%s', error: %w", str, schema.GetName(), err)
}
field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, int8(value))
return nil
}
case schemapb.DataType_Int16:
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
value, err := strconv.ParseInt(str, 0, 16)
if err != nil {
return fmt.Errorf("failed to parse value '%v' for int16 field '%s', error: %w", str, schema.GetName(), err)
}
field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, int16(value))
return nil
}
case schemapb.DataType_Int32:
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
value, err := strconv.ParseInt(str, 0, 32)
if err != nil {
return fmt.Errorf("failed to parse value '%v' for int32 field '%s', error: %w", str, schema.GetName(), err)
}
field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, int32(value))
return nil
}
case schemapb.DataType_Int64:
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
value, err := strconv.ParseInt(str, 0, 64)
if err != nil {
return fmt.Errorf("failed to parse value '%v' for int64 field '%s', error: %w", str, schema.GetName(), err)
}
field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value)
return nil
}
case schemapb.DataType_BinaryVector:
dim, err := getFieldDimension(schema)
if err != nil {
return err
}
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
var arr []interface{}
desc := json.NewDecoder(strings.NewReader(str))
desc.UseNumber()
if err := desc.Decode(&arr); err != nil {
return fmt.Errorf("'%v' is not an array for binary vector field '%s'", str, schema.GetName())
}
// we use uint8 to represent binary vector in csv file, each uint8 value represents 8 dimensions.
if len(arr)*8 != dim {
return fmt.Errorf("bit size %d doesn't equal to vector dimension %d of field '%s'", len(arr)*8, dim, schema.GetName())
}
for i := 0; i < len(arr); i++ {
if num, ok := arr[i].(json.Number); ok {
value, err := strconv.ParseUint(string(num), 0, 8)
if err != nil {
return fmt.Errorf("failed to parse value '%v' for binary vector field '%s', error: %w", num, schema.GetName(), err)
}
field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, byte(value))
} else {
return fmt.Errorf("illegal value '%v' for binary vector field '%s'", str, schema.GetName())
}
}
return nil
}
case schemapb.DataType_FloatVector:
dim, err := getFieldDimension(schema)
if err != nil {
return err
}
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
var arr []interface{}
desc := json.NewDecoder(strings.NewReader(str))
desc.UseNumber()
if err := desc.Decode(&arr); err != nil {
return fmt.Errorf("'%v' is not an array for float vector field '%s'", str, schema.GetName())
}
if len(arr) != dim {
return fmt.Errorf("array size %d doesn't equal to vector dimension %d of field '%s'", len(arr), dim, schema.GetName())
}
for i := 0; i < len(arr); i++ {
if num, ok := arr[i].(json.Number); ok {
value, err := parseFloat(string(num), 32, schema.GetName())
if err != nil {
return err
}
field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, float32(value))
} else {
return fmt.Errorf("illegal value '%v' for float vector field '%s'", str, schema.GetName())
}
}
return nil
}
case schemapb.DataType_String, schemapb.DataType_VarChar:
validators[schema.GetFieldID()].isString = true
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, str)
return nil
}
case schemapb.DataType_JSON:
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
var dummy interface{}
if err := json.Unmarshal([]byte(str), &dummy); err != nil {
return fmt.Errorf("failed to parse value '%v' for JSON field '%s', error: %w", str, schema.GetName(), err)
}
field.(*storage.JSONFieldData).Data = append(field.(*storage.JSONFieldData).Data, []byte(str))
return nil
}
default:
return fmt.Errorf("unsupport data type: %s", getTypeName(collectionSchema.Fields[i].DataType))
}
}
return nil
}
func (v *CSVRowConsumer) IDRange() []int64 {
return v.autoIDRange
}
func (v *CSVRowConsumer) RowCount() int64 {
return v.rowCounter
}
func (v *CSVRowConsumer) Handle(rows []map[storage.FieldID]string) error {
if v == nil || v.validators == nil || len(v.validators) == 0 {
log.Warn("CSV row consumer is not initialized")
return errors.New("CSV row consumer is not initialized")
}
// if rows is nil, that means read to end of file, force flush all data
if rows == nil {
err := tryFlushBlocks(v.ctx, v.shardsData, v.collectionInfo.Schema, v.callFlushFunc, v.blockSize, MaxTotalSizeInMemory, true)
log.Info("CSV row consumer finished")
return err
}
// rows is not nil, flush in necessary:
// 1. data block size larger than v.blockSize will be flushed
// 2. total data size exceeds MaxTotalSizeInMemory, the largest data block will be flushed
err := tryFlushBlocks(v.ctx, v.shardsData, v.collectionInfo.Schema, v.callFlushFunc, v.blockSize, MaxTotalSizeInMemory, false)
if err != nil {
log.Warn("CSV row consumer: try flush data but failed", zap.Error(err))
return fmt.Errorf("try flush data but failed, error: %w", err)
}
// prepare autoid, no matter int64 or varchar pk, we always generate autoid since the hidden field RowIDField requires them
primaryKeyID := v.collectionInfo.PrimaryKey.FieldID
primaryValidator := v.validators[primaryKeyID]
var rowIDBegin typeutil.UniqueID
var rowIDEnd typeutil.UniqueID
if v.collectionInfo.PrimaryKey.AutoID {
if v.rowIDAllocator == nil {
log.Warn("CSV row consumer: primary keys is auto-generated but IDAllocator is nil")
return fmt.Errorf("primary keys is auto-generated but IDAllocator is nil")
}
var err error
rowIDBegin, rowIDEnd, err = v.rowIDAllocator.Alloc(uint32(len(rows)))
if err != nil {
log.Warn("CSV row consumer: failed to generate primary keys", zap.Int("count", len(rows)), zap.Error(err))
return fmt.Errorf("failed to generate %d primary keys, error: %w", len(rows), err)
}
if rowIDEnd-rowIDBegin != int64(len(rows)) {
log.Warn("CSV row consumer: try to generate primary keys but allocated ids are not enough",
zap.Int("count", len(rows)), zap.Int64("generated", rowIDEnd-rowIDBegin))
return fmt.Errorf("try to generate %d primary keys but only %d keys were allocated", len(rows), rowIDEnd-rowIDBegin)
}
log.Info("CSV row consumer: auto-generate primary keys", zap.Int64("begin", rowIDBegin), zap.Int64("end", rowIDEnd))
if primaryValidator.isString {
// if pk is varchar, no need to record auto-generated row ids
log.Warn("CSV row consumer: string type primary key connot be auto-generated")
return errors.New("string type primary key connot be auto-generated")
}
v.autoIDRange = append(v.autoIDRange, rowIDBegin, rowIDEnd)
}
// consume rows
for i := 0; i < len(rows); i++ {
row := rows[i]
rowNumber := v.rowCounter + int64(i)
// hash to a shard number
var shardID uint32
var partitionID int64
if primaryValidator.isString {
pk := row[primaryKeyID]
// hash to shard based on pk, hash to partition if partition key exist
hash := typeutil.HashString2Uint32(pk)
shardID = hash % uint32(v.collectionInfo.ShardNum)
partitionID, err = v.hashToPartition(row, rowNumber)
if err != nil {
return err
}
pkArray := v.shardsData[shardID][partitionID][primaryKeyID].(*storage.StringFieldData)
pkArray.Data = append(pkArray.Data, pk)
} else {
var pk int64
if v.collectionInfo.PrimaryKey.AutoID {
pk = rowIDBegin + int64(i)
} else {
pkStr := row[primaryKeyID]
pk, err = strconv.ParseInt(pkStr, 10, 64)
if err != nil {
log.Warn("CSV row consumer: failed to parse primary key at the row",
zap.String("value", pkStr), zap.Int64("rowNumber", rowNumber), zap.Error(err))
return fmt.Errorf("failed to parse primary key '%s' at the row %d, error: %w",
pkStr, rowNumber, err)
}
}
hash, err := typeutil.Hash32Int64(pk)
if err != nil {
log.Warn("CSV row consumer: failed to hash primary key at the row",
zap.Int64("key", pk), zap.Int64("rowNumber", rowNumber), zap.Error(err))
return fmt.Errorf("failed to hash primary key %d at the row %d, error: %w", pk, rowNumber, err)
}
// hash to shard based on pk, hash to partition if partition key exist
shardID = hash % uint32(v.collectionInfo.ShardNum)
partitionID, err = v.hashToPartition(row, rowNumber)
if err != nil {
return err
}
pkArray := v.shardsData[shardID][partitionID][primaryKeyID].(*storage.Int64FieldData)
pkArray.Data = append(pkArray.Data, pk)
}
rowIDField := v.shardsData[shardID][partitionID][common.RowIDField].(*storage.Int64FieldData)
rowIDField.Data = append(rowIDField.Data, rowIDBegin+int64(i))
for fieldID, validator := range v.validators {
if fieldID == v.collectionInfo.PrimaryKey.GetFieldID() {
continue
}
value := row[fieldID]
if err := validator.convertFunc(value, v.shardsData[shardID][partitionID][fieldID]); err != nil {
log.Warn("CSV row consumer: failed to convert value for field at the row",
zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", rowNumber), zap.Error(err))
return fmt.Errorf("failed to convert value for field '%s' at the row %d, error: %w",
validator.fieldName, rowNumber, err)
}
}
}
v.rowCounter += int64(len(rows))
return nil
}
// hashToPartition hash partition key to get an partition ID, return the first partition ID if no partition key exist
// CollectionInfo ensures only one partition ID in the PartitionIDs if no partition key exist
func (v *CSVRowConsumer) hashToPartition(row map[storage.FieldID]string, rowNumber int64) (int64, error) {
if v.collectionInfo.PartitionKey == nil {
if len(v.collectionInfo.PartitionIDs) != 1 {
return 0, fmt.Errorf("collection '%s' partition list is empty", v.collectionInfo.Schema.Name)
}
// no partition key, directly return the target partition id
return v.collectionInfo.PartitionIDs[0], nil
}
partitionKeyID := v.collectionInfo.PartitionKey.GetFieldID()
partitionKeyValidator := v.validators[partitionKeyID]
value := row[partitionKeyID]
var hashValue uint32
if partitionKeyValidator.isString {
hashValue = typeutil.HashString2Uint32(value)
} else {
// parse the value from a string
pk, err := strconv.ParseInt(value, 10, 64)
if err != nil {
log.Warn("CSV row consumer: failed to parse partition key at the row",
zap.String("value", value), zap.Int64("rowNumber", rowNumber), zap.Error(err))
return 0, fmt.Errorf("failed to parse partition key '%s' at the row %d, error: %w",
value, rowNumber, err)
}
hashValue, err = typeutil.Hash32Int64(pk)
if err != nil {
log.Warn("CSV row consumer: failed to hash partition key at the row",
zap.Int64("key", pk), zap.Int64("rowNumber", rowNumber), zap.Error(err))
return 0, fmt.Errorf("failed to hash partition key %d at the row %d, error: %w", pk, rowNumber, err)
}
}
index := int64(hashValue % uint32(len(v.collectionInfo.PartitionIDs)))
return v.collectionInfo.PartitionIDs[index], nil
}

View File

@ -0,0 +1,760 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importutil
import (
"context"
"strconv"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
)
func Test_CSVRowConsumerNew(t *testing.T) {
ctx := context.Background()
t.Run("nil schema", func(t *testing.T) {
consumer, err := NewCSVRowConsumer(ctx, nil, nil, 16, nil)
assert.Error(t, err)
assert.Nil(t, consumer)
})
t.Run("wrong schema", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Name: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "uid",
IsPrimaryKey: true,
AutoID: false,
DataType: schemapb.DataType_Int64,
},
},
}
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
assert.NoError(t, err)
schema.Fields[0].DataType = schemapb.DataType_None
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil)
assert.Error(t, err)
assert.Nil(t, consumer)
})
t.Run("primary key is autoid but no IDAllocator", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Name: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "uid",
IsPrimaryKey: true,
AutoID: true,
DataType: schemapb.DataType_Int64,
},
},
}
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
assert.NoError(t, err)
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil)
assert.Error(t, err)
assert.Nil(t, consumer)
})
t.Run("succeed", func(t *testing.T) {
collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1})
assert.NoError(t, err)
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil)
assert.NoError(t, err)
assert.NotNil(t, consumer)
})
}
func Test_CSVRowConsumerInitValidators(t *testing.T) {
ctx := context.Background()
consumer := &CSVRowConsumer{
ctx: ctx,
validators: make(map[int64]*CSVValidator),
}
collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1})
assert.NoError(t, err)
schema := collectionInfo.Schema
err = consumer.initValidators(schema)
assert.NoError(t, err)
assert.Equal(t, len(schema.Fields), len(consumer.validators))
for _, field := range schema.Fields {
fieldID := field.GetFieldID()
assert.Equal(t, field.GetName(), consumer.validators[fieldID].fieldName)
if field.GetDataType() != schemapb.DataType_VarChar && field.GetDataType() != schemapb.DataType_String {
assert.False(t, consumer.validators[fieldID].isString)
} else {
assert.True(t, consumer.validators[fieldID].isString)
}
}
name2ID := make(map[string]storage.FieldID)
for _, field := range schema.Fields {
name2ID[field.GetName()] = field.GetFieldID()
}
fields := initBlockData(schema)
assert.NotNil(t, fields)
checkConvertFunc := func(funcName string, validVal string, invalidVal string) {
id := name2ID[funcName]
v, ok := consumer.validators[id]
assert.True(t, ok)
fieldData := fields[id]
preNum := fieldData.RowNum()
err = v.convertFunc(validVal, fieldData)
assert.NoError(t, err)
postNum := fieldData.RowNum()
assert.Equal(t, 1, postNum-preNum)
err = v.convertFunc(invalidVal, fieldData)
assert.Error(t, err)
}
t.Run("check convert functions", func(t *testing.T) {
// all val is string type
validVal := "true"
invalidVal := "5"
checkConvertFunc("FieldBool", validVal, invalidVal)
validVal = "100"
invalidVal = "128"
checkConvertFunc("FieldInt8", validVal, invalidVal)
invalidVal = "65536"
checkConvertFunc("FieldInt16", validVal, invalidVal)
invalidVal = "2147483648"
checkConvertFunc("FieldInt32", validVal, invalidVal)
invalidVal = "1.2"
checkConvertFunc("FieldInt64", validVal, invalidVal)
invalidVal = "dummy"
checkConvertFunc("FieldFloat", validVal, invalidVal)
checkConvertFunc("FieldDouble", validVal, invalidVal)
// json type
validVal = `{"x": 5, "y": true, "z": "hello"}`
checkConvertFunc("FieldJSON", validVal, "a")
checkConvertFunc("FieldJSON", validVal, "{")
// the binary vector dimension is 16, shoud input two uint8 values, each value should between 0~255
validVal = "[100, 101]"
invalidVal = "[100, 1256]"
checkConvertFunc("FieldBinaryVector", validVal, invalidVal)
invalidVal = "false"
checkConvertFunc("FieldBinaryVector", validVal, invalidVal)
invalidVal = "[100]"
checkConvertFunc("FieldBinaryVector", validVal, invalidVal)
invalidVal = "[100.2, 102.5]"
checkConvertFunc("FieldBinaryVector", validVal, invalidVal)
// the float vector dimension is 4, each value should be valid float number
validVal = "[1,2,3,4]"
invalidVal = `[1,2,3,"dummy"]`
checkConvertFunc("FieldFloatVector", validVal, invalidVal)
invalidVal = "true"
checkConvertFunc("FieldFloatVector", validVal, invalidVal)
invalidVal = `[1]`
checkConvertFunc("FieldFloatVector", validVal, invalidVal)
})
t.Run("init error cases", func(t *testing.T) {
// schema is nil
err := consumer.initValidators(nil)
assert.Error(t, err)
schema = &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: make([]*schemapb.FieldSchema, 0),
}
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 111,
Name: "FieldFloatVector",
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "aa"},
},
})
consumer.validators = make(map[int64]*CSVValidator)
err = consumer.initValidators(schema)
assert.Error(t, err)
schema.Fields = make([]*schemapb.FieldSchema, 0)
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 110,
Name: "FieldBinaryVector",
IsPrimaryKey: false,
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "aa"},
},
})
err = consumer.initValidators(schema)
assert.Error(t, err)
// unsupported data type
schema.Fields = make([]*schemapb.FieldSchema, 0)
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 110,
Name: "dummy",
IsPrimaryKey: false,
DataType: schemapb.DataType_None,
})
err = consumer.initValidators(schema)
assert.Error(t, err)
})
t.Run("json field", func(t *testing.T) {
schema = &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 102,
Name: "FieldJSON",
DataType: schemapb.DataType_JSON,
},
},
}
consumer.validators = make(map[int64]*CSVValidator)
err = consumer.initValidators(schema)
assert.NoError(t, err)
v, ok := consumer.validators[102]
assert.True(t, ok)
fields := initBlockData(schema)
assert.NotNil(t, fields)
fieldData := fields[102]
err = v.convertFunc("{\"x\": 1, \"y\": 5}", fieldData)
assert.NoError(t, err)
assert.Equal(t, 1, fieldData.RowNum())
err = v.convertFunc("{}", fieldData)
assert.NoError(t, err)
assert.Equal(t, 2, fieldData.RowNum())
err = v.convertFunc("", fieldData)
assert.Error(t, err)
assert.Equal(t, 2, fieldData.RowNum())
})
}
func Test_CSVRowConsumerHandleIntPK(t *testing.T) {
ctx := context.Background()
t.Run("nil input", func(t *testing.T) {
var consumer *CSVRowConsumer
err := consumer.Handle(nil)
assert.Error(t, err)
})
schema := &schemapb.CollectionSchema{
Name: "schema",
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "FieldInt64",
IsPrimaryKey: true,
AutoID: true,
DataType: schemapb.DataType_Int64,
},
{
FieldID: 102,
Name: "FieldVarchar",
DataType: schemapb.DataType_VarChar,
},
{
FieldID: 103,
Name: "FieldFloat",
DataType: schemapb.DataType_Float,
},
},
}
createConsumeFunc := func(shardNum int32, partitionIDs []int64, flushFunc ImportFlushFunc) *CSVRowConsumer {
collectionInfo, err := NewCollectionInfo(schema, shardNum, partitionIDs)
assert.NoError(t, err)
idAllocator := newIDAllocator(ctx, t, nil)
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, idAllocator, 1, flushFunc)
assert.NotNil(t, consumer)
assert.NoError(t, err)
return consumer
}
t.Run("auto pk no partition key", func(t *testing.T) {
flushErrFunc := func(fields BlockData, shard int, partID int64) error {
return errors.New("dummy error")
}
// rows to input
inputRowCount := 100
input := make([]map[storage.FieldID]string, inputRowCount)
for i := 0; i < inputRowCount; i++ {
input[i] = map[storage.FieldID]string{
102: "string",
103: "122.5",
}
}
shardNum := int32(2)
partitionID := int64(1)
consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushErrFunc)
consumer.rowIDAllocator = newIDAllocator(ctx, t, errors.New("error"))
waitFlushRowCount := 10
fieldData := createFieldsData(schema, waitFlushRowCount)
consumer.shardsData = createShardsData(schema, fieldData, shardNum, []int64{partitionID})
// nil input will trigger force flush, flushErrFunc returns error
err := consumer.Handle(nil)
assert.Error(t, err)
// optional flush, flushErrFunc returns error
err = consumer.Handle(input)
assert.Error(t, err)
// reset flushFunc
var callTime int32
var flushedRowCount int
consumer.callFlushFunc = func(fields BlockData, shard int, partID int64) error {
callTime++
assert.Less(t, int32(shard), shardNum)
assert.Equal(t, partitionID, partID)
assert.Greater(t, len(fields), 0)
for _, v := range fields {
assert.Greater(t, v.RowNum(), 0)
}
flushedRowCount += fields[102].RowNum()
return nil
}
// optional flush succeed, each shard has 10 rows, idErrAllocator returns error
err = consumer.Handle(input)
assert.Error(t, err)
assert.Equal(t, waitFlushRowCount*int(shardNum), flushedRowCount)
assert.Equal(t, shardNum, callTime)
// optional flush again, large blockSize, nothing flushed, idAllocator returns error
callTime = int32(0)
flushedRowCount = 0
consumer.shardsData = createShardsData(schema, fieldData, shardNum, []int64{partitionID})
consumer.rowIDAllocator = nil
consumer.blockSize = 8 * 1024 * 1024
err = consumer.Handle(input)
assert.Error(t, err)
assert.Equal(t, 0, flushedRowCount)
assert.Equal(t, int32(0), callTime)
// idAllocator is ok, consume 100 rows, the previous shardsData(10 rows per shard) is flushed
callTime = int32(0)
flushedRowCount = 0
consumer.blockSize = 1
consumer.rowIDAllocator = newIDAllocator(ctx, t, nil)
err = consumer.Handle(input)
assert.NoError(t, err)
assert.Equal(t, waitFlushRowCount*int(shardNum), flushedRowCount)
assert.Equal(t, shardNum, callTime)
assert.Equal(t, int64(inputRowCount), consumer.RowCount())
assert.Equal(t, 2, len(consumer.IDRange()))
assert.Equal(t, int64(1), consumer.IDRange()[0])
assert.Equal(t, int64(1+inputRowCount), consumer.IDRange()[1])
// call handle again, the 100 rows are flushed
callTime = int32(0)
flushedRowCount = 0
err = consumer.Handle(nil)
assert.NoError(t, err)
assert.Equal(t, inputRowCount, flushedRowCount)
assert.Equal(t, shardNum, callTime)
})
schema.Fields[0].AutoID = false
t.Run("manual pk no partition key", func(t *testing.T) {
shardNum := int32(1)
partitionID := int64(100)
var callTime int32
var flushedRowCount int
flushFunc := func(fields BlockData, shard int, partID int64) error {
callTime++
assert.Less(t, int32(shard), shardNum)
assert.Equal(t, partitionID, partID)
assert.Greater(t, len(fields), 0)
flushedRowCount += fields[102].RowNum()
return nil
}
consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushFunc)
// failed to convert pk to int value
input := make([]map[storage.FieldID]string, 1)
input[0] = map[int64]string{
101: "abc",
102: "string",
103: "11.11",
}
err := consumer.Handle(input)
assert.Error(t, err)
// failed to hash to partition
input[0] = map[int64]string{
101: "99",
102: "string",
103: "11.11",
}
consumer.collectionInfo.PartitionIDs = nil
err = consumer.Handle(input)
assert.Error(t, err)
consumer.collectionInfo.PartitionIDs = []int64{partitionID}
// failed to convert value
input[0] = map[int64]string{
101: "99",
102: "string",
103: "abc.11",
}
err = consumer.Handle(input)
assert.Error(t, err)
consumer.shardsData = createShardsData(schema, nil, shardNum, []int64{partitionID}) // in-memory data is dirty, reset
// succeed, consum 1 row
input[0] = map[int64]string{
101: "99",
102: "string",
103: "11.11",
}
err = consumer.Handle(input)
assert.NoError(t, err)
assert.Equal(t, int64(1), consumer.RowCount())
assert.Equal(t, 0, len(consumer.IDRange()))
// call handle again, the 1 row is flushed
callTime = int32(0)
flushedRowCount = 0
err = consumer.Handle(nil)
assert.NoError(t, err)
assert.Equal(t, 1, flushedRowCount)
assert.Equal(t, shardNum, callTime)
})
schema.Fields[1].IsPartitionKey = true
t.Run("manual pk with partition key", func(t *testing.T) {
// 10 partitions
partitionIDs := make([]int64, 0)
for i := 0; i < 10; i++ {
partitionIDs = append(partitionIDs, int64(i))
}
shardNum := int32(2)
var flushedRowCount int
flushFunc := func(fields BlockData, shard int, partID int64) error {
assert.Less(t, int32(shard), shardNum)
assert.Contains(t, partitionIDs, partID)
assert.Greater(t, len(fields), 0)
flushedRowCount += fields[102].RowNum()
return nil
}
consumer := createConsumeFunc(shardNum, partitionIDs, flushFunc)
// rows to input
inputRowCount := 100
input := make([]map[storage.FieldID]string, inputRowCount)
for i := 0; i < inputRowCount; i++ {
input[i] = map[int64]string{
101: strconv.Itoa(i),
102: "partitionKey_" + strconv.Itoa(i),
103: "6.18",
}
}
// 100 rows are consumed to different partitions
err := consumer.Handle(input)
assert.NoError(t, err)
assert.Equal(t, int64(inputRowCount), consumer.RowCount())
// call handle again, 100 rows are flushed
flushedRowCount = 0
err = consumer.Handle(nil)
assert.NoError(t, err)
assert.Equal(t, inputRowCount, flushedRowCount)
})
}
func Test_CSVRowConsumerHandleVarcharPK(t *testing.T) {
ctx := context.Background()
schema := &schemapb.CollectionSchema{
Name: "schema",
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "FieldVarchar",
IsPrimaryKey: true,
AutoID: false,
DataType: schemapb.DataType_VarChar,
},
{
FieldID: 102,
Name: "FieldInt64",
DataType: schemapb.DataType_Int64,
},
{
FieldID: 103,
Name: "FieldFloat",
DataType: schemapb.DataType_Float,
},
},
}
createConsumeFunc := func(shardNum int32, partitionIDs []int64, flushFunc ImportFlushFunc) *CSVRowConsumer {
collectionInfo, err := NewCollectionInfo(schema, shardNum, partitionIDs)
assert.NoError(t, err)
idAllocator := newIDAllocator(ctx, t, nil)
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, idAllocator, 1, flushFunc)
assert.NotNil(t, consumer)
assert.NoError(t, err)
return consumer
}
t.Run("no partition key", func(t *testing.T) {
shardNum := int32(2)
partitionID := int64(1)
var callTime int32
var flushedRowCount int
flushFunc := func(fields BlockData, shard int, partID int64) error {
callTime++
assert.Less(t, int32(shard), shardNum)
assert.Equal(t, partitionID, partID)
assert.Greater(t, len(fields), 0)
for _, v := range fields {
assert.Greater(t, v.RowNum(), 0)
}
flushedRowCount += fields[102].RowNum()
return nil
}
consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushFunc)
consumer.shardsData = createShardsData(schema, nil, shardNum, []int64{partitionID})
// string type primary key cannot be auto-generated
input := make([]map[storage.FieldID]string, 1)
input[0] = map[storage.FieldID]string{
101: "primaryKey_0",
102: "1",
103: "1.252",
}
consumer.collectionInfo.PrimaryKey.AutoID = true
err := consumer.Handle(input)
assert.Error(t, err)
consumer.collectionInfo.PrimaryKey.AutoID = false
// failed to hash to partition
consumer.collectionInfo.PartitionIDs = nil
err = consumer.Handle(input)
assert.Error(t, err)
consumer.collectionInfo.PartitionIDs = []int64{partitionID}
// rows to input
inputRowCount := 100
input = make([]map[storage.FieldID]string, inputRowCount)
for i := 0; i < inputRowCount; i++ {
input[i] = map[int64]string{
101: "primaryKey_" + strconv.Itoa(i),
102: strconv.Itoa(i),
103: "6.18",
}
}
err = consumer.Handle(input)
assert.NoError(t, err)
assert.Equal(t, int64(inputRowCount), consumer.RowCount())
assert.Equal(t, 0, len(consumer.IDRange()))
// call handle again, 100 rows are flushed
err = consumer.Handle(nil)
assert.NoError(t, err)
assert.Equal(t, inputRowCount, flushedRowCount)
assert.Equal(t, shardNum, callTime)
})
schema.Fields[1].IsPartitionKey = true
t.Run("has partition key", func(t *testing.T) {
partitionIDs := make([]int64, 0)
for i := 0; i < 10; i++ {
partitionIDs = append(partitionIDs, int64(i))
}
shardNum := int32(2)
var flushedRowCount int
flushFunc := func(fields BlockData, shard int, partID int64) error {
assert.Less(t, int32(shard), shardNum)
assert.Contains(t, partitionIDs, partID)
assert.Greater(t, len(fields), 0)
flushedRowCount += fields[102].RowNum()
return nil
}
consumer := createConsumeFunc(shardNum, partitionIDs, flushFunc)
// rows to input
inputRowCount := 100
input := make([]map[storage.FieldID]string, inputRowCount)
for i := 0; i < inputRowCount; i++ {
input[i] = map[int64]string{
101: "primaryKey_" + strconv.Itoa(i),
102: strconv.Itoa(i),
103: "6.18",
}
}
err := consumer.Handle(input)
assert.NoError(t, err)
assert.Equal(t, int64(inputRowCount), consumer.RowCount())
assert.Equal(t, 0, len(consumer.IDRange()))
// call handle again, 100 rows are flushed
err = consumer.Handle(nil)
assert.NoError(t, err)
assert.Equal(t, inputRowCount, flushedRowCount)
})
}
func Test_CSVRowConsumerHashToPartition(t *testing.T) {
ctx := context.Background()
schema := &schemapb.CollectionSchema{
Name: "schema",
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "ID",
IsPrimaryKey: true,
AutoID: false,
DataType: schemapb.DataType_Int64,
},
{
FieldID: 101,
Name: "FieldVarchar",
DataType: schemapb.DataType_VarChar,
},
{
FieldID: 102,
Name: "FieldInt64",
DataType: schemapb.DataType_Int64,
},
},
}
partitionID := int64(1)
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{partitionID})
assert.NoError(t, err)
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil)
assert.NoError(t, err)
assert.NotNil(t, consumer)
input := map[int64]string{
100: "1",
101: "abc",
102: "100",
}
t.Run("no partition key", func(t *testing.T) {
partID, err := consumer.hashToPartition(input, 0)
assert.NoError(t, err)
assert.Equal(t, partitionID, partID)
})
t.Run("partition list is empty", func(t *testing.T) {
collectionInfo.PartitionIDs = []int64{}
partID, err := consumer.hashToPartition(input, 0)
assert.Error(t, err)
assert.Equal(t, int64(0), partID)
collectionInfo.PartitionIDs = []int64{partitionID}
})
schema.Fields[1].IsPartitionKey = true
err = collectionInfo.resetSchema(schema)
assert.NoError(t, err)
collectionInfo.PartitionIDs = []int64{1, 2, 3}
t.Run("varchar partition key", func(t *testing.T) {
input = map[int64]string{
100: "1",
101: "abc",
102: "100",
}
partID, err := consumer.hashToPartition(input, 0)
assert.NoError(t, err)
assert.Contains(t, collectionInfo.PartitionIDs, partID)
})
schema.Fields[1].IsPartitionKey = false
schema.Fields[2].IsPartitionKey = true
err = collectionInfo.resetSchema(schema)
assert.NoError(t, err)
t.Run("int64 partition key", func(t *testing.T) {
input = map[int64]string{
100: "1",
101: "abc",
102: "ab0",
}
// parse int failed
partID, err := consumer.hashToPartition(input, 0)
assert.Error(t, err)
assert.Equal(t, int64(0), partID)
// succeed
input[102] = "100"
partID, err = consumer.hashToPartition(input, 0)
assert.NoError(t, err)
assert.Contains(t, collectionInfo.PartitionIDs, partID)
})
}

View File

@ -0,0 +1,318 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importutil
import (
"context"
"encoding/csv"
"encoding/json"
"fmt"
"io"
"strconv"
"strings"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type CSVParser struct {
ctx context.Context // for canceling parse process
collectionInfo *CollectionInfo // collection details including schema
bufRowCount int // max rows in a buffer
fieldsName []string // fieldsName(header name) in the csv file
updateProgressFunc func(percent int64) // update working progress percent value
}
func NewCSVParser(ctx context.Context, collectionInfo *CollectionInfo, updateProgressFunc func(percent int64)) (*CSVParser, error) {
if collectionInfo == nil {
log.Warn("CSV parser: collection schema is nil")
return nil, errors.New("collection schema is nil")
}
parser := &CSVParser{
ctx: ctx,
collectionInfo: collectionInfo,
bufRowCount: 1024,
fieldsName: make([]string, 0),
updateProgressFunc: updateProgressFunc,
}
parser.SetBufSize()
return parser, nil
}
func (p *CSVParser) SetBufSize() {
schema := p.collectionInfo.Schema
sizePerRecord, _ := typeutil.EstimateSizePerRecord(schema)
if sizePerRecord <= 0 {
return
}
bufRowCount := p.bufRowCount
for {
if bufRowCount*sizePerRecord > SingleBlockSize {
bufRowCount--
} else {
break
}
}
if bufRowCount <= 0 {
bufRowCount = 1
}
log.Info("CSV parser: reset bufRowCount", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufRowCount", bufRowCount))
p.bufRowCount = bufRowCount
}
func (p *CSVParser) combineDynamicRow(dynamicValues map[string]string, row map[storage.FieldID]string) error {
if p.collectionInfo.DynamicField == nil {
return nil
}
dynamicFieldID := p.collectionInfo.DynamicField.GetFieldID()
// combine the dynamic field value
// valid input:
// id,vector,x,$meta id,vector,$meta
// case1: 1,"[]",8,"{""y"": 8}" ==>> 1,"[]","{""y"": 8, ""x"": 8}"
// case2: 1,"[]",8,"{}" ==>> 1,"[]","{""x"": 8}"
// case3: 1,"[]",,"{""x"": 8}"
// case4: 1,"[]",8, ==>> 1,"[]","{""x"": 8}"
// case5: 1,"[]",,
value, ok := row[dynamicFieldID]
// ignore empty string field
if value == "" {
ok = false
}
if len(dynamicValues) > 0 {
mp := make(map[string]interface{})
if ok {
// case 1/2
// $meta is JSON type field, we first convert it to map[string]interface{}
// then merge other dynamic field into it
desc := json.NewDecoder(strings.NewReader(value))
desc.UseNumber()
if err := desc.Decode(&mp); err != nil {
log.Warn("CSV parser: illegal value for dynamic field, not a JSON object")
return errors.New("illegal value for dynamic field, not a JSON object")
}
}
// case 4
for k, v := range dynamicValues {
// ignore empty string field
if v == "" {
continue
}
var value interface{}
desc := json.NewDecoder(strings.NewReader(v))
desc.UseNumber()
if err := desc.Decode(&value); err != nil {
// Decode a string will cause error, like "abcd"
mp[k] = v
continue
}
if num, ok := value.(json.Number); ok {
// Decode may convert "123ab" to 123, so need additional check
if _, err := strconv.ParseFloat(v, 64); err != nil {
mp[k] = v
} else {
mp[k] = num
}
} else if arr, ok := value.([]interface{}); ok {
mp[k] = arr
} else if obj, ok := value.(map[string]interface{}); ok {
mp[k] = obj
} else if b, ok := value.(bool); ok {
mp[k] = b
}
}
bs, err := json.Marshal(mp)
if err != nil {
log.Warn("CSV parser: illegal value for dynamic field, not a JSON object")
return errors.New("illegal value for dynamic field, not a JSON object")
}
row[dynamicFieldID] = string(bs)
} else if !ok && len(dynamicValues) == 0 {
// case 5
row[dynamicFieldID] = "{}"
}
// else case 3
return nil
}
func (p *CSVParser) verifyRow(raw []string) (map[storage.FieldID]string, error) {
row := make(map[storage.FieldID]string)
dynamicValues := make(map[string]string)
for i := 0; i < len(p.fieldsName); i++ {
fieldName := p.fieldsName[i]
fieldID, ok := p.collectionInfo.Name2FieldID[fieldName]
if fieldID == p.collectionInfo.PrimaryKey.GetFieldID() && p.collectionInfo.PrimaryKey.GetAutoID() {
// primary key is auto-id, no need to provide
log.Warn("CSV parser: the primary key is auto-generated, no need to provide", zap.String("fieldName", fieldName))
return nil, fmt.Errorf("the primary key '%s' is auto-generated, no need to provide", fieldName)
}
if ok {
row[fieldID] = raw[i]
} else if p.collectionInfo.DynamicField != nil {
// collection have dynamic field. put it to dynamicValues
dynamicValues[fieldName] = raw[i]
} else {
// no dynamic field. if user provided redundant field, return error
log.Warn("CSV parser: the field is not defined in collection schema", zap.String("fieldName", fieldName))
return nil, fmt.Errorf("the field '%s' is not defined in collection schema", fieldName)
}
}
// some fields not provided?
if len(row) != len(p.collectionInfo.Name2FieldID) {
for k, v := range p.collectionInfo.Name2FieldID {
if p.collectionInfo.DynamicField != nil && v == p.collectionInfo.DynamicField.GetFieldID() {
// ignore dyanmic field, user don't have to provide values for dynamic field
continue
}
if v == p.collectionInfo.PrimaryKey.GetFieldID() && p.collectionInfo.PrimaryKey.GetAutoID() {
// ignore auto-generaed primary key
continue
}
_, ok := row[v]
if !ok {
// not auto-id primary key, no dynamic field, must provide value
log.Warn("CSV parser: a field value is missed", zap.String("fieldName", k))
return nil, fmt.Errorf("value of field '%s' is missed", k)
}
}
}
// combine the redundant pairs into dynamic field(if has)
err := p.combineDynamicRow(dynamicValues, row)
if err != nil {
log.Warn("CSV parser: failed to combine dynamic values", zap.Error(err))
return nil, err
}
return row, nil
}
func (p *CSVParser) ParseRows(reader *IOReader, handle CSVRowHandler) error {
if reader == nil || handle == nil {
log.Warn("CSV Parser: CSV parse handle is nil")
return errors.New("CSV parse handle is nil")
}
// discard bom in the file
RuneScanner := reader.r.(io.RuneScanner)
bom, _, err := RuneScanner.ReadRune()
if err == io.EOF {
log.Info("CSV Parser: row count is 0")
return nil
}
if err != nil {
return err
}
if bom != '\ufeff' {
RuneScanner.UnreadRune()
}
r := csv.NewReader(reader.r)
oldPercent := int64(0)
updateProgress := func() {
if p.updateProgressFunc != nil && reader.fileSize > 0 {
percent := (r.InputOffset() * ProgressValueForPersist) / reader.fileSize
if percent > oldPercent { // avoid too many log
log.Debug("CSV parser: working progress", zap.Int64("offset", r.InputOffset()),
zap.Int64("fileSize", reader.fileSize), zap.Int64("percent", percent))
}
oldPercent = percent
p.updateProgressFunc(percent)
}
}
isEmpty := true
for {
// read the fields value
fieldsName, err := r.Read()
if err == io.EOF {
break
} else if err != nil {
log.Warn("CSV Parser: failed to parse the field value", zap.Error(err))
return fmt.Errorf("failed to read the field value, error: %w", err)
}
p.fieldsName = fieldsName
// read buffer
buf := make([]map[storage.FieldID]string, 0, p.bufRowCount)
for {
// read the row value
values, err := r.Read()
if err == io.EOF {
break
} else if err != nil {
log.Warn("CSV parser: failed to parse row value", zap.Error(err))
return fmt.Errorf("failed to parse row value, error: %w", err)
}
row, err := p.verifyRow(values)
if err != nil {
return err
}
updateProgress()
buf = append(buf, row)
if len(buf) >= p.bufRowCount {
isEmpty = false
if err = handle.Handle(buf); err != nil {
log.Warn("CSV parser: failed to convert row value to entity", zap.Error(err))
return fmt.Errorf("failed to convert row value to entity, error: %w", err)
}
// clean the buffer
buf = make([]map[storage.FieldID]string, 0, p.bufRowCount)
}
}
if len(buf) > 0 {
isEmpty = false
if err = handle.Handle(buf); err != nil {
log.Warn("CSV parser: failed to convert row value to entity", zap.Error(err))
return fmt.Errorf("failed to convert row value to entity, error: %w", err)
}
}
// outside context might be canceled(service stop, or future enhancement for canceling import task)
if isCanceled(p.ctx) {
log.Warn("CSV parser: import task was canceled")
return errors.New("import task was canceled")
}
// nolint
// this break means we require the first row must be fieldsName
break
}
// empty file is allowed, don't return error
if isEmpty {
log.Info("CSV Parser: row count is 0")
return nil
}
updateProgress()
// send nil to notify the handler all have done
return handle.Handle(nil)
}

View File

@ -0,0 +1,414 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importutil
import (
"context"
"strings"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
)
type mockCSVRowConsumer struct {
handleErr error
rows []map[storage.FieldID]string
handleCount int
}
func (v *mockCSVRowConsumer) Handle(rows []map[storage.FieldID]string) error {
if v.handleErr != nil {
return v.handleErr
}
if rows != nil {
v.rows = append(v.rows, rows...)
}
v.handleCount++
return nil
}
func Test_CSVParserAdjustBufSize(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := sampleSchema()
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
assert.NoError(t, err)
parser, err := NewCSVParser(ctx, collectionInfo, nil)
assert.NoError(t, err)
assert.NotNil(t, parser)
assert.Greater(t, parser.bufRowCount, 0)
// huge row
schema.Fields[9].TypeParams = []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "32768"},
}
parser, err = NewCSVParser(ctx, collectionInfo, nil)
assert.NoError(t, err)
assert.NotNil(t, parser)
assert.Greater(t, parser.bufRowCount, 0)
}
func Test_CSVParserParseRows_IntPK(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := sampleSchema()
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
assert.NoError(t, err)
parser, err := NewCSVParser(ctx, collectionInfo, nil)
assert.NoError(t, err)
assert.NotNil(t, parser)
consumer := &mockCSVRowConsumer{
handleErr: nil,
rows: make([]map[int64]string, 0),
handleCount: 0,
}
reader := strings.NewReader(
`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
true,10,101,1001,10001,3.14,1.56,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`)
t.Run("parse success", func(t *testing.T) {
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
assert.NoError(t, err)
// empty file
reader = strings.NewReader(``)
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(0)}, consumer)
assert.NoError(t, err)
// only have headers no value row
reader = strings.NewReader(`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector`)
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
assert.NoError(t, err)
// csv file have bom
reader = strings.NewReader(`\ufeffFieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector`)
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
assert.NoError(t, err)
})
t.Run("error cases", func(t *testing.T) {
// handler is nil
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(0)}, nil)
assert.Error(t, err)
// csv parse error, fields len error
reader := strings.NewReader(
`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`)
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
assert.Error(t, err)
// redundant field
reader = strings.NewReader(
`dummy,FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
1,true,0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`)
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
assert.Error(t, err)
// field missed
reader = strings.NewReader(
`FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`)
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
assert.Error(t, err)
// handle() error
content := `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
true,0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`
consumer.handleErr = errors.New("error")
reader = strings.NewReader(content)
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
assert.Error(t, err)
// canceled
consumer.handleErr = nil
cancel()
reader = strings.NewReader(content)
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
assert.Error(t, err)
})
}
func Test_CSVParserCombineDynamicRow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
EnableDynamicField: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 106,
Name: "FieldID",
IsPrimaryKey: true,
AutoID: false,
Description: "int64",
DataType: schemapb.DataType_Int64,
},
{
FieldID: 113,
Name: "FieldDynamic",
IsPrimaryKey: false,
IsDynamic: true,
Description: "dynamic field",
DataType: schemapb.DataType_JSON,
},
},
}
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
assert.NoError(t, err)
parser, err := NewCSVParser(ctx, collectionInfo, nil)
assert.NoError(t, err)
assert.NotNil(t, parser)
// valid input:
// id,vector,x,$meta id,vector,$meta
// case1: 1,"[]",8,"{""y"": 8}" ==>> 1,"[]","{""y"": 8, ""x"": 8}"
// case2: 1,"[]",8,"{}" ==>> 1,"[]","{""x"": 8}"
// case3: 1,"[]",,"{""x"": 8}"
// case4: 1,"[]",8, ==>> 1,"[]","{""x"": 8}"
// case5: 1,"[]",,
t.Run("value combined for dynamic field", func(t *testing.T) {
dynamicValues := map[string]string{
"x": "88",
}
row := map[storage.FieldID]string{
106: "1",
113: `{"y": 8}`,
}
err = parser.combineDynamicRow(dynamicValues, row)
assert.NoError(t, err)
assert.Contains(t, row, int64(113))
assert.Contains(t, row[113], "x")
assert.Contains(t, row[113], "y")
row = map[storage.FieldID]string{
106: "1",
113: `{}`,
}
err = parser.combineDynamicRow(dynamicValues, row)
assert.NoError(t, err)
assert.Contains(t, row, int64(113))
assert.Contains(t, row[113], "x")
})
t.Run("JSON format string/object for dynamic field", func(t *testing.T) {
dynamicValues := map[string]string{}
row := map[storage.FieldID]string{
106: "1",
113: `{"x": 8}`,
}
err = parser.combineDynamicRow(dynamicValues, row)
assert.NoError(t, err)
assert.Contains(t, row, int64(113))
})
t.Run("dynamic field is hidden", func(t *testing.T) {
dynamicValues := map[string]string{
"x": "8",
}
row := map[storage.FieldID]string{
106: "1",
}
err = parser.combineDynamicRow(dynamicValues, row)
assert.NoError(t, err)
assert.Contains(t, row, int64(113))
assert.Contains(t, row, int64(113))
assert.Contains(t, row[113], "x")
})
t.Run("no values for dynamic field", func(t *testing.T) {
dynamicValues := map[string]string{}
row := map[storage.FieldID]string{
106: "1",
}
err = parser.combineDynamicRow(dynamicValues, row)
assert.NoError(t, err)
assert.Contains(t, row, int64(113))
assert.Equal(t, "{}", row[113])
})
t.Run("empty value for dynamic field", func(t *testing.T) {
dynamicValues := map[string]string{
"x": "",
}
row := map[storage.FieldID]string{
106: "1",
113: `{"y": 8}`,
}
err = parser.combineDynamicRow(dynamicValues, row)
assert.NoError(t, err)
assert.Contains(t, row, int64(113))
assert.Contains(t, row[113], "y")
assert.NotContains(t, row[113], "x")
row = map[storage.FieldID]string{
106: "1",
113: "",
}
err = parser.combineDynamicRow(dynamicValues, row)
assert.NoError(t, err)
assert.Equal(t, "{}", row[113])
dynamicValues = map[string]string{
"x": "5",
}
err = parser.combineDynamicRow(dynamicValues, row)
assert.NoError(t, err)
assert.Contains(t, row[113], "x")
})
t.Run("invalid input for dynamic field", func(t *testing.T) {
dynamicValues := map[string]string{
"x": "8",
}
row := map[storage.FieldID]string{
106: "1",
113: "5",
}
err = parser.combineDynamicRow(dynamicValues, row)
assert.Error(t, err)
row = map[storage.FieldID]string{
106: "1",
113: "abc",
}
err = parser.combineDynamicRow(dynamicValues, row)
assert.Error(t, err)
})
t.Run("not allow dynamic values if no dynamic field", func(t *testing.T) {
parser.collectionInfo.DynamicField = nil
dynamicValues := map[string]string{
"x": "8",
}
row := map[storage.FieldID]string{
106: "1",
}
err = parser.combineDynamicRow(dynamicValues, row)
assert.NoError(t, err)
assert.NotContains(t, row, int64(113))
})
}
func Test_CSVParserVerifyRow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
EnableDynamicField: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 106,
Name: "FieldID",
IsPrimaryKey: true,
AutoID: false,
Description: "int64",
DataType: schemapb.DataType_Int64,
},
{
FieldID: 113,
Name: "FieldDynamic",
IsPrimaryKey: false,
IsDynamic: true,
Description: "dynamic field",
DataType: schemapb.DataType_JSON,
},
},
}
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
assert.NoError(t, err)
parser, err := NewCSVParser(ctx, collectionInfo, nil)
assert.NoError(t, err)
assert.NotNil(t, parser)
t.Run("not auto-id, dynamic field provided", func(t *testing.T) {
parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"}
raw := []string{"1", `{"x": 8}`, "true"}
row, err := parser.verifyRow(raw)
assert.NoError(t, err)
assert.Contains(t, row, int64(106))
assert.Contains(t, row, int64(113))
assert.Contains(t, row[113], "x")
assert.Contains(t, row[113], "y")
})
t.Run("not auto-id, dynamic field not provided", func(t *testing.T) {
parser.fieldsName = []string{"FieldID"}
raw := []string{"1"}
row, err := parser.verifyRow(raw)
assert.NoError(t, err)
assert.Contains(t, row, int64(106))
assert.Contains(t, row, int64(113))
assert.Contains(t, "{}", row[113])
})
t.Run("not auto-id, invalid input dynamic field", func(t *testing.T) {
parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"}
raw := []string{"1", "true", "true"}
_, err = parser.verifyRow(raw)
assert.Error(t, err)
})
schema.Fields[0].AutoID = true
err = collectionInfo.resetSchema(schema)
assert.NoError(t, err)
t.Run("no need to provide value for auto-id", func(t *testing.T) {
parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"}
raw := []string{"1", `{"x": 8}`, "true"}
_, err := parser.verifyRow(raw)
assert.Error(t, err)
parser.fieldsName = []string{"FieldDynamic", "y"}
raw = []string{`{"x": 8}`, "true"}
row, err := parser.verifyRow(raw)
assert.NoError(t, err)
assert.Contains(t, row, int64(113))
})
schema.Fields[1].IsDynamic = false
err = collectionInfo.resetSchema(schema)
assert.NoError(t, err)
t.Run("auto id, no dynamic field", func(t *testing.T) {
parser.fieldsName = []string{"FieldDynamic", "y"}
raw := []string{`{"x": 8}`, "true"}
_, err := parser.verifyRow(raw)
assert.Error(t, err)
// miss FieldDynamic
parser.fieldsName = []string{}
raw = []string{}
_, err = parser.verifyRow(raw)
assert.Error(t, err)
})
}

View File

@ -464,7 +464,7 @@ func fillDynamicData(blockData BlockData, collectionSchema *schemapb.CollectionS
// tryFlushBlocks does the two things:
// 1. if accumulate data of a block exceed blockSize, call callFlushFunc to generate new binlog file
// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest block
// 2. if total accumulate data exceed maxTotalSize, call callFlushFunc to flush the biggest block
func tryFlushBlocks(ctx context.Context,
shardsData []ShardData,
collectionSchema *schemapb.CollectionSchema,

View File

@ -38,6 +38,7 @@ import (
const (
JSONFileExt = ".json"
NumpyFileExt = ".npy"
CSVFileExt = ".csv"
// supposed size of a single block, to control a binlog file size, the max biglog file size is no more than 2*SingleBlockSize
SingleBlockSize = 16 * 1024 * 1024 // 16MB
@ -177,21 +178,21 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) {
filePath := filePaths[i]
name, fileType := GetFileNameAndExt(filePath)
// only allow json file or numpy file
if fileType != JSONFileExt && fileType != NumpyFileExt {
// only allow json file, numpy file and csv file
if fileType != JSONFileExt && fileType != NumpyFileExt && fileType != CSVFileExt {
log.Warn("import wrapper: unsupported file type", zap.String("filePath", filePath))
return false, fmt.Errorf("unsupported file type: '%s'", filePath)
}
// we use the first file to determine row-based or column-based
if i == 0 && fileType == JSONFileExt {
if i == 0 && (fileType == JSONFileExt || fileType == CSVFileExt) {
rowBased = true
}
// check file type
// row-based only support json type, column-based only support numpy type
// row-based only support json and csv type, column-based only support numpy type
if rowBased {
if fileType != JSONFileExt {
if fileType != JSONFileExt && fileType != CSVFileExt {
log.Warn("import wrapper: unsupported file type for row-based mode", zap.String("filePath", filePath))
return rowBased, fmt.Errorf("unsupported file type for row-based mode: '%s'", filePath)
}
@ -269,6 +270,12 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error
log.Warn("import wrapper: failed to parse row-based json file", zap.Error(err), zap.String("filePath", filePath))
return err
}
} else if fileType == CSVFileExt {
err = p.parseRowBasedCSV(filePath, options.OnlyValidate)
if err != nil {
log.Warn("import wrapper: failed to parse row-based csv file", zap.Error(err), zap.String("filePath", filePath))
return err
}
} // no need to check else, since the fileValidation() already do this
// trigger gc after each file finished
@ -450,6 +457,54 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er
return nil
}
func (p *ImportWrapper) parseRowBasedCSV(filePath string, onlyValidate bool) error {
tr := timerecord.NewTimeRecorder("csv row-based parser: " + filePath)
file, err := p.chunkManager.Reader(p.ctx, filePath)
if err != nil {
return err
}
defer file.Close()
size, err := p.chunkManager.Size(p.ctx, filePath)
if err != nil {
return err
}
// csv parser
reader := bufio.NewReader(file)
parser, err := NewCSVParser(p.ctx, p.collectionInfo, p.updateProgressPercent)
if err != nil {
return err
}
// if only validate, we input a empty flushFunc so that the consumer do nothing but only validation.
var flushFunc ImportFlushFunc
if onlyValidate {
flushFunc = func(fields BlockData, shardID int, partitionID int64) error {
return nil
}
} else {
flushFunc = func(fields BlockData, shardID int, partitionID int64) error {
filePaths := []string{filePath}
printFieldsDataInfo(fields, "import wrapper: prepare to flush binlogs", filePaths)
return p.flushFunc(fields, shardID, partitionID)
}
}
consumer, err := NewCSVRowConsumer(p.ctx, p.collectionInfo, p.rowIDAllocator, SingleBlockSize, flushFunc)
if err != nil {
return err
}
err = parser.ParseRows(&IOReader{r: reader, fileSize: size}, consumer)
if err != nil {
return err
}
p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...)
tr.Elapse("parsed")
return nil
}
// flushFunc is the callback function for parsers generate segment and save binlog files
func (p *ImportWrapper) flushFunc(fields BlockData, shardID int, partitionID int64) error {
logFields := []zap.Field{

View File

@ -326,6 +326,93 @@ func Test_ImportWrapperRowBased(t *testing.T) {
})
}
func Test_ImportWrapperRowBased_CSV(t *testing.T) {
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.NoError(t, err)
defer os.RemoveAll(TempFilesPath)
paramtable.Init()
// NewDefaultFactory() use "/tmp/milvus" as default root path, and cannot specify root path
// NewChunkManagerFactory() can specify the root path
f := storage.NewChunkManagerFactory("local", storage.RootPath(TempFilesPath))
ctx := context.Background()
cm, err := f.NewPersistentStorageChunkManager(ctx)
assert.NoError(t, err)
defer cm.RemoveWithPrefix(ctx, cm.RootPath())
idAllocator := newIDAllocator(ctx, t, nil)
content := []byte(
`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
true,10,101,1001,10001,3.14,1.56,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"
false,11,102,1002,10002,3.15,1.57,No.1,"{""x"": 1}","[201,0]","[0.1,0.2,0.3,0.4]"
true,12,103,1003,10003,3.16,1.58,No.2,"{""x"": 2}","[202,0]","[0.1,0.2,0.3,0.4]"`)
filePath := TempFilesPath + "rows_1.csv"
err = cm.Write(ctx, filePath, content)
assert.NoError(t, err)
rowCounter := &rowCounterTest{}
assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter)
importResult := &rootcoordpb.ImportResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
TaskId: 1,
DatanodeId: 1,
State: commonpb.ImportState_ImportStarted,
Segments: make([]int64, 0),
AutoIds: make([]int64, 0),
RowCount: 0,
}
reportFunc := func(res *rootcoordpb.ImportResult) error {
return nil
}
collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1})
assert.NoError(t, err)
t.Run("success case", func(t *testing.T) {
wrapper := NewImportWrapper(ctx, collectionInfo, 1, idAllocator, cm, importResult, reportFunc)
wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc)
files := make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, ImportOptions{OnlyValidate: true})
assert.NoError(t, err)
assert.Equal(t, 0, rowCounter.rowCount)
err = wrapper.Import(files, DefaultImportOptions())
assert.NoError(t, err)
assert.Equal(t, 3, rowCounter.rowCount)
assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State)
})
t.Run("parse error", func(t *testing.T) {
content := []byte(
`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
true,false,103,1003,10003,3.16,1.58,No.2,"{""x"": 2}","[202,0]","[0.1,0.2,0.3,0.4]"`)
filePath = TempFilesPath + "rows_2.csv"
err = cm.Write(ctx, filePath, content)
assert.NoError(t, err)
importResult.State = commonpb.ImportState_ImportStarted
wrapper := NewImportWrapper(ctx, collectionInfo, 1, idAllocator, cm, importResult, reportFunc)
wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc)
files := make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, ImportOptions{OnlyValidate: true})
assert.Error(t, err)
assert.NotEqual(t, commonpb.ImportState_ImportPersisted, importResult.State)
})
t.Run("file doesn't exist", func(t *testing.T) {
files := make([]string, 0)
files = append(files, "/dummy/dummy.csv")
wrapper := NewImportWrapper(ctx, collectionInfo, 1, idAllocator, cm, importResult, reportFunc)
err = wrapper.Import(files, ImportOptions{OnlyValidate: true})
assert.Error(t, err)
})
}
func Test_ImportWrapperColumnBased_numpy(t *testing.T) {
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.NoError(t, err)

View File

@ -110,7 +110,9 @@ func (p *JSONParser) combineDynamicRow(dynamicValues map[string]interface{}, row
if value, is := obj.(string); is {
// case 1
mp := make(map[string]interface{})
err := json.Unmarshal([]byte(value), &mp)
desc := json.NewDecoder(strings.NewReader(value))
desc.UseNumber()
err := desc.Decode(&mp)
if err != nil {
// invalid input
return errors.New("illegal value for dynamic field, not a JSON format string")
@ -192,7 +194,7 @@ func (p *JSONParser) verifyRow(raw interface{}) (map[storage.FieldID]interface{}
}
}
// combine the redundant pairs into dunamic field(if has)
// combine the redundant pairs into dynamic field(if has)
err := p.combineDynamicRow(dynamicValues, row)
if err != nil {
log.Warn("JSON parser: failed to combine dynamic values", zap.Error(err))