From 168935f4cf84145088eee81ebcbdaec47ac024c1 Mon Sep 17 00:00:00 2001 From: groot Date: Fri, 13 May 2022 16:07:54 +0800 Subject: [PATCH] Fix bulkload bug for string primary key (#16958) Signed-off-by: groot --- internal/datanode/data_node.go | 9 +- internal/util/importutil/import_wrapper.go | 37 ++- .../util/importutil/import_wrapper_test.go | 67 ++++++ internal/util/importutil/json_handler.go | 61 +++-- internal/util/importutil/json_handler_test.go | 132 ++++++++++- internal/util/importutil/json_parser.go | 18 +- internal/util/importutil/json_parser_test.go | 213 ++++++++++++++++++ 7 files changed, 502 insertions(+), 35 deletions(-) diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index 50924bb040..6ad44522a9 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -934,14 +934,7 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root Data: tsFieldData, NumRows: []int64{int64(rowNum)}, } - var pkFieldID int64 - for _, field := range schema.Fields { - if field.IsPrimaryKey { - pkFieldID = field.GetFieldID() - break - } - } - fields[common.RowIDField] = fields[pkFieldID] + if status, _ := node.dataCoord.UpdateSegmentStatistics(context.TODO(), &datapb.UpdateSegmentStatisticsRequest{ Stats: []*datapb.SegmentStats{{ SegmentID: segmentID, diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go index 0a937aa2f4..9b30c134f2 100644 --- a/internal/util/importutil/import_wrapper.go +++ b/internal/util/importutil/import_wrapper.go @@ -505,12 +505,19 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F return errors.New("import error: primary key field is not provided") } - // generate auto id for primary key + // generate auto id for primary key and rowid field + var rowIDBegin typeutil.UniqueID + var rowIDEnd typeutil.UniqueID + rowIDBegin, rowIDEnd, _ = p.rowIDAllocator.Alloc(uint32(rowCount)) + + rowIDField := fieldsData[common.RowIDField] + rowIDFieldArr := rowIDField.(*storage.Int64FieldData) + for i := rowIDBegin; i < rowIDEnd; i++ { + rowIDFieldArr.Data = append(rowIDFieldArr.Data, rowIDBegin+i) + } + if primaryKey.GetAutoID() { log.Info("import wrapper: generating auto-id", zap.Any("rowCount", rowCount)) - var rowIDBegin typeutil.UniqueID - var rowIDEnd typeutil.UniqueID - rowIDBegin, rowIDEnd, _ = p.rowIDAllocator.Alloc(uint32(rowCount)) primaryDataArr := primaryData.(*storage.Int64FieldData) for i := rowIDBegin; i < rowIDEnd; i++ { @@ -547,11 +554,27 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F // split data into segments for i := 0; i < rowCount; i++ { - id := primaryData.GetRow(i).(int64) // hash to a shard number - hash, _ := typeutil.Hash32Int64(id) - shard := hash % uint32(p.shardNum) + var shard uint32 + pk := primaryData.GetRow(i) + strPK, ok := interface{}(pk).(string) + if ok { + hash := typeutil.HashString2Uint32(strPK) + shard = hash % uint32(p.shardNum) + } else { + intPK, ok := interface{}(pk).(int64) + if !ok { + return errors.New("import error: primary key field must be int64 or varchar") + } + hash, _ := typeutil.Hash32Int64(intPK) + shard = hash % uint32(p.shardNum) + } + // set rowID field + rowIDField := segmentsData[shard][common.RowIDField].(*storage.Int64FieldData) + rowIDField.Data = append(rowIDField.Data, rowIDFieldArr.GetRow(i).(int64)) + + // append row to shard for k := 0; k < len(p.collectionSchema.Fields); k++ { schema := p.collectionSchema.Fields[k] srcData := fieldsData[schema.GetFieldID()] diff --git a/internal/util/importutil/import_wrapper_test.go b/internal/util/importutil/import_wrapper_test.go index 2a8098dfee..64b49559fb 100644 --- a/internal/util/importutil/import_wrapper_test.go +++ b/internal/util/importutil/import_wrapper_test.go @@ -308,6 +308,73 @@ func Test_ImportColumnBased_json(t *testing.T) { assert.NotNil(t, err) } +func Test_ImportColumnBased_StringKey(t *testing.T) { + f := dependency.NewDefaultFactory(true) + ctx := context.Background() + cm, err := f.NewVectorStorageChunkManager(ctx) + assert.NoError(t, err) + defer cm.RemoveWithPrefix("") + + idAllocator := newIDAllocator(ctx, t) + + content := []byte(`{ + "uid": ["Dm4aWrbNzhmjwCTEnCJ9LDPO2N09sqysxgVfbH9Zmn3nBzmwsmk0eZN6x7wSAoPQ", "RP50U0d2napRjXu94a8oGikWgklvVsXFurp8RR4tHGw7N0gk1b7opm59k3FCpyPb", "oxhFkQitWPPw0Bjmj7UQcn4iwvS0CU7RLAC81uQFFQjWtOdiB329CPyWkfGSeYfE", "sxoEL4Mpk1LdsyXhbNm059UWJ3CvxURLCQczaVI5xtBD4QcVWTDFUW7dBdye6nbn", "g33Rqq2UQSHPRHw5FvuXxf5uGEhIAetxE6UuXXCJj0hafG8WuJr1ueZftsySCqAd"], + "int_scalar": [9070353, 8505288, 4392660, 7927425, 9288807], + "float_scalar": [0.9798043638085004, 0.937913432198687, 0.32381232630490264, 0.31074026464844895, 0.4953578200336135], + "string_scalar": ["ShQ44OX0z8kGpRPhaXmfSsdH7JHq5DsZzu0e2umS1hrWG0uONH2RIIAdOECaaXir", "Ld4b0avxathBdNvCrtm3QsWO1pYktUVR7WgAtrtozIwrA8vpeactNhJ85CFGQnK5", "EmAlB0xdQcxeBtwlZJQnLgKodiuRinynoQtg0eXrjkq24dQohzSm7Bx3zquHd3kO", "fdY2beCvs1wSws0Gb9ySD92xwfEfJpX5DQgsWoISylBAoYOcXpRaqIJoXYS4g269", "6f8Iv1zQAGksj5XxMbbI5evTrYrB8fSFQ58jl0oU7Z4BpA81VsD2tlWqkhfoBNa7"], + "bool_scalar": [true, false, true, false, false], + "vectors": [ + [0.5040062902126952, 0.8297619818664708, 0.20248342801564806, 0.12834786423659314], + [0.528232122836893, 0.6916116750653186, 0.41443762522548705, 0.26624344144792056], + [0.7978693027281338, 0.12394906726785092, 0.42431962903815285, 0.4098707807351914], + [0.3716157812069954, 0.006981281113265229, 0.9007003458552365, 0.22492634316191004], + [0.5921374209648096, 0.04234832587925662, 0.7803878096531548, 0.1964045837884633] + ] + }`) + + filePath := TempFilesPath + "columns_2.json" + err = cm.Write(filePath, content) + assert.NoError(t, err) + + rowCount := 0 + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { + count := 0 + for _, data := range fields { + assert.Less(t, 0, data.RowNum()) + if count == 0 { + count = data.RowNum() + } else { + assert.Equal(t, count, data.RowNum()) + } + } + rowCount += count + return nil + } + + // success case + importResult := &rootcoordpb.ImportResult{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + TaskId: 1, + DatanodeId: 1, + State: commonpb.ImportState_ImportStarted, + Segments: make([]int64, 0), + AutoIds: make([]int64, 0), + RowCount: 0, + } + reportFunc := func(res *rootcoordpb.ImportResult) error { + return nil + } + wrapper := NewImportWrapper(ctx, strKeySchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) + files := make([]string, 0) + files = append(files, filePath) + err = wrapper.Import(files, false, false) + assert.Nil(t, err) + assert.Equal(t, 5, rowCount) + assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) +} + func Test_ImportColumnBased_numpy(t *testing.T) { f := dependency.NewDefaultFactory(true) ctx := context.Background() diff --git a/internal/util/importutil/json_handler.go b/internal/util/importutil/json_handler.go index 3db0bf92a5..a1bd45b780 100644 --- a/internal/util/importutil/json_handler.go +++ b/internal/util/importutil/json_handler.go @@ -8,6 +8,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" @@ -46,6 +47,7 @@ type Validator struct { convertFunc func(obj interface{}, field storage.FieldData) error // convert data function primaryKey bool // true for primary key autoID bool // only for primary key field + isString bool // for string field dimension int // only for vector field fieldName string // field name } @@ -76,6 +78,7 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ validators[schema.GetFieldID()].primaryKey = schema.GetIsPrimaryKey() validators[schema.GetFieldID()].autoID = schema.GetAutoID() validators[schema.GetFieldID()].fieldName = schema.GetName() + validators[schema.GetFieldID()].isString = false switch schema.DataType { case schemapb.DataType_Bool: @@ -165,7 +168,7 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ } t := int(vt[i].(float64)) - if t >= 255 || t < 0 { + if t > 255 || t < 0 { msg := "illegal value " + strconv.Itoa(t) + " for binary vector field " + schema.GetName() return errors.New(msg) } @@ -226,6 +229,7 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ return nil } case schemapb.DataType_String, schemapb.DataType_VarChar: + validators[schema.GetFieldID()].isString = true validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { switch obj.(type) { case string: @@ -301,7 +305,7 @@ func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error } if err := validator.validateFunc(value); err != nil { - return errors.New("JSON row validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter, 10)) + return errors.New("JSON row validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10)) } } } @@ -405,6 +409,13 @@ type JSONRowConsumer struct { func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.FieldID]storage.FieldData { segmentData := make(map[storage.FieldID]storage.FieldData) + // rowID field is a hidden field with fieldID=0, it is always auto-generated by IDAllocator + // if primary key is int64 and autoID=true, primary key field is equal to rowID field + segmentData[common.RowIDField] = &storage.Int64FieldData{ + Data: make([]int64, 0), + NumRows: []int64{0}, + } + for i := 0; i < len(collectionSchema.Fields); i++ { schema := collectionSchema.Fields[i] switch schema.DataType { @@ -596,21 +607,41 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { for i := 0; i < len(rows); i++ { row := rows[i] - // firstly get/generate the row id - var id int64 - if primaryValidator.autoID { - id = rowIDBegin + int64(i) - } else { + // hash to a shard number + var shard uint32 + if primaryValidator.isString { + if primaryValidator.autoID { + return errors.New("JSON row consumer: string type primary key cannot be auto-generated") + } + value := row[v.primaryKey] - id = int64(value.(float64)) + pk := string(value.(string)) + hash := typeutil.HashString2Uint32(pk) + shard = hash % uint32(v.shardNum) + pkArray := v.segmentsData[shard][v.primaryKey].(*storage.StringFieldData) + pkArray.Data = append(pkArray.Data, pk) + pkArray.NumRows[0]++ + } else { + // get/generate the row id + var pk int64 + if primaryValidator.autoID { + pk = rowIDBegin + int64(i) + } else { + value := row[v.primaryKey] + pk = int64(value.(float64)) + } + + hash, _ := typeutil.Hash32Int64(pk) + shard = hash % uint32(v.shardNum) + pkArray := v.segmentsData[shard][v.primaryKey].(*storage.Int64FieldData) + pkArray.Data = append(pkArray.Data, pk) + pkArray.NumRows[0]++ } - // hash to a shard number - hash, _ := typeutil.Hash32Int64(id) - shard := hash % uint32(v.shardNum) - pkArray := v.segmentsData[shard][v.primaryKey].(*storage.Int64FieldData) - pkArray.Data = append(pkArray.Data, id) - pkArray.NumRows[0]++ + // set rowid field + rowIDField := v.segmentsData[shard][common.RowIDField].(*storage.Int64FieldData) + rowIDField.Data = append(rowIDField.Data, rowIDBegin+int64(i)) + rowIDField.NumRows[0]++ // convert value and consume for name, validator := range v.validators { @@ -619,7 +650,7 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { } value := row[name] if err := validator.convertFunc(value, v.segmentsData[shard][name]); err != nil { - return errors.New("JSON row consumer: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter, 10)) + return errors.New("JSON row consumer: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10)) } } } diff --git a/internal/util/importutil/json_handler_test.go b/internal/util/importutil/json_handler_test.go index 144289ea30..f160621463 100644 --- a/internal/util/importutil/json_handler_test.go +++ b/internal/util/importutil/json_handler_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/proto/schemapb" @@ -186,7 +187,7 @@ func Test_JSONRowValidator(t *testing.T) { validator := NewJSONRowValidator(schema, nil) err := parser.ParseRows(reader, validator) - assert.Nil(t, err) + assert.NotNil(t, err) assert.Equal(t, int64(0), validator.ValidateCount()) // // missed some fields @@ -247,7 +248,7 @@ func Test_JSONColumnValidator(t *testing.T) { validator := NewJSONColumnValidator(schema, nil) err := parser.ParseColumns(reader, validator) - assert.Nil(t, err) + assert.NotNil(t, err) for _, count := range validator.rowCounter { assert.Equal(t, int64(0), count) } @@ -343,6 +344,128 @@ func Test_JSONRowConsumer(t *testing.T) { assert.Equal(t, 5, totalCount) } +func Test_JSONRowConsumerStringKey(t *testing.T) { + ctx := context.Background() + idAllocator := newIDAllocator(ctx, t) + + schema := strKeySchema() + parser := NewJSONParser(ctx, schema) + assert.NotNil(t, parser) + + reader := strings.NewReader(`{ + "rows": [{ + "uid": "Dm4aWrbNzhmjwCTEnCJ9LDPO2N09sqysxgVfbH9Zmn3nBzmwsmk0eZN6x7wSAoPQ", + "int_scalar": 9070353, + "float_scalar": 0.9798043638085004, + "string_scalar": "ShQ44OX0z8kGpRPhaXmfSsdH7JHq5DsZzu0e2umS1hrWG0uONH2RIIAdOECaaXir", + "bool_scalar": true, + "vectors": [0.5040062902126952, 0.8297619818664708, 0.20248342801564806, 0.12834786423659314] + }, + { + "uid": "RP50U0d2napRjXu94a8oGikWgklvVsXFurp8RR4tHGw7N0gk1b7opm59k3FCpyPb", + "int_scalar": 8505288, + "float_scalar": 0.937913432198687, + "string_scalar": "Ld4b0avxathBdNvCrtm3QsWO1pYktUVR7WgAtrtozIwrA8vpeactNhJ85CFGQnK5", + "bool_scalar": false, + "vectors": [0.528232122836893, 0.6916116750653186, 0.41443762522548705, 0.26624344144792056] + }, + { + "uid": "oxhFkQitWPPw0Bjmj7UQcn4iwvS0CU7RLAC81uQFFQjWtOdiB329CPyWkfGSeYfE", + "int_scalar": 4392660, + "float_scalar": 0.32381232630490264, + "string_scalar": "EmAlB0xdQcxeBtwlZJQnLgKodiuRinynoQtg0eXrjkq24dQohzSm7Bx3zquHd3kO", + "bool_scalar": false, + "vectors": [0.7978693027281338, 0.12394906726785092, 0.42431962903815285, 0.4098707807351914] + }, + { + "uid": "sxoEL4Mpk1LdsyXhbNm059UWJ3CvxURLCQczaVI5xtBD4QcVWTDFUW7dBdye6nbn", + "int_scalar": 7927425, + "float_scalar": 0.31074026464844895, + "string_scalar": "fdY2beCvs1wSws0Gb9ySD92xwfEfJpX5DQgsWoISylBAoYOcXpRaqIJoXYS4g269", + "bool_scalar": true, + "vectors": [0.3716157812069954, 0.006981281113265229, 0.9007003458552365, 0.22492634316191004] + }, + { + "uid": "g33Rqq2UQSHPRHw5FvuXxf5uGEhIAetxE6UuXXCJj0hafG8WuJr1ueZftsySCqAd", + "int_scalar": 9288807, + "float_scalar": 0.4953578200336135, + "string_scalar": "6f8Iv1zQAGksj5XxMbbI5evTrYrB8fSFQ58jl0oU7Z4BpA81VsD2tlWqkhfoBNa7", + "bool_scalar": false, + "vectors": [0.5921374209648096, 0.04234832587925662, 0.7803878096531548, 0.1964045837884633] + }, + { + "uid": "ACIJd7lTXkRgUNmlQk6AbnWIKEEV8Z6OS3vDcm0w9psmt9sH3z1JLg1fNVCqiX3d", + "int_scalar": 1173595, + "float_scalar": 0.9000745450802002, + "string_scalar": "gpj9YctF2ig1l1APkvRzHbVE8PZVKRbk7nvW73qS2uQbY5l7MeIeTPwRBjasbY8z", + "bool_scalar": true, + "vectors": [0.4655121736168688, 0.6195496905333787, 0.5316616196326639, 0.3417492053890768] + }, + { + "uid": "f0wRVZZ9u1bEKrAjLeZj3oliEnUjBiUl6TiermeczceBmGe6M2RHONgz3qEogrd5", + "int_scalar": 3722368, + "float_scalar": 0.7212299175768438, + "string_scalar": "xydiejGUlvS49BfBuy1EuYRKt3v2oKwC6pqy7Ga4dGWn3BnQigV4XAGawixDAGHN", + "bool_scalar": false, + "vectors": [0.6173164237304075, 0.374107748459483, 0.3686321416317251, 0.585725336391797] + }, + { + "uid": "uXq9q96vUqnDebcUISFkRFT27OjD89DWhok6urXIjTuLzaSWnCVTJkrJXxFctSg0", + "int_scalar": 1940731, + "float_scalar": 0.9524404085944204, + "string_scalar": "ZXSNzR5V3t62fjop7b7DHK56ByAF0INYwycKsu6OxGP4p2j0Obs6l0NUqukypGXd", + "bool_scalar": false, + "vectors": [0.07178869784465443, 0.4208459174227864, 0.5882811425075762, 0.6867753592116734] + }, + { + "uid": "EXDDklLvQIfeCJN8cES3b9mdCYDQVhq2iLj8WWA3TPtZ1SZ4Jpidj7OXJidSD7Wn", + "int_scalar": 2158426, + "float_scalar": 0.23770219927963454, + "string_scalar": "9TNeKVSMqTP8Zxs90kaAcB7n6JbIcvFWInzi9JxZQgmYxD5xLYwaCoeUzRiNAxAg", + "bool_scalar": false, + "vectors": [0.5659468293534021, 0.6275816433340369, 0.3978846871291008, 0.3571179679645908] + }, + { + "uid": "mlaXOgYvB88WWRpXNyWv6UqpmvIHrC6pRo03AtaPLMpVymu0L9ioO8GWa1XgGyj0", + "int_scalar": 198279, + "float_scalar": 0.020343767010139513, + "string_scalar": "AblYGRZJiMAwDbMEkungG0yKTeuya4FgyliakWWqSOJ5TvQWB9Ki2WXbnvSsYIDF", + "bool_scalar": true, + "vectors": [0.5374636140212398, 0.7655373567912009, 0.05491796821609715, 0.349384366747262] + } + ] + }`) + + var shardNum int32 = 2 + var callTime int32 + var totalCount int + consumeFunc := func(fields map[storage.FieldID]storage.FieldData, shard int) error { + assert.Equal(t, int(callTime), shard) + callTime++ + rowCount := 0 + for _, data := range fields { + if rowCount == 0 { + rowCount = data.RowNum() + } else { + assert.Equal(t, rowCount, data.RowNum()) + } + } + totalCount += rowCount + return nil + } + + consumer := NewJSONRowConsumer(schema, idAllocator, shardNum, 1, consumeFunc) + assert.NotNil(t, consumer) + + validator := NewJSONRowValidator(schema, consumer) + err := parser.ParseRows(reader, validator) + assert.Nil(t, err) + assert.Equal(t, int64(10), validator.ValidateCount()) + + assert.Equal(t, shardNum, callTime) + assert.Equal(t, 10, totalCount) +} + func Test_JSONColumnConsumer(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -380,7 +503,10 @@ func Test_JSONColumnConsumer(t *testing.T) { rowCount := 0 consumeFunc := func(fields map[storage.FieldID]storage.FieldData) error { callTime++ - for _, data := range fields { + for id, data := range fields { + if id == common.RowIDField { + continue + } if rowCount == 0 { rowCount = data.RowNum() } else { diff --git a/internal/util/importutil/json_parser.go b/internal/util/importutil/json_parser.go index bd0cb317a5..c463a29063 100644 --- a/internal/util/importutil/json_parser.go +++ b/internal/util/importutil/json_parser.go @@ -60,13 +60,14 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { t, err := dec.Token() if err != nil { - return p.logError("JSON parse: " + err.Error()) + return p.logError("JSON parse: row count is 0") } if t != json.Delim('{') { return p.logError("JSON parse: invalid JSON format, the content should be started with'{'") } // read the first level + isEmpty := true for dec.More() { // read the key t, err := dec.Token() @@ -114,6 +115,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { buf = append(buf, row) if len(buf) >= int(p.bufSize) { + isEmpty = false if err = handler.Handle(buf); err != nil { return p.logError(err.Error()) } @@ -125,6 +127,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { // some rows in buffer not parsed, parse them if len(buf) > 0 { + isEmpty = false if err = handler.Handle(buf); err != nil { return p.logError(err.Error()) } @@ -153,6 +156,10 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { break } + if isEmpty { + return p.logError("JSON parse: row count is 0") + } + // send nil to notify the handler all have done return handler.Handle(nil) } @@ -166,13 +173,14 @@ func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error t, err := dec.Token() if err != nil { - return p.logError("JSON parse: " + err.Error()) + return p.logError("JSON parse: row count is 0") } if t != json.Delim('{') { return p.logError("JSON parse: invalid JSON format, the content should be started with'{'") } // read the first level + isEmpty := true for dec.More() { // read the key t, err := dec.Token() @@ -210,6 +218,7 @@ func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error buf[id] = append(buf[id], value) if len(buf[id]) >= int(p.bufSize) { + isEmpty = false if err = handler.Handle(buf); err != nil { return p.logError(err.Error()) } @@ -221,6 +230,7 @@ func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error // some values in buffer not parsed, parse them if len(buf[id]) > 0 { + isEmpty = false if err = handler.Handle(buf); err != nil { return p.logError(err.Error()) } @@ -245,6 +255,10 @@ func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error } } + if isEmpty { + return p.logError("JSON parse: row count is 0") + } + // send nil to notify the handler all have done return handler.Handle(nil) } diff --git a/internal/util/importutil/json_parser_test.go b/internal/util/importutil/json_parser_test.go index 2de3721550..11359b29a1 100644 --- a/internal/util/importutil/json_parser_test.go +++ b/internal/util/importutil/json_parser_test.go @@ -98,6 +98,66 @@ func sampleSchema() *schemapb.CollectionSchema { return schema } +func strKeySchema() *schemapb.CollectionSchema { + schema := &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: false, + Description: "uid", + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "max_length_per_row", Value: "1024"}, + }, + }, + { + FieldID: 102, + Name: "int_scalar", + IsPrimaryKey: false, + Description: "int_scalar", + DataType: schemapb.DataType_Int32, + }, + { + FieldID: 103, + Name: "float_scalar", + IsPrimaryKey: false, + Description: "float_scalar", + DataType: schemapb.DataType_Float, + }, + { + FieldID: 104, + Name: "string_scalar", + IsPrimaryKey: false, + Description: "string_scalar", + DataType: schemapb.DataType_VarChar, + }, + { + FieldID: 105, + Name: "bool_scalar", + IsPrimaryKey: false, + Description: "bool_scalar", + DataType: schemapb.DataType_Bool, + }, + { + FieldID: 106, + Name: "vectors", + IsPrimaryKey: false, + Description: "vectors", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + return schema +} + func Test_ParserRows(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -165,6 +225,11 @@ func Test_ParserRows(t *testing.T) { err = parser.ParseRows(reader, validator) assert.NotNil(t, err) + reader = strings.NewReader(`{}`) + validator = NewJSONRowValidator(schema, nil) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) + reader = strings.NewReader(``) validator = NewJSONRowValidator(schema, nil) err = parser.ParseRows(reader, validator) @@ -217,12 +282,20 @@ func Test_ParserColumns(t *testing.T) { } reader = strings.NewReader(`{ + "field_int8": [10, 11, 12, 13, 14], "dummy":[1, 2, 3] }`) validator = NewJSONColumnValidator(schema, nil) err = parser.ParseColumns(reader, validator) assert.Nil(t, err) + reader = strings.NewReader(`{ + "dummy":[1, 2, 3] + }`) + validator = NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.NotNil(t, err) + reader = strings.NewReader(`{ "field_bool": }`) @@ -249,8 +322,148 @@ func Test_ParserColumns(t *testing.T) { err = parser.ParseColumns(reader, validator) assert.NotNil(t, err) + reader = strings.NewReader(`{}`) + validator = NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.NotNil(t, err) + reader = strings.NewReader(``) validator = NewJSONColumnValidator(schema, nil) err = parser.ParseColumns(reader, validator) assert.NotNil(t, err) } + +func Test_ParserRowsStringKey(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := strKeySchema() + parser := NewJSONParser(ctx, schema) + assert.NotNil(t, parser) + parser.bufSize = 1 + + reader := strings.NewReader(`{ + "rows": [{ + "uid": "Dm4aWrbNzhmjwCTEnCJ9LDPO2N09sqysxgVfbH9Zmn3nBzmwsmk0eZN6x7wSAoPQ", + "int_scalar": 9070353, + "float_scalar": 0.9798043638085004, + "string_scalar": "ShQ44OX0z8kGpRPhaXmfSsdH7JHq5DsZzu0e2umS1hrWG0uONH2RIIAdOECaaXir", + "bool_scalar": true, + "vectors": [0.5040062902126952, 0.8297619818664708, 0.20248342801564806, 0.12834786423659314] + }, + { + "uid": "RP50U0d2napRjXu94a8oGikWgklvVsXFurp8RR4tHGw7N0gk1b7opm59k3FCpyPb", + "int_scalar": 8505288, + "float_scalar": 0.937913432198687, + "string_scalar": "Ld4b0avxathBdNvCrtm3QsWO1pYktUVR7WgAtrtozIwrA8vpeactNhJ85CFGQnK5", + "bool_scalar": false, + "vectors": [0.528232122836893, 0.6916116750653186, 0.41443762522548705, 0.26624344144792056] + }, + { + "uid": "oxhFkQitWPPw0Bjmj7UQcn4iwvS0CU7RLAC81uQFFQjWtOdiB329CPyWkfGSeYfE", + "int_scalar": 4392660, + "float_scalar": 0.32381232630490264, + "string_scalar": "EmAlB0xdQcxeBtwlZJQnLgKodiuRinynoQtg0eXrjkq24dQohzSm7Bx3zquHd3kO", + "bool_scalar": false, + "vectors": [0.7978693027281338, 0.12394906726785092, 0.42431962903815285, 0.4098707807351914] + }, + { + "uid": "sxoEL4Mpk1LdsyXhbNm059UWJ3CvxURLCQczaVI5xtBD4QcVWTDFUW7dBdye6nbn", + "int_scalar": 7927425, + "float_scalar": 0.31074026464844895, + "string_scalar": "fdY2beCvs1wSws0Gb9ySD92xwfEfJpX5DQgsWoISylBAoYOcXpRaqIJoXYS4g269", + "bool_scalar": true, + "vectors": [0.3716157812069954, 0.006981281113265229, 0.9007003458552365, 0.22492634316191004] + }, + { + "uid": "g33Rqq2UQSHPRHw5FvuXxf5uGEhIAetxE6UuXXCJj0hafG8WuJr1ueZftsySCqAd", + "int_scalar": 9288807, + "float_scalar": 0.4953578200336135, + "string_scalar": "6f8Iv1zQAGksj5XxMbbI5evTrYrB8fSFQ58jl0oU7Z4BpA81VsD2tlWqkhfoBNa7", + "bool_scalar": false, + "vectors": [0.5921374209648096, 0.04234832587925662, 0.7803878096531548, 0.1964045837884633] + }, + { + "uid": "ACIJd7lTXkRgUNmlQk6AbnWIKEEV8Z6OS3vDcm0w9psmt9sH3z1JLg1fNVCqiX3d", + "int_scalar": 1173595, + "float_scalar": 0.9000745450802002, + "string_scalar": "gpj9YctF2ig1l1APkvRzHbVE8PZVKRbk7nvW73qS2uQbY5l7MeIeTPwRBjasbY8z", + "bool_scalar": true, + "vectors": [0.4655121736168688, 0.6195496905333787, 0.5316616196326639, 0.3417492053890768] + }, + { + "uid": "f0wRVZZ9u1bEKrAjLeZj3oliEnUjBiUl6TiermeczceBmGe6M2RHONgz3qEogrd5", + "int_scalar": 3722368, + "float_scalar": 0.7212299175768438, + "string_scalar": "xydiejGUlvS49BfBuy1EuYRKt3v2oKwC6pqy7Ga4dGWn3BnQigV4XAGawixDAGHN", + "bool_scalar": false, + "vectors": [0.6173164237304075, 0.374107748459483, 0.3686321416317251, 0.585725336391797] + }, + { + "uid": "uXq9q96vUqnDebcUISFkRFT27OjD89DWhok6urXIjTuLzaSWnCVTJkrJXxFctSg0", + "int_scalar": 1940731, + "float_scalar": 0.9524404085944204, + "string_scalar": "ZXSNzR5V3t62fjop7b7DHK56ByAF0INYwycKsu6OxGP4p2j0Obs6l0NUqukypGXd", + "bool_scalar": false, + "vectors": [0.07178869784465443, 0.4208459174227864, 0.5882811425075762, 0.6867753592116734] + }, + { + "uid": "EXDDklLvQIfeCJN8cES3b9mdCYDQVhq2iLj8WWA3TPtZ1SZ4Jpidj7OXJidSD7Wn", + "int_scalar": 2158426, + "float_scalar": 0.23770219927963454, + "string_scalar": "9TNeKVSMqTP8Zxs90kaAcB7n6JbIcvFWInzi9JxZQgmYxD5xLYwaCoeUzRiNAxAg", + "bool_scalar": false, + "vectors": [0.5659468293534021, 0.6275816433340369, 0.3978846871291008, 0.3571179679645908] + }, + { + "uid": "mlaXOgYvB88WWRpXNyWv6UqpmvIHrC6pRo03AtaPLMpVymu0L9ioO8GWa1XgGyj0", + "int_scalar": 198279, + "float_scalar": 0.020343767010139513, + "string_scalar": "AblYGRZJiMAwDbMEkungG0yKTeuya4FgyliakWWqSOJ5TvQWB9Ki2WXbnvSsYIDF", + "bool_scalar": true, + "vectors": [0.5374636140212398, 0.7655373567912009, 0.05491796821609715, 0.349384366747262] + } + ] + }`) + + validator := NewJSONRowValidator(schema, nil) + err := parser.ParseRows(reader, validator) + assert.Nil(t, err) + assert.Equal(t, int64(10), validator.ValidateCount()) +} + +func Test_ParserColumnsStrKey(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := strKeySchema() + parser := NewJSONParser(ctx, schema) + assert.NotNil(t, parser) + parser.bufSize = 1 + + reader := strings.NewReader(`{ + "uid": ["Dm4aWrbNzhmjwCTEnCJ9LDPO2N09sqysxgVfbH9Zmn3nBzmwsmk0eZN6x7wSAoPQ", "RP50U0d2napRjXu94a8oGikWgklvVsXFurp8RR4tHGw7N0gk1b7opm59k3FCpyPb", "oxhFkQitWPPw0Bjmj7UQcn4iwvS0CU7RLAC81uQFFQjWtOdiB329CPyWkfGSeYfE", "sxoEL4Mpk1LdsyXhbNm059UWJ3CvxURLCQczaVI5xtBD4QcVWTDFUW7dBdye6nbn", "g33Rqq2UQSHPRHw5FvuXxf5uGEhIAetxE6UuXXCJj0hafG8WuJr1ueZftsySCqAd"], + "int_scalar": [9070353, 8505288, 4392660, 7927425, 9288807], + "float_scalar": [0.9798043638085004, 0.937913432198687, 0.32381232630490264, 0.31074026464844895, 0.4953578200336135], + "string_scalar": ["ShQ44OX0z8kGpRPhaXmfSsdH7JHq5DsZzu0e2umS1hrWG0uONH2RIIAdOECaaXir", "Ld4b0avxathBdNvCrtm3QsWO1pYktUVR7WgAtrtozIwrA8vpeactNhJ85CFGQnK5", "EmAlB0xdQcxeBtwlZJQnLgKodiuRinynoQtg0eXrjkq24dQohzSm7Bx3zquHd3kO", "fdY2beCvs1wSws0Gb9ySD92xwfEfJpX5DQgsWoISylBAoYOcXpRaqIJoXYS4g269", "6f8Iv1zQAGksj5XxMbbI5evTrYrB8fSFQ58jl0oU7Z4BpA81VsD2tlWqkhfoBNa7"], + "bool_scalar": [true, false, true, false, false], + "vectors": [ + [0.5040062902126952, 0.8297619818664708, 0.20248342801564806, 0.12834786423659314], + [0.528232122836893, 0.6916116750653186, 0.41443762522548705, 0.26624344144792056], + [0.7978693027281338, 0.12394906726785092, 0.42431962903815285, 0.4098707807351914], + [0.3716157812069954, 0.006981281113265229, 0.9007003458552365, 0.22492634316191004], + [0.5921374209648096, 0.04234832587925662, 0.7803878096531548, 0.1964045837884633] + ] + }`) + + err := parser.ParseColumns(reader, nil) + assert.NotNil(t, err) + + validator := NewJSONColumnValidator(schema, nil) + err = parser.ParseColumns(reader, validator) + assert.Nil(t, err) + counter := validator.ValidateCount() + for _, v := range counter { + assert.Equal(t, int64(5), v) + } +}