diff --git a/internal/datacoord/handler.go b/internal/datacoord/handler.go index 9a56843501..34b33b5e99 100644 --- a/internal/datacoord/handler.go +++ b/internal/datacoord/handler.go @@ -71,7 +71,7 @@ func (h *ServerHandler) GetDataVChanPositions(channel *channel, partitionID Uniq continue } if s.GetIsImporting() { - // Skip bulk load segments. + // Skip bulk insert segments. continue } @@ -149,7 +149,7 @@ func (h *ServerHandler) GetQueryVChanPositions(channel *channel, partitionID Uni continue } if s.GetIsImporting() { - // Skip bulk load segments. + // Skip bulk insert segments. continue } segmentInfos[s.GetID()] = s diff --git a/internal/datacoord/segment_manager.go b/internal/datacoord/segment_manager.go index 993124d96a..e4f98c33fb 100644 --- a/internal/datacoord/segment_manager.go +++ b/internal/datacoord/segment_manager.go @@ -68,7 +68,7 @@ func putAllocation(a *Allocation) { type Manager interface { // AllocSegment allocates rows and record the allocation. AllocSegment(ctx context.Context, collectionID, partitionID UniqueID, channelName string, requestRows int64) ([]*Allocation, error) - // allocSegmentForImport allocates one segment allocation for bulk load. + // allocSegmentForImport allocates one segment allocation for bulk insert. // TODO: Remove this method and AllocSegment() above instead. allocSegmentForImport(ctx context.Context, collectionID, partitionID UniqueID, channelName string, requestRows int64, taskID int64) (*Allocation, error) // DropSegment drops the segment from manager. @@ -278,7 +278,7 @@ func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID return allocations, nil } -// allocSegmentForImport allocates one segment allocation for bulk load. +// allocSegmentForImport allocates one segment allocation for bulk insert. func (s *SegmentManager) allocSegmentForImport(ctx context.Context, collectionID UniqueID, partitionID UniqueID, channelName string, requestRows int64, importTaskID int64) (*Allocation, error) { // init allocation diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 7e2c4bb079..2999f19ab5 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -637,7 +637,7 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf if segment.State != commonpb.SegmentState_Flushed && segment.State != commonpb.SegmentState_Flushing && segment.State != commonpb.SegmentState_Dropped { continue } - // Also skip bulk load segments. + // Also skip bulk insert segments. if segment.GetIsImporting() { continue } diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index 5c0f0fee1e..87d959240a 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -1074,16 +1074,19 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) // parse files and generate segments segmentSize := int64(Params.DataCoordCfg.SegmentMaxSize) * 1024 * 1024 - importWrapper := importutil.NewImportWrapper(newCtx, colInfo.GetSchema(), colInfo.GetShardsNum(), segmentSize, node.rowIDAllocator, node.chunkManager, - importFlushReqFunc(node, req, importResult, colInfo.GetSchema(), ts), importResult, reportFunc) + importWrapper := importutil.NewImportWrapper(newCtx, colInfo.GetSchema(), colInfo.GetShardsNum(), segmentSize, node.rowIDAllocator, + node.chunkManager, importResult, reportFunc) + importWrapper.SetCallbackFunctions(assignSegmentFunc(node, req), + createBinLogsFunc(node, req, colInfo.GetSchema(), ts), + saveSegmentFunc(node, req, importResult, ts)) // todo: pass tsStart and tsStart after import_wrapper support tsStart, tsEnd, err := importutil.ParseTSFromOptions(req.GetImportTask().GetInfos()) if err != nil { return returnFailFunc(err) } log.Info("import time range", zap.Uint64("start_ts", tsStart), zap.Uint64("end_ts", tsEnd)) - err = importWrapper.Import(req.GetImportTask().GetFiles(), req.GetImportTask().GetRowBased(), false) - //err = importWrapper.Import(req.GetImportTask().GetFiles(), req.GetImportTask().GetRowBased(), false, tsStart, tsEnd) + err = importWrapper.Import(req.GetImportTask().GetFiles(), + importutil.ImportOptions{OnlyValidate: false, TsStartPoint: tsStart, TsEndPoint: tsEnd}) if err != nil { return returnFailFunc(err) } @@ -1183,8 +1186,8 @@ func (node *DataNode) AddImportSegment(ctx context.Context, req *datapb.AddImpor }, nil } -func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoordpb.ImportResult, schema *schemapb.CollectionSchema, ts Timestamp) importutil.ImportFlushFunc { - return func(fields map[storage.FieldID]storage.FieldData, shardID int) error { +func assignSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest) importutil.AssignSegmentFunc { + return func(shardID int) (int64, string, error) { chNames := req.GetImportTask().GetChannelNames() importTaskID := req.GetImportTask().GetTaskId() if shardID >= len(chNames) { @@ -1194,53 +1197,91 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root zap.Int("# of channels", len(chNames)), zap.Strings("channel names", chNames), ) - return fmt.Errorf("syncSegmentID Failed: invalid shard ID %d", shardID) + return 0, "", fmt.Errorf("syncSegmentID Failed: invalid shard ID %d", shardID) } tr := timerecord.NewTimeRecorder("import callback function") defer tr.Elapse("finished") + colID := req.GetImportTask().GetCollectionId() + partID := req.GetImportTask().GetPartitionId() + segmentIDReq := composeAssignSegmentIDRequest(1, shardID, chNames, colID, partID) + targetChName := segmentIDReq.GetSegmentIDRequests()[0].GetChannelName() + log.Info("target channel for the import task", + zap.Int64("task ID", importTaskID), + zap.Int("shard ID", shardID), + zap.String("target channel name", targetChName)) + resp, err := node.dataCoord.AssignSegmentID(context.Background(), segmentIDReq) + if err != nil { + return 0, "", fmt.Errorf("syncSegmentID Failed:%w", err) + } + if resp.Status.ErrorCode != commonpb.ErrorCode_Success { + return 0, "", fmt.Errorf("syncSegmentID Failed:%s", resp.Status.Reason) + } + segmentID := resp.SegIDAssignments[0].SegID + log.Info("new segment assigned", + zap.Int64("task ID", importTaskID), + zap.Int64("segmentID", segmentID), + zap.Int("shard ID", shardID), + zap.String("target channel name", targetChName)) + return segmentID, targetChName, nil + } +} + +func createBinLogsFunc(node *DataNode, req *datapb.ImportTaskRequest, schema *schemapb.CollectionSchema, ts Timestamp) importutil.CreateBinlogsFunc { + return func(fields map[storage.FieldID]storage.FieldData, segmentID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error) { var rowNum int for _, field := range fields { rowNum = field.RowNum() break } + chNames := req.GetImportTask().GetChannelNames() + importTaskID := req.GetImportTask().GetTaskId() if rowNum <= 0 { - log.Info("fields data is empty, no need to generate segment", + log.Info("fields data is empty, no need to generate binlog", zap.Int64("task ID", importTaskID), - zap.Int("shard ID", shardID), zap.Int("# of channels", len(chNames)), zap.Strings("channel names", chNames), ) - return nil + return nil, nil, nil } colID := req.GetImportTask().GetCollectionId() partID := req.GetImportTask().GetPartitionId() - segmentIDReq := composeAssignSegmentIDRequest(rowNum, shardID, chNames, colID, partID) - targetChName := segmentIDReq.GetSegmentIDRequests()[0].GetChannelName() - log.Info("target channel for the import task", - zap.Int64("task ID", importTaskID), - zap.String("target channel name", targetChName)) - resp, err := node.dataCoord.AssignSegmentID(context.Background(), segmentIDReq) - if err != nil { - return fmt.Errorf("syncSegmentID Failed:%w", err) - } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return fmt.Errorf("syncSegmentID Failed:%s", resp.Status.Reason) - } - segmentID := resp.SegIDAssignments[0].SegID fieldInsert, fieldStats, err := createBinLogs(rowNum, schema, ts, fields, node, segmentID, colID, partID) if err != nil { - return err + log.Error("failed to create binlogs", + zap.Int64("task ID", importTaskID), + zap.Int("# of channels", len(chNames)), + zap.Strings("channel names", chNames), + zap.Any("err", err), + ) + return nil, nil, err } + log.Info("new binlog created", + zap.Int64("task ID", importTaskID), + zap.Int64("segmentID", segmentID), + zap.Int("insert log count", len(fieldInsert)), + zap.Int("stats log count", len(fieldStats))) + + return fieldInsert, fieldStats, err + } +} + +func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoordpb.ImportResult, ts Timestamp) importutil.SaveSegmentFunc { + importTaskID := req.GetImportTask().GetTaskId() + return func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64) error { log.Info("adding segment to the correct DataNode flow graph and saving binlog paths", - zap.Int64("segment ID", segmentID), + zap.Int64("task ID", importTaskID), + zap.Int64("segmentID", segmentID), + zap.String("targetChName", targetChName), + zap.Int64("rowCount", rowCount), zap.Uint64("ts", ts)) - err = retry.Do(context.Background(), func() error { + + err := retry.Do(context.Background(), func() error { // Ask DataCoord to save binlog path and add segment to the corresponding DataNode flow graph. resp, err := node.dataCoord.SaveImportSegment(context.Background(), &datapb.SaveImportSegmentRequest{ Base: commonpbutil.NewMsgBase( @@ -1251,7 +1292,7 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root ChannelName: targetChName, CollectionId: req.GetImportTask().GetCollectionId(), PartitionId: req.GetImportTask().GetPartitionId(), - RowNum: int64(rowNum), + RowNum: rowCount, SaveBinlogPathReq: &datapb.SaveBinlogPathsRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(0), @@ -1261,8 +1302,8 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root ), SegmentID: segmentID, CollectionID: req.GetImportTask().GetCollectionId(), - Field2BinlogPaths: fieldInsert, - Field2StatslogPaths: fieldStats, + Field2BinlogPaths: fieldsInsert, + Field2StatslogPaths: fieldsStats, // Set start positions of a SaveBinlogPathRequest explicitly. StartPositions: []*datapb.SegmentStartPosition{ { @@ -1291,9 +1332,11 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root log.Warn("failed to save import segment", zap.Error(err)) return err } - log.Info("segment imported and persisted", zap.Int64("segmentID", segmentID)) + log.Info("segment imported and persisted", + zap.Int64("task ID", importTaskID), + zap.Int64("segmentID", segmentID)) res.Segments = append(res.Segments, segmentID) - res.RowCount += int64(rowNum) + res.RowCount += rowCount return nil } } diff --git a/internal/datanode/data_node_test.go b/internal/datanode/data_node_test.go index 2c112850aa..77af7b6bbd 100644 --- a/internal/datanode/data_node_test.go +++ b/internal/datanode/data_node_test.go @@ -38,7 +38,6 @@ import ( "github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" @@ -603,30 +602,6 @@ func TestDataNode(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stat.GetErrorCode()) }) - t.Run("Test Import callback func error", func(t *testing.T) { - req := &datapb.ImportTaskRequest{ - ImportTask: &datapb.ImportTask{ - CollectionId: 100, - PartitionId: 100, - ChannelNames: []string{"ch1", "ch2"}, - }, - } - importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - TaskId: 0, - DatanodeId: 0, - State: commonpb.ImportState_ImportStarted, - Segments: make([]int64, 0), - AutoIds: make([]int64, 0), - RowCount: 0, - } - callback := importFlushReqFunc(node, req, importResult, nil, 0) - err := callback(nil, len(req.ImportTask.ChannelNames)+1) - assert.Error(t, err) - }) - t.Run("Test BackGroundGC", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) diff --git a/internal/proto/data_coord.proto b/internal/proto/data_coord.proto index 54ec62dfc7..16b9aa1aef 100644 --- a/internal/proto/data_coord.proto +++ b/internal/proto/data_coord.proto @@ -114,7 +114,7 @@ message SegmentIDRequest { string channel_name = 2; int64 collectionID = 3; int64 partitionID = 4; - bool isImport = 5; // Indicate whether this request comes from a bulk load task. + bool isImport = 5; // Indicate whether this request comes from a bulk insert task. int64 importTaskID = 6; // Needed for segment lock. } @@ -269,8 +269,8 @@ message SegmentInfo { repeated int64 compactionFrom = 15; uint64 dropped_at = 16; // timestamp when segment marked drop // A flag indicating if: - // (1) this segment is created by bulk load, and - // (2) the bulk load task that creates this segment has not yet reached `ImportCompleted` state. + // (1) this segment is created by bulk insert, and + // (2) the bulk insert task that creates this segment has not yet reached `ImportCompleted` state. bool is_importing = 17; bool is_fake = 18; } diff --git a/internal/proto/datapb/data_coord.pb.go b/internal/proto/datapb/data_coord.pb.go index 62f9254dba..65988cb76e 100644 --- a/internal/proto/datapb/data_coord.pb.go +++ b/internal/proto/datapb/data_coord.pb.go @@ -1598,8 +1598,8 @@ type SegmentInfo struct { CompactionFrom []int64 `protobuf:"varint,15,rep,packed,name=compactionFrom,proto3" json:"compactionFrom,omitempty"` DroppedAt uint64 `protobuf:"varint,16,opt,name=dropped_at,json=droppedAt,proto3" json:"dropped_at,omitempty"` // A flag indicating if: - // (1) this segment is created by bulk load, and - // (2) the bulk load task that creates this segment has not yet reached `ImportCompleted` state. + // (1) this segment is created by bulk insert, and + // (2) the bulk insert task that creates this segment has not yet reached `ImportCompleted` state. IsImporting bool `protobuf:"varint,17,opt,name=is_importing,json=isImporting,proto3" json:"is_importing,omitempty"` IsFake bool `protobuf:"varint,18,opt,name=is_fake,json=isFake,proto3" json:"is_fake,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index cb62cf0002..4f32aca9b9 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -3967,7 +3967,8 @@ func unhealthyStatus() *commonpb.Status { func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { log.Info("received import request", zap.String("collection name", req.GetCollectionName()), - zap.Bool("row-based", req.GetRowBased())) + zap.String("partition name", req.GetPartitionName()), + zap.Strings("files", req.GetFiles())) resp := &milvuspb.ImportResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, @@ -3996,7 +3997,7 @@ func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*mi respFromRC, err := node.rootCoord.Import(ctx, req) if err != nil { metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method, metrics.FailLabel).Inc() - log.Error("failed to execute bulk load request", zap.Error(err)) + log.Error("failed to execute bulk insert request", zap.Error(err)) resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError resp.Status.Reason = err.Error() return resp, nil diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index b4987b44d2..1085f6ad04 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -1698,7 +1698,7 @@ func TestProxy(t *testing.T) { defer wg.Done() req := &milvuspb.ImportRequest{ CollectionName: collectionName, - Files: []string{"f1.json", "f2.json", "f3.csv"}, + Files: []string{"f1.json", "f2.json", "f3.json"}, } proxy.stateCode.Store(commonpb.StateCode_Healthy) resp, err := proxy.Import(context.TODO(), req) diff --git a/internal/rootcoord/import_manager.go b/internal/rootcoord/import_manager.go index 44f5e85416..68f6122bc8 100644 --- a/internal/rootcoord/import_manager.go +++ b/internal/rootcoord/import_manager.go @@ -34,7 +34,9 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/util/importutil" "github.com/milvus-io/milvus/internal/util/typeutil" + "go.uber.org/zap" ) @@ -210,7 +212,6 @@ func (m *importManager) sendOutTasks(ctx context.Context) error { CollectionId: task.GetCollectionId(), PartitionId: task.GetPartitionId(), ChannelNames: task.GetChannelNames(), - RowBased: task.GetRowBased(), TaskId: task.GetId(), Files: task.GetFiles(), Infos: task.GetInfos(), @@ -381,25 +382,39 @@ func (m *importManager) checkIndexingDone(ctx context.Context, collID UniqueID, return len(allSegmentIDs) == indexedSegmentCount, nil } +func (m *importManager) isRowbased(files []string) (bool, error) { + isRowBased := false + for _, filePath := range files { + _, fileType := importutil.GetFileNameAndExt(filePath) + if fileType == importutil.JSONFileExt { + isRowBased = true + } else if isRowBased { + log.Error("row-based data file type must be JSON, mixed file types is not allowed", zap.Strings("files", files)) + return isRowBased, fmt.Errorf("row-based data file type must be JSON, file type '%s' is not allowed", fileType) + } + } + + return isRowBased, nil +} + // importJob processes the import request, generates import tasks, sends these tasks to DataCoord, and returns // immediately. func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportRequest, cID int64, pID int64) *milvuspb.ImportResponse { - if req == nil || len(req.Files) == 0 { + returnErrorFunc := func(reason string) *milvuspb.ImportResponse { return &milvuspb.ImportResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "import request is empty", + Reason: reason, }, } } + if req == nil || len(req.Files) == 0 { + return returnErrorFunc("import request is empty") + } + if m.callImportService == nil { - return &milvuspb.ImportResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "import service is not available", - }, - } + return returnErrorFunc("import service is not available") } resp := &milvuspb.ImportResponse{ @@ -409,7 +424,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque Tasks: make([]int64, 0), } - log.Debug("request received", + log.Debug("receive import job", zap.String("collection name", req.GetCollectionName()), zap.Int64("collection ID", cID), zap.Int64("partition ID", pID)) @@ -420,8 +435,13 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque capacity := cap(m.pendingTasks) length := len(m.pendingTasks) + isRowBased, err := m.isRowbased(req.GetFiles()) + if err != nil { + return err + } + taskCount := 1 - if req.RowBased { + if isRowBased { taskCount = len(req.Files) } @@ -433,7 +453,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque } // convert import request to import tasks - if req.RowBased { + if isRowBased { // For row-based importing, each file makes a task. taskList := make([]int64, len(req.Files)) for i := 0; i < len(req.Files); i++ { @@ -446,7 +466,6 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque CollectionId: cID, PartitionId: pID, ChannelNames: req.ChannelNames, - RowBased: req.GetRowBased(), Files: []string{req.GetFiles()[i]}, CreateTs: time.Now().Unix(), State: &datapb.ImportTaskState{ @@ -485,7 +504,6 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque CollectionId: cID, PartitionId: pID, ChannelNames: req.ChannelNames, - RowBased: req.GetRowBased(), Files: req.GetFiles(), CreateTs: time.Now().Unix(), State: &datapb.ImportTaskState{ @@ -514,12 +532,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque return nil }() if err != nil { - return &milvuspb.ImportResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - } + return returnErrorFunc(err.Error()) } if sendOutTasksErr := m.sendOutTasks(ctx); sendOutTasksErr != nil { log.Error("fail to send out tasks", zap.Error(sendOutTasksErr)) @@ -1054,7 +1067,6 @@ func cloneImportTaskInfo(taskInfo *datapb.ImportTaskInfo) *datapb.ImportTaskInfo CollectionId: taskInfo.GetCollectionId(), PartitionId: taskInfo.GetPartitionId(), ChannelNames: taskInfo.GetChannelNames(), - RowBased: taskInfo.GetRowBased(), Files: taskInfo.GetFiles(), CreateTs: taskInfo.GetCreateTs(), State: taskInfo.GetState(), diff --git a/internal/rootcoord/import_manager_test.go b/internal/rootcoord/import_manager_test.go index 3b6521ad50..0139312e8c 100644 --- a/internal/rootcoord/import_manager_test.go +++ b/internal/rootcoord/import_manager_test.go @@ -574,6 +574,8 @@ func TestImportManager_ImportJob(t *testing.T) { ErrorCode: commonpb.ErrorCode_Success, }, nil } + + // nil request mgr := newImportManager(context.TODO(), mockKv, idAlloc, nil, callMarkSegmentsDropped, nil, nil, nil, nil) resp := mgr.importJob(context.TODO(), nil, colID, 0) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) @@ -581,27 +583,14 @@ func TestImportManager_ImportJob(t *testing.T) { rowReq := &milvuspb.ImportRequest{ CollectionName: "c1", PartitionName: "p1", - RowBased: true, - Files: []string{"f1", "f2", "f3"}, + Files: []string{"f1.json", "f2.json", "f3.json"}, } + // nil callImportService resp = mgr.importJob(context.TODO(), rowReq, colID, 0) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - colReq := &milvuspb.ImportRequest{ - CollectionName: "c1", - PartitionName: "p1", - RowBased: false, - Files: []string{"f1", "f2"}, - Options: []*commonpb.KeyValuePair{ - { - Key: importutil.Bucket, - Value: "mybucket", - }, - }, - } - - fn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { + importServiceFunc := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -609,17 +598,27 @@ func TestImportManager_ImportJob(t *testing.T) { }, nil } - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil) + // row-based case, task count equal to file count + // since the importServiceFunc return error, tasks will be kept in pending list + mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil) resp = mgr.importJob(context.TODO(), rowReq, colID, 0) assert.Equal(t, len(rowReq.Files), len(mgr.pendingTasks)) assert.Equal(t, 0, len(mgr.workingTasks)) - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil) + colReq := &milvuspb.ImportRequest{ + CollectionName: "c1", + PartitionName: "p1", + Files: []string{"f1.npy", "f2.npy", "f3.npy"}, + } + + // column-based case, one quest one task + // since the importServiceFunc return error, tasks will be kept in pending list + mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil) resp = mgr.importJob(context.TODO(), colReq, colID, 0) assert.Equal(t, 1, len(mgr.pendingTasks)) assert.Equal(t, 0, len(mgr.workingTasks)) - fn = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { + importServiceFunc = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, @@ -627,18 +626,20 @@ func TestImportManager_ImportJob(t *testing.T) { }, nil } - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil) + // row-based case, since the importServiceFunc return success, tasks will be sent to woring list + mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil) resp = mgr.importJob(context.TODO(), rowReq, colID, 0) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, len(rowReq.Files), len(mgr.workingTasks)) - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil) + // column-based case, since the importServiceFunc return success, tasks will be sent to woring list + mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil) resp = mgr.importJob(context.TODO(), colReq, colID, 0) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, 1, len(mgr.workingTasks)) count := 0 - fn = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { + importServiceFunc = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { if count >= 2 { return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ @@ -654,13 +655,16 @@ func TestImportManager_ImportJob(t *testing.T) { }, nil } - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil) + // row-based case, since the importServiceFunc return success for 2 tasks + // the 2 tasks are sent to working list, and 1 task left in pending list + mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil) resp = mgr.importJob(context.TODO(), rowReq, colID, 0) assert.Equal(t, len(rowReq.Files)-2, len(mgr.pendingTasks)) assert.Equal(t, 2, len(mgr.workingTasks)) - for i := 0; i <= 32; i++ { - rowReq.Files = append(rowReq.Files, strconv.Itoa(i)) + // files count exceed MaxPendingCount, return error + for i := 0; i <= MaxPendingCount; i++ { + rowReq.Files = append(rowReq.Files, strconv.Itoa(i)+".json") } resp = mgr.importJob(context.TODO(), rowReq, colID, 0) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) @@ -683,14 +687,12 @@ func TestImportManager_AllDataNodesBusy(t *testing.T) { rowReq := &milvuspb.ImportRequest{ CollectionName: "c1", PartitionName: "p1", - RowBased: true, - Files: []string{"f1", "f2", "f3"}, + Files: []string{"f1.json", "f2.json", "f3.json"}, } colReq := &milvuspb.ImportRequest{ CollectionName: "c1", PartitionName: "p1", - RowBased: false, - Files: []string{"f1", "f2"}, + Files: []string{"f1.npy", "f2.npy"}, Options: []*commonpb.KeyValuePair{ { Key: importutil.Bucket, @@ -778,8 +780,7 @@ func TestImportManager_TaskState(t *testing.T) { rowReq := &milvuspb.ImportRequest{ CollectionName: "c1", PartitionName: "p1", - RowBased: true, - Files: []string{"f1", "f2", "f3"}, + Files: []string{"f1.json", "f2.json", "f3.json"}, } callMarkSegmentsDropped := func(ctx context.Context, segIDs []typeutil.UniqueID) (*commonpb.Status, error) { @@ -817,8 +818,7 @@ func TestImportManager_TaskState(t *testing.T) { assert.Equal(t, int64(2), ti.GetId()) assert.Equal(t, int64(100), ti.GetCollectionId()) assert.Equal(t, int64(0), ti.GetPartitionId()) - assert.Equal(t, true, ti.GetRowBased()) - assert.Equal(t, []string{"f2"}, ti.GetFiles()) + assert.Equal(t, []string{"f2.json"}, ti.GetFiles()) assert.Equal(t, commonpb.ImportState_ImportPersisted, ti.GetState().GetStateCode()) assert.Equal(t, int64(1000), ti.GetState().GetRowCount()) @@ -876,8 +876,7 @@ func TestImportManager_AllocFail(t *testing.T) { rowReq := &milvuspb.ImportRequest{ CollectionName: "c1", PartitionName: "p1", - RowBased: true, - Files: []string{"f1", "f2", "f3"}, + Files: []string{"f1.json", "f2.json", "f3.json"}, } callMarkSegmentsDropped := func(ctx context.Context, segIDs []typeutil.UniqueID) (*commonpb.Status, error) { @@ -917,8 +916,7 @@ func TestImportManager_ListAllTasks(t *testing.T) { rowReq := &milvuspb.ImportRequest{ CollectionName: "c1", PartitionName: "p1", - RowBased: true, - Files: []string{"f1", "f2", "f3"}, + Files: []string{"f1.json", "f2.json", "f3.json"}, } callMarkSegmentsDropped := func(ctx context.Context, segIDs []typeutil.UniqueID) (*commonpb.Status, error) { return &commonpb.Status{ @@ -1007,3 +1005,22 @@ func TestImportManager_rearrangeTasks(t *testing.T) { assert.Equal(t, int64(50), tasks[1].GetId()) assert.Equal(t, int64(100), tasks[2].GetId()) } + +func TestImportManager_isRowbased(t *testing.T) { + mgr := &importManager{} + + files := []string{"1.json", "2.json"} + rb, err := mgr.isRowbased(files) + assert.Nil(t, err) + assert.True(t, rb) + + files = []string{"1.json", "2.npy"} + rb, err = mgr.isRowbased(files) + assert.NotNil(t, err) + assert.True(t, rb) + + files = []string{"1.npy", "2.npy"} + rb, err = mgr.isRowbased(files) + assert.Nil(t, err) + assert.False(t, rb) +} diff --git a/internal/rootcoord/meta_table.go b/internal/rootcoord/meta_table.go index 7d1961d4cb..50e0f3d2fe 100644 --- a/internal/rootcoord/meta_table.go +++ b/internal/rootcoord/meta_table.go @@ -86,8 +86,8 @@ type IMetaTable interface { ListAliasesByID(collID UniqueID) []string // TODO: better to accept ctx. - GetPartitionNameByID(collID UniqueID, partitionID UniqueID, ts Timestamp) (string, error) // serve for bulk load. - GetPartitionByName(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) // serve for bulk load. + GetPartitionNameByID(collID UniqueID, partitionID UniqueID, ts Timestamp) (string, error) // serve for bulk insert. + GetPartitionByName(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) // serve for bulk insert. // TODO: better to accept ctx. AddCredential(credInfo *internalpb.CredentialInfo) error @@ -634,7 +634,7 @@ func (mt *MetaTable) ListAliasesByID(collID UniqueID) []string { return mt.listAliasesByID(collID) } -// GetCollectionNameByID serve for bulk load. TODO: why this didn't accept ts? +// GetCollectionNameByID serve for bulk insert. TODO: why this didn't accept ts? // [Deprecated] func (mt *MetaTable) GetCollectionNameByID(collID UniqueID) (string, error) { mt.ddLock.RLock() @@ -648,7 +648,7 @@ func (mt *MetaTable) GetCollectionNameByID(collID UniqueID) (string, error) { return coll.Name, nil } -// GetPartitionNameByID serve for bulk load. +// GetPartitionNameByID serve for bulk insert. func (mt *MetaTable) GetPartitionNameByID(collID UniqueID, partitionID UniqueID, ts Timestamp) (string, error) { mt.ddLock.RLock() defer mt.ddLock.RUnlock() @@ -680,7 +680,7 @@ func (mt *MetaTable) GetPartitionNameByID(collID UniqueID, partitionID UniqueID, return "", fmt.Errorf("partition not exist: %d", partitionID) } -// GetCollectionIDByName serve for bulk load. TODO: why this didn't accept ts? +// GetCollectionIDByName serve for bulk insert. TODO: why this didn't accept ts? // [Deprecated] func (mt *MetaTable) GetCollectionIDByName(name string) (UniqueID, error) { mt.ddLock.RLock() @@ -693,7 +693,7 @@ func (mt *MetaTable) GetCollectionIDByName(name string) (UniqueID, error) { return id, nil } -// GetPartitionByName serve for bulk load. +// GetPartitionByName serve for bulk insert. func (mt *MetaTable) GetPartitionByName(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) { mt.ddLock.RLock() defer mt.ddLock.RUnlock() diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index e6ae475614..ab59feb6d7 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -1594,7 +1594,6 @@ func (c *Core) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvus zap.Strings("virtual channel names", req.GetChannelNames()), zap.Int64("partition ID", pID), zap.Int("# of files = ", len(req.GetFiles())), - zap.Bool("row-based", req.GetRowBased()), ) importJobResp := c.importManager.importJob(ctx, req, cID, pID) return importJobResp, nil @@ -1692,7 +1691,7 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) ( resendTaskFunc() // Flush all import data segments. if err := c.broker.Flush(ctx, ti.GetCollectionId(), ir.GetSegments()); err != nil { - log.Error("failed to call Flush on bulk load segments", + log.Error("failed to call Flush on bulk insert segments", zap.Int64("task ID", ir.GetTaskId())) return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, diff --git a/internal/util/importutil/binlog_adapter.go b/internal/util/importutil/binlog_adapter.go index 969e6d0b37..ec20170022 100644 --- a/internal/util/importutil/binlog_adapter.go +++ b/internal/util/importutil/binlog_adapter.go @@ -17,6 +17,7 @@ package importutil import ( + "context" "encoding/json" "errors" "fmt" @@ -43,29 +44,39 @@ type SegmentFilesHolder struct { // 1. read insert log of each field, then constructs map[storage.FieldID]storage.FieldData in memory. // 2. read delta log to remove deleted entities(TimeStampField is used to apply or skip the operation). // 3. split data according to shard number -// 4. call the callFlushFunc function to flush data into new segment if data size reaches segmentSize. +// 4. call the callFlushFunc function to flush data into new binlog file if data size reaches blockSize. type BinlogAdapter struct { + ctx context.Context // for canceling parse process collectionSchema *schemapb.CollectionSchema // collection schema chunkManager storage.ChunkManager // storage interfaces to read binlog files callFlushFunc ImportFlushFunc // call back function to flush segment shardNum int32 // sharding number of the collection - segmentSize int64 // maximum size of a segment(unit:byte) + blockSize int64 // maximum size of a read block(unit:byte) maxTotalSize int64 // maximum size of in-memory segments(unit:byte) primaryKey storage.FieldID // id of primary key primaryType schemapb.DataType // data type of primary key - // a timestamp to define the end point of restore, data after this point will be ignored + // a timestamp to define the start time point of restore, data before this time point will be ignored + // set this value to 0, all the data will be imported + // set this value to math.MaxUint64, all the data will be ignored + // the tsStartPoint value must be less/equal than tsEndPoint + tsStartPoint uint64 + + // a timestamp to define the end time point of restore, data after this time point will be ignored // set this value to 0, all the data will be ignored // set this value to math.MaxUint64, all the data will be imported + // the tsEndPoint value must be larger/equal than tsStartPoint tsEndPoint uint64 } -func NewBinlogAdapter(collectionSchema *schemapb.CollectionSchema, +func NewBinlogAdapter(ctx context.Context, + collectionSchema *schemapb.CollectionSchema, shardNum int32, - segmentSize int64, + blockSize int64, maxTotalSize int64, chunkManager storage.ChunkManager, flushFunc ImportFlushFunc, + tsStartPoint uint64, tsEndPoint uint64) (*BinlogAdapter, error) { if collectionSchema == nil { log.Error("Binlog adapter: collection schema is nil") @@ -83,18 +94,20 @@ func NewBinlogAdapter(collectionSchema *schemapb.CollectionSchema, } adapter := &BinlogAdapter{ + ctx: ctx, collectionSchema: collectionSchema, chunkManager: chunkManager, callFlushFunc: flushFunc, shardNum: shardNum, - segmentSize: segmentSize, + blockSize: blockSize, maxTotalSize: maxTotalSize, + tsStartPoint: tsStartPoint, tsEndPoint: tsEndPoint, } // amend the segment size to avoid portential OOM risk - if adapter.segmentSize > MaxSegmentSizeInMemory { - adapter.segmentSize = MaxSegmentSizeInMemory + if adapter.blockSize > MaxSegmentSizeInMemory { + adapter.blockSize = MaxSegmentSizeInMemory } // find out the primary key ID and its data type @@ -142,7 +155,7 @@ func (p *BinlogAdapter) Read(segmentHolder *SegmentFilesHolder) error { // b has these binlog files: b_1, b_2, b_3 ... // Then first round read a_1 and b_1, second round read a_2 and b_2, etc... // deleted list will be used to remove deleted entities - // if accumulate data exceed segmentSize, call callFlushFunc to generate new segment + // if accumulate data exceed blockSize, call callFlushFunc to generate new binlog file batchCount := 0 for _, files := range segmentHolder.fieldFiles { batchCount = len(files) @@ -215,24 +228,30 @@ func (p *BinlogAdapter) Read(segmentHolder *SegmentFilesHolder) error { // read other insert logs and use the shardList to do sharding for fieldID, file := range batchFiles { + // outside context might be canceled(service stop, or future enhancement for canceling import task) + if isCanceled(p.ctx) { + log.Error("Binlog adapter: import task was canceled") + return errors.New("import task was canceled") + } + err = p.readInsertlog(fieldID, file, segmentsData, shardList) if err != nil { return err } } - // flush segment whose size exceed segmentSize - err = p.tryFlushSegments(segmentsData, false) + // flush segment whose size exceed blockSize + err = tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.callFlushFunc, p.blockSize, p.maxTotalSize, false) if err != nil { return err } } // finally, force to flush - return p.tryFlushSegments(segmentsData, true) + return tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.callFlushFunc, p.blockSize, p.maxTotalSize, true) } -// This method verify the schema and binlog files +// verify method verify the schema and binlog files // 1. each field must has binlog file // 2. binlog file count of each field must be equal // 3. the collectionSchema doesn't contain TimeStampField and RowIDField since the import_wrapper excludes them, @@ -284,7 +303,7 @@ func (p *BinlogAdapter) verify(segmentHolder *SegmentFilesHolder) error { return nil } -// This method read data from deltalog, and convert to a dict +// readDeltalogs method reads data from deltalog, and convert to a dict // The deltalog data is a list, to improve performance of next step, we convert it to a dict, // key is the deleted ID, value is operation timestamp which is used to apply or skip the delete operation. func (p *BinlogAdapter) readDeltalogs(segmentHolder *SegmentFilesHolder) (map[int64]uint64, map[string]uint64, error) { @@ -318,7 +337,7 @@ func (p *BinlogAdapter) readDeltalogs(segmentHolder *SegmentFilesHolder) (map[in } } -// Decode string array(read from delta log) to storage.DeleteLog array +// decodeDeleteLogs decodes string array(read from delta log) to storage.DeleteLog array func (p *BinlogAdapter) decodeDeleteLogs(segmentHolder *SegmentFilesHolder) ([]*storage.DeleteLog, error) { // step 1: read all delta logs to construct a string array, each string is marshaled from storage.DeleteLog stringArray := make([]string, 0) @@ -345,8 +364,9 @@ func (p *BinlogAdapter) decodeDeleteLogs(segmentHolder *SegmentFilesHolder) ([]* return nil, err } - // ignore deletions whose timestamp is larger than the tsEndPoint - if deleteLog.Ts <= p.tsEndPoint { + // only the ts between tsStartPoint and tsEndPoint is effective + // ignore deletions whose timestamp is larger than the tsEndPoint or less than tsStartPoint + if deleteLog.Ts >= p.tsStartPoint && deleteLog.Ts <= p.tsEndPoint { deleteLogs = append(deleteLogs, deleteLog) } } @@ -365,7 +385,7 @@ func (p *BinlogAdapter) decodeDeleteLogs(segmentHolder *SegmentFilesHolder) ([]* return deleteLogs, nil } -// Decode a string to storage.DeleteLog +// decodeDeleteLog decodes a string to storage.DeleteLog // Note: the following code is mainly come from data_codec.go, I suppose the code can compatible with old version 2.0 func (p *BinlogAdapter) decodeDeleteLog(deltaStr string) (*storage.DeleteLog, error) { deleteLog := &storage.DeleteLog{} @@ -399,7 +419,7 @@ func (p *BinlogAdapter) decodeDeleteLog(deltaStr string) (*storage.DeleteLog, er return deleteLog, nil } -// Each delta log data type is varchar, marshaled from an array of storage.DeleteLog objects. +// readDeltalog parses a delta log file. Each delta log data type is varchar, marshaled from an array of storage.DeleteLog objects. func (p *BinlogAdapter) readDeltalog(logPath string) ([]string, error) { // open the delta log file binlogFile, err := NewBinlogFile(p.chunkManager) @@ -426,7 +446,7 @@ func (p *BinlogAdapter) readDeltalog(logPath string) ([]string, error) { return data, nil } -// This method read data from int64 field, currently we use it to read the timestamp field. +// readTimestamp method reads data from int64 field, currently we use it to read the timestamp field. func (p *BinlogAdapter) readTimestamp(logPath string) ([]int64, error) { // open the log file binlogFile, err := NewBinlogFile(p.chunkManager) @@ -454,7 +474,7 @@ func (p *BinlogAdapter) readTimestamp(logPath string) ([]int64, error) { return int64List, nil } -// This method read primary keys from insert log. +// readPrimaryKeys method reads primary keys from insert log. func (p *BinlogAdapter) readPrimaryKeys(logPath string) ([]int64, []string, error) { // open the delta log file binlogFile, err := NewBinlogFile(p.chunkManager) @@ -493,7 +513,7 @@ func (p *BinlogAdapter) readPrimaryKeys(logPath string) ([]int64, []string, erro } } -// This method generate a shard id list by primary key(int64) list and deleted list. +// getShardingListByPrimaryInt64 method generates a shard id list by primary key(int64) list and deleted list. // For example, an insert log has 10 rows, the no.3 and no.7 has been deleted, shardNum=2, the shardList could be: // [0, 1, -1, 1, 0, 1, -1, 1, 0, 1] // Compare timestampList with tsEndPoint to skip some rows. @@ -513,10 +533,10 @@ func (p *BinlogAdapter) getShardingListByPrimaryInt64(primaryKeys []int64, excluded := 0 shardList := make([]int32, 0, len(primaryKeys)) for i, key := range primaryKeys { - // if this entity's timestamp is greater than the tsEndPoint, set shardID = -1 to skip this entity + // if this entity's timestamp is greater than the tsEndPoint, or less than tsStartPoint, set shardID = -1 to skip this entity // timestamp is stored as int64 type in log file, actually it is uint64, compare with uint64 ts := timestampList[i] - if uint64(ts) > p.tsEndPoint { + if uint64(ts) > p.tsEndPoint || uint64(ts) < p.tsStartPoint { shardList = append(shardList, -1) excluded++ continue @@ -546,7 +566,7 @@ func (p *BinlogAdapter) getShardingListByPrimaryInt64(primaryKeys []int64, return shardList, nil } -// This method generate a shard id list by primary key(varchar) list and deleted list. +// getShardingListByPrimaryVarchar method generates a shard id list by primary key(varchar) list and deleted list. // For example, an insert log has 10 rows, the no.3 and no.7 has been deleted, shardNum=2, the shardList could be: // [0, 1, -1, 1, 0, 1, -1, 1, 0, 1] func (p *BinlogAdapter) getShardingListByPrimaryVarchar(primaryKeys []string, @@ -565,10 +585,10 @@ func (p *BinlogAdapter) getShardingListByPrimaryVarchar(primaryKeys []string, excluded := 0 shardList := make([]int32, 0, len(primaryKeys)) for i, key := range primaryKeys { - // if this entity's timestamp is greater than the tsEndPoint, set shardID = -1 to skip this entity + // if this entity's timestamp is greater than the tsEndPoint, or less than tsStartPoint, set shardID = -1 to skip this entity // timestamp is stored as int64 type in log file, actually it is uint64, compare with uint64 ts := timestampList[i] - if uint64(ts) > p.tsEndPoint { + if uint64(ts) > p.tsEndPoint || uint64(ts) < p.tsStartPoint { shardList = append(shardList, -1) excluded++ continue @@ -598,7 +618,7 @@ func (p *BinlogAdapter) getShardingListByPrimaryVarchar(primaryKeys []string, return shardList, nil } -// This method read an insert log, and split the data into different shards according to a shard list +// readInsertlog method reads an insert log, and split the data into different shards according to a shard list // The shardList is a list to tell which row belong to which shard, returned by getShardingListByPrimaryXXX() // For deleted rows, we say its shard id is -1. // For example, an insert log has 10 rows, the no.3 and no.7 has been deleted, shardNum=2, the shardList could be: @@ -1000,96 +1020,3 @@ func (p *BinlogAdapter) dispatchFloatVecToShards(data []float32, dim int, memory return nil } - -// This method do the two things: -// 1. if accumulate data of a segment exceed segmentSize, call callFlushFunc to generate new segment -// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest segment -func (p *BinlogAdapter) tryFlushSegments(segmentsData []map[storage.FieldID]storage.FieldData, force bool) error { - totalSize := 0 - biggestSize := 0 - biggestItem := -1 - - // 1. if accumulate data of a segment exceed segmentSize, call callFlushFunc to generate new segment - for i := 0; i < len(segmentsData); i++ { - segmentData := segmentsData[i] - // Note: even rowCount is 0, the size is still non-zero - size := 0 - rowCount := 0 - for _, fieldData := range segmentData { - size += fieldData.GetMemorySize() - rowCount = fieldData.RowNum() - } - - // force to flush, called at the end of Read() - if force && rowCount > 0 { - err := p.callFlushFunc(segmentData, i) - if err != nil { - log.Error("Binlog adapter: failed to force flush segment data", zap.Int("shardID", i)) - return err - } - log.Info("Binlog adapter: force flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i)) - - segmentsData[i] = initSegmentData(p.collectionSchema) - if segmentsData[i] == nil { - log.Error("Binlog adapter: failed to initialize FieldData list") - return errors.New("failed to initialize FieldData list") - } - continue - } - - // if segment size is larger than predefined segmentSize, flush to create a new segment - // initialize a new FieldData list for next round batch read - if size > int(p.segmentSize) && rowCount > 0 { - err := p.callFlushFunc(segmentData, i) - if err != nil { - log.Error("Binlog adapter: failed to flush segment data", zap.Int("shardID", i)) - return err - } - log.Info("Binlog adapter: segment size exceed limit and flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i)) - - segmentsData[i] = initSegmentData(p.collectionSchema) - if segmentsData[i] == nil { - log.Error("Binlog adapter: failed to initialize FieldData list") - return errors.New("failed to initialize FieldData list") - } - continue - } - - // calculate the total size(ignore the flushed segments) - // find out the biggest segment for the step 2 - totalSize += size - if size > biggestSize { - biggestSize = size - biggestItem = i - } - } - - // 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest segment - if totalSize > int(p.maxTotalSize) && biggestItem >= 0 { - segmentData := segmentsData[biggestItem] - size := 0 - rowCount := 0 - for _, fieldData := range segmentData { - size += fieldData.GetMemorySize() - rowCount = fieldData.RowNum() - } - - if rowCount > 0 { - err := p.callFlushFunc(segmentData, biggestItem) - if err != nil { - log.Error("Binlog adapter: failed to flush biggest segment data", zap.Int("shardID", biggestItem)) - return err - } - log.Info("Binlog adapter: total size exceed limit and flush", zap.Int("rowCount", rowCount), - zap.Int("size", size), zap.Int("totalSize", totalSize), zap.Int("shardID", biggestItem)) - - segmentsData[biggestItem] = initSegmentData(p.collectionSchema) - if segmentsData[biggestItem] == nil { - log.Error("Binlog adapter: failed to initialize FieldData list") - return errors.New("failed to initialize FieldData list") - } - } - } - - return nil -} diff --git a/internal/util/importutil/binlog_adapter_test.go b/internal/util/importutil/binlog_adapter_test.go index 9f60c26343..ce6cf8e7b9 100644 --- a/internal/util/importutil/binlog_adapter_test.go +++ b/internal/util/importutil/binlog_adapter_test.go @@ -16,8 +16,10 @@ package importutil import ( + "context" "encoding/json" "errors" + "math" "strconv" "testing" @@ -154,18 +156,20 @@ func createSegmentsData(fieldsData map[storage.FieldID]interface{}, shardNum int } func Test_NewBinlogAdapter(t *testing.T) { + ctx := context.Background() + // nil schema - adapter, err := NewBinlogAdapter(nil, 2, 1024, 2048, nil, nil, 0) + adapter, err := NewBinlogAdapter(ctx, nil, 2, 1024, 2048, nil, nil, 0, math.MaxUint64) assert.Nil(t, adapter) assert.NotNil(t, err) // nil chunkmanager - adapter, err = NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, nil, nil, 0) + adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, nil, nil, 0, math.MaxUint64) assert.Nil(t, adapter) assert.NotNil(t, err) // nil flushfunc - adapter, err = NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, nil, 0) + adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, nil, 0, math.MaxUint64) assert.Nil(t, adapter) assert.NotNil(t, err) @@ -173,7 +177,7 @@ func Test_NewBinlogAdapter(t *testing.T) { flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } - adapter, err = NewBinlogAdapter(sampleSchema(), 2, 2048, 1024, &MockChunkManager{}, flushFunc, 0) + adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 2048, 1024, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -191,16 +195,18 @@ func Test_NewBinlogAdapter(t *testing.T) { }, }, } - adapter, err = NewBinlogAdapter(schema, 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0) + adapter, err = NewBinlogAdapter(ctx, schema, 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.Nil(t, adapter) assert.NotNil(t, err) } func Test_BinlogAdapterVerify(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } - adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -247,6 +253,8 @@ func Test_BinlogAdapterVerify(t *testing.T) { } func Test_BinlogAdapterReadDeltalog(t *testing.T) { + ctx := context.Background() + deleteItems := []int64{1001, 1002, 1003} buf := createDeltalogBuf(t, deleteItems, false) chunkManager := &MockChunkManager{ @@ -259,7 +267,7 @@ func Test_BinlogAdapterReadDeltalog(t *testing.T) { return nil } - adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -283,6 +291,8 @@ func Test_BinlogAdapterReadDeltalog(t *testing.T) { } func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) { + ctx := context.Background() + deleteItems := []int64{1001, 1002, 1003, 1004, 1005} buf := createDeltalogBuf(t, deleteItems, false) chunkManager := &MockChunkManager{ @@ -295,7 +305,7 @@ func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) { return nil } - adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -316,7 +326,7 @@ func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) { "dummy": createDeltalogBuf(t, []string{"1001", "1002"}, true), } - adapter, err = NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0) + adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -327,11 +337,13 @@ func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) { } func Test_BinlogAdapterDecodeDeleteLog(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } - adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -378,6 +390,8 @@ func Test_BinlogAdapterDecodeDeleteLog(t *testing.T) { } func Test_BinlogAdapterReadDeltalogs(t *testing.T) { + ctx := context.Background() + deleteItems := []int64{1001, 1002, 1003, 1004, 1005} buf := createDeltalogBuf(t, deleteItems, false) chunkManager := &MockChunkManager{ @@ -390,7 +404,7 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) { return nil } - adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -432,17 +446,19 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) { "dummy": createDeltalogBuf(t, []string{"1001", "1002"}, true), } - adapter, err = NewBinlogAdapter(schema, 2, 1024, 2048, chunkManager, flushFunc, 0) + adapter, err = NewBinlogAdapter(ctx, schema, 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) // 2.1 all deletion have been filtered out + adapter.tsStartPoint = baseTimestamp + 2 intDeletions, strDeletions, err = adapter.readDeltalogs(holder) assert.Nil(t, err) assert.Nil(t, intDeletions) assert.Nil(t, strDeletions) // 2.2 filter the no.1 and no.2 deletion + adapter.tsStartPoint = 0 adapter.tsEndPoint = baseTimestamp + 1 intDeletions, strDeletions, err = adapter.readDeltalogs(holder) assert.Nil(t, err) @@ -470,7 +486,7 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) { }, } - adapter, err = NewBinlogAdapter(schema, 2, 1024, 2048, chunkManager, flushFunc, 0) + adapter, err = NewBinlogAdapter(ctx, schema, 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -482,10 +498,12 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) { } func Test_BinlogAdapterReadTimestamp(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } - adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -515,10 +533,12 @@ func Test_BinlogAdapterReadTimestamp(t *testing.T) { } func Test_BinlogAdapterReadPrimaryKeys(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } - adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -570,12 +590,14 @@ func Test_BinlogAdapterReadPrimaryKeys(t *testing.T) { } func Test_BinlogAdapterShardListInt64(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } shardNum := int32(2) - adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -610,12 +632,14 @@ func Test_BinlogAdapterShardListInt64(t *testing.T) { } func Test_BinlogAdapterShardListVarchar(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } shardNum := int32(2) - adapter, err := NewBinlogAdapter(strKeySchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, strKeySchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -650,6 +674,8 @@ func Test_BinlogAdapterShardListVarchar(t *testing.T) { } func Test_BinlogAdapterReadInt64PK(t *testing.T) { + ctx := context.Background() + chunkManager := &MockChunkManager{} flushCounter := 0 @@ -669,7 +695,7 @@ func Test_BinlogAdapterReadInt64PK(t *testing.T) { } shardNum := int32(2) - adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, sampleSchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) adapter.tsEndPoint = baseTimestamp + 1 @@ -741,6 +767,8 @@ func Test_BinlogAdapterReadInt64PK(t *testing.T) { } func Test_BinlogAdapterReadVarcharPK(t *testing.T) { + ctx := context.Background() + chunkManager := &MockChunkManager{} flushCounter := 0 @@ -814,7 +842,7 @@ func Test_BinlogAdapterReadVarcharPK(t *testing.T) { // succeed shardNum := int32(3) - adapter, err := NewBinlogAdapter(strKeySchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, strKeySchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -825,93 +853,14 @@ func Test_BinlogAdapterReadVarcharPK(t *testing.T) { assert.Equal(t, rowCount-502, flushRowCount) } -func Test_BinlogAdapterTryFlush(t *testing.T) { - flushCounter := 0 - flushRowCount := 0 - flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { - flushCounter++ - rowCount := 0 - for _, v := range fields { - rowCount = v.RowNum() - break - } - flushRowCount += rowCount - for _, v := range fields { - assert.Equal(t, rowCount, v.RowNum()) - } - return nil - } - - segmentSize := int64(1024) - maxTotalSize := int64(2048) - shardNum := int32(3) - adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, segmentSize, maxTotalSize, &MockChunkManager{}, flushFunc, 0) - assert.NotNil(t, adapter) - assert.Nil(t, err) - - // prepare flush data, 3 shards, each shard 10 rows - rowCount := 10 - fieldsData := createFieldsData(rowCount) - - // non-force flush - segmentsData := createSegmentsData(fieldsData, shardNum) - err = adapter.tryFlushSegments(segmentsData, false) - assert.Nil(t, err) - assert.Equal(t, 0, flushCounter) - assert.Equal(t, 0, flushRowCount) - - // force flush - err = adapter.tryFlushSegments(segmentsData, true) - assert.Nil(t, err) - assert.Equal(t, int(shardNum), flushCounter) - assert.Equal(t, rowCount*int(shardNum), flushRowCount) - - // after force flush, no data left - flushCounter = 0 - flushRowCount = 0 - err = adapter.tryFlushSegments(segmentsData, true) - assert.Nil(t, err) - assert.Equal(t, 0, flushCounter) - assert.Equal(t, 0, flushRowCount) - - // flush when segment size exceeds segmentSize - segmentsData = createSegmentsData(fieldsData, shardNum) - adapter.segmentSize = 100 // segmentSize is 100 bytes, less than the 10 rows size - err = adapter.tryFlushSegments(segmentsData, false) - assert.Nil(t, err) - assert.Equal(t, int(shardNum), flushCounter) - assert.Equal(t, rowCount*int(shardNum), flushRowCount) - - flushCounter = 0 - flushRowCount = 0 - err = adapter.tryFlushSegments(segmentsData, true) // no data left - assert.Nil(t, err) - assert.Equal(t, 0, flushCounter) - assert.Equal(t, 0, flushRowCount) - - // flush when segments total size exceeds maxTotalSize - segmentsData = createSegmentsData(fieldsData, shardNum) - adapter.segmentSize = 4096 // segmentSize is 4096 bytes, larger than the 10 rows size - adapter.maxTotalSize = 100 // maxTotalSize is 100 bytes, less than the 30 rows size - err = adapter.tryFlushSegments(segmentsData, false) - assert.Nil(t, err) - assert.Equal(t, 1, flushCounter) // only the max segment is flushed - assert.Equal(t, 10, flushRowCount) - - flushCounter = 0 - flushRowCount = 0 - err = adapter.tryFlushSegments(segmentsData, true) // two segments left - assert.Nil(t, err) - assert.Equal(t, 2, flushCounter) - assert.Equal(t, 20, flushRowCount) -} - func Test_BinlogAdapterDispatch(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } shardNum := int32(3) - adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) @@ -1055,10 +1004,12 @@ func Test_BinlogAdapterDispatch(t *testing.T) { } func Test_BinlogAdapterReadInsertlog(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } - adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0) + adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.NotNil(t, adapter) assert.Nil(t, err) diff --git a/internal/util/importutil/binlog_file.go b/internal/util/importutil/binlog_file.go index eff0a3ed48..dfd3b550b0 100644 --- a/internal/util/importutil/binlog_file.go +++ b/internal/util/importutil/binlog_file.go @@ -26,7 +26,7 @@ import ( "go.uber.org/zap" ) -// This class is a wrapper of storage.BinlogReader, to read binlog file, block by block. +// BinlogFile class is a wrapper of storage.BinlogReader, to read binlog file, block by block. // Note: for bulkoad function, we only handle normal insert log and delta log. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. // Typically, an insert log file size is 16MB. @@ -72,7 +72,7 @@ func (p *BinlogFile) Open(filePath string) error { return nil } -// The outer caller must call this method in defer +// Close close the reader object, outer caller must call this method in defer func (p *BinlogFile) Close() { if p.reader != nil { p.reader.Close() @@ -88,8 +88,8 @@ func (p *BinlogFile) DataType() schemapb.DataType { return p.reader.PayloadDataType } +// ReadBool method reads all the blocks of a binlog by a data type. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. -// This method read all the blocks of a binlog by a data type. func (p *BinlogFile) ReadBool() ([]bool, error) { if p.reader == nil { log.Error("Binlog file: binlog reader not yet initialized") @@ -131,8 +131,8 @@ func (p *BinlogFile) ReadBool() ([]bool, error) { return result, nil } +// ReadInt8 method reads all the blocks of a binlog by a data type. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. -// This method read all the blocks of a binlog by a data type. func (p *BinlogFile) ReadInt8() ([]int8, error) { if p.reader == nil { log.Error("Binlog file: binlog reader not yet initialized") @@ -174,8 +174,8 @@ func (p *BinlogFile) ReadInt8() ([]int8, error) { return result, nil } +// ReadInt16 method reads all the blocks of a binlog by a data type. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. -// This method read all the blocks of a binlog by a data type. func (p *BinlogFile) ReadInt16() ([]int16, error) { if p.reader == nil { log.Error("Binlog file: binlog reader not yet initialized") @@ -217,8 +217,8 @@ func (p *BinlogFile) ReadInt16() ([]int16, error) { return result, nil } +// ReadInt32 method reads all the blocks of a binlog by a data type. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. -// This method read all the blocks of a binlog by a data type. func (p *BinlogFile) ReadInt32() ([]int32, error) { if p.reader == nil { log.Error("Binlog file: binlog reader not yet initialized") @@ -260,8 +260,8 @@ func (p *BinlogFile) ReadInt32() ([]int32, error) { return result, nil } +// ReadInt64 method reads all the blocks of a binlog by a data type. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. -// This method read all the blocks of a binlog by a data type. func (p *BinlogFile) ReadInt64() ([]int64, error) { if p.reader == nil { log.Error("Binlog file: binlog reader not yet initialized") @@ -303,8 +303,8 @@ func (p *BinlogFile) ReadInt64() ([]int64, error) { return result, nil } +// ReadFloat method reads all the blocks of a binlog by a data type. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. -// This method read all the blocks of a binlog by a data type. func (p *BinlogFile) ReadFloat() ([]float32, error) { if p.reader == nil { log.Error("Binlog file: binlog reader not yet initialized") @@ -346,8 +346,8 @@ func (p *BinlogFile) ReadFloat() ([]float32, error) { return result, nil } +// ReadDouble method reads all the blocks of a binlog by a data type. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. -// This method read all the blocks of a binlog by a data type. func (p *BinlogFile) ReadDouble() ([]float64, error) { if p.reader == nil { log.Error("Binlog file: binlog reader not yet initialized") @@ -389,8 +389,8 @@ func (p *BinlogFile) ReadDouble() ([]float64, error) { return result, nil } +// ReadVarchar method reads all the blocks of a binlog by a data type. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. -// This method read all the blocks of a binlog by a data type. func (p *BinlogFile) ReadVarchar() ([]string, error) { if p.reader == nil { log.Error("Binlog file: binlog reader not yet initialized") @@ -433,8 +433,8 @@ func (p *BinlogFile) ReadVarchar() ([]string, error) { return result, nil } +// ReadBinaryVector method reads all the blocks of a binlog by a data type. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. -// This method read all the blocks of a binlog by a data type. // return vectors data and the dimension func (p *BinlogFile) ReadBinaryVector() ([]byte, int, error) { if p.reader == nil { @@ -479,8 +479,8 @@ func (p *BinlogFile) ReadBinaryVector() ([]byte, int, error) { return result, dim, nil } +// ReadFloatVector method reads all the blocks of a binlog by a data type. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. -// This method read all the blocks of a binlog by a data type. // return vectors data and the dimension func (p *BinlogFile) ReadFloatVector() ([]float32, int, error) { if p.reader == nil { diff --git a/internal/util/importutil/binlog_file_test.go b/internal/util/importutil/binlog_file_test.go index 038c5a347d..254351e9c8 100644 --- a/internal/util/importutil/binlog_file_test.go +++ b/internal/util/importutil/binlog_file_test.go @@ -599,7 +599,7 @@ func Test_BinlogFileDouble(t *testing.T) { } func Test_BinlogFileVarchar(t *testing.T) { - source := []string{"a", "b", "c", "d"} + source := []string{"a", "bb", "罗伯特", "d"} chunkManager := &MockChunkManager{ readBuf: map[string][]byte{ "dummy": createBinlogBuf(t, schemapb.DataType_VarChar, source), diff --git a/internal/util/importutil/binlog_parser.go b/internal/util/importutil/binlog_parser.go index 181107d42d..1fb4e96ccd 100644 --- a/internal/util/importutil/binlog_parser.go +++ b/internal/util/importutil/binlog_parser.go @@ -19,6 +19,7 @@ package importutil import ( "context" "errors" + "fmt" "path" "sort" "strconv" @@ -30,23 +31,33 @@ import ( ) type BinlogParser struct { + ctx context.Context // for canceling parse process collectionSchema *schemapb.CollectionSchema // collection schema shardNum int32 // sharding number of the collection - segmentSize int64 // maximum size of a segment(unit:byte) + blockSize int64 // maximum size of a read block(unit:byte) chunkManager storage.ChunkManager // storage interfaces to browse/read the files callFlushFunc ImportFlushFunc // call back function to flush segment - // a timestamp to define the end point of restore, data after this point will be ignored + // a timestamp to define the start time point of restore, data before this time point will be ignored + // set this value to 0, all the data will be imported + // set this value to math.MaxUint64, all the data will be ignored + // the tsStartPoint value must be less/equal than tsEndPoint + tsStartPoint uint64 + + // a timestamp to define the end time point of restore, data after this time point will be ignored // set this value to 0, all the data will be ignored // set this value to math.MaxUint64, all the data will be imported + // the tsEndPoint value must be larger/equal than tsStartPoint tsEndPoint uint64 } -func NewBinlogParser(collectionSchema *schemapb.CollectionSchema, +func NewBinlogParser(ctx context.Context, + collectionSchema *schemapb.CollectionSchema, shardNum int32, - segmentSize int64, + blockSize int64, chunkManager storage.ChunkManager, flushFunc ImportFlushFunc, + tsStartPoint uint64, tsEndPoint uint64) (*BinlogParser, error) { if collectionSchema == nil { log.Error("Binlog parser: collection schema is nil") @@ -63,18 +74,27 @@ func NewBinlogParser(collectionSchema *schemapb.CollectionSchema, return nil, errors.New("flush function is nil") } + if tsStartPoint > tsEndPoint { + log.Error("Binlog parser: the tsStartPoint should be less than tsEndPoint", + zap.Uint64("tsStartPoint", tsStartPoint), zap.Uint64("tsEndPoint", tsEndPoint)) + return nil, fmt.Errorf("Binlog parser: the tsStartPoint %d should be less than tsEndPoint %d", tsStartPoint, tsEndPoint) + } + v := &BinlogParser{ + ctx: ctx, collectionSchema: collectionSchema, shardNum: shardNum, - segmentSize: segmentSize, + blockSize: blockSize, chunkManager: chunkManager, callFlushFunc: flushFunc, + tsStartPoint: tsStartPoint, tsEndPoint: tsEndPoint, } return v, nil } +// constructSegmentHolders builds a list of SegmentFilesHolder, each SegmentFilesHolder represents a segment folder // For instance, the insertlogRoot is "backup/bak1/data/insert_log/435978159196147009/435978159196147010". // 435978159196147009 is a collection id, 435978159196147010 is a partition id, // there is a segment(id is 435978159261483009) under this partition. @@ -195,8 +215,8 @@ func (p *BinlogParser) parseSegmentFiles(segmentHolder *SegmentFilesHolder) erro return errors.New("segment files holder is nil") } - adapter, err := NewBinlogAdapter(p.collectionSchema, p.shardNum, p.segmentSize, - MaxTotalSizeInMemory, p.chunkManager, p.callFlushFunc, p.tsEndPoint) + adapter, err := NewBinlogAdapter(p.ctx, p.collectionSchema, p.shardNum, p.blockSize, + MaxTotalSizeInMemory, p.chunkManager, p.callFlushFunc, p.tsStartPoint, p.tsEndPoint) if err != nil { log.Error("Binlog parser: failed to create binlog adapter", zap.Error(err)) return err @@ -205,7 +225,7 @@ func (p *BinlogParser) parseSegmentFiles(segmentHolder *SegmentFilesHolder) erro return adapter.Read(segmentHolder) } -// This functions requires two paths: +// Parse requires two paths: // 1. the insert log path of a partition // 2. the delta log path of a partiion (optional) func (p *BinlogParser) Parse(filePaths []string) error { diff --git a/internal/util/importutil/binlog_parser_test.go b/internal/util/importutil/binlog_parser_test.go index 33e8a10b94..7ad375e766 100644 --- a/internal/util/importutil/binlog_parser_test.go +++ b/internal/util/importutil/binlog_parser_test.go @@ -16,7 +16,9 @@ package importutil import ( + "context" "errors" + "math" "path" "strconv" "testing" @@ -27,18 +29,20 @@ import ( ) func Test_NewBinlogParser(t *testing.T) { + ctx := context.Background() + // nil schema - parser, err := NewBinlogParser(nil, 2, 1024, nil, nil, 0) + parser, err := NewBinlogParser(ctx, nil, 2, 1024, nil, nil, 0, math.MaxUint64) assert.Nil(t, parser) assert.NotNil(t, err) // nil chunkmanager - parser, err = NewBinlogParser(sampleSchema(), 2, 1024, nil, nil, 0) + parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, nil, nil, 0, math.MaxUint64) assert.Nil(t, parser) assert.NotNil(t, err) // nil flushfunc - parser, err = NewBinlogParser(sampleSchema(), 2, 1024, &MockChunkManager{}, nil, 0) + parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, nil, 0, math.MaxUint64) assert.Nil(t, parser) assert.NotNil(t, err) @@ -46,12 +50,19 @@ func Test_NewBinlogParser(t *testing.T) { flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } - parser, err = NewBinlogParser(sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0) + parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.NotNil(t, parser) assert.Nil(t, err) + + // tsStartPoint larger than tsEndPoint + parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 2, 1) + assert.Nil(t, parser) + assert.NotNil(t, err) } func Test_BinlogParserConstructHolders(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } @@ -116,7 +127,7 @@ func Test_BinlogParserConstructHolders(t *testing.T) { "backup/bak1/data/delta_log/435978159196147009/435978159196147010/435978159261483009/434574382554415105", } - parser, err := NewBinlogParser(sampleSchema(), 2, 1024, chunkManager, flushFunc, 0) + parser, err := NewBinlogParser(ctx, sampleSchema(), 2, 1024, chunkManager, flushFunc, 0, math.MaxUint64) assert.NotNil(t, parser) assert.Nil(t, err) @@ -165,6 +176,8 @@ func Test_BinlogParserConstructHolders(t *testing.T) { } func Test_BinlogParserConstructHoldersFailed(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } @@ -174,7 +187,7 @@ func Test_BinlogParserConstructHoldersFailed(t *testing.T) { listResult: make(map[string][]string), } - parser, err := NewBinlogParser(sampleSchema(), 2, 1024, chunkManager, flushFunc, 0) + parser, err := NewBinlogParser(ctx, sampleSchema(), 2, 1024, chunkManager, flushFunc, 0, math.MaxUint64) assert.NotNil(t, parser) assert.Nil(t, err) @@ -214,11 +227,13 @@ func Test_BinlogParserConstructHoldersFailed(t *testing.T) { } func Test_BinlogParserParseFilesFailed(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } - parser, err := NewBinlogParser(sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0) + parser, err := NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0, math.MaxUint64) assert.NotNil(t, parser) assert.Nil(t, err) @@ -231,6 +246,8 @@ func Test_BinlogParserParseFilesFailed(t *testing.T) { } func Test_BinlogParserParse(t *testing.T) { + ctx := context.Background() + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { return nil } @@ -251,7 +268,7 @@ func Test_BinlogParserParse(t *testing.T) { }, } - parser, err := NewBinlogParser(schema, 2, 1024, chunkManager, flushFunc, 0) + parser, err := NewBinlogParser(ctx, schema, 2, 1024, chunkManager, flushFunc, 0, math.MaxUint64) assert.NotNil(t, parser) assert.Nil(t, err) diff --git a/internal/util/importutil/import_util.go b/internal/util/importutil/import_util.go new file mode 100644 index 0000000000..0176416b35 --- /dev/null +++ b/internal/util/importutil/import_util.go @@ -0,0 +1,512 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importutil + +import ( + "context" + "errors" + "fmt" + "path" + "runtime/debug" + "strconv" + "strings" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/common" + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/storage" +) + +func isCanceled(ctx context.Context) bool { + // canceled? + select { + case <-ctx.Done(): + return true + default: + break + } + return false +} + +func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.FieldID]storage.FieldData { + segmentData := make(map[storage.FieldID]storage.FieldData) + // rowID field is a hidden field with fieldID=0, it is always auto-generated by IDAllocator + // if primary key is int64 and autoID=true, primary key field is equal to rowID field + segmentData[common.RowIDField] = &storage.Int64FieldData{ + Data: make([]int64, 0), + NumRows: []int64{0}, + } + + for i := 0; i < len(collectionSchema.Fields); i++ { + schema := collectionSchema.Fields[i] + switch schema.DataType { + case schemapb.DataType_Bool: + segmentData[schema.GetFieldID()] = &storage.BoolFieldData{ + Data: make([]bool, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Float: + segmentData[schema.GetFieldID()] = &storage.FloatFieldData{ + Data: make([]float32, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Double: + segmentData[schema.GetFieldID()] = &storage.DoubleFieldData{ + Data: make([]float64, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Int8: + segmentData[schema.GetFieldID()] = &storage.Int8FieldData{ + Data: make([]int8, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Int16: + segmentData[schema.GetFieldID()] = &storage.Int16FieldData{ + Data: make([]int16, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Int32: + segmentData[schema.GetFieldID()] = &storage.Int32FieldData{ + Data: make([]int32, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_Int64: + segmentData[schema.GetFieldID()] = &storage.Int64FieldData{ + Data: make([]int64, 0), + NumRows: []int64{0}, + } + case schemapb.DataType_BinaryVector: + dim, _ := getFieldDimension(schema) + segmentData[schema.GetFieldID()] = &storage.BinaryVectorFieldData{ + Data: make([]byte, 0), + NumRows: []int64{0}, + Dim: dim, + } + case schemapb.DataType_FloatVector: + dim, _ := getFieldDimension(schema) + segmentData[schema.GetFieldID()] = &storage.FloatVectorFieldData{ + Data: make([]float32, 0), + NumRows: []int64{0}, + Dim: dim, + } + case schemapb.DataType_String, schemapb.DataType_VarChar: + segmentData[schema.GetFieldID()] = &storage.StringFieldData{ + Data: make([]string, 0), + NumRows: []int64{0}, + } + default: + log.Error("Import util: unsupported data type", zap.Int("DataType", int(schema.DataType))) + return nil + } + } + + return segmentData +} + +// initValidators constructs valiator methods and data conversion methods +func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[storage.FieldID]*Validator) error { + if collectionSchema == nil { + return errors.New("collection schema is nil") + } + + // json decoder parse all the numeric value into float64 + numericValidator := func(obj interface{}) error { + switch obj.(type) { + case float64: + return nil + default: + s := fmt.Sprintf("%v", obj) + msg := "illegal numeric value " + s + return errors.New(msg) + } + + } + + for i := 0; i < len(collectionSchema.Fields); i++ { + schema := collectionSchema.Fields[i] + + validators[schema.GetFieldID()] = &Validator{} + validators[schema.GetFieldID()].primaryKey = schema.GetIsPrimaryKey() + validators[schema.GetFieldID()].autoID = schema.GetAutoID() + validators[schema.GetFieldID()].fieldName = schema.GetName() + validators[schema.GetFieldID()].isString = false + + switch schema.DataType { + case schemapb.DataType_Bool: + validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { + switch obj.(type) { + case bool: + return nil + default: + s := fmt.Sprintf("%v", obj) + msg := "illegal value " + s + " for bool type field " + schema.GetName() + return errors.New(msg) + } + + } + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := obj.(bool) + field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value) + field.(*storage.BoolFieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Float: + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := float32(obj.(float64)) + field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value) + field.(*storage.FloatFieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Double: + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := obj.(float64) + field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value) + field.(*storage.DoubleFieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Int8: + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := int8(obj.(float64)) + field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value) + field.(*storage.Int8FieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Int16: + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := int16(obj.(float64)) + field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value) + field.(*storage.Int16FieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Int32: + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := int32(obj.(float64)) + field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value) + field.(*storage.Int32FieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_Int64: + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := int64(obj.(float64)) + field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value) + field.(*storage.Int64FieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_BinaryVector: + dim, err := getFieldDimension(schema) + if err != nil { + return err + } + validators[schema.GetFieldID()].dimension = dim + + validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { + switch vt := obj.(type) { + case []interface{}: + if len(vt)*8 != dim { + msg := "bit size " + strconv.Itoa(len(vt)*8) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName() + return errors.New(msg) + } + for i := 0; i < len(vt); i++ { + if e := numericValidator(vt[i]); e != nil { + msg := e.Error() + " for binary vector field " + schema.GetName() + return errors.New(msg) + } + + t := int(vt[i].(float64)) + if t > 255 || t < 0 { + msg := "illegal value " + strconv.Itoa(t) + " for binary vector field " + schema.GetName() + return errors.New(msg) + } + } + return nil + default: + s := fmt.Sprintf("%v", obj) + msg := s + " is not an array for binary vector field " + schema.GetName() + return errors.New(msg) + } + } + + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { + arr := obj.([]interface{}) + for i := 0; i < len(arr); i++ { + value := byte(arr[i].(float64)) + field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, value) + } + + field.(*storage.BinaryVectorFieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_FloatVector: + dim, err := getFieldDimension(schema) + if err != nil { + return err + } + validators[schema.GetFieldID()].dimension = dim + + validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { + switch vt := obj.(type) { + case []interface{}: + if len(vt) != dim { + msg := "array size " + strconv.Itoa(len(vt)) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName() + return errors.New(msg) + } + for i := 0; i < len(vt); i++ { + if e := numericValidator(vt[i]); e != nil { + msg := e.Error() + " for float vector field " + schema.GetName() + return errors.New(msg) + } + } + return nil + default: + s := fmt.Sprintf("%v", obj) + msg := s + " is not an array for float vector field " + schema.GetName() + return errors.New(msg) + } + } + + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { + arr := obj.([]interface{}) + for i := 0; i < len(arr); i++ { + value := float32(arr[i].(float64)) + field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, value) + } + field.(*storage.FloatVectorFieldData).NumRows[0]++ + return nil + } + case schemapb.DataType_String, schemapb.DataType_VarChar: + validators[schema.GetFieldID()].isString = true + validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { + switch obj.(type) { + case string: + return nil + default: + s := fmt.Sprintf("%v", obj) + msg := s + " is not a string for string type field " + schema.GetName() + return errors.New(msg) + } + } + + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { + value := obj.(string) + field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, value) + field.(*storage.StringFieldData).NumRows[0]++ + return nil + } + default: + return errors.New("unsupport data type: " + strconv.Itoa(int(collectionSchema.Fields[i].DataType))) + } + } + + return nil +} + +func printFieldsDataInfo(fieldsData map[storage.FieldID]storage.FieldData, msg string, files []string) { + stats := make([]zapcore.Field, 0) + for k, v := range fieldsData { + stats = append(stats, zap.Int(strconv.FormatInt(k, 10), v.RowNum())) + } + + if len(files) > 0 { + stats = append(stats, zap.Any("files", files)) + } + log.Info(msg, stats...) +} + +// GetFileNameAndExt extracts file name and extension +// for example: "/a/b/c.ttt" returns "c" and ".ttt" +func GetFileNameAndExt(filePath string) (string, string) { + fileName := path.Base(filePath) + fileType := path.Ext(fileName) + fileNameWithoutExt := strings.TrimSuffix(fileName, fileType) + return fileNameWithoutExt, fileType +} + +// getFieldDimension gets dimension of vecotor field +func getFieldDimension(schema *schemapb.FieldSchema) (int, error) { + for _, kvPair := range schema.GetTypeParams() { + key, value := kvPair.GetKey(), kvPair.GetValue() + if key == "dim" { + dim, err := strconv.Atoi(value) + if err != nil { + return 0, errors.New("vector dimension is invalid") + } + return dim, nil + } + } + + return 0, errors.New("vector dimension is not defined") +} + +// triggerGC triggers golang gc to return all free memory back to the underlying system at once, +// Note: this operation is expensive, and can lead to latency spikes as it holds the heap lock through the whole process +func triggerGC() { + debug.FreeOSMemory() +} + +// tryFlushBlocks does the two things: +// 1. if accumulate data of a block exceed blockSize, call callFlushFunc to generate new binlog file +// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest block +func tryFlushBlocks(ctx context.Context, + blocksData []map[storage.FieldID]storage.FieldData, + collectionSchema *schemapb.CollectionSchema, + callFlushFunc ImportFlushFunc, + blockSize int64, + maxTotalSize int64, + force bool) error { + + totalSize := 0 + biggestSize := 0 + biggestItem := -1 + + // 1. if accumulate data of a block exceed blockSize, call callFlushFunc to generate new binlog file + for i := 0; i < len(blocksData); i++ { + // outside context might be canceled(service stop, or future enhancement for canceling import task) + if isCanceled(ctx) { + log.Error("Import util: import task was canceled") + return errors.New("import task was canceled") + } + + blockData := blocksData[i] + // Note: even rowCount is 0, the size is still non-zero + size := 0 + rowCount := 0 + for _, fieldData := range blockData { + size += fieldData.GetMemorySize() + rowCount = fieldData.RowNum() + } + + // force to flush, called at the end of Read() + if force && rowCount > 0 { + printFieldsDataInfo(blockData, "import util: prepare to force flush a block", nil) + err := callFlushFunc(blockData, i) + if err != nil { + log.Error("Import util: failed to force flush block data", zap.Int("shardID", i)) + return err + } + log.Info("Import util: force flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i)) + + blocksData[i] = initSegmentData(collectionSchema) + if blocksData[i] == nil { + log.Error("Import util: failed to initialize FieldData list") + return errors.New("failed to initialize FieldData list") + } + continue + } + + // if segment size is larger than predefined blockSize, flush to create a new binlog file + // initialize a new FieldData list for next round batch read + if size > int(blockSize) && rowCount > 0 { + printFieldsDataInfo(blockData, "import util: prepare to flush block larger than maxBlockSize", nil) + err := callFlushFunc(blockData, i) + if err != nil { + log.Error("Import util: failed to flush block data", zap.Int("shardID", i)) + return err + } + log.Info("Import util: block size exceed limit and flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i)) + + blocksData[i] = initSegmentData(collectionSchema) + if blocksData[i] == nil { + log.Error("Import util: failed to initialize FieldData list") + return errors.New("failed to initialize FieldData list") + } + continue + } + + // calculate the total size(ignore the flushed blocks) + // find out the biggest block for the step 2 + totalSize += size + if size > biggestSize { + biggestSize = size + biggestItem = i + } + } + + // 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest block + if totalSize > int(maxTotalSize) && biggestItem >= 0 { + // outside context might be canceled(service stop, or future enhancement for canceling import task) + if isCanceled(ctx) { + log.Error("Import util: import task was canceled") + return errors.New("import task was canceled") + } + + blockData := blocksData[biggestItem] + // Note: even rowCount is 0, the size is still non-zero + size := 0 + rowCount := 0 + for _, fieldData := range blockData { + size += fieldData.GetMemorySize() + rowCount = fieldData.RowNum() + } + + if rowCount > 0 { + printFieldsDataInfo(blockData, "import util: prepare to flush biggest block", nil) + err := callFlushFunc(blockData, biggestItem) + if err != nil { + log.Error("Import util: failed to flush biggest block data", zap.Int("shardID", biggestItem)) + return err + } + log.Info("Import util: total size exceed limit and flush", zap.Int("rowCount", rowCount), + zap.Int("size", size), zap.Int("totalSize", totalSize), zap.Int("shardID", biggestItem)) + + blocksData[biggestItem] = initSegmentData(collectionSchema) + if blocksData[biggestItem] == nil { + log.Error("Import util: failed to initialize FieldData list") + return errors.New("failed to initialize FieldData list") + } + } + } + + return nil +} + +func getTypeName(dt schemapb.DataType) string { + switch dt { + case schemapb.DataType_Bool: + return "Bool" + case schemapb.DataType_Int8: + return "Int8" + case schemapb.DataType_Int16: + return "Int16" + case schemapb.DataType_Int32: + return "Int32" + case schemapb.DataType_Int64: + return "Int64" + case schemapb.DataType_Float: + return "Float" + case schemapb.DataType_Double: + return "Double" + case schemapb.DataType_VarChar: + return "Varchar" + case schemapb.DataType_String: + return "String" + case schemapb.DataType_BinaryVector: + return "BinaryVector" + case schemapb.DataType_FloatVector: + return "FloatVector" + default: + return "InvalidType" + } +} diff --git a/internal/util/importutil/import_util_test.go b/internal/util/importutil/import_util_test.go new file mode 100644 index 0000000000..855891d4fd --- /dev/null +++ b/internal/util/importutil/import_util_test.go @@ -0,0 +1,460 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package importutil + +import ( + "context" + "testing" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/stretchr/testify/assert" +) + +func sampleSchema() *schemapb.CollectionSchema { + schema := &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 102, + Name: "field_bool", + IsPrimaryKey: false, + Description: "bool", + DataType: schemapb.DataType_Bool, + }, + { + FieldID: 103, + Name: "field_int8", + IsPrimaryKey: false, + Description: "int8", + DataType: schemapb.DataType_Int8, + }, + { + FieldID: 104, + Name: "field_int16", + IsPrimaryKey: false, + Description: "int16", + DataType: schemapb.DataType_Int16, + }, + { + FieldID: 105, + Name: "field_int32", + IsPrimaryKey: false, + Description: "int32", + DataType: schemapb.DataType_Int32, + }, + { + FieldID: 106, + Name: "field_int64", + IsPrimaryKey: true, + AutoID: false, + Description: "int64", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 107, + Name: "field_float", + IsPrimaryKey: false, + Description: "float", + DataType: schemapb.DataType_Float, + }, + { + FieldID: 108, + Name: "field_double", + IsPrimaryKey: false, + Description: "double", + DataType: schemapb.DataType_Double, + }, + { + FieldID: 109, + Name: "field_string", + IsPrimaryKey: false, + Description: "string", + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "max_length", Value: "128"}, + }, + }, + { + FieldID: 110, + Name: "field_binary_vector", + IsPrimaryKey: false, + Description: "binary_vector", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "16"}, + }, + }, + { + FieldID: 111, + Name: "field_float_vector", + IsPrimaryKey: false, + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + return schema +} + +func strKeySchema() *schemapb.CollectionSchema { + schema := &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: false, + Description: "uid", + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "max_length", Value: "1024"}, + }, + }, + { + FieldID: 102, + Name: "int_scalar", + IsPrimaryKey: false, + Description: "int_scalar", + DataType: schemapb.DataType_Int32, + }, + { + FieldID: 103, + Name: "float_scalar", + IsPrimaryKey: false, + Description: "float_scalar", + DataType: schemapb.DataType_Float, + }, + { + FieldID: 104, + Name: "string_scalar", + IsPrimaryKey: false, + Description: "string_scalar", + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "max_length", Value: "128"}, + }, + }, + { + FieldID: 105, + Name: "bool_scalar", + IsPrimaryKey: false, + Description: "bool_scalar", + DataType: schemapb.DataType_Bool, + }, + { + FieldID: 106, + Name: "vectors", + IsPrimaryKey: false, + Description: "vectors", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + return schema +} + +func Test_IsCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + assert.False(t, isCanceled(ctx)) + cancel() + assert.True(t, isCanceled(ctx)) +} + +func Test_InitSegmentData(t *testing.T) { + testFunc := func(schema *schemapb.CollectionSchema) { + fields := initSegmentData(schema) + assert.Equal(t, len(schema.Fields)+1, len(fields)) + + for _, field := range schema.Fields { + data, ok := fields[field.FieldID] + assert.True(t, ok) + assert.NotNil(t, data) + } + printFieldsDataInfo(fields, "dummy", []string{}) + } + testFunc(sampleSchema()) + testFunc(strKeySchema()) +} + +func Test_InitValidators(t *testing.T) { + validators := make(map[storage.FieldID]*Validator) + err := initValidators(nil, validators) + assert.NotNil(t, err) + + schema := sampleSchema() + // success case + err = initValidators(schema, validators) + assert.Nil(t, err) + assert.Equal(t, len(schema.Fields), len(validators)) + name2ID := make(map[string]storage.FieldID) + for _, field := range schema.Fields { + name2ID[field.GetName()] = field.GetFieldID() + } + + checkFunc := func(funcName string, validVal interface{}, invalidVal interface{}) { + id := name2ID[funcName] + v, ok := validators[id] + assert.True(t, ok) + err = v.validateFunc(validVal) + assert.Nil(t, err) + err = v.validateFunc(invalidVal) + assert.NotNil(t, err) + } + + // validate functions + var validVal interface{} = true + var invalidVal interface{} = "aa" + checkFunc("field_bool", validVal, invalidVal) + + validVal = float64(100) + invalidVal = "aa" + checkFunc("field_int8", validVal, invalidVal) + checkFunc("field_int16", validVal, invalidVal) + checkFunc("field_int32", validVal, invalidVal) + checkFunc("field_int64", validVal, invalidVal) + checkFunc("field_float", validVal, invalidVal) + checkFunc("field_double", validVal, invalidVal) + + validVal = "aa" + invalidVal = 100 + checkFunc("field_string", validVal, invalidVal) + + validVal = []interface{}{float64(100), float64(101)} + invalidVal = "aa" + checkFunc("field_binary_vector", validVal, invalidVal) + invalidVal = []interface{}{float64(100)} + checkFunc("field_binary_vector", validVal, invalidVal) + invalidVal = []interface{}{float64(100), float64(101), float64(102)} + checkFunc("field_binary_vector", validVal, invalidVal) + invalidVal = []interface{}{true, true} + checkFunc("field_binary_vector", validVal, invalidVal) + invalidVal = []interface{}{float64(255), float64(-1)} + checkFunc("field_binary_vector", validVal, invalidVal) + + validVal = []interface{}{float64(1), float64(2), float64(3), float64(4)} + invalidVal = true + checkFunc("field_float_vector", validVal, invalidVal) + invalidVal = []interface{}{float64(1), float64(2), float64(3)} + checkFunc("field_float_vector", validVal, invalidVal) + invalidVal = []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5)} + checkFunc("field_float_vector", validVal, invalidVal) + invalidVal = []interface{}{"a", "b", "c", "d"} + checkFunc("field_float_vector", validVal, invalidVal) + + // error cases + schema = &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: make([]*schemapb.FieldSchema, 0), + } + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 111, + Name: "field_float_vector", + IsPrimaryKey: false, + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "aa"}, + }, + }) + + validators = make(map[storage.FieldID]*Validator) + err = initValidators(schema, validators) + assert.NotNil(t, err) + + schema.Fields = make([]*schemapb.FieldSchema, 0) + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 110, + Name: "field_binary_vector", + IsPrimaryKey: false, + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "aa"}, + }, + }) + + err = initValidators(schema, validators) + assert.NotNil(t, err) +} + +func Test_GetFileNameAndExt(t *testing.T) { + filePath := "aaa/bbb/ccc.txt" + name, ext := GetFileNameAndExt(filePath) + assert.EqualValues(t, "ccc", name) + assert.EqualValues(t, ".txt", ext) +} + +func Test_GetFieldDimension(t *testing.T) { + schema := &schemapb.FieldSchema{ + FieldID: 111, + Name: "field_float_vector", + IsPrimaryKey: false, + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + } + + dim, err := getFieldDimension(schema) + assert.Nil(t, err) + assert.Equal(t, 4, dim) + + schema.TypeParams = []*commonpb.KeyValuePair{ + {Key: "dim", Value: "abc"}, + } + dim, err = getFieldDimension(schema) + assert.NotNil(t, err) + assert.Equal(t, 0, dim) + + schema.TypeParams = []*commonpb.KeyValuePair{} + dim, err = getFieldDimension(schema) + assert.NotNil(t, err) + assert.Equal(t, 0, dim) +} + +func Test_TryFlushBlocks(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + flushCounter := 0 + flushRowCount := 0 + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { + flushCounter++ + rowCount := 0 + for _, v := range fields { + rowCount = v.RowNum() + break + } + flushRowCount += rowCount + for _, v := range fields { + assert.Equal(t, rowCount, v.RowNum()) + } + return nil + } + + blockSize := int64(1024) + maxTotalSize := int64(2048) + shardNum := int32(3) + + // prepare flush data, 3 shards, each shard 10 rows + rowCount := 10 + fieldsData := createFieldsData(rowCount) + + // non-force flush + segmentsData := createSegmentsData(fieldsData, shardNum) + err := tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, false) + assert.Nil(t, err) + assert.Equal(t, 0, flushCounter) + assert.Equal(t, 0, flushRowCount) + + // force flush + err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true) + assert.Nil(t, err) + assert.Equal(t, int(shardNum), flushCounter) + assert.Equal(t, rowCount*int(shardNum), flushRowCount) + + // after force flush, no data left + flushCounter = 0 + flushRowCount = 0 + err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true) + assert.Nil(t, err) + assert.Equal(t, 0, flushCounter) + assert.Equal(t, 0, flushRowCount) + + // flush when segment size exceeds blockSize + segmentsData = createSegmentsData(fieldsData, shardNum) + blockSize = 100 // blockSize is 100 bytes, less than the 10 rows size + err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, false) + assert.Nil(t, err) + assert.Equal(t, int(shardNum), flushCounter) + assert.Equal(t, rowCount*int(shardNum), flushRowCount) + + flushCounter = 0 + flushRowCount = 0 + err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true) // no data left + assert.Nil(t, err) + assert.Equal(t, 0, flushCounter) + assert.Equal(t, 0, flushRowCount) + + // flush when segments total size exceeds maxTotalSize + segmentsData = createSegmentsData(fieldsData, shardNum) + blockSize = 4096 // blockSize is 4096 bytes, larger than the 10 rows size + maxTotalSize = 100 // maxTotalSize is 100 bytes, less than the 30 rows size + err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, false) + assert.Nil(t, err) + assert.Equal(t, 1, flushCounter) // only the max segment is flushed + assert.Equal(t, 10, flushRowCount) + + flushCounter = 0 + flushRowCount = 0 + err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true) // two segments left + assert.Nil(t, err) + assert.Equal(t, 2, flushCounter) + assert.Equal(t, 20, flushRowCount) + + // canceled + cancel() + flushCounter = 0 + flushRowCount = 0 + segmentsData = createSegmentsData(fieldsData, shardNum) + err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true) + assert.Error(t, err) + assert.Equal(t, 0, flushCounter) + assert.Equal(t, 0, flushRowCount) +} + +func Test_GetTypeName(t *testing.T) { + str := getTypeName(schemapb.DataType_Bool) + assert.NotEmpty(t, str) + str = getTypeName(schemapb.DataType_Int8) + assert.NotEmpty(t, str) + str = getTypeName(schemapb.DataType_Int16) + assert.NotEmpty(t, str) + str = getTypeName(schemapb.DataType_Int32) + assert.NotEmpty(t, str) + str = getTypeName(schemapb.DataType_Int64) + assert.NotEmpty(t, str) + str = getTypeName(schemapb.DataType_Float) + assert.NotEmpty(t, str) + str = getTypeName(schemapb.DataType_Double) + assert.NotEmpty(t, str) + str = getTypeName(schemapb.DataType_VarChar) + assert.NotEmpty(t, str) + str = getTypeName(schemapb.DataType_String) + assert.NotEmpty(t, str) + str = getTypeName(schemapb.DataType_BinaryVector) + assert.NotEmpty(t, str) + str = getTypeName(schemapb.DataType_FloatVector) + assert.NotEmpty(t, str) + str = getTypeName(schemapb.DataType_None) + assert.Equal(t, "InvalidType", str) +} diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go index f92c31ecb5..e98d2a860e 100644 --- a/internal/util/importutil/import_wrapper.go +++ b/internal/util/importutil/import_wrapper.go @@ -19,21 +19,17 @@ package importutil import ( "bufio" "context" - "errors" + "fmt" "math" - "path" - "runtime/debug" - "strconv" - "strings" "go.uber.org/zap" - "go.uber.org/zap/zapcore" "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/retry" @@ -45,6 +41,9 @@ const ( JSONFileExt = ".json" NumpyFileExt = ".npy" + // supposed size of a single block, to control a binlog file size, the max biglog file size is no more than 2*SingleBlockSize + SingleBlockSize = 16 * 1024 * 1024 // 16MB + // this limitation is to avoid this OOM risk: // for column-based file, we read all its data into memory, if user input a large file, the read() method may // cost extra memory and lear to OOM. @@ -64,26 +63,60 @@ const ( // ReportImportAttempts is the maximum # of attempts to retry when import fails. var ReportImportAttempts uint = 10 +type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardID int) error +type AssignSegmentFunc func(shardID int) (int64, string, error) +type CreateBinlogsFunc func(fields map[storage.FieldID]storage.FieldData, segmentID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error) +type SaveSegmentFunc func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64) error + +type WorkingSegment struct { + segmentID int64 // segment ID + shardID int // shard id + targetChName string // target dml channel + rowCount int64 // accumulate row count + memSize int // total memory size of all binlogs + fieldsInsert []*datapb.FieldBinlog // persisted binlogs + fieldsStats []*datapb.FieldBinlog // stats of persisted binlogs +} + +type ImportOptions struct { + OnlyValidate bool + TsStartPoint uint64 + TsEndPoint uint64 +} + +func DefaultImportOptions() ImportOptions { + options := ImportOptions{ + OnlyValidate: false, + TsStartPoint: 0, + TsEndPoint: math.MaxUint64, + } + return options +} + type ImportWrapper struct { ctx context.Context // for canceling parse process cancel context.CancelFunc // for canceling parse process collectionSchema *schemapb.CollectionSchema // collection schema shardNum int32 // sharding number of the collection - segmentSize int64 // maximum size of a segment(unit:byte) + segmentSize int64 // maximum size of a segment(unit:byte) defined by dataCoord.segment.maxSize (milvus.yml) rowIDAllocator *allocator.IDAllocator // autoid allocator chunkManager storage.ChunkManager - callFlushFunc ImportFlushFunc // call back function to flush a segment + assignSegmentFunc AssignSegmentFunc // function to prepare a new segment + createBinlogsFunc CreateBinlogsFunc // function to create binlog for a segment + saveSegmentFunc SaveSegmentFunc // function to persist a segment importResult *rootcoordpb.ImportResult // import result reportFunc func(res *rootcoordpb.ImportResult) error // report import state to rootcoord + + workingSegments map[int]*WorkingSegment // a map shard id to working segments } func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int64, - idAlloc *allocator.IDAllocator, cm storage.ChunkManager, flushFunc ImportFlushFunc, - importResult *rootcoordpb.ImportResult, reportFunc func(res *rootcoordpb.ImportResult) error) *ImportWrapper { + idAlloc *allocator.IDAllocator, cm storage.ChunkManager, importResult *rootcoordpb.ImportResult, + reportFunc func(res *rootcoordpb.ImportResult) error) *ImportWrapper { if collectionSchema == nil { - log.Error("import error: collection schema is nil") + log.Error("import wrapper: collection schema is nil") return nil } @@ -111,96 +144,143 @@ func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.Collection shardNum: shardNum, segmentSize: segmentSize, rowIDAllocator: idAlloc, - callFlushFunc: flushFunc, chunkManager: cm, importResult: importResult, reportFunc: reportFunc, + workingSegments: make(map[int]*WorkingSegment), } return wrapper } -// this method can be used to cancel parse process +func (p *ImportWrapper) SetCallbackFunctions(assignSegmentFunc AssignSegmentFunc, createBinlogsFunc CreateBinlogsFunc, saveSegmentFunc SaveSegmentFunc) error { + if assignSegmentFunc == nil { + log.Error("import wrapper: callback function AssignSegmentFunc is nil") + return fmt.Errorf("import wrapper: callback function AssignSegmentFunc is nil") + } + + if createBinlogsFunc == nil { + log.Error("import wrapper: callback function CreateBinlogsFunc is nil") + return fmt.Errorf("import wrapper: callback function CreateBinlogsFunc is nil") + } + + if saveSegmentFunc == nil { + log.Error("import wrapper: callback function SaveSegmentFunc is nil") + return fmt.Errorf("import wrapper: callback function SaveSegmentFunc is nil") + } + + p.assignSegmentFunc = assignSegmentFunc + p.createBinlogsFunc = createBinlogsFunc + p.saveSegmentFunc = saveSegmentFunc + return nil +} + +// Cancel method can be used to cancel parse process func (p *ImportWrapper) Cancel() error { p.cancel() return nil } -func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[storage.FieldID]storage.FieldData, msg string, files []string) { - stats := make([]zapcore.Field, 0) - for k, v := range fieldsData { - stats = append(stats, zap.Int(strconv.FormatInt(k, 10), v.RowNum())) +func (p *ImportWrapper) validateColumnBasedFiles(filePaths []string, collectionSchema *schemapb.CollectionSchema) error { + requiredFieldNames := make(map[string]interface{}) + for _, schema := range p.collectionSchema.Fields { + if schema.GetIsPrimaryKey() { + if !schema.GetAutoID() { + requiredFieldNames[schema.GetName()] = nil + } + } else { + requiredFieldNames[schema.GetName()] = nil + } } - if len(files) > 0 { - stats = append(stats, zap.Any("files", files)) + // check redundant file + fileNames := make(map[string]interface{}) + for _, filePath := range filePaths { + name, _ := GetFileNameAndExt(filePath) + fileNames[name] = nil + _, ok := requiredFieldNames[name] + if !ok { + log.Error("import wrapper: the file has no corresponding field in collection", zap.String("fieldName", name)) + return fmt.Errorf("import wrapper: the file '%s' has no corresponding field in collection", filePath) + } } - log.Info(msg, stats...) + + // check missed file + for name := range requiredFieldNames { + _, ok := fileNames[name] + if !ok { + log.Error("import wrapper: there is no file corresponding to field", zap.String("fieldName", name)) + return fmt.Errorf("import wrapper: there is no file corresponding to field '%s'", name) + } + } + + return nil } -func getFileNameAndExt(filePath string) (string, string) { - fileName := path.Base(filePath) - fileType := path.Ext(fileName) - fileNameWithoutExt := strings.TrimSuffix(fileName, fileType) - return fileNameWithoutExt, fileType -} - -// trigger golang gc to return all free memory back to the underlying system at once, -// Note: this operation is expensive, and can lead to latency spikes as it holds the heap lock through the whole process -func triggerGC() { - debug.FreeOSMemory() -} - -func (p *ImportWrapper) fileValidation(filePaths []string, rowBased bool) error { +// fileValidation verify the input paths +// if all the files are json type, return true +// if all the files are numpy type, return false, and not allow duplicate file name +func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) { // use this map to check duplicate file name(only for numpy file) fileNames := make(map[string]struct{}) totalSize := int64(0) + rowBased := false for i := 0; i < len(filePaths); i++ { filePath := filePaths[i] - name, fileType := getFileNameAndExt(filePath) - _, ok := fileNames[name] - if ok { - // only check dupliate numpy file - if fileType == NumpyFileExt { - log.Error("import wrapper: duplicate file name", zap.String("fileName", name+"."+fileType)) - return errors.New("duplicate file: " + name + "." + fileType) - } - } else { - fileNames[name] = struct{}{} + name, fileType := GetFileNameAndExt(filePath) + + // only allow json file or numpy file + if fileType != JSONFileExt && fileType != NumpyFileExt { + log.Error("import wrapper: unsupportted file type", zap.String("filePath", filePath)) + return false, fmt.Errorf("import wrapper: unsupportted file type: '%s'", filePath) + } + + // we use the first file to determine row-based or column-based + if i == 0 && fileType == JSONFileExt { + rowBased = true } // check file type - // row-based only support json type, column-based can support json and numpy type + // row-based only support json type, column-based only support numpy type if rowBased { if fileType != JSONFileExt { log.Error("import wrapper: unsupported file type for row-based mode", zap.String("filePath", filePath)) - return errors.New("unsupported file type for row-based mode: " + filePath) + return rowBased, fmt.Errorf("import wrapper: unsupported file type for row-based mode: '%s'", filePath) } } else { - if fileType != JSONFileExt && fileType != NumpyFileExt { + if fileType != NumpyFileExt { log.Error("import wrapper: unsupported file type for column-based mode", zap.String("filePath", filePath)) - return errors.New("unsupported file type for column-based mode: " + filePath) + return rowBased, fmt.Errorf("import wrapper: unsupported file type for column-based mode: '%s'", filePath) } } + // check dupliate file + _, ok := fileNames[name] + if ok { + log.Error("import wrapper: duplicate file name", zap.String("filePath", filePath)) + return rowBased, fmt.Errorf("import wrapper: duplicate file: '%s'", filePath) + } + fileNames[name] = struct{}{} + // check file size, single file size cannot exceed MaxFileSize // TODO add context size, err := p.chunkManager.Size(context.TODO(), filePath) if err != nil { log.Error("import wrapper: failed to get file size", zap.String("filePath", filePath), zap.Any("err", err)) - return errors.New("failed to get file size of " + filePath) + return rowBased, fmt.Errorf("import wrapper: failed to get file size of '%s'", filePath) } + // empty file if size == 0 { - log.Error("import wrapper: file path is empty", zap.String("filePath", filePath)) - return errors.New("the file " + filePath + " is empty") + log.Error("import wrapper: file size is zero", zap.String("filePath", filePath)) + return rowBased, fmt.Errorf("import wrapper: the file '%s' size is zero", filePath) } if size > MaxFileSize { log.Error("import wrapper: file size exceeds the maximum size", zap.String("filePath", filePath), zap.Int64("fileSize", size), zap.Int64("MaxFileSize", MaxFileSize)) - return errors.New("the file " + filePath + " size exceeds the maximum size: " + strconv.FormatInt(MaxFileSize, 10) + " bytes") + return rowBased, fmt.Errorf("import wrapper: the file '%s' size exceeds the maximum size: %d bytes", filePath, MaxFileSize) } totalSize += size } @@ -208,26 +288,36 @@ func (p *ImportWrapper) fileValidation(filePaths []string, rowBased bool) error // especially for column-base, total size of files cannot exceed MaxTotalSizeInMemory if totalSize > MaxTotalSizeInMemory { log.Error("import wrapper: total size of files exceeds the maximum size", zap.Int64("totalSize", totalSize), zap.Int64("MaxTotalSize", MaxTotalSizeInMemory)) - return errors.New("the total size of all files exceeds the maximum size: " + strconv.FormatInt(MaxTotalSizeInMemory, 10) + " bytes") + return rowBased, fmt.Errorf("import wrapper: total size(%d bytes) of all files exceeds the maximum size: %d bytes", totalSize, MaxTotalSizeInMemory) } - return nil + // check redundant files for column-based import + // if the field is primary key and autoid is false, the file is required + // any redundant file is not allowed + if !rowBased { + err := p.validateColumnBasedFiles(filePaths, p.collectionSchema) + if err != nil { + return rowBased, err + } + } + + return rowBased, nil } -// import process entry +// Import is the entry of import operation // filePath and rowBased are from ImportTask -// if onlyValidate is true, this process only do validation, no data generated, callFlushFunc will not be called -func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate bool) error { - log.Info("import wrapper: filePaths", zap.Any("filePaths", filePaths)) +// if onlyValidate is true, this process only do validation, no data generated, flushFunc will not be called +func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error { + log.Info("import wrapper: begin import", zap.Any("filePaths", filePaths), zap.Any("options", options)) // data restore function to import milvus native binlog files(for backup/restore tools) // the backup/restore tool provide two paths for a partition, the first path is binlog path, the second is deltalog path if p.isBinlogImport(filePaths) { // TODO: handle the timestamp end point passed from client side, currently use math.MaxUint64 - return p.doBinlogImport(filePaths, math.MaxUint64) + return p.doBinlogImport(filePaths, options.TsStartPoint, options.TsEndPoint) } // normal logic for import general data files - err := p.fileValidation(filePaths, rowBased) + rowBased, err := p.fileValidation(filePaths) if err != nil { return err } @@ -235,16 +325,16 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b if rowBased { // parse and consume row-based files // for row-based files, the JSONRowConsumer will generate autoid for primary key, and split rows into segments - // according to shard number, so the callFlushFunc will be called in the JSONRowConsumer + // according to shard number, so the flushFunc will be called in the JSONRowConsumer for i := 0; i < len(filePaths); i++ { filePath := filePaths[i] - _, fileType := getFileNameAndExt(filePath) + _, fileType := GetFileNameAndExt(filePath) log.Info("import wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType)) if fileType == JSONFileExt { - err = p.parseRowBasedJSON(filePath, onlyValidate) + err = p.parseRowBasedJSON(filePath, options.OnlyValidate) if err != nil { - log.Error("import error: "+err.Error(), zap.String("filePath", filePath)) + log.Error("import wrapper: failed to parse row-based json file", zap.Any("err", err), zap.String("filePath", filePath)) return err } } // no need to check else, since the fileValidation() already do this @@ -260,7 +350,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b fieldsData := initSegmentData(p.collectionSchema) if fieldsData == nil { log.Error("import wrapper: failed to initialize FieldData list") - return errors.New("failed to initialize FieldData list") + return fmt.Errorf("import wrapper: failed to initialize FieldData list") } rowCount := 0 @@ -271,25 +361,32 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b return nil } - p.printFieldsDataInfo(fields, "import wrapper: combine field data", nil) + printFieldsDataInfo(fields, "import wrapper: combine field data", nil) tr := timerecord.NewTimeRecorder("combine field data") defer tr.Elapse("finished") for k, v := range fields { // ignore 0 row field if v.RowNum() == 0 { + log.Warn("import wrapper: empty FieldData ignored", zap.Int64("fieldID", k)) + continue + } + + // ignore internal fields: RowIDField and TimeStampField + if k == common.RowIDField || k == common.TimeStampField { + log.Warn("import wrapper: internal fields should not be provided", zap.Int64("fieldID", k)) continue } // each column should be only combined once data, ok := fieldsData[k] if ok && data.RowNum() > 0 { - return errors.New("the field " + strconv.FormatInt(k, 10) + " is duplicated") + return fmt.Errorf("the field %d is duplicated", k) } // check the row count. only count non-zero row fields if rowCount > 0 && rowCount != v.RowNum() { - return errors.New("the field " + strconv.FormatInt(k, 10) + " row count " + strconv.Itoa(v.RowNum()) + " doesn't equal " + strconv.Itoa(rowCount)) + return fmt.Errorf("the field %d row count %d doesn't equal others row count: %d", k, v.RowNum(), rowCount) } rowCount = v.RowNum() @@ -303,20 +400,14 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b // parse/validate/consume data for i := 0; i < len(filePaths); i++ { filePath := filePaths[i] - _, fileType := getFileNameAndExt(filePath) + _, fileType := GetFileNameAndExt(filePath) log.Info("import wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType)) - if fileType == JSONFileExt { - err = p.parseColumnBasedJSON(filePath, onlyValidate, combineFunc) - if err != nil { - log.Error("import error: "+err.Error(), zap.String("filePath", filePath)) - return err - } - } else if fileType == NumpyFileExt { - err = p.parseColumnBasedNumpy(filePath, onlyValidate, combineFunc) + if fileType == NumpyFileExt { + err = p.parseColumnBasedNumpy(filePath, options.OnlyValidate, combineFunc) if err != nil { - log.Error("import error: "+err.Error(), zap.String("filePath", filePath)) + log.Error("import wrapper: failed to parse column-based numpy file", zap.Any("err", err), zap.String("filePath", filePath)) return err } } @@ -327,7 +418,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b triggerGC() // split fields data into segments - err := p.splitFieldsData(fieldsData, filePaths) + err := p.splitFieldsData(fieldsData, SingleBlockSize) if err != nil { return err } @@ -339,7 +430,14 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b return p.reportPersisted() } +// reportPersisted notify the rootcoord to mark the task state to be ImportPersisted func (p *ImportWrapper) reportPersisted() error { + // force close all segments + err := p.closeAllWorkingSegments() + if err != nil { + return err + } + // report file process state p.importResult.State = commonpb.ImportState_ImportPersisted // persist state task is valuable, retry more times in case fail this task only because of network error @@ -353,6 +451,7 @@ func (p *ImportWrapper) reportPersisted() error { return nil } +// isBinlogImport is to judge whether it is binlog import operation // For internal usage by the restore tool: https://github.com/zilliztech/milvus-backup // This tool exports data from a milvus service, and call bulkload interface to import native data into another milvus service. // This tool provides two paths: one is data log path of a partition,the other is delta log path of this partition. @@ -360,16 +459,16 @@ func (p *ImportWrapper) reportPersisted() error { func (p *ImportWrapper) isBinlogImport(filePaths []string) bool { // must contains the insert log path, and the delta log path is optional if len(filePaths) != 1 && len(filePaths) != 2 { - log.Info("import wrapper: paths count is not 1 or 2", zap.Int("len", len(filePaths))) + log.Info("import wrapper: paths count is not 1 or 2, not binlog import", zap.Int("len", len(filePaths))) return false } for i := 0; i < len(filePaths); i++ { filePath := filePaths[i] - _, fileType := getFileNameAndExt(filePath) + _, fileType := GetFileNameAndExt(filePath) // contains file extension, is not a path if len(fileType) != 0 { - log.Info("import wrapper: not a path", zap.String("filePath", filePath), zap.String("fileType", fileType)) + log.Info("import wrapper: not a path, not binlog import", zap.String("filePath", filePath), zap.String("fileType", fileType)) return false } } @@ -378,12 +477,14 @@ func (p *ImportWrapper) isBinlogImport(filePaths []string) bool { return true } -func (p *ImportWrapper) doBinlogImport(filePaths []string, tsEndPoint uint64) error { +// doBinlogImport is the entry of binlog import operation +func (p *ImportWrapper) doBinlogImport(filePaths []string, tsStartPoint uint64, tsEndPoint uint64) error { flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { - p.printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths) - return p.callFlushFunc(fields, shardID) + printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths) + return p.flushFunc(fields, shardID) } - parser, err := NewBinlogParser(p.collectionSchema, p.shardNum, p.segmentSize, p.chunkManager, flushFunc, tsEndPoint) + parser, err := NewBinlogParser(p.ctx, p.collectionSchema, p.shardNum, SingleBlockSize, p.chunkManager, flushFunc, + tsStartPoint, tsEndPoint) if err != nil { return err } @@ -396,6 +497,7 @@ func (p *ImportWrapper) doBinlogImport(filePaths []string, tsEndPoint uint64) er return p.reportPersisted() } +// parseRowBasedJSON is the entry of row-based json import operation func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) error { tr := timerecord.NewTimeRecorder("json row-based parser: " + filePath) @@ -417,14 +519,16 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er if !onlyValidate { flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { var filePaths = []string{filePath} - p.printFieldsDataInfo(fields, "import wrapper: prepare to flush segment", filePaths) - return p.callFlushFunc(fields, shardID) + printFieldsDataInfo(fields, "import wrapper: prepare to flush binlogs", filePaths) + return p.flushFunc(fields, shardID) } - consumer, err = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, p.segmentSize, flushFunc) + + consumer, err = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, SingleBlockSize, flushFunc) if err != nil { return err } } + validator, err := NewJSONRowValidator(p.collectionSchema, consumer) if err != nil { return err @@ -444,52 +548,14 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er return nil } -func (p *ImportWrapper) parseColumnBasedJSON(filePath string, onlyValidate bool, - combineFunc func(fields map[storage.FieldID]storage.FieldData) error) error { - tr := timerecord.NewTimeRecorder("json column-based parser: " + filePath) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // for minio storage, chunkManager will download file into local memory - // for local storage, chunkManager open the file directly - file, err := p.chunkManager.Reader(ctx, filePath) - if err != nil { - return err - } - defer file.Close() - - // parse file - reader := bufio.NewReader(file) - parser := NewJSONParser(p.ctx, p.collectionSchema) - var consumer *JSONColumnConsumer - if !onlyValidate { - consumer, err = NewJSONColumnConsumer(p.collectionSchema, combineFunc) - if err != nil { - return err - } - } - validator, err := NewJSONColumnValidator(p.collectionSchema, consumer) - if err != nil { - return err - } - - err = parser.ParseColumns(reader, validator) - if err != nil { - return err - } - - tr.Elapse("parsed") - return nil -} - +// parseColumnBasedNumpy is the entry of column-based numpy import operation func (p *ImportWrapper) parseColumnBasedNumpy(filePath string, onlyValidate bool, combineFunc func(fields map[storage.FieldID]storage.FieldData) error) error { tr := timerecord.NewTimeRecorder("numpy parser: " + filePath) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - fileName, _ := getFileNameAndExt(filePath) + fileName, _ := GetFileNameAndExt(filePath) // for minio storage, chunkManager will download file into local memory // for local storage, chunkManager open the file directly @@ -532,6 +598,7 @@ func (p *ImportWrapper) parseColumnBasedNumpy(filePath string, onlyValidate bool return nil } +// appendFunc defines the methods to append data to storage.FieldData func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error { switch schema.DataType { case schemapb.DataType_Bool: @@ -608,13 +675,14 @@ func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storag } } -func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, files []string) error { +// splitFieldsData is to split the in-memory data(parsed from column-based files) into blocks, each block save to a binlog file +func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, blockSize int64) error { if len(fieldsData) == 0 { log.Error("import wrapper: fields data is empty") - return errors.New("import error: fields data is empty") + return fmt.Errorf("import wrapper: fields data is empty") } - tr := timerecord.NewTimeRecorder("split field data") + tr := timerecord.NewTimeRecorder("import wrapper: split field data") defer tr.Elapse("finished") // check existence of each field @@ -633,7 +701,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F v, ok := fieldsData[schema.GetFieldID()] if !ok { log.Error("import wrapper: field not provided", zap.String("fieldName", schema.GetName())) - return errors.New("import error: field " + schema.GetName() + " not provided") + return fmt.Errorf("import wrapper: field '%s' not provided", schema.GetName()) } rowCounter[schema.GetName()] = v.RowNum() if v.RowNum() > rowCount { @@ -643,27 +711,30 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F } if primaryKey == nil { log.Error("import wrapper: primary key field is not found") - return errors.New("import error: primary key field is not found") + return fmt.Errorf("import wrapper: primary key field is not found") } for name, count := range rowCounter { if count != rowCount { log.Error("import wrapper: field row count is not equal to other fields row count", zap.String("fieldName", name), zap.Int("rowCount", count), zap.Int("otherRowCount", rowCount)) - return errors.New("import error: field " + name + " row count " + strconv.Itoa(count) + " is not equal to other fields row count " + strconv.Itoa(rowCount)) + return fmt.Errorf("import wrapper: field '%s' row count %d is not equal to other fields row count: %d", name, count, rowCount) } } + log.Info("import wrapper: try to split a block with row count", zap.Int("rowCount", rowCount), zap.Any("rowCountOfEachField", rowCounter)) primaryData, ok := fieldsData[primaryKey.GetFieldID()] if !ok { log.Error("import wrapper: primary key field is not provided", zap.String("keyName", primaryKey.GetName())) - return errors.New("import error: primary key field is not provided") + return fmt.Errorf("import wrapper: primary key field is not provided") } // generate auto id for primary key and rowid field - var rowIDBegin typeutil.UniqueID - var rowIDEnd typeutil.UniqueID - rowIDBegin, rowIDEnd, _ = p.rowIDAllocator.Alloc(uint32(rowCount)) + rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount)) + if err != nil { + log.Error("import wrapper: failed to alloc row ID", zap.Any("err", err)) + return err + } rowIDField := fieldsData[common.RowIDField] rowIDFieldArr := rowIDField.(*storage.Int64FieldData) @@ -672,19 +743,25 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F } if primaryKey.GetAutoID() { - log.Info("import wrapper: generating auto-id", zap.Any("rowCount", rowCount)) + log.Info("import wrapper: generating auto-id", zap.Int("rowCount", rowCount), zap.Int64("rowIDBegin", rowIDBegin)) - primaryDataArr := primaryData.(*storage.Int64FieldData) + // reset the primary keys, as we know, only int64 pk can be auto-generated + primaryDataArr := &storage.Int64FieldData{ + NumRows: []int64{int64(rowCount)}, + Data: make([]int64, 0, rowCount), + } for i := rowIDBegin; i < rowIDEnd; i++ { primaryDataArr.Data = append(primaryDataArr.Data, i) } + primaryData = primaryDataArr + fieldsData[primaryKey.GetFieldID()] = primaryData p.importResult.AutoIds = append(p.importResult.AutoIds, rowIDBegin, rowIDEnd) } if primaryData.RowNum() <= 0 { log.Error("import wrapper: primary key not provided", zap.String("keyName", primaryKey.GetName())) - return errors.New("import wrapper: primary key " + primaryKey.GetName() + " not provided") + return fmt.Errorf("import wrapper: the primary key '%s' not provided", primaryKey.GetName()) } // prepare segemnts @@ -693,7 +770,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F segmentData := initSegmentData(p.collectionSchema) if segmentData == nil { log.Error("import wrapper: failed to initialize FieldData list") - return errors.New("failed to initialize FieldData list") + return fmt.Errorf("import wrapper: failed to initialize FieldData list") } segmentsData = append(segmentsData, segmentData) } @@ -705,12 +782,12 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F appendFuncErr := p.appendFunc(schema) if appendFuncErr == nil { log.Error("import wrapper: unsupported field data type") - return errors.New("import wrapper: unsupported field data type") + return fmt.Errorf("import wrapper: unsupported field data type") } appendFunctions[schema.GetName()] = appendFuncErr } - // split data into segments + // split data into shards for i := 0; i < rowCount; i++ { // hash to a shard number var shard uint32 @@ -723,7 +800,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F intPK, ok := interface{}(pk).(int64) if !ok { log.Error("import wrapper: primary key field must be int64 or varchar") - return errors.New("import error: primary key field must be int64 or varchar") + return fmt.Errorf("import wrapper: primary key field must be int64 or varchar") } hash, _ := typeutil.Hash32Int64(intPK) shard = hash % uint32(p.shardNum) @@ -744,18 +821,118 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F return err } } - } - // call flush function - for i := 0; i < int(p.shardNum); i++ { - segmentData := segmentsData[i] - p.printFieldsDataInfo(segmentData, "import wrapper: prepare to flush segment", files) - err := p.callFlushFunc(segmentData, i) + // when the estimated size is close to blockSize, force flush + err = tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, false) if err != nil { - log.Error("import wrapper: flush callback function failed", zap.Any("err", err)) return err } } + // force flush at the end + return tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, true) +} + +// flushFunc is the callback function for parsers generate segment and save binlog files +func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData, shardID int) error { + // if fields data is empty, do nothing + var rowNum int + memSize := 0 + for _, field := range fields { + rowNum = field.RowNum() + memSize += field.GetMemorySize() + break + } + if rowNum <= 0 { + log.Warn("import wrapper: fields data is empty", zap.Int("shardID", shardID)) + return nil + } + + // if there is no segment for this shard, create a new one + // if the segment exists and its size almost exceed segmentSize, close it and create a new one + var segment *WorkingSegment + segment, ok := p.workingSegments[shardID] + if ok { + // the segment already exists, check its size, if the size exceeds(or almost) segmentSize, close the segment + if int64(segment.memSize)+int64(memSize) >= p.segmentSize { + err := p.closeWorkingSegment(segment) + if err != nil { + return err + } + segment = nil + p.workingSegments[shardID] = nil + } + + } + + if segment == nil { + // create a new segment + segID, channelName, err := p.assignSegmentFunc(shardID) + if err != nil { + log.Error("import wrapper: failed to assign a new segment", zap.Any("error", err), zap.Int("shardID", shardID)) + return err + } + + segment = &WorkingSegment{ + segmentID: segID, + shardID: shardID, + targetChName: channelName, + rowCount: int64(0), + memSize: 0, + fieldsInsert: make([]*datapb.FieldBinlog, 0), + fieldsStats: make([]*datapb.FieldBinlog, 0), + } + p.workingSegments[shardID] = segment + } + + // save binlogs + fieldsInsert, fieldsStats, err := p.createBinlogsFunc(fields, segment.segmentID) + if err != nil { + log.Error("import wrapper: failed to save binlogs", zap.Any("error", err), zap.Int("shardID", shardID), + zap.Int64("segmentID", segment.segmentID), zap.String("targetChannel", segment.targetChName)) + return err + } + + segment.fieldsInsert = append(segment.fieldsInsert, fieldsInsert...) + segment.fieldsStats = append(segment.fieldsStats, fieldsStats...) + segment.rowCount += int64(rowNum) + segment.memSize += memSize + + return nil +} + +// closeWorkingSegment marks a segment to be sealed +func (p *ImportWrapper) closeWorkingSegment(segment *WorkingSegment) error { + log.Info("import wrapper: adding segment to the correct DataNode flow graph and saving binlog paths", + zap.Int("shardID", segment.shardID), + zap.Int64("segmentID", segment.segmentID), + zap.String("targetChannel", segment.targetChName), + zap.Int64("rowCount", segment.rowCount), + zap.Int("insertLogCount", len(segment.fieldsInsert)), + zap.Int("statsLogCount", len(segment.fieldsStats))) + + err := p.saveSegmentFunc(segment.fieldsInsert, segment.fieldsStats, segment.segmentID, segment.targetChName, segment.rowCount) + if err != nil { + log.Error("import wrapper: failed to save segment", + zap.Any("error", err), + zap.Int("shardID", segment.shardID), + zap.Int64("segmentID", segment.segmentID), + zap.String("targetChannel", segment.targetChName)) + return err + } + + return nil +} + +// closeAllWorkingSegments mark all segments to be sealed at the end of import operation +func (p *ImportWrapper) closeAllWorkingSegments() error { + for _, segment := range p.workingSegments { + err := p.closeWorkingSegment(segment) + if err != nil { + return err + } + } + p.workingSegments = make(map[int]*WorkingSegment) + return nil } diff --git a/internal/util/importutil/import_wrapper_test.go b/internal/util/importutil/import_wrapper_test.go index 3d08c04577..2222bb0b54 100644 --- a/internal/util/importutil/import_wrapper_test.go +++ b/internal/util/importutil/import_wrapper_test.go @@ -22,18 +22,19 @@ import ( "context" "encoding/json" "errors" + "math" + "os" "strconv" "testing" "time" "github.com/stretchr/testify/assert" - "go.uber.org/zap" "golang.org/x/exp/mmap" "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/common" - "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" @@ -139,12 +140,44 @@ func (mc *MockChunkManager) RemoveWithPrefix(ctx context.Context, prefix string) return nil } +type rowCounterTest struct { + rowCount int + callTime int +} + +func createMockCallbackFunctions(t *testing.T, rowCounter *rowCounterTest) (AssignSegmentFunc, CreateBinlogsFunc, SaveSegmentFunc) { + createBinlogFunc := func(fields map[storage.FieldID]storage.FieldData, segmentID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error) { + count := 0 + for _, data := range fields { + assert.Less(t, 0, data.RowNum()) + if count == 0 { + count = data.RowNum() + } else { + assert.Equal(t, count, data.RowNum()) + } + } + rowCounter.rowCount += count + rowCounter.callTime++ + return nil, nil, nil + } + + assignSegmentFunc := func(shardID int) (int64, string, error) { + return 100, "ch", nil + } + + saveSegmentFunc := func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64) error { + return nil + } + + return assignSegmentFunc, createBinlogFunc, saveSegmentFunc +} + func Test_NewImportWrapper(t *testing.T) { f := dependency.NewDefaultFactory(true) ctx := context.Background() cm, err := f.NewPersistentStorageChunkManager(ctx) assert.NoError(t, err) - wrapper := NewImportWrapper(ctx, nil, 2, 1, nil, cm, nil, nil, nil) + wrapper := NewImportWrapper(ctx, nil, 2, 1, nil, cm, nil, nil) assert.Nil(t, wrapper) schema := &schemapb.CollectionSchema{ @@ -162,20 +195,43 @@ func Test_NewImportWrapper(t *testing.T) { Description: "int64", DataType: schemapb.DataType_Int64, }) - wrapper = NewImportWrapper(ctx, schema, 2, 1, nil, cm, nil, nil, nil) + wrapper = NewImportWrapper(ctx, schema, 2, 1, nil, cm, nil, nil) assert.NotNil(t, wrapper) + assignSegFunc := func(shardID int) (int64, string, error) { + return 0, "", nil + } + createBinFunc := func(fields map[storage.FieldID]storage.FieldData, segmentID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error) { + return nil, nil, nil + } + saveBinFunc := func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64) error { + return nil + } + + err = wrapper.SetCallbackFunctions(assignSegFunc, createBinFunc, saveBinFunc) + assert.Nil(t, err) + err = wrapper.SetCallbackFunctions(assignSegFunc, createBinFunc, nil) + assert.NotNil(t, err) + err = wrapper.SetCallbackFunctions(assignSegFunc, nil, nil) + assert.NotNil(t, err) + err = wrapper.SetCallbackFunctions(nil, nil, nil) + assert.NotNil(t, err) + err = wrapper.Cancel() assert.Nil(t, err) } func Test_ImportWrapperRowBased(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + f := dependency.NewDefaultFactory(true) ctx := context.Background() cm, err := f.NewPersistentStorageChunkManager(ctx) assert.NoError(t, err) - idAllocator := newIDAllocator(ctx, t) + idAllocator := newIDAllocator(ctx, t, nil) content := []byte(`{ "rows":[ @@ -192,20 +248,8 @@ func Test_ImportWrapperRowBased(t *testing.T) { assert.NoError(t, err) defer cm.RemoveWithPrefix(ctx, "") - rowCount := 0 - flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { - count := 0 - for _, data := range fields { - assert.Less(t, 0, data.RowNum()) - if count == 0 { - count = data.RowNum() - } else { - assert.Equal(t, count, data.RowNum()) - } - } - rowCount += count - return nil - } + rowCounter := &rowCounterTest{} + assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) // success case importResult := &rootcoordpb.ImportResult{ @@ -222,12 +266,17 @@ func Test_ImportWrapperRowBased(t *testing.T) { reportFunc := func(res *rootcoordpb.ImportResult) error { return nil } - wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) + wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, importResult, reportFunc) + wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) files := make([]string, 0) files = append(files, filePath) - err = wrapper.Import(files, true, false) + err = wrapper.Import(files, ImportOptions{OnlyValidate: true}) assert.Nil(t, err) - assert.Equal(t, 5, rowCount) + assert.Equal(t, 0, rowCounter.rowCount) + + err = wrapper.Import(files, DefaultImportOptions()) + assert.Nil(t, err) + assert.Equal(t, 5, rowCounter.rowCount) assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) // parse error @@ -242,246 +291,113 @@ func Test_ImportWrapperRowBased(t *testing.T) { assert.NoError(t, err) importResult.State = commonpb.ImportState_ImportStarted - wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) + wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, importResult, reportFunc) + wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) files = make([]string, 0) files = append(files, filePath) - err = wrapper.Import(files, true, false) + err = wrapper.Import(files, ImportOptions{OnlyValidate: true}) assert.NotNil(t, err) assert.NotEqual(t, commonpb.ImportState_ImportPersisted, importResult.State) // file doesn't exist files = make([]string, 0) files = append(files, "/dummy/dummy.json") - err = wrapper.Import(files, true, false) + err = wrapper.Import(files, ImportOptions{OnlyValidate: true}) assert.NotNil(t, err) } -func Test_ImportWrapperColumnBased_json(t *testing.T) { - f := dependency.NewDefaultFactory(true) +func createSampleNumpyFiles(t *testing.T, cm storage.ChunkManager) []string { ctx := context.Background() - cm, err := f.NewPersistentStorageChunkManager(ctx) - assert.NoError(t, err) - defer cm.RemoveWithPrefix(ctx, "") - - idAllocator := newIDAllocator(ctx, t) - - content := []byte(`{ - "field_bool": [true, false, true, true, true], - "field_int8": [10, 11, 12, 13, 14], - "field_int16": [100, 101, 102, 103, 104], - "field_int32": [1000, 1001, 1002, 1003, 1004], - "field_int64": [10000, 10001, 10002, 10003, 10004], - "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], - "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], - "field_string": ["a", "b", "c", "d", "e"], - "field_binary_vector": [ - [254, 1], - [253, 2], - [252, 3], - [251, 4], - [250, 5] - ], - "field_float_vector": [ - [1.1, 1.2, 1.3, 1.4], - [2.1, 2.2, 2.3, 2.4], - [3.1, 3.2, 3.3, 3.4], - [4.1, 4.2, 4.3, 4.4], - [5.1, 5.2, 5.3, 5.4] - ] - }`) - - filePath := TempFilesPath + "columns_1.json" - err = cm.Write(ctx, filePath, content) - assert.NoError(t, err) - - rowCount := 0 - flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { - count := 0 - for _, data := range fields { - assert.Less(t, 0, data.RowNum()) - if count == 0 { - count = data.RowNum() - } else { - assert.Equal(t, count, data.RowNum()) - } - } - rowCount += count - return nil - } - - // success case - importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - TaskId: 1, - DatanodeId: 1, - State: commonpb.ImportState_ImportStarted, - Segments: make([]int64, 0), - AutoIds: make([]int64, 0), - RowCount: 0, - } - reportFunc := func(res *rootcoordpb.ImportResult) error { - return nil - } - wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) files := make([]string, 0) - files = append(files, filePath) - err = wrapper.Import(files, false, false) + + filePath := "field_bool.npy" + content, err := CreateNumpyData([]bool{true, false, true, true, true}) assert.Nil(t, err) - assert.Equal(t, 5, rowCount) - assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) - - // parse error - content = []byte(`{ - "field_bool": [true, false, true, true, true] - }`) - - filePath = TempFilesPath + "rows_2.json" err = cm.Write(ctx, filePath, content) assert.NoError(t, err) - - importResult.State = commonpb.ImportState_ImportStarted - wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) - files = make([]string, 0) files = append(files, filePath) - err = wrapper.Import(files, false, false) - assert.NotNil(t, err) - assert.NotEqual(t, commonpb.ImportState_ImportPersisted, importResult.State) - // file doesn't exist - files = make([]string, 0) - files = append(files, "/dummy/dummy.json") - err = wrapper.Import(files, false, false) - assert.NotNil(t, err) -} - -func Test_ImportWrapperColumnBased_StringKey(t *testing.T) { - f := dependency.NewDefaultFactory(true) - ctx := context.Background() - cm, err := f.NewPersistentStorageChunkManager(ctx) - assert.NoError(t, err) - defer cm.RemoveWithPrefix(ctx, "") - - idAllocator := newIDAllocator(ctx, t) - - content := []byte(`{ - "uid": ["Dm4aWrbNzhmjwCTEnCJ9LDPO2N09sqysxgVfbH9Zmn3nBzmwsmk0eZN6x7wSAoPQ", "RP50U0d2napRjXu94a8oGikWgklvVsXFurp8RR4tHGw7N0gk1b7opm59k3FCpyPb", "oxhFkQitWPPw0Bjmj7UQcn4iwvS0CU7RLAC81uQFFQjWtOdiB329CPyWkfGSeYfE", "sxoEL4Mpk1LdsyXhbNm059UWJ3CvxURLCQczaVI5xtBD4QcVWTDFUW7dBdye6nbn", "g33Rqq2UQSHPRHw5FvuXxf5uGEhIAetxE6UuXXCJj0hafG8WuJr1ueZftsySCqAd"], - "int_scalar": [9070353, 8505288, 4392660, 7927425, 9288807], - "float_scalar": [0.9798043638085004, 0.937913432198687, 0.32381232630490264, 0.31074026464844895, 0.4953578200336135], - "string_scalar": ["ShQ44OX0z8kGpRPhaXmfSsdH7JHq5DsZzu0e2umS1hrWG0uONH2RIIAdOECaaXir", "Ld4b0avxathBdNvCrtm3QsWO1pYktUVR7WgAtrtozIwrA8vpeactNhJ85CFGQnK5", "EmAlB0xdQcxeBtwlZJQnLgKodiuRinynoQtg0eXrjkq24dQohzSm7Bx3zquHd3kO", "fdY2beCvs1wSws0Gb9ySD92xwfEfJpX5DQgsWoISylBAoYOcXpRaqIJoXYS4g269", "6f8Iv1zQAGksj5XxMbbI5evTrYrB8fSFQ58jl0oU7Z4BpA81VsD2tlWqkhfoBNa7"], - "bool_scalar": [true, false, true, false, false], - "vectors": [ - [0.5040062902126952, 0.8297619818664708, 0.20248342801564806, 0.12834786423659314], - [0.528232122836893, 0.6916116750653186, 0.41443762522548705, 0.26624344144792056], - [0.7978693027281338, 0.12394906726785092, 0.42431962903815285, 0.4098707807351914], - [0.3716157812069954, 0.006981281113265229, 0.9007003458552365, 0.22492634316191004], - [0.5921374209648096, 0.04234832587925662, 0.7803878096531548, 0.1964045837884633] - ] - }`) - - filePath := TempFilesPath + "columns_2.json" + filePath = "field_int8.npy" + content, err = CreateNumpyData([]int8{10, 11, 12, 13, 14}) + assert.Nil(t, err) err = cm.Write(ctx, filePath, content) assert.NoError(t, err) - - rowCount := 0 - flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { - count := 0 - for _, data := range fields { - assert.Less(t, 0, data.RowNum()) - if count == 0 { - count = data.RowNum() - } else { - assert.Equal(t, count, data.RowNum()) - } - } - rowCount += count - return nil - } - - // success case - importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - TaskId: 1, - DatanodeId: 1, - State: commonpb.ImportState_ImportStarted, - Segments: make([]int64, 0), - AutoIds: make([]int64, 0), - RowCount: 0, - } - reportFunc := func(res *rootcoordpb.ImportResult) error { - return nil - } - wrapper := NewImportWrapper(ctx, strKeySchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) - files := make([]string, 0) files = append(files, filePath) - err = wrapper.Import(files, false, false) + + filePath = "field_int16.npy" + content, err = CreateNumpyData([]int16{100, 101, 102, 103, 104}) assert.Nil(t, err) - assert.Equal(t, 5, rowCount) - assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + files = append(files, filePath) + + filePath = "field_int32.npy" + content, err = CreateNumpyData([]int32{1000, 1001, 1002, 1003, 1004}) + assert.Nil(t, err) + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + files = append(files, filePath) + + filePath = "field_int64.npy" + content, err = CreateNumpyData([]int64{10000, 10001, 10002, 10003, 10004}) + assert.Nil(t, err) + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + files = append(files, filePath) + + filePath = "field_float.npy" + content, err = CreateNumpyData([]float32{3.14, 3.15, 3.16, 3.17, 3.18}) + assert.Nil(t, err) + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + files = append(files, filePath) + + filePath = "field_double.npy" + content, err = CreateNumpyData([]float64{5.1, 5.2, 5.3, 5.4, 5.5}) + assert.Nil(t, err) + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + files = append(files, filePath) + + filePath = "field_string.npy" + content, err = CreateNumpyData([]string{"a", "bb", "ccc", "dd", "e"}) + assert.Nil(t, err) + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + files = append(files, filePath) + + filePath = "field_binary_vector.npy" + content, err = CreateNumpyData([][2]uint8{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}}) + assert.Nil(t, err) + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + files = append(files, filePath) + + filePath = "field_float_vector.npy" + content, err = CreateNumpyData([][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}, {5, 6, 7, 8}, {7, 8, 9, 10}, {9, 10, 11, 12}}) + assert.Nil(t, err) + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + files = append(files, filePath) + + return files } func Test_ImportWrapperColumnBased_numpy(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + f := dependency.NewDefaultFactory(true) ctx := context.Background() cm, err := f.NewPersistentStorageChunkManager(ctx) assert.NoError(t, err) defer cm.RemoveWithPrefix(ctx, "") - idAllocator := newIDAllocator(ctx, t) + idAllocator := newIDAllocator(ctx, t, nil) - content := []byte(`{ - "field_bool": [true, false, true, true, true], - "field_int8": [10, 11, 12, 13, 14], - "field_int16": [100, 101, 102, 103, 104], - "field_int32": [1000, 1001, 1002, 1003, 1004], - "field_int64": [10000, 10001, 10002, 10003, 10004], - "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], - "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], - "field_string": ["a", "b", "c", "d", "e"] - }`) - - files := make([]string, 0) - - filePath := TempFilesPath + "scalar_fields.json" - err = cm.Write(ctx, filePath, content) - assert.NoError(t, err) - files = append(files, filePath) - - filePath = TempFilesPath + "field_binary_vector.npy" - bin := [][2]uint8{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}} - content, err = CreateNumpyData(bin) - assert.Nil(t, err) - log.Debug("content", zap.Any("c", content)) - err = cm.Write(ctx, filePath, content) - assert.NoError(t, err) - files = append(files, filePath) - - filePath = TempFilesPath + "field_float_vector.npy" - flo := [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}, {5, 6, 7, 8}, {7, 8, 9, 10}, {9, 10, 11, 12}} - content, err = CreateNumpyData(flo) - assert.Nil(t, err) - log.Debug("content", zap.Any("c", content)) - err = cm.Write(ctx, filePath, content) - assert.NoError(t, err) - files = append(files, filePath) - - rowCount := 0 - flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { - count := 0 - for _, data := range fields { - assert.Less(t, 0, data.RowNum()) - if count == 0 { - count = data.RowNum() - } else { - assert.Equal(t, count, data.RowNum()) - } - } - rowCount += count - return nil - } + rowCounter := &rowCounterTest{} + assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) // success case importResult := &rootcoordpb.ImportResult{ @@ -499,35 +415,38 @@ func Test_ImportWrapperColumnBased_numpy(t *testing.T) { return nil } schema := sampleSchema() - schema.Fields[4].AutoID = true - wrapper := NewImportWrapper(ctx, schema, 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) + wrapper := NewImportWrapper(ctx, schema, 2, 1, idAllocator, cm, importResult, reportFunc) + wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) - err = wrapper.Import(files, false, false) + files := createSampleNumpyFiles(t, cm) + err = wrapper.Import(files, DefaultImportOptions()) assert.Nil(t, err) - assert.Equal(t, 5, rowCount) + assert.Equal(t, 5, rowCounter.rowCount) assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) // parse error - content = []byte(`{ + content := []byte(`{ "field_bool": [true, false, true, true, true] }`) - filePath = TempFilesPath + "rows_2.json" + filePath := "rows_2.json" err = cm.Write(ctx, filePath, content) assert.NoError(t, err) importResult.State = commonpb.ImportState_ImportStarted - wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) + wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, importResult, reportFunc) + wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) + files = make([]string, 0) files = append(files, filePath) - err = wrapper.Import(files, false, false) + err = wrapper.Import(files, DefaultImportOptions()) assert.NotNil(t, err) assert.NotEqual(t, commonpb.ImportState_ImportPersisted, importResult.State) // file doesn't exist files = make([]string, 0) files = append(files, "/dummy/dummy.json") - err = wrapper.Import(files, false, false) + err = wrapper.Import(files, DefaultImportOptions()) assert.NotNil(t, err) } @@ -562,13 +481,17 @@ func perfSchema(dim int) *schemapb.CollectionSchema { } func Test_ImportWrapperRowBased_perf(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + f := dependency.NewDefaultFactory(true) ctx := context.Background() cm, err := f.NewPersistentStorageChunkManager(ctx) assert.NoError(t, err) defer cm.RemoveWithPrefix(ctx, "") - idAllocator := newIDAllocator(ctx, t) + idAllocator := newIDAllocator(ctx, t, nil) tr := timerecord.NewTimeRecorder("row-based parse performance") @@ -605,7 +528,7 @@ func Test_ImportWrapperRowBased_perf(t *testing.T) { tr.Record("generate " + strconv.Itoa(rowCount) + " rows") // generate a json file - filePath := TempFilesPath + "row_perf.json" + filePath := "row_perf.json" func() { var b bytes.Buffer bw := bufio.NewWriter(&b) @@ -620,21 +543,8 @@ func Test_ImportWrapperRowBased_perf(t *testing.T) { }() tr.Record("generate large json file " + filePath) - // parse the json file - parseCount := 0 - flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { - count := 0 - for _, data := range fields { - assert.Less(t, 0, data.RowNum()) - if count == 0 { - count = data.RowNum() - } else { - assert.Equal(t, count, data.RowNum()) - } - } - parseCount += count - return nil - } + rowCounter := &rowCounterTest{} + assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) schema := perfSchema(dim) @@ -652,127 +562,89 @@ func Test_ImportWrapperRowBased_perf(t *testing.T) { reportFunc := func(res *rootcoordpb.ImportResult) error { return nil } - wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, flushFunc, importResult, reportFunc) + wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, importResult, reportFunc) + wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) + files := make([]string, 0) files = append(files, filePath) - err = wrapper.Import(files, true, false) + err = wrapper.Import(files, DefaultImportOptions()) assert.Nil(t, err) - assert.Equal(t, rowCount, parseCount) + assert.Equal(t, rowCount, rowCounter.rowCount) tr.Record("parse large json file " + filePath) } -func Test_ImportWrapperColumnBased_perf(t *testing.T) { - f := dependency.NewDefaultFactory(true) +func Test_ImportWrapperValidateColumnBasedFiles(t *testing.T) { ctx := context.Background() - cm, err := f.NewPersistentStorageChunkManager(ctx) - assert.NoError(t, err) - defer cm.RemoveWithPrefix(ctx, "") - idAllocator := newIDAllocator(ctx, t) - - tr := timerecord.NewTimeRecorder("column-based parse performance") - - type IDCol struct { - ID []int64 + cm := &MockChunkManager{ + size: 1, } - type VectorCol struct { - Vector [][]float32 - } - - // change these parameters to test different cases - dim := 128 - rowCount := 10000 + idAllocator := newIDAllocator(ctx, t, nil) shardNum := 2 segmentSize := 512 // unit: MB - // generate rows data - ids := &IDCol{ - ID: make([]int64, 0, rowCount), - } - - vectors := &VectorCol{ - Vector: make([][]float32, 0, rowCount), - } - - for i := 0; i < rowCount; i++ { - ids.ID = append(ids.ID, int64(i)) - - vector := make([]float32, 0, dim) - for k := 0; k < dim; k++ { - vector = append(vector, float32(i)+3.1415926) - } - vectors.Vector = append(vectors.Vector, vector) - } - tr.Record("generate " + strconv.Itoa(rowCount) + " rows") - - // generate json files - saveFileFunc := func(filePath string, data interface{}) error { - var b bytes.Buffer - bw := bufio.NewWriter(&b) - - encoder := json.NewEncoder(bw) - err = encoder.Encode(data) - assert.Nil(t, err) - err = bw.Flush() - assert.NoError(t, err) - err = cm.Write(ctx, filePath, b.Bytes()) - assert.NoError(t, err) - return nil - } - - filePath1 := TempFilesPath + "ids.json" - err = saveFileFunc(filePath1, ids) - assert.Nil(t, err) - tr.Record("generate large json file " + filePath1) - - filePath2 := TempFilesPath + "vectors.json" - err = saveFileFunc(filePath2, vectors) - assert.Nil(t, err) - tr.Record("generate large json file " + filePath2) - - // parse the json file - parseCount := 0 - flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { - count := 0 - for _, data := range fields { - assert.Less(t, 0, data.RowNum()) - if count == 0 { - count = data.RowNum() - } else { - assert.Equal(t, count, data.RowNum()) - } - } - parseCount += count - return nil - } - - schema := perfSchema(dim) - - importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, + schema := &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "ID", + IsPrimaryKey: true, + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 102, + Name: "Age", + IsPrimaryKey: false, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 103, + Name: "Vector", + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "10"}, + }, + }, }, - TaskId: 1, - DatanodeId: 1, - State: commonpb.ImportState_ImportStarted, - Segments: make([]int64, 0), - AutoIds: make([]int64, 0), - RowCount: 0, } - reportFunc := func(res *rootcoordpb.ImportResult) error { - return nil - } - wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, flushFunc, importResult, reportFunc) - files := make([]string, 0) - files = append(files, filePath1) - files = append(files, filePath2) - err = wrapper.Import(files, false, false) - assert.Nil(t, err) - assert.Equal(t, rowCount, parseCount) - tr.Record("parse large json files: " + filePath1 + "," + filePath2) + wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) + + // file for PK is redundant + files := []string{"ID.npy", "Age.npy", "Vector.npy"} + err := wrapper.validateColumnBasedFiles(files, schema) + assert.NotNil(t, err) + + // file for PK is not redundant + schema.Fields[0].AutoID = false + err = wrapper.validateColumnBasedFiles(files, schema) + assert.Nil(t, err) + + // file missed + files = []string{"Age.npy", "Vector.npy"} + err = wrapper.validateColumnBasedFiles(files, schema) + assert.NotNil(t, err) + + files = []string{"ID.npy", "Vector.npy"} + err = wrapper.validateColumnBasedFiles(files, schema) + assert.NotNil(t, err) + + // redundant file + files = []string{"ID.npy", "Age.npy", "Vector.npy", "dummy.npy"} + err = wrapper.validateColumnBasedFiles(files, schema) + assert.NotNil(t, err) + + // correct input + files = []string{"ID.npy", "Age.npy", "Vector.npy"} + err = wrapper.validateColumnBasedFiles(files, schema) + assert.Nil(t, err) } func Test_ImportWrapperFileValidation(t *testing.T) { @@ -782,85 +654,122 @@ func Test_ImportWrapperFileValidation(t *testing.T) { size: 1, } - idAllocator := newIDAllocator(ctx, t) - schema := perfSchema(128) + idAllocator := newIDAllocator(ctx, t, nil) + schema := &schemapb.CollectionSchema{ + Name: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 102, + Name: "bol", + IsPrimaryKey: false, + DataType: schemapb.DataType_Bool, + }, + }, + } shardNum := 2 segmentSize := 512 // unit: MB - wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil, nil) - - // duplicate files - files := []string{"1.npy", "1.npy"} - err := wrapper.fileValidation(files, false) - assert.NotNil(t, err) - err = wrapper.fileValidation(files, true) - assert.NotNil(t, err) - - // unsupported file name - files[0] = "a/1.npy" - files[1] = "b/1.npy" - err = wrapper.fileValidation(files, true) - assert.NotNil(t, err) - - err = wrapper.fileValidation(files, false) - assert.NotNil(t, err) + wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) // unsupported file type - files[0] = "1" - files[1] = "1" - err = wrapper.fileValidation(files, true) + files := []string{"uid.txt"} + rowBased, err := wrapper.fileValidation(files) assert.NotNil(t, err) + assert.False(t, rowBased) - err = wrapper.fileValidation(files, false) + // file missed + files = []string{"uid.npy"} + rowBased, err = wrapper.fileValidation(files) assert.NotNil(t, err) + assert.False(t, rowBased) + + // redundant file + files = []string{"uid.npy", "b/bol.npy", "c/no.npy"} + rowBased, err = wrapper.fileValidation(files) + assert.NotNil(t, err) + assert.False(t, rowBased) + + // duplicate files + files = []string{"a/1.json", "b/1.json"} + rowBased, err = wrapper.fileValidation(files) + assert.NotNil(t, err) + assert.True(t, rowBased) + + files = []string{"a/uid.npy", "uid.npy", "b/bol.npy"} + rowBased, err = wrapper.fileValidation(files) + assert.NotNil(t, err) + assert.False(t, rowBased) + + // unsupported file for row-based + files = []string{"a/uid.json", "b/bol.npy"} + rowBased, err = wrapper.fileValidation(files) + assert.NotNil(t, err) + assert.True(t, rowBased) + + // unsupported file for column-based + files = []string{"a/uid.npy", "b/bol.json"} + rowBased, err = wrapper.fileValidation(files) + assert.NotNil(t, err) + assert.False(t, rowBased) // valid cases - files[0] = "1.json" - files[1] = "2.json" - err = wrapper.fileValidation(files, true) + files = []string{"a/1.json", "b/2.json"} + rowBased, err = wrapper.fileValidation(files) assert.Nil(t, err) + assert.True(t, rowBased) - files[1] = "2.npy" - err = wrapper.fileValidation(files, false) + files = []string{"a/uid.npy", "b/bol.npy"} + rowBased, err = wrapper.fileValidation(files) assert.Nil(t, err) + assert.False(t, rowBased) // empty file cm.size = 0 - wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil, nil) - err = wrapper.fileValidation(files, true) - assert.NotNil(t, err) - - err = wrapper.fileValidation(files, false) + wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) + rowBased, err = wrapper.fileValidation(files) assert.NotNil(t, err) + assert.False(t, rowBased) // file size exceed MaxFileSize limit cm.size = MaxFileSize + 1 - wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil, nil) - err = wrapper.fileValidation(files, true) - assert.NotNil(t, err) - - err = wrapper.fileValidation(files, false) + wrapper = NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) + rowBased, err = wrapper.fileValidation(files) assert.NotNil(t, err) + assert.False(t, rowBased) // total files size exceed MaxTotalSizeInMemory limit cm.size = MaxFileSize - 1 files = append(files, "3.npy") - err = wrapper.fileValidation(files, false) + rowBased, err = wrapper.fileValidation(files) assert.NotNil(t, err) + assert.False(t, rowBased) // failed to get file size cm.sizeErr = errors.New("error") - err = wrapper.fileValidation(files, false) + rowBased, err = wrapper.fileValidation(files) assert.NotNil(t, err) + assert.False(t, rowBased) } func Test_ImportWrapperReportFailRowBased(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + f := dependency.NewDefaultFactory(true) ctx := context.Background() cm, err := f.NewPersistentStorageChunkManager(ctx) assert.NoError(t, err) - idAllocator := newIDAllocator(ctx, t) + idAllocator := newIDAllocator(ctx, t, nil) content := []byte(`{ "rows":[ @@ -872,25 +781,13 @@ func Test_ImportWrapperReportFailRowBased(t *testing.T) { ] }`) - filePath := TempFilesPath + "rows_1.json" + filePath := "rows_1.json" err = cm.Write(ctx, filePath, content) assert.NoError(t, err) defer cm.RemoveWithPrefix(ctx, "") - rowCount := 0 - flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { - count := 0 - for _, data := range fields { - assert.Less(t, 0, data.RowNum()) - if count == 0 { - count = data.RowNum() - } else { - assert.Equal(t, count, data.RowNum()) - } - } - rowCount += count - return nil - } + rowCounter := &rowCounterTest{} + assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) // success case importResult := &rootcoordpb.ImportResult{ @@ -907,159 +804,36 @@ func Test_ImportWrapperReportFailRowBased(t *testing.T) { reportFunc := func(res *rootcoordpb.ImportResult) error { return nil } - wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) + wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, importResult, reportFunc) + wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) + files := make([]string, 0) files = append(files, filePath) wrapper.reportFunc = func(res *rootcoordpb.ImportResult) error { return errors.New("mock error") } - err = wrapper.Import(files, true, false) + err = wrapper.Import(files, DefaultImportOptions()) assert.NotNil(t, err) - assert.Equal(t, 5, rowCount) - assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) -} - -func Test_ImportWrapperReportFailColumnBased_json(t *testing.T) { - f := dependency.NewDefaultFactory(true) - ctx := context.Background() - cm, err := f.NewPersistentStorageChunkManager(ctx) - assert.NoError(t, err) - defer cm.RemoveWithPrefix(ctx, "") - - idAllocator := newIDAllocator(ctx, t) - - content := []byte(`{ - "field_bool": [true, false, true, true, true], - "field_int8": [10, 11, 12, 13, 14], - "field_int16": [100, 101, 102, 103, 104], - "field_int32": [1000, 1001, 1002, 1003, 1004], - "field_int64": [10000, 10001, 10002, 10003, 10004], - "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], - "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], - "field_string": ["a", "b", "c", "d", "e"], - "field_binary_vector": [ - [254, 1], - [253, 2], - [252, 3], - [251, 4], - [250, 5] - ], - "field_float_vector": [ - [1.1, 1.2, 1.3, 1.4], - [2.1, 2.2, 2.3, 2.4], - [3.1, 3.2, 3.3, 3.4], - [4.1, 4.2, 4.3, 4.4], - [5.1, 5.2, 5.3, 5.4] - ] - }`) - - filePath := TempFilesPath + "columns_1.json" - err = cm.Write(ctx, filePath, content) - assert.NoError(t, err) - - rowCount := 0 - flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { - count := 0 - for _, data := range fields { - assert.Less(t, 0, data.RowNum()) - if count == 0 { - count = data.RowNum() - } else { - assert.Equal(t, count, data.RowNum()) - } - } - rowCount += count - return nil - } - - // success case - importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - TaskId: 1, - DatanodeId: 1, - State: commonpb.ImportState_ImportStarted, - Segments: make([]int64, 0), - AutoIds: make([]int64, 0), - RowCount: 0, - } - reportFunc := func(res *rootcoordpb.ImportResult) error { - return nil - } - wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) - files := make([]string, 0) - files = append(files, filePath) - - wrapper.reportFunc = func(res *rootcoordpb.ImportResult) error { - return errors.New("mock error") - } - err = wrapper.Import(files, false, false) - assert.NotNil(t, err) - assert.Equal(t, 5, rowCount) + assert.Equal(t, 5, rowCounter.rowCount) assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) } func Test_ImportWrapperReportFailColumnBased_numpy(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + f := dependency.NewDefaultFactory(true) ctx := context.Background() cm, err := f.NewPersistentStorageChunkManager(ctx) assert.NoError(t, err) defer cm.RemoveWithPrefix(ctx, "") - idAllocator := newIDAllocator(ctx, t) + idAllocator := newIDAllocator(ctx, t, nil) - content := []byte(`{ - "field_bool": [true, false, true, true, true], - "field_int8": [10, 11, 12, 13, 14], - "field_int16": [100, 101, 102, 103, 104], - "field_int32": [1000, 1001, 1002, 1003, 1004], - "field_int64": [10000, 10001, 10002, 10003, 10004], - "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], - "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], - "field_string": ["a", "b", "c", "d", "e"] - }`) - - files := make([]string, 0) - - filePath := TempFilesPath + "scalar_fields.json" - err = cm.Write(ctx, filePath, content) - assert.NoError(t, err) - files = append(files, filePath) - - filePath = TempFilesPath + "field_binary_vector.npy" - bin := [][2]uint8{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}} - content, err = CreateNumpyData(bin) - assert.Nil(t, err) - log.Debug("content", zap.Any("c", content)) - err = cm.Write(ctx, filePath, content) - assert.NoError(t, err) - files = append(files, filePath) - - filePath = TempFilesPath + "field_float_vector.npy" - flo := [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}, {5, 6, 7, 8}, {7, 8, 9, 10}, {9, 10, 11, 12}} - content, err = CreateNumpyData(flo) - assert.Nil(t, err) - log.Debug("content", zap.Any("c", content)) - err = cm.Write(ctx, filePath, content) - assert.NoError(t, err) - files = append(files, filePath) - - rowCount := 0 - flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { - count := 0 - for _, data := range fields { - assert.Less(t, 0, data.RowNum()) - if count == 0 { - count = data.RowNum() - } else { - assert.Equal(t, count, data.RowNum()) - } - } - rowCount += count - return nil - } + rowCounter := &rowCounterTest{} + assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) // success case importResult := &rootcoordpb.ImportResult{ @@ -1077,15 +851,18 @@ func Test_ImportWrapperReportFailColumnBased_numpy(t *testing.T) { return nil } schema := sampleSchema() - schema.Fields[4].AutoID = true - wrapper := NewImportWrapper(ctx, schema, 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) + wrapper := NewImportWrapper(ctx, schema, 2, 1, idAllocator, cm, importResult, reportFunc) + wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) wrapper.reportFunc = func(res *rootcoordpb.ImportResult) error { return errors.New("mock error") } - err = wrapper.Import(files, false, false) + + files := createSampleNumpyFiles(t, cm) + + err = wrapper.Import(files, DefaultImportOptions()) assert.NotNil(t, err) - assert.Equal(t, 5, rowCount) + assert.Equal(t, 5, rowCounter.rowCount) assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) } @@ -1096,17 +873,19 @@ func Test_ImportWrapperIsBinlogImport(t *testing.T) { size: 1, } - idAllocator := newIDAllocator(ctx, t) + idAllocator := newIDAllocator(ctx, t, nil) schema := perfSchema(128) shardNum := 2 segmentSize := 512 // unit: MB - wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil, nil) + wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) + // empty paths paths := []string{} b := wrapper.isBinlogImport(paths) assert.False(t, b) + // paths count should be 2 paths = []string{ "path1", "path2", @@ -1115,6 +894,7 @@ func Test_ImportWrapperIsBinlogImport(t *testing.T) { b = wrapper.isBinlogImport(paths) assert.False(t, b) + // not path paths = []string{ "path1.txt", "path2.jpg", @@ -1122,6 +902,7 @@ func Test_ImportWrapperIsBinlogImport(t *testing.T) { b = wrapper.isBinlogImport(paths) assert.False(t, b) + // success paths = []string{ "/tmp", "/tmp", @@ -1137,12 +918,12 @@ func Test_ImportWrapperDoBinlogImport(t *testing.T) { size: 1, } - idAllocator := newIDAllocator(ctx, t) + idAllocator := newIDAllocator(ctx, t, nil) schema := perfSchema(128) shardNum := 2 segmentSize := 512 // unit: MB - wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil, nil) + wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, nil, nil) paths := []string{ "/tmp", "/tmp", @@ -1150,21 +931,22 @@ func Test_ImportWrapperDoBinlogImport(t *testing.T) { wrapper.chunkManager = nil // failed to create new BinlogParser - err := wrapper.doBinlogImport(paths, 0) + err := wrapper.doBinlogImport(paths, 0, math.MaxUint64) assert.NotNil(t, err) cm.listErr = errors.New("error") wrapper.chunkManager = cm - wrapper.callFlushFunc = func(fields map[storage.FieldID]storage.FieldData, shardID int) error { - return nil - } + + rowCounter := &rowCounterTest{} + assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) + wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) // failed to call parser.Parse() - err = wrapper.doBinlogImport(paths, 0) + err = wrapper.doBinlogImport(paths, 0, math.MaxUint64) assert.NotNil(t, err) // Import() failed - err = wrapper.Import(paths, false, false) + err = wrapper.Import(paths, DefaultImportOptions()) assert.NotNil(t, err) cm.listErr = nil @@ -1184,6 +966,91 @@ func Test_ImportWrapperDoBinlogImport(t *testing.T) { } // succeed - err = wrapper.doBinlogImport(paths, 0) + err = wrapper.doBinlogImport(paths, 0, math.MaxUint64) assert.Nil(t, err) } + +func Test_ImportWrapperSplitFieldsData(t *testing.T) { + ctx := context.Background() + + cm := &MockChunkManager{} + + idAllocator := newIDAllocator(ctx, t, nil) + + rowCounter := &rowCounterTest{} + assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) + + importResult := &rootcoordpb.ImportResult{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + TaskId: 1, + DatanodeId: 1, + State: commonpb.ImportState_ImportStarted, + Segments: make([]int64, 0), + AutoIds: make([]int64, 0), + RowCount: 0, + } + reportFunc := func(res *rootcoordpb.ImportResult) error { + return nil + } + + schema := &schemapb.CollectionSchema{ + Name: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 102, + Name: "flag", + IsPrimaryKey: false, + DataType: schemapb.DataType_Bool, + }, + }, + } + + wrapper := NewImportWrapper(ctx, schema, 2, 1024*1024, idAllocator, cm, importResult, reportFunc) + wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) + + // nil input + err := wrapper.splitFieldsData(nil, 0) + assert.NotNil(t, err) + + // split 100 rows to 4 blocks + rowCount := 100 + input := initSegmentData(schema) + for j := 0; j < rowCount; j++ { + pkField := input[101].(*storage.Int64FieldData) + pkField.Data = append(pkField.Data, int64(j)) + + flagField := input[102].(*storage.BoolFieldData) + flagField.Data = append(flagField.Data, true) + } + + err = wrapper.splitFieldsData(input, 512) + assert.Nil(t, err) + assert.Equal(t, 2, len(importResult.AutoIds)) + assert.Equal(t, 4, rowCounter.callTime) + assert.Equal(t, rowCount, rowCounter.rowCount) + + // row count of fields are unequal + schema.Fields[0].AutoID = false + input = initSegmentData(schema) + for j := 0; j < rowCount; j++ { + pkField := input[101].(*storage.Int64FieldData) + pkField.Data = append(pkField.Data, int64(j)) + if j%2 == 0 { + continue + } + flagField := input[102].(*storage.BoolFieldData) + flagField.Data = append(flagField.Data, true) + } + err = wrapper.splitFieldsData(input, 512) + assert.NotNil(t, err) +} diff --git a/internal/util/importutil/json_handler.go b/internal/util/importutil/json_handler.go index b515a584c6..773faa66c7 100644 --- a/internal/util/importutil/json_handler.go +++ b/internal/util/importutil/json_handler.go @@ -19,7 +19,7 @@ package importutil import ( "errors" "fmt" - "strconv" + "reflect" "go.uber.org/zap" @@ -31,33 +31,12 @@ import ( "github.com/milvus-io/milvus/internal/util/typeutil" ) -// interface to process rows data +// JSONRowHandler is the interface to process rows data type JSONRowHandler interface { Handle(rows []map[storage.FieldID]interface{}) error } -// interface to process column data -type JSONColumnHandler interface { - Handle(columns map[storage.FieldID][]interface{}) error -} - -// method to get dimension of vecotor field -func getFieldDimension(schema *schemapb.FieldSchema) (int, error) { - for _, kvPair := range schema.GetTypeParams() { - key, value := kvPair.GetKey(), kvPair.GetValue() - if key == "dim" { - dim, err := strconv.Atoi(value) - if err != nil { - return 0, errors.New("vector dimension is invalid") - } - return dim, nil - } - } - - return 0, errors.New("vector dimension is not defined") -} - -// field value validator +// Validator is field value validator type Validator struct { validateFunc func(obj interface{}) error // validate data type function convertFunc func(obj interface{}, field storage.FieldData) error // convert data function @@ -68,210 +47,7 @@ type Validator struct { fieldName string // field name } -// method to construct valiator functions -func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[storage.FieldID]*Validator) error { - if collectionSchema == nil { - return errors.New("collection schema is nil") - } - - // json decoder parse all the numeric value into float64 - numericValidator := func(obj interface{}) error { - switch obj.(type) { - case float64: - return nil - default: - s := fmt.Sprintf("%v", obj) - msg := "illegal numeric value " + s - return errors.New(msg) - } - - } - - for i := 0; i < len(collectionSchema.Fields); i++ { - schema := collectionSchema.Fields[i] - - validators[schema.GetFieldID()] = &Validator{} - validators[schema.GetFieldID()].primaryKey = schema.GetIsPrimaryKey() - validators[schema.GetFieldID()].autoID = schema.GetAutoID() - validators[schema.GetFieldID()].fieldName = schema.GetName() - validators[schema.GetFieldID()].isString = false - - switch schema.DataType { - case schemapb.DataType_Bool: - validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { - switch obj.(type) { - case bool: - return nil - default: - s := fmt.Sprintf("%v", obj) - msg := "illegal value " + s + " for bool type field " + schema.GetName() - return errors.New(msg) - } - - } - validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := obj.(bool) - field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value) - field.(*storage.BoolFieldData).NumRows[0]++ - return nil - } - case schemapb.DataType_Float: - validators[schema.GetFieldID()].validateFunc = numericValidator - validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := float32(obj.(float64)) - field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value) - field.(*storage.FloatFieldData).NumRows[0]++ - return nil - } - case schemapb.DataType_Double: - validators[schema.GetFieldID()].validateFunc = numericValidator - validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := obj.(float64) - field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value) - field.(*storage.DoubleFieldData).NumRows[0]++ - return nil - } - case schemapb.DataType_Int8: - validators[schema.GetFieldID()].validateFunc = numericValidator - validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := int8(obj.(float64)) - field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value) - field.(*storage.Int8FieldData).NumRows[0]++ - return nil - } - case schemapb.DataType_Int16: - validators[schema.GetFieldID()].validateFunc = numericValidator - validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := int16(obj.(float64)) - field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value) - field.(*storage.Int16FieldData).NumRows[0]++ - return nil - } - case schemapb.DataType_Int32: - validators[schema.GetFieldID()].validateFunc = numericValidator - validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := int32(obj.(float64)) - field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value) - field.(*storage.Int32FieldData).NumRows[0]++ - return nil - } - case schemapb.DataType_Int64: - validators[schema.GetFieldID()].validateFunc = numericValidator - validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := int64(obj.(float64)) - field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value) - field.(*storage.Int64FieldData).NumRows[0]++ - return nil - } - case schemapb.DataType_BinaryVector: - dim, err := getFieldDimension(schema) - if err != nil { - return err - } - validators[schema.GetFieldID()].dimension = dim - - validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { - switch vt := obj.(type) { - case []interface{}: - if len(vt)*8 != dim { - msg := "bit size " + strconv.Itoa(len(vt)*8) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName() - return errors.New(msg) - } - for i := 0; i < len(vt); i++ { - if e := numericValidator(vt[i]); e != nil { - msg := e.Error() + " for binary vector field " + schema.GetName() - return errors.New(msg) - } - - t := int(vt[i].(float64)) - if t > 255 || t < 0 { - msg := "illegal value " + strconv.Itoa(t) + " for binary vector field " + schema.GetName() - return errors.New(msg) - } - } - return nil - default: - s := fmt.Sprintf("%v", obj) - msg := s + " is not an array for binary vector field " + schema.GetName() - return errors.New(msg) - } - } - - validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - arr := obj.([]interface{}) - for i := 0; i < len(arr); i++ { - value := byte(arr[i].(float64)) - field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, value) - } - - field.(*storage.BinaryVectorFieldData).NumRows[0]++ - return nil - } - case schemapb.DataType_FloatVector: - dim, err := getFieldDimension(schema) - if err != nil { - return err - } - validators[schema.GetFieldID()].dimension = dim - - validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { - switch vt := obj.(type) { - case []interface{}: - if len(vt) != dim { - msg := "array size " + strconv.Itoa(len(vt)) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName() - return errors.New(msg) - } - for i := 0; i < len(vt); i++ { - if e := numericValidator(vt[i]); e != nil { - msg := e.Error() + " for float vector field " + schema.GetName() - return errors.New(msg) - } - } - return nil - default: - s := fmt.Sprintf("%v", obj) - msg := s + " is not an array for float vector field " + schema.GetName() - return errors.New(msg) - } - } - - validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - arr := obj.([]interface{}) - for i := 0; i < len(arr); i++ { - value := float32(arr[i].(float64)) - field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, value) - } - field.(*storage.FloatVectorFieldData).NumRows[0]++ - return nil - } - case schemapb.DataType_String, schemapb.DataType_VarChar: - validators[schema.GetFieldID()].isString = true - validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { - switch obj.(type) { - case string: - return nil - default: - s := fmt.Sprintf("%v", obj) - msg := s + " is not a string for string type field " + schema.GetName() - return errors.New(msg) - } - } - - validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := obj.(string) - field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, value) - field.(*storage.StringFieldData).NumRows[0]++ - return nil - } - default: - return errors.New("unsupport data type: " + strconv.Itoa(int(collectionSchema.Fields[i].DataType))) - } - } - - return nil -} - -// row-based json format validator class +// JSONRowValidator is row-based json format validator class type JSONRowValidator struct { downstream JSONRowHandler // downstream processor, typically is a JSONRowComsumer validators map[storage.FieldID]*Validator // validators for each field @@ -286,7 +62,7 @@ func NewJSONRowValidator(collectionSchema *schemapb.CollectionSchema, downstream } err := initValidators(collectionSchema, v.validators) if err != nil { - log.Error("JSON column validator: failed to initialize json row-based validator", zap.Error(err)) + log.Error("JSON row validator: failed to initialize json row-based validator", zap.Error(err)) return nil, err } return v, nil @@ -298,13 +74,14 @@ func (v *JSONRowValidator) ValidateCount() int64 { func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error { if v == nil || v.validators == nil || len(v.validators) == 0 { + log.Error("JSON row validator is not initialized") return errors.New("JSON row validator is not initialized") } // parse completed if rows == nil { log.Info("JSON row validation finished") - if v.downstream != nil { + if v.downstream != nil && !reflect.ValueOf(v.downstream).IsNil() { return v.downstream.Handle(rows) } return nil @@ -314,107 +91,42 @@ func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error row := rows[i] for id, validator := range v.validators { + value, ok := row[id] if validator.primaryKey && validator.autoID { - // auto-generated primary key, ignore + // primary key is auto-generated, if user provided it, return error + if ok { + log.Error("JSON row validator: primary key is auto-generated, no need to provide PK value at the row", + zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", v.rowCounter+int64(i))) + return fmt.Errorf("the primary key '%s' is auto-generated, no need to provide PK value at the row %d", + validator.fieldName, v.rowCounter+int64(i)) + } continue } - value, ok := row[id] if !ok { - return errors.New("JSON row validator: field " + validator.fieldName + " missed at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10)) + log.Error("JSON row validator: field missed at the row", + zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", v.rowCounter+int64(i))) + return fmt.Errorf("the field '%s' missed at the row %d", validator.fieldName, v.rowCounter+int64(i)) } if err := validator.validateFunc(value); err != nil { - return errors.New("JSON row validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10)) + log.Error("JSON row validator: invalid value at the row", zap.String("fieldName", validator.fieldName), + zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Any("value", value), zap.Error(err)) + return fmt.Errorf("the field '%s' value at the row %d is invalid, error: %s", + validator.fieldName, v.rowCounter+int64(i), err.Error()) } } } v.rowCounter += int64(len(rows)) - if v.downstream != nil { + if v.downstream != nil && !reflect.ValueOf(v.downstream).IsNil() { return v.downstream.Handle(rows) } return nil } -// column-based json format validator class -type JSONColumnValidator struct { - downstream JSONColumnHandler // downstream processor, typically is a JSONColumnComsumer - validators map[storage.FieldID]*Validator // validators for each field - rowCounter map[string]int64 // row count of each field -} - -func NewJSONColumnValidator(schema *schemapb.CollectionSchema, downstream JSONColumnHandler) (*JSONColumnValidator, error) { - v := &JSONColumnValidator{ - validators: make(map[storage.FieldID]*Validator), - downstream: downstream, - rowCounter: make(map[string]int64), - } - err := initValidators(schema, v.validators) - if err != nil { - log.Error("JSON column validator: fail to initialize json column-based validator", zap.Error(err)) - return nil, err - } - return v, nil -} - -func (v *JSONColumnValidator) ValidateCount() map[string]int64 { - return v.rowCounter -} - -func (v *JSONColumnValidator) Handle(columns map[storage.FieldID][]interface{}) error { - if v == nil || v.validators == nil || len(v.validators) == 0 { - return errors.New("JSON column validator is not initialized") - } - - // parse completed - if columns == nil { - // compare the row count of columns, should be equal - rowCount := int64(-1) - for k, counter := range v.rowCounter { - if rowCount == -1 { - rowCount = counter - } else if rowCount != counter { - return errors.New("JSON column validator: the field " + k + " row count " + strconv.Itoa(int(counter)) + " is not equal to other fields row count" + strconv.Itoa(int(rowCount))) - } - } - - // let the downstream know parse is completed - log.Info("JSON column validation finished") - if v.downstream != nil { - return v.downstream.Handle(nil) - } - return nil - } - - for id, values := range columns { - validator, ok := v.validators[id] - name := validator.fieldName - if !ok { - // not a valid field name, skip without parsing - break - } - - for i := 0; i < len(values); i++ { - if err := validator.validateFunc(values[i]); err != nil { - return errors.New("JSON column validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter[name]+int64(i), 10)) - } - } - - v.rowCounter[name] += int64(len(values)) - } - - if v.downstream != nil { - return v.downstream.Handle(columns) - } - - return nil -} - -type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardID int) error - -// row-based json format consumer class +// JSONRowConsumer is row-based json format consumer class type JSONRowConsumer struct { collectionSchema *schemapb.CollectionSchema // collection schema rowIDAllocator *allocator.IDAllocator // autoid allocator @@ -422,89 +134,14 @@ type JSONRowConsumer struct { rowCounter int64 // how many rows have been consumed shardNum int32 // sharding number of the collection segmentsData []map[storage.FieldID]storage.FieldData // in-memory segments data - segmentSize int64 // maximum size of a segment(unit:byte) + blockSize int64 // maximum size of a read block(unit:byte) primaryKey storage.FieldID // name of primary key autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25 callFlushFunc ImportFlushFunc // call back function to flush segment } -func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.FieldID]storage.FieldData { - segmentData := make(map[storage.FieldID]storage.FieldData) - // rowID field is a hidden field with fieldID=0, it is always auto-generated by IDAllocator - // if primary key is int64 and autoID=true, primary key field is equal to rowID field - segmentData[common.RowIDField] = &storage.Int64FieldData{ - Data: make([]int64, 0), - NumRows: []int64{0}, - } - - for i := 0; i < len(collectionSchema.Fields); i++ { - schema := collectionSchema.Fields[i] - switch schema.DataType { - case schemapb.DataType_Bool: - segmentData[schema.GetFieldID()] = &storage.BoolFieldData{ - Data: make([]bool, 0), - NumRows: []int64{0}, - } - case schemapb.DataType_Float: - segmentData[schema.GetFieldID()] = &storage.FloatFieldData{ - Data: make([]float32, 0), - NumRows: []int64{0}, - } - case schemapb.DataType_Double: - segmentData[schema.GetFieldID()] = &storage.DoubleFieldData{ - Data: make([]float64, 0), - NumRows: []int64{0}, - } - case schemapb.DataType_Int8: - segmentData[schema.GetFieldID()] = &storage.Int8FieldData{ - Data: make([]int8, 0), - NumRows: []int64{0}, - } - case schemapb.DataType_Int16: - segmentData[schema.GetFieldID()] = &storage.Int16FieldData{ - Data: make([]int16, 0), - NumRows: []int64{0}, - } - case schemapb.DataType_Int32: - segmentData[schema.GetFieldID()] = &storage.Int32FieldData{ - Data: make([]int32, 0), - NumRows: []int64{0}, - } - case schemapb.DataType_Int64: - segmentData[schema.GetFieldID()] = &storage.Int64FieldData{ - Data: make([]int64, 0), - NumRows: []int64{0}, - } - case schemapb.DataType_BinaryVector: - dim, _ := getFieldDimension(schema) - segmentData[schema.GetFieldID()] = &storage.BinaryVectorFieldData{ - Data: make([]byte, 0), - NumRows: []int64{0}, - Dim: dim, - } - case schemapb.DataType_FloatVector: - dim, _ := getFieldDimension(schema) - segmentData[schema.GetFieldID()] = &storage.FloatVectorFieldData{ - Data: make([]float32, 0), - NumRows: []int64{0}, - Dim: dim, - } - case schemapb.DataType_String, schemapb.DataType_VarChar: - segmentData[schema.GetFieldID()] = &storage.StringFieldData{ - Data: make([]string, 0), - NumRows: []int64{0}, - } - default: - log.Error("JSON row consumer error: unsupported data type", zap.Int("DataType", int(schema.DataType))) - return nil - } - } - - return segmentData -} - -func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, segmentSize int64, +func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, blockSize int64, flushFunc ImportFlushFunc) (*JSONRowConsumer, error) { if collectionSchema == nil { log.Error("JSON row consumer: collection schema is nil") @@ -516,7 +153,7 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al rowIDAllocator: idAlloc, validators: make(map[storage.FieldID]*Validator), shardNum: shardNum, - segmentSize: segmentSize, + blockSize: blockSize, rowCounter: 0, primaryKey: -1, autoIDRange: make([]int64, 0), @@ -526,15 +163,15 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al err := initValidators(collectionSchema, v.validators) if err != nil { log.Error("JSON row consumer: fail to initialize json row-based consumer", zap.Error(err)) - return nil, errors.New("fail to initialize json row-based consumer") + return nil, fmt.Errorf("fail to initialize json row-based consumer: %v", err) } v.segmentsData = make([]map[storage.FieldID]storage.FieldData, 0, shardNum) for i := 0; i < int(shardNum); i++ { segmentData := initSegmentData(collectionSchema) if segmentData == nil { - log.Error("JSON row consumer: fail to initialize in-memory segment data", zap.Int32("shardNum", shardNum)) - return nil, errors.New("fail to initialize in-memory segment data") + log.Error("JSON row consumer: fail to initialize in-memory segment data", zap.Int("shardID", i)) + return nil, fmt.Errorf("fail to initialize in-memory segment data for shardID %d", i) } v.segmentsData = append(v.segmentsData, segmentData) } @@ -554,7 +191,7 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al // primary key is autoid, id generator is required if v.validators[v.primaryKey].autoID && idAlloc == nil { log.Error("JSON row consumer: ID allocator is nil") - return nil, errors.New(" ID allocator is nil") + return nil, errors.New("ID allocator is nil") } return v, nil @@ -571,8 +208,17 @@ func (v *JSONRowConsumer) flush(force bool) error { segmentData := v.segmentsData[i] rowNum := segmentData[v.primaryKey].RowNum() if rowNum > 0 { - log.Info("JSON row consumer: force flush segment", zap.Int("rows", rowNum)) - v.callFlushFunc(segmentData, i) + log.Info("JSON row consumer: force flush binlog", zap.Int("rows", rowNum)) + err := v.callFlushFunc(segmentData, i) + if err != nil { + return err + } + + v.segmentsData[i] = initSegmentData(v.collectionSchema) + if v.segmentsData[i] == nil { + log.Error("JSON row consumer: fail to initialize in-memory segment data") + return errors.New("fail to initialize in-memory segment data") + } } } @@ -587,9 +233,13 @@ func (v *JSONRowConsumer) flush(force bool) error { for _, field := range segmentData { memSize += field.GetMemorySize() } - if memSize >= int(v.segmentSize) && rowNum > 0 { - log.Info("JSON row consumer: flush fulled segment", zap.Int("bytes", memSize), zap.Int("rowNum", rowNum)) - v.callFlushFunc(segmentData, i) + if memSize >= int(v.blockSize) && rowNum > 0 { + log.Info("JSON row consumer: flush fulled binlog", zap.Int("bytes", memSize), zap.Int("rowNum", rowNum)) + err := v.callFlushFunc(segmentData, i) + if err != nil { + return err + } + v.segmentsData[i] = initSegmentData(v.collectionSchema) if v.segmentsData[i] == nil { log.Error("JSON row consumer: fail to initialize in-memory segment data") @@ -603,6 +253,7 @@ func (v *JSONRowConsumer) flush(force bool) error { func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { if v == nil || v.validators == nil || len(v.validators) == 0 { + log.Error("JSON row consumer is not initialized") return errors.New("JSON row consumer is not initialized") } @@ -615,10 +266,11 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { err := v.flush(false) if err != nil { - return err + log.Error("JSON row consumer: try flush data but failed", zap.Error(err)) + return fmt.Errorf("try flush data but failed: %s", err.Error()) } - // prepare autoid + // prepare autoid, no matter int64 or varchar pk, we always generate autoid since the hidden field RowIDField requires them primaryValidator := v.validators[v.primaryKey] var rowIDBegin typeutil.UniqueID var rowIDEnd typeutil.UniqueID @@ -626,12 +278,19 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { var err error rowIDBegin, rowIDEnd, err = v.rowIDAllocator.Alloc(uint32(len(rows))) if err != nil { - return errors.New("JSON row consumer: " + err.Error()) + log.Error("JSON row consumer: failed to generate primary keys", zap.Int("count", len(rows)), zap.Error(err)) + return fmt.Errorf("failed to generate %d primary keys: %s", len(rows), err.Error()) } if rowIDEnd-rowIDBegin != int64(len(rows)) { - return errors.New("JSON row consumer: failed to allocate ID for " + strconv.Itoa(len(rows)) + " rows") + log.Error("JSON row consumer: try to generate primary keys but allocated ids are not enough", + zap.Int("count", len(rows)), zap.Int64("generated", rowIDEnd-rowIDBegin)) + return fmt.Errorf("try to generate %d primary keys but only %d keys were allocated", len(rows), rowIDEnd-rowIDBegin) + } + log.Info("JSON row consumer: auto-generate primary keys", zap.Int64("begin", rowIDBegin), zap.Int64("end", rowIDEnd)) + if !primaryValidator.isString { + // if pk is varchar, no need to record auto-generated row ids + v.autoIDRange = append(v.autoIDRange, rowIDBegin, rowIDEnd) } - v.autoIDRange = append(v.autoIDRange, rowIDBegin, rowIDEnd) } // consume rows @@ -642,7 +301,8 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { var shard uint32 if primaryValidator.isString { if primaryValidator.autoID { - return errors.New("JSON row consumer: string type primary key cannot be auto-generated") + log.Error("JSON row consumer: string type primary key cannot be auto-generated") + return errors.New("string type primary key cannot be auto-generated") } value := row[v.primaryKey] @@ -662,7 +322,13 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { pk = int64(value.(float64)) } - hash, _ := typeutil.Hash32Int64(pk) + hash, err := typeutil.Hash32Int64(pk) + if err != nil { + log.Error("JSON row consumer: failed to hash primary key at the row", + zap.Int64("key", pk), zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Error(err)) + return fmt.Errorf("failed to hash primary key %d at the row %d, error: %s", pk, v.rowCounter+int64(i), err.Error()) + } + shard = hash % uint32(v.shardNum) pkArray := v.segmentsData[shard][v.primaryKey].(*storage.Int64FieldData) pkArray.Data = append(pkArray.Data, pk) @@ -681,7 +347,10 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { } value := row[name] if err := validator.convertFunc(value, v.segmentsData[shard][name]); err != nil { - return errors.New("JSON row consumer: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10)) + log.Error("JSON row consumer: failed to convert value for field at the row", + zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Error(err)) + return fmt.Errorf("failed to convert value for field %s at the row %d, error: %s", + validator.fieldName, v.rowCounter+int64(i), err.Error()) } } } @@ -690,115 +359,3 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { return nil } - -type ColumnFlushFunc func(fields map[storage.FieldID]storage.FieldData) error - -// column-based json format consumer class -type JSONColumnConsumer struct { - collectionSchema *schemapb.CollectionSchema // collection schema - validators map[storage.FieldID]*Validator // validators for each field - fieldsData map[storage.FieldID]storage.FieldData // in-memory fields data - primaryKey storage.FieldID // name of primary key - - callFlushFunc ColumnFlushFunc // call back function to flush segment -} - -func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema, flushFunc ColumnFlushFunc) (*JSONColumnConsumer, error) { - if collectionSchema == nil { - log.Error("JSON column consumer: collection schema is nil") - return nil, errors.New("collection schema is nil") - } - - v := &JSONColumnConsumer{ - collectionSchema: collectionSchema, - validators: make(map[storage.FieldID]*Validator), - callFlushFunc: flushFunc, - } - err := initValidators(collectionSchema, v.validators) - if err != nil { - log.Error("JSON column consumer: fail to initialize validator", zap.Error(err)) - return nil, errors.New("fail to initialize validator") - } - v.fieldsData = initSegmentData(collectionSchema) - if v.fieldsData == nil { - log.Error("JSON column consumer: fail to initialize in-memory segment data") - return nil, errors.New("fail to initialize in-memory segment data") - } - - for i := 0; i < len(collectionSchema.Fields); i++ { - schema := collectionSchema.Fields[i] - if schema.GetIsPrimaryKey() { - v.primaryKey = schema.GetFieldID() - break - } - } - - return v, nil -} - -func (v *JSONColumnConsumer) flush() error { - // check row count, should be equal - rowCount := 0 - for id, field := range v.fieldsData { - // skip the autoid field - if id == v.primaryKey && v.validators[v.primaryKey].autoID { - continue - } - cnt := field.RowNum() - // skip 0 row fields since a data file may only import one column(there are several data files imported) - if cnt == 0 { - continue - } - - // only check non-zero row fields - if rowCount == 0 { - rowCount = cnt - } else if rowCount != cnt { - return errors.New("JSON column consumer: " + strconv.FormatInt(id, 10) + " row count " + strconv.Itoa(cnt) + " doesn't equal " + strconv.Itoa(rowCount)) - } - } - - if rowCount == 0 { - return errors.New("JSON column consumer: row count is 0") - } - log.Info("JSON column consumer: rows parsed", zap.Int("rowCount", rowCount)) - - // output the fileds data, let outside split them into segments - return v.callFlushFunc(v.fieldsData) -} - -func (v *JSONColumnConsumer) Handle(columns map[storage.FieldID][]interface{}) error { - if v == nil || v.validators == nil || len(v.validators) == 0 { - return errors.New("JSON column consumer is not initialized") - } - - // flush at the end - if columns == nil { - err := v.flush() - log.Info("JSON column consumer finished") - return err - } - - // consume columns data - for id, values := range columns { - validator, ok := v.validators[id] - if !ok { - // not a valid field id - break - } - - if validator.primaryKey && validator.autoID { - // autoid is no need to provide - break - } - - // convert and consume data - for i := 0; i < len(values); i++ { - if err := validator.convertFunc(values[i], v.fieldsData[id]); err != nil { - return errors.New("JSON column consumer: " + err.Error() + " of field " + strconv.FormatInt(id, 10)) - } - } - } - - return nil -} diff --git a/internal/util/importutil/json_handler_test.go b/internal/util/importutil/json_handler_test.go index 2a30c7aba1..6b5d05b15e 100644 --- a/internal/util/importutil/json_handler_test.go +++ b/internal/util/importutil/json_handler_test.go @@ -18,6 +18,7 @@ package importutil import ( "context" + "errors" "strings" "testing" @@ -26,15 +27,15 @@ import ( "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/allocator" - "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/storage" ) type mockIDAllocator struct { + allocErr error } -func (tso *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { +func (a *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { return &rootcoordpb.AllocIDResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, @@ -42,11 +43,13 @@ func (tso *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocI }, ID: int64(1), Count: req.Count, - }, nil + }, a.allocErr } -func newIDAllocator(ctx context.Context, t *testing.T) *allocator.IDAllocator { - mockIDAllocator := &mockIDAllocator{} +func newIDAllocator(ctx context.Context, t *testing.T, allocErr error) *allocator.IDAllocator { + mockIDAllocator := &mockIDAllocator{ + allocErr: allocErr, + } idAllocator, err := allocator.NewIDAllocator(ctx, mockIDAllocator, int64(1)) assert.Nil(t, err) @@ -56,136 +59,14 @@ func newIDAllocator(ctx context.Context, t *testing.T) *allocator.IDAllocator { return idAllocator } -func Test_GetFieldDimension(t *testing.T) { - schema := &schemapb.FieldSchema{ - FieldID: 111, - Name: "field_float_vector", - IsPrimaryKey: false, - Description: "float_vector", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "dim", Value: "4"}, - }, - } +func Test_NewJSONRowValidator(t *testing.T) { + validator, err := NewJSONRowValidator(nil, nil) + assert.NotNil(t, err) + assert.Nil(t, validator) - dim, err := getFieldDimension(schema) + validator, err = NewJSONRowValidator(sampleSchema(), nil) + assert.NotNil(t, validator) assert.Nil(t, err) - assert.Equal(t, 4, dim) - - schema.TypeParams = []*commonpb.KeyValuePair{ - {Key: "dim", Value: "abc"}, - } - dim, err = getFieldDimension(schema) - assert.NotNil(t, err) - assert.Equal(t, 0, dim) - - schema.TypeParams = []*commonpb.KeyValuePair{} - dim, err = getFieldDimension(schema) - assert.NotNil(t, err) - assert.Equal(t, 0, dim) -} - -func Test_InitValidators(t *testing.T) { - validators := make(map[storage.FieldID]*Validator) - err := initValidators(nil, validators) - assert.NotNil(t, err) - - schema := sampleSchema() - // success case - err = initValidators(schema, validators) - assert.Nil(t, err) - assert.Equal(t, len(schema.Fields), len(validators)) - name2ID := make(map[string]storage.FieldID) - for _, field := range schema.Fields { - name2ID[field.GetName()] = field.GetFieldID() - } - - checkFunc := func(funcName string, validVal interface{}, invalidVal interface{}) { - id := name2ID[funcName] - v, ok := validators[id] - assert.True(t, ok) - err = v.validateFunc(validVal) - assert.Nil(t, err) - err = v.validateFunc(invalidVal) - assert.NotNil(t, err) - } - - // validate functions - var validVal interface{} = true - var invalidVal interface{} = "aa" - checkFunc("field_bool", validVal, invalidVal) - - validVal = float64(100) - invalidVal = "aa" - checkFunc("field_int8", validVal, invalidVal) - checkFunc("field_int16", validVal, invalidVal) - checkFunc("field_int32", validVal, invalidVal) - checkFunc("field_int64", validVal, invalidVal) - checkFunc("field_float", validVal, invalidVal) - checkFunc("field_double", validVal, invalidVal) - - validVal = "aa" - invalidVal = 100 - checkFunc("field_string", validVal, invalidVal) - - validVal = []interface{}{float64(100), float64(101)} - invalidVal = "aa" - checkFunc("field_binary_vector", validVal, invalidVal) - invalidVal = []interface{}{float64(100)} - checkFunc("field_binary_vector", validVal, invalidVal) - invalidVal = []interface{}{float64(100), float64(101), float64(102)} - checkFunc("field_binary_vector", validVal, invalidVal) - invalidVal = []interface{}{true, true} - checkFunc("field_binary_vector", validVal, invalidVal) - invalidVal = []interface{}{float64(255), float64(-1)} - checkFunc("field_binary_vector", validVal, invalidVal) - - validVal = []interface{}{float64(1), float64(2), float64(3), float64(4)} - invalidVal = true - checkFunc("field_float_vector", validVal, invalidVal) - invalidVal = []interface{}{float64(1), float64(2), float64(3)} - checkFunc("field_float_vector", validVal, invalidVal) - invalidVal = []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5)} - checkFunc("field_float_vector", validVal, invalidVal) - invalidVal = []interface{}{"a", "b", "c", "d"} - checkFunc("field_float_vector", validVal, invalidVal) - - // error cases - schema = &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: make([]*schemapb.FieldSchema, 0), - } - schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ - FieldID: 111, - Name: "field_float_vector", - IsPrimaryKey: false, - Description: "float_vector", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "dim", Value: "aa"}, - }, - }) - - validators = make(map[storage.FieldID]*Validator) - err = initValidators(schema, validators) - assert.NotNil(t, err) - - schema.Fields = make([]*schemapb.FieldSchema, 0) - schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ - FieldID: 110, - Name: "field_binary_vector", - IsPrimaryKey: false, - Description: "float_vector", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "dim", Value: "aa"}, - }, - }) - - err = initValidators(schema, validators) - assert.NotNil(t, err) } func Test_JSONRowValidator(t *testing.T) { @@ -209,15 +90,15 @@ func Test_JSONRowValidator(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, int64(0), validator.ValidateCount()) - // // missed some fields - // reader = strings.NewReader(`{ - // "rows":[ - // {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}, - // {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]} - // ] - // }`) - // err = parser.ParseRows(reader, validator) - // assert.NotNil(t, err) + // missed some fields + reader = strings.NewReader(`{ + "rows":[ + {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}, + {"field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]} + ] + }`) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) // invalid dimension reader = strings.NewReader(`{ @@ -241,92 +122,90 @@ func Test_JSONRowValidator(t *testing.T) { validator.validators = nil err = validator.Handle(nil) assert.NotNil(t, err) -} -func Test_JSONColumnValidator(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - schema := sampleSchema() - parser := NewJSONParser(ctx, schema) - assert.NotNil(t, parser) - - // 0 row case - reader := strings.NewReader(`{ - "field_bool": [], - "field_int8": [], - "field_int16": [], - "field_int32": [], - "field_int64": [], - "field_float": [], - "field_double": [], - "field_string": [], - "field_binary_vector": [], - "field_float_vector": [] - }`) - - validator, err := NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.NotNil(t, err) - for _, count := range validator.rowCounter { - assert.Equal(t, int64(0), count) + // primary key is auto-generate, but user provide pk value, return error + schema = &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "ID", + IsPrimaryKey: true, + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 102, + Name: "Age", + IsPrimaryKey: false, + DataType: schemapb.DataType_Int64, + }, + }, } - // different row count - reader = strings.NewReader(`{ - "field_bool": [true], - "field_int8": [], - "field_int16": [], - "field_int32": [1, 2, 3], - "field_int64": [], - "field_float": [], - "field_double": [], - "field_string": [], - "field_binary_vector": [], - "field_float_vector": [] - }`) - - validator, err = NewJSONColumnValidator(schema, nil) + validator, err = NewJSONRowValidator(schema, nil) assert.NotNil(t, validator) assert.Nil(t, err) - err = parser.ParseColumns(reader, validator) - assert.NotNil(t, err) - - // invalid value type reader = strings.NewReader(`{ - "dummy": [], - "field_bool": [true], - "field_int8": [1], - "field_int16": [2], - "field_int32": [3], - "field_int64": [4], - "field_float": [1], - "field_double": [1], - "field_string": [9], - "field_binary_vector": [[254, 1]], - "field_float_vector": [[1.1, 1.2, 1.3, 1.4]] + "rows":[ + {"ID": 1, "Age": 2} + ] }`) + parser = NewJSONParser(ctx, schema) + err = parser.ParseRows(reader, validator) + assert.NotNil(t, err) +} - validator, err = NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) +func Test_NewJSONRowConsumer(t *testing.T) { + // nil schema + consumer, err := NewJSONRowConsumer(nil, nil, 2, 16, nil) + assert.NotNil(t, err) + assert.Nil(t, consumer) + + // wrong schema + schema := &schemapb.CollectionSchema{ + Name: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: false, + DataType: schemapb.DataType_None, + }, + }, + } + consumer, err = NewJSONRowConsumer(schema, nil, 2, 16, nil) + assert.NotNil(t, err) + assert.Nil(t, consumer) + + // no primary key + schema.Fields[0].IsPrimaryKey = false + schema.Fields[0].DataType = schemapb.DataType_Int64 + consumer, err = NewJSONRowConsumer(schema, nil, 2, 16, nil) + assert.NotNil(t, err) + assert.Nil(t, consumer) + + // primary key is autoid, but no IDAllocator + schema.Fields[0].IsPrimaryKey = true + schema.Fields[0].AutoID = true + consumer, err = NewJSONRowConsumer(schema, nil, 2, 16, nil) + assert.NotNil(t, err) + assert.Nil(t, consumer) + + // success + consumer, err = NewJSONRowConsumer(sampleSchema(), nil, 2, 16, nil) + assert.NotNil(t, consumer) assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.NotNil(t, err) - - // init failed - validator.validators = nil - err = validator.Handle(nil) - assert.NotNil(t, err) } func Test_JSONRowConsumer(t *testing.T) { ctx := context.Background() - idAllocator := newIDAllocator(ctx, t) + idAllocator := newIDAllocator(ctx, t, nil) schema := sampleSchema() parser := NewJSONParser(ctx, schema) @@ -376,9 +255,196 @@ func Test_JSONRowConsumer(t *testing.T) { assert.Equal(t, 5, totalCount) } +func Test_JSONRowConsumerFlush(t *testing.T) { + var callTime int32 + var totalCount int + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shard int) error { + callTime++ + field, ok := fields[101] + assert.True(t, ok) + assert.Greater(t, field.RowNum(), 0) + totalCount += field.RowNum() + return nil + } + + schema := &schemapb.CollectionSchema{ + Name: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + }, + } + + var shardNum int32 = 4 + var blockSize int64 = 1 + consumer, err := NewJSONRowConsumer(schema, nil, shardNum, blockSize, flushFunc) + assert.NotNil(t, consumer) + assert.Nil(t, err) + + // force flush + rowCountEachShard := 100 + for i := 0; i < int(shardNum); i++ { + pkFieldData := consumer.segmentsData[i][101].(*storage.Int64FieldData) + for j := 0; j < rowCountEachShard; j++ { + pkFieldData.Data = append(pkFieldData.Data, int64(j)) + } + pkFieldData.NumRows = []int64{int64(rowCountEachShard)} + } + + err = consumer.flush(true) + assert.Nil(t, err) + assert.Equal(t, shardNum, callTime) + assert.Equal(t, rowCountEachShard*int(shardNum), totalCount) + + // execeed block size trigger flush + callTime = 0 + totalCount = 0 + for i := 0; i < int(shardNum); i++ { + consumer.segmentsData[i] = initSegmentData(schema) + if i%2 == 0 { + continue + } + pkFieldData := consumer.segmentsData[i][101].(*storage.Int64FieldData) + for j := 0; j < rowCountEachShard; j++ { + pkFieldData.Data = append(pkFieldData.Data, int64(j)) + } + pkFieldData.NumRows = []int64{int64(rowCountEachShard)} + } + err = consumer.flush(true) + assert.Nil(t, err) + assert.Equal(t, shardNum/2, callTime) + assert.Equal(t, rowCountEachShard*int(shardNum)/2, totalCount) +} + +func Test_JSONRowConsumerHandle(t *testing.T) { + ctx := context.Background() + idAllocator := newIDAllocator(ctx, t, errors.New("error")) + + var callTime int32 + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shard int) error { + callTime++ + return errors.New("dummy error") + } + + schema := &schemapb.CollectionSchema{ + Name: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + }, + } + + t.Run("handle int64 pk", func(t *testing.T) { + consumer, err := NewJSONRowConsumer(schema, idAllocator, 1, 1, flushFunc) + assert.NotNil(t, consumer) + assert.Nil(t, err) + + pkFieldData := consumer.segmentsData[0][101].(*storage.Int64FieldData) + for i := 0; i < 10; i++ { + pkFieldData.Data = append(pkFieldData.Data, int64(i)) + } + pkFieldData.NumRows = []int64{int64(10)} + + // nil input will trigger flush + err = consumer.Handle(nil) + assert.NotNil(t, err) + assert.Equal(t, int32(1), callTime) + + // optional flush + callTime = 0 + rowCount := 100 + pkFieldData = consumer.segmentsData[0][101].(*storage.Int64FieldData) + for j := 0; j < rowCount; j++ { + pkFieldData.Data = append(pkFieldData.Data, int64(j)) + } + pkFieldData.NumRows = []int64{int64(rowCount)} + + input := make([]map[storage.FieldID]interface{}, rowCount) + for j := 0; j < rowCount; j++ { + input[j] = make(map[int64]interface{}) + input[j][101] = int64(j) + } + err = consumer.Handle(input) + assert.NotNil(t, err) + assert.Equal(t, int32(1), callTime) + + // failed to auto-generate pk + consumer.blockSize = 1024 * 1024 + err = consumer.Handle(input) + assert.NotNil(t, err) + + // hash int64 pk + consumer.rowIDAllocator = newIDAllocator(ctx, t, nil) + err = consumer.Handle(input) + assert.Nil(t, err) + assert.Equal(t, int64(rowCount), consumer.rowCounter) + assert.Equal(t, 2, len(consumer.autoIDRange)) + assert.Equal(t, int64(1), consumer.autoIDRange[0]) + assert.Equal(t, int64(1+rowCount), consumer.autoIDRange[1]) + }) + + t.Run("handle varchar pk", func(t *testing.T) { + schema = &schemapb.CollectionSchema{ + Name: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: true, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "max_length", Value: "1024"}, + }, + }, + }, + } + + idAllocator := newIDAllocator(ctx, t, nil) + consumer, err := NewJSONRowConsumer(schema, idAllocator, 1, 1024*1024, flushFunc) + assert.NotNil(t, consumer) + assert.Nil(t, err) + + rowCount := 100 + input := make([]map[storage.FieldID]interface{}, rowCount) + for j := 0; j < rowCount; j++ { + input[j] = make(map[int64]interface{}) + input[j][101] = "abc" + } + + // varchar pk cannot be autoid + err = consumer.Handle(input) + assert.NotNil(t, err) + + // hash varchar pk + schema.Fields[0].AutoID = false + consumer, err = NewJSONRowConsumer(schema, idAllocator, 1, 1024*1024, flushFunc) + assert.NotNil(t, consumer) + assert.Nil(t, err) + + err = consumer.Handle(input) + assert.Nil(t, err) + assert.Equal(t, int64(rowCount), consumer.rowCounter) + assert.Equal(t, 0, len(consumer.autoIDRange)) + }) +} + func Test_JSONRowConsumerStringKey(t *testing.T) { ctx := context.Background() - idAllocator := newIDAllocator(ctx, t) + idAllocator := newIDAllocator(ctx, t, nil) schema := strKeySchema() parser := NewJSONParser(ctx, schema) @@ -501,71 +567,3 @@ func Test_JSONRowConsumerStringKey(t *testing.T) { assert.Equal(t, shardNum, callTime) assert.Equal(t, 10, totalCount) } - -func Test_JSONColumnConsumer(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - schema := sampleSchema() - parser := NewJSONParser(ctx, schema) - assert.NotNil(t, parser) - - reader := strings.NewReader(`{ - "field_bool": [true, false, true, true, true], - "field_int8": [10, 11, 12, 13, 14], - "field_int16": [100, 101, 102, 103, 104], - "field_int32": [1000, 1001, 1002, 1003, 1004], - "field_int64": [10000, 10001, 10002, 10003, 10004], - "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], - "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], - "field_string": ["a", "b", "c", "d", "e"], - "field_binary_vector": [ - [254, 1], - [253, 2], - [252, 3], - [251, 4], - [250, 5] - ], - "field_float_vector": [ - [1.1, 1.2, 1.3, 1.4], - [2.1, 2.2, 2.3, 2.4], - [3.1, 3.2, 3.3, 3.4], - [4.1, 4.2, 4.3, 4.4], - [5.1, 5.2, 5.3, 5.4] - ] - }`) - - callTime := 0 - rowCount := 0 - consumeFunc := func(fields map[storage.FieldID]storage.FieldData) error { - callTime++ - for id, data := range fields { - if id == common.RowIDField { - continue - } - if rowCount == 0 { - rowCount = data.RowNum() - } else { - assert.Equal(t, rowCount, data.RowNum()) - } - } - return nil - } - - consumer, err := NewJSONColumnConsumer(schema, consumeFunc) - assert.NotNil(t, consumer) - assert.Nil(t, err) - - validator, err := NewJSONColumnValidator(schema, consumer) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.Nil(t, err) - for _, count := range validator.ValidateCount() { - assert.Equal(t, int64(5), count) - } - - assert.Equal(t, 1, callTime) - assert.Equal(t, 5, rowCount) -} diff --git a/internal/util/importutil/json_parser.go b/internal/util/importutil/json_parser.go index 12dd71bec2..f9a60022bc 100644 --- a/internal/util/importutil/json_parser.go +++ b/internal/util/importutil/json_parser.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "strings" @@ -84,28 +85,26 @@ func adjustBufSize(parser *JSONParser, collectionSchema *schemapb.CollectionSche bufSize = MinBufferSize } - log.Info("JSON parse: reset bufSize", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufSize", bufSize)) + log.Info("JSON parser: reset bufSize", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufSize", bufSize)) parser.bufSize = int64(bufSize) } -func (p *JSONParser) logError(msg string) error { - log.Error(msg) - return errors.New(msg) -} - func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { if handler == nil { - return p.logError("JSON parse handler is nil") + log.Error("JSON parse handler is nil") + return errors.New("JSON parse handler is nil") } dec := json.NewDecoder(r) t, err := dec.Token() if err != nil { - return p.logError("JSON parse: row count is 0") + log.Error("JSON parser: row count is 0") + return errors.New("JSON parser: row count is 0") } if t != json.Delim('{') { - return p.logError("JSON parse: invalid JSON format, the content should be started with'{'") + log.Error("JSON parser: invalid JSON format, the content should be started with'{'") + return errors.New("JSON parser: invalid JSON format, the content should be started with'{'") } // read the first level @@ -114,24 +113,28 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { // read the key t, err := dec.Token() if err != nil { - return p.logError("JSON parse: " + err.Error()) + log.Error("JSON parser: read json token error", zap.Any("err", err)) + return fmt.Errorf("JSON parser: read json token error: %v", err) } key := t.(string) keyLower := strings.ToLower(key) // the root key should be RowRootNode if keyLower != RowRootNode { - return p.logError("JSON parse: invalid row-based JSON format, the key " + key + " is not found") + log.Error("JSON parser: invalid row-based JSON format, the key is not found", zap.String("key", key)) + return fmt.Errorf("JSON parser: invalid row-based JSON format, the key %s is not found", key) } // started by '[' t, err = dec.Token() if err != nil { - return p.logError("JSON parse: " + err.Error()) + log.Error("JSON parser: read json token error", zap.Any("err", err)) + return fmt.Errorf("JSON parser: read json token error: %v", err) } if t != json.Delim('[') { - return p.logError("JSON parse: invalid row-based JSON format, rows list should begin with '['") + log.Error("JSON parser: invalid row-based JSON format, rows list should begin with '['") + return errors.New("JSON parser: invalid row-based JSON format, rows list should begin with '['") } // read buffer @@ -139,27 +142,36 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { for dec.More() { var value interface{} if err := dec.Decode(&value); err != nil { - return p.logError("JSON parse: " + err.Error()) + log.Error("JSON parser: decode json value error", zap.Any("err", err)) + return fmt.Errorf("JSON parser: decode json value error: %v", err) } switch value.(type) { case map[string]interface{}: break default: - return p.logError("JSON parse: invalid JSON format, each row should be a key-value map") + log.Error("JSON parser: invalid JSON format, each row should be a key-value map") + return errors.New("JSON parser: invalid JSON format, each row should be a key-value map") } row := make(map[storage.FieldID]interface{}) stringMap := value.(map[string]interface{}) for k, v := range stringMap { - row[p.name2FieldID[k]] = v + // if user provided redundant field, return error + fieldID, ok := p.name2FieldID[k] + if !ok { + log.Error("JSON parser: the field is not defined in collection schema", zap.String("fieldName", k)) + return fmt.Errorf("JSON parser: the field '%s' is not defined in collection schema", k) + } + row[fieldID] = v } buf = append(buf, row) if len(buf) >= int(p.bufSize) { isEmpty = false if err = handler.Handle(buf); err != nil { - return p.logError(err.Error()) + log.Error("JSON parser: parse values error", zap.Any("err", err)) + return fmt.Errorf("JSON parser: parse values error: %v", err) } // clear the buffer @@ -171,26 +183,27 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { if len(buf) > 0 { isEmpty = false if err = handler.Handle(buf); err != nil { - return p.logError(err.Error()) + log.Error("JSON parser: parse values error", zap.Any("err", err)) + return fmt.Errorf("JSON parser: parse values error: %v", err) } } // end by ']' t, err = dec.Token() if err != nil { - return p.logError("JSON parse: " + err.Error()) + log.Error("JSON parser: read json token error", zap.Any("err", err)) + return fmt.Errorf("JSON parser: read json token error: %v", err) } if t != json.Delim(']') { - return p.logError("JSON parse: invalid column-based JSON format, rows list should end with a ']'") + log.Error("JSON parser: invalid column-based JSON format, rows list should end with a ']'") + return errors.New("JSON parser: invalid column-based JSON format, rows list should end with a ']'") } - // canceled? - select { - case <-p.ctx.Done(): - return p.logError("import task was canceled") - default: - break + // outside context might be canceled(service stop, or future enhancement for canceling import task) + if isCanceled(p.ctx) { + log.Error("JSON parser: import task was canceled") + return errors.New("JSON parser: import task was canceled") } // this break means we require the first node must be RowRootNode @@ -199,106 +212,8 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { } if isEmpty { - return p.logError("JSON parse: row count is 0") - } - - // send nil to notify the handler all have done - return handler.Handle(nil) -} - -func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error { - if handler == nil { - return p.logError("JSON parse handler is nil") - } - - dec := json.NewDecoder(r) - - t, err := dec.Token() - if err != nil { - return p.logError("JSON parse: row count is 0") - } - if t != json.Delim('{') { - return p.logError("JSON parse: invalid JSON format, the content should be started with'{'") - } - - // read the first level - isEmpty := true - for dec.More() { - // read the key - t, err := dec.Token() - if err != nil { - return p.logError("JSON parse: " + err.Error()) - } - key := t.(string) - - // not a valid column name, skip - _, isValidField := p.fields[key] - - // started by '[' - t, err = dec.Token() - if err != nil { - return p.logError("JSON parse: " + err.Error()) - } - - if t != json.Delim('[') { - return p.logError("JSON parse: invalid column-based JSON format, each field should begin with '['") - } - - id := p.name2FieldID[key] - // read buffer - buf := make(map[storage.FieldID][]interface{}) - buf[id] = make([]interface{}, 0, MinBufferSize) - for dec.More() { - var value interface{} - if err := dec.Decode(&value); err != nil { - return p.logError("JSON parse: " + err.Error()) - } - - if !isValidField { - continue - } - - buf[id] = append(buf[id], value) - if len(buf[id]) >= int(p.bufSize) { - isEmpty = false - if err = handler.Handle(buf); err != nil { - return p.logError(err.Error()) - } - - // clear the buffer - buf[id] = make([]interface{}, 0, MinBufferSize) - } - } - - // some values in buffer not parsed, parse them - if len(buf[id]) > 0 { - isEmpty = false - if err = handler.Handle(buf); err != nil { - return p.logError(err.Error()) - } - } - - // end by ']' - t, err = dec.Token() - if err != nil { - return p.logError("JSON parse: " + err.Error()) - } - - if t != json.Delim(']') { - return p.logError("JSON parse: invalid column-based JSON format, each field should end with a ']'") - } - - // canceled? - select { - case <-p.ctx.Done(): - return p.logError("import task was canceled") - default: - break - } - } - - if isEmpty { - return p.logError("JSON parse: row count is 0") + log.Error("JSON parser: row count is 0") + return errors.New("JSON parser: row count is 0") } // send nil to notify the handler all have done diff --git a/internal/util/importutil/json_parser_test.go b/internal/util/importutil/json_parser_test.go index 5192ab8fc0..21b91fc7a3 100644 --- a/internal/util/importutil/json_parser_test.go +++ b/internal/util/importutil/json_parser_test.go @@ -27,157 +27,6 @@ import ( "github.com/stretchr/testify/assert" ) -func sampleSchema() *schemapb.CollectionSchema { - schema := &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 102, - Name: "field_bool", - IsPrimaryKey: false, - Description: "bool", - DataType: schemapb.DataType_Bool, - }, - { - FieldID: 103, - Name: "field_int8", - IsPrimaryKey: false, - Description: "int8", - DataType: schemapb.DataType_Int8, - }, - { - FieldID: 104, - Name: "field_int16", - IsPrimaryKey: false, - Description: "int16", - DataType: schemapb.DataType_Int16, - }, - { - FieldID: 105, - Name: "field_int32", - IsPrimaryKey: false, - Description: "int32", - DataType: schemapb.DataType_Int32, - }, - { - FieldID: 106, - Name: "field_int64", - IsPrimaryKey: true, - AutoID: false, - Description: "int64", - DataType: schemapb.DataType_Int64, - }, - { - FieldID: 107, - Name: "field_float", - IsPrimaryKey: false, - Description: "float", - DataType: schemapb.DataType_Float, - }, - { - FieldID: 108, - Name: "field_double", - IsPrimaryKey: false, - Description: "double", - DataType: schemapb.DataType_Double, - }, - { - FieldID: 109, - Name: "field_string", - IsPrimaryKey: false, - Description: "string", - DataType: schemapb.DataType_VarChar, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "max_length", Value: "128"}, - }, - }, - { - FieldID: 110, - Name: "field_binary_vector", - IsPrimaryKey: false, - Description: "binary_vector", - DataType: schemapb.DataType_BinaryVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "dim", Value: "16"}, - }, - }, - { - FieldID: 111, - Name: "field_float_vector", - IsPrimaryKey: false, - Description: "float_vector", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "dim", Value: "4"}, - }, - }, - }, - } - return schema -} - -func strKeySchema() *schemapb.CollectionSchema { - schema := &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 101, - Name: "uid", - IsPrimaryKey: true, - AutoID: false, - Description: "uid", - DataType: schemapb.DataType_VarChar, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "max_length", Value: "1024"}, - }, - }, - { - FieldID: 102, - Name: "int_scalar", - IsPrimaryKey: false, - Description: "int_scalar", - DataType: schemapb.DataType_Int32, - }, - { - FieldID: 103, - Name: "float_scalar", - IsPrimaryKey: false, - Description: "float_scalar", - DataType: schemapb.DataType_Float, - }, - { - FieldID: 104, - Name: "string_scalar", - IsPrimaryKey: false, - Description: "string_scalar", - DataType: schemapb.DataType_VarChar, - }, - { - FieldID: 105, - Name: "bool_scalar", - IsPrimaryKey: false, - Description: "bool_scalar", - DataType: schemapb.DataType_Bool, - }, - { - FieldID: 106, - Name: "vectors", - IsPrimaryKey: false, - Description: "vectors", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "dim", Value: "4"}, - }, - }, - }, - } - return schema -} - func Test_AdjustBufSize(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -234,6 +83,7 @@ func Test_JSONParserParserRows(t *testing.T) { ] }`) + // handler is nil err := parser.ParseRows(reader, nil) assert.NotNil(t, err) @@ -241,10 +91,12 @@ func Test_JSONParserParserRows(t *testing.T) { assert.NotNil(t, validator) assert.Nil(t, err) + // success err = parser.ParseRows(reader, validator) assert.Nil(t, err) assert.Equal(t, int64(5), validator.ValidateCount()) + // not a row-based format reader = strings.NewReader(`{ "dummy":[] }`) @@ -255,6 +107,7 @@ func Test_JSONParserParserRows(t *testing.T) { err = parser.ParseRows(reader, validator) assert.NotNil(t, err) + // rows is not a list reader = strings.NewReader(`{ "rows": }`) @@ -265,6 +118,7 @@ func Test_JSONParserParserRows(t *testing.T) { err = parser.ParseRows(reader, validator) assert.NotNil(t, err) + // typo reader = strings.NewReader(`{ "rows": [} }`) @@ -275,6 +129,7 @@ func Test_JSONParserParserRows(t *testing.T) { err = parser.ParseRows(reader, validator) assert.NotNil(t, err) + // rows is not a list reader = strings.NewReader(`{ "rows": {} }`) @@ -285,6 +140,7 @@ func Test_JSONParserParserRows(t *testing.T) { err = parser.ParseRows(reader, validator) assert.NotNil(t, err) + // rows is not a list of list reader = strings.NewReader(`{ "rows": [[]] }`) @@ -295,6 +151,7 @@ func Test_JSONParserParserRows(t *testing.T) { err = parser.ParseRows(reader, validator) assert.NotNil(t, err) + // not valid json format reader = strings.NewReader(`[]`) validator, err = NewJSONRowValidator(schema, nil) assert.NotNil(t, validator) @@ -303,6 +160,7 @@ func Test_JSONParserParserRows(t *testing.T) { err = parser.ParseRows(reader, validator) assert.NotNil(t, err) + // empty content reader = strings.NewReader(`{}`) validator, err = NewJSONRowValidator(schema, nil) assert.NotNil(t, validator) @@ -311,6 +169,7 @@ func Test_JSONParserParserRows(t *testing.T) { err = parser.ParseRows(reader, validator) assert.NotNil(t, err) + // empty content reader = strings.NewReader(``) validator, err = NewJSONRowValidator(schema, nil) assert.NotNil(t, validator) @@ -318,129 +177,23 @@ func Test_JSONParserParserRows(t *testing.T) { err = parser.ParseRows(reader, validator) assert.NotNil(t, err) -} -func Test_JSONParserParserColumns(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - schema := sampleSchema() - parser := NewJSONParser(ctx, schema) - assert.NotNil(t, parser) - parser.bufSize = 1 - - reader := strings.NewReader(`{ - "field_bool": [true, false, true, true, true], - "field_int8": [10, 11, 12, 13, 14], - "field_int16": [100, 101, 102, 103, 104], - "field_int32": [1000, 1001, 1002, 1003, 1004], - "field_int64": [10000, 10001, 10002, 10003, 10004], - "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], - "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], - "field_string": ["a", "b", "c", "d", "e"], - "field_binary_vector": [ - [254, 1], - [253, 2], - [252, 3], - [251, 4], - [250, 5] - ], - "field_float_vector": [ - [1.1, 1.2, 1.3, 1.4], - [2.1, 2.2, 2.3, 2.4], - [3.1, 3.2, 3.3, 3.4], - [4.1, 4.2, 4.3, 4.4], - [5.1, 5.2, 5.3, 5.4] + // redundant field + reader = strings.NewReader(`{ + "rows":[ + {"dummy": 1, "field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}, ] }`) - - err := parser.ParseColumns(reader, nil) + err = parser.ParseRows(reader, validator) assert.NotNil(t, err) - validator, err := NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.Nil(t, err) - counter := validator.ValidateCount() - for _, v := range counter { - assert.Equal(t, int64(5), v) - } - + // field missed reader = strings.NewReader(`{ - "field_int8": [10, 11, 12, 13, 14], - "dummy":[1, 2, 3] + "rows":[ + {"field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}, + ] }`) - validator, err = NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.Nil(t, err) - - reader = strings.NewReader(`{ - "dummy":[1, 2, 3] - }`) - validator, err = NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.NotNil(t, err) - - reader = strings.NewReader(`{ - "field_bool": - }`) - validator, err = NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.NotNil(t, err) - - reader = strings.NewReader(`{ - "field_bool":{} - }`) - validator, err = NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.NotNil(t, err) - - reader = strings.NewReader(`{ - "field_bool":[} - }`) - validator, err = NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.NotNil(t, err) - - reader = strings.NewReader(`[]`) - validator, err = NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.NotNil(t, err) - - reader = strings.NewReader(`{}`) - validator, err = NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.NotNil(t, err) - - reader = strings.NewReader(``) - validator, err = NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) + err = parser.ParseRows(reader, validator) assert.NotNil(t, err) } @@ -545,42 +298,3 @@ func Test_JSONParserParserRowsStringKey(t *testing.T) { assert.Nil(t, err) assert.Equal(t, int64(10), validator.ValidateCount()) } - -func Test_JSONParserParserColumnsStrKey(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - schema := strKeySchema() - parser := NewJSONParser(ctx, schema) - assert.NotNil(t, parser) - parser.bufSize = 1 - - reader := strings.NewReader(`{ - "uid": ["Dm4aWrbNzhmjwCTEnCJ9LDPO2N09sqysxgVfbH9Zmn3nBzmwsmk0eZN6x7wSAoPQ", "RP50U0d2napRjXu94a8oGikWgklvVsXFurp8RR4tHGw7N0gk1b7opm59k3FCpyPb", "oxhFkQitWPPw0Bjmj7UQcn4iwvS0CU7RLAC81uQFFQjWtOdiB329CPyWkfGSeYfE", "sxoEL4Mpk1LdsyXhbNm059UWJ3CvxURLCQczaVI5xtBD4QcVWTDFUW7dBdye6nbn", "g33Rqq2UQSHPRHw5FvuXxf5uGEhIAetxE6UuXXCJj0hafG8WuJr1ueZftsySCqAd"], - "int_scalar": [9070353, 8505288, 4392660, 7927425, 9288807], - "float_scalar": [0.9798043638085004, 0.937913432198687, 0.32381232630490264, 0.31074026464844895, 0.4953578200336135], - "string_scalar": ["ShQ44OX0z8kGpRPhaXmfSsdH7JHq5DsZzu0e2umS1hrWG0uONH2RIIAdOECaaXir", "Ld4b0avxathBdNvCrtm3QsWO1pYktUVR7WgAtrtozIwrA8vpeactNhJ85CFGQnK5", "EmAlB0xdQcxeBtwlZJQnLgKodiuRinynoQtg0eXrjkq24dQohzSm7Bx3zquHd3kO", "fdY2beCvs1wSws0Gb9ySD92xwfEfJpX5DQgsWoISylBAoYOcXpRaqIJoXYS4g269", "6f8Iv1zQAGksj5XxMbbI5evTrYrB8fSFQ58jl0oU7Z4BpA81VsD2tlWqkhfoBNa7"], - "bool_scalar": [true, false, true, false, false], - "vectors": [ - [0.5040062902126952, 0.8297619818664708, 0.20248342801564806, 0.12834786423659314], - [0.528232122836893, 0.6916116750653186, 0.41443762522548705, 0.26624344144792056], - [0.7978693027281338, 0.12394906726785092, 0.42431962903815285, 0.4098707807351914], - [0.3716157812069954, 0.006981281113265229, 0.9007003458552365, 0.22492634316191004], - [0.5921374209648096, 0.04234832587925662, 0.7803878096531548, 0.1964045837884633] - ] - }`) - - err := parser.ParseColumns(reader, nil) - assert.NotNil(t, err) - - validator, err := NewJSONColumnValidator(schema, nil) - assert.NotNil(t, validator) - assert.Nil(t, err) - - err = parser.ParseColumns(reader, validator) - assert.Nil(t, err) - counter := validator.ValidateCount() - for _, v := range counter { - assert.Equal(t, int64(5), v) - } -} diff --git a/internal/util/importutil/numpy_adapter.go b/internal/util/importutil/numpy_adapter.go index 60b1ea386b..27bd4dee9e 100644 --- a/internal/util/importutil/numpy_adapter.go +++ b/internal/util/importutil/numpy_adapter.go @@ -20,11 +20,27 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" + "io/ioutil" "os" + "reflect" + "regexp" + "strconv" + "unicode/utf8" + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/log" "github.com/sbinet/npyio" "github.com/sbinet/npyio/npy" + "go.uber.org/zap" +) + +var ( + reStrPre = regexp.MustCompile(`^[|]*?(\d.*)[Sa]$`) + reStrPost = regexp.MustCompile(`^[|]*?[Sa](\d.*)$`) + reUniPre = regexp.MustCompile(`^[<|>]*?(\d.*)U$`) + reUniPost = regexp.MustCompile(`^[<|>]*?U(\d.*)$`) ) func CreateNumpyFile(path string, data interface{}) error { @@ -52,17 +68,18 @@ func CreateNumpyData(data interface{}) ([]byte, error) { return buf.Bytes(), nil } -// a class to expand other numpy lib ability +// NumpyAdapter is the class to expand other numpy lib ability // we evaluate two go-numpy lins: github.com/kshedden/gonpy and github.com/sbinet/npyio // the npyio lib read data one by one, the performance is poor, we expand the read methods // to read data in one batch, the performance is 100X faster // the gonpy lib also read data in one batch, but it has no method to read bool data, and the ability // to handle different data type is not strong as the npylib, so we choose the npyio lib to expand. type NumpyAdapter struct { - reader io.Reader // data source, typically is os.File - npyReader *npy.Reader // reader of npyio lib - order binary.ByteOrder // LittleEndian or BigEndian - readPosition int // how many elements have been read + reader io.Reader // data source, typically is os.File + npyReader *npy.Reader // reader of npyio lib + order binary.ByteOrder // LittleEndian or BigEndian + readPosition int // how many elements have been read + dataType schemapb.DataType // data type parsed from numpy file header } func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) { @@ -70,17 +87,106 @@ func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) { if err != nil { return nil, err } + + dataType, err := convertNumpyType(r.Header.Descr.Type) + if err != nil { + return nil, err + } + adapter := &NumpyAdapter{ reader: reader, npyReader: r, readPosition: 0, + dataType: dataType, } adapter.setByteOrder() + log.Info("Numpy adapter: numpy header info", + zap.Any("shape", r.Header.Descr.Shape), + zap.String("dType", r.Header.Descr.Type), + zap.Uint8("majorVer", r.Header.Major), + zap.Uint8("minorVer", r.Header.Minor), + zap.String("ByteOrder", adapter.order.String())) + return adapter, err } -// the logic of this method is copied from npyio lib +// convertNumpyType gets data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector) +func convertNumpyType(typeStr string) (schemapb.DataType, error) { + log.Info("Numpy adapter: parse numpy file dtype", zap.String("dtype", typeStr)) + switch typeStr { + case "b1", "i1", "int8": + return schemapb.DataType_Int8, nil + case "i2", "i2", "int16": + return schemapb.DataType_Int16, nil + case "i4", "i4", "int32": + return schemapb.DataType_Int32, nil + case "i8", "i8", "int64": + return schemapb.DataType_Int64, nil + case "f4", "f4", "float32": + return schemapb.DataType_Float, nil + case "f8", "f8", "float64": + return schemapb.DataType_Double, nil + default: + if isStringType(typeStr) { + return schemapb.DataType_VarChar, nil + } + log.Error("Numpy adapter: the numpy file data type not supported", zap.String("dataType", typeStr)) + return schemapb.DataType_None, fmt.Errorf("Numpy adapter: the numpy file dtype '%s' is not supported", typeStr) + } +} + +func stringLen(dtype string) (int, bool, error) { + var utf bool + switch { + case reStrPre.MatchString(dtype), reStrPost.MatchString(dtype): + utf = false + case reUniPre.MatchString(dtype), reUniPost.MatchString(dtype): + utf = true + } + + if m := reStrPre.FindStringSubmatch(dtype); m != nil { + v, err := strconv.Atoi(m[1]) + if err != nil { + return 0, false, err + } + return v, utf, nil + } + if m := reStrPost.FindStringSubmatch(dtype); m != nil { + v, err := strconv.Atoi(m[1]) + if err != nil { + return 0, false, err + } + return v, utf, nil + } + if m := reUniPre.FindStringSubmatch(dtype); m != nil { + v, err := strconv.Atoi(m[1]) + if err != nil { + return 0, false, err + } + return v, utf, nil + } + if m := reUniPost.FindStringSubmatch(dtype); m != nil { + v, err := strconv.Atoi(m[1]) + if err != nil { + return 0, false, err + } + return v, utf, nil + } + + return 0, false, fmt.Errorf("Numpy adapter: data type '%s' of numpy file is not varchar data type", dtype) +} + +func isStringType(typeStr string) bool { + rt := npyio.TypeFrom(typeStr) + return rt == reflect.TypeOf((*string)(nil)).Elem() +} + +// setByteOrder sets BigEndian/LittleEndian, the logic of this method is copied from npyio lib func (n *NumpyAdapter) setByteOrder() { var nativeEndian binary.ByteOrder v := uint16(1) @@ -109,15 +215,15 @@ func (n *NumpyAdapter) NpyReader() *npy.Reader { return n.npyReader } -func (n *NumpyAdapter) GetType() string { - return n.npyReader.Header.Descr.Type +func (n *NumpyAdapter) GetType() schemapb.DataType { + return n.dataType } func (n *NumpyAdapter) GetShape() []int { return n.npyReader.Header.Descr.Shape } -func (n *NumpyAdapter) checkSize(size int) int { +func (n *NumpyAdapter) checkCount(count int) int { shape := n.GetShape() // empty file? @@ -135,35 +241,34 @@ func (n *NumpyAdapter) checkSize(size int) int { } // overflow? - if size > (total - n.readPosition) { + if count > (total - n.readPosition) { return total - n.readPosition } - return size + return count } -func (n *NumpyAdapter) ReadBool(size int) ([]bool, error) { - if n.npyReader == nil { - return nil, errors.New("reader is not initialized") +func (n *NumpyAdapter) ReadBool(count int) ([]bool, error) { + if count <= 0 { + return nil, errors.New("Numpy adapter: cannot read bool data with a zero or nagative count") } // incorrect type - switch n.npyReader.Header.Descr.Type { - case "b1", "i1", "int8": - default: - return nil, errors.New("numpy data is not int8 type") + if n.dataType != schemapb.DataType_Int8 { + return nil, errors.New("Numpy adapter: numpy data is not int8 type") } // avoid read overflow - readSize := n.checkSize(size) + readSize := n.checkCount(count) if readSize <= 0 { - return nil, errors.New("nothing to read") + return nil, errors.New("Numpy adapter: end of int8 file, nothing to read") } + // read data data := make([]int8, readSize) err := binary.Read(n.reader, n.order, &data) if err != nil { - return nil, err + return nil, fmt.Errorf("Numpy adapter: failed to read int8 data with count %d, error: %w", readSize, err) } // update read position after successfully read @@ -232,28 +338,27 @@ func (n *NumpyAdapter) ReadInt8(size int) ([]int8, error) { return data, nil } -func (n *NumpyAdapter) ReadInt16(size int) ([]int16, error) { - if n.npyReader == nil { - return nil, errors.New("reader is not initialized") +func (n *NumpyAdapter) ReadInt16(count int) ([]int16, error) { + if count <= 0 { + return nil, errors.New("Numpy adapter: cannot read int16 data with a zero or nagative count") } // incorrect type - switch n.npyReader.Header.Descr.Type { - case "i2", "i2", "int16": - default: - return nil, errors.New("numpy data is not int16 type") + if n.dataType != schemapb.DataType_Int16 { + return nil, errors.New("Numpy adapter: numpy data is not int16 type") } // avoid read overflow - readSize := n.checkSize(size) + readSize := n.checkCount(count) if readSize <= 0 { - return nil, errors.New("nothing to read") + return nil, errors.New("Numpy adapter: end of int16 file, nothing to read") } + // read data data := make([]int16, readSize) err := binary.Read(n.reader, n.order, &data) if err != nil { - return nil, err + return nil, fmt.Errorf("Numpy adapter: failed to read int16 data with count %d, error: %w", readSize, err) } // update read position after successfully read @@ -262,28 +367,27 @@ func (n *NumpyAdapter) ReadInt16(size int) ([]int16, error) { return data, nil } -func (n *NumpyAdapter) ReadInt32(size int) ([]int32, error) { - if n.npyReader == nil { - return nil, errors.New("reader is not initialized") +func (n *NumpyAdapter) ReadInt32(count int) ([]int32, error) { + if count <= 0 { + return nil, errors.New("Numpy adapter: cannot read int32 data with a zero or nagative count") } // incorrect type - switch n.npyReader.Header.Descr.Type { - case "i4", "i4", "int32": - default: - return nil, errors.New("numpy data is not int32 type") + if n.dataType != schemapb.DataType_Int32 { + return nil, errors.New("Numpy adapter: numpy data is not int32 type") } // avoid read overflow - readSize := n.checkSize(size) + readSize := n.checkCount(count) if readSize <= 0 { - return nil, errors.New("nothing to read") + return nil, errors.New("Numpy adapter: end of int32 file, nothing to read") } + // read data data := make([]int32, readSize) err := binary.Read(n.reader, n.order, &data) if err != nil { - return nil, err + return nil, fmt.Errorf("Numpy adapter: failed to read int32 data with count %d, error: %w", readSize, err) } // update read position after successfully read @@ -292,28 +396,27 @@ func (n *NumpyAdapter) ReadInt32(size int) ([]int32, error) { return data, nil } -func (n *NumpyAdapter) ReadInt64(size int) ([]int64, error) { - if n.npyReader == nil { - return nil, errors.New("reader is not initialized") +func (n *NumpyAdapter) ReadInt64(count int) ([]int64, error) { + if count <= 0 { + return nil, errors.New("Numpy adapter: cannot read int64 data with a zero or nagative count") } // incorrect type - switch n.npyReader.Header.Descr.Type { - case "i8", "i8", "int64": - default: - return nil, errors.New("numpy data is not int64 type") + if n.dataType != schemapb.DataType_Int64 { + return nil, errors.New("Numpy adapter: numpy data is not int64 type") } // avoid read overflow - readSize := n.checkSize(size) + readSize := n.checkCount(count) if readSize <= 0 { - return nil, errors.New("nothing to read") + return nil, errors.New("Numpy adapter: end of int64 file, nothing to read") } + // read data data := make([]int64, readSize) err := binary.Read(n.reader, n.order, &data) if err != nil { - return nil, err + return nil, fmt.Errorf("Numpy adapter: failed to read int64 data with count %d, error: %w", readSize, err) } // update read position after successfully read @@ -322,28 +425,27 @@ func (n *NumpyAdapter) ReadInt64(size int) ([]int64, error) { return data, nil } -func (n *NumpyAdapter) ReadFloat32(size int) ([]float32, error) { - if n.npyReader == nil { - return nil, errors.New("reader is not initialized") +func (n *NumpyAdapter) ReadFloat32(count int) ([]float32, error) { + if count <= 0 { + return nil, errors.New("Numpy adapter: cannot read float32 data with a zero or nagative count") } // incorrect type - switch n.npyReader.Header.Descr.Type { - case "f4", "f4", "float32": - default: - return nil, errors.New("numpy data is not float32 type") + if n.dataType != schemapb.DataType_Float { + return nil, errors.New("Numpy adapter: numpy data is not float32 type") } // avoid read overflow - readSize := n.checkSize(size) + readSize := n.checkCount(count) if readSize <= 0 { - return nil, errors.New("nothing to read") + return nil, errors.New("Numpy adapter: end of float32 file, nothing to read") } + // read data data := make([]float32, readSize) err := binary.Read(n.reader, n.order, &data) if err != nil { - return nil, err + return nil, fmt.Errorf("Numpy adapter: failed to read float32 data with count %d, error: %w", readSize, err) } // update read position after successfully read @@ -352,28 +454,109 @@ func (n *NumpyAdapter) ReadFloat32(size int) ([]float32, error) { return data, nil } -func (n *NumpyAdapter) ReadFloat64(size int) ([]float64, error) { - if n.npyReader == nil { - return nil, errors.New("reader is not initialized") +func (n *NumpyAdapter) ReadFloat64(count int) ([]float64, error) { + if count <= 0 { + return nil, errors.New("Numpy adapter: cannot read float64 data with a zero or nagative count") } // incorrect type - switch n.npyReader.Header.Descr.Type { - case "f8", "f8", "float64": - default: - return nil, errors.New("numpy data is not float32 type") + if n.dataType != schemapb.DataType_Double { + return nil, errors.New("Numpy adapter: numpy data is not float64 type") } // avoid read overflow - readSize := n.checkSize(size) + readSize := n.checkCount(count) if readSize <= 0 { - return nil, errors.New("nothing to read") + return nil, errors.New("Numpy adapter: end of float64 file, nothing to read") } + // read data data := make([]float64, readSize) err := binary.Read(n.reader, n.order, &data) if err != nil { - return nil, err + return nil, fmt.Errorf("Numpy adapter: failed to read float64 data with count %d, error: %w", readSize, err) + } + + // update read position after successfully read + n.readPosition += readSize + + return data, nil +} + +func (n *NumpyAdapter) ReadString(count int) ([]string, error) { + if count <= 0 { + return nil, errors.New("Numpy adapter: cannot read varhar data with a zero or nagative count") + } + + // incorrect type + if n.dataType != schemapb.DataType_VarChar { + return nil, errors.New("Numpy adapter: numpy data is not varhar type") + } + + // varchar length, this is the max length, some item is shorter than this length, but they also occupy bytes of max length + maxLen, utf, err := stringLen(n.npyReader.Header.Descr.Type) + if err != nil || maxLen <= 0 { + log.Error("Numpy adapter: failed to get max length of varchar from numpy file header", zap.Int("maxLen", maxLen), zap.Any("err", err)) + return nil, fmt.Errorf("Numpy adapter: failed to get max length %d of varchar from numpy file header, error: %w", maxLen, err) + } + log.Info("Numpy adapter: get varchar max length from numpy file header", zap.Int("maxLen", maxLen), zap.Bool("utf", utf)) + + // avoid read overflow + readSize := n.checkCount(count) + if readSize <= 0 { + return nil, errors.New("Numpy adapter: end of varhar file, nothing to read") + } + + // read data + data := make([]string, 0) + for i := 0; i < readSize; i++ { + if utf { + // in the numpy file, each utf8 character occupy utf8.UTFMax bytes, each string occupys utf8.UTFMax*maxLen bytes + // for example, an ANSI character "a" only uses one byte, but it still occupy utf8.UTFMax bytes + // a chinese character uses three bytes, it also occupy utf8.UTFMax bytes + raw, err := ioutil.ReadAll(io.LimitReader(n.reader, utf8.UTFMax*int64(maxLen))) + if err != nil { + log.Error("Numpy adapter: failed to read utf8 string from numpy file", zap.Int("i", i), zap.Any("err", err)) + return nil, fmt.Errorf("Numpy adapter: failed to read utf8 string from numpy file, error: %w", err) + } + + var str string + for len(raw) > 0 { + r, _ := utf8.DecodeRune(raw) + if r == utf8.RuneError { + log.Error("Numpy adapter: failed to decode utf8 string from numpy file", zap.Any("raw", raw[:utf8.UTFMax])) + return nil, fmt.Errorf("Numpy adapter: failed to decode utf8 string from numpy file, error: illegal utf-8 encoding") + } + + // only support ascii characters, because the numpy lib encode the utf8 bytes by its internal method, + // the encode/decode logic is not clear now, return error + n := n.order.Uint32(raw) + if n > 127 { + log.Error("Numpy adapter: a string contains non-ascii characters, not support yet", zap.Int32("utf8Code", r)) + return nil, fmt.Errorf("Numpy adapter: a string contains non-ascii characters, not support yet") + } + + // if a string is shorter than maxLen, the tail characters will be filled with "\u0000"(in utf spec this is Null) + if r > 0 { + str += string(r) + } + + raw = raw[utf8.UTFMax:] + } + + data = append(data, str) + } else { + buf, err := ioutil.ReadAll(io.LimitReader(n.reader, int64(maxLen))) + if err != nil { + log.Error("Numpy adapter: failed to read string from numpy file", zap.Int("i", i), zap.Any("err", err)) + return nil, fmt.Errorf("Numpy adapter: failed to read string from numpy file, error: %w", err) + } + n := bytes.Index(buf, []byte{0}) + if n > 0 { + buf = buf[:n] + } + data = append(data, string(buf)) + } } // update read position after successfully read diff --git a/internal/util/importutil/numpy_adapter_test.go b/internal/util/importutil/numpy_adapter_test.go index 0a8284eaaa..5b3804711a 100644 --- a/internal/util/importutil/numpy_adapter_test.go +++ b/internal/util/importutil/numpy_adapter_test.go @@ -22,6 +22,7 @@ import ( "os" "testing" + "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/sbinet/npyio/npy" "github.com/stretchr/testify/assert" ) @@ -59,6 +60,55 @@ func Test_CreateNumpyData(t *testing.T) { assert.Nil(t, buf) } +func Test_ConvertNumpyType(t *testing.T) { + checkFunc := func(inputs []string, output schemapb.DataType) { + for i := 0; i < len(inputs); i++ { + dt, err := convertNumpyType(inputs[i]) + assert.Nil(t, err) + assert.Equal(t, output, dt) + } + } + + checkFunc([]string{"b1", "i1", "int8"}, schemapb.DataType_Int8) + checkFunc([]string{"i2", "i2", "int16"}, schemapb.DataType_Int16) + checkFunc([]string{"i4", "i4", "int32"}, schemapb.DataType_Int32) + checkFunc([]string{"i8", "i8", "int64"}, schemapb.DataType_Int64) + checkFunc([]string{"f4", "f4", "float32"}, schemapb.DataType_Float) + checkFunc([]string{"f8", "f8", "float64"}, schemapb.DataType_Double) + + dt, err := convertNumpyType("dummy") + assert.NotNil(t, err) + assert.Equal(t, schemapb.DataType_None, dt) +} + +func Test_StringLen(t *testing.T) { + len, utf, err := stringLen("S1") + assert.Equal(t, 1, len) + assert.False(t, utf) + assert.Nil(t, err) + + len, utf, err = stringLen("2S") + assert.Equal(t, 2, len) + assert.False(t, utf) + assert.Nil(t, err) + + len, utf, err = stringLen("4U") + assert.Equal(t, 4, len) + assert.True(t, utf) + assert.Nil(t, err) + + len, utf, err = stringLen("dummy") + assert.NotNil(t, err) + assert.Equal(t, 0, len) + assert.False(t, utf) +} + func Test_NumpyAdapterSetByteOrder(t *testing.T) { adapter := &NumpyAdapter{ reader: nil, @@ -82,32 +132,32 @@ func Test_NumpyAdapterReadError(t *testing.T) { npyReader: nil, } - // reader is nil - { - _, err := adapter.ReadBool(1) + // reader size is zero + t.Run("test size is zero", func(t *testing.T) { + _, err := adapter.ReadBool(0) assert.NotNil(t, err) - _, err = adapter.ReadUint8(1) + _, err = adapter.ReadUint8(0) assert.NotNil(t, err) - _, err = adapter.ReadInt8(1) + _, err = adapter.ReadInt8(0) assert.NotNil(t, err) - _, err = adapter.ReadInt16(1) + _, err = adapter.ReadInt16(0) assert.NotNil(t, err) - _, err = adapter.ReadInt32(1) + _, err = adapter.ReadInt32(0) assert.NotNil(t, err) - _, err = adapter.ReadInt64(1) + _, err = adapter.ReadInt64(0) assert.NotNil(t, err) - _, err = adapter.ReadFloat32(1) + _, err = adapter.ReadFloat32(0) assert.NotNil(t, err) - _, err = adapter.ReadFloat64(1) + _, err = adapter.ReadFloat64(0) assert.NotNil(t, err) - } + }) adapter = &NumpyAdapter{ reader: &MockReader{}, npyReader: &npy.Reader{}, } - { + t.Run("test read bool", func(t *testing.T) { adapter.npyReader.Header.Descr.Type = "bool" data, err := adapter.ReadBool(1) assert.Nil(t, data) @@ -117,9 +167,9 @@ func Test_NumpyAdapterReadError(t *testing.T) { data, err = adapter.ReadBool(1) assert.Nil(t, data) assert.NotNil(t, err) - } + }) - { + t.Run("test read uint8", func(t *testing.T) { adapter.npyReader.Header.Descr.Type = "u1" data, err := adapter.ReadUint8(1) assert.Nil(t, data) @@ -129,9 +179,9 @@ func Test_NumpyAdapterReadError(t *testing.T) { data, err = adapter.ReadUint8(1) assert.Nil(t, data) assert.NotNil(t, err) - } + }) - { + t.Run("test read int8", func(t *testing.T) { adapter.npyReader.Header.Descr.Type = "i1" data, err := adapter.ReadInt8(1) assert.Nil(t, data) @@ -141,9 +191,9 @@ func Test_NumpyAdapterReadError(t *testing.T) { data, err = adapter.ReadInt8(1) assert.Nil(t, data) assert.NotNil(t, err) - } + }) - { + t.Run("test read int16", func(t *testing.T) { adapter.npyReader.Header.Descr.Type = "i2" data, err := adapter.ReadInt16(1) assert.Nil(t, data) @@ -153,9 +203,9 @@ func Test_NumpyAdapterReadError(t *testing.T) { data, err = adapter.ReadInt16(1) assert.Nil(t, data) assert.NotNil(t, err) - } + }) - { + t.Run("test read int32", func(t *testing.T) { adapter.npyReader.Header.Descr.Type = "i4" data, err := adapter.ReadInt32(1) assert.Nil(t, data) @@ -165,9 +215,9 @@ func Test_NumpyAdapterReadError(t *testing.T) { data, err = adapter.ReadInt32(1) assert.Nil(t, data) assert.NotNil(t, err) - } + }) - { + t.Run("test read int64", func(t *testing.T) { adapter.npyReader.Header.Descr.Type = "i8" data, err := adapter.ReadInt64(1) assert.Nil(t, data) @@ -177,9 +227,9 @@ func Test_NumpyAdapterReadError(t *testing.T) { data, err = adapter.ReadInt64(1) assert.Nil(t, data) assert.NotNil(t, err) - } + }) - { + t.Run("test read float", func(t *testing.T) { adapter.npyReader.Header.Descr.Type = "f4" data, err := adapter.ReadFloat32(1) assert.Nil(t, data) @@ -189,9 +239,9 @@ func Test_NumpyAdapterReadError(t *testing.T) { data, err = adapter.ReadFloat32(1) assert.Nil(t, data) assert.NotNil(t, err) - } + }) - { + t.Run("test read double", func(t *testing.T) { adapter.npyReader.Header.Descr.Type = "f8" data, err := adapter.ReadFloat64(1) assert.Nil(t, data) @@ -201,7 +251,19 @@ func Test_NumpyAdapterReadError(t *testing.T) { data, err = adapter.ReadFloat64(1) assert.Nil(t, data) assert.NotNil(t, err) - } + }) + + t.Run("test read varchar", func(t *testing.T) { + adapter.npyReader.Header.Descr.Type = "U3" + data, err := adapter.ReadString(1) + assert.Nil(t, data) + assert.NotNil(t, err) + + adapter.npyReader.Header.Descr.Type = "dummy" + data, err = adapter.ReadString(1) + assert.Nil(t, data) + assert.NotNil(t, err) + }) } func Test_NumpyAdapterRead(t *testing.T) { @@ -209,7 +271,7 @@ func Test_NumpyAdapterRead(t *testing.T) { assert.Nil(t, err) defer os.RemoveAll(TempFilesPath) - { + t.Run("test read bool", func(t *testing.T) { filePath := TempFilesPath + "bool.npy" data := []bool{true, false, true, false} err := CreateNumpyFile(filePath, data) @@ -267,9 +329,9 @@ func Test_NumpyAdapterRead(t *testing.T) { resf8, err := adapter.ReadFloat64(len(data)) assert.NotNil(t, err) assert.Nil(t, resf8) - } + }) - { + t.Run("test read uint8", func(t *testing.T) { filePath := TempFilesPath + "uint8.npy" data := []uint8{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) @@ -303,9 +365,9 @@ func Test_NumpyAdapterRead(t *testing.T) { resb, err := adapter.ReadBool(len(data)) assert.NotNil(t, err) assert.Nil(t, resb) - } + }) - { + t.Run("test read int8", func(t *testing.T) { filePath := TempFilesPath + "int8.npy" data := []int8{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) @@ -334,9 +396,9 @@ func Test_NumpyAdapterRead(t *testing.T) { res, err = adapter.ReadInt8(len(data)) assert.NotNil(t, err) assert.Nil(t, res) - } + }) - { + t.Run("test read int16", func(t *testing.T) { filePath := TempFilesPath + "int16.npy" data := []int16{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) @@ -365,9 +427,9 @@ func Test_NumpyAdapterRead(t *testing.T) { res, err = adapter.ReadInt16(len(data)) assert.NotNil(t, err) assert.Nil(t, res) - } + }) - { + t.Run("test read int32", func(t *testing.T) { filePath := TempFilesPath + "int32.npy" data := []int32{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) @@ -396,9 +458,9 @@ func Test_NumpyAdapterRead(t *testing.T) { res, err = adapter.ReadInt32(len(data)) assert.NotNil(t, err) assert.Nil(t, res) - } + }) - { + t.Run("test read int64", func(t *testing.T) { filePath := TempFilesPath + "int64.npy" data := []int64{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) @@ -427,9 +489,9 @@ func Test_NumpyAdapterRead(t *testing.T) { res, err = adapter.ReadInt64(len(data)) assert.NotNil(t, err) assert.Nil(t, res) - } + }) - { + t.Run("test read float", func(t *testing.T) { filePath := TempFilesPath + "float.npy" data := []float32{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) @@ -458,9 +520,9 @@ func Test_NumpyAdapterRead(t *testing.T) { res, err = adapter.ReadFloat32(len(data)) assert.NotNil(t, err) assert.Nil(t, res) - } + }) - { + t.Run("test read double", func(t *testing.T) { filePath := TempFilesPath + "double.npy" data := []float64{1, 2, 3, 4, 5, 6} err := CreateNumpyFile(filePath, data) @@ -489,5 +551,52 @@ func Test_NumpyAdapterRead(t *testing.T) { res, err = adapter.ReadFloat64(len(data)) assert.NotNil(t, err) assert.Nil(t, res) - } + }) + + t.Run("test read ascii characters", func(t *testing.T) { + filePath := TempFilesPath + "varchar1.npy" + data := []string{"a", "bbb", "c", "dd", "eeee", "fff"} + err := CreateNumpyFile(filePath, data) + assert.Nil(t, err) + + file, err := os.Open(filePath) + assert.Nil(t, err) + defer file.Close() + + adapter, err := NewNumpyAdapter(file) + assert.Nil(t, err) + res, err := adapter.ReadString(len(data) - 1) + assert.Nil(t, err) + assert.Equal(t, len(data)-1, len(res)) + + for i := 0; i < len(res); i++ { + assert.Equal(t, data[i], res[i]) + } + + res, err = adapter.ReadString(len(data)) + assert.Nil(t, err) + assert.Equal(t, 1, len(res)) + assert.Equal(t, data[len(data)-1], res[0]) + + res, err = adapter.ReadString(len(data)) + assert.NotNil(t, err) + assert.Nil(t, res) + }) + + t.Run("test read non-ascii", func(t *testing.T) { + filePath := TempFilesPath + "varchar2.npy" + data := []string{"a三百", "马克bbb"} + err := CreateNumpyFile(filePath, data) + assert.Nil(t, err) + + file, err := os.Open(filePath) + assert.Nil(t, err) + defer file.Close() + + adapter, err := NewNumpyAdapter(file) + assert.Nil(t, err) + res, err := adapter.ReadString(len(data)) + assert.NotNil(t, err) + assert.Nil(t, res) + }) } diff --git a/internal/util/importutil/numpy_parser.go b/internal/util/importutil/numpy_parser.go index 510e83ffbf..55cd334ad0 100644 --- a/internal/util/importutil/numpy_parser.go +++ b/internal/util/importutil/numpy_parser.go @@ -19,12 +19,13 @@ package importutil import ( "context" "errors" + "fmt" "io" - "strconv" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/storage" + "go.uber.org/zap" ) type ColumnDesc struct { @@ -43,7 +44,7 @@ type NumpyParser struct { callFlushFunc func(field storage.FieldData) error // call back function to output column data } -// NewNumpyParser helper function to create a NumpyParser +// NewNumpyParser is helper function to create a NumpyParser func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema, flushFunc func(field storage.FieldData) error) *NumpyParser { if collectionSchema == nil || flushFunc == nil { @@ -60,38 +61,10 @@ func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSc return parser } -func (p *NumpyParser) logError(msg string) error { - log.Error(msg) - return errors.New(msg) -} - -// data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector) -func convertNumpyType(str string) (schemapb.DataType, error) { - switch str { - case "b1", "i1", "int8": - return schemapb.DataType_Int8, nil - case "i2", "i2", "int16": - return schemapb.DataType_Int16, nil - case "i4", "i4", "int32": - return schemapb.DataType_Int32, nil - case "i8", "i8", "int64": - return schemapb.DataType_Int64, nil - case "f4", "f4", "float32": - return schemapb.DataType_Float, nil - case "f8", "f8", "float64": - return schemapb.DataType_Double, nil - default: - return schemapb.DataType_None, errors.New("unsupported data type " + str) - } -} - func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error { if adapter == nil { - return errors.New("numpy adapter is nil") + log.Error("Numpy parser: numpy adapter is nil") + return errors.New("Numpy parser: numpy adapter is nil") } // check existence of the target field @@ -105,27 +78,32 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error { } if p.columnDesc.name == "" { - return errors.New("the field " + fieldName + " doesn't exist") + log.Error("Numpy parser: Numpy parser: the field is not found in collection schema", zap.String("fieldName", fieldName)) + return fmt.Errorf("Numpy parser: the field name '%s' is not found in collection schema", fieldName) } p.columnDesc.dt = schema.DataType - elementType, err := convertNumpyType(adapter.GetType()) - if err != nil { - return err - } - + elementType := adapter.GetType() shape := adapter.GetShape() + var err error // 1. field data type should be consist to numpy data type // 2. vector field dimension should be consist to numpy shape if schemapb.DataType_FloatVector == schema.DataType { + // float32/float64 numpy file can be used for float vector file, 2 reasons: + // 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit + // 2. for float64 numpy file, the performance is worse than float32 numpy file if elementType != schemapb.DataType_Float && elementType != schemapb.DataType_Double { - return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName()) + log.Error("Numpy parser: illegal data type of numpy file for float vector field", zap.Any("dataType", elementType), + zap.String("fieldName", fieldName)) + return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for float vector field '%s'", getTypeName(elementType), schema.GetName()) } // vector field, the shape should be 2 if len(shape) != 2 { - return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName()) + log.Error("Numpy parser: illegal shape of numpy file for float vector field, shape should be 2", zap.Int("shape", len(shape)), + zap.String("fieldName", fieldName)) + return fmt.Errorf("Numpy parser: illegal shape %d of numpy file for float vector field '%s', shape should be 2", shape, schema.GetName()) } // shape[0] is row count, shape[1] is element count per row @@ -137,16 +115,23 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error { } if shape[1] != p.columnDesc.dimension { - return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension)) + log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName), + zap.Int("numpyDimension", shape[1]), zap.Int("fieldDimension", p.columnDesc.dimension)) + return fmt.Errorf("Numpy parser: illegal dimension %d of numpy file for float vector field '%s', dimension should be %d", + shape[1], schema.GetName(), p.columnDesc.dimension) } } else if schemapb.DataType_BinaryVector == schema.DataType { if elementType != schemapb.DataType_BinaryVector { - return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName()) + log.Error("Numpy parser: illegal data type of numpy file for binary vector field", zap.Any("dataType", elementType), + zap.String("fieldName", fieldName)) + return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for binary vector field '%s'", getTypeName(elementType), schema.GetName()) } // vector field, the shape should be 2 if len(shape) != 2 { - return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName()) + log.Error("Numpy parser: illegal shape of numpy file for binary vector field, shape should be 2", zap.Int("shape", len(shape)), + zap.String("fieldName", fieldName)) + return fmt.Errorf("Numpy parser: illegal shape %d of numpy file for binary vector field '%s', shape should be 2", shape, schema.GetName()) } // shape[0] is row count, shape[1] is element count per row @@ -158,16 +143,24 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error { } if shape[1] != p.columnDesc.dimension/8 { - return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension)) + log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName), + zap.Int("numpyDimension", shape[1]*8), zap.Int("fieldDimension", p.columnDesc.dimension)) + return fmt.Errorf("Numpy parser: illegal dimension %d of numpy file for binary vector field '%s', dimension should be %d", + shape[1]*8, schema.GetName(), p.columnDesc.dimension) } } else { if elementType != schema.DataType { - return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName()) + log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType), + zap.String("fieldName", fieldName), zap.Any("fieldDataType", schema.DataType)) + return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for scalar field '%s' with type %d", + getTypeName(elementType), schema.GetName(), schema.DataType) } // scalar field, the shape should be 1 if len(shape) != 1 { - return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName()) + log.Error("Numpy parser: illegal shape of numpy file for scalar field, shape should be 1", zap.Int("shape", len(shape)), + zap.String("fieldName", fieldName)) + return fmt.Errorf("Numpy parser: illegal shape %d of numpy file for scalar field '%s', shape should be 1", shape, schema.GetName()) } p.columnDesc.elementCount = shape[0] @@ -176,13 +169,14 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error { return nil } -// this method read numpy data section into a storage.FieldData +// consume method reads numpy data section into a storage.FieldData // please note it will require a large memory block(the memory size is almost equal to numpy file size) func (p *NumpyParser) consume(adapter *NumpyAdapter) error { switch p.columnDesc.dt { case schemapb.DataType_Bool: data, err := adapter.ReadBool(p.columnDesc.elementCount) if err != nil { + log.Error(err.Error()) return err } @@ -194,6 +188,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error { case schemapb.DataType_Int8: data, err := adapter.ReadInt8(p.columnDesc.elementCount) if err != nil { + log.Error(err.Error()) return err } @@ -204,6 +199,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error { case schemapb.DataType_Int16: data, err := adapter.ReadInt16(p.columnDesc.elementCount) if err != nil { + log.Error(err.Error()) return err } @@ -214,6 +210,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error { case schemapb.DataType_Int32: data, err := adapter.ReadInt32(p.columnDesc.elementCount) if err != nil { + log.Error(err.Error()) return err } @@ -224,6 +221,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error { case schemapb.DataType_Int64: data, err := adapter.ReadInt64(p.columnDesc.elementCount) if err != nil { + log.Error(err.Error()) return err } @@ -234,6 +232,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error { case schemapb.DataType_Float: data, err := adapter.ReadFloat32(p.columnDesc.elementCount) if err != nil { + log.Error(err.Error()) return err } @@ -244,6 +243,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error { case schemapb.DataType_Double: data, err := adapter.ReadFloat64(p.columnDesc.elementCount) if err != nil { + log.Error(err.Error()) return err } @@ -251,9 +251,21 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error { NumRows: []int64{int64(p.columnDesc.elementCount)}, Data: data, } + case schemapb.DataType_VarChar: + data, err := adapter.ReadString(p.columnDesc.elementCount) + if err != nil { + log.Error(err.Error()) + return err + } + + p.columnData = &storage.StringFieldData{ + NumRows: []int64{int64(p.columnDesc.elementCount)}, + Data: data, + } case schemapb.DataType_BinaryVector: data, err := adapter.ReadUint8(p.columnDesc.elementCount) if err != nil { + log.Error(err.Error()) return err } @@ -263,24 +275,24 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error { Dim: p.columnDesc.dimension, } case schemapb.DataType_FloatVector: - // for float vector, we support float32 and float64 numpy file because python float value is 64 bit - // for float64 numpy file, the performance is worse than float32 numpy file - // we don't check overflow here - elementType, err := convertNumpyType(adapter.GetType()) - if err != nil { - return err - } + // float32/float64 numpy file can be used for float vector file, 2 reasons: + // 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit + // 2. for float64 numpy file, the performance is worse than float32 numpy file + elementType := adapter.GetType() var data []float32 + var err error if elementType == schemapb.DataType_Float { data, err = adapter.ReadFloat32(p.columnDesc.elementCount) if err != nil { + log.Error(err.Error()) return err } } else if elementType == schemapb.DataType_Double { data = make([]float32, 0, p.columnDesc.elementCount) data64, err := adapter.ReadFloat64(p.columnDesc.elementCount) if err != nil { + log.Error(err.Error()) return err } @@ -295,7 +307,8 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error { Dim: p.columnDesc.dimension, } default: - return errors.New("unsupported data type: " + strconv.Itoa(int(p.columnDesc.dt))) + log.Error("Numpy parser: unsupported data type of field", zap.Any("dataType", p.columnDesc.dt), zap.String("fieldName", p.columnDesc.name)) + return fmt.Errorf("Numpy parser: unsupported data type %s of field '%s'", getTypeName(p.columnDesc.dt), p.columnDesc.name) } return nil @@ -304,13 +317,13 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error { func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate bool) error { adapter, err := NewNumpyAdapter(reader) if err != nil { - return p.logError("Numpy parse: " + err.Error()) + return err } // the validation method only check the file header information err = p.validate(adapter, fieldName) if err != nil { - return p.logError("Numpy parse: " + err.Error()) + return err } if onlyValidate { @@ -320,7 +333,7 @@ func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate boo // read all data from the numpy file err = p.consume(adapter) if err != nil { - return p.logError("Numpy parse: " + err.Error()) + return err } return p.callFlushFunc(p.columnData) diff --git a/internal/util/importutil/numpy_parser_test.go b/internal/util/importutil/numpy_parser_test.go index 45c29b55f4..9e7d3e75a2 100644 --- a/internal/util/importutil/numpy_parser_test.go +++ b/internal/util/importutil/numpy_parser_test.go @@ -36,28 +36,6 @@ func Test_NewNumpyParser(t *testing.T) { assert.Nil(t, parser) } -func Test_ConvertNumpyType(t *testing.T) { - checkFunc := func(inputs []string, output schemapb.DataType) { - for i := 0; i < len(inputs); i++ { - dt, err := convertNumpyType(inputs[i]) - assert.Nil(t, err) - assert.Equal(t, output, dt) - } - } - - checkFunc([]string{"b1", "i1", "int8"}, schemapb.DataType_Int8) - checkFunc([]string{"i2", "i2", "int16"}, schemapb.DataType_Int16) - checkFunc([]string{"i4", "i4", "int32"}, schemapb.DataType_Int32) - checkFunc([]string{"i8", "i8", "int64"}, schemapb.DataType_Int64) - checkFunc([]string{"f4", "f4", "float32"}, schemapb.DataType_Float) - checkFunc([]string{"f8", "f8", "float64"}, schemapb.DataType_Double) - - dt, err := convertNumpyType("dummy") - assert.NotNil(t, err) - assert.Equal(t, schemapb.DataType_None, dt) -} - func Test_NumpyParserValidate(t *testing.T) { ctx := context.Background() err := os.MkdirAll(TempFilesPath, os.ModePerm) @@ -71,7 +49,7 @@ func Test_NumpyParserValidate(t *testing.T) { adapter := &NumpyAdapter{npyReader: &npy.Reader{}} - { + t.Run("not support DataType_String", func(t *testing.T) { // string type is not supported p := NewNumpyParser(ctx, &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ @@ -88,15 +66,14 @@ func Test_NumpyParserValidate(t *testing.T) { assert.NotNil(t, err) err = p.validate(adapter, "field_string") assert.NotNil(t, err) - } + }) // reader is nil parser := NewNumpyParser(ctx, schema, flushFunc) err = parser.validate(nil, "") assert.NotNil(t, err) - // validate scalar data - func() { + t.Run("validate scalar", func(t *testing.T) { filePath := TempFilesPath + "scalar_1.npy" data1 := []float64{0, 1, 2, 3, 4, 5} err := CreateNumpyFile(filePath, data1) @@ -108,6 +85,7 @@ func Test_NumpyParserValidate(t *testing.T) { adapter, err := NewNumpyAdapter(file1) assert.Nil(t, err) + assert.NotNil(t, adapter) err = parser.validate(adapter, "field_double") assert.Nil(t, err) @@ -128,6 +106,7 @@ func Test_NumpyParserValidate(t *testing.T) { adapter, err = NewNumpyAdapter(file2) assert.Nil(t, err) + assert.NotNil(t, adapter) err = parser.validate(adapter, "field_double") assert.NotNil(t, err) @@ -144,13 +123,13 @@ func Test_NumpyParserValidate(t *testing.T) { adapter, err = NewNumpyAdapter(file3) assert.Nil(t, err) + assert.NotNil(t, adapter) err = parser.validate(adapter, "field_double") assert.NotNil(t, err) - }() + }) - // validate binary vector data - func() { + t.Run("validate binary vector", func(t *testing.T) { filePath := TempFilesPath + "binary_vector_1.npy" data1 := [][2]uint8{{0, 1}, {2, 3}, {4, 5}} err := CreateNumpyFile(filePath, data1) @@ -162,6 +141,7 @@ func Test_NumpyParserValidate(t *testing.T) { adapter, err := NewNumpyAdapter(file1) assert.Nil(t, err) + assert.NotNil(t, adapter) err = parser.validate(adapter, "field_binary_vector") assert.Nil(t, err) @@ -178,10 +158,8 @@ func Test_NumpyParserValidate(t *testing.T) { defer file2.Close() adapter, err = NewNumpyAdapter(file2) - assert.Nil(t, err) - - err = parser.validate(adapter, "field_binary_vector") assert.NotNil(t, err) + assert.Nil(t, adapter) // shape mismatch filePath = TempFilesPath + "binary_vector_3.npy" @@ -195,6 +173,7 @@ func Test_NumpyParserValidate(t *testing.T) { adapter, err = NewNumpyAdapter(file3) assert.Nil(t, err) + assert.NotNil(t, adapter) err = parser.validate(adapter, "field_binary_vector") assert.NotNil(t, err) @@ -211,6 +190,7 @@ func Test_NumpyParserValidate(t *testing.T) { adapter, err = NewNumpyAdapter(file4) assert.Nil(t, err) + assert.NotNil(t, adapter) err = parser.validate(adapter, "field_binary_vector") assert.NotNil(t, err) @@ -228,10 +208,9 @@ func Test_NumpyParserValidate(t *testing.T) { err = p.validate(adapter, "field_binary_vector") assert.NotNil(t, err) - }() + }) - // validate float vector data - func() { + t.Run("validate float vector", func(t *testing.T) { filePath := TempFilesPath + "float_vector.npy" data1 := [][4]float32{{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}, {3, 3, 3, 3}} err := CreateNumpyFile(filePath, data1) @@ -243,6 +222,7 @@ func Test_NumpyParserValidate(t *testing.T) { adapter, err := NewNumpyAdapter(file1) assert.Nil(t, err) + assert.NotNil(t, adapter) err = parser.validate(adapter, "field_float_vector") assert.Nil(t, err) @@ -260,6 +240,7 @@ func Test_NumpyParserValidate(t *testing.T) { adapter, err = NewNumpyAdapter(file2) assert.Nil(t, err) + assert.NotNil(t, adapter) err = parser.validate(adapter, "field_float_vector") assert.NotNil(t, err) @@ -276,6 +257,7 @@ func Test_NumpyParserValidate(t *testing.T) { adapter, err = NewNumpyAdapter(file3) assert.Nil(t, err) + assert.NotNil(t, adapter) err = parser.validate(adapter, "field_float_vector") assert.NotNil(t, err) @@ -292,6 +274,7 @@ func Test_NumpyParserValidate(t *testing.T) { adapter, err = NewNumpyAdapter(file4) assert.Nil(t, err) + assert.NotNil(t, adapter) err = parser.validate(adapter, "field_float_vector") assert.NotNil(t, err) @@ -309,7 +292,7 @@ func Test_NumpyParserValidate(t *testing.T) { err = p.validate(adapter, "field_float_vector") assert.NotNil(t, err) - }() + }) } func Test_NumpyParserParse(t *testing.T) { @@ -355,153 +338,179 @@ func Test_NumpyParserParse(t *testing.T) { }() } - // scalar bool - data1 := []bool{true, false, true, false, true} - flushFunc := func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data1), field.RowNum()) + t.Run("parse scalar bool", func(t *testing.T) { + data := []bool{true, false, true, false, true} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data), field.RowNum()) - for i := 0; i < len(data1); i++ { - assert.Equal(t, data1[i], field.GetRow(i)) - } - - return nil - } - checkFunc(data1, "field_bool", flushFunc) - - // scalar int8 - data2 := []int8{1, 2, 3, 4, 5} - flushFunc = func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data2), field.RowNum()) - - for i := 0; i < len(data2); i++ { - assert.Equal(t, data2[i], field.GetRow(i)) - } - - return nil - } - checkFunc(data2, "field_int8", flushFunc) - - // scalar int16 - data3 := []int16{1, 2, 3, 4, 5} - flushFunc = func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data3), field.RowNum()) - - for i := 0; i < len(data3); i++ { - assert.Equal(t, data3[i], field.GetRow(i)) - } - - return nil - } - checkFunc(data3, "field_int16", flushFunc) - - // scalar int32 - data4 := []int32{1, 2, 3, 4, 5} - flushFunc = func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data4), field.RowNum()) - - for i := 0; i < len(data4); i++ { - assert.Equal(t, data4[i], field.GetRow(i)) - } - - return nil - } - checkFunc(data4, "field_int32", flushFunc) - - // scalar int64 - data5 := []int64{1, 2, 3, 4, 5} - flushFunc = func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data5), field.RowNum()) - - for i := 0; i < len(data5); i++ { - assert.Equal(t, data5[i], field.GetRow(i)) - } - - return nil - } - checkFunc(data5, "field_int64", flushFunc) - - // scalar float - data6 := []float32{1, 2, 3, 4, 5} - flushFunc = func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data6), field.RowNum()) - - for i := 0; i < len(data6); i++ { - assert.Equal(t, data6[i], field.GetRow(i)) - } - - return nil - } - checkFunc(data6, "field_float", flushFunc) - - // scalar double - data7 := []float64{1, 2, 3, 4, 5} - flushFunc = func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data7), field.RowNum()) - - for i := 0; i < len(data7); i++ { - assert.Equal(t, data7[i], field.GetRow(i)) - } - - return nil - } - checkFunc(data7, "field_double", flushFunc) - - // binary vector - data8 := [][2]uint8{{1, 2}, {3, 4}, {5, 6}} - flushFunc = func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data8), field.RowNum()) - - for i := 0; i < len(data8); i++ { - row := field.GetRow(i).([]uint8) - for k := 0; k < len(row); k++ { - assert.Equal(t, data8[i][k], row[k]) + for i := 0; i < len(data); i++ { + assert.Equal(t, data[i], field.GetRow(i)) } + + return nil } + checkFunc(data, "field_bool", flushFunc) + }) - return nil - } - checkFunc(data8, "field_binary_vector", flushFunc) + t.Run("parse scalar int8", func(t *testing.T) { + data := []int8{1, 2, 3, 4, 5} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data), field.RowNum()) - // double vector(element can be float32 or float64) - data9 := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}} - flushFunc = func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data9), field.RowNum()) - - for i := 0; i < len(data9); i++ { - row := field.GetRow(i).([]float32) - for k := 0; k < len(row); k++ { - assert.Equal(t, data9[i][k], row[k]) + for i := 0; i < len(data); i++ { + assert.Equal(t, data[i], field.GetRow(i)) } + + return nil } + checkFunc(data, "field_int8", flushFunc) + }) - return nil - } - checkFunc(data9, "field_float_vector", flushFunc) + t.Run("parse scalar int16", func(t *testing.T) { + data := []int16{1, 2, 3, 4, 5} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data), field.RowNum()) - data10 := [][4]float64{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}} - flushFunc = func(field storage.FieldData) error { - assert.NotNil(t, field) - assert.Equal(t, len(data10), field.RowNum()) - - for i := 0; i < len(data10); i++ { - row := field.GetRow(i).([]float32) - for k := 0; k < len(row); k++ { - assert.Equal(t, float32(data10[i][k]), row[k]) + for i := 0; i < len(data); i++ { + assert.Equal(t, data[i], field.GetRow(i)) } - } - return nil - } - checkFunc(data10, "field_float_vector", flushFunc) + return nil + } + checkFunc(data, "field_int16", flushFunc) + }) + + t.Run("parse scalar int32", func(t *testing.T) { + data := []int32{1, 2, 3, 4, 5} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data), field.RowNum()) + + for i := 0; i < len(data); i++ { + assert.Equal(t, data[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data, "field_int32", flushFunc) + }) + + t.Run("parse scalar int64", func(t *testing.T) { + data := []int64{1, 2, 3, 4, 5} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data), field.RowNum()) + + for i := 0; i < len(data); i++ { + assert.Equal(t, data[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data, "field_int64", flushFunc) + }) + + t.Run("parse scalar float", func(t *testing.T) { + data := []float32{1, 2, 3, 4, 5} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data), field.RowNum()) + + for i := 0; i < len(data); i++ { + assert.Equal(t, data[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data, "field_float", flushFunc) + }) + + t.Run("parse scalar double", func(t *testing.T) { + data := []float64{1, 2, 3, 4, 5} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data), field.RowNum()) + + for i := 0; i < len(data); i++ { + assert.Equal(t, data[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data, "field_double", flushFunc) + }) + + t.Run("parse scalar varchar", func(t *testing.T) { + data := []string{"abcd", "sdb", "ok", "milvus"} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data), field.RowNum()) + + for i := 0; i < len(data); i++ { + assert.Equal(t, data[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data, "field_string", flushFunc) + }) + + t.Run("parse binary vector", func(t *testing.T) { + data := [][2]uint8{{1, 2}, {3, 4}, {5, 6}} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data), field.RowNum()) + + for i := 0; i < len(data); i++ { + row := field.GetRow(i).([]uint8) + for k := 0; k < len(row); k++ { + assert.Equal(t, data[i][k], row[k]) + } + } + + return nil + } + checkFunc(data, "field_binary_vector", flushFunc) + }) + + t.Run("parse binary vector with float32", func(t *testing.T) { + data := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data), field.RowNum()) + + for i := 0; i < len(data); i++ { + row := field.GetRow(i).([]float32) + for k := 0; k < len(row); k++ { + assert.Equal(t, data[i][k], row[k]) + } + } + + return nil + } + checkFunc(data, "field_float_vector", flushFunc) + }) + + t.Run("parse binary vector with float64", func(t *testing.T) { + data := [][4]float64{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data), field.RowNum()) + + for i := 0; i < len(data); i++ { + row := field.GetRow(i).([]float32) + for k := 0; k < len(row); k++ { + assert.Equal(t, float32(data[i][k]), row[k]) + } + } + + return nil + } + checkFunc(data, "field_float_vector", flushFunc) + }) } func Test_NumpyParserParse_perf(t *testing.T) {