mirror of https://github.com/milvus-io/milvus.git
Add CSV file import function (#27149)
Signed-off-by: kuma <675613722@qq.com> Co-authored-by: kuma <675613722@qq.com>pull/28051/head
parent
0677d2623d
commit
e88212ba4b
|
@ -399,18 +399,18 @@ func (m *importManager) isRowbased(files []string) (bool, error) {
|
|||
isRowBased := false
|
||||
for _, filePath := range files {
|
||||
_, fileType := importutil.GetFileNameAndExt(filePath)
|
||||
if fileType == importutil.JSONFileExt {
|
||||
if fileType == importutil.JSONFileExt || fileType == importutil.CSVFileExt {
|
||||
isRowBased = true
|
||||
} else if isRowBased {
|
||||
log.Error("row-based data file type must be JSON, mixed file types is not allowed", zap.Strings("files", files))
|
||||
return isRowBased, fmt.Errorf("row-based data file type must be JSON, file type '%s' is not allowed", fileType)
|
||||
log.Error("row-based data file type must be JSON or CSV, mixed file types is not allowed", zap.Strings("files", files))
|
||||
return isRowBased, fmt.Errorf("row-based data file type must be JSON or CSV, file type '%s' is not allowed", fileType)
|
||||
}
|
||||
}
|
||||
|
||||
// for row_based, we only allow one file so that each invocation only generate a task
|
||||
if isRowBased && len(files) > 1 {
|
||||
log.Error("row-based import, only allow one JSON file each time", zap.Strings("files", files))
|
||||
return isRowBased, fmt.Errorf("row-based import, only allow one JSON file each time")
|
||||
log.Error("row-based import, only allow one JSON or CSV file each time", zap.Strings("files", files))
|
||||
return isRowBased, fmt.Errorf("row-based import, only allow one JSON or CSV file each time")
|
||||
}
|
||||
|
||||
return isRowBased, nil
|
||||
|
|
|
@ -1101,6 +1101,26 @@ func TestImportManager_isRowbased(t *testing.T) {
|
|||
rb, err = mgr.isRowbased(files)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, rb)
|
||||
|
||||
files = []string{"1.csv"}
|
||||
rb, err = mgr.isRowbased(files)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, rb)
|
||||
|
||||
files = []string{"1.csv", "2.csv"}
|
||||
rb, err = mgr.isRowbased(files)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, rb)
|
||||
|
||||
files = []string{"1.csv", "2.json"}
|
||||
rb, err = mgr.isRowbased(files)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, rb)
|
||||
|
||||
files = []string{"1.csv", "2.npy"}
|
||||
rb, err = mgr.isRowbased(files)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, rb)
|
||||
}
|
||||
|
||||
func TestImportManager_mergeArray(t *testing.T) {
|
||||
|
|
|
@ -0,0 +1,446 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type CSVRowHandler interface {
|
||||
Handle(row []map[storage.FieldID]string) error
|
||||
}
|
||||
|
||||
// CSVRowConsumer is row-based csv format consumer class
|
||||
type CSVRowConsumer struct {
|
||||
ctx context.Context // for canceling parse process
|
||||
collectionInfo *CollectionInfo // collection details including schema
|
||||
rowIDAllocator *allocator.IDAllocator // autoid allocator
|
||||
validators map[storage.FieldID]*CSVValidator // validators for each field
|
||||
rowCounter int64 // how many rows have been consumed
|
||||
shardsData []ShardData // in-memory shards data
|
||||
blockSize int64 // maximum size of a read block(unit:byte)
|
||||
autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25
|
||||
|
||||
callFlushFunc ImportFlushFunc // call back function to flush segment
|
||||
}
|
||||
|
||||
func NewCSVRowConsumer(ctx context.Context,
|
||||
collectionInfo *CollectionInfo,
|
||||
idAlloc *allocator.IDAllocator,
|
||||
blockSize int64,
|
||||
flushFunc ImportFlushFunc,
|
||||
) (*CSVRowConsumer, error) {
|
||||
if collectionInfo == nil {
|
||||
log.Warn("CSV row consumer: collection schema is nil")
|
||||
return nil, errors.New("collection schema is nil")
|
||||
}
|
||||
|
||||
v := &CSVRowConsumer{
|
||||
ctx: ctx,
|
||||
collectionInfo: collectionInfo,
|
||||
rowIDAllocator: idAlloc,
|
||||
validators: make(map[storage.FieldID]*CSVValidator, 0),
|
||||
rowCounter: 0,
|
||||
shardsData: make([]ShardData, 0, collectionInfo.ShardNum),
|
||||
blockSize: blockSize,
|
||||
autoIDRange: make([]int64, 0),
|
||||
callFlushFunc: flushFunc,
|
||||
}
|
||||
|
||||
if err := v.initValidators(collectionInfo.Schema); err != nil {
|
||||
log.Warn("CSV row consumer: fail to initialize csv row-based consumer", zap.Error(err))
|
||||
return nil, fmt.Errorf("fail to initialize csv row-based consumer, error: %w", err)
|
||||
}
|
||||
|
||||
for i := 0; i < int(collectionInfo.ShardNum); i++ {
|
||||
shardData := initShardData(collectionInfo.Schema, collectionInfo.PartitionIDs)
|
||||
if shardData == nil {
|
||||
log.Warn("CSV row consumer: fail to initialize in-memory segment data", zap.Int("shardID", i))
|
||||
return nil, fmt.Errorf("fail to initialize in-memory segment data for shard id %d", i)
|
||||
}
|
||||
v.shardsData = append(v.shardsData, shardData)
|
||||
}
|
||||
|
||||
// primary key is autoid, id generator is required
|
||||
if v.collectionInfo.PrimaryKey.GetAutoID() && idAlloc == nil {
|
||||
log.Warn("CSV row consumer: ID allocator is nil")
|
||||
return nil, errors.New("ID allocator is nil")
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
type CSVValidator struct {
|
||||
convertFunc func(val string, field storage.FieldData) error // convert data function
|
||||
isString bool // for string field
|
||||
fieldName string // field name
|
||||
}
|
||||
|
||||
func (v *CSVRowConsumer) initValidators(collectionSchema *schemapb.CollectionSchema) error {
|
||||
if collectionSchema == nil {
|
||||
return errors.New("collection schema is nil")
|
||||
}
|
||||
|
||||
validators := v.validators
|
||||
|
||||
for i := 0; i < len(collectionSchema.Fields); i++ {
|
||||
schema := collectionSchema.Fields[i]
|
||||
|
||||
validators[schema.GetFieldID()] = &CSVValidator{}
|
||||
validators[schema.GetFieldID()].fieldName = schema.GetName()
|
||||
validators[schema.GetFieldID()].isString = false
|
||||
|
||||
switch schema.DataType {
|
||||
// all obj is string type
|
||||
case schemapb.DataType_Bool:
|
||||
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
|
||||
var value bool
|
||||
if err := json.Unmarshal([]byte(str), &value); err != nil {
|
||||
return fmt.Errorf("illegal value '%v' for bool type field '%s'", str, schema.GetName())
|
||||
}
|
||||
field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value)
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
|
||||
value, err := parseFloat(str, 32, schema.GetName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, float32(value))
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
|
||||
value, err := parseFloat(str, 64, schema.GetName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value)
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int8:
|
||||
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
|
||||
value, err := strconv.ParseInt(str, 0, 8)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse value '%v' for int8 field '%s', error: %w", str, schema.GetName(), err)
|
||||
}
|
||||
field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, int8(value))
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int16:
|
||||
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
|
||||
value, err := strconv.ParseInt(str, 0, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse value '%v' for int16 field '%s', error: %w", str, schema.GetName(), err)
|
||||
}
|
||||
field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, int16(value))
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int32:
|
||||
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
|
||||
value, err := strconv.ParseInt(str, 0, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse value '%v' for int32 field '%s', error: %w", str, schema.GetName(), err)
|
||||
}
|
||||
field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, int32(value))
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
|
||||
value, err := strconv.ParseInt(str, 0, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse value '%v' for int64 field '%s', error: %w", str, schema.GetName(), err)
|
||||
}
|
||||
field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value)
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
dim, err := getFieldDimension(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
|
||||
var arr []interface{}
|
||||
desc := json.NewDecoder(strings.NewReader(str))
|
||||
desc.UseNumber()
|
||||
if err := desc.Decode(&arr); err != nil {
|
||||
return fmt.Errorf("'%v' is not an array for binary vector field '%s'", str, schema.GetName())
|
||||
}
|
||||
|
||||
// we use uint8 to represent binary vector in csv file, each uint8 value represents 8 dimensions.
|
||||
if len(arr)*8 != dim {
|
||||
return fmt.Errorf("bit size %d doesn't equal to vector dimension %d of field '%s'", len(arr)*8, dim, schema.GetName())
|
||||
}
|
||||
|
||||
for i := 0; i < len(arr); i++ {
|
||||
if num, ok := arr[i].(json.Number); ok {
|
||||
value, err := strconv.ParseUint(string(num), 0, 8)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse value '%v' for binary vector field '%s', error: %w", num, schema.GetName(), err)
|
||||
}
|
||||
field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, byte(value))
|
||||
} else {
|
||||
return fmt.Errorf("illegal value '%v' for binary vector field '%s'", str, schema.GetName())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_FloatVector:
|
||||
dim, err := getFieldDimension(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
|
||||
var arr []interface{}
|
||||
desc := json.NewDecoder(strings.NewReader(str))
|
||||
desc.UseNumber()
|
||||
if err := desc.Decode(&arr); err != nil {
|
||||
return fmt.Errorf("'%v' is not an array for float vector field '%s'", str, schema.GetName())
|
||||
}
|
||||
|
||||
if len(arr) != dim {
|
||||
return fmt.Errorf("array size %d doesn't equal to vector dimension %d of field '%s'", len(arr), dim, schema.GetName())
|
||||
}
|
||||
|
||||
for i := 0; i < len(arr); i++ {
|
||||
if num, ok := arr[i].(json.Number); ok {
|
||||
value, err := parseFloat(string(num), 32, schema.GetName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, float32(value))
|
||||
} else {
|
||||
return fmt.Errorf("illegal value '%v' for float vector field '%s'", str, schema.GetName())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
validators[schema.GetFieldID()].isString = true
|
||||
|
||||
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
|
||||
field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, str)
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_JSON:
|
||||
validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error {
|
||||
var dummy interface{}
|
||||
if err := json.Unmarshal([]byte(str), &dummy); err != nil {
|
||||
return fmt.Errorf("failed to parse value '%v' for JSON field '%s', error: %w", str, schema.GetName(), err)
|
||||
}
|
||||
field.(*storage.JSONFieldData).Data = append(field.(*storage.JSONFieldData).Data, []byte(str))
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupport data type: %s", getTypeName(collectionSchema.Fields[i].DataType))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *CSVRowConsumer) IDRange() []int64 {
|
||||
return v.autoIDRange
|
||||
}
|
||||
|
||||
func (v *CSVRowConsumer) RowCount() int64 {
|
||||
return v.rowCounter
|
||||
}
|
||||
|
||||
func (v *CSVRowConsumer) Handle(rows []map[storage.FieldID]string) error {
|
||||
if v == nil || v.validators == nil || len(v.validators) == 0 {
|
||||
log.Warn("CSV row consumer is not initialized")
|
||||
return errors.New("CSV row consumer is not initialized")
|
||||
}
|
||||
// if rows is nil, that means read to end of file, force flush all data
|
||||
if rows == nil {
|
||||
err := tryFlushBlocks(v.ctx, v.shardsData, v.collectionInfo.Schema, v.callFlushFunc, v.blockSize, MaxTotalSizeInMemory, true)
|
||||
log.Info("CSV row consumer finished")
|
||||
return err
|
||||
}
|
||||
|
||||
// rows is not nil, flush in necessary:
|
||||
// 1. data block size larger than v.blockSize will be flushed
|
||||
// 2. total data size exceeds MaxTotalSizeInMemory, the largest data block will be flushed
|
||||
err := tryFlushBlocks(v.ctx, v.shardsData, v.collectionInfo.Schema, v.callFlushFunc, v.blockSize, MaxTotalSizeInMemory, false)
|
||||
if err != nil {
|
||||
log.Warn("CSV row consumer: try flush data but failed", zap.Error(err))
|
||||
return fmt.Errorf("try flush data but failed, error: %w", err)
|
||||
}
|
||||
|
||||
// prepare autoid, no matter int64 or varchar pk, we always generate autoid since the hidden field RowIDField requires them
|
||||
primaryKeyID := v.collectionInfo.PrimaryKey.FieldID
|
||||
primaryValidator := v.validators[primaryKeyID]
|
||||
var rowIDBegin typeutil.UniqueID
|
||||
var rowIDEnd typeutil.UniqueID
|
||||
if v.collectionInfo.PrimaryKey.AutoID {
|
||||
if v.rowIDAllocator == nil {
|
||||
log.Warn("CSV row consumer: primary keys is auto-generated but IDAllocator is nil")
|
||||
return fmt.Errorf("primary keys is auto-generated but IDAllocator is nil")
|
||||
}
|
||||
var err error
|
||||
rowIDBegin, rowIDEnd, err = v.rowIDAllocator.Alloc(uint32(len(rows)))
|
||||
if err != nil {
|
||||
log.Warn("CSV row consumer: failed to generate primary keys", zap.Int("count", len(rows)), zap.Error(err))
|
||||
return fmt.Errorf("failed to generate %d primary keys, error: %w", len(rows), err)
|
||||
}
|
||||
if rowIDEnd-rowIDBegin != int64(len(rows)) {
|
||||
log.Warn("CSV row consumer: try to generate primary keys but allocated ids are not enough",
|
||||
zap.Int("count", len(rows)), zap.Int64("generated", rowIDEnd-rowIDBegin))
|
||||
return fmt.Errorf("try to generate %d primary keys but only %d keys were allocated", len(rows), rowIDEnd-rowIDBegin)
|
||||
}
|
||||
log.Info("CSV row consumer: auto-generate primary keys", zap.Int64("begin", rowIDBegin), zap.Int64("end", rowIDEnd))
|
||||
if primaryValidator.isString {
|
||||
// if pk is varchar, no need to record auto-generated row ids
|
||||
log.Warn("CSV row consumer: string type primary key connot be auto-generated")
|
||||
return errors.New("string type primary key connot be auto-generated")
|
||||
}
|
||||
v.autoIDRange = append(v.autoIDRange, rowIDBegin, rowIDEnd)
|
||||
}
|
||||
|
||||
// consume rows
|
||||
for i := 0; i < len(rows); i++ {
|
||||
row := rows[i]
|
||||
rowNumber := v.rowCounter + int64(i)
|
||||
|
||||
// hash to a shard number
|
||||
var shardID uint32
|
||||
var partitionID int64
|
||||
if primaryValidator.isString {
|
||||
pk := row[primaryKeyID]
|
||||
|
||||
// hash to shard based on pk, hash to partition if partition key exist
|
||||
hash := typeutil.HashString2Uint32(pk)
|
||||
shardID = hash % uint32(v.collectionInfo.ShardNum)
|
||||
partitionID, err = v.hashToPartition(row, rowNumber)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pkArray := v.shardsData[shardID][partitionID][primaryKeyID].(*storage.StringFieldData)
|
||||
pkArray.Data = append(pkArray.Data, pk)
|
||||
} else {
|
||||
var pk int64
|
||||
if v.collectionInfo.PrimaryKey.AutoID {
|
||||
pk = rowIDBegin + int64(i)
|
||||
} else {
|
||||
pkStr := row[primaryKeyID]
|
||||
pk, err = strconv.ParseInt(pkStr, 10, 64)
|
||||
if err != nil {
|
||||
log.Warn("CSV row consumer: failed to parse primary key at the row",
|
||||
zap.String("value", pkStr), zap.Int64("rowNumber", rowNumber), zap.Error(err))
|
||||
return fmt.Errorf("failed to parse primary key '%s' at the row %d, error: %w",
|
||||
pkStr, rowNumber, err)
|
||||
}
|
||||
}
|
||||
|
||||
hash, err := typeutil.Hash32Int64(pk)
|
||||
if err != nil {
|
||||
log.Warn("CSV row consumer: failed to hash primary key at the row",
|
||||
zap.Int64("key", pk), zap.Int64("rowNumber", rowNumber), zap.Error(err))
|
||||
return fmt.Errorf("failed to hash primary key %d at the row %d, error: %w", pk, rowNumber, err)
|
||||
}
|
||||
|
||||
// hash to shard based on pk, hash to partition if partition key exist
|
||||
shardID = hash % uint32(v.collectionInfo.ShardNum)
|
||||
partitionID, err = v.hashToPartition(row, rowNumber)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pkArray := v.shardsData[shardID][partitionID][primaryKeyID].(*storage.Int64FieldData)
|
||||
pkArray.Data = append(pkArray.Data, pk)
|
||||
}
|
||||
rowIDField := v.shardsData[shardID][partitionID][common.RowIDField].(*storage.Int64FieldData)
|
||||
rowIDField.Data = append(rowIDField.Data, rowIDBegin+int64(i))
|
||||
|
||||
for fieldID, validator := range v.validators {
|
||||
if fieldID == v.collectionInfo.PrimaryKey.GetFieldID() {
|
||||
continue
|
||||
}
|
||||
|
||||
value := row[fieldID]
|
||||
if err := validator.convertFunc(value, v.shardsData[shardID][partitionID][fieldID]); err != nil {
|
||||
log.Warn("CSV row consumer: failed to convert value for field at the row",
|
||||
zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", rowNumber), zap.Error(err))
|
||||
return fmt.Errorf("failed to convert value for field '%s' at the row %d, error: %w",
|
||||
validator.fieldName, rowNumber, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
v.rowCounter += int64(len(rows))
|
||||
return nil
|
||||
}
|
||||
|
||||
// hashToPartition hash partition key to get an partition ID, return the first partition ID if no partition key exist
|
||||
// CollectionInfo ensures only one partition ID in the PartitionIDs if no partition key exist
|
||||
func (v *CSVRowConsumer) hashToPartition(row map[storage.FieldID]string, rowNumber int64) (int64, error) {
|
||||
if v.collectionInfo.PartitionKey == nil {
|
||||
if len(v.collectionInfo.PartitionIDs) != 1 {
|
||||
return 0, fmt.Errorf("collection '%s' partition list is empty", v.collectionInfo.Schema.Name)
|
||||
}
|
||||
// no partition key, directly return the target partition id
|
||||
return v.collectionInfo.PartitionIDs[0], nil
|
||||
}
|
||||
|
||||
partitionKeyID := v.collectionInfo.PartitionKey.GetFieldID()
|
||||
partitionKeyValidator := v.validators[partitionKeyID]
|
||||
value := row[partitionKeyID]
|
||||
|
||||
var hashValue uint32
|
||||
if partitionKeyValidator.isString {
|
||||
hashValue = typeutil.HashString2Uint32(value)
|
||||
} else {
|
||||
// parse the value from a string
|
||||
pk, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
log.Warn("CSV row consumer: failed to parse partition key at the row",
|
||||
zap.String("value", value), zap.Int64("rowNumber", rowNumber), zap.Error(err))
|
||||
return 0, fmt.Errorf("failed to parse partition key '%s' at the row %d, error: %w",
|
||||
value, rowNumber, err)
|
||||
}
|
||||
|
||||
hashValue, err = typeutil.Hash32Int64(pk)
|
||||
if err != nil {
|
||||
log.Warn("CSV row consumer: failed to hash partition key at the row",
|
||||
zap.Int64("key", pk), zap.Int64("rowNumber", rowNumber), zap.Error(err))
|
||||
return 0, fmt.Errorf("failed to hash partition key %d at the row %d, error: %w", pk, rowNumber, err)
|
||||
}
|
||||
}
|
||||
|
||||
index := int64(hashValue % uint32(len(v.collectionInfo.PartitionIDs)))
|
||||
return v.collectionInfo.PartitionIDs[index], nil
|
||||
}
|
|
@ -0,0 +1,760 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
)
|
||||
|
||||
func Test_CSVRowConsumerNew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("nil schema", func(t *testing.T) {
|
||||
consumer, err := NewCSVRowConsumer(ctx, nil, nil, 16, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, consumer)
|
||||
})
|
||||
|
||||
t.Run("wrong schema", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "uid",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
|
||||
assert.NoError(t, err)
|
||||
|
||||
schema.Fields[0].DataType = schemapb.DataType_None
|
||||
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, consumer)
|
||||
})
|
||||
|
||||
t.Run("primary key is autoid but no IDAllocator", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "uid",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
|
||||
assert.NoError(t, err)
|
||||
|
||||
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, consumer)
|
||||
})
|
||||
|
||||
t.Run("succeed", func(t *testing.T) {
|
||||
collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1})
|
||||
assert.NoError(t, err)
|
||||
|
||||
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, consumer)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_CSVRowConsumerInitValidators(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
consumer := &CSVRowConsumer{
|
||||
ctx: ctx,
|
||||
validators: make(map[int64]*CSVValidator),
|
||||
}
|
||||
|
||||
collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1})
|
||||
assert.NoError(t, err)
|
||||
schema := collectionInfo.Schema
|
||||
err = consumer.initValidators(schema)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(schema.Fields), len(consumer.validators))
|
||||
for _, field := range schema.Fields {
|
||||
fieldID := field.GetFieldID()
|
||||
assert.Equal(t, field.GetName(), consumer.validators[fieldID].fieldName)
|
||||
if field.GetDataType() != schemapb.DataType_VarChar && field.GetDataType() != schemapb.DataType_String {
|
||||
assert.False(t, consumer.validators[fieldID].isString)
|
||||
} else {
|
||||
assert.True(t, consumer.validators[fieldID].isString)
|
||||
}
|
||||
}
|
||||
|
||||
name2ID := make(map[string]storage.FieldID)
|
||||
for _, field := range schema.Fields {
|
||||
name2ID[field.GetName()] = field.GetFieldID()
|
||||
}
|
||||
|
||||
fields := initBlockData(schema)
|
||||
assert.NotNil(t, fields)
|
||||
|
||||
checkConvertFunc := func(funcName string, validVal string, invalidVal string) {
|
||||
id := name2ID[funcName]
|
||||
v, ok := consumer.validators[id]
|
||||
assert.True(t, ok)
|
||||
|
||||
fieldData := fields[id]
|
||||
preNum := fieldData.RowNum()
|
||||
err = v.convertFunc(validVal, fieldData)
|
||||
assert.NoError(t, err)
|
||||
postNum := fieldData.RowNum()
|
||||
assert.Equal(t, 1, postNum-preNum)
|
||||
|
||||
err = v.convertFunc(invalidVal, fieldData)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
t.Run("check convert functions", func(t *testing.T) {
|
||||
// all val is string type
|
||||
validVal := "true"
|
||||
invalidVal := "5"
|
||||
checkConvertFunc("FieldBool", validVal, invalidVal)
|
||||
|
||||
validVal = "100"
|
||||
invalidVal = "128"
|
||||
checkConvertFunc("FieldInt8", validVal, invalidVal)
|
||||
|
||||
invalidVal = "65536"
|
||||
checkConvertFunc("FieldInt16", validVal, invalidVal)
|
||||
|
||||
invalidVal = "2147483648"
|
||||
checkConvertFunc("FieldInt32", validVal, invalidVal)
|
||||
|
||||
invalidVal = "1.2"
|
||||
checkConvertFunc("FieldInt64", validVal, invalidVal)
|
||||
|
||||
invalidVal = "dummy"
|
||||
checkConvertFunc("FieldFloat", validVal, invalidVal)
|
||||
checkConvertFunc("FieldDouble", validVal, invalidVal)
|
||||
|
||||
// json type
|
||||
validVal = `{"x": 5, "y": true, "z": "hello"}`
|
||||
checkConvertFunc("FieldJSON", validVal, "a")
|
||||
checkConvertFunc("FieldJSON", validVal, "{")
|
||||
|
||||
// the binary vector dimension is 16, shoud input two uint8 values, each value should between 0~255
|
||||
validVal = "[100, 101]"
|
||||
invalidVal = "[100, 1256]"
|
||||
checkConvertFunc("FieldBinaryVector", validVal, invalidVal)
|
||||
|
||||
invalidVal = "false"
|
||||
checkConvertFunc("FieldBinaryVector", validVal, invalidVal)
|
||||
invalidVal = "[100]"
|
||||
checkConvertFunc("FieldBinaryVector", validVal, invalidVal)
|
||||
invalidVal = "[100.2, 102.5]"
|
||||
checkConvertFunc("FieldBinaryVector", validVal, invalidVal)
|
||||
|
||||
// the float vector dimension is 4, each value should be valid float number
|
||||
validVal = "[1,2,3,4]"
|
||||
invalidVal = `[1,2,3,"dummy"]`
|
||||
checkConvertFunc("FieldFloatVector", validVal, invalidVal)
|
||||
invalidVal = "true"
|
||||
checkConvertFunc("FieldFloatVector", validVal, invalidVal)
|
||||
invalidVal = `[1]`
|
||||
checkConvertFunc("FieldFloatVector", validVal, invalidVal)
|
||||
})
|
||||
|
||||
t.Run("init error cases", func(t *testing.T) {
|
||||
// schema is nil
|
||||
err := consumer.initValidators(nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
schema = &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
AutoID: true,
|
||||
Fields: make([]*schemapb.FieldSchema, 0),
|
||||
}
|
||||
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
|
||||
FieldID: 111,
|
||||
Name: "FieldFloatVector",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "aa"},
|
||||
},
|
||||
})
|
||||
consumer.validators = make(map[int64]*CSVValidator)
|
||||
err = consumer.initValidators(schema)
|
||||
assert.Error(t, err)
|
||||
|
||||
schema.Fields = make([]*schemapb.FieldSchema, 0)
|
||||
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
|
||||
FieldID: 110,
|
||||
Name: "FieldBinaryVector",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "aa"},
|
||||
},
|
||||
})
|
||||
|
||||
err = consumer.initValidators(schema)
|
||||
assert.Error(t, err)
|
||||
|
||||
// unsupported data type
|
||||
schema.Fields = make([]*schemapb.FieldSchema, 0)
|
||||
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
|
||||
FieldID: 110,
|
||||
Name: "dummy",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_None,
|
||||
})
|
||||
|
||||
err = consumer.initValidators(schema)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("json field", func(t *testing.T) {
|
||||
schema = &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "FieldJSON",
|
||||
DataType: schemapb.DataType_JSON,
|
||||
},
|
||||
},
|
||||
}
|
||||
consumer.validators = make(map[int64]*CSVValidator)
|
||||
err = consumer.initValidators(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v, ok := consumer.validators[102]
|
||||
assert.True(t, ok)
|
||||
|
||||
fields := initBlockData(schema)
|
||||
assert.NotNil(t, fields)
|
||||
fieldData := fields[102]
|
||||
|
||||
err = v.convertFunc("{\"x\": 1, \"y\": 5}", fieldData)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, fieldData.RowNum())
|
||||
|
||||
err = v.convertFunc("{}", fieldData)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, fieldData.RowNum())
|
||||
|
||||
err = v.convertFunc("", fieldData)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 2, fieldData.RowNum())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_CSVRowConsumerHandleIntPK(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
var consumer *CSVRowConsumer
|
||||
err := consumer.Handle(nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "FieldInt64",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "FieldVarchar",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "FieldFloat",
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
}
|
||||
createConsumeFunc := func(shardNum int32, partitionIDs []int64, flushFunc ImportFlushFunc) *CSVRowConsumer {
|
||||
collectionInfo, err := NewCollectionInfo(schema, shardNum, partitionIDs)
|
||||
assert.NoError(t, err)
|
||||
|
||||
idAllocator := newIDAllocator(ctx, t, nil)
|
||||
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, idAllocator, 1, flushFunc)
|
||||
assert.NotNil(t, consumer)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return consumer
|
||||
}
|
||||
|
||||
t.Run("auto pk no partition key", func(t *testing.T) {
|
||||
flushErrFunc := func(fields BlockData, shard int, partID int64) error {
|
||||
return errors.New("dummy error")
|
||||
}
|
||||
|
||||
// rows to input
|
||||
inputRowCount := 100
|
||||
input := make([]map[storage.FieldID]string, inputRowCount)
|
||||
for i := 0; i < inputRowCount; i++ {
|
||||
input[i] = map[storage.FieldID]string{
|
||||
102: "string",
|
||||
103: "122.5",
|
||||
}
|
||||
}
|
||||
|
||||
shardNum := int32(2)
|
||||
partitionID := int64(1)
|
||||
consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushErrFunc)
|
||||
consumer.rowIDAllocator = newIDAllocator(ctx, t, errors.New("error"))
|
||||
|
||||
waitFlushRowCount := 10
|
||||
fieldData := createFieldsData(schema, waitFlushRowCount)
|
||||
consumer.shardsData = createShardsData(schema, fieldData, shardNum, []int64{partitionID})
|
||||
|
||||
// nil input will trigger force flush, flushErrFunc returns error
|
||||
err := consumer.Handle(nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
// optional flush, flushErrFunc returns error
|
||||
err = consumer.Handle(input)
|
||||
assert.Error(t, err)
|
||||
|
||||
// reset flushFunc
|
||||
var callTime int32
|
||||
var flushedRowCount int
|
||||
consumer.callFlushFunc = func(fields BlockData, shard int, partID int64) error {
|
||||
callTime++
|
||||
assert.Less(t, int32(shard), shardNum)
|
||||
assert.Equal(t, partitionID, partID)
|
||||
assert.Greater(t, len(fields), 0)
|
||||
for _, v := range fields {
|
||||
assert.Greater(t, v.RowNum(), 0)
|
||||
}
|
||||
flushedRowCount += fields[102].RowNum()
|
||||
return nil
|
||||
}
|
||||
// optional flush succeed, each shard has 10 rows, idErrAllocator returns error
|
||||
err = consumer.Handle(input)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, waitFlushRowCount*int(shardNum), flushedRowCount)
|
||||
assert.Equal(t, shardNum, callTime)
|
||||
|
||||
// optional flush again, large blockSize, nothing flushed, idAllocator returns error
|
||||
callTime = int32(0)
|
||||
flushedRowCount = 0
|
||||
consumer.shardsData = createShardsData(schema, fieldData, shardNum, []int64{partitionID})
|
||||
consumer.rowIDAllocator = nil
|
||||
consumer.blockSize = 8 * 1024 * 1024
|
||||
err = consumer.Handle(input)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 0, flushedRowCount)
|
||||
assert.Equal(t, int32(0), callTime)
|
||||
|
||||
// idAllocator is ok, consume 100 rows, the previous shardsData(10 rows per shard) is flushed
|
||||
callTime = int32(0)
|
||||
flushedRowCount = 0
|
||||
consumer.blockSize = 1
|
||||
consumer.rowIDAllocator = newIDAllocator(ctx, t, nil)
|
||||
err = consumer.Handle(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, waitFlushRowCount*int(shardNum), flushedRowCount)
|
||||
assert.Equal(t, shardNum, callTime)
|
||||
assert.Equal(t, int64(inputRowCount), consumer.RowCount())
|
||||
assert.Equal(t, 2, len(consumer.IDRange()))
|
||||
assert.Equal(t, int64(1), consumer.IDRange()[0])
|
||||
assert.Equal(t, int64(1+inputRowCount), consumer.IDRange()[1])
|
||||
|
||||
// call handle again, the 100 rows are flushed
|
||||
callTime = int32(0)
|
||||
flushedRowCount = 0
|
||||
err = consumer.Handle(nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, inputRowCount, flushedRowCount)
|
||||
assert.Equal(t, shardNum, callTime)
|
||||
})
|
||||
|
||||
schema.Fields[0].AutoID = false
|
||||
|
||||
t.Run("manual pk no partition key", func(t *testing.T) {
|
||||
shardNum := int32(1)
|
||||
partitionID := int64(100)
|
||||
|
||||
var callTime int32
|
||||
var flushedRowCount int
|
||||
flushFunc := func(fields BlockData, shard int, partID int64) error {
|
||||
callTime++
|
||||
assert.Less(t, int32(shard), shardNum)
|
||||
assert.Equal(t, partitionID, partID)
|
||||
assert.Greater(t, len(fields), 0)
|
||||
flushedRowCount += fields[102].RowNum()
|
||||
return nil
|
||||
}
|
||||
|
||||
consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushFunc)
|
||||
|
||||
// failed to convert pk to int value
|
||||
input := make([]map[storage.FieldID]string, 1)
|
||||
input[0] = map[int64]string{
|
||||
101: "abc",
|
||||
102: "string",
|
||||
103: "11.11",
|
||||
}
|
||||
|
||||
err := consumer.Handle(input)
|
||||
assert.Error(t, err)
|
||||
|
||||
// failed to hash to partition
|
||||
input[0] = map[int64]string{
|
||||
101: "99",
|
||||
102: "string",
|
||||
103: "11.11",
|
||||
}
|
||||
consumer.collectionInfo.PartitionIDs = nil
|
||||
err = consumer.Handle(input)
|
||||
assert.Error(t, err)
|
||||
consumer.collectionInfo.PartitionIDs = []int64{partitionID}
|
||||
|
||||
// failed to convert value
|
||||
input[0] = map[int64]string{
|
||||
101: "99",
|
||||
102: "string",
|
||||
103: "abc.11",
|
||||
}
|
||||
err = consumer.Handle(input)
|
||||
assert.Error(t, err)
|
||||
consumer.shardsData = createShardsData(schema, nil, shardNum, []int64{partitionID}) // in-memory data is dirty, reset
|
||||
|
||||
// succeed, consum 1 row
|
||||
input[0] = map[int64]string{
|
||||
101: "99",
|
||||
102: "string",
|
||||
103: "11.11",
|
||||
}
|
||||
err = consumer.Handle(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), consumer.RowCount())
|
||||
assert.Equal(t, 0, len(consumer.IDRange()))
|
||||
|
||||
// call handle again, the 1 row is flushed
|
||||
callTime = int32(0)
|
||||
flushedRowCount = 0
|
||||
err = consumer.Handle(nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, flushedRowCount)
|
||||
assert.Equal(t, shardNum, callTime)
|
||||
})
|
||||
|
||||
schema.Fields[1].IsPartitionKey = true
|
||||
|
||||
t.Run("manual pk with partition key", func(t *testing.T) {
|
||||
// 10 partitions
|
||||
partitionIDs := make([]int64, 0)
|
||||
for i := 0; i < 10; i++ {
|
||||
partitionIDs = append(partitionIDs, int64(i))
|
||||
}
|
||||
|
||||
shardNum := int32(2)
|
||||
var flushedRowCount int
|
||||
flushFunc := func(fields BlockData, shard int, partID int64) error {
|
||||
assert.Less(t, int32(shard), shardNum)
|
||||
assert.Contains(t, partitionIDs, partID)
|
||||
assert.Greater(t, len(fields), 0)
|
||||
flushedRowCount += fields[102].RowNum()
|
||||
return nil
|
||||
}
|
||||
|
||||
consumer := createConsumeFunc(shardNum, partitionIDs, flushFunc)
|
||||
|
||||
// rows to input
|
||||
inputRowCount := 100
|
||||
input := make([]map[storage.FieldID]string, inputRowCount)
|
||||
for i := 0; i < inputRowCount; i++ {
|
||||
input[i] = map[int64]string{
|
||||
101: strconv.Itoa(i),
|
||||
102: "partitionKey_" + strconv.Itoa(i),
|
||||
103: "6.18",
|
||||
}
|
||||
}
|
||||
|
||||
// 100 rows are consumed to different partitions
|
||||
err := consumer.Handle(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(inputRowCount), consumer.RowCount())
|
||||
|
||||
// call handle again, 100 rows are flushed
|
||||
flushedRowCount = 0
|
||||
err = consumer.Handle(nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, inputRowCount, flushedRowCount)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_CSVRowConsumerHandleVarcharPK(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "FieldVarchar",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "FieldInt64",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "FieldFloat",
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
createConsumeFunc := func(shardNum int32, partitionIDs []int64, flushFunc ImportFlushFunc) *CSVRowConsumer {
|
||||
collectionInfo, err := NewCollectionInfo(schema, shardNum, partitionIDs)
|
||||
assert.NoError(t, err)
|
||||
|
||||
idAllocator := newIDAllocator(ctx, t, nil)
|
||||
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, idAllocator, 1, flushFunc)
|
||||
assert.NotNil(t, consumer)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return consumer
|
||||
}
|
||||
|
||||
t.Run("no partition key", func(t *testing.T) {
|
||||
shardNum := int32(2)
|
||||
partitionID := int64(1)
|
||||
var callTime int32
|
||||
var flushedRowCount int
|
||||
flushFunc := func(fields BlockData, shard int, partID int64) error {
|
||||
callTime++
|
||||
assert.Less(t, int32(shard), shardNum)
|
||||
assert.Equal(t, partitionID, partID)
|
||||
assert.Greater(t, len(fields), 0)
|
||||
for _, v := range fields {
|
||||
assert.Greater(t, v.RowNum(), 0)
|
||||
}
|
||||
flushedRowCount += fields[102].RowNum()
|
||||
return nil
|
||||
}
|
||||
|
||||
consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushFunc)
|
||||
consumer.shardsData = createShardsData(schema, nil, shardNum, []int64{partitionID})
|
||||
|
||||
// string type primary key cannot be auto-generated
|
||||
input := make([]map[storage.FieldID]string, 1)
|
||||
input[0] = map[storage.FieldID]string{
|
||||
101: "primaryKey_0",
|
||||
102: "1",
|
||||
103: "1.252",
|
||||
}
|
||||
|
||||
consumer.collectionInfo.PrimaryKey.AutoID = true
|
||||
err := consumer.Handle(input)
|
||||
assert.Error(t, err)
|
||||
consumer.collectionInfo.PrimaryKey.AutoID = false
|
||||
|
||||
// failed to hash to partition
|
||||
consumer.collectionInfo.PartitionIDs = nil
|
||||
err = consumer.Handle(input)
|
||||
assert.Error(t, err)
|
||||
consumer.collectionInfo.PartitionIDs = []int64{partitionID}
|
||||
|
||||
// rows to input
|
||||
inputRowCount := 100
|
||||
input = make([]map[storage.FieldID]string, inputRowCount)
|
||||
for i := 0; i < inputRowCount; i++ {
|
||||
input[i] = map[int64]string{
|
||||
101: "primaryKey_" + strconv.Itoa(i),
|
||||
102: strconv.Itoa(i),
|
||||
103: "6.18",
|
||||
}
|
||||
}
|
||||
|
||||
err = consumer.Handle(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(inputRowCount), consumer.RowCount())
|
||||
assert.Equal(t, 0, len(consumer.IDRange()))
|
||||
|
||||
// call handle again, 100 rows are flushed
|
||||
err = consumer.Handle(nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, inputRowCount, flushedRowCount)
|
||||
assert.Equal(t, shardNum, callTime)
|
||||
})
|
||||
|
||||
schema.Fields[1].IsPartitionKey = true
|
||||
t.Run("has partition key", func(t *testing.T) {
|
||||
partitionIDs := make([]int64, 0)
|
||||
for i := 0; i < 10; i++ {
|
||||
partitionIDs = append(partitionIDs, int64(i))
|
||||
}
|
||||
|
||||
shardNum := int32(2)
|
||||
var flushedRowCount int
|
||||
flushFunc := func(fields BlockData, shard int, partID int64) error {
|
||||
assert.Less(t, int32(shard), shardNum)
|
||||
assert.Contains(t, partitionIDs, partID)
|
||||
assert.Greater(t, len(fields), 0)
|
||||
flushedRowCount += fields[102].RowNum()
|
||||
return nil
|
||||
}
|
||||
|
||||
consumer := createConsumeFunc(shardNum, partitionIDs, flushFunc)
|
||||
|
||||
// rows to input
|
||||
inputRowCount := 100
|
||||
input := make([]map[storage.FieldID]string, inputRowCount)
|
||||
for i := 0; i < inputRowCount; i++ {
|
||||
input[i] = map[int64]string{
|
||||
101: "primaryKey_" + strconv.Itoa(i),
|
||||
102: strconv.Itoa(i),
|
||||
103: "6.18",
|
||||
}
|
||||
}
|
||||
|
||||
err := consumer.Handle(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(inputRowCount), consumer.RowCount())
|
||||
assert.Equal(t, 0, len(consumer.IDRange()))
|
||||
|
||||
// call handle again, 100 rows are flushed
|
||||
err = consumer.Handle(nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, inputRowCount, flushedRowCount)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_CSVRowConsumerHashToPartition(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "ID",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "FieldVarchar",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "FieldInt64",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
partitionID := int64(1)
|
||||
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{partitionID})
|
||||
assert.NoError(t, err)
|
||||
consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, consumer)
|
||||
input := map[int64]string{
|
||||
100: "1",
|
||||
101: "abc",
|
||||
102: "100",
|
||||
}
|
||||
t.Run("no partition key", func(t *testing.T) {
|
||||
partID, err := consumer.hashToPartition(input, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, partitionID, partID)
|
||||
})
|
||||
|
||||
t.Run("partition list is empty", func(t *testing.T) {
|
||||
collectionInfo.PartitionIDs = []int64{}
|
||||
partID, err := consumer.hashToPartition(input, 0)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, int64(0), partID)
|
||||
collectionInfo.PartitionIDs = []int64{partitionID}
|
||||
})
|
||||
|
||||
schema.Fields[1].IsPartitionKey = true
|
||||
err = collectionInfo.resetSchema(schema)
|
||||
assert.NoError(t, err)
|
||||
collectionInfo.PartitionIDs = []int64{1, 2, 3}
|
||||
|
||||
t.Run("varchar partition key", func(t *testing.T) {
|
||||
input = map[int64]string{
|
||||
100: "1",
|
||||
101: "abc",
|
||||
102: "100",
|
||||
}
|
||||
|
||||
partID, err := consumer.hashToPartition(input, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, collectionInfo.PartitionIDs, partID)
|
||||
})
|
||||
|
||||
schema.Fields[1].IsPartitionKey = false
|
||||
schema.Fields[2].IsPartitionKey = true
|
||||
err = collectionInfo.resetSchema(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("int64 partition key", func(t *testing.T) {
|
||||
input = map[int64]string{
|
||||
100: "1",
|
||||
101: "abc",
|
||||
102: "ab0",
|
||||
}
|
||||
// parse int failed
|
||||
partID, err := consumer.hashToPartition(input, 0)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, int64(0), partID)
|
||||
|
||||
// succeed
|
||||
input[102] = "100"
|
||||
partID, err = consumer.hashToPartition(input, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, collectionInfo.PartitionIDs, partID)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,318 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type CSVParser struct {
|
||||
ctx context.Context // for canceling parse process
|
||||
collectionInfo *CollectionInfo // collection details including schema
|
||||
bufRowCount int // max rows in a buffer
|
||||
fieldsName []string // fieldsName(header name) in the csv file
|
||||
updateProgressFunc func(percent int64) // update working progress percent value
|
||||
}
|
||||
|
||||
func NewCSVParser(ctx context.Context, collectionInfo *CollectionInfo, updateProgressFunc func(percent int64)) (*CSVParser, error) {
|
||||
if collectionInfo == nil {
|
||||
log.Warn("CSV parser: collection schema is nil")
|
||||
return nil, errors.New("collection schema is nil")
|
||||
}
|
||||
|
||||
parser := &CSVParser{
|
||||
ctx: ctx,
|
||||
collectionInfo: collectionInfo,
|
||||
bufRowCount: 1024,
|
||||
fieldsName: make([]string, 0),
|
||||
updateProgressFunc: updateProgressFunc,
|
||||
}
|
||||
parser.SetBufSize()
|
||||
return parser, nil
|
||||
}
|
||||
|
||||
func (p *CSVParser) SetBufSize() {
|
||||
schema := p.collectionInfo.Schema
|
||||
sizePerRecord, _ := typeutil.EstimateSizePerRecord(schema)
|
||||
if sizePerRecord <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
bufRowCount := p.bufRowCount
|
||||
for {
|
||||
if bufRowCount*sizePerRecord > SingleBlockSize {
|
||||
bufRowCount--
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
if bufRowCount <= 0 {
|
||||
bufRowCount = 1
|
||||
}
|
||||
log.Info("CSV parser: reset bufRowCount", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufRowCount", bufRowCount))
|
||||
p.bufRowCount = bufRowCount
|
||||
}
|
||||
|
||||
func (p *CSVParser) combineDynamicRow(dynamicValues map[string]string, row map[storage.FieldID]string) error {
|
||||
if p.collectionInfo.DynamicField == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
dynamicFieldID := p.collectionInfo.DynamicField.GetFieldID()
|
||||
// combine the dynamic field value
|
||||
// valid input:
|
||||
// id,vector,x,$meta id,vector,$meta
|
||||
// case1: 1,"[]",8,"{""y"": 8}" ==>> 1,"[]","{""y"": 8, ""x"": 8}"
|
||||
// case2: 1,"[]",8,"{}" ==>> 1,"[]","{""x"": 8}"
|
||||
// case3: 1,"[]",,"{""x"": 8}"
|
||||
// case4: 1,"[]",8, ==>> 1,"[]","{""x"": 8}"
|
||||
// case5: 1,"[]",,
|
||||
value, ok := row[dynamicFieldID]
|
||||
// ignore empty string field
|
||||
if value == "" {
|
||||
ok = false
|
||||
}
|
||||
if len(dynamicValues) > 0 {
|
||||
mp := make(map[string]interface{})
|
||||
if ok {
|
||||
// case 1/2
|
||||
// $meta is JSON type field, we first convert it to map[string]interface{}
|
||||
// then merge other dynamic field into it
|
||||
desc := json.NewDecoder(strings.NewReader(value))
|
||||
desc.UseNumber()
|
||||
if err := desc.Decode(&mp); err != nil {
|
||||
log.Warn("CSV parser: illegal value for dynamic field, not a JSON object")
|
||||
return errors.New("illegal value for dynamic field, not a JSON object")
|
||||
}
|
||||
}
|
||||
// case 4
|
||||
for k, v := range dynamicValues {
|
||||
// ignore empty string field
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
var value interface{}
|
||||
|
||||
desc := json.NewDecoder(strings.NewReader(v))
|
||||
desc.UseNumber()
|
||||
if err := desc.Decode(&value); err != nil {
|
||||
// Decode a string will cause error, like "abcd"
|
||||
mp[k] = v
|
||||
continue
|
||||
}
|
||||
|
||||
if num, ok := value.(json.Number); ok {
|
||||
// Decode may convert "123ab" to 123, so need additional check
|
||||
if _, err := strconv.ParseFloat(v, 64); err != nil {
|
||||
mp[k] = v
|
||||
} else {
|
||||
mp[k] = num
|
||||
}
|
||||
} else if arr, ok := value.([]interface{}); ok {
|
||||
mp[k] = arr
|
||||
} else if obj, ok := value.(map[string]interface{}); ok {
|
||||
mp[k] = obj
|
||||
} else if b, ok := value.(bool); ok {
|
||||
mp[k] = b
|
||||
}
|
||||
}
|
||||
bs, err := json.Marshal(mp)
|
||||
if err != nil {
|
||||
log.Warn("CSV parser: illegal value for dynamic field, not a JSON object")
|
||||
return errors.New("illegal value for dynamic field, not a JSON object")
|
||||
}
|
||||
row[dynamicFieldID] = string(bs)
|
||||
} else if !ok && len(dynamicValues) == 0 {
|
||||
// case 5
|
||||
row[dynamicFieldID] = "{}"
|
||||
}
|
||||
// else case 3
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *CSVParser) verifyRow(raw []string) (map[storage.FieldID]string, error) {
|
||||
row := make(map[storage.FieldID]string)
|
||||
dynamicValues := make(map[string]string)
|
||||
|
||||
for i := 0; i < len(p.fieldsName); i++ {
|
||||
fieldName := p.fieldsName[i]
|
||||
fieldID, ok := p.collectionInfo.Name2FieldID[fieldName]
|
||||
|
||||
if fieldID == p.collectionInfo.PrimaryKey.GetFieldID() && p.collectionInfo.PrimaryKey.GetAutoID() {
|
||||
// primary key is auto-id, no need to provide
|
||||
log.Warn("CSV parser: the primary key is auto-generated, no need to provide", zap.String("fieldName", fieldName))
|
||||
return nil, fmt.Errorf("the primary key '%s' is auto-generated, no need to provide", fieldName)
|
||||
}
|
||||
|
||||
if ok {
|
||||
row[fieldID] = raw[i]
|
||||
} else if p.collectionInfo.DynamicField != nil {
|
||||
// collection have dynamic field. put it to dynamicValues
|
||||
dynamicValues[fieldName] = raw[i]
|
||||
} else {
|
||||
// no dynamic field. if user provided redundant field, return error
|
||||
log.Warn("CSV parser: the field is not defined in collection schema", zap.String("fieldName", fieldName))
|
||||
return nil, fmt.Errorf("the field '%s' is not defined in collection schema", fieldName)
|
||||
}
|
||||
}
|
||||
// some fields not provided?
|
||||
if len(row) != len(p.collectionInfo.Name2FieldID) {
|
||||
for k, v := range p.collectionInfo.Name2FieldID {
|
||||
if p.collectionInfo.DynamicField != nil && v == p.collectionInfo.DynamicField.GetFieldID() {
|
||||
// ignore dyanmic field, user don't have to provide values for dynamic field
|
||||
continue
|
||||
}
|
||||
|
||||
if v == p.collectionInfo.PrimaryKey.GetFieldID() && p.collectionInfo.PrimaryKey.GetAutoID() {
|
||||
// ignore auto-generaed primary key
|
||||
continue
|
||||
}
|
||||
_, ok := row[v]
|
||||
if !ok {
|
||||
// not auto-id primary key, no dynamic field, must provide value
|
||||
log.Warn("CSV parser: a field value is missed", zap.String("fieldName", k))
|
||||
return nil, fmt.Errorf("value of field '%s' is missed", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
// combine the redundant pairs into dynamic field(if has)
|
||||
err := p.combineDynamicRow(dynamicValues, row)
|
||||
if err != nil {
|
||||
log.Warn("CSV parser: failed to combine dynamic values", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return row, nil
|
||||
}
|
||||
|
||||
func (p *CSVParser) ParseRows(reader *IOReader, handle CSVRowHandler) error {
|
||||
if reader == nil || handle == nil {
|
||||
log.Warn("CSV Parser: CSV parse handle is nil")
|
||||
return errors.New("CSV parse handle is nil")
|
||||
}
|
||||
// discard bom in the file
|
||||
RuneScanner := reader.r.(io.RuneScanner)
|
||||
bom, _, err := RuneScanner.ReadRune()
|
||||
if err == io.EOF {
|
||||
log.Info("CSV Parser: row count is 0")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if bom != '\ufeff' {
|
||||
RuneScanner.UnreadRune()
|
||||
}
|
||||
r := csv.NewReader(reader.r)
|
||||
|
||||
oldPercent := int64(0)
|
||||
updateProgress := func() {
|
||||
if p.updateProgressFunc != nil && reader.fileSize > 0 {
|
||||
percent := (r.InputOffset() * ProgressValueForPersist) / reader.fileSize
|
||||
if percent > oldPercent { // avoid too many log
|
||||
log.Debug("CSV parser: working progress", zap.Int64("offset", r.InputOffset()),
|
||||
zap.Int64("fileSize", reader.fileSize), zap.Int64("percent", percent))
|
||||
}
|
||||
oldPercent = percent
|
||||
p.updateProgressFunc(percent)
|
||||
}
|
||||
}
|
||||
isEmpty := true
|
||||
for {
|
||||
// read the fields value
|
||||
fieldsName, err := r.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
log.Warn("CSV Parser: failed to parse the field value", zap.Error(err))
|
||||
return fmt.Errorf("failed to read the field value, error: %w", err)
|
||||
}
|
||||
p.fieldsName = fieldsName
|
||||
// read buffer
|
||||
buf := make([]map[storage.FieldID]string, 0, p.bufRowCount)
|
||||
for {
|
||||
// read the row value
|
||||
values, err := r.Read()
|
||||
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
log.Warn("CSV parser: failed to parse row value", zap.Error(err))
|
||||
return fmt.Errorf("failed to parse row value, error: %w", err)
|
||||
}
|
||||
|
||||
row, err := p.verifyRow(values)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateProgress()
|
||||
|
||||
buf = append(buf, row)
|
||||
if len(buf) >= p.bufRowCount {
|
||||
isEmpty = false
|
||||
if err = handle.Handle(buf); err != nil {
|
||||
log.Warn("CSV parser: failed to convert row value to entity", zap.Error(err))
|
||||
return fmt.Errorf("failed to convert row value to entity, error: %w", err)
|
||||
}
|
||||
// clean the buffer
|
||||
buf = make([]map[storage.FieldID]string, 0, p.bufRowCount)
|
||||
}
|
||||
}
|
||||
if len(buf) > 0 {
|
||||
isEmpty = false
|
||||
if err = handle.Handle(buf); err != nil {
|
||||
log.Warn("CSV parser: failed to convert row value to entity", zap.Error(err))
|
||||
return fmt.Errorf("failed to convert row value to entity, error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// outside context might be canceled(service stop, or future enhancement for canceling import task)
|
||||
if isCanceled(p.ctx) {
|
||||
log.Warn("CSV parser: import task was canceled")
|
||||
return errors.New("import task was canceled")
|
||||
}
|
||||
// nolint
|
||||
// this break means we require the first row must be fieldsName
|
||||
break
|
||||
}
|
||||
|
||||
// empty file is allowed, don't return error
|
||||
if isEmpty {
|
||||
log.Info("CSV Parser: row count is 0")
|
||||
return nil
|
||||
}
|
||||
|
||||
updateProgress()
|
||||
|
||||
// send nil to notify the handler all have done
|
||||
return handle.Handle(nil)
|
||||
}
|
|
@ -0,0 +1,414 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
)
|
||||
|
||||
type mockCSVRowConsumer struct {
|
||||
handleErr error
|
||||
rows []map[storage.FieldID]string
|
||||
handleCount int
|
||||
}
|
||||
|
||||
func (v *mockCSVRowConsumer) Handle(rows []map[storage.FieldID]string) error {
|
||||
if v.handleErr != nil {
|
||||
return v.handleErr
|
||||
}
|
||||
if rows != nil {
|
||||
v.rows = append(v.rows, rows...)
|
||||
}
|
||||
v.handleCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_CSVParserAdjustBufSize(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
schema := sampleSchema()
|
||||
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
|
||||
assert.NoError(t, err)
|
||||
parser, err := NewCSVParser(ctx, collectionInfo, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, parser)
|
||||
assert.Greater(t, parser.bufRowCount, 0)
|
||||
// huge row
|
||||
schema.Fields[9].TypeParams = []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "32768"},
|
||||
}
|
||||
parser, err = NewCSVParser(ctx, collectionInfo, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, parser)
|
||||
assert.Greater(t, parser.bufRowCount, 0)
|
||||
}
|
||||
|
||||
func Test_CSVParserParseRows_IntPK(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
schema := sampleSchema()
|
||||
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
|
||||
assert.NoError(t, err)
|
||||
parser, err := NewCSVParser(ctx, collectionInfo, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, parser)
|
||||
|
||||
consumer := &mockCSVRowConsumer{
|
||||
handleErr: nil,
|
||||
rows: make([]map[int64]string, 0),
|
||||
handleCount: 0,
|
||||
}
|
||||
|
||||
reader := strings.NewReader(
|
||||
`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
|
||||
true,10,101,1001,10001,3.14,1.56,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`)
|
||||
|
||||
t.Run("parse success", func(t *testing.T) {
|
||||
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// empty file
|
||||
reader = strings.NewReader(``)
|
||||
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(0)}, consumer)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// only have headers no value row
|
||||
reader = strings.NewReader(`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector`)
|
||||
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// csv file have bom
|
||||
reader = strings.NewReader(`\ufeffFieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector`)
|
||||
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("error cases", func(t *testing.T) {
|
||||
// handler is nil
|
||||
|
||||
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(0)}, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
// csv parse error, fields len error
|
||||
reader := strings.NewReader(
|
||||
`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
|
||||
0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`)
|
||||
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
|
||||
assert.Error(t, err)
|
||||
|
||||
// redundant field
|
||||
reader = strings.NewReader(
|
||||
`dummy,FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
|
||||
1,true,0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`)
|
||||
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
|
||||
assert.Error(t, err)
|
||||
|
||||
// field missed
|
||||
reader = strings.NewReader(
|
||||
`FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
|
||||
0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`)
|
||||
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
|
||||
assert.Error(t, err)
|
||||
|
||||
// handle() error
|
||||
content := `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
|
||||
true,0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`
|
||||
consumer.handleErr = errors.New("error")
|
||||
reader = strings.NewReader(content)
|
||||
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
|
||||
assert.Error(t, err)
|
||||
|
||||
// canceled
|
||||
consumer.handleErr = nil
|
||||
cancel()
|
||||
reader = strings.NewReader(content)
|
||||
err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_CSVParserCombineDynamicRow(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
EnableDynamicField: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 106,
|
||||
Name: "FieldID",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
Description: "int64",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 113,
|
||||
Name: "FieldDynamic",
|
||||
IsPrimaryKey: false,
|
||||
IsDynamic: true,
|
||||
Description: "dynamic field",
|
||||
DataType: schemapb.DataType_JSON,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
|
||||
assert.NoError(t, err)
|
||||
parser, err := NewCSVParser(ctx, collectionInfo, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, parser)
|
||||
|
||||
// valid input:
|
||||
// id,vector,x,$meta id,vector,$meta
|
||||
// case1: 1,"[]",8,"{""y"": 8}" ==>> 1,"[]","{""y"": 8, ""x"": 8}"
|
||||
// case2: 1,"[]",8,"{}" ==>> 1,"[]","{""x"": 8}"
|
||||
// case3: 1,"[]",,"{""x"": 8}"
|
||||
// case4: 1,"[]",8, ==>> 1,"[]","{""x"": 8}"
|
||||
// case5: 1,"[]",,
|
||||
|
||||
t.Run("value combined for dynamic field", func(t *testing.T) {
|
||||
dynamicValues := map[string]string{
|
||||
"x": "88",
|
||||
}
|
||||
row := map[storage.FieldID]string{
|
||||
106: "1",
|
||||
113: `{"y": 8}`,
|
||||
}
|
||||
err = parser.combineDynamicRow(dynamicValues, row)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, row, int64(113))
|
||||
assert.Contains(t, row[113], "x")
|
||||
assert.Contains(t, row[113], "y")
|
||||
|
||||
row = map[storage.FieldID]string{
|
||||
106: "1",
|
||||
113: `{}`,
|
||||
}
|
||||
err = parser.combineDynamicRow(dynamicValues, row)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, row, int64(113))
|
||||
assert.Contains(t, row[113], "x")
|
||||
})
|
||||
|
||||
t.Run("JSON format string/object for dynamic field", func(t *testing.T) {
|
||||
dynamicValues := map[string]string{}
|
||||
row := map[storage.FieldID]string{
|
||||
106: "1",
|
||||
113: `{"x": 8}`,
|
||||
}
|
||||
err = parser.combineDynamicRow(dynamicValues, row)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, row, int64(113))
|
||||
})
|
||||
|
||||
t.Run("dynamic field is hidden", func(t *testing.T) {
|
||||
dynamicValues := map[string]string{
|
||||
"x": "8",
|
||||
}
|
||||
row := map[storage.FieldID]string{
|
||||
106: "1",
|
||||
}
|
||||
err = parser.combineDynamicRow(dynamicValues, row)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, row, int64(113))
|
||||
assert.Contains(t, row, int64(113))
|
||||
assert.Contains(t, row[113], "x")
|
||||
})
|
||||
|
||||
t.Run("no values for dynamic field", func(t *testing.T) {
|
||||
dynamicValues := map[string]string{}
|
||||
row := map[storage.FieldID]string{
|
||||
106: "1",
|
||||
}
|
||||
err = parser.combineDynamicRow(dynamicValues, row)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, row, int64(113))
|
||||
assert.Equal(t, "{}", row[113])
|
||||
})
|
||||
|
||||
t.Run("empty value for dynamic field", func(t *testing.T) {
|
||||
dynamicValues := map[string]string{
|
||||
"x": "",
|
||||
}
|
||||
row := map[storage.FieldID]string{
|
||||
106: "1",
|
||||
113: `{"y": 8}`,
|
||||
}
|
||||
err = parser.combineDynamicRow(dynamicValues, row)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, row, int64(113))
|
||||
assert.Contains(t, row[113], "y")
|
||||
assert.NotContains(t, row[113], "x")
|
||||
|
||||
row = map[storage.FieldID]string{
|
||||
106: "1",
|
||||
113: "",
|
||||
}
|
||||
err = parser.combineDynamicRow(dynamicValues, row)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "{}", row[113])
|
||||
|
||||
dynamicValues = map[string]string{
|
||||
"x": "5",
|
||||
}
|
||||
err = parser.combineDynamicRow(dynamicValues, row)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, row[113], "x")
|
||||
})
|
||||
|
||||
t.Run("invalid input for dynamic field", func(t *testing.T) {
|
||||
dynamicValues := map[string]string{
|
||||
"x": "8",
|
||||
}
|
||||
row := map[storage.FieldID]string{
|
||||
106: "1",
|
||||
113: "5",
|
||||
}
|
||||
err = parser.combineDynamicRow(dynamicValues, row)
|
||||
assert.Error(t, err)
|
||||
|
||||
row = map[storage.FieldID]string{
|
||||
106: "1",
|
||||
113: "abc",
|
||||
}
|
||||
err = parser.combineDynamicRow(dynamicValues, row)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not allow dynamic values if no dynamic field", func(t *testing.T) {
|
||||
parser.collectionInfo.DynamicField = nil
|
||||
dynamicValues := map[string]string{
|
||||
"x": "8",
|
||||
}
|
||||
row := map[storage.FieldID]string{
|
||||
106: "1",
|
||||
}
|
||||
err = parser.combineDynamicRow(dynamicValues, row)
|
||||
assert.NoError(t, err)
|
||||
assert.NotContains(t, row, int64(113))
|
||||
})
|
||||
}
|
||||
|
||||
func Test_CSVParserVerifyRow(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
EnableDynamicField: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 106,
|
||||
Name: "FieldID",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
Description: "int64",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 113,
|
||||
Name: "FieldDynamic",
|
||||
IsPrimaryKey: false,
|
||||
IsDynamic: true,
|
||||
Description: "dynamic field",
|
||||
DataType: schemapb.DataType_JSON,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1})
|
||||
assert.NoError(t, err)
|
||||
parser, err := NewCSVParser(ctx, collectionInfo, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, parser)
|
||||
|
||||
t.Run("not auto-id, dynamic field provided", func(t *testing.T) {
|
||||
parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"}
|
||||
raw := []string{"1", `{"x": 8}`, "true"}
|
||||
row, err := parser.verifyRow(raw)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, row, int64(106))
|
||||
assert.Contains(t, row, int64(113))
|
||||
assert.Contains(t, row[113], "x")
|
||||
assert.Contains(t, row[113], "y")
|
||||
})
|
||||
|
||||
t.Run("not auto-id, dynamic field not provided", func(t *testing.T) {
|
||||
parser.fieldsName = []string{"FieldID"}
|
||||
raw := []string{"1"}
|
||||
row, err := parser.verifyRow(raw)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, row, int64(106))
|
||||
assert.Contains(t, row, int64(113))
|
||||
assert.Contains(t, "{}", row[113])
|
||||
})
|
||||
|
||||
t.Run("not auto-id, invalid input dynamic field", func(t *testing.T) {
|
||||
parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"}
|
||||
raw := []string{"1", "true", "true"}
|
||||
_, err = parser.verifyRow(raw)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
schema.Fields[0].AutoID = true
|
||||
err = collectionInfo.resetSchema(schema)
|
||||
assert.NoError(t, err)
|
||||
t.Run("no need to provide value for auto-id", func(t *testing.T) {
|
||||
parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"}
|
||||
raw := []string{"1", `{"x": 8}`, "true"}
|
||||
_, err := parser.verifyRow(raw)
|
||||
assert.Error(t, err)
|
||||
|
||||
parser.fieldsName = []string{"FieldDynamic", "y"}
|
||||
raw = []string{`{"x": 8}`, "true"}
|
||||
row, err := parser.verifyRow(raw)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, row, int64(113))
|
||||
})
|
||||
|
||||
schema.Fields[1].IsDynamic = false
|
||||
err = collectionInfo.resetSchema(schema)
|
||||
assert.NoError(t, err)
|
||||
t.Run("auto id, no dynamic field", func(t *testing.T) {
|
||||
parser.fieldsName = []string{"FieldDynamic", "y"}
|
||||
raw := []string{`{"x": 8}`, "true"}
|
||||
_, err := parser.verifyRow(raw)
|
||||
assert.Error(t, err)
|
||||
|
||||
// miss FieldDynamic
|
||||
parser.fieldsName = []string{}
|
||||
raw = []string{}
|
||||
_, err = parser.verifyRow(raw)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
|
@ -464,7 +464,7 @@ func fillDynamicData(blockData BlockData, collectionSchema *schemapb.CollectionS
|
|||
|
||||
// tryFlushBlocks does the two things:
|
||||
// 1. if accumulate data of a block exceed blockSize, call callFlushFunc to generate new binlog file
|
||||
// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest block
|
||||
// 2. if total accumulate data exceed maxTotalSize, call callFlushFunc to flush the biggest block
|
||||
func tryFlushBlocks(ctx context.Context,
|
||||
shardsData []ShardData,
|
||||
collectionSchema *schemapb.CollectionSchema,
|
||||
|
|
|
@ -38,6 +38,7 @@ import (
|
|||
const (
|
||||
JSONFileExt = ".json"
|
||||
NumpyFileExt = ".npy"
|
||||
CSVFileExt = ".csv"
|
||||
|
||||
// supposed size of a single block, to control a binlog file size, the max biglog file size is no more than 2*SingleBlockSize
|
||||
SingleBlockSize = 16 * 1024 * 1024 // 16MB
|
||||
|
@ -177,21 +178,21 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) {
|
|||
filePath := filePaths[i]
|
||||
name, fileType := GetFileNameAndExt(filePath)
|
||||
|
||||
// only allow json file or numpy file
|
||||
if fileType != JSONFileExt && fileType != NumpyFileExt {
|
||||
// only allow json file, numpy file and csv file
|
||||
if fileType != JSONFileExt && fileType != NumpyFileExt && fileType != CSVFileExt {
|
||||
log.Warn("import wrapper: unsupported file type", zap.String("filePath", filePath))
|
||||
return false, fmt.Errorf("unsupported file type: '%s'", filePath)
|
||||
}
|
||||
|
||||
// we use the first file to determine row-based or column-based
|
||||
if i == 0 && fileType == JSONFileExt {
|
||||
if i == 0 && (fileType == JSONFileExt || fileType == CSVFileExt) {
|
||||
rowBased = true
|
||||
}
|
||||
|
||||
// check file type
|
||||
// row-based only support json type, column-based only support numpy type
|
||||
// row-based only support json and csv type, column-based only support numpy type
|
||||
if rowBased {
|
||||
if fileType != JSONFileExt {
|
||||
if fileType != JSONFileExt && fileType != CSVFileExt {
|
||||
log.Warn("import wrapper: unsupported file type for row-based mode", zap.String("filePath", filePath))
|
||||
return rowBased, fmt.Errorf("unsupported file type for row-based mode: '%s'", filePath)
|
||||
}
|
||||
|
@ -269,6 +270,12 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error
|
|||
log.Warn("import wrapper: failed to parse row-based json file", zap.Error(err), zap.String("filePath", filePath))
|
||||
return err
|
||||
}
|
||||
} else if fileType == CSVFileExt {
|
||||
err = p.parseRowBasedCSV(filePath, options.OnlyValidate)
|
||||
if err != nil {
|
||||
log.Warn("import wrapper: failed to parse row-based csv file", zap.Error(err), zap.String("filePath", filePath))
|
||||
return err
|
||||
}
|
||||
} // no need to check else, since the fileValidation() already do this
|
||||
|
||||
// trigger gc after each file finished
|
||||
|
@ -450,6 +457,54 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *ImportWrapper) parseRowBasedCSV(filePath string, onlyValidate bool) error {
|
||||
tr := timerecord.NewTimeRecorder("csv row-based parser: " + filePath)
|
||||
|
||||
file, err := p.chunkManager.Reader(p.ctx, filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
size, err := p.chunkManager.Size(p.ctx, filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// csv parser
|
||||
reader := bufio.NewReader(file)
|
||||
parser, err := NewCSVParser(p.ctx, p.collectionInfo, p.updateProgressPercent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if only validate, we input a empty flushFunc so that the consumer do nothing but only validation.
|
||||
var flushFunc ImportFlushFunc
|
||||
if onlyValidate {
|
||||
flushFunc = func(fields BlockData, shardID int, partitionID int64) error {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
flushFunc = func(fields BlockData, shardID int, partitionID int64) error {
|
||||
filePaths := []string{filePath}
|
||||
printFieldsDataInfo(fields, "import wrapper: prepare to flush binlogs", filePaths)
|
||||
return p.flushFunc(fields, shardID, partitionID)
|
||||
}
|
||||
}
|
||||
|
||||
consumer, err := NewCSVRowConsumer(p.ctx, p.collectionInfo, p.rowIDAllocator, SingleBlockSize, flushFunc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = parser.ParseRows(&IOReader{r: reader, fileSize: size}, consumer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...)
|
||||
|
||||
tr.Elapse("parsed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// flushFunc is the callback function for parsers generate segment and save binlog files
|
||||
func (p *ImportWrapper) flushFunc(fields BlockData, shardID int, partitionID int64) error {
|
||||
logFields := []zap.Field{
|
||||
|
|
|
@ -326,6 +326,93 @@ func Test_ImportWrapperRowBased(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_ImportWrapperRowBased_CSV(t *testing.T) {
|
||||
err := os.MkdirAll(TempFilesPath, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
defer os.RemoveAll(TempFilesPath)
|
||||
paramtable.Init()
|
||||
|
||||
// NewDefaultFactory() use "/tmp/milvus" as default root path, and cannot specify root path
|
||||
// NewChunkManagerFactory() can specify the root path
|
||||
f := storage.NewChunkManagerFactory("local", storage.RootPath(TempFilesPath))
|
||||
ctx := context.Background()
|
||||
cm, err := f.NewPersistentStorageChunkManager(ctx)
|
||||
assert.NoError(t, err)
|
||||
defer cm.RemoveWithPrefix(ctx, cm.RootPath())
|
||||
|
||||
idAllocator := newIDAllocator(ctx, t, nil)
|
||||
content := []byte(
|
||||
`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
|
||||
true,10,101,1001,10001,3.14,1.56,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"
|
||||
false,11,102,1002,10002,3.15,1.57,No.1,"{""x"": 1}","[201,0]","[0.1,0.2,0.3,0.4]"
|
||||
true,12,103,1003,10003,3.16,1.58,No.2,"{""x"": 2}","[202,0]","[0.1,0.2,0.3,0.4]"`)
|
||||
|
||||
filePath := TempFilesPath + "rows_1.csv"
|
||||
err = cm.Write(ctx, filePath, content)
|
||||
assert.NoError(t, err)
|
||||
rowCounter := &rowCounterTest{}
|
||||
assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter)
|
||||
importResult := &rootcoordpb.ImportResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
TaskId: 1,
|
||||
DatanodeId: 1,
|
||||
State: commonpb.ImportState_ImportStarted,
|
||||
Segments: make([]int64, 0),
|
||||
AutoIds: make([]int64, 0),
|
||||
RowCount: 0,
|
||||
}
|
||||
|
||||
reportFunc := func(res *rootcoordpb.ImportResult) error {
|
||||
return nil
|
||||
}
|
||||
collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1})
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("success case", func(t *testing.T) {
|
||||
wrapper := NewImportWrapper(ctx, collectionInfo, 1, idAllocator, cm, importResult, reportFunc)
|
||||
wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc)
|
||||
files := make([]string, 0)
|
||||
files = append(files, filePath)
|
||||
err = wrapper.Import(files, ImportOptions{OnlyValidate: true})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, rowCounter.rowCount)
|
||||
|
||||
err = wrapper.Import(files, DefaultImportOptions())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, rowCounter.rowCount)
|
||||
assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State)
|
||||
})
|
||||
|
||||
t.Run("parse error", func(t *testing.T) {
|
||||
content := []byte(
|
||||
`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector
|
||||
true,false,103,1003,10003,3.16,1.58,No.2,"{""x"": 2}","[202,0]","[0.1,0.2,0.3,0.4]"`)
|
||||
|
||||
filePath = TempFilesPath + "rows_2.csv"
|
||||
err = cm.Write(ctx, filePath, content)
|
||||
assert.NoError(t, err)
|
||||
|
||||
importResult.State = commonpb.ImportState_ImportStarted
|
||||
wrapper := NewImportWrapper(ctx, collectionInfo, 1, idAllocator, cm, importResult, reportFunc)
|
||||
wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc)
|
||||
files := make([]string, 0)
|
||||
files = append(files, filePath)
|
||||
err = wrapper.Import(files, ImportOptions{OnlyValidate: true})
|
||||
assert.Error(t, err)
|
||||
assert.NotEqual(t, commonpb.ImportState_ImportPersisted, importResult.State)
|
||||
})
|
||||
|
||||
t.Run("file doesn't exist", func(t *testing.T) {
|
||||
files := make([]string, 0)
|
||||
files = append(files, "/dummy/dummy.csv")
|
||||
wrapper := NewImportWrapper(ctx, collectionInfo, 1, idAllocator, cm, importResult, reportFunc)
|
||||
err = wrapper.Import(files, ImportOptions{OnlyValidate: true})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_ImportWrapperColumnBased_numpy(t *testing.T) {
|
||||
err := os.MkdirAll(TempFilesPath, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -110,7 +110,9 @@ func (p *JSONParser) combineDynamicRow(dynamicValues map[string]interface{}, row
|
|||
if value, is := obj.(string); is {
|
||||
// case 1
|
||||
mp := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(value), &mp)
|
||||
desc := json.NewDecoder(strings.NewReader(value))
|
||||
desc.UseNumber()
|
||||
err := desc.Decode(&mp)
|
||||
if err != nil {
|
||||
// invalid input
|
||||
return errors.New("illegal value for dynamic field, not a JSON format string")
|
||||
|
@ -192,7 +194,7 @@ func (p *JSONParser) verifyRow(raw interface{}) (map[storage.FieldID]interface{}
|
|||
}
|
||||
}
|
||||
|
||||
// combine the redundant pairs into dunamic field(if has)
|
||||
// combine the redundant pairs into dynamic field(if has)
|
||||
err := p.combineDynamicRow(dynamicValues, row)
|
||||
if err != nil {
|
||||
log.Warn("JSON parser: failed to combine dynamic values", zap.Error(err))
|
||||
|
|
Loading…
Reference in New Issue