Refine bulkload (#19671)

Signed-off-by: yhmo <yihua.mo@zilliz.com>

Signed-off-by: yhmo <yihua.mo@zilliz.com>
pull/20137/head
groot 2022-10-27 16:21:34 +08:00 committed by GitHub
parent def5972e01
commit bee66631e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 3177 additions and 2701 deletions

View File

@ -71,7 +71,7 @@ func (h *ServerHandler) GetDataVChanPositions(channel *channel, partitionID Uniq
continue
}
if s.GetIsImporting() {
// Skip bulk load segments.
// Skip bulk insert segments.
continue
}
@ -149,7 +149,7 @@ func (h *ServerHandler) GetQueryVChanPositions(channel *channel, partitionID Uni
continue
}
if s.GetIsImporting() {
// Skip bulk load segments.
// Skip bulk insert segments.
continue
}
segmentInfos[s.GetID()] = s

View File

@ -68,7 +68,7 @@ func putAllocation(a *Allocation) {
type Manager interface {
// AllocSegment allocates rows and record the allocation.
AllocSegment(ctx context.Context, collectionID, partitionID UniqueID, channelName string, requestRows int64) ([]*Allocation, error)
// allocSegmentForImport allocates one segment allocation for bulk load.
// allocSegmentForImport allocates one segment allocation for bulk insert.
// TODO: Remove this method and AllocSegment() above instead.
allocSegmentForImport(ctx context.Context, collectionID, partitionID UniqueID, channelName string, requestRows int64, taskID int64) (*Allocation, error)
// DropSegment drops the segment from manager.
@ -278,7 +278,7 @@ func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID
return allocations, nil
}
// allocSegmentForImport allocates one segment allocation for bulk load.
// allocSegmentForImport allocates one segment allocation for bulk insert.
func (s *SegmentManager) allocSegmentForImport(ctx context.Context, collectionID UniqueID,
partitionID UniqueID, channelName string, requestRows int64, importTaskID int64) (*Allocation, error) {
// init allocation

View File

@ -637,7 +637,7 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
if segment.State != commonpb.SegmentState_Flushed && segment.State != commonpb.SegmentState_Flushing && segment.State != commonpb.SegmentState_Dropped {
continue
}
// Also skip bulk load segments.
// Also skip bulk insert segments.
if segment.GetIsImporting() {
continue
}

View File

@ -1074,16 +1074,19 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest)
// parse files and generate segments
segmentSize := int64(Params.DataCoordCfg.SegmentMaxSize) * 1024 * 1024
importWrapper := importutil.NewImportWrapper(newCtx, colInfo.GetSchema(), colInfo.GetShardsNum(), segmentSize, node.rowIDAllocator, node.chunkManager,
importFlushReqFunc(node, req, importResult, colInfo.GetSchema(), ts), importResult, reportFunc)
importWrapper := importutil.NewImportWrapper(newCtx, colInfo.GetSchema(), colInfo.GetShardsNum(), segmentSize, node.rowIDAllocator,
node.chunkManager, importResult, reportFunc)
importWrapper.SetCallbackFunctions(assignSegmentFunc(node, req),
createBinLogsFunc(node, req, colInfo.GetSchema(), ts),
saveSegmentFunc(node, req, importResult, ts))
// todo: pass tsStart and tsStart after import_wrapper support
tsStart, tsEnd, err := importutil.ParseTSFromOptions(req.GetImportTask().GetInfos())
if err != nil {
return returnFailFunc(err)
}
log.Info("import time range", zap.Uint64("start_ts", tsStart), zap.Uint64("end_ts", tsEnd))
err = importWrapper.Import(req.GetImportTask().GetFiles(), req.GetImportTask().GetRowBased(), false)
//err = importWrapper.Import(req.GetImportTask().GetFiles(), req.GetImportTask().GetRowBased(), false, tsStart, tsEnd)
err = importWrapper.Import(req.GetImportTask().GetFiles(),
importutil.ImportOptions{OnlyValidate: false, TsStartPoint: tsStart, TsEndPoint: tsEnd})
if err != nil {
return returnFailFunc(err)
}
@ -1183,8 +1186,8 @@ func (node *DataNode) AddImportSegment(ctx context.Context, req *datapb.AddImpor
}, nil
}
func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoordpb.ImportResult, schema *schemapb.CollectionSchema, ts Timestamp) importutil.ImportFlushFunc {
return func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
func assignSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest) importutil.AssignSegmentFunc {
return func(shardID int) (int64, string, error) {
chNames := req.GetImportTask().GetChannelNames()
importTaskID := req.GetImportTask().GetTaskId()
if shardID >= len(chNames) {
@ -1194,53 +1197,91 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root
zap.Int("# of channels", len(chNames)),
zap.Strings("channel names", chNames),
)
return fmt.Errorf("syncSegmentID Failed: invalid shard ID %d", shardID)
return 0, "", fmt.Errorf("syncSegmentID Failed: invalid shard ID %d", shardID)
}
tr := timerecord.NewTimeRecorder("import callback function")
defer tr.Elapse("finished")
colID := req.GetImportTask().GetCollectionId()
partID := req.GetImportTask().GetPartitionId()
segmentIDReq := composeAssignSegmentIDRequest(1, shardID, chNames, colID, partID)
targetChName := segmentIDReq.GetSegmentIDRequests()[0].GetChannelName()
log.Info("target channel for the import task",
zap.Int64("task ID", importTaskID),
zap.Int("shard ID", shardID),
zap.String("target channel name", targetChName))
resp, err := node.dataCoord.AssignSegmentID(context.Background(), segmentIDReq)
if err != nil {
return 0, "", fmt.Errorf("syncSegmentID Failed:%w", err)
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return 0, "", fmt.Errorf("syncSegmentID Failed:%s", resp.Status.Reason)
}
segmentID := resp.SegIDAssignments[0].SegID
log.Info("new segment assigned",
zap.Int64("task ID", importTaskID),
zap.Int64("segmentID", segmentID),
zap.Int("shard ID", shardID),
zap.String("target channel name", targetChName))
return segmentID, targetChName, nil
}
}
func createBinLogsFunc(node *DataNode, req *datapb.ImportTaskRequest, schema *schemapb.CollectionSchema, ts Timestamp) importutil.CreateBinlogsFunc {
return func(fields map[storage.FieldID]storage.FieldData, segmentID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error) {
var rowNum int
for _, field := range fields {
rowNum = field.RowNum()
break
}
chNames := req.GetImportTask().GetChannelNames()
importTaskID := req.GetImportTask().GetTaskId()
if rowNum <= 0 {
log.Info("fields data is empty, no need to generate segment",
log.Info("fields data is empty, no need to generate binlog",
zap.Int64("task ID", importTaskID),
zap.Int("shard ID", shardID),
zap.Int("# of channels", len(chNames)),
zap.Strings("channel names", chNames),
)
return nil
return nil, nil, nil
}
colID := req.GetImportTask().GetCollectionId()
partID := req.GetImportTask().GetPartitionId()
segmentIDReq := composeAssignSegmentIDRequest(rowNum, shardID, chNames, colID, partID)
targetChName := segmentIDReq.GetSegmentIDRequests()[0].GetChannelName()
log.Info("target channel for the import task",
zap.Int64("task ID", importTaskID),
zap.String("target channel name", targetChName))
resp, err := node.dataCoord.AssignSegmentID(context.Background(), segmentIDReq)
if err != nil {
return fmt.Errorf("syncSegmentID Failed:%w", err)
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return fmt.Errorf("syncSegmentID Failed:%s", resp.Status.Reason)
}
segmentID := resp.SegIDAssignments[0].SegID
fieldInsert, fieldStats, err := createBinLogs(rowNum, schema, ts, fields, node, segmentID, colID, partID)
if err != nil {
return err
log.Error("failed to create binlogs",
zap.Int64("task ID", importTaskID),
zap.Int("# of channels", len(chNames)),
zap.Strings("channel names", chNames),
zap.Any("err", err),
)
return nil, nil, err
}
log.Info("new binlog created",
zap.Int64("task ID", importTaskID),
zap.Int64("segmentID", segmentID),
zap.Int("insert log count", len(fieldInsert)),
zap.Int("stats log count", len(fieldStats)))
return fieldInsert, fieldStats, err
}
}
func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoordpb.ImportResult, ts Timestamp) importutil.SaveSegmentFunc {
importTaskID := req.GetImportTask().GetTaskId()
return func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64) error {
log.Info("adding segment to the correct DataNode flow graph and saving binlog paths",
zap.Int64("segment ID", segmentID),
zap.Int64("task ID", importTaskID),
zap.Int64("segmentID", segmentID),
zap.String("targetChName", targetChName),
zap.Int64("rowCount", rowCount),
zap.Uint64("ts", ts))
err = retry.Do(context.Background(), func() error {
err := retry.Do(context.Background(), func() error {
// Ask DataCoord to save binlog path and add segment to the corresponding DataNode flow graph.
resp, err := node.dataCoord.SaveImportSegment(context.Background(), &datapb.SaveImportSegmentRequest{
Base: commonpbutil.NewMsgBase(
@ -1251,7 +1292,7 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root
ChannelName: targetChName,
CollectionId: req.GetImportTask().GetCollectionId(),
PartitionId: req.GetImportTask().GetPartitionId(),
RowNum: int64(rowNum),
RowNum: rowCount,
SaveBinlogPathReq: &datapb.SaveBinlogPathsRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(0),
@ -1261,8 +1302,8 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root
),
SegmentID: segmentID,
CollectionID: req.GetImportTask().GetCollectionId(),
Field2BinlogPaths: fieldInsert,
Field2StatslogPaths: fieldStats,
Field2BinlogPaths: fieldsInsert,
Field2StatslogPaths: fieldsStats,
// Set start positions of a SaveBinlogPathRequest explicitly.
StartPositions: []*datapb.SegmentStartPosition{
{
@ -1291,9 +1332,11 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root
log.Warn("failed to save import segment", zap.Error(err))
return err
}
log.Info("segment imported and persisted", zap.Int64("segmentID", segmentID))
log.Info("segment imported and persisted",
zap.Int64("task ID", importTaskID),
zap.Int64("segmentID", segmentID))
res.Segments = append(res.Segments, segmentID)
res.RowCount += int64(rowNum)
res.RowCount += rowCount
return nil
}
}

View File

@ -38,7 +38,6 @@ import (
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
@ -603,30 +602,6 @@ func TestDataNode(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stat.GetErrorCode())
})
t.Run("Test Import callback func error", func(t *testing.T) {
req := &datapb.ImportTaskRequest{
ImportTask: &datapb.ImportTask{
CollectionId: 100,
PartitionId: 100,
ChannelNames: []string{"ch1", "ch2"},
},
}
importResult := &rootcoordpb.ImportResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
TaskId: 0,
DatanodeId: 0,
State: commonpb.ImportState_ImportStarted,
Segments: make([]int64, 0),
AutoIds: make([]int64, 0),
RowCount: 0,
}
callback := importFlushReqFunc(node, req, importResult, nil, 0)
err := callback(nil, len(req.ImportTask.ChannelNames)+1)
assert.Error(t, err)
})
t.Run("Test BackGroundGC", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64)

View File

@ -114,7 +114,7 @@ message SegmentIDRequest {
string channel_name = 2;
int64 collectionID = 3;
int64 partitionID = 4;
bool isImport = 5; // Indicate whether this request comes from a bulk load task.
bool isImport = 5; // Indicate whether this request comes from a bulk insert task.
int64 importTaskID = 6; // Needed for segment lock.
}
@ -269,8 +269,8 @@ message SegmentInfo {
repeated int64 compactionFrom = 15;
uint64 dropped_at = 16; // timestamp when segment marked drop
// A flag indicating if:
// (1) this segment is created by bulk load, and
// (2) the bulk load task that creates this segment has not yet reached `ImportCompleted` state.
// (1) this segment is created by bulk insert, and
// (2) the bulk insert task that creates this segment has not yet reached `ImportCompleted` state.
bool is_importing = 17;
bool is_fake = 18;
}

View File

@ -1598,8 +1598,8 @@ type SegmentInfo struct {
CompactionFrom []int64 `protobuf:"varint,15,rep,packed,name=compactionFrom,proto3" json:"compactionFrom,omitempty"`
DroppedAt uint64 `protobuf:"varint,16,opt,name=dropped_at,json=droppedAt,proto3" json:"dropped_at,omitempty"`
// A flag indicating if:
// (1) this segment is created by bulk load, and
// (2) the bulk load task that creates this segment has not yet reached `ImportCompleted` state.
// (1) this segment is created by bulk insert, and
// (2) the bulk insert task that creates this segment has not yet reached `ImportCompleted` state.
IsImporting bool `protobuf:"varint,17,opt,name=is_importing,json=isImporting,proto3" json:"is_importing,omitempty"`
IsFake bool `protobuf:"varint,18,opt,name=is_fake,json=isFake,proto3" json:"is_fake,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`

View File

@ -3967,7 +3967,8 @@ func unhealthyStatus() *commonpb.Status {
func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) {
log.Info("received import request",
zap.String("collection name", req.GetCollectionName()),
zap.Bool("row-based", req.GetRowBased()))
zap.String("partition name", req.GetPartitionName()),
zap.Strings("files", req.GetFiles()))
resp := &milvuspb.ImportResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
@ -3996,7 +3997,7 @@ func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*mi
respFromRC, err := node.rootCoord.Import(ctx, req)
if err != nil {
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method, metrics.FailLabel).Inc()
log.Error("failed to execute bulk load request", zap.Error(err))
log.Error("failed to execute bulk insert request", zap.Error(err))
resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
resp.Status.Reason = err.Error()
return resp, nil

View File

@ -1698,7 +1698,7 @@ func TestProxy(t *testing.T) {
defer wg.Done()
req := &milvuspb.ImportRequest{
CollectionName: collectionName,
Files: []string{"f1.json", "f2.json", "f3.csv"},
Files: []string{"f1.json", "f2.json", "f3.json"},
}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
resp, err := proxy.Import(context.TODO(), req)

View File

@ -34,7 +34,9 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/util/importutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
@ -210,7 +212,6 @@ func (m *importManager) sendOutTasks(ctx context.Context) error {
CollectionId: task.GetCollectionId(),
PartitionId: task.GetPartitionId(),
ChannelNames: task.GetChannelNames(),
RowBased: task.GetRowBased(),
TaskId: task.GetId(),
Files: task.GetFiles(),
Infos: task.GetInfos(),
@ -381,25 +382,39 @@ func (m *importManager) checkIndexingDone(ctx context.Context, collID UniqueID,
return len(allSegmentIDs) == indexedSegmentCount, nil
}
func (m *importManager) isRowbased(files []string) (bool, error) {
isRowBased := false
for _, filePath := range files {
_, fileType := importutil.GetFileNameAndExt(filePath)
if fileType == importutil.JSONFileExt {
isRowBased = true
} else if isRowBased {
log.Error("row-based data file type must be JSON, mixed file types is not allowed", zap.Strings("files", files))
return isRowBased, fmt.Errorf("row-based data file type must be JSON, file type '%s' is not allowed", fileType)
}
}
return isRowBased, nil
}
// importJob processes the import request, generates import tasks, sends these tasks to DataCoord, and returns
// immediately.
func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportRequest, cID int64, pID int64) *milvuspb.ImportResponse {
if req == nil || len(req.Files) == 0 {
returnErrorFunc := func(reason string) *milvuspb.ImportResponse {
return &milvuspb.ImportResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "import request is empty",
Reason: reason,
},
}
}
if m.callImportService == nil {
return &milvuspb.ImportResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "import service is not available",
},
if req == nil || len(req.Files) == 0 {
return returnErrorFunc("import request is empty")
}
if m.callImportService == nil {
return returnErrorFunc("import service is not available")
}
resp := &milvuspb.ImportResponse{
@ -409,7 +424,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
Tasks: make([]int64, 0),
}
log.Debug("request received",
log.Debug("receive import job",
zap.String("collection name", req.GetCollectionName()),
zap.Int64("collection ID", cID),
zap.Int64("partition ID", pID))
@ -420,8 +435,13 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
capacity := cap(m.pendingTasks)
length := len(m.pendingTasks)
isRowBased, err := m.isRowbased(req.GetFiles())
if err != nil {
return err
}
taskCount := 1
if req.RowBased {
if isRowBased {
taskCount = len(req.Files)
}
@ -433,7 +453,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
}
// convert import request to import tasks
if req.RowBased {
if isRowBased {
// For row-based importing, each file makes a task.
taskList := make([]int64, len(req.Files))
for i := 0; i < len(req.Files); i++ {
@ -446,7 +466,6 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
CollectionId: cID,
PartitionId: pID,
ChannelNames: req.ChannelNames,
RowBased: req.GetRowBased(),
Files: []string{req.GetFiles()[i]},
CreateTs: time.Now().Unix(),
State: &datapb.ImportTaskState{
@ -485,7 +504,6 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
CollectionId: cID,
PartitionId: pID,
ChannelNames: req.ChannelNames,
RowBased: req.GetRowBased(),
Files: req.GetFiles(),
CreateTs: time.Now().Unix(),
State: &datapb.ImportTaskState{
@ -514,12 +532,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
return nil
}()
if err != nil {
return &milvuspb.ImportResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}
return returnErrorFunc(err.Error())
}
if sendOutTasksErr := m.sendOutTasks(ctx); sendOutTasksErr != nil {
log.Error("fail to send out tasks", zap.Error(sendOutTasksErr))
@ -1054,7 +1067,6 @@ func cloneImportTaskInfo(taskInfo *datapb.ImportTaskInfo) *datapb.ImportTaskInfo
CollectionId: taskInfo.GetCollectionId(),
PartitionId: taskInfo.GetPartitionId(),
ChannelNames: taskInfo.GetChannelNames(),
RowBased: taskInfo.GetRowBased(),
Files: taskInfo.GetFiles(),
CreateTs: taskInfo.GetCreateTs(),
State: taskInfo.GetState(),

View File

@ -574,6 +574,8 @@ func TestImportManager_ImportJob(t *testing.T) {
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
// nil request
mgr := newImportManager(context.TODO(), mockKv, idAlloc, nil, callMarkSegmentsDropped, nil, nil, nil, nil)
resp := mgr.importJob(context.TODO(), nil, colID, 0)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
@ -581,27 +583,14 @@ func TestImportManager_ImportJob(t *testing.T) {
rowReq := &milvuspb.ImportRequest{
CollectionName: "c1",
PartitionName: "p1",
RowBased: true,
Files: []string{"f1", "f2", "f3"},
Files: []string{"f1.json", "f2.json", "f3.json"},
}
// nil callImportService
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
colReq := &milvuspb.ImportRequest{
CollectionName: "c1",
PartitionName: "p1",
RowBased: false,
Files: []string{"f1", "f2"},
Options: []*commonpb.KeyValuePair{
{
Key: importutil.Bucket,
Value: "mybucket",
},
},
}
fn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
importServiceFunc := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
return &datapb.ImportTaskResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -609,17 +598,27 @@ func TestImportManager_ImportJob(t *testing.T) {
}, nil
}
mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil)
// row-based case, task count equal to file count
// since the importServiceFunc return error, tasks will be kept in pending list
mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil)
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
assert.Equal(t, len(rowReq.Files), len(mgr.pendingTasks))
assert.Equal(t, 0, len(mgr.workingTasks))
mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil)
colReq := &milvuspb.ImportRequest{
CollectionName: "c1",
PartitionName: "p1",
Files: []string{"f1.npy", "f2.npy", "f3.npy"},
}
// column-based case, one quest one task
// since the importServiceFunc return error, tasks will be kept in pending list
mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil)
resp = mgr.importJob(context.TODO(), colReq, colID, 0)
assert.Equal(t, 1, len(mgr.pendingTasks))
assert.Equal(t, 0, len(mgr.workingTasks))
fn = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
importServiceFunc = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
return &datapb.ImportTaskResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
@ -627,18 +626,20 @@ func TestImportManager_ImportJob(t *testing.T) {
}, nil
}
mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil)
// row-based case, since the importServiceFunc return success, tasks will be sent to woring list
mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil)
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
assert.Equal(t, 0, len(mgr.pendingTasks))
assert.Equal(t, len(rowReq.Files), len(mgr.workingTasks))
mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil)
// column-based case, since the importServiceFunc return success, tasks will be sent to woring list
mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil)
resp = mgr.importJob(context.TODO(), colReq, colID, 0)
assert.Equal(t, 0, len(mgr.pendingTasks))
assert.Equal(t, 1, len(mgr.workingTasks))
count := 0
fn = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
importServiceFunc = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
if count >= 2 {
return &datapb.ImportTaskResponse{
Status: &commonpb.Status{
@ -654,13 +655,16 @@ func TestImportManager_ImportJob(t *testing.T) {
}, nil
}
mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil)
// row-based case, since the importServiceFunc return success for 2 tasks
// the 2 tasks are sent to working list, and 1 task left in pending list
mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil)
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
assert.Equal(t, len(rowReq.Files)-2, len(mgr.pendingTasks))
assert.Equal(t, 2, len(mgr.workingTasks))
for i := 0; i <= 32; i++ {
rowReq.Files = append(rowReq.Files, strconv.Itoa(i))
// files count exceed MaxPendingCount, return error
for i := 0; i <= MaxPendingCount; i++ {
rowReq.Files = append(rowReq.Files, strconv.Itoa(i)+".json")
}
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
@ -683,14 +687,12 @@ func TestImportManager_AllDataNodesBusy(t *testing.T) {
rowReq := &milvuspb.ImportRequest{
CollectionName: "c1",
PartitionName: "p1",
RowBased: true,
Files: []string{"f1", "f2", "f3"},
Files: []string{"f1.json", "f2.json", "f3.json"},
}
colReq := &milvuspb.ImportRequest{
CollectionName: "c1",
PartitionName: "p1",
RowBased: false,
Files: []string{"f1", "f2"},
Files: []string{"f1.npy", "f2.npy"},
Options: []*commonpb.KeyValuePair{
{
Key: importutil.Bucket,
@ -778,8 +780,7 @@ func TestImportManager_TaskState(t *testing.T) {
rowReq := &milvuspb.ImportRequest{
CollectionName: "c1",
PartitionName: "p1",
RowBased: true,
Files: []string{"f1", "f2", "f3"},
Files: []string{"f1.json", "f2.json", "f3.json"},
}
callMarkSegmentsDropped := func(ctx context.Context, segIDs []typeutil.UniqueID) (*commonpb.Status, error) {
@ -817,8 +818,7 @@ func TestImportManager_TaskState(t *testing.T) {
assert.Equal(t, int64(2), ti.GetId())
assert.Equal(t, int64(100), ti.GetCollectionId())
assert.Equal(t, int64(0), ti.GetPartitionId())
assert.Equal(t, true, ti.GetRowBased())
assert.Equal(t, []string{"f2"}, ti.GetFiles())
assert.Equal(t, []string{"f2.json"}, ti.GetFiles())
assert.Equal(t, commonpb.ImportState_ImportPersisted, ti.GetState().GetStateCode())
assert.Equal(t, int64(1000), ti.GetState().GetRowCount())
@ -876,8 +876,7 @@ func TestImportManager_AllocFail(t *testing.T) {
rowReq := &milvuspb.ImportRequest{
CollectionName: "c1",
PartitionName: "p1",
RowBased: true,
Files: []string{"f1", "f2", "f3"},
Files: []string{"f1.json", "f2.json", "f3.json"},
}
callMarkSegmentsDropped := func(ctx context.Context, segIDs []typeutil.UniqueID) (*commonpb.Status, error) {
@ -917,8 +916,7 @@ func TestImportManager_ListAllTasks(t *testing.T) {
rowReq := &milvuspb.ImportRequest{
CollectionName: "c1",
PartitionName: "p1",
RowBased: true,
Files: []string{"f1", "f2", "f3"},
Files: []string{"f1.json", "f2.json", "f3.json"},
}
callMarkSegmentsDropped := func(ctx context.Context, segIDs []typeutil.UniqueID) (*commonpb.Status, error) {
return &commonpb.Status{
@ -1007,3 +1005,22 @@ func TestImportManager_rearrangeTasks(t *testing.T) {
assert.Equal(t, int64(50), tasks[1].GetId())
assert.Equal(t, int64(100), tasks[2].GetId())
}
func TestImportManager_isRowbased(t *testing.T) {
mgr := &importManager{}
files := []string{"1.json", "2.json"}
rb, err := mgr.isRowbased(files)
assert.Nil(t, err)
assert.True(t, rb)
files = []string{"1.json", "2.npy"}
rb, err = mgr.isRowbased(files)
assert.NotNil(t, err)
assert.True(t, rb)
files = []string{"1.npy", "2.npy"}
rb, err = mgr.isRowbased(files)
assert.Nil(t, err)
assert.False(t, rb)
}

View File

@ -86,8 +86,8 @@ type IMetaTable interface {
ListAliasesByID(collID UniqueID) []string
// TODO: better to accept ctx.
GetPartitionNameByID(collID UniqueID, partitionID UniqueID, ts Timestamp) (string, error) // serve for bulk load.
GetPartitionByName(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) // serve for bulk load.
GetPartitionNameByID(collID UniqueID, partitionID UniqueID, ts Timestamp) (string, error) // serve for bulk insert.
GetPartitionByName(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) // serve for bulk insert.
// TODO: better to accept ctx.
AddCredential(credInfo *internalpb.CredentialInfo) error
@ -634,7 +634,7 @@ func (mt *MetaTable) ListAliasesByID(collID UniqueID) []string {
return mt.listAliasesByID(collID)
}
// GetCollectionNameByID serve for bulk load. TODO: why this didn't accept ts?
// GetCollectionNameByID serve for bulk insert. TODO: why this didn't accept ts?
// [Deprecated]
func (mt *MetaTable) GetCollectionNameByID(collID UniqueID) (string, error) {
mt.ddLock.RLock()
@ -648,7 +648,7 @@ func (mt *MetaTable) GetCollectionNameByID(collID UniqueID) (string, error) {
return coll.Name, nil
}
// GetPartitionNameByID serve for bulk load.
// GetPartitionNameByID serve for bulk insert.
func (mt *MetaTable) GetPartitionNameByID(collID UniqueID, partitionID UniqueID, ts Timestamp) (string, error) {
mt.ddLock.RLock()
defer mt.ddLock.RUnlock()
@ -680,7 +680,7 @@ func (mt *MetaTable) GetPartitionNameByID(collID UniqueID, partitionID UniqueID,
return "", fmt.Errorf("partition not exist: %d", partitionID)
}
// GetCollectionIDByName serve for bulk load. TODO: why this didn't accept ts?
// GetCollectionIDByName serve for bulk insert. TODO: why this didn't accept ts?
// [Deprecated]
func (mt *MetaTable) GetCollectionIDByName(name string) (UniqueID, error) {
mt.ddLock.RLock()
@ -693,7 +693,7 @@ func (mt *MetaTable) GetCollectionIDByName(name string) (UniqueID, error) {
return id, nil
}
// GetPartitionByName serve for bulk load.
// GetPartitionByName serve for bulk insert.
func (mt *MetaTable) GetPartitionByName(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) {
mt.ddLock.RLock()
defer mt.ddLock.RUnlock()

View File

@ -1594,7 +1594,6 @@ func (c *Core) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvus
zap.Strings("virtual channel names", req.GetChannelNames()),
zap.Int64("partition ID", pID),
zap.Int("# of files = ", len(req.GetFiles())),
zap.Bool("row-based", req.GetRowBased()),
)
importJobResp := c.importManager.importJob(ctx, req, cID, pID)
return importJobResp, nil
@ -1692,7 +1691,7 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) (
resendTaskFunc()
// Flush all import data segments.
if err := c.broker.Flush(ctx, ti.GetCollectionId(), ir.GetSegments()); err != nil {
log.Error("failed to call Flush on bulk load segments",
log.Error("failed to call Flush on bulk insert segments",
zap.Int64("task ID", ir.GetTaskId()))
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,

View File

@ -17,6 +17,7 @@
package importutil
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -43,29 +44,39 @@ type SegmentFilesHolder struct {
// 1. read insert log of each field, then constructs map[storage.FieldID]storage.FieldData in memory.
// 2. read delta log to remove deleted entities(TimeStampField is used to apply or skip the operation).
// 3. split data according to shard number
// 4. call the callFlushFunc function to flush data into new segment if data size reaches segmentSize.
// 4. call the callFlushFunc function to flush data into new binlog file if data size reaches blockSize.
type BinlogAdapter struct {
ctx context.Context // for canceling parse process
collectionSchema *schemapb.CollectionSchema // collection schema
chunkManager storage.ChunkManager // storage interfaces to read binlog files
callFlushFunc ImportFlushFunc // call back function to flush segment
shardNum int32 // sharding number of the collection
segmentSize int64 // maximum size of a segment(unit:byte)
blockSize int64 // maximum size of a read block(unit:byte)
maxTotalSize int64 // maximum size of in-memory segments(unit:byte)
primaryKey storage.FieldID // id of primary key
primaryType schemapb.DataType // data type of primary key
// a timestamp to define the end point of restore, data after this point will be ignored
// a timestamp to define the start time point of restore, data before this time point will be ignored
// set this value to 0, all the data will be imported
// set this value to math.MaxUint64, all the data will be ignored
// the tsStartPoint value must be less/equal than tsEndPoint
tsStartPoint uint64
// a timestamp to define the end time point of restore, data after this time point will be ignored
// set this value to 0, all the data will be ignored
// set this value to math.MaxUint64, all the data will be imported
// the tsEndPoint value must be larger/equal than tsStartPoint
tsEndPoint uint64
}
func NewBinlogAdapter(collectionSchema *schemapb.CollectionSchema,
func NewBinlogAdapter(ctx context.Context,
collectionSchema *schemapb.CollectionSchema,
shardNum int32,
segmentSize int64,
blockSize int64,
maxTotalSize int64,
chunkManager storage.ChunkManager,
flushFunc ImportFlushFunc,
tsStartPoint uint64,
tsEndPoint uint64) (*BinlogAdapter, error) {
if collectionSchema == nil {
log.Error("Binlog adapter: collection schema is nil")
@ -83,18 +94,20 @@ func NewBinlogAdapter(collectionSchema *schemapb.CollectionSchema,
}
adapter := &BinlogAdapter{
ctx: ctx,
collectionSchema: collectionSchema,
chunkManager: chunkManager,
callFlushFunc: flushFunc,
shardNum: shardNum,
segmentSize: segmentSize,
blockSize: blockSize,
maxTotalSize: maxTotalSize,
tsStartPoint: tsStartPoint,
tsEndPoint: tsEndPoint,
}
// amend the segment size to avoid portential OOM risk
if adapter.segmentSize > MaxSegmentSizeInMemory {
adapter.segmentSize = MaxSegmentSizeInMemory
if adapter.blockSize > MaxSegmentSizeInMemory {
adapter.blockSize = MaxSegmentSizeInMemory
}
// find out the primary key ID and its data type
@ -142,7 +155,7 @@ func (p *BinlogAdapter) Read(segmentHolder *SegmentFilesHolder) error {
// b has these binlog files: b_1, b_2, b_3 ...
// Then first round read a_1 and b_1, second round read a_2 and b_2, etc...
// deleted list will be used to remove deleted entities
// if accumulate data exceed segmentSize, call callFlushFunc to generate new segment
// if accumulate data exceed blockSize, call callFlushFunc to generate new binlog file
batchCount := 0
for _, files := range segmentHolder.fieldFiles {
batchCount = len(files)
@ -215,24 +228,30 @@ func (p *BinlogAdapter) Read(segmentHolder *SegmentFilesHolder) error {
// read other insert logs and use the shardList to do sharding
for fieldID, file := range batchFiles {
// outside context might be canceled(service stop, or future enhancement for canceling import task)
if isCanceled(p.ctx) {
log.Error("Binlog adapter: import task was canceled")
return errors.New("import task was canceled")
}
err = p.readInsertlog(fieldID, file, segmentsData, shardList)
if err != nil {
return err
}
}
// flush segment whose size exceed segmentSize
err = p.tryFlushSegments(segmentsData, false)
// flush segment whose size exceed blockSize
err = tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.callFlushFunc, p.blockSize, p.maxTotalSize, false)
if err != nil {
return err
}
}
// finally, force to flush
return p.tryFlushSegments(segmentsData, true)
return tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.callFlushFunc, p.blockSize, p.maxTotalSize, true)
}
// This method verify the schema and binlog files
// verify method verify the schema and binlog files
// 1. each field must has binlog file
// 2. binlog file count of each field must be equal
// 3. the collectionSchema doesn't contain TimeStampField and RowIDField since the import_wrapper excludes them,
@ -284,7 +303,7 @@ func (p *BinlogAdapter) verify(segmentHolder *SegmentFilesHolder) error {
return nil
}
// This method read data from deltalog, and convert to a dict
// readDeltalogs method reads data from deltalog, and convert to a dict
// The deltalog data is a list, to improve performance of next step, we convert it to a dict,
// key is the deleted ID, value is operation timestamp which is used to apply or skip the delete operation.
func (p *BinlogAdapter) readDeltalogs(segmentHolder *SegmentFilesHolder) (map[int64]uint64, map[string]uint64, error) {
@ -318,7 +337,7 @@ func (p *BinlogAdapter) readDeltalogs(segmentHolder *SegmentFilesHolder) (map[in
}
}
// Decode string array(read from delta log) to storage.DeleteLog array
// decodeDeleteLogs decodes string array(read from delta log) to storage.DeleteLog array
func (p *BinlogAdapter) decodeDeleteLogs(segmentHolder *SegmentFilesHolder) ([]*storage.DeleteLog, error) {
// step 1: read all delta logs to construct a string array, each string is marshaled from storage.DeleteLog
stringArray := make([]string, 0)
@ -345,8 +364,9 @@ func (p *BinlogAdapter) decodeDeleteLogs(segmentHolder *SegmentFilesHolder) ([]*
return nil, err
}
// ignore deletions whose timestamp is larger than the tsEndPoint
if deleteLog.Ts <= p.tsEndPoint {
// only the ts between tsStartPoint and tsEndPoint is effective
// ignore deletions whose timestamp is larger than the tsEndPoint or less than tsStartPoint
if deleteLog.Ts >= p.tsStartPoint && deleteLog.Ts <= p.tsEndPoint {
deleteLogs = append(deleteLogs, deleteLog)
}
}
@ -365,7 +385,7 @@ func (p *BinlogAdapter) decodeDeleteLogs(segmentHolder *SegmentFilesHolder) ([]*
return deleteLogs, nil
}
// Decode a string to storage.DeleteLog
// decodeDeleteLog decodes a string to storage.DeleteLog
// Note: the following code is mainly come from data_codec.go, I suppose the code can compatible with old version 2.0
func (p *BinlogAdapter) decodeDeleteLog(deltaStr string) (*storage.DeleteLog, error) {
deleteLog := &storage.DeleteLog{}
@ -399,7 +419,7 @@ func (p *BinlogAdapter) decodeDeleteLog(deltaStr string) (*storage.DeleteLog, er
return deleteLog, nil
}
// Each delta log data type is varchar, marshaled from an array of storage.DeleteLog objects.
// readDeltalog parses a delta log file. Each delta log data type is varchar, marshaled from an array of storage.DeleteLog objects.
func (p *BinlogAdapter) readDeltalog(logPath string) ([]string, error) {
// open the delta log file
binlogFile, err := NewBinlogFile(p.chunkManager)
@ -426,7 +446,7 @@ func (p *BinlogAdapter) readDeltalog(logPath string) ([]string, error) {
return data, nil
}
// This method read data from int64 field, currently we use it to read the timestamp field.
// readTimestamp method reads data from int64 field, currently we use it to read the timestamp field.
func (p *BinlogAdapter) readTimestamp(logPath string) ([]int64, error) {
// open the log file
binlogFile, err := NewBinlogFile(p.chunkManager)
@ -454,7 +474,7 @@ func (p *BinlogAdapter) readTimestamp(logPath string) ([]int64, error) {
return int64List, nil
}
// This method read primary keys from insert log.
// readPrimaryKeys method reads primary keys from insert log.
func (p *BinlogAdapter) readPrimaryKeys(logPath string) ([]int64, []string, error) {
// open the delta log file
binlogFile, err := NewBinlogFile(p.chunkManager)
@ -493,7 +513,7 @@ func (p *BinlogAdapter) readPrimaryKeys(logPath string) ([]int64, []string, erro
}
}
// This method generate a shard id list by primary key(int64) list and deleted list.
// getShardingListByPrimaryInt64 method generates a shard id list by primary key(int64) list and deleted list.
// For example, an insert log has 10 rows, the no.3 and no.7 has been deleted, shardNum=2, the shardList could be:
// [0, 1, -1, 1, 0, 1, -1, 1, 0, 1]
// Compare timestampList with tsEndPoint to skip some rows.
@ -513,10 +533,10 @@ func (p *BinlogAdapter) getShardingListByPrimaryInt64(primaryKeys []int64,
excluded := 0
shardList := make([]int32, 0, len(primaryKeys))
for i, key := range primaryKeys {
// if this entity's timestamp is greater than the tsEndPoint, set shardID = -1 to skip this entity
// if this entity's timestamp is greater than the tsEndPoint, or less than tsStartPoint, set shardID = -1 to skip this entity
// timestamp is stored as int64 type in log file, actually it is uint64, compare with uint64
ts := timestampList[i]
if uint64(ts) > p.tsEndPoint {
if uint64(ts) > p.tsEndPoint || uint64(ts) < p.tsStartPoint {
shardList = append(shardList, -1)
excluded++
continue
@ -546,7 +566,7 @@ func (p *BinlogAdapter) getShardingListByPrimaryInt64(primaryKeys []int64,
return shardList, nil
}
// This method generate a shard id list by primary key(varchar) list and deleted list.
// getShardingListByPrimaryVarchar method generates a shard id list by primary key(varchar) list and deleted list.
// For example, an insert log has 10 rows, the no.3 and no.7 has been deleted, shardNum=2, the shardList could be:
// [0, 1, -1, 1, 0, 1, -1, 1, 0, 1]
func (p *BinlogAdapter) getShardingListByPrimaryVarchar(primaryKeys []string,
@ -565,10 +585,10 @@ func (p *BinlogAdapter) getShardingListByPrimaryVarchar(primaryKeys []string,
excluded := 0
shardList := make([]int32, 0, len(primaryKeys))
for i, key := range primaryKeys {
// if this entity's timestamp is greater than the tsEndPoint, set shardID = -1 to skip this entity
// if this entity's timestamp is greater than the tsEndPoint, or less than tsStartPoint, set shardID = -1 to skip this entity
// timestamp is stored as int64 type in log file, actually it is uint64, compare with uint64
ts := timestampList[i]
if uint64(ts) > p.tsEndPoint {
if uint64(ts) > p.tsEndPoint || uint64(ts) < p.tsStartPoint {
shardList = append(shardList, -1)
excluded++
continue
@ -598,7 +618,7 @@ func (p *BinlogAdapter) getShardingListByPrimaryVarchar(primaryKeys []string,
return shardList, nil
}
// This method read an insert log, and split the data into different shards according to a shard list
// readInsertlog method reads an insert log, and split the data into different shards according to a shard list
// The shardList is a list to tell which row belong to which shard, returned by getShardingListByPrimaryXXX()
// For deleted rows, we say its shard id is -1.
// For example, an insert log has 10 rows, the no.3 and no.7 has been deleted, shardNum=2, the shardList could be:
@ -1000,96 +1020,3 @@ func (p *BinlogAdapter) dispatchFloatVecToShards(data []float32, dim int, memory
return nil
}
// This method do the two things:
// 1. if accumulate data of a segment exceed segmentSize, call callFlushFunc to generate new segment
// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest segment
func (p *BinlogAdapter) tryFlushSegments(segmentsData []map[storage.FieldID]storage.FieldData, force bool) error {
totalSize := 0
biggestSize := 0
biggestItem := -1
// 1. if accumulate data of a segment exceed segmentSize, call callFlushFunc to generate new segment
for i := 0; i < len(segmentsData); i++ {
segmentData := segmentsData[i]
// Note: even rowCount is 0, the size is still non-zero
size := 0
rowCount := 0
for _, fieldData := range segmentData {
size += fieldData.GetMemorySize()
rowCount = fieldData.RowNum()
}
// force to flush, called at the end of Read()
if force && rowCount > 0 {
err := p.callFlushFunc(segmentData, i)
if err != nil {
log.Error("Binlog adapter: failed to force flush segment data", zap.Int("shardID", i))
return err
}
log.Info("Binlog adapter: force flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i))
segmentsData[i] = initSegmentData(p.collectionSchema)
if segmentsData[i] == nil {
log.Error("Binlog adapter: failed to initialize FieldData list")
return errors.New("failed to initialize FieldData list")
}
continue
}
// if segment size is larger than predefined segmentSize, flush to create a new segment
// initialize a new FieldData list for next round batch read
if size > int(p.segmentSize) && rowCount > 0 {
err := p.callFlushFunc(segmentData, i)
if err != nil {
log.Error("Binlog adapter: failed to flush segment data", zap.Int("shardID", i))
return err
}
log.Info("Binlog adapter: segment size exceed limit and flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i))
segmentsData[i] = initSegmentData(p.collectionSchema)
if segmentsData[i] == nil {
log.Error("Binlog adapter: failed to initialize FieldData list")
return errors.New("failed to initialize FieldData list")
}
continue
}
// calculate the total size(ignore the flushed segments)
// find out the biggest segment for the step 2
totalSize += size
if size > biggestSize {
biggestSize = size
biggestItem = i
}
}
// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest segment
if totalSize > int(p.maxTotalSize) && biggestItem >= 0 {
segmentData := segmentsData[biggestItem]
size := 0
rowCount := 0
for _, fieldData := range segmentData {
size += fieldData.GetMemorySize()
rowCount = fieldData.RowNum()
}
if rowCount > 0 {
err := p.callFlushFunc(segmentData, biggestItem)
if err != nil {
log.Error("Binlog adapter: failed to flush biggest segment data", zap.Int("shardID", biggestItem))
return err
}
log.Info("Binlog adapter: total size exceed limit and flush", zap.Int("rowCount", rowCount),
zap.Int("size", size), zap.Int("totalSize", totalSize), zap.Int("shardID", biggestItem))
segmentsData[biggestItem] = initSegmentData(p.collectionSchema)
if segmentsData[biggestItem] == nil {
log.Error("Binlog adapter: failed to initialize FieldData list")
return errors.New("failed to initialize FieldData list")
}
}
}
return nil
}

View File

@ -16,8 +16,10 @@
package importutil
import (
"context"
"encoding/json"
"errors"
"math"
"strconv"
"testing"
@ -154,18 +156,20 @@ func createSegmentsData(fieldsData map[storage.FieldID]interface{}, shardNum int
}
func Test_NewBinlogAdapter(t *testing.T) {
ctx := context.Background()
// nil schema
adapter, err := NewBinlogAdapter(nil, 2, 1024, 2048, nil, nil, 0)
adapter, err := NewBinlogAdapter(ctx, nil, 2, 1024, 2048, nil, nil, 0, math.MaxUint64)
assert.Nil(t, adapter)
assert.NotNil(t, err)
// nil chunkmanager
adapter, err = NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, nil, nil, 0)
adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, nil, nil, 0, math.MaxUint64)
assert.Nil(t, adapter)
assert.NotNil(t, err)
// nil flushfunc
adapter, err = NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, nil, 0)
adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, nil, 0, math.MaxUint64)
assert.Nil(t, adapter)
assert.NotNil(t, err)
@ -173,7 +177,7 @@ func Test_NewBinlogAdapter(t *testing.T) {
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
adapter, err = NewBinlogAdapter(sampleSchema(), 2, 2048, 1024, &MockChunkManager{}, flushFunc, 0)
adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 2048, 1024, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -191,16 +195,18 @@ func Test_NewBinlogAdapter(t *testing.T) {
},
},
}
adapter, err = NewBinlogAdapter(schema, 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
adapter, err = NewBinlogAdapter(ctx, schema, 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.Nil(t, adapter)
assert.NotNil(t, err)
}
func Test_BinlogAdapterVerify(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -247,6 +253,8 @@ func Test_BinlogAdapterVerify(t *testing.T) {
}
func Test_BinlogAdapterReadDeltalog(t *testing.T) {
ctx := context.Background()
deleteItems := []int64{1001, 1002, 1003}
buf := createDeltalogBuf(t, deleteItems, false)
chunkManager := &MockChunkManager{
@ -259,7 +267,7 @@ func Test_BinlogAdapterReadDeltalog(t *testing.T) {
return nil
}
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -283,6 +291,8 @@ func Test_BinlogAdapterReadDeltalog(t *testing.T) {
}
func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) {
ctx := context.Background()
deleteItems := []int64{1001, 1002, 1003, 1004, 1005}
buf := createDeltalogBuf(t, deleteItems, false)
chunkManager := &MockChunkManager{
@ -295,7 +305,7 @@ func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) {
return nil
}
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -316,7 +326,7 @@ func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) {
"dummy": createDeltalogBuf(t, []string{"1001", "1002"}, true),
}
adapter, err = NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0)
adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -327,11 +337,13 @@ func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) {
}
func Test_BinlogAdapterDecodeDeleteLog(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -378,6 +390,8 @@ func Test_BinlogAdapterDecodeDeleteLog(t *testing.T) {
}
func Test_BinlogAdapterReadDeltalogs(t *testing.T) {
ctx := context.Background()
deleteItems := []int64{1001, 1002, 1003, 1004, 1005}
buf := createDeltalogBuf(t, deleteItems, false)
chunkManager := &MockChunkManager{
@ -390,7 +404,7 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) {
return nil
}
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -432,17 +446,19 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) {
"dummy": createDeltalogBuf(t, []string{"1001", "1002"}, true),
}
adapter, err = NewBinlogAdapter(schema, 2, 1024, 2048, chunkManager, flushFunc, 0)
adapter, err = NewBinlogAdapter(ctx, schema, 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
// 2.1 all deletion have been filtered out
adapter.tsStartPoint = baseTimestamp + 2
intDeletions, strDeletions, err = adapter.readDeltalogs(holder)
assert.Nil(t, err)
assert.Nil(t, intDeletions)
assert.Nil(t, strDeletions)
// 2.2 filter the no.1 and no.2 deletion
adapter.tsStartPoint = 0
adapter.tsEndPoint = baseTimestamp + 1
intDeletions, strDeletions, err = adapter.readDeltalogs(holder)
assert.Nil(t, err)
@ -470,7 +486,7 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) {
},
}
adapter, err = NewBinlogAdapter(schema, 2, 1024, 2048, chunkManager, flushFunc, 0)
adapter, err = NewBinlogAdapter(ctx, schema, 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -482,10 +498,12 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) {
}
func Test_BinlogAdapterReadTimestamp(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -515,10 +533,12 @@ func Test_BinlogAdapterReadTimestamp(t *testing.T) {
}
func Test_BinlogAdapterReadPrimaryKeys(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -570,12 +590,14 @@ func Test_BinlogAdapterReadPrimaryKeys(t *testing.T) {
}
func Test_BinlogAdapterShardListInt64(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
shardNum := int32(2)
adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -610,12 +632,14 @@ func Test_BinlogAdapterShardListInt64(t *testing.T) {
}
func Test_BinlogAdapterShardListVarchar(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
shardNum := int32(2)
adapter, err := NewBinlogAdapter(strKeySchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, strKeySchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -650,6 +674,8 @@ func Test_BinlogAdapterShardListVarchar(t *testing.T) {
}
func Test_BinlogAdapterReadInt64PK(t *testing.T) {
ctx := context.Background()
chunkManager := &MockChunkManager{}
flushCounter := 0
@ -669,7 +695,7 @@ func Test_BinlogAdapterReadInt64PK(t *testing.T) {
}
shardNum := int32(2)
adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
adapter.tsEndPoint = baseTimestamp + 1
@ -741,6 +767,8 @@ func Test_BinlogAdapterReadInt64PK(t *testing.T) {
}
func Test_BinlogAdapterReadVarcharPK(t *testing.T) {
ctx := context.Background()
chunkManager := &MockChunkManager{}
flushCounter := 0
@ -814,7 +842,7 @@ func Test_BinlogAdapterReadVarcharPK(t *testing.T) {
// succeed
shardNum := int32(3)
adapter, err := NewBinlogAdapter(strKeySchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, strKeySchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -825,93 +853,14 @@ func Test_BinlogAdapterReadVarcharPK(t *testing.T) {
assert.Equal(t, rowCount-502, flushRowCount)
}
func Test_BinlogAdapterTryFlush(t *testing.T) {
flushCounter := 0
flushRowCount := 0
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
flushCounter++
rowCount := 0
for _, v := range fields {
rowCount = v.RowNum()
break
}
flushRowCount += rowCount
for _, v := range fields {
assert.Equal(t, rowCount, v.RowNum())
}
return nil
}
segmentSize := int64(1024)
maxTotalSize := int64(2048)
shardNum := int32(3)
adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, segmentSize, maxTotalSize, &MockChunkManager{}, flushFunc, 0)
assert.NotNil(t, adapter)
assert.Nil(t, err)
// prepare flush data, 3 shards, each shard 10 rows
rowCount := 10
fieldsData := createFieldsData(rowCount)
// non-force flush
segmentsData := createSegmentsData(fieldsData, shardNum)
err = adapter.tryFlushSegments(segmentsData, false)
assert.Nil(t, err)
assert.Equal(t, 0, flushCounter)
assert.Equal(t, 0, flushRowCount)
// force flush
err = adapter.tryFlushSegments(segmentsData, true)
assert.Nil(t, err)
assert.Equal(t, int(shardNum), flushCounter)
assert.Equal(t, rowCount*int(shardNum), flushRowCount)
// after force flush, no data left
flushCounter = 0
flushRowCount = 0
err = adapter.tryFlushSegments(segmentsData, true)
assert.Nil(t, err)
assert.Equal(t, 0, flushCounter)
assert.Equal(t, 0, flushRowCount)
// flush when segment size exceeds segmentSize
segmentsData = createSegmentsData(fieldsData, shardNum)
adapter.segmentSize = 100 // segmentSize is 100 bytes, less than the 10 rows size
err = adapter.tryFlushSegments(segmentsData, false)
assert.Nil(t, err)
assert.Equal(t, int(shardNum), flushCounter)
assert.Equal(t, rowCount*int(shardNum), flushRowCount)
flushCounter = 0
flushRowCount = 0
err = adapter.tryFlushSegments(segmentsData, true) // no data left
assert.Nil(t, err)
assert.Equal(t, 0, flushCounter)
assert.Equal(t, 0, flushRowCount)
// flush when segments total size exceeds maxTotalSize
segmentsData = createSegmentsData(fieldsData, shardNum)
adapter.segmentSize = 4096 // segmentSize is 4096 bytes, larger than the 10 rows size
adapter.maxTotalSize = 100 // maxTotalSize is 100 bytes, less than the 30 rows size
err = adapter.tryFlushSegments(segmentsData, false)
assert.Nil(t, err)
assert.Equal(t, 1, flushCounter) // only the max segment is flushed
assert.Equal(t, 10, flushRowCount)
flushCounter = 0
flushRowCount = 0
err = adapter.tryFlushSegments(segmentsData, true) // two segments left
assert.Nil(t, err)
assert.Equal(t, 2, flushCounter)
assert.Equal(t, 20, flushRowCount)
}
func Test_BinlogAdapterDispatch(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
shardNum := int32(3)
adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)
@ -1055,10 +1004,12 @@ func Test_BinlogAdapterDispatch(t *testing.T) {
}
func Test_BinlogAdapterReadInsertlog(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, adapter)
assert.Nil(t, err)

View File

@ -26,7 +26,7 @@ import (
"go.uber.org/zap"
)
// This class is a wrapper of storage.BinlogReader, to read binlog file, block by block.
// BinlogFile class is a wrapper of storage.BinlogReader, to read binlog file, block by block.
// Note: for bulkoad function, we only handle normal insert log and delta log.
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
// Typically, an insert log file size is 16MB.
@ -72,7 +72,7 @@ func (p *BinlogFile) Open(filePath string) error {
return nil
}
// The outer caller must call this method in defer
// Close close the reader object, outer caller must call this method in defer
func (p *BinlogFile) Close() {
if p.reader != nil {
p.reader.Close()
@ -88,8 +88,8 @@ func (p *BinlogFile) DataType() schemapb.DataType {
return p.reader.PayloadDataType
}
// ReadBool method reads all the blocks of a binlog by a data type.
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
// This method read all the blocks of a binlog by a data type.
func (p *BinlogFile) ReadBool() ([]bool, error) {
if p.reader == nil {
log.Error("Binlog file: binlog reader not yet initialized")
@ -131,8 +131,8 @@ func (p *BinlogFile) ReadBool() ([]bool, error) {
return result, nil
}
// ReadInt8 method reads all the blocks of a binlog by a data type.
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
// This method read all the blocks of a binlog by a data type.
func (p *BinlogFile) ReadInt8() ([]int8, error) {
if p.reader == nil {
log.Error("Binlog file: binlog reader not yet initialized")
@ -174,8 +174,8 @@ func (p *BinlogFile) ReadInt8() ([]int8, error) {
return result, nil
}
// ReadInt16 method reads all the blocks of a binlog by a data type.
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
// This method read all the blocks of a binlog by a data type.
func (p *BinlogFile) ReadInt16() ([]int16, error) {
if p.reader == nil {
log.Error("Binlog file: binlog reader not yet initialized")
@ -217,8 +217,8 @@ func (p *BinlogFile) ReadInt16() ([]int16, error) {
return result, nil
}
// ReadInt32 method reads all the blocks of a binlog by a data type.
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
// This method read all the blocks of a binlog by a data type.
func (p *BinlogFile) ReadInt32() ([]int32, error) {
if p.reader == nil {
log.Error("Binlog file: binlog reader not yet initialized")
@ -260,8 +260,8 @@ func (p *BinlogFile) ReadInt32() ([]int32, error) {
return result, nil
}
// ReadInt64 method reads all the blocks of a binlog by a data type.
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
// This method read all the blocks of a binlog by a data type.
func (p *BinlogFile) ReadInt64() ([]int64, error) {
if p.reader == nil {
log.Error("Binlog file: binlog reader not yet initialized")
@ -303,8 +303,8 @@ func (p *BinlogFile) ReadInt64() ([]int64, error) {
return result, nil
}
// ReadFloat method reads all the blocks of a binlog by a data type.
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
// This method read all the blocks of a binlog by a data type.
func (p *BinlogFile) ReadFloat() ([]float32, error) {
if p.reader == nil {
log.Error("Binlog file: binlog reader not yet initialized")
@ -346,8 +346,8 @@ func (p *BinlogFile) ReadFloat() ([]float32, error) {
return result, nil
}
// ReadDouble method reads all the blocks of a binlog by a data type.
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
// This method read all the blocks of a binlog by a data type.
func (p *BinlogFile) ReadDouble() ([]float64, error) {
if p.reader == nil {
log.Error("Binlog file: binlog reader not yet initialized")
@ -389,8 +389,8 @@ func (p *BinlogFile) ReadDouble() ([]float64, error) {
return result, nil
}
// ReadVarchar method reads all the blocks of a binlog by a data type.
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
// This method read all the blocks of a binlog by a data type.
func (p *BinlogFile) ReadVarchar() ([]string, error) {
if p.reader == nil {
log.Error("Binlog file: binlog reader not yet initialized")
@ -433,8 +433,8 @@ func (p *BinlogFile) ReadVarchar() ([]string, error) {
return result, nil
}
// ReadBinaryVector method reads all the blocks of a binlog by a data type.
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
// This method read all the blocks of a binlog by a data type.
// return vectors data and the dimension
func (p *BinlogFile) ReadBinaryVector() ([]byte, int, error) {
if p.reader == nil {
@ -479,8 +479,8 @@ func (p *BinlogFile) ReadBinaryVector() ([]byte, int, error) {
return result, dim, nil
}
// ReadFloatVector method reads all the blocks of a binlog by a data type.
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
// This method read all the blocks of a binlog by a data type.
// return vectors data and the dimension
func (p *BinlogFile) ReadFloatVector() ([]float32, int, error) {
if p.reader == nil {

View File

@ -599,7 +599,7 @@ func Test_BinlogFileDouble(t *testing.T) {
}
func Test_BinlogFileVarchar(t *testing.T) {
source := []string{"a", "b", "c", "d"}
source := []string{"a", "bb", "罗伯特", "d"}
chunkManager := &MockChunkManager{
readBuf: map[string][]byte{
"dummy": createBinlogBuf(t, schemapb.DataType_VarChar, source),

View File

@ -19,6 +19,7 @@ package importutil
import (
"context"
"errors"
"fmt"
"path"
"sort"
"strconv"
@ -30,23 +31,33 @@ import (
)
type BinlogParser struct {
ctx context.Context // for canceling parse process
collectionSchema *schemapb.CollectionSchema // collection schema
shardNum int32 // sharding number of the collection
segmentSize int64 // maximum size of a segment(unit:byte)
blockSize int64 // maximum size of a read block(unit:byte)
chunkManager storage.ChunkManager // storage interfaces to browse/read the files
callFlushFunc ImportFlushFunc // call back function to flush segment
// a timestamp to define the end point of restore, data after this point will be ignored
// a timestamp to define the start time point of restore, data before this time point will be ignored
// set this value to 0, all the data will be imported
// set this value to math.MaxUint64, all the data will be ignored
// the tsStartPoint value must be less/equal than tsEndPoint
tsStartPoint uint64
// a timestamp to define the end time point of restore, data after this time point will be ignored
// set this value to 0, all the data will be ignored
// set this value to math.MaxUint64, all the data will be imported
// the tsEndPoint value must be larger/equal than tsStartPoint
tsEndPoint uint64
}
func NewBinlogParser(collectionSchema *schemapb.CollectionSchema,
func NewBinlogParser(ctx context.Context,
collectionSchema *schemapb.CollectionSchema,
shardNum int32,
segmentSize int64,
blockSize int64,
chunkManager storage.ChunkManager,
flushFunc ImportFlushFunc,
tsStartPoint uint64,
tsEndPoint uint64) (*BinlogParser, error) {
if collectionSchema == nil {
log.Error("Binlog parser: collection schema is nil")
@ -63,18 +74,27 @@ func NewBinlogParser(collectionSchema *schemapb.CollectionSchema,
return nil, errors.New("flush function is nil")
}
if tsStartPoint > tsEndPoint {
log.Error("Binlog parser: the tsStartPoint should be less than tsEndPoint",
zap.Uint64("tsStartPoint", tsStartPoint), zap.Uint64("tsEndPoint", tsEndPoint))
return nil, fmt.Errorf("Binlog parser: the tsStartPoint %d should be less than tsEndPoint %d", tsStartPoint, tsEndPoint)
}
v := &BinlogParser{
ctx: ctx,
collectionSchema: collectionSchema,
shardNum: shardNum,
segmentSize: segmentSize,
blockSize: blockSize,
chunkManager: chunkManager,
callFlushFunc: flushFunc,
tsStartPoint: tsStartPoint,
tsEndPoint: tsEndPoint,
}
return v, nil
}
// constructSegmentHolders builds a list of SegmentFilesHolder, each SegmentFilesHolder represents a segment folder
// For instance, the insertlogRoot is "backup/bak1/data/insert_log/435978159196147009/435978159196147010".
// 435978159196147009 is a collection id, 435978159196147010 is a partition id,
// there is a segment(id is 435978159261483009) under this partition.
@ -195,8 +215,8 @@ func (p *BinlogParser) parseSegmentFiles(segmentHolder *SegmentFilesHolder) erro
return errors.New("segment files holder is nil")
}
adapter, err := NewBinlogAdapter(p.collectionSchema, p.shardNum, p.segmentSize,
MaxTotalSizeInMemory, p.chunkManager, p.callFlushFunc, p.tsEndPoint)
adapter, err := NewBinlogAdapter(p.ctx, p.collectionSchema, p.shardNum, p.blockSize,
MaxTotalSizeInMemory, p.chunkManager, p.callFlushFunc, p.tsStartPoint, p.tsEndPoint)
if err != nil {
log.Error("Binlog parser: failed to create binlog adapter", zap.Error(err))
return err
@ -205,7 +225,7 @@ func (p *BinlogParser) parseSegmentFiles(segmentHolder *SegmentFilesHolder) erro
return adapter.Read(segmentHolder)
}
// This functions requires two paths:
// Parse requires two paths:
// 1. the insert log path of a partition
// 2. the delta log path of a partiion (optional)
func (p *BinlogParser) Parse(filePaths []string) error {

View File

@ -16,7 +16,9 @@
package importutil
import (
"context"
"errors"
"math"
"path"
"strconv"
"testing"
@ -27,18 +29,20 @@ import (
)
func Test_NewBinlogParser(t *testing.T) {
ctx := context.Background()
// nil schema
parser, err := NewBinlogParser(nil, 2, 1024, nil, nil, 0)
parser, err := NewBinlogParser(ctx, nil, 2, 1024, nil, nil, 0, math.MaxUint64)
assert.Nil(t, parser)
assert.NotNil(t, err)
// nil chunkmanager
parser, err = NewBinlogParser(sampleSchema(), 2, 1024, nil, nil, 0)
parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, nil, nil, 0, math.MaxUint64)
assert.Nil(t, parser)
assert.NotNil(t, err)
// nil flushfunc
parser, err = NewBinlogParser(sampleSchema(), 2, 1024, &MockChunkManager{}, nil, 0)
parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, nil, 0, math.MaxUint64)
assert.Nil(t, parser)
assert.NotNil(t, err)
@ -46,12 +50,19 @@ func Test_NewBinlogParser(t *testing.T) {
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
parser, err = NewBinlogParser(sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0)
parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, parser)
assert.Nil(t, err)
// tsStartPoint larger than tsEndPoint
parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 2, 1)
assert.Nil(t, parser)
assert.NotNil(t, err)
}
func Test_BinlogParserConstructHolders(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
@ -116,7 +127,7 @@ func Test_BinlogParserConstructHolders(t *testing.T) {
"backup/bak1/data/delta_log/435978159196147009/435978159196147010/435978159261483009/434574382554415105",
}
parser, err := NewBinlogParser(sampleSchema(), 2, 1024, chunkManager, flushFunc, 0)
parser, err := NewBinlogParser(ctx, sampleSchema(), 2, 1024, chunkManager, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, parser)
assert.Nil(t, err)
@ -165,6 +176,8 @@ func Test_BinlogParserConstructHolders(t *testing.T) {
}
func Test_BinlogParserConstructHoldersFailed(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
@ -174,7 +187,7 @@ func Test_BinlogParserConstructHoldersFailed(t *testing.T) {
listResult: make(map[string][]string),
}
parser, err := NewBinlogParser(sampleSchema(), 2, 1024, chunkManager, flushFunc, 0)
parser, err := NewBinlogParser(ctx, sampleSchema(), 2, 1024, chunkManager, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, parser)
assert.Nil(t, err)
@ -214,11 +227,13 @@ func Test_BinlogParserConstructHoldersFailed(t *testing.T) {
}
func Test_BinlogParserParseFilesFailed(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
parser, err := NewBinlogParser(sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0)
parser, err := NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, parser)
assert.Nil(t, err)
@ -231,6 +246,8 @@ func Test_BinlogParserParseFilesFailed(t *testing.T) {
}
func Test_BinlogParserParse(t *testing.T) {
ctx := context.Background()
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
return nil
}
@ -251,7 +268,7 @@ func Test_BinlogParserParse(t *testing.T) {
},
}
parser, err := NewBinlogParser(schema, 2, 1024, chunkManager, flushFunc, 0)
parser, err := NewBinlogParser(ctx, schema, 2, 1024, chunkManager, flushFunc, 0, math.MaxUint64)
assert.NotNil(t, parser)
assert.Nil(t, err)

View File

@ -0,0 +1,512 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importutil
import (
"context"
"errors"
"fmt"
"path"
"runtime/debug"
"strconv"
"strings"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/storage"
)
func isCanceled(ctx context.Context) bool {
// canceled?
select {
case <-ctx.Done():
return true
default:
break
}
return false
}
func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.FieldID]storage.FieldData {
segmentData := make(map[storage.FieldID]storage.FieldData)
// rowID field is a hidden field with fieldID=0, it is always auto-generated by IDAllocator
// if primary key is int64 and autoID=true, primary key field is equal to rowID field
segmentData[common.RowIDField] = &storage.Int64FieldData{
Data: make([]int64, 0),
NumRows: []int64{0},
}
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
switch schema.DataType {
case schemapb.DataType_Bool:
segmentData[schema.GetFieldID()] = &storage.BoolFieldData{
Data: make([]bool, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Float:
segmentData[schema.GetFieldID()] = &storage.FloatFieldData{
Data: make([]float32, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Double:
segmentData[schema.GetFieldID()] = &storage.DoubleFieldData{
Data: make([]float64, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int8:
segmentData[schema.GetFieldID()] = &storage.Int8FieldData{
Data: make([]int8, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int16:
segmentData[schema.GetFieldID()] = &storage.Int16FieldData{
Data: make([]int16, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int32:
segmentData[schema.GetFieldID()] = &storage.Int32FieldData{
Data: make([]int32, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int64:
segmentData[schema.GetFieldID()] = &storage.Int64FieldData{
Data: make([]int64, 0),
NumRows: []int64{0},
}
case schemapb.DataType_BinaryVector:
dim, _ := getFieldDimension(schema)
segmentData[schema.GetFieldID()] = &storage.BinaryVectorFieldData{
Data: make([]byte, 0),
NumRows: []int64{0},
Dim: dim,
}
case schemapb.DataType_FloatVector:
dim, _ := getFieldDimension(schema)
segmentData[schema.GetFieldID()] = &storage.FloatVectorFieldData{
Data: make([]float32, 0),
NumRows: []int64{0},
Dim: dim,
}
case schemapb.DataType_String, schemapb.DataType_VarChar:
segmentData[schema.GetFieldID()] = &storage.StringFieldData{
Data: make([]string, 0),
NumRows: []int64{0},
}
default:
log.Error("Import util: unsupported data type", zap.Int("DataType", int(schema.DataType)))
return nil
}
}
return segmentData
}
// initValidators constructs valiator methods and data conversion methods
func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[storage.FieldID]*Validator) error {
if collectionSchema == nil {
return errors.New("collection schema is nil")
}
// json decoder parse all the numeric value into float64
numericValidator := func(obj interface{}) error {
switch obj.(type) {
case float64:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := "illegal numeric value " + s
return errors.New(msg)
}
}
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
validators[schema.GetFieldID()] = &Validator{}
validators[schema.GetFieldID()].primaryKey = schema.GetIsPrimaryKey()
validators[schema.GetFieldID()].autoID = schema.GetAutoID()
validators[schema.GetFieldID()].fieldName = schema.GetName()
validators[schema.GetFieldID()].isString = false
switch schema.DataType {
case schemapb.DataType_Bool:
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch obj.(type) {
case bool:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := "illegal value " + s + " for bool type field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(bool)
field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value)
field.(*storage.BoolFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Float:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := float32(obj.(float64))
field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value)
field.(*storage.FloatFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Double:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(float64)
field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value)
field.(*storage.DoubleFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int8:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int8(obj.(float64))
field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value)
field.(*storage.Int8FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int16:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int16(obj.(float64))
field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value)
field.(*storage.Int16FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int32:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int32(obj.(float64))
field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value)
field.(*storage.Int32FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int64:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int64(obj.(float64))
field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value)
field.(*storage.Int64FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_BinaryVector:
dim, err := getFieldDimension(schema)
if err != nil {
return err
}
validators[schema.GetFieldID()].dimension = dim
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch vt := obj.(type) {
case []interface{}:
if len(vt)*8 != dim {
msg := "bit size " + strconv.Itoa(len(vt)*8) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName()
return errors.New(msg)
}
for i := 0; i < len(vt); i++ {
if e := numericValidator(vt[i]); e != nil {
msg := e.Error() + " for binary vector field " + schema.GetName()
return errors.New(msg)
}
t := int(vt[i].(float64))
if t > 255 || t < 0 {
msg := "illegal value " + strconv.Itoa(t) + " for binary vector field " + schema.GetName()
return errors.New(msg)
}
}
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not an array for binary vector field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
arr := obj.([]interface{})
for i := 0; i < len(arr); i++ {
value := byte(arr[i].(float64))
field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, value)
}
field.(*storage.BinaryVectorFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_FloatVector:
dim, err := getFieldDimension(schema)
if err != nil {
return err
}
validators[schema.GetFieldID()].dimension = dim
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch vt := obj.(type) {
case []interface{}:
if len(vt) != dim {
msg := "array size " + strconv.Itoa(len(vt)) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName()
return errors.New(msg)
}
for i := 0; i < len(vt); i++ {
if e := numericValidator(vt[i]); e != nil {
msg := e.Error() + " for float vector field " + schema.GetName()
return errors.New(msg)
}
}
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not an array for float vector field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
arr := obj.([]interface{})
for i := 0; i < len(arr); i++ {
value := float32(arr[i].(float64))
field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, value)
}
field.(*storage.FloatVectorFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_String, schemapb.DataType_VarChar:
validators[schema.GetFieldID()].isString = true
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch obj.(type) {
case string:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not a string for string type field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(string)
field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, value)
field.(*storage.StringFieldData).NumRows[0]++
return nil
}
default:
return errors.New("unsupport data type: " + strconv.Itoa(int(collectionSchema.Fields[i].DataType)))
}
}
return nil
}
func printFieldsDataInfo(fieldsData map[storage.FieldID]storage.FieldData, msg string, files []string) {
stats := make([]zapcore.Field, 0)
for k, v := range fieldsData {
stats = append(stats, zap.Int(strconv.FormatInt(k, 10), v.RowNum()))
}
if len(files) > 0 {
stats = append(stats, zap.Any("files", files))
}
log.Info(msg, stats...)
}
// GetFileNameAndExt extracts file name and extension
// for example: "/a/b/c.ttt" returns "c" and ".ttt"
func GetFileNameAndExt(filePath string) (string, string) {
fileName := path.Base(filePath)
fileType := path.Ext(fileName)
fileNameWithoutExt := strings.TrimSuffix(fileName, fileType)
return fileNameWithoutExt, fileType
}
// getFieldDimension gets dimension of vecotor field
func getFieldDimension(schema *schemapb.FieldSchema) (int, error) {
for _, kvPair := range schema.GetTypeParams() {
key, value := kvPair.GetKey(), kvPair.GetValue()
if key == "dim" {
dim, err := strconv.Atoi(value)
if err != nil {
return 0, errors.New("vector dimension is invalid")
}
return dim, nil
}
}
return 0, errors.New("vector dimension is not defined")
}
// triggerGC triggers golang gc to return all free memory back to the underlying system at once,
// Note: this operation is expensive, and can lead to latency spikes as it holds the heap lock through the whole process
func triggerGC() {
debug.FreeOSMemory()
}
// tryFlushBlocks does the two things:
// 1. if accumulate data of a block exceed blockSize, call callFlushFunc to generate new binlog file
// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest block
func tryFlushBlocks(ctx context.Context,
blocksData []map[storage.FieldID]storage.FieldData,
collectionSchema *schemapb.CollectionSchema,
callFlushFunc ImportFlushFunc,
blockSize int64,
maxTotalSize int64,
force bool) error {
totalSize := 0
biggestSize := 0
biggestItem := -1
// 1. if accumulate data of a block exceed blockSize, call callFlushFunc to generate new binlog file
for i := 0; i < len(blocksData); i++ {
// outside context might be canceled(service stop, or future enhancement for canceling import task)
if isCanceled(ctx) {
log.Error("Import util: import task was canceled")
return errors.New("import task was canceled")
}
blockData := blocksData[i]
// Note: even rowCount is 0, the size is still non-zero
size := 0
rowCount := 0
for _, fieldData := range blockData {
size += fieldData.GetMemorySize()
rowCount = fieldData.RowNum()
}
// force to flush, called at the end of Read()
if force && rowCount > 0 {
printFieldsDataInfo(blockData, "import util: prepare to force flush a block", nil)
err := callFlushFunc(blockData, i)
if err != nil {
log.Error("Import util: failed to force flush block data", zap.Int("shardID", i))
return err
}
log.Info("Import util: force flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i))
blocksData[i] = initSegmentData(collectionSchema)
if blocksData[i] == nil {
log.Error("Import util: failed to initialize FieldData list")
return errors.New("failed to initialize FieldData list")
}
continue
}
// if segment size is larger than predefined blockSize, flush to create a new binlog file
// initialize a new FieldData list for next round batch read
if size > int(blockSize) && rowCount > 0 {
printFieldsDataInfo(blockData, "import util: prepare to flush block larger than maxBlockSize", nil)
err := callFlushFunc(blockData, i)
if err != nil {
log.Error("Import util: failed to flush block data", zap.Int("shardID", i))
return err
}
log.Info("Import util: block size exceed limit and flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i))
blocksData[i] = initSegmentData(collectionSchema)
if blocksData[i] == nil {
log.Error("Import util: failed to initialize FieldData list")
return errors.New("failed to initialize FieldData list")
}
continue
}
// calculate the total size(ignore the flushed blocks)
// find out the biggest block for the step 2
totalSize += size
if size > biggestSize {
biggestSize = size
biggestItem = i
}
}
// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest block
if totalSize > int(maxTotalSize) && biggestItem >= 0 {
// outside context might be canceled(service stop, or future enhancement for canceling import task)
if isCanceled(ctx) {
log.Error("Import util: import task was canceled")
return errors.New("import task was canceled")
}
blockData := blocksData[biggestItem]
// Note: even rowCount is 0, the size is still non-zero
size := 0
rowCount := 0
for _, fieldData := range blockData {
size += fieldData.GetMemorySize()
rowCount = fieldData.RowNum()
}
if rowCount > 0 {
printFieldsDataInfo(blockData, "import util: prepare to flush biggest block", nil)
err := callFlushFunc(blockData, biggestItem)
if err != nil {
log.Error("Import util: failed to flush biggest block data", zap.Int("shardID", biggestItem))
return err
}
log.Info("Import util: total size exceed limit and flush", zap.Int("rowCount", rowCount),
zap.Int("size", size), zap.Int("totalSize", totalSize), zap.Int("shardID", biggestItem))
blocksData[biggestItem] = initSegmentData(collectionSchema)
if blocksData[biggestItem] == nil {
log.Error("Import util: failed to initialize FieldData list")
return errors.New("failed to initialize FieldData list")
}
}
}
return nil
}
func getTypeName(dt schemapb.DataType) string {
switch dt {
case schemapb.DataType_Bool:
return "Bool"
case schemapb.DataType_Int8:
return "Int8"
case schemapb.DataType_Int16:
return "Int16"
case schemapb.DataType_Int32:
return "Int32"
case schemapb.DataType_Int64:
return "Int64"
case schemapb.DataType_Float:
return "Float"
case schemapb.DataType_Double:
return "Double"
case schemapb.DataType_VarChar:
return "Varchar"
case schemapb.DataType_String:
return "String"
case schemapb.DataType_BinaryVector:
return "BinaryVector"
case schemapb.DataType_FloatVector:
return "FloatVector"
default:
return "InvalidType"
}
}

View File

@ -0,0 +1,460 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importutil
import (
"context"
"testing"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/stretchr/testify/assert"
)
func sampleSchema() *schemapb.CollectionSchema {
schema := &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 102,
Name: "field_bool",
IsPrimaryKey: false,
Description: "bool",
DataType: schemapb.DataType_Bool,
},
{
FieldID: 103,
Name: "field_int8",
IsPrimaryKey: false,
Description: "int8",
DataType: schemapb.DataType_Int8,
},
{
FieldID: 104,
Name: "field_int16",
IsPrimaryKey: false,
Description: "int16",
DataType: schemapb.DataType_Int16,
},
{
FieldID: 105,
Name: "field_int32",
IsPrimaryKey: false,
Description: "int32",
DataType: schemapb.DataType_Int32,
},
{
FieldID: 106,
Name: "field_int64",
IsPrimaryKey: true,
AutoID: false,
Description: "int64",
DataType: schemapb.DataType_Int64,
},
{
FieldID: 107,
Name: "field_float",
IsPrimaryKey: false,
Description: "float",
DataType: schemapb.DataType_Float,
},
{
FieldID: 108,
Name: "field_double",
IsPrimaryKey: false,
Description: "double",
DataType: schemapb.DataType_Double,
},
{
FieldID: 109,
Name: "field_string",
IsPrimaryKey: false,
Description: "string",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "128"},
},
},
{
FieldID: 110,
Name: "field_binary_vector",
IsPrimaryKey: false,
Description: "binary_vector",
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "16"},
},
},
{
FieldID: 111,
Name: "field_float_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
return schema
}
func strKeySchema() *schemapb.CollectionSchema {
schema := &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "uid",
IsPrimaryKey: true,
AutoID: false,
Description: "uid",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "1024"},
},
},
{
FieldID: 102,
Name: "int_scalar",
IsPrimaryKey: false,
Description: "int_scalar",
DataType: schemapb.DataType_Int32,
},
{
FieldID: 103,
Name: "float_scalar",
IsPrimaryKey: false,
Description: "float_scalar",
DataType: schemapb.DataType_Float,
},
{
FieldID: 104,
Name: "string_scalar",
IsPrimaryKey: false,
Description: "string_scalar",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "128"},
},
},
{
FieldID: 105,
Name: "bool_scalar",
IsPrimaryKey: false,
Description: "bool_scalar",
DataType: schemapb.DataType_Bool,
},
{
FieldID: 106,
Name: "vectors",
IsPrimaryKey: false,
Description: "vectors",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
return schema
}
func Test_IsCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
assert.False(t, isCanceled(ctx))
cancel()
assert.True(t, isCanceled(ctx))
}
func Test_InitSegmentData(t *testing.T) {
testFunc := func(schema *schemapb.CollectionSchema) {
fields := initSegmentData(schema)
assert.Equal(t, len(schema.Fields)+1, len(fields))
for _, field := range schema.Fields {
data, ok := fields[field.FieldID]
assert.True(t, ok)
assert.NotNil(t, data)
}
printFieldsDataInfo(fields, "dummy", []string{})
}
testFunc(sampleSchema())
testFunc(strKeySchema())
}
func Test_InitValidators(t *testing.T) {
validators := make(map[storage.FieldID]*Validator)
err := initValidators(nil, validators)
assert.NotNil(t, err)
schema := sampleSchema()
// success case
err = initValidators(schema, validators)
assert.Nil(t, err)
assert.Equal(t, len(schema.Fields), len(validators))
name2ID := make(map[string]storage.FieldID)
for _, field := range schema.Fields {
name2ID[field.GetName()] = field.GetFieldID()
}
checkFunc := func(funcName string, validVal interface{}, invalidVal interface{}) {
id := name2ID[funcName]
v, ok := validators[id]
assert.True(t, ok)
err = v.validateFunc(validVal)
assert.Nil(t, err)
err = v.validateFunc(invalidVal)
assert.NotNil(t, err)
}
// validate functions
var validVal interface{} = true
var invalidVal interface{} = "aa"
checkFunc("field_bool", validVal, invalidVal)
validVal = float64(100)
invalidVal = "aa"
checkFunc("field_int8", validVal, invalidVal)
checkFunc("field_int16", validVal, invalidVal)
checkFunc("field_int32", validVal, invalidVal)
checkFunc("field_int64", validVal, invalidVal)
checkFunc("field_float", validVal, invalidVal)
checkFunc("field_double", validVal, invalidVal)
validVal = "aa"
invalidVal = 100
checkFunc("field_string", validVal, invalidVal)
validVal = []interface{}{float64(100), float64(101)}
invalidVal = "aa"
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(100)}
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(100), float64(101), float64(102)}
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{true, true}
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(255), float64(-1)}
checkFunc("field_binary_vector", validVal, invalidVal)
validVal = []interface{}{float64(1), float64(2), float64(3), float64(4)}
invalidVal = true
checkFunc("field_float_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(1), float64(2), float64(3)}
checkFunc("field_float_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5)}
checkFunc("field_float_vector", validVal, invalidVal)
invalidVal = []interface{}{"a", "b", "c", "d"}
checkFunc("field_float_vector", validVal, invalidVal)
// error cases
schema = &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: make([]*schemapb.FieldSchema, 0),
}
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 111,
Name: "field_float_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "aa"},
},
})
validators = make(map[storage.FieldID]*Validator)
err = initValidators(schema, validators)
assert.NotNil(t, err)
schema.Fields = make([]*schemapb.FieldSchema, 0)
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 110,
Name: "field_binary_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "aa"},
},
})
err = initValidators(schema, validators)
assert.NotNil(t, err)
}
func Test_GetFileNameAndExt(t *testing.T) {
filePath := "aaa/bbb/ccc.txt"
name, ext := GetFileNameAndExt(filePath)
assert.EqualValues(t, "ccc", name)
assert.EqualValues(t, ".txt", ext)
}
func Test_GetFieldDimension(t *testing.T) {
schema := &schemapb.FieldSchema{
FieldID: 111,
Name: "field_float_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
}
dim, err := getFieldDimension(schema)
assert.Nil(t, err)
assert.Equal(t, 4, dim)
schema.TypeParams = []*commonpb.KeyValuePair{
{Key: "dim", Value: "abc"},
}
dim, err = getFieldDimension(schema)
assert.NotNil(t, err)
assert.Equal(t, 0, dim)
schema.TypeParams = []*commonpb.KeyValuePair{}
dim, err = getFieldDimension(schema)
assert.NotNil(t, err)
assert.Equal(t, 0, dim)
}
func Test_TryFlushBlocks(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
flushCounter := 0
flushRowCount := 0
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
flushCounter++
rowCount := 0
for _, v := range fields {
rowCount = v.RowNum()
break
}
flushRowCount += rowCount
for _, v := range fields {
assert.Equal(t, rowCount, v.RowNum())
}
return nil
}
blockSize := int64(1024)
maxTotalSize := int64(2048)
shardNum := int32(3)
// prepare flush data, 3 shards, each shard 10 rows
rowCount := 10
fieldsData := createFieldsData(rowCount)
// non-force flush
segmentsData := createSegmentsData(fieldsData, shardNum)
err := tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, false)
assert.Nil(t, err)
assert.Equal(t, 0, flushCounter)
assert.Equal(t, 0, flushRowCount)
// force flush
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true)
assert.Nil(t, err)
assert.Equal(t, int(shardNum), flushCounter)
assert.Equal(t, rowCount*int(shardNum), flushRowCount)
// after force flush, no data left
flushCounter = 0
flushRowCount = 0
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true)
assert.Nil(t, err)
assert.Equal(t, 0, flushCounter)
assert.Equal(t, 0, flushRowCount)
// flush when segment size exceeds blockSize
segmentsData = createSegmentsData(fieldsData, shardNum)
blockSize = 100 // blockSize is 100 bytes, less than the 10 rows size
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, false)
assert.Nil(t, err)
assert.Equal(t, int(shardNum), flushCounter)
assert.Equal(t, rowCount*int(shardNum), flushRowCount)
flushCounter = 0
flushRowCount = 0
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true) // no data left
assert.Nil(t, err)
assert.Equal(t, 0, flushCounter)
assert.Equal(t, 0, flushRowCount)
// flush when segments total size exceeds maxTotalSize
segmentsData = createSegmentsData(fieldsData, shardNum)
blockSize = 4096 // blockSize is 4096 bytes, larger than the 10 rows size
maxTotalSize = 100 // maxTotalSize is 100 bytes, less than the 30 rows size
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, false)
assert.Nil(t, err)
assert.Equal(t, 1, flushCounter) // only the max segment is flushed
assert.Equal(t, 10, flushRowCount)
flushCounter = 0
flushRowCount = 0
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true) // two segments left
assert.Nil(t, err)
assert.Equal(t, 2, flushCounter)
assert.Equal(t, 20, flushRowCount)
// canceled
cancel()
flushCounter = 0
flushRowCount = 0
segmentsData = createSegmentsData(fieldsData, shardNum)
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true)
assert.Error(t, err)
assert.Equal(t, 0, flushCounter)
assert.Equal(t, 0, flushRowCount)
}
func Test_GetTypeName(t *testing.T) {
str := getTypeName(schemapb.DataType_Bool)
assert.NotEmpty(t, str)
str = getTypeName(schemapb.DataType_Int8)
assert.NotEmpty(t, str)
str = getTypeName(schemapb.DataType_Int16)
assert.NotEmpty(t, str)
str = getTypeName(schemapb.DataType_Int32)
assert.NotEmpty(t, str)
str = getTypeName(schemapb.DataType_Int64)
assert.NotEmpty(t, str)
str = getTypeName(schemapb.DataType_Float)
assert.NotEmpty(t, str)
str = getTypeName(schemapb.DataType_Double)
assert.NotEmpty(t, str)
str = getTypeName(schemapb.DataType_VarChar)
assert.NotEmpty(t, str)
str = getTypeName(schemapb.DataType_String)
assert.NotEmpty(t, str)
str = getTypeName(schemapb.DataType_BinaryVector)
assert.NotEmpty(t, str)
str = getTypeName(schemapb.DataType_FloatVector)
assert.NotEmpty(t, str)
str = getTypeName(schemapb.DataType_None)
assert.Equal(t, "InvalidType", str)
}

View File

@ -19,21 +19,17 @@ package importutil
import (
"bufio"
"context"
"errors"
"fmt"
"math"
"path"
"runtime/debug"
"strconv"
"strings"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/retry"
@ -45,6 +41,9 @@ const (
JSONFileExt = ".json"
NumpyFileExt = ".npy"
// supposed size of a single block, to control a binlog file size, the max biglog file size is no more than 2*SingleBlockSize
SingleBlockSize = 16 * 1024 * 1024 // 16MB
// this limitation is to avoid this OOM risk:
// for column-based file, we read all its data into memory, if user input a large file, the read() method may
// cost extra memory and lear to OOM.
@ -64,26 +63,60 @@ const (
// ReportImportAttempts is the maximum # of attempts to retry when import fails.
var ReportImportAttempts uint = 10
type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardID int) error
type AssignSegmentFunc func(shardID int) (int64, string, error)
type CreateBinlogsFunc func(fields map[storage.FieldID]storage.FieldData, segmentID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error)
type SaveSegmentFunc func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64) error
type WorkingSegment struct {
segmentID int64 // segment ID
shardID int // shard id
targetChName string // target dml channel
rowCount int64 // accumulate row count
memSize int // total memory size of all binlogs
fieldsInsert []*datapb.FieldBinlog // persisted binlogs
fieldsStats []*datapb.FieldBinlog // stats of persisted binlogs
}
type ImportOptions struct {
OnlyValidate bool
TsStartPoint uint64
TsEndPoint uint64
}
func DefaultImportOptions() ImportOptions {
options := ImportOptions{
OnlyValidate: false,
TsStartPoint: 0,
TsEndPoint: math.MaxUint64,
}
return options
}
type ImportWrapper struct {
ctx context.Context // for canceling parse process
cancel context.CancelFunc // for canceling parse process
collectionSchema *schemapb.CollectionSchema // collection schema
shardNum int32 // sharding number of the collection
segmentSize int64 // maximum size of a segment(unit:byte)
segmentSize int64 // maximum size of a segment(unit:byte) defined by dataCoord.segment.maxSize (milvus.yml)
rowIDAllocator *allocator.IDAllocator // autoid allocator
chunkManager storage.ChunkManager
callFlushFunc ImportFlushFunc // call back function to flush a segment
assignSegmentFunc AssignSegmentFunc // function to prepare a new segment
createBinlogsFunc CreateBinlogsFunc // function to create binlog for a segment
saveSegmentFunc SaveSegmentFunc // function to persist a segment
importResult *rootcoordpb.ImportResult // import result
reportFunc func(res *rootcoordpb.ImportResult) error // report import state to rootcoord
workingSegments map[int]*WorkingSegment // a map shard id to working segments
}
func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int64,
idAlloc *allocator.IDAllocator, cm storage.ChunkManager, flushFunc ImportFlushFunc,
importResult *rootcoordpb.ImportResult, reportFunc func(res *rootcoordpb.ImportResult) error) *ImportWrapper {
idAlloc *allocator.IDAllocator, cm storage.ChunkManager, importResult *rootcoordpb.ImportResult,
reportFunc func(res *rootcoordpb.ImportResult) error) *ImportWrapper {
if collectionSchema == nil {
log.Error("import error: collection schema is nil")
log.Error("import wrapper: collection schema is nil")
return nil
}
@ -111,96 +144,143 @@ func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.Collection
shardNum: shardNum,
segmentSize: segmentSize,
rowIDAllocator: idAlloc,
callFlushFunc: flushFunc,
chunkManager: cm,
importResult: importResult,
reportFunc: reportFunc,
workingSegments: make(map[int]*WorkingSegment),
}
return wrapper
}
// this method can be used to cancel parse process
func (p *ImportWrapper) SetCallbackFunctions(assignSegmentFunc AssignSegmentFunc, createBinlogsFunc CreateBinlogsFunc, saveSegmentFunc SaveSegmentFunc) error {
if assignSegmentFunc == nil {
log.Error("import wrapper: callback function AssignSegmentFunc is nil")
return fmt.Errorf("import wrapper: callback function AssignSegmentFunc is nil")
}
if createBinlogsFunc == nil {
log.Error("import wrapper: callback function CreateBinlogsFunc is nil")
return fmt.Errorf("import wrapper: callback function CreateBinlogsFunc is nil")
}
if saveSegmentFunc == nil {
log.Error("import wrapper: callback function SaveSegmentFunc is nil")
return fmt.Errorf("import wrapper: callback function SaveSegmentFunc is nil")
}
p.assignSegmentFunc = assignSegmentFunc
p.createBinlogsFunc = createBinlogsFunc
p.saveSegmentFunc = saveSegmentFunc
return nil
}
// Cancel method can be used to cancel parse process
func (p *ImportWrapper) Cancel() error {
p.cancel()
return nil
}
func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[storage.FieldID]storage.FieldData, msg string, files []string) {
stats := make([]zapcore.Field, 0)
for k, v := range fieldsData {
stats = append(stats, zap.Int(strconv.FormatInt(k, 10), v.RowNum()))
func (p *ImportWrapper) validateColumnBasedFiles(filePaths []string, collectionSchema *schemapb.CollectionSchema) error {
requiredFieldNames := make(map[string]interface{})
for _, schema := range p.collectionSchema.Fields {
if schema.GetIsPrimaryKey() {
if !schema.GetAutoID() {
requiredFieldNames[schema.GetName()] = nil
}
} else {
requiredFieldNames[schema.GetName()] = nil
}
}
if len(files) > 0 {
stats = append(stats, zap.Any("files", files))
// check redundant file
fileNames := make(map[string]interface{})
for _, filePath := range filePaths {
name, _ := GetFileNameAndExt(filePath)
fileNames[name] = nil
_, ok := requiredFieldNames[name]
if !ok {
log.Error("import wrapper: the file has no corresponding field in collection", zap.String("fieldName", name))
return fmt.Errorf("import wrapper: the file '%s' has no corresponding field in collection", filePath)
}
log.Info(msg, stats...)
}
// check missed file
for name := range requiredFieldNames {
_, ok := fileNames[name]
if !ok {
log.Error("import wrapper: there is no file corresponding to field", zap.String("fieldName", name))
return fmt.Errorf("import wrapper: there is no file corresponding to field '%s'", name)
}
}
return nil
}
func getFileNameAndExt(filePath string) (string, string) {
fileName := path.Base(filePath)
fileType := path.Ext(fileName)
fileNameWithoutExt := strings.TrimSuffix(fileName, fileType)
return fileNameWithoutExt, fileType
}
// trigger golang gc to return all free memory back to the underlying system at once,
// Note: this operation is expensive, and can lead to latency spikes as it holds the heap lock through the whole process
func triggerGC() {
debug.FreeOSMemory()
}
func (p *ImportWrapper) fileValidation(filePaths []string, rowBased bool) error {
// fileValidation verify the input paths
// if all the files are json type, return true
// if all the files are numpy type, return false, and not allow duplicate file name
func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) {
// use this map to check duplicate file name(only for numpy file)
fileNames := make(map[string]struct{})
totalSize := int64(0)
rowBased := false
for i := 0; i < len(filePaths); i++ {
filePath := filePaths[i]
name, fileType := getFileNameAndExt(filePath)
_, ok := fileNames[name]
if ok {
// only check dupliate numpy file
if fileType == NumpyFileExt {
log.Error("import wrapper: duplicate file name", zap.String("fileName", name+"."+fileType))
return errors.New("duplicate file: " + name + "." + fileType)
name, fileType := GetFileNameAndExt(filePath)
// only allow json file or numpy file
if fileType != JSONFileExt && fileType != NumpyFileExt {
log.Error("import wrapper: unsupportted file type", zap.String("filePath", filePath))
return false, fmt.Errorf("import wrapper: unsupportted file type: '%s'", filePath)
}
} else {
fileNames[name] = struct{}{}
// we use the first file to determine row-based or column-based
if i == 0 && fileType == JSONFileExt {
rowBased = true
}
// check file type
// row-based only support json type, column-based can support json and numpy type
// row-based only support json type, column-based only support numpy type
if rowBased {
if fileType != JSONFileExt {
log.Error("import wrapper: unsupported file type for row-based mode", zap.String("filePath", filePath))
return errors.New("unsupported file type for row-based mode: " + filePath)
return rowBased, fmt.Errorf("import wrapper: unsupported file type for row-based mode: '%s'", filePath)
}
} else {
if fileType != JSONFileExt && fileType != NumpyFileExt {
if fileType != NumpyFileExt {
log.Error("import wrapper: unsupported file type for column-based mode", zap.String("filePath", filePath))
return errors.New("unsupported file type for column-based mode: " + filePath)
return rowBased, fmt.Errorf("import wrapper: unsupported file type for column-based mode: '%s'", filePath)
}
}
// check dupliate file
_, ok := fileNames[name]
if ok {
log.Error("import wrapper: duplicate file name", zap.String("filePath", filePath))
return rowBased, fmt.Errorf("import wrapper: duplicate file: '%s'", filePath)
}
fileNames[name] = struct{}{}
// check file size, single file size cannot exceed MaxFileSize
// TODO add context
size, err := p.chunkManager.Size(context.TODO(), filePath)
if err != nil {
log.Error("import wrapper: failed to get file size", zap.String("filePath", filePath), zap.Any("err", err))
return errors.New("failed to get file size of " + filePath)
return rowBased, fmt.Errorf("import wrapper: failed to get file size of '%s'", filePath)
}
// empty file
if size == 0 {
log.Error("import wrapper: file path is empty", zap.String("filePath", filePath))
return errors.New("the file " + filePath + " is empty")
log.Error("import wrapper: file size is zero", zap.String("filePath", filePath))
return rowBased, fmt.Errorf("import wrapper: the file '%s' size is zero", filePath)
}
if size > MaxFileSize {
log.Error("import wrapper: file size exceeds the maximum size", zap.String("filePath", filePath),
zap.Int64("fileSize", size), zap.Int64("MaxFileSize", MaxFileSize))
return errors.New("the file " + filePath + " size exceeds the maximum size: " + strconv.FormatInt(MaxFileSize, 10) + " bytes")
return rowBased, fmt.Errorf("import wrapper: the file '%s' size exceeds the maximum size: %d bytes", filePath, MaxFileSize)
}
totalSize += size
}
@ -208,26 +288,36 @@ func (p *ImportWrapper) fileValidation(filePaths []string, rowBased bool) error
// especially for column-base, total size of files cannot exceed MaxTotalSizeInMemory
if totalSize > MaxTotalSizeInMemory {
log.Error("import wrapper: total size of files exceeds the maximum size", zap.Int64("totalSize", totalSize), zap.Int64("MaxTotalSize", MaxTotalSizeInMemory))
return errors.New("the total size of all files exceeds the maximum size: " + strconv.FormatInt(MaxTotalSizeInMemory, 10) + " bytes")
return rowBased, fmt.Errorf("import wrapper: total size(%d bytes) of all files exceeds the maximum size: %d bytes", totalSize, MaxTotalSizeInMemory)
}
return nil
// check redundant files for column-based import
// if the field is primary key and autoid is false, the file is required
// any redundant file is not allowed
if !rowBased {
err := p.validateColumnBasedFiles(filePaths, p.collectionSchema)
if err != nil {
return rowBased, err
}
}
return rowBased, nil
}
// import process entry
// Import is the entry of import operation
// filePath and rowBased are from ImportTask
// if onlyValidate is true, this process only do validation, no data generated, callFlushFunc will not be called
func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate bool) error {
log.Info("import wrapper: filePaths", zap.Any("filePaths", filePaths))
// if onlyValidate is true, this process only do validation, no data generated, flushFunc will not be called
func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error {
log.Info("import wrapper: begin import", zap.Any("filePaths", filePaths), zap.Any("options", options))
// data restore function to import milvus native binlog files(for backup/restore tools)
// the backup/restore tool provide two paths for a partition, the first path is binlog path, the second is deltalog path
if p.isBinlogImport(filePaths) {
// TODO: handle the timestamp end point passed from client side, currently use math.MaxUint64
return p.doBinlogImport(filePaths, math.MaxUint64)
return p.doBinlogImport(filePaths, options.TsStartPoint, options.TsEndPoint)
}
// normal logic for import general data files
err := p.fileValidation(filePaths, rowBased)
rowBased, err := p.fileValidation(filePaths)
if err != nil {
return err
}
@ -235,16 +325,16 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
if rowBased {
// parse and consume row-based files
// for row-based files, the JSONRowConsumer will generate autoid for primary key, and split rows into segments
// according to shard number, so the callFlushFunc will be called in the JSONRowConsumer
// according to shard number, so the flushFunc will be called in the JSONRowConsumer
for i := 0; i < len(filePaths); i++ {
filePath := filePaths[i]
_, fileType := getFileNameAndExt(filePath)
_, fileType := GetFileNameAndExt(filePath)
log.Info("import wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
if fileType == JSONFileExt {
err = p.parseRowBasedJSON(filePath, onlyValidate)
err = p.parseRowBasedJSON(filePath, options.OnlyValidate)
if err != nil {
log.Error("import error: "+err.Error(), zap.String("filePath", filePath))
log.Error("import wrapper: failed to parse row-based json file", zap.Any("err", err), zap.String("filePath", filePath))
return err
}
} // no need to check else, since the fileValidation() already do this
@ -260,7 +350,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
fieldsData := initSegmentData(p.collectionSchema)
if fieldsData == nil {
log.Error("import wrapper: failed to initialize FieldData list")
return errors.New("failed to initialize FieldData list")
return fmt.Errorf("import wrapper: failed to initialize FieldData list")
}
rowCount := 0
@ -271,25 +361,32 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
return nil
}
p.printFieldsDataInfo(fields, "import wrapper: combine field data", nil)
printFieldsDataInfo(fields, "import wrapper: combine field data", nil)
tr := timerecord.NewTimeRecorder("combine field data")
defer tr.Elapse("finished")
for k, v := range fields {
// ignore 0 row field
if v.RowNum() == 0 {
log.Warn("import wrapper: empty FieldData ignored", zap.Int64("fieldID", k))
continue
}
// ignore internal fields: RowIDField and TimeStampField
if k == common.RowIDField || k == common.TimeStampField {
log.Warn("import wrapper: internal fields should not be provided", zap.Int64("fieldID", k))
continue
}
// each column should be only combined once
data, ok := fieldsData[k]
if ok && data.RowNum() > 0 {
return errors.New("the field " + strconv.FormatInt(k, 10) + " is duplicated")
return fmt.Errorf("the field %d is duplicated", k)
}
// check the row count. only count non-zero row fields
if rowCount > 0 && rowCount != v.RowNum() {
return errors.New("the field " + strconv.FormatInt(k, 10) + " row count " + strconv.Itoa(v.RowNum()) + " doesn't equal " + strconv.Itoa(rowCount))
return fmt.Errorf("the field %d row count %d doesn't equal others row count: %d", k, v.RowNum(), rowCount)
}
rowCount = v.RowNum()
@ -303,20 +400,14 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
// parse/validate/consume data
for i := 0; i < len(filePaths); i++ {
filePath := filePaths[i]
_, fileType := getFileNameAndExt(filePath)
_, fileType := GetFileNameAndExt(filePath)
log.Info("import wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
if fileType == JSONFileExt {
err = p.parseColumnBasedJSON(filePath, onlyValidate, combineFunc)
if err != nil {
log.Error("import error: "+err.Error(), zap.String("filePath", filePath))
return err
}
} else if fileType == NumpyFileExt {
err = p.parseColumnBasedNumpy(filePath, onlyValidate, combineFunc)
if fileType == NumpyFileExt {
err = p.parseColumnBasedNumpy(filePath, options.OnlyValidate, combineFunc)
if err != nil {
log.Error("import error: "+err.Error(), zap.String("filePath", filePath))
log.Error("import wrapper: failed to parse column-based numpy file", zap.Any("err", err), zap.String("filePath", filePath))
return err
}
}
@ -327,7 +418,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
triggerGC()
// split fields data into segments
err := p.splitFieldsData(fieldsData, filePaths)
err := p.splitFieldsData(fieldsData, SingleBlockSize)
if err != nil {
return err
}
@ -339,7 +430,14 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
return p.reportPersisted()
}
// reportPersisted notify the rootcoord to mark the task state to be ImportPersisted
func (p *ImportWrapper) reportPersisted() error {
// force close all segments
err := p.closeAllWorkingSegments()
if err != nil {
return err
}
// report file process state
p.importResult.State = commonpb.ImportState_ImportPersisted
// persist state task is valuable, retry more times in case fail this task only because of network error
@ -353,6 +451,7 @@ func (p *ImportWrapper) reportPersisted() error {
return nil
}
// isBinlogImport is to judge whether it is binlog import operation
// For internal usage by the restore tool: https://github.com/zilliztech/milvus-backup
// This tool exports data from a milvus service, and call bulkload interface to import native data into another milvus service.
// This tool provides two paths: one is data log path of a partition,the other is delta log path of this partition.
@ -360,16 +459,16 @@ func (p *ImportWrapper) reportPersisted() error {
func (p *ImportWrapper) isBinlogImport(filePaths []string) bool {
// must contains the insert log path, and the delta log path is optional
if len(filePaths) != 1 && len(filePaths) != 2 {
log.Info("import wrapper: paths count is not 1 or 2", zap.Int("len", len(filePaths)))
log.Info("import wrapper: paths count is not 1 or 2, not binlog import", zap.Int("len", len(filePaths)))
return false
}
for i := 0; i < len(filePaths); i++ {
filePath := filePaths[i]
_, fileType := getFileNameAndExt(filePath)
_, fileType := GetFileNameAndExt(filePath)
// contains file extension, is not a path
if len(fileType) != 0 {
log.Info("import wrapper: not a path", zap.String("filePath", filePath), zap.String("fileType", fileType))
log.Info("import wrapper: not a path, not binlog import", zap.String("filePath", filePath), zap.String("fileType", fileType))
return false
}
}
@ -378,12 +477,14 @@ func (p *ImportWrapper) isBinlogImport(filePaths []string) bool {
return true
}
func (p *ImportWrapper) doBinlogImport(filePaths []string, tsEndPoint uint64) error {
// doBinlogImport is the entry of binlog import operation
func (p *ImportWrapper) doBinlogImport(filePaths []string, tsStartPoint uint64, tsEndPoint uint64) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
p.printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths)
return p.callFlushFunc(fields, shardID)
printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths)
return p.flushFunc(fields, shardID)
}
parser, err := NewBinlogParser(p.collectionSchema, p.shardNum, p.segmentSize, p.chunkManager, flushFunc, tsEndPoint)
parser, err := NewBinlogParser(p.ctx, p.collectionSchema, p.shardNum, SingleBlockSize, p.chunkManager, flushFunc,
tsStartPoint, tsEndPoint)
if err != nil {
return err
}
@ -396,6 +497,7 @@ func (p *ImportWrapper) doBinlogImport(filePaths []string, tsEndPoint uint64) er
return p.reportPersisted()
}
// parseRowBasedJSON is the entry of row-based json import operation
func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) error {
tr := timerecord.NewTimeRecorder("json row-based parser: " + filePath)
@ -417,14 +519,16 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er
if !onlyValidate {
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
var filePaths = []string{filePath}
p.printFieldsDataInfo(fields, "import wrapper: prepare to flush segment", filePaths)
return p.callFlushFunc(fields, shardID)
printFieldsDataInfo(fields, "import wrapper: prepare to flush binlogs", filePaths)
return p.flushFunc(fields, shardID)
}
consumer, err = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, p.segmentSize, flushFunc)
consumer, err = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, SingleBlockSize, flushFunc)
if err != nil {
return err
}
}
validator, err := NewJSONRowValidator(p.collectionSchema, consumer)
if err != nil {
return err
@ -444,52 +548,14 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er
return nil
}
func (p *ImportWrapper) parseColumnBasedJSON(filePath string, onlyValidate bool,
combineFunc func(fields map[storage.FieldID]storage.FieldData) error) error {
tr := timerecord.NewTimeRecorder("json column-based parser: " + filePath)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// for minio storage, chunkManager will download file into local memory
// for local storage, chunkManager open the file directly
file, err := p.chunkManager.Reader(ctx, filePath)
if err != nil {
return err
}
defer file.Close()
// parse file
reader := bufio.NewReader(file)
parser := NewJSONParser(p.ctx, p.collectionSchema)
var consumer *JSONColumnConsumer
if !onlyValidate {
consumer, err = NewJSONColumnConsumer(p.collectionSchema, combineFunc)
if err != nil {
return err
}
}
validator, err := NewJSONColumnValidator(p.collectionSchema, consumer)
if err != nil {
return err
}
err = parser.ParseColumns(reader, validator)
if err != nil {
return err
}
tr.Elapse("parsed")
return nil
}
// parseColumnBasedNumpy is the entry of column-based numpy import operation
func (p *ImportWrapper) parseColumnBasedNumpy(filePath string, onlyValidate bool,
combineFunc func(fields map[storage.FieldID]storage.FieldData) error) error {
tr := timerecord.NewTimeRecorder("numpy parser: " + filePath)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fileName, _ := getFileNameAndExt(filePath)
fileName, _ := GetFileNameAndExt(filePath)
// for minio storage, chunkManager will download file into local memory
// for local storage, chunkManager open the file directly
@ -532,6 +598,7 @@ func (p *ImportWrapper) parseColumnBasedNumpy(filePath string, onlyValidate bool
return nil
}
// appendFunc defines the methods to append data to storage.FieldData
func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error {
switch schema.DataType {
case schemapb.DataType_Bool:
@ -608,13 +675,14 @@ func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storag
}
}
func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, files []string) error {
// splitFieldsData is to split the in-memory data(parsed from column-based files) into blocks, each block save to a binlog file
func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, blockSize int64) error {
if len(fieldsData) == 0 {
log.Error("import wrapper: fields data is empty")
return errors.New("import error: fields data is empty")
return fmt.Errorf("import wrapper: fields data is empty")
}
tr := timerecord.NewTimeRecorder("split field data")
tr := timerecord.NewTimeRecorder("import wrapper: split field data")
defer tr.Elapse("finished")
// check existence of each field
@ -633,7 +701,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
v, ok := fieldsData[schema.GetFieldID()]
if !ok {
log.Error("import wrapper: field not provided", zap.String("fieldName", schema.GetName()))
return errors.New("import error: field " + schema.GetName() + " not provided")
return fmt.Errorf("import wrapper: field '%s' not provided", schema.GetName())
}
rowCounter[schema.GetName()] = v.RowNum()
if v.RowNum() > rowCount {
@ -643,27 +711,30 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
}
if primaryKey == nil {
log.Error("import wrapper: primary key field is not found")
return errors.New("import error: primary key field is not found")
return fmt.Errorf("import wrapper: primary key field is not found")
}
for name, count := range rowCounter {
if count != rowCount {
log.Error("import wrapper: field row count is not equal to other fields row count", zap.String("fieldName", name),
zap.Int("rowCount", count), zap.Int("otherRowCount", rowCount))
return errors.New("import error: field " + name + " row count " + strconv.Itoa(count) + " is not equal to other fields row count " + strconv.Itoa(rowCount))
return fmt.Errorf("import wrapper: field '%s' row count %d is not equal to other fields row count: %d", name, count, rowCount)
}
}
log.Info("import wrapper: try to split a block with row count", zap.Int("rowCount", rowCount), zap.Any("rowCountOfEachField", rowCounter))
primaryData, ok := fieldsData[primaryKey.GetFieldID()]
if !ok {
log.Error("import wrapper: primary key field is not provided", zap.String("keyName", primaryKey.GetName()))
return errors.New("import error: primary key field is not provided")
return fmt.Errorf("import wrapper: primary key field is not provided")
}
// generate auto id for primary key and rowid field
var rowIDBegin typeutil.UniqueID
var rowIDEnd typeutil.UniqueID
rowIDBegin, rowIDEnd, _ = p.rowIDAllocator.Alloc(uint32(rowCount))
rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount))
if err != nil {
log.Error("import wrapper: failed to alloc row ID", zap.Any("err", err))
return err
}
rowIDField := fieldsData[common.RowIDField]
rowIDFieldArr := rowIDField.(*storage.Int64FieldData)
@ -672,19 +743,25 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
}
if primaryKey.GetAutoID() {
log.Info("import wrapper: generating auto-id", zap.Any("rowCount", rowCount))
log.Info("import wrapper: generating auto-id", zap.Int("rowCount", rowCount), zap.Int64("rowIDBegin", rowIDBegin))
primaryDataArr := primaryData.(*storage.Int64FieldData)
// reset the primary keys, as we know, only int64 pk can be auto-generated
primaryDataArr := &storage.Int64FieldData{
NumRows: []int64{int64(rowCount)},
Data: make([]int64, 0, rowCount),
}
for i := rowIDBegin; i < rowIDEnd; i++ {
primaryDataArr.Data = append(primaryDataArr.Data, i)
}
primaryData = primaryDataArr
fieldsData[primaryKey.GetFieldID()] = primaryData
p.importResult.AutoIds = append(p.importResult.AutoIds, rowIDBegin, rowIDEnd)
}
if primaryData.RowNum() <= 0 {
log.Error("import wrapper: primary key not provided", zap.String("keyName", primaryKey.GetName()))
return errors.New("import wrapper: primary key " + primaryKey.GetName() + " not provided")
return fmt.Errorf("import wrapper: the primary key '%s' not provided", primaryKey.GetName())
}
// prepare segemnts
@ -693,7 +770,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
segmentData := initSegmentData(p.collectionSchema)
if segmentData == nil {
log.Error("import wrapper: failed to initialize FieldData list")
return errors.New("failed to initialize FieldData list")
return fmt.Errorf("import wrapper: failed to initialize FieldData list")
}
segmentsData = append(segmentsData, segmentData)
}
@ -705,12 +782,12 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
appendFuncErr := p.appendFunc(schema)
if appendFuncErr == nil {
log.Error("import wrapper: unsupported field data type")
return errors.New("import wrapper: unsupported field data type")
return fmt.Errorf("import wrapper: unsupported field data type")
}
appendFunctions[schema.GetName()] = appendFuncErr
}
// split data into segments
// split data into shards
for i := 0; i < rowCount; i++ {
// hash to a shard number
var shard uint32
@ -723,7 +800,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
intPK, ok := interface{}(pk).(int64)
if !ok {
log.Error("import wrapper: primary key field must be int64 or varchar")
return errors.New("import error: primary key field must be int64 or varchar")
return fmt.Errorf("import wrapper: primary key field must be int64 or varchar")
}
hash, _ := typeutil.Hash32Int64(intPK)
shard = hash % uint32(p.shardNum)
@ -744,18 +821,118 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
return err
}
}
}
// call flush function
for i := 0; i < int(p.shardNum); i++ {
segmentData := segmentsData[i]
p.printFieldsDataInfo(segmentData, "import wrapper: prepare to flush segment", files)
err := p.callFlushFunc(segmentData, i)
// when the estimated size is close to blockSize, force flush
err = tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, false)
if err != nil {
log.Error("import wrapper: flush callback function failed", zap.Any("err", err))
return err
}
}
// force flush at the end
return tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, true)
}
// flushFunc is the callback function for parsers generate segment and save binlog files
func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData, shardID int) error {
// if fields data is empty, do nothing
var rowNum int
memSize := 0
for _, field := range fields {
rowNum = field.RowNum()
memSize += field.GetMemorySize()
break
}
if rowNum <= 0 {
log.Warn("import wrapper: fields data is empty", zap.Int("shardID", shardID))
return nil
}
// if there is no segment for this shard, create a new one
// if the segment exists and its size almost exceed segmentSize, close it and create a new one
var segment *WorkingSegment
segment, ok := p.workingSegments[shardID]
if ok {
// the segment already exists, check its size, if the size exceeds(or almost) segmentSize, close the segment
if int64(segment.memSize)+int64(memSize) >= p.segmentSize {
err := p.closeWorkingSegment(segment)
if err != nil {
return err
}
segment = nil
p.workingSegments[shardID] = nil
}
}
if segment == nil {
// create a new segment
segID, channelName, err := p.assignSegmentFunc(shardID)
if err != nil {
log.Error("import wrapper: failed to assign a new segment", zap.Any("error", err), zap.Int("shardID", shardID))
return err
}
segment = &WorkingSegment{
segmentID: segID,
shardID: shardID,
targetChName: channelName,
rowCount: int64(0),
memSize: 0,
fieldsInsert: make([]*datapb.FieldBinlog, 0),
fieldsStats: make([]*datapb.FieldBinlog, 0),
}
p.workingSegments[shardID] = segment
}
// save binlogs
fieldsInsert, fieldsStats, err := p.createBinlogsFunc(fields, segment.segmentID)
if err != nil {
log.Error("import wrapper: failed to save binlogs", zap.Any("error", err), zap.Int("shardID", shardID),
zap.Int64("segmentID", segment.segmentID), zap.String("targetChannel", segment.targetChName))
return err
}
segment.fieldsInsert = append(segment.fieldsInsert, fieldsInsert...)
segment.fieldsStats = append(segment.fieldsStats, fieldsStats...)
segment.rowCount += int64(rowNum)
segment.memSize += memSize
return nil
}
// closeWorkingSegment marks a segment to be sealed
func (p *ImportWrapper) closeWorkingSegment(segment *WorkingSegment) error {
log.Info("import wrapper: adding segment to the correct DataNode flow graph and saving binlog paths",
zap.Int("shardID", segment.shardID),
zap.Int64("segmentID", segment.segmentID),
zap.String("targetChannel", segment.targetChName),
zap.Int64("rowCount", segment.rowCount),
zap.Int("insertLogCount", len(segment.fieldsInsert)),
zap.Int("statsLogCount", len(segment.fieldsStats)))
err := p.saveSegmentFunc(segment.fieldsInsert, segment.fieldsStats, segment.segmentID, segment.targetChName, segment.rowCount)
if err != nil {
log.Error("import wrapper: failed to save segment",
zap.Any("error", err),
zap.Int("shardID", segment.shardID),
zap.Int64("segmentID", segment.segmentID),
zap.String("targetChannel", segment.targetChName))
return err
}
return nil
}
// closeAllWorkingSegments mark all segments to be sealed at the end of import operation
func (p *ImportWrapper) closeAllWorkingSegments() error {
for _, segment := range p.workingSegments {
err := p.closeWorkingSegment(segment)
if err != nil {
return err
}
}
p.workingSegments = make(map[int]*WorkingSegment)
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -19,7 +19,7 @@ package importutil
import (
"errors"
"fmt"
"strconv"
"reflect"
"go.uber.org/zap"
@ -31,33 +31,12 @@ import (
"github.com/milvus-io/milvus/internal/util/typeutil"
)
// interface to process rows data
// JSONRowHandler is the interface to process rows data
type JSONRowHandler interface {
Handle(rows []map[storage.FieldID]interface{}) error
}
// interface to process column data
type JSONColumnHandler interface {
Handle(columns map[storage.FieldID][]interface{}) error
}
// method to get dimension of vecotor field
func getFieldDimension(schema *schemapb.FieldSchema) (int, error) {
for _, kvPair := range schema.GetTypeParams() {
key, value := kvPair.GetKey(), kvPair.GetValue()
if key == "dim" {
dim, err := strconv.Atoi(value)
if err != nil {
return 0, errors.New("vector dimension is invalid")
}
return dim, nil
}
}
return 0, errors.New("vector dimension is not defined")
}
// field value validator
// Validator is field value validator
type Validator struct {
validateFunc func(obj interface{}) error // validate data type function
convertFunc func(obj interface{}, field storage.FieldData) error // convert data function
@ -68,210 +47,7 @@ type Validator struct {
fieldName string // field name
}
// method to construct valiator functions
func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[storage.FieldID]*Validator) error {
if collectionSchema == nil {
return errors.New("collection schema is nil")
}
// json decoder parse all the numeric value into float64
numericValidator := func(obj interface{}) error {
switch obj.(type) {
case float64:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := "illegal numeric value " + s
return errors.New(msg)
}
}
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
validators[schema.GetFieldID()] = &Validator{}
validators[schema.GetFieldID()].primaryKey = schema.GetIsPrimaryKey()
validators[schema.GetFieldID()].autoID = schema.GetAutoID()
validators[schema.GetFieldID()].fieldName = schema.GetName()
validators[schema.GetFieldID()].isString = false
switch schema.DataType {
case schemapb.DataType_Bool:
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch obj.(type) {
case bool:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := "illegal value " + s + " for bool type field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(bool)
field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value)
field.(*storage.BoolFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Float:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := float32(obj.(float64))
field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value)
field.(*storage.FloatFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Double:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(float64)
field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value)
field.(*storage.DoubleFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int8:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int8(obj.(float64))
field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value)
field.(*storage.Int8FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int16:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int16(obj.(float64))
field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value)
field.(*storage.Int16FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int32:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int32(obj.(float64))
field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value)
field.(*storage.Int32FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int64:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int64(obj.(float64))
field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value)
field.(*storage.Int64FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_BinaryVector:
dim, err := getFieldDimension(schema)
if err != nil {
return err
}
validators[schema.GetFieldID()].dimension = dim
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch vt := obj.(type) {
case []interface{}:
if len(vt)*8 != dim {
msg := "bit size " + strconv.Itoa(len(vt)*8) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName()
return errors.New(msg)
}
for i := 0; i < len(vt); i++ {
if e := numericValidator(vt[i]); e != nil {
msg := e.Error() + " for binary vector field " + schema.GetName()
return errors.New(msg)
}
t := int(vt[i].(float64))
if t > 255 || t < 0 {
msg := "illegal value " + strconv.Itoa(t) + " for binary vector field " + schema.GetName()
return errors.New(msg)
}
}
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not an array for binary vector field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
arr := obj.([]interface{})
for i := 0; i < len(arr); i++ {
value := byte(arr[i].(float64))
field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, value)
}
field.(*storage.BinaryVectorFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_FloatVector:
dim, err := getFieldDimension(schema)
if err != nil {
return err
}
validators[schema.GetFieldID()].dimension = dim
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch vt := obj.(type) {
case []interface{}:
if len(vt) != dim {
msg := "array size " + strconv.Itoa(len(vt)) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName()
return errors.New(msg)
}
for i := 0; i < len(vt); i++ {
if e := numericValidator(vt[i]); e != nil {
msg := e.Error() + " for float vector field " + schema.GetName()
return errors.New(msg)
}
}
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not an array for float vector field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
arr := obj.([]interface{})
for i := 0; i < len(arr); i++ {
value := float32(arr[i].(float64))
field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, value)
}
field.(*storage.FloatVectorFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_String, schemapb.DataType_VarChar:
validators[schema.GetFieldID()].isString = true
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch obj.(type) {
case string:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not a string for string type field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(string)
field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, value)
field.(*storage.StringFieldData).NumRows[0]++
return nil
}
default:
return errors.New("unsupport data type: " + strconv.Itoa(int(collectionSchema.Fields[i].DataType)))
}
}
return nil
}
// row-based json format validator class
// JSONRowValidator is row-based json format validator class
type JSONRowValidator struct {
downstream JSONRowHandler // downstream processor, typically is a JSONRowComsumer
validators map[storage.FieldID]*Validator // validators for each field
@ -286,7 +62,7 @@ func NewJSONRowValidator(collectionSchema *schemapb.CollectionSchema, downstream
}
err := initValidators(collectionSchema, v.validators)
if err != nil {
log.Error("JSON column validator: failed to initialize json row-based validator", zap.Error(err))
log.Error("JSON row validator: failed to initialize json row-based validator", zap.Error(err))
return nil, err
}
return v, nil
@ -298,13 +74,14 @@ func (v *JSONRowValidator) ValidateCount() int64 {
func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error {
if v == nil || v.validators == nil || len(v.validators) == 0 {
log.Error("JSON row validator is not initialized")
return errors.New("JSON row validator is not initialized")
}
// parse completed
if rows == nil {
log.Info("JSON row validation finished")
if v.downstream != nil {
if v.downstream != nil && !reflect.ValueOf(v.downstream).IsNil() {
return v.downstream.Handle(rows)
}
return nil
@ -314,107 +91,42 @@ func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error
row := rows[i]
for id, validator := range v.validators {
value, ok := row[id]
if validator.primaryKey && validator.autoID {
// auto-generated primary key, ignore
// primary key is auto-generated, if user provided it, return error
if ok {
log.Error("JSON row validator: primary key is auto-generated, no need to provide PK value at the row",
zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", v.rowCounter+int64(i)))
return fmt.Errorf("the primary key '%s' is auto-generated, no need to provide PK value at the row %d",
validator.fieldName, v.rowCounter+int64(i))
}
continue
}
value, ok := row[id]
if !ok {
return errors.New("JSON row validator: field " + validator.fieldName + " missed at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
log.Error("JSON row validator: field missed at the row",
zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", v.rowCounter+int64(i)))
return fmt.Errorf("the field '%s' missed at the row %d", validator.fieldName, v.rowCounter+int64(i))
}
if err := validator.validateFunc(value); err != nil {
return errors.New("JSON row validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
log.Error("JSON row validator: invalid value at the row", zap.String("fieldName", validator.fieldName),
zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Any("value", value), zap.Error(err))
return fmt.Errorf("the field '%s' value at the row %d is invalid, error: %s",
validator.fieldName, v.rowCounter+int64(i), err.Error())
}
}
}
v.rowCounter += int64(len(rows))
if v.downstream != nil {
if v.downstream != nil && !reflect.ValueOf(v.downstream).IsNil() {
return v.downstream.Handle(rows)
}
return nil
}
// column-based json format validator class
type JSONColumnValidator struct {
downstream JSONColumnHandler // downstream processor, typically is a JSONColumnComsumer
validators map[storage.FieldID]*Validator // validators for each field
rowCounter map[string]int64 // row count of each field
}
func NewJSONColumnValidator(schema *schemapb.CollectionSchema, downstream JSONColumnHandler) (*JSONColumnValidator, error) {
v := &JSONColumnValidator{
validators: make(map[storage.FieldID]*Validator),
downstream: downstream,
rowCounter: make(map[string]int64),
}
err := initValidators(schema, v.validators)
if err != nil {
log.Error("JSON column validator: fail to initialize json column-based validator", zap.Error(err))
return nil, err
}
return v, nil
}
func (v *JSONColumnValidator) ValidateCount() map[string]int64 {
return v.rowCounter
}
func (v *JSONColumnValidator) Handle(columns map[storage.FieldID][]interface{}) error {
if v == nil || v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON column validator is not initialized")
}
// parse completed
if columns == nil {
// compare the row count of columns, should be equal
rowCount := int64(-1)
for k, counter := range v.rowCounter {
if rowCount == -1 {
rowCount = counter
} else if rowCount != counter {
return errors.New("JSON column validator: the field " + k + " row count " + strconv.Itoa(int(counter)) + " is not equal to other fields row count" + strconv.Itoa(int(rowCount)))
}
}
// let the downstream know parse is completed
log.Info("JSON column validation finished")
if v.downstream != nil {
return v.downstream.Handle(nil)
}
return nil
}
for id, values := range columns {
validator, ok := v.validators[id]
name := validator.fieldName
if !ok {
// not a valid field name, skip without parsing
break
}
for i := 0; i < len(values); i++ {
if err := validator.validateFunc(values[i]); err != nil {
return errors.New("JSON column validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter[name]+int64(i), 10))
}
}
v.rowCounter[name] += int64(len(values))
}
if v.downstream != nil {
return v.downstream.Handle(columns)
}
return nil
}
type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardID int) error
// row-based json format consumer class
// JSONRowConsumer is row-based json format consumer class
type JSONRowConsumer struct {
collectionSchema *schemapb.CollectionSchema // collection schema
rowIDAllocator *allocator.IDAllocator // autoid allocator
@ -422,89 +134,14 @@ type JSONRowConsumer struct {
rowCounter int64 // how many rows have been consumed
shardNum int32 // sharding number of the collection
segmentsData []map[storage.FieldID]storage.FieldData // in-memory segments data
segmentSize int64 // maximum size of a segment(unit:byte)
blockSize int64 // maximum size of a read block(unit:byte)
primaryKey storage.FieldID // name of primary key
autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25
callFlushFunc ImportFlushFunc // call back function to flush segment
}
func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.FieldID]storage.FieldData {
segmentData := make(map[storage.FieldID]storage.FieldData)
// rowID field is a hidden field with fieldID=0, it is always auto-generated by IDAllocator
// if primary key is int64 and autoID=true, primary key field is equal to rowID field
segmentData[common.RowIDField] = &storage.Int64FieldData{
Data: make([]int64, 0),
NumRows: []int64{0},
}
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
switch schema.DataType {
case schemapb.DataType_Bool:
segmentData[schema.GetFieldID()] = &storage.BoolFieldData{
Data: make([]bool, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Float:
segmentData[schema.GetFieldID()] = &storage.FloatFieldData{
Data: make([]float32, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Double:
segmentData[schema.GetFieldID()] = &storage.DoubleFieldData{
Data: make([]float64, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int8:
segmentData[schema.GetFieldID()] = &storage.Int8FieldData{
Data: make([]int8, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int16:
segmentData[schema.GetFieldID()] = &storage.Int16FieldData{
Data: make([]int16, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int32:
segmentData[schema.GetFieldID()] = &storage.Int32FieldData{
Data: make([]int32, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int64:
segmentData[schema.GetFieldID()] = &storage.Int64FieldData{
Data: make([]int64, 0),
NumRows: []int64{0},
}
case schemapb.DataType_BinaryVector:
dim, _ := getFieldDimension(schema)
segmentData[schema.GetFieldID()] = &storage.BinaryVectorFieldData{
Data: make([]byte, 0),
NumRows: []int64{0},
Dim: dim,
}
case schemapb.DataType_FloatVector:
dim, _ := getFieldDimension(schema)
segmentData[schema.GetFieldID()] = &storage.FloatVectorFieldData{
Data: make([]float32, 0),
NumRows: []int64{0},
Dim: dim,
}
case schemapb.DataType_String, schemapb.DataType_VarChar:
segmentData[schema.GetFieldID()] = &storage.StringFieldData{
Data: make([]string, 0),
NumRows: []int64{0},
}
default:
log.Error("JSON row consumer error: unsupported data type", zap.Int("DataType", int(schema.DataType)))
return nil
}
}
return segmentData
}
func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, segmentSize int64,
func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, blockSize int64,
flushFunc ImportFlushFunc) (*JSONRowConsumer, error) {
if collectionSchema == nil {
log.Error("JSON row consumer: collection schema is nil")
@ -516,7 +153,7 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al
rowIDAllocator: idAlloc,
validators: make(map[storage.FieldID]*Validator),
shardNum: shardNum,
segmentSize: segmentSize,
blockSize: blockSize,
rowCounter: 0,
primaryKey: -1,
autoIDRange: make([]int64, 0),
@ -526,15 +163,15 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al
err := initValidators(collectionSchema, v.validators)
if err != nil {
log.Error("JSON row consumer: fail to initialize json row-based consumer", zap.Error(err))
return nil, errors.New("fail to initialize json row-based consumer")
return nil, fmt.Errorf("fail to initialize json row-based consumer: %v", err)
}
v.segmentsData = make([]map[storage.FieldID]storage.FieldData, 0, shardNum)
for i := 0; i < int(shardNum); i++ {
segmentData := initSegmentData(collectionSchema)
if segmentData == nil {
log.Error("JSON row consumer: fail to initialize in-memory segment data", zap.Int32("shardNum", shardNum))
return nil, errors.New("fail to initialize in-memory segment data")
log.Error("JSON row consumer: fail to initialize in-memory segment data", zap.Int("shardID", i))
return nil, fmt.Errorf("fail to initialize in-memory segment data for shardID %d", i)
}
v.segmentsData = append(v.segmentsData, segmentData)
}
@ -554,7 +191,7 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al
// primary key is autoid, id generator is required
if v.validators[v.primaryKey].autoID && idAlloc == nil {
log.Error("JSON row consumer: ID allocator is nil")
return nil, errors.New(" ID allocator is nil")
return nil, errors.New("ID allocator is nil")
}
return v, nil
@ -571,8 +208,17 @@ func (v *JSONRowConsumer) flush(force bool) error {
segmentData := v.segmentsData[i]
rowNum := segmentData[v.primaryKey].RowNum()
if rowNum > 0 {
log.Info("JSON row consumer: force flush segment", zap.Int("rows", rowNum))
v.callFlushFunc(segmentData, i)
log.Info("JSON row consumer: force flush binlog", zap.Int("rows", rowNum))
err := v.callFlushFunc(segmentData, i)
if err != nil {
return err
}
v.segmentsData[i] = initSegmentData(v.collectionSchema)
if v.segmentsData[i] == nil {
log.Error("JSON row consumer: fail to initialize in-memory segment data")
return errors.New("fail to initialize in-memory segment data")
}
}
}
@ -587,9 +233,13 @@ func (v *JSONRowConsumer) flush(force bool) error {
for _, field := range segmentData {
memSize += field.GetMemorySize()
}
if memSize >= int(v.segmentSize) && rowNum > 0 {
log.Info("JSON row consumer: flush fulled segment", zap.Int("bytes", memSize), zap.Int("rowNum", rowNum))
v.callFlushFunc(segmentData, i)
if memSize >= int(v.blockSize) && rowNum > 0 {
log.Info("JSON row consumer: flush fulled binlog", zap.Int("bytes", memSize), zap.Int("rowNum", rowNum))
err := v.callFlushFunc(segmentData, i)
if err != nil {
return err
}
v.segmentsData[i] = initSegmentData(v.collectionSchema)
if v.segmentsData[i] == nil {
log.Error("JSON row consumer: fail to initialize in-memory segment data")
@ -603,6 +253,7 @@ func (v *JSONRowConsumer) flush(force bool) error {
func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
if v == nil || v.validators == nil || len(v.validators) == 0 {
log.Error("JSON row consumer is not initialized")
return errors.New("JSON row consumer is not initialized")
}
@ -615,10 +266,11 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
err := v.flush(false)
if err != nil {
return err
log.Error("JSON row consumer: try flush data but failed", zap.Error(err))
return fmt.Errorf("try flush data but failed: %s", err.Error())
}
// prepare autoid
// prepare autoid, no matter int64 or varchar pk, we always generate autoid since the hidden field RowIDField requires them
primaryValidator := v.validators[v.primaryKey]
var rowIDBegin typeutil.UniqueID
var rowIDEnd typeutil.UniqueID
@ -626,13 +278,20 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
var err error
rowIDBegin, rowIDEnd, err = v.rowIDAllocator.Alloc(uint32(len(rows)))
if err != nil {
return errors.New("JSON row consumer: " + err.Error())
log.Error("JSON row consumer: failed to generate primary keys", zap.Int("count", len(rows)), zap.Error(err))
return fmt.Errorf("failed to generate %d primary keys: %s", len(rows), err.Error())
}
if rowIDEnd-rowIDBegin != int64(len(rows)) {
return errors.New("JSON row consumer: failed to allocate ID for " + strconv.Itoa(len(rows)) + " rows")
log.Error("JSON row consumer: try to generate primary keys but allocated ids are not enough",
zap.Int("count", len(rows)), zap.Int64("generated", rowIDEnd-rowIDBegin))
return fmt.Errorf("try to generate %d primary keys but only %d keys were allocated", len(rows), rowIDEnd-rowIDBegin)
}
log.Info("JSON row consumer: auto-generate primary keys", zap.Int64("begin", rowIDBegin), zap.Int64("end", rowIDEnd))
if !primaryValidator.isString {
// if pk is varchar, no need to record auto-generated row ids
v.autoIDRange = append(v.autoIDRange, rowIDBegin, rowIDEnd)
}
}
// consume rows
for i := 0; i < len(rows); i++ {
@ -642,7 +301,8 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
var shard uint32
if primaryValidator.isString {
if primaryValidator.autoID {
return errors.New("JSON row consumer: string type primary key cannot be auto-generated")
log.Error("JSON row consumer: string type primary key cannot be auto-generated")
return errors.New("string type primary key cannot be auto-generated")
}
value := row[v.primaryKey]
@ -662,7 +322,13 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
pk = int64(value.(float64))
}
hash, _ := typeutil.Hash32Int64(pk)
hash, err := typeutil.Hash32Int64(pk)
if err != nil {
log.Error("JSON row consumer: failed to hash primary key at the row",
zap.Int64("key", pk), zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Error(err))
return fmt.Errorf("failed to hash primary key %d at the row %d, error: %s", pk, v.rowCounter+int64(i), err.Error())
}
shard = hash % uint32(v.shardNum)
pkArray := v.segmentsData[shard][v.primaryKey].(*storage.Int64FieldData)
pkArray.Data = append(pkArray.Data, pk)
@ -681,7 +347,10 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
}
value := row[name]
if err := validator.convertFunc(value, v.segmentsData[shard][name]); err != nil {
return errors.New("JSON row consumer: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
log.Error("JSON row consumer: failed to convert value for field at the row",
zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Error(err))
return fmt.Errorf("failed to convert value for field %s at the row %d, error: %s",
validator.fieldName, v.rowCounter+int64(i), err.Error())
}
}
}
@ -690,115 +359,3 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
return nil
}
type ColumnFlushFunc func(fields map[storage.FieldID]storage.FieldData) error
// column-based json format consumer class
type JSONColumnConsumer struct {
collectionSchema *schemapb.CollectionSchema // collection schema
validators map[storage.FieldID]*Validator // validators for each field
fieldsData map[storage.FieldID]storage.FieldData // in-memory fields data
primaryKey storage.FieldID // name of primary key
callFlushFunc ColumnFlushFunc // call back function to flush segment
}
func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema, flushFunc ColumnFlushFunc) (*JSONColumnConsumer, error) {
if collectionSchema == nil {
log.Error("JSON column consumer: collection schema is nil")
return nil, errors.New("collection schema is nil")
}
v := &JSONColumnConsumer{
collectionSchema: collectionSchema,
validators: make(map[storage.FieldID]*Validator),
callFlushFunc: flushFunc,
}
err := initValidators(collectionSchema, v.validators)
if err != nil {
log.Error("JSON column consumer: fail to initialize validator", zap.Error(err))
return nil, errors.New("fail to initialize validator")
}
v.fieldsData = initSegmentData(collectionSchema)
if v.fieldsData == nil {
log.Error("JSON column consumer: fail to initialize in-memory segment data")
return nil, errors.New("fail to initialize in-memory segment data")
}
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
if schema.GetIsPrimaryKey() {
v.primaryKey = schema.GetFieldID()
break
}
}
return v, nil
}
func (v *JSONColumnConsumer) flush() error {
// check row count, should be equal
rowCount := 0
for id, field := range v.fieldsData {
// skip the autoid field
if id == v.primaryKey && v.validators[v.primaryKey].autoID {
continue
}
cnt := field.RowNum()
// skip 0 row fields since a data file may only import one column(there are several data files imported)
if cnt == 0 {
continue
}
// only check non-zero row fields
if rowCount == 0 {
rowCount = cnt
} else if rowCount != cnt {
return errors.New("JSON column consumer: " + strconv.FormatInt(id, 10) + " row count " + strconv.Itoa(cnt) + " doesn't equal " + strconv.Itoa(rowCount))
}
}
if rowCount == 0 {
return errors.New("JSON column consumer: row count is 0")
}
log.Info("JSON column consumer: rows parsed", zap.Int("rowCount", rowCount))
// output the fileds data, let outside split them into segments
return v.callFlushFunc(v.fieldsData)
}
func (v *JSONColumnConsumer) Handle(columns map[storage.FieldID][]interface{}) error {
if v == nil || v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON column consumer is not initialized")
}
// flush at the end
if columns == nil {
err := v.flush()
log.Info("JSON column consumer finished")
return err
}
// consume columns data
for id, values := range columns {
validator, ok := v.validators[id]
if !ok {
// not a valid field id
break
}
if validator.primaryKey && validator.autoID {
// autoid is no need to provide
break
}
// convert and consume data
for i := 0; i < len(values); i++ {
if err := validator.convertFunc(values[i], v.fieldsData[id]); err != nil {
return errors.New("JSON column consumer: " + err.Error() + " of field " + strconv.FormatInt(id, 10))
}
}
}
return nil
}

View File

@ -18,6 +18,7 @@ package importutil
import (
"context"
"errors"
"strings"
"testing"
@ -26,15 +27,15 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/storage"
)
type mockIDAllocator struct {
allocErr error
}
func (tso *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) {
func (a *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) {
return &rootcoordpb.AllocIDResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
@ -42,11 +43,13 @@ func (tso *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocI
},
ID: int64(1),
Count: req.Count,
}, nil
}, a.allocErr
}
func newIDAllocator(ctx context.Context, t *testing.T) *allocator.IDAllocator {
mockIDAllocator := &mockIDAllocator{}
func newIDAllocator(ctx context.Context, t *testing.T, allocErr error) *allocator.IDAllocator {
mockIDAllocator := &mockIDAllocator{
allocErr: allocErr,
}
idAllocator, err := allocator.NewIDAllocator(ctx, mockIDAllocator, int64(1))
assert.Nil(t, err)
@ -56,136 +59,14 @@ func newIDAllocator(ctx context.Context, t *testing.T) *allocator.IDAllocator {
return idAllocator
}
func Test_GetFieldDimension(t *testing.T) {
schema := &schemapb.FieldSchema{
FieldID: 111,
Name: "field_float_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
}
func Test_NewJSONRowValidator(t *testing.T) {
validator, err := NewJSONRowValidator(nil, nil)
assert.NotNil(t, err)
assert.Nil(t, validator)
dim, err := getFieldDimension(schema)
validator, err = NewJSONRowValidator(sampleSchema(), nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
assert.Equal(t, 4, dim)
schema.TypeParams = []*commonpb.KeyValuePair{
{Key: "dim", Value: "abc"},
}
dim, err = getFieldDimension(schema)
assert.NotNil(t, err)
assert.Equal(t, 0, dim)
schema.TypeParams = []*commonpb.KeyValuePair{}
dim, err = getFieldDimension(schema)
assert.NotNil(t, err)
assert.Equal(t, 0, dim)
}
func Test_InitValidators(t *testing.T) {
validators := make(map[storage.FieldID]*Validator)
err := initValidators(nil, validators)
assert.NotNil(t, err)
schema := sampleSchema()
// success case
err = initValidators(schema, validators)
assert.Nil(t, err)
assert.Equal(t, len(schema.Fields), len(validators))
name2ID := make(map[string]storage.FieldID)
for _, field := range schema.Fields {
name2ID[field.GetName()] = field.GetFieldID()
}
checkFunc := func(funcName string, validVal interface{}, invalidVal interface{}) {
id := name2ID[funcName]
v, ok := validators[id]
assert.True(t, ok)
err = v.validateFunc(validVal)
assert.Nil(t, err)
err = v.validateFunc(invalidVal)
assert.NotNil(t, err)
}
// validate functions
var validVal interface{} = true
var invalidVal interface{} = "aa"
checkFunc("field_bool", validVal, invalidVal)
validVal = float64(100)
invalidVal = "aa"
checkFunc("field_int8", validVal, invalidVal)
checkFunc("field_int16", validVal, invalidVal)
checkFunc("field_int32", validVal, invalidVal)
checkFunc("field_int64", validVal, invalidVal)
checkFunc("field_float", validVal, invalidVal)
checkFunc("field_double", validVal, invalidVal)
validVal = "aa"
invalidVal = 100
checkFunc("field_string", validVal, invalidVal)
validVal = []interface{}{float64(100), float64(101)}
invalidVal = "aa"
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(100)}
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(100), float64(101), float64(102)}
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{true, true}
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(255), float64(-1)}
checkFunc("field_binary_vector", validVal, invalidVal)
validVal = []interface{}{float64(1), float64(2), float64(3), float64(4)}
invalidVal = true
checkFunc("field_float_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(1), float64(2), float64(3)}
checkFunc("field_float_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5)}
checkFunc("field_float_vector", validVal, invalidVal)
invalidVal = []interface{}{"a", "b", "c", "d"}
checkFunc("field_float_vector", validVal, invalidVal)
// error cases
schema = &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: make([]*schemapb.FieldSchema, 0),
}
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 111,
Name: "field_float_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "aa"},
},
})
validators = make(map[storage.FieldID]*Validator)
err = initValidators(schema, validators)
assert.NotNil(t, err)
schema.Fields = make([]*schemapb.FieldSchema, 0)
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 110,
Name: "field_binary_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "aa"},
},
})
err = initValidators(schema, validators)
assert.NotNil(t, err)
}
func Test_JSONRowValidator(t *testing.T) {
@ -209,15 +90,15 @@ func Test_JSONRowValidator(t *testing.T) {
assert.NotNil(t, err)
assert.Equal(t, int64(0), validator.ValidateCount())
// // missed some fields
// reader = strings.NewReader(`{
// "rows":[
// {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
// {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}
// ]
// }`)
// err = parser.ParseRows(reader, validator)
// assert.NotNil(t, err)
// missed some fields
reader = strings.NewReader(`{
"rows":[
{"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
{"field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}
]
}`)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
// invalid dimension
reader = strings.NewReader(`{
@ -241,92 +122,90 @@ func Test_JSONRowValidator(t *testing.T) {
validator.validators = nil
err = validator.Handle(nil)
assert.NotNil(t, err)
}
func Test_JSONColumnValidator(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := sampleSchema()
parser := NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
// 0 row case
reader := strings.NewReader(`{
"field_bool": [],
"field_int8": [],
"field_int16": [],
"field_int32": [],
"field_int64": [],
"field_float": [],
"field_double": [],
"field_string": [],
"field_binary_vector": [],
"field_float_vector": []
}`)
validator, err := NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
for _, count := range validator.rowCounter {
assert.Equal(t, int64(0), count)
// primary key is auto-generate, but user provide pk value, return error
schema = &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "ID",
IsPrimaryKey: true,
AutoID: true,
DataType: schemapb.DataType_Int64,
},
{
FieldID: 102,
Name: "Age",
IsPrimaryKey: false,
DataType: schemapb.DataType_Int64,
},
},
}
// different row count
reader = strings.NewReader(`{
"field_bool": [true],
"field_int8": [],
"field_int16": [],
"field_int32": [1, 2, 3],
"field_int64": [],
"field_float": [],
"field_double": [],
"field_string": [],
"field_binary_vector": [],
"field_float_vector": []
}`)
validator, err = NewJSONColumnValidator(schema, nil)
validator, err = NewJSONRowValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
// invalid value type
reader = strings.NewReader(`{
"dummy": [],
"field_bool": [true],
"field_int8": [1],
"field_int16": [2],
"field_int32": [3],
"field_int64": [4],
"field_float": [1],
"field_double": [1],
"field_string": [9],
"field_binary_vector": [[254, 1]],
"field_float_vector": [[1.1, 1.2, 1.3, 1.4]]
"rows":[
{"ID": 1, "Age": 2}
]
}`)
parser = NewJSONParser(ctx, schema)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
}
validator, err = NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
func Test_NewJSONRowConsumer(t *testing.T) {
// nil schema
consumer, err := NewJSONRowConsumer(nil, nil, 2, 16, nil)
assert.NotNil(t, err)
assert.Nil(t, consumer)
// wrong schema
schema := &schemapb.CollectionSchema{
Name: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "uid",
IsPrimaryKey: true,
AutoID: false,
DataType: schemapb.DataType_None,
},
},
}
consumer, err = NewJSONRowConsumer(schema, nil, 2, 16, nil)
assert.NotNil(t, err)
assert.Nil(t, consumer)
// no primary key
schema.Fields[0].IsPrimaryKey = false
schema.Fields[0].DataType = schemapb.DataType_Int64
consumer, err = NewJSONRowConsumer(schema, nil, 2, 16, nil)
assert.NotNil(t, err)
assert.Nil(t, consumer)
// primary key is autoid, but no IDAllocator
schema.Fields[0].IsPrimaryKey = true
schema.Fields[0].AutoID = true
consumer, err = NewJSONRowConsumer(schema, nil, 2, 16, nil)
assert.NotNil(t, err)
assert.Nil(t, consumer)
// success
consumer, err = NewJSONRowConsumer(sampleSchema(), nil, 2, 16, nil)
assert.NotNil(t, consumer)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
// init failed
validator.validators = nil
err = validator.Handle(nil)
assert.NotNil(t, err)
}
func Test_JSONRowConsumer(t *testing.T) {
ctx := context.Background()
idAllocator := newIDAllocator(ctx, t)
idAllocator := newIDAllocator(ctx, t, nil)
schema := sampleSchema()
parser := NewJSONParser(ctx, schema)
@ -376,9 +255,196 @@ func Test_JSONRowConsumer(t *testing.T) {
assert.Equal(t, 5, totalCount)
}
func Test_JSONRowConsumerFlush(t *testing.T) {
var callTime int32
var totalCount int
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shard int) error {
callTime++
field, ok := fields[101]
assert.True(t, ok)
assert.Greater(t, field.RowNum(), 0)
totalCount += field.RowNum()
return nil
}
schema := &schemapb.CollectionSchema{
Name: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "uid",
IsPrimaryKey: true,
AutoID: false,
DataType: schemapb.DataType_Int64,
},
},
}
var shardNum int32 = 4
var blockSize int64 = 1
consumer, err := NewJSONRowConsumer(schema, nil, shardNum, blockSize, flushFunc)
assert.NotNil(t, consumer)
assert.Nil(t, err)
// force flush
rowCountEachShard := 100
for i := 0; i < int(shardNum); i++ {
pkFieldData := consumer.segmentsData[i][101].(*storage.Int64FieldData)
for j := 0; j < rowCountEachShard; j++ {
pkFieldData.Data = append(pkFieldData.Data, int64(j))
}
pkFieldData.NumRows = []int64{int64(rowCountEachShard)}
}
err = consumer.flush(true)
assert.Nil(t, err)
assert.Equal(t, shardNum, callTime)
assert.Equal(t, rowCountEachShard*int(shardNum), totalCount)
// execeed block size trigger flush
callTime = 0
totalCount = 0
for i := 0; i < int(shardNum); i++ {
consumer.segmentsData[i] = initSegmentData(schema)
if i%2 == 0 {
continue
}
pkFieldData := consumer.segmentsData[i][101].(*storage.Int64FieldData)
for j := 0; j < rowCountEachShard; j++ {
pkFieldData.Data = append(pkFieldData.Data, int64(j))
}
pkFieldData.NumRows = []int64{int64(rowCountEachShard)}
}
err = consumer.flush(true)
assert.Nil(t, err)
assert.Equal(t, shardNum/2, callTime)
assert.Equal(t, rowCountEachShard*int(shardNum)/2, totalCount)
}
func Test_JSONRowConsumerHandle(t *testing.T) {
ctx := context.Background()
idAllocator := newIDAllocator(ctx, t, errors.New("error"))
var callTime int32
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shard int) error {
callTime++
return errors.New("dummy error")
}
schema := &schemapb.CollectionSchema{
Name: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "uid",
IsPrimaryKey: true,
AutoID: true,
DataType: schemapb.DataType_Int64,
},
},
}
t.Run("handle int64 pk", func(t *testing.T) {
consumer, err := NewJSONRowConsumer(schema, idAllocator, 1, 1, flushFunc)
assert.NotNil(t, consumer)
assert.Nil(t, err)
pkFieldData := consumer.segmentsData[0][101].(*storage.Int64FieldData)
for i := 0; i < 10; i++ {
pkFieldData.Data = append(pkFieldData.Data, int64(i))
}
pkFieldData.NumRows = []int64{int64(10)}
// nil input will trigger flush
err = consumer.Handle(nil)
assert.NotNil(t, err)
assert.Equal(t, int32(1), callTime)
// optional flush
callTime = 0
rowCount := 100
pkFieldData = consumer.segmentsData[0][101].(*storage.Int64FieldData)
for j := 0; j < rowCount; j++ {
pkFieldData.Data = append(pkFieldData.Data, int64(j))
}
pkFieldData.NumRows = []int64{int64(rowCount)}
input := make([]map[storage.FieldID]interface{}, rowCount)
for j := 0; j < rowCount; j++ {
input[j] = make(map[int64]interface{})
input[j][101] = int64(j)
}
err = consumer.Handle(input)
assert.NotNil(t, err)
assert.Equal(t, int32(1), callTime)
// failed to auto-generate pk
consumer.blockSize = 1024 * 1024
err = consumer.Handle(input)
assert.NotNil(t, err)
// hash int64 pk
consumer.rowIDAllocator = newIDAllocator(ctx, t, nil)
err = consumer.Handle(input)
assert.Nil(t, err)
assert.Equal(t, int64(rowCount), consumer.rowCounter)
assert.Equal(t, 2, len(consumer.autoIDRange))
assert.Equal(t, int64(1), consumer.autoIDRange[0])
assert.Equal(t, int64(1+rowCount), consumer.autoIDRange[1])
})
t.Run("handle varchar pk", func(t *testing.T) {
schema = &schemapb.CollectionSchema{
Name: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "uid",
IsPrimaryKey: true,
AutoID: true,
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "1024"},
},
},
},
}
idAllocator := newIDAllocator(ctx, t, nil)
consumer, err := NewJSONRowConsumer(schema, idAllocator, 1, 1024*1024, flushFunc)
assert.NotNil(t, consumer)
assert.Nil(t, err)
rowCount := 100
input := make([]map[storage.FieldID]interface{}, rowCount)
for j := 0; j < rowCount; j++ {
input[j] = make(map[int64]interface{})
input[j][101] = "abc"
}
// varchar pk cannot be autoid
err = consumer.Handle(input)
assert.NotNil(t, err)
// hash varchar pk
schema.Fields[0].AutoID = false
consumer, err = NewJSONRowConsumer(schema, idAllocator, 1, 1024*1024, flushFunc)
assert.NotNil(t, consumer)
assert.Nil(t, err)
err = consumer.Handle(input)
assert.Nil(t, err)
assert.Equal(t, int64(rowCount), consumer.rowCounter)
assert.Equal(t, 0, len(consumer.autoIDRange))
})
}
func Test_JSONRowConsumerStringKey(t *testing.T) {
ctx := context.Background()
idAllocator := newIDAllocator(ctx, t)
idAllocator := newIDAllocator(ctx, t, nil)
schema := strKeySchema()
parser := NewJSONParser(ctx, schema)
@ -501,71 +567,3 @@ func Test_JSONRowConsumerStringKey(t *testing.T) {
assert.Equal(t, shardNum, callTime)
assert.Equal(t, 10, totalCount)
}
func Test_JSONColumnConsumer(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := sampleSchema()
parser := NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
reader := strings.NewReader(`{
"field_bool": [true, false, true, true, true],
"field_int8": [10, 11, 12, 13, 14],
"field_int16": [100, 101, 102, 103, 104],
"field_int32": [1000, 1001, 1002, 1003, 1004],
"field_int64": [10000, 10001, 10002, 10003, 10004],
"field_float": [3.14, 3.15, 3.16, 3.17, 3.18],
"field_double": [5.1, 5.2, 5.3, 5.4, 5.5],
"field_string": ["a", "b", "c", "d", "e"],
"field_binary_vector": [
[254, 1],
[253, 2],
[252, 3],
[251, 4],
[250, 5]
],
"field_float_vector": [
[1.1, 1.2, 1.3, 1.4],
[2.1, 2.2, 2.3, 2.4],
[3.1, 3.2, 3.3, 3.4],
[4.1, 4.2, 4.3, 4.4],
[5.1, 5.2, 5.3, 5.4]
]
}`)
callTime := 0
rowCount := 0
consumeFunc := func(fields map[storage.FieldID]storage.FieldData) error {
callTime++
for id, data := range fields {
if id == common.RowIDField {
continue
}
if rowCount == 0 {
rowCount = data.RowNum()
} else {
assert.Equal(t, rowCount, data.RowNum())
}
}
return nil
}
consumer, err := NewJSONColumnConsumer(schema, consumeFunc)
assert.NotNil(t, consumer)
assert.Nil(t, err)
validator, err := NewJSONColumnValidator(schema, consumer)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.Nil(t, err)
for _, count := range validator.ValidateCount() {
assert.Equal(t, int64(5), count)
}
assert.Equal(t, 1, callTime)
assert.Equal(t, 5, rowCount)
}

View File

@ -20,6 +20,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
@ -84,28 +85,26 @@ func adjustBufSize(parser *JSONParser, collectionSchema *schemapb.CollectionSche
bufSize = MinBufferSize
}
log.Info("JSON parse: reset bufSize", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufSize", bufSize))
log.Info("JSON parser: reset bufSize", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufSize", bufSize))
parser.bufSize = int64(bufSize)
}
func (p *JSONParser) logError(msg string) error {
log.Error(msg)
return errors.New(msg)
}
func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
if handler == nil {
return p.logError("JSON parse handler is nil")
log.Error("JSON parse handler is nil")
return errors.New("JSON parse handler is nil")
}
dec := json.NewDecoder(r)
t, err := dec.Token()
if err != nil {
return p.logError("JSON parse: row count is 0")
log.Error("JSON parser: row count is 0")
return errors.New("JSON parser: row count is 0")
}
if t != json.Delim('{') {
return p.logError("JSON parse: invalid JSON format, the content should be started with'{'")
log.Error("JSON parser: invalid JSON format, the content should be started with'{'")
return errors.New("JSON parser: invalid JSON format, the content should be started with'{'")
}
// read the first level
@ -114,24 +113,28 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
// read the key
t, err := dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
log.Error("JSON parser: read json token error", zap.Any("err", err))
return fmt.Errorf("JSON parser: read json token error: %v", err)
}
key := t.(string)
keyLower := strings.ToLower(key)
// the root key should be RowRootNode
if keyLower != RowRootNode {
return p.logError("JSON parse: invalid row-based JSON format, the key " + key + " is not found")
log.Error("JSON parser: invalid row-based JSON format, the key is not found", zap.String("key", key))
return fmt.Errorf("JSON parser: invalid row-based JSON format, the key %s is not found", key)
}
// started by '['
t, err = dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
log.Error("JSON parser: read json token error", zap.Any("err", err))
return fmt.Errorf("JSON parser: read json token error: %v", err)
}
if t != json.Delim('[') {
return p.logError("JSON parse: invalid row-based JSON format, rows list should begin with '['")
log.Error("JSON parser: invalid row-based JSON format, rows list should begin with '['")
return errors.New("JSON parser: invalid row-based JSON format, rows list should begin with '['")
}
// read buffer
@ -139,27 +142,36 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
for dec.More() {
var value interface{}
if err := dec.Decode(&value); err != nil {
return p.logError("JSON parse: " + err.Error())
log.Error("JSON parser: decode json value error", zap.Any("err", err))
return fmt.Errorf("JSON parser: decode json value error: %v", err)
}
switch value.(type) {
case map[string]interface{}:
break
default:
return p.logError("JSON parse: invalid JSON format, each row should be a key-value map")
log.Error("JSON parser: invalid JSON format, each row should be a key-value map")
return errors.New("JSON parser: invalid JSON format, each row should be a key-value map")
}
row := make(map[storage.FieldID]interface{})
stringMap := value.(map[string]interface{})
for k, v := range stringMap {
row[p.name2FieldID[k]] = v
// if user provided redundant field, return error
fieldID, ok := p.name2FieldID[k]
if !ok {
log.Error("JSON parser: the field is not defined in collection schema", zap.String("fieldName", k))
return fmt.Errorf("JSON parser: the field '%s' is not defined in collection schema", k)
}
row[fieldID] = v
}
buf = append(buf, row)
if len(buf) >= int(p.bufSize) {
isEmpty = false
if err = handler.Handle(buf); err != nil {
return p.logError(err.Error())
log.Error("JSON parser: parse values error", zap.Any("err", err))
return fmt.Errorf("JSON parser: parse values error: %v", err)
}
// clear the buffer
@ -171,26 +183,27 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
if len(buf) > 0 {
isEmpty = false
if err = handler.Handle(buf); err != nil {
return p.logError(err.Error())
log.Error("JSON parser: parse values error", zap.Any("err", err))
return fmt.Errorf("JSON parser: parse values error: %v", err)
}
}
// end by ']'
t, err = dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
log.Error("JSON parser: read json token error", zap.Any("err", err))
return fmt.Errorf("JSON parser: read json token error: %v", err)
}
if t != json.Delim(']') {
return p.logError("JSON parse: invalid column-based JSON format, rows list should end with a ']'")
log.Error("JSON parser: invalid column-based JSON format, rows list should end with a ']'")
return errors.New("JSON parser: invalid column-based JSON format, rows list should end with a ']'")
}
// canceled?
select {
case <-p.ctx.Done():
return p.logError("import task was canceled")
default:
break
// outside context might be canceled(service stop, or future enhancement for canceling import task)
if isCanceled(p.ctx) {
log.Error("JSON parser: import task was canceled")
return errors.New("JSON parser: import task was canceled")
}
// this break means we require the first node must be RowRootNode
@ -199,106 +212,8 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
}
if isEmpty {
return p.logError("JSON parse: row count is 0")
}
// send nil to notify the handler all have done
return handler.Handle(nil)
}
func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error {
if handler == nil {
return p.logError("JSON parse handler is nil")
}
dec := json.NewDecoder(r)
t, err := dec.Token()
if err != nil {
return p.logError("JSON parse: row count is 0")
}
if t != json.Delim('{') {
return p.logError("JSON parse: invalid JSON format, the content should be started with'{'")
}
// read the first level
isEmpty := true
for dec.More() {
// read the key
t, err := dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
}
key := t.(string)
// not a valid column name, skip
_, isValidField := p.fields[key]
// started by '['
t, err = dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
}
if t != json.Delim('[') {
return p.logError("JSON parse: invalid column-based JSON format, each field should begin with '['")
}
id := p.name2FieldID[key]
// read buffer
buf := make(map[storage.FieldID][]interface{})
buf[id] = make([]interface{}, 0, MinBufferSize)
for dec.More() {
var value interface{}
if err := dec.Decode(&value); err != nil {
return p.logError("JSON parse: " + err.Error())
}
if !isValidField {
continue
}
buf[id] = append(buf[id], value)
if len(buf[id]) >= int(p.bufSize) {
isEmpty = false
if err = handler.Handle(buf); err != nil {
return p.logError(err.Error())
}
// clear the buffer
buf[id] = make([]interface{}, 0, MinBufferSize)
}
}
// some values in buffer not parsed, parse them
if len(buf[id]) > 0 {
isEmpty = false
if err = handler.Handle(buf); err != nil {
return p.logError(err.Error())
}
}
// end by ']'
t, err = dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
}
if t != json.Delim(']') {
return p.logError("JSON parse: invalid column-based JSON format, each field should end with a ']'")
}
// canceled?
select {
case <-p.ctx.Done():
return p.logError("import task was canceled")
default:
break
}
}
if isEmpty {
return p.logError("JSON parse: row count is 0")
log.Error("JSON parser: row count is 0")
return errors.New("JSON parser: row count is 0")
}
// send nil to notify the handler all have done

View File

@ -27,157 +27,6 @@ import (
"github.com/stretchr/testify/assert"
)
func sampleSchema() *schemapb.CollectionSchema {
schema := &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 102,
Name: "field_bool",
IsPrimaryKey: false,
Description: "bool",
DataType: schemapb.DataType_Bool,
},
{
FieldID: 103,
Name: "field_int8",
IsPrimaryKey: false,
Description: "int8",
DataType: schemapb.DataType_Int8,
},
{
FieldID: 104,
Name: "field_int16",
IsPrimaryKey: false,
Description: "int16",
DataType: schemapb.DataType_Int16,
},
{
FieldID: 105,
Name: "field_int32",
IsPrimaryKey: false,
Description: "int32",
DataType: schemapb.DataType_Int32,
},
{
FieldID: 106,
Name: "field_int64",
IsPrimaryKey: true,
AutoID: false,
Description: "int64",
DataType: schemapb.DataType_Int64,
},
{
FieldID: 107,
Name: "field_float",
IsPrimaryKey: false,
Description: "float",
DataType: schemapb.DataType_Float,
},
{
FieldID: 108,
Name: "field_double",
IsPrimaryKey: false,
Description: "double",
DataType: schemapb.DataType_Double,
},
{
FieldID: 109,
Name: "field_string",
IsPrimaryKey: false,
Description: "string",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "128"},
},
},
{
FieldID: 110,
Name: "field_binary_vector",
IsPrimaryKey: false,
Description: "binary_vector",
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "16"},
},
},
{
FieldID: 111,
Name: "field_float_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
return schema
}
func strKeySchema() *schemapb.CollectionSchema {
schema := &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 101,
Name: "uid",
IsPrimaryKey: true,
AutoID: false,
Description: "uid",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "1024"},
},
},
{
FieldID: 102,
Name: "int_scalar",
IsPrimaryKey: false,
Description: "int_scalar",
DataType: schemapb.DataType_Int32,
},
{
FieldID: 103,
Name: "float_scalar",
IsPrimaryKey: false,
Description: "float_scalar",
DataType: schemapb.DataType_Float,
},
{
FieldID: 104,
Name: "string_scalar",
IsPrimaryKey: false,
Description: "string_scalar",
DataType: schemapb.DataType_VarChar,
},
{
FieldID: 105,
Name: "bool_scalar",
IsPrimaryKey: false,
Description: "bool_scalar",
DataType: schemapb.DataType_Bool,
},
{
FieldID: 106,
Name: "vectors",
IsPrimaryKey: false,
Description: "vectors",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
return schema
}
func Test_AdjustBufSize(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -234,6 +83,7 @@ func Test_JSONParserParserRows(t *testing.T) {
]
}`)
// handler is nil
err := parser.ParseRows(reader, nil)
assert.NotNil(t, err)
@ -241,10 +91,12 @@ func Test_JSONParserParserRows(t *testing.T) {
assert.NotNil(t, validator)
assert.Nil(t, err)
// success
err = parser.ParseRows(reader, validator)
assert.Nil(t, err)
assert.Equal(t, int64(5), validator.ValidateCount())
// not a row-based format
reader = strings.NewReader(`{
"dummy":[]
}`)
@ -255,6 +107,7 @@ func Test_JSONParserParserRows(t *testing.T) {
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
// rows is not a list
reader = strings.NewReader(`{
"rows":
}`)
@ -265,6 +118,7 @@ func Test_JSONParserParserRows(t *testing.T) {
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
// typo
reader = strings.NewReader(`{
"rows": [}
}`)
@ -275,6 +129,7 @@ func Test_JSONParserParserRows(t *testing.T) {
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
// rows is not a list
reader = strings.NewReader(`{
"rows": {}
}`)
@ -285,6 +140,7 @@ func Test_JSONParserParserRows(t *testing.T) {
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
// rows is not a list of list
reader = strings.NewReader(`{
"rows": [[]]
}`)
@ -295,6 +151,7 @@ func Test_JSONParserParserRows(t *testing.T) {
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
// not valid json format
reader = strings.NewReader(`[]`)
validator, err = NewJSONRowValidator(schema, nil)
assert.NotNil(t, validator)
@ -303,6 +160,7 @@ func Test_JSONParserParserRows(t *testing.T) {
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
// empty content
reader = strings.NewReader(`{}`)
validator, err = NewJSONRowValidator(schema, nil)
assert.NotNil(t, validator)
@ -311,6 +169,7 @@ func Test_JSONParserParserRows(t *testing.T) {
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
// empty content
reader = strings.NewReader(``)
validator, err = NewJSONRowValidator(schema, nil)
assert.NotNil(t, validator)
@ -318,129 +177,23 @@ func Test_JSONParserParserRows(t *testing.T) {
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
}
func Test_JSONParserParserColumns(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := sampleSchema()
parser := NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
parser.bufSize = 1
reader := strings.NewReader(`{
"field_bool": [true, false, true, true, true],
"field_int8": [10, 11, 12, 13, 14],
"field_int16": [100, 101, 102, 103, 104],
"field_int32": [1000, 1001, 1002, 1003, 1004],
"field_int64": [10000, 10001, 10002, 10003, 10004],
"field_float": [3.14, 3.15, 3.16, 3.17, 3.18],
"field_double": [5.1, 5.2, 5.3, 5.4, 5.5],
"field_string": ["a", "b", "c", "d", "e"],
"field_binary_vector": [
[254, 1],
[253, 2],
[252, 3],
[251, 4],
[250, 5]
],
"field_float_vector": [
[1.1, 1.2, 1.3, 1.4],
[2.1, 2.2, 2.3, 2.4],
[3.1, 3.2, 3.3, 3.4],
[4.1, 4.2, 4.3, 4.4],
[5.1, 5.2, 5.3, 5.4]
// redundant field
reader = strings.NewReader(`{
"rows":[
{"dummy": 1, "field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
]
}`)
err := parser.ParseColumns(reader, nil)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
validator, err := NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.Nil(t, err)
counter := validator.ValidateCount()
for _, v := range counter {
assert.Equal(t, int64(5), v)
}
// field missed
reader = strings.NewReader(`{
"field_int8": [10, 11, 12, 13, 14],
"dummy":[1, 2, 3]
"rows":[
{"field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
]
}`)
validator, err = NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.Nil(t, err)
reader = strings.NewReader(`{
"dummy":[1, 2, 3]
}`)
validator, err = NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`{
"field_bool":
}`)
validator, err = NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`{
"field_bool":{}
}`)
validator, err = NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`{
"field_bool":[}
}`)
validator, err = NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`[]`)
validator, err = NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`{}`)
validator, err = NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(``)
validator, err = NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
}
@ -545,42 +298,3 @@ func Test_JSONParserParserRowsStringKey(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, int64(10), validator.ValidateCount())
}
func Test_JSONParserParserColumnsStrKey(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := strKeySchema()
parser := NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
parser.bufSize = 1
reader := strings.NewReader(`{
"uid": ["Dm4aWrbNzhmjwCTEnCJ9LDPO2N09sqysxgVfbH9Zmn3nBzmwsmk0eZN6x7wSAoPQ", "RP50U0d2napRjXu94a8oGikWgklvVsXFurp8RR4tHGw7N0gk1b7opm59k3FCpyPb", "oxhFkQitWPPw0Bjmj7UQcn4iwvS0CU7RLAC81uQFFQjWtOdiB329CPyWkfGSeYfE", "sxoEL4Mpk1LdsyXhbNm059UWJ3CvxURLCQczaVI5xtBD4QcVWTDFUW7dBdye6nbn", "g33Rqq2UQSHPRHw5FvuXxf5uGEhIAetxE6UuXXCJj0hafG8WuJr1ueZftsySCqAd"],
"int_scalar": [9070353, 8505288, 4392660, 7927425, 9288807],
"float_scalar": [0.9798043638085004, 0.937913432198687, 0.32381232630490264, 0.31074026464844895, 0.4953578200336135],
"string_scalar": ["ShQ44OX0z8kGpRPhaXmfSsdH7JHq5DsZzu0e2umS1hrWG0uONH2RIIAdOECaaXir", "Ld4b0avxathBdNvCrtm3QsWO1pYktUVR7WgAtrtozIwrA8vpeactNhJ85CFGQnK5", "EmAlB0xdQcxeBtwlZJQnLgKodiuRinynoQtg0eXrjkq24dQohzSm7Bx3zquHd3kO", "fdY2beCvs1wSws0Gb9ySD92xwfEfJpX5DQgsWoISylBAoYOcXpRaqIJoXYS4g269", "6f8Iv1zQAGksj5XxMbbI5evTrYrB8fSFQ58jl0oU7Z4BpA81VsD2tlWqkhfoBNa7"],
"bool_scalar": [true, false, true, false, false],
"vectors": [
[0.5040062902126952, 0.8297619818664708, 0.20248342801564806, 0.12834786423659314],
[0.528232122836893, 0.6916116750653186, 0.41443762522548705, 0.26624344144792056],
[0.7978693027281338, 0.12394906726785092, 0.42431962903815285, 0.4098707807351914],
[0.3716157812069954, 0.006981281113265229, 0.9007003458552365, 0.22492634316191004],
[0.5921374209648096, 0.04234832587925662, 0.7803878096531548, 0.1964045837884633]
]
}`)
err := parser.ParseColumns(reader, nil)
assert.NotNil(t, err)
validator, err := NewJSONColumnValidator(schema, nil)
assert.NotNil(t, validator)
assert.Nil(t, err)
err = parser.ParseColumns(reader, validator)
assert.Nil(t, err)
counter := validator.ValidateCount()
for _, v := range counter {
assert.Equal(t, int64(5), v)
}
}

View File

@ -20,11 +20,27 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"reflect"
"regexp"
"strconv"
"unicode/utf8"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/log"
"github.com/sbinet/npyio"
"github.com/sbinet/npyio/npy"
"go.uber.org/zap"
)
var (
reStrPre = regexp.MustCompile(`^[|]*?(\d.*)[Sa]$`)
reStrPost = regexp.MustCompile(`^[|]*?[Sa](\d.*)$`)
reUniPre = regexp.MustCompile(`^[<|>]*?(\d.*)U$`)
reUniPost = regexp.MustCompile(`^[<|>]*?U(\d.*)$`)
)
func CreateNumpyFile(path string, data interface{}) error {
@ -52,7 +68,7 @@ func CreateNumpyData(data interface{}) ([]byte, error) {
return buf.Bytes(), nil
}
// a class to expand other numpy lib ability
// NumpyAdapter is the class to expand other numpy lib ability
// we evaluate two go-numpy lins: github.com/kshedden/gonpy and github.com/sbinet/npyio
// the npyio lib read data one by one, the performance is poor, we expand the read methods
// to read data in one batch, the performance is 100X faster
@ -63,6 +79,7 @@ type NumpyAdapter struct {
npyReader *npy.Reader // reader of npyio lib
order binary.ByteOrder // LittleEndian or BigEndian
readPosition int // how many elements have been read
dataType schemapb.DataType // data type parsed from numpy file header
}
func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) {
@ -70,17 +87,106 @@ func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) {
if err != nil {
return nil, err
}
dataType, err := convertNumpyType(r.Header.Descr.Type)
if err != nil {
return nil, err
}
adapter := &NumpyAdapter{
reader: reader,
npyReader: r,
readPosition: 0,
dataType: dataType,
}
adapter.setByteOrder()
log.Info("Numpy adapter: numpy header info",
zap.Any("shape", r.Header.Descr.Shape),
zap.String("dType", r.Header.Descr.Type),
zap.Uint8("majorVer", r.Header.Major),
zap.Uint8("minorVer", r.Header.Minor),
zap.String("ByteOrder", adapter.order.String()))
return adapter, err
}
// the logic of this method is copied from npyio lib
// convertNumpyType gets data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector)
func convertNumpyType(typeStr string) (schemapb.DataType, error) {
log.Info("Numpy adapter: parse numpy file dtype", zap.String("dtype", typeStr))
switch typeStr {
case "b1", "<b1", "|b1", "bool":
return schemapb.DataType_Bool, nil
case "u1", "<u1", "|u1", "uint8": // binary vector data type is uint8
return schemapb.DataType_BinaryVector, nil
case "i1", "<i1", "|i1", ">i1", "int8":
return schemapb.DataType_Int8, nil
case "i2", "<i2", "|i2", ">i2", "int16":
return schemapb.DataType_Int16, nil
case "i4", "<i4", "|i4", ">i4", "int32":
return schemapb.DataType_Int32, nil
case "i8", "<i8", "|i8", ">i8", "int64":
return schemapb.DataType_Int64, nil
case "f4", "<f4", "|f4", ">f4", "float32":
return schemapb.DataType_Float, nil
case "f8", "<f8", "|f8", ">f8", "float64":
return schemapb.DataType_Double, nil
default:
if isStringType(typeStr) {
return schemapb.DataType_VarChar, nil
}
log.Error("Numpy adapter: the numpy file data type not supported", zap.String("dataType", typeStr))
return schemapb.DataType_None, fmt.Errorf("Numpy adapter: the numpy file dtype '%s' is not supported", typeStr)
}
}
func stringLen(dtype string) (int, bool, error) {
var utf bool
switch {
case reStrPre.MatchString(dtype), reStrPost.MatchString(dtype):
utf = false
case reUniPre.MatchString(dtype), reUniPost.MatchString(dtype):
utf = true
}
if m := reStrPre.FindStringSubmatch(dtype); m != nil {
v, err := strconv.Atoi(m[1])
if err != nil {
return 0, false, err
}
return v, utf, nil
}
if m := reStrPost.FindStringSubmatch(dtype); m != nil {
v, err := strconv.Atoi(m[1])
if err != nil {
return 0, false, err
}
return v, utf, nil
}
if m := reUniPre.FindStringSubmatch(dtype); m != nil {
v, err := strconv.Atoi(m[1])
if err != nil {
return 0, false, err
}
return v, utf, nil
}
if m := reUniPost.FindStringSubmatch(dtype); m != nil {
v, err := strconv.Atoi(m[1])
if err != nil {
return 0, false, err
}
return v, utf, nil
}
return 0, false, fmt.Errorf("Numpy adapter: data type '%s' of numpy file is not varchar data type", dtype)
}
func isStringType(typeStr string) bool {
rt := npyio.TypeFrom(typeStr)
return rt == reflect.TypeOf((*string)(nil)).Elem()
}
// setByteOrder sets BigEndian/LittleEndian, the logic of this method is copied from npyio lib
func (n *NumpyAdapter) setByteOrder() {
var nativeEndian binary.ByteOrder
v := uint16(1)
@ -109,15 +215,15 @@ func (n *NumpyAdapter) NpyReader() *npy.Reader {
return n.npyReader
}
func (n *NumpyAdapter) GetType() string {
return n.npyReader.Header.Descr.Type
func (n *NumpyAdapter) GetType() schemapb.DataType {
return n.dataType
}
func (n *NumpyAdapter) GetShape() []int {
return n.npyReader.Header.Descr.Shape
}
func (n *NumpyAdapter) checkSize(size int) int {
func (n *NumpyAdapter) checkCount(count int) int {
shape := n.GetShape()
// empty file?
@ -135,35 +241,34 @@ func (n *NumpyAdapter) checkSize(size int) int {
}
// overflow?
if size > (total - n.readPosition) {
if count > (total - n.readPosition) {
return total - n.readPosition
}
return size
return count
}
func (n *NumpyAdapter) ReadBool(size int) ([]bool, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
func (n *NumpyAdapter) ReadBool(count int) ([]bool, error) {
if count <= 0 {
return nil, errors.New("Numpy adapter: cannot read bool data with a zero or nagative count")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "b1", "<b1", "|b1", "bool":
default:
return nil, errors.New("numpy data is not bool type")
if n.dataType != schemapb.DataType_Bool {
return nil, errors.New("Numpy adapter: numpy data is not bool type")
}
// avoid read overflow
readSize := n.checkSize(size)
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("nothing to read")
return nil, errors.New("Numpy adapter: end of bool file, nothing to read")
}
// read data
data := make([]bool, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
return nil, fmt.Errorf("Numpy adapter: failed to read bool data with count %d, error: %w", readSize, err)
}
// update read position after successfully read
@ -172,28 +277,30 @@ func (n *NumpyAdapter) ReadBool(size int) ([]bool, error) {
return data, nil
}
func (n *NumpyAdapter) ReadUint8(size int) ([]uint8, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
func (n *NumpyAdapter) ReadUint8(count int) ([]uint8, error) {
if count <= 0 {
return nil, errors.New("Numpy adapter: cannot read uint8 data with a zero or nagative count")
}
// incorrect type
// here we don't use n.dataType to check because currently milvus has no uint8 type
switch n.npyReader.Header.Descr.Type {
case "u1", "<u1", "|u1", "uint8":
default:
return nil, errors.New("numpy data is not uint8 type")
return nil, errors.New("Numpy adapter: numpy data is not uint8 type")
}
// avoid read overflow
readSize := n.checkSize(size)
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("nothing to read")
return nil, errors.New("Numpy adapter: end of uint8 file, nothing to read")
}
// read data
data := make([]uint8, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
return nil, fmt.Errorf("Numpy adapter: failed to read uint8 data with count %d, error: %w", readSize, err)
}
// update read position after successfully read
@ -202,28 +309,27 @@ func (n *NumpyAdapter) ReadUint8(size int) ([]uint8, error) {
return data, nil
}
func (n *NumpyAdapter) ReadInt8(size int) ([]int8, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
func (n *NumpyAdapter) ReadInt8(count int) ([]int8, error) {
if count <= 0 {
return nil, errors.New("Numpy adapter: cannot read int8 data with a zero or nagative count")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "i1", "<i1", "|i1", ">i1", "int8":
default:
return nil, errors.New("numpy data is not int8 type")
if n.dataType != schemapb.DataType_Int8 {
return nil, errors.New("Numpy adapter: numpy data is not int8 type")
}
// avoid read overflow
readSize := n.checkSize(size)
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("nothing to read")
return nil, errors.New("Numpy adapter: end of int8 file, nothing to read")
}
// read data
data := make([]int8, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
return nil, fmt.Errorf("Numpy adapter: failed to read int8 data with count %d, error: %w", readSize, err)
}
// update read position after successfully read
@ -232,28 +338,27 @@ func (n *NumpyAdapter) ReadInt8(size int) ([]int8, error) {
return data, nil
}
func (n *NumpyAdapter) ReadInt16(size int) ([]int16, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
func (n *NumpyAdapter) ReadInt16(count int) ([]int16, error) {
if count <= 0 {
return nil, errors.New("Numpy adapter: cannot read int16 data with a zero or nagative count")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "i2", "<i2", "|i2", ">i2", "int16":
default:
return nil, errors.New("numpy data is not int16 type")
if n.dataType != schemapb.DataType_Int16 {
return nil, errors.New("Numpy adapter: numpy data is not int16 type")
}
// avoid read overflow
readSize := n.checkSize(size)
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("nothing to read")
return nil, errors.New("Numpy adapter: end of int16 file, nothing to read")
}
// read data
data := make([]int16, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
return nil, fmt.Errorf("Numpy adapter: failed to read int16 data with count %d, error: %w", readSize, err)
}
// update read position after successfully read
@ -262,28 +367,27 @@ func (n *NumpyAdapter) ReadInt16(size int) ([]int16, error) {
return data, nil
}
func (n *NumpyAdapter) ReadInt32(size int) ([]int32, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
func (n *NumpyAdapter) ReadInt32(count int) ([]int32, error) {
if count <= 0 {
return nil, errors.New("Numpy adapter: cannot read int32 data with a zero or nagative count")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "i4", "<i4", "|i4", ">i4", "int32":
default:
return nil, errors.New("numpy data is not int32 type")
if n.dataType != schemapb.DataType_Int32 {
return nil, errors.New("Numpy adapter: numpy data is not int32 type")
}
// avoid read overflow
readSize := n.checkSize(size)
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("nothing to read")
return nil, errors.New("Numpy adapter: end of int32 file, nothing to read")
}
// read data
data := make([]int32, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
return nil, fmt.Errorf("Numpy adapter: failed to read int32 data with count %d, error: %w", readSize, err)
}
// update read position after successfully read
@ -292,28 +396,27 @@ func (n *NumpyAdapter) ReadInt32(size int) ([]int32, error) {
return data, nil
}
func (n *NumpyAdapter) ReadInt64(size int) ([]int64, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
func (n *NumpyAdapter) ReadInt64(count int) ([]int64, error) {
if count <= 0 {
return nil, errors.New("Numpy adapter: cannot read int64 data with a zero or nagative count")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "i8", "<i8", "|i8", ">i8", "int64":
default:
return nil, errors.New("numpy data is not int64 type")
if n.dataType != schemapb.DataType_Int64 {
return nil, errors.New("Numpy adapter: numpy data is not int64 type")
}
// avoid read overflow
readSize := n.checkSize(size)
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("nothing to read")
return nil, errors.New("Numpy adapter: end of int64 file, nothing to read")
}
// read data
data := make([]int64, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
return nil, fmt.Errorf("Numpy adapter: failed to read int64 data with count %d, error: %w", readSize, err)
}
// update read position after successfully read
@ -322,28 +425,27 @@ func (n *NumpyAdapter) ReadInt64(size int) ([]int64, error) {
return data, nil
}
func (n *NumpyAdapter) ReadFloat32(size int) ([]float32, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
func (n *NumpyAdapter) ReadFloat32(count int) ([]float32, error) {
if count <= 0 {
return nil, errors.New("Numpy adapter: cannot read float32 data with a zero or nagative count")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "f4", "<f4", "|f4", ">f4", "float32":
default:
return nil, errors.New("numpy data is not float32 type")
if n.dataType != schemapb.DataType_Float {
return nil, errors.New("Numpy adapter: numpy data is not float32 type")
}
// avoid read overflow
readSize := n.checkSize(size)
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("nothing to read")
return nil, errors.New("Numpy adapter: end of float32 file, nothing to read")
}
// read data
data := make([]float32, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
return nil, fmt.Errorf("Numpy adapter: failed to read float32 data with count %d, error: %w", readSize, err)
}
// update read position after successfully read
@ -352,28 +454,109 @@ func (n *NumpyAdapter) ReadFloat32(size int) ([]float32, error) {
return data, nil
}
func (n *NumpyAdapter) ReadFloat64(size int) ([]float64, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
func (n *NumpyAdapter) ReadFloat64(count int) ([]float64, error) {
if count <= 0 {
return nil, errors.New("Numpy adapter: cannot read float64 data with a zero or nagative count")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "f8", "<f8", "|f8", ">f8", "float64":
default:
return nil, errors.New("numpy data is not float32 type")
if n.dataType != schemapb.DataType_Double {
return nil, errors.New("Numpy adapter: numpy data is not float64 type")
}
// avoid read overflow
readSize := n.checkSize(size)
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("nothing to read")
return nil, errors.New("Numpy adapter: end of float64 file, nothing to read")
}
// read data
data := make([]float64, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
return nil, fmt.Errorf("Numpy adapter: failed to read float64 data with count %d, error: %w", readSize, err)
}
// update read position after successfully read
n.readPosition += readSize
return data, nil
}
func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
if count <= 0 {
return nil, errors.New("Numpy adapter: cannot read varhar data with a zero or nagative count")
}
// incorrect type
if n.dataType != schemapb.DataType_VarChar {
return nil, errors.New("Numpy adapter: numpy data is not varhar type")
}
// varchar length, this is the max length, some item is shorter than this length, but they also occupy bytes of max length
maxLen, utf, err := stringLen(n.npyReader.Header.Descr.Type)
if err != nil || maxLen <= 0 {
log.Error("Numpy adapter: failed to get max length of varchar from numpy file header", zap.Int("maxLen", maxLen), zap.Any("err", err))
return nil, fmt.Errorf("Numpy adapter: failed to get max length %d of varchar from numpy file header, error: %w", maxLen, err)
}
log.Info("Numpy adapter: get varchar max length from numpy file header", zap.Int("maxLen", maxLen), zap.Bool("utf", utf))
// avoid read overflow
readSize := n.checkCount(count)
if readSize <= 0 {
return nil, errors.New("Numpy adapter: end of varhar file, nothing to read")
}
// read data
data := make([]string, 0)
for i := 0; i < readSize; i++ {
if utf {
// in the numpy file, each utf8 character occupy utf8.UTFMax bytes, each string occupys utf8.UTFMax*maxLen bytes
// for example, an ANSI character "a" only uses one byte, but it still occupy utf8.UTFMax bytes
// a chinese character uses three bytes, it also occupy utf8.UTFMax bytes
raw, err := ioutil.ReadAll(io.LimitReader(n.reader, utf8.UTFMax*int64(maxLen)))
if err != nil {
log.Error("Numpy adapter: failed to read utf8 string from numpy file", zap.Int("i", i), zap.Any("err", err))
return nil, fmt.Errorf("Numpy adapter: failed to read utf8 string from numpy file, error: %w", err)
}
var str string
for len(raw) > 0 {
r, _ := utf8.DecodeRune(raw)
if r == utf8.RuneError {
log.Error("Numpy adapter: failed to decode utf8 string from numpy file", zap.Any("raw", raw[:utf8.UTFMax]))
return nil, fmt.Errorf("Numpy adapter: failed to decode utf8 string from numpy file, error: illegal utf-8 encoding")
}
// only support ascii characters, because the numpy lib encode the utf8 bytes by its internal method,
// the encode/decode logic is not clear now, return error
n := n.order.Uint32(raw)
if n > 127 {
log.Error("Numpy adapter: a string contains non-ascii characters, not support yet", zap.Int32("utf8Code", r))
return nil, fmt.Errorf("Numpy adapter: a string contains non-ascii characters, not support yet")
}
// if a string is shorter than maxLen, the tail characters will be filled with "\u0000"(in utf spec this is Null)
if r > 0 {
str += string(r)
}
raw = raw[utf8.UTFMax:]
}
data = append(data, str)
} else {
buf, err := ioutil.ReadAll(io.LimitReader(n.reader, int64(maxLen)))
if err != nil {
log.Error("Numpy adapter: failed to read string from numpy file", zap.Int("i", i), zap.Any("err", err))
return nil, fmt.Errorf("Numpy adapter: failed to read string from numpy file, error: %w", err)
}
n := bytes.Index(buf, []byte{0})
if n > 0 {
buf = buf[:n]
}
data = append(data, string(buf))
}
}
// update read position after successfully read

View File

@ -22,6 +22,7 @@ import (
"os"
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/sbinet/npyio/npy"
"github.com/stretchr/testify/assert"
)
@ -59,6 +60,55 @@ func Test_CreateNumpyData(t *testing.T) {
assert.Nil(t, buf)
}
func Test_ConvertNumpyType(t *testing.T) {
checkFunc := func(inputs []string, output schemapb.DataType) {
for i := 0; i < len(inputs); i++ {
dt, err := convertNumpyType(inputs[i])
assert.Nil(t, err)
assert.Equal(t, output, dt)
}
}
checkFunc([]string{"b1", "<b1", "|b1", "bool"}, schemapb.DataType_Bool)
checkFunc([]string{"i1", "<i1", "|i1", ">i1", "int8"}, schemapb.DataType_Int8)
checkFunc([]string{"i2", "<i2", "|i2", ">i2", "int16"}, schemapb.DataType_Int16)
checkFunc([]string{"i4", "<i4", "|i4", ">i4", "int32"}, schemapb.DataType_Int32)
checkFunc([]string{"i8", "<i8", "|i8", ">i8", "int64"}, schemapb.DataType_Int64)
checkFunc([]string{"f4", "<f4", "|f4", ">f4", "float32"}, schemapb.DataType_Float)
checkFunc([]string{"f8", "<f8", "|f8", ">f8", "float64"}, schemapb.DataType_Double)
dt, err := convertNumpyType("dummy")
assert.NotNil(t, err)
assert.Equal(t, schemapb.DataType_None, dt)
}
func Test_StringLen(t *testing.T) {
len, utf, err := stringLen("S1")
assert.Equal(t, 1, len)
assert.False(t, utf)
assert.Nil(t, err)
len, utf, err = stringLen("2S")
assert.Equal(t, 2, len)
assert.False(t, utf)
assert.Nil(t, err)
len, utf, err = stringLen("<U3")
assert.Equal(t, 3, len)
assert.True(t, utf)
assert.Nil(t, err)
len, utf, err = stringLen(">4U")
assert.Equal(t, 4, len)
assert.True(t, utf)
assert.Nil(t, err)
len, utf, err = stringLen("dummy")
assert.NotNil(t, err)
assert.Equal(t, 0, len)
assert.False(t, utf)
}
func Test_NumpyAdapterSetByteOrder(t *testing.T) {
adapter := &NumpyAdapter{
reader: nil,
@ -82,32 +132,32 @@ func Test_NumpyAdapterReadError(t *testing.T) {
npyReader: nil,
}
// reader is nil
{
_, err := adapter.ReadBool(1)
// reader size is zero
t.Run("test size is zero", func(t *testing.T) {
_, err := adapter.ReadBool(0)
assert.NotNil(t, err)
_, err = adapter.ReadUint8(1)
_, err = adapter.ReadUint8(0)
assert.NotNil(t, err)
_, err = adapter.ReadInt8(1)
_, err = adapter.ReadInt8(0)
assert.NotNil(t, err)
_, err = adapter.ReadInt16(1)
_, err = adapter.ReadInt16(0)
assert.NotNil(t, err)
_, err = adapter.ReadInt32(1)
_, err = adapter.ReadInt32(0)
assert.NotNil(t, err)
_, err = adapter.ReadInt64(1)
_, err = adapter.ReadInt64(0)
assert.NotNil(t, err)
_, err = adapter.ReadFloat32(1)
_, err = adapter.ReadFloat32(0)
assert.NotNil(t, err)
_, err = adapter.ReadFloat64(1)
_, err = adapter.ReadFloat64(0)
assert.NotNil(t, err)
}
})
adapter = &NumpyAdapter{
reader: &MockReader{},
npyReader: &npy.Reader{},
}
{
t.Run("test read bool", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "bool"
data, err := adapter.ReadBool(1)
assert.Nil(t, data)
@ -117,9 +167,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
data, err = adapter.ReadBool(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
})
{
t.Run("test read uint8", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "u1"
data, err := adapter.ReadUint8(1)
assert.Nil(t, data)
@ -129,9 +179,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
data, err = adapter.ReadUint8(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
})
{
t.Run("test read int8", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "i1"
data, err := adapter.ReadInt8(1)
assert.Nil(t, data)
@ -141,9 +191,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
data, err = adapter.ReadInt8(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
})
{
t.Run("test read int16", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "i2"
data, err := adapter.ReadInt16(1)
assert.Nil(t, data)
@ -153,9 +203,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
data, err = adapter.ReadInt16(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
})
{
t.Run("test read int32", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "i4"
data, err := adapter.ReadInt32(1)
assert.Nil(t, data)
@ -165,9 +215,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
data, err = adapter.ReadInt32(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
})
{
t.Run("test read int64", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "i8"
data, err := adapter.ReadInt64(1)
assert.Nil(t, data)
@ -177,9 +227,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
data, err = adapter.ReadInt64(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
})
{
t.Run("test read float", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "f4"
data, err := adapter.ReadFloat32(1)
assert.Nil(t, data)
@ -189,9 +239,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
data, err = adapter.ReadFloat32(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
})
{
t.Run("test read double", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "f8"
data, err := adapter.ReadFloat64(1)
assert.Nil(t, data)
@ -201,7 +251,19 @@ func Test_NumpyAdapterReadError(t *testing.T) {
data, err = adapter.ReadFloat64(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
})
t.Run("test read varchar", func(t *testing.T) {
adapter.npyReader.Header.Descr.Type = "U3"
data, err := adapter.ReadString(1)
assert.Nil(t, data)
assert.NotNil(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
data, err = adapter.ReadString(1)
assert.Nil(t, data)
assert.NotNil(t, err)
})
}
func Test_NumpyAdapterRead(t *testing.T) {
@ -209,7 +271,7 @@ func Test_NumpyAdapterRead(t *testing.T) {
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
{
t.Run("test read bool", func(t *testing.T) {
filePath := TempFilesPath + "bool.npy"
data := []bool{true, false, true, false}
err := CreateNumpyFile(filePath, data)
@ -267,9 +329,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
resf8, err := adapter.ReadFloat64(len(data))
assert.NotNil(t, err)
assert.Nil(t, resf8)
}
})
{
t.Run("test read uint8", func(t *testing.T) {
filePath := TempFilesPath + "uint8.npy"
data := []uint8{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
@ -303,9 +365,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
resb, err := adapter.ReadBool(len(data))
assert.NotNil(t, err)
assert.Nil(t, resb)
}
})
{
t.Run("test read int8", func(t *testing.T) {
filePath := TempFilesPath + "int8.npy"
data := []int8{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
@ -334,9 +396,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
res, err = adapter.ReadInt8(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
}
})
{
t.Run("test read int16", func(t *testing.T) {
filePath := TempFilesPath + "int16.npy"
data := []int16{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
@ -365,9 +427,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
res, err = adapter.ReadInt16(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
}
})
{
t.Run("test read int32", func(t *testing.T) {
filePath := TempFilesPath + "int32.npy"
data := []int32{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
@ -396,9 +458,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
res, err = adapter.ReadInt32(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
}
})
{
t.Run("test read int64", func(t *testing.T) {
filePath := TempFilesPath + "int64.npy"
data := []int64{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
@ -427,9 +489,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
res, err = adapter.ReadInt64(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
}
})
{
t.Run("test read float", func(t *testing.T) {
filePath := TempFilesPath + "float.npy"
data := []float32{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
@ -458,9 +520,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
res, err = adapter.ReadFloat32(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
}
})
{
t.Run("test read double", func(t *testing.T) {
filePath := TempFilesPath + "double.npy"
data := []float64{1, 2, 3, 4, 5, 6}
err := CreateNumpyFile(filePath, data)
@ -489,5 +551,52 @@ func Test_NumpyAdapterRead(t *testing.T) {
res, err = adapter.ReadFloat64(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
})
t.Run("test read ascii characters", func(t *testing.T) {
filePath := TempFilesPath + "varchar1.npy"
data := []string{"a", "bbb", "c", "dd", "eeee", "fff"}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
res, err := adapter.ReadString(len(data) - 1)
assert.Nil(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
res, err = adapter.ReadString(len(data))
assert.Nil(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
res, err = adapter.ReadString(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
})
t.Run("test read non-ascii", func(t *testing.T) {
filePath := TempFilesPath + "varchar2.npy"
data := []string{"a三百", "马克bbb"}
err := CreateNumpyFile(filePath, data)
assert.Nil(t, err)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
res, err := adapter.ReadString(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
})
}

View File

@ -19,12 +19,13 @@ package importutil
import (
"context"
"errors"
"fmt"
"io"
"strconv"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/storage"
"go.uber.org/zap"
)
type ColumnDesc struct {
@ -43,7 +44,7 @@ type NumpyParser struct {
callFlushFunc func(field storage.FieldData) error // call back function to output column data
}
// NewNumpyParser helper function to create a NumpyParser
// NewNumpyParser is helper function to create a NumpyParser
func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema,
flushFunc func(field storage.FieldData) error) *NumpyParser {
if collectionSchema == nil || flushFunc == nil {
@ -60,38 +61,10 @@ func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSc
return parser
}
func (p *NumpyParser) logError(msg string) error {
log.Error(msg)
return errors.New(msg)
}
// data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector)
func convertNumpyType(str string) (schemapb.DataType, error) {
switch str {
case "b1", "<b1", "|b1", "bool":
return schemapb.DataType_Bool, nil
case "u1", "<u1", "|u1", "uint8": // binary vector data type is uint8
return schemapb.DataType_BinaryVector, nil
case "i1", "<i1", "|i1", ">i1", "int8":
return schemapb.DataType_Int8, nil
case "i2", "<i2", "|i2", ">i2", "int16":
return schemapb.DataType_Int16, nil
case "i4", "<i4", "|i4", ">i4", "int32":
return schemapb.DataType_Int32, nil
case "i8", "<i8", "|i8", ">i8", "int64":
return schemapb.DataType_Int64, nil
case "f4", "<f4", "|f4", ">f4", "float32":
return schemapb.DataType_Float, nil
case "f8", "<f8", "|f8", ">f8", "float64":
return schemapb.DataType_Double, nil
default:
return schemapb.DataType_None, errors.New("unsupported data type " + str)
}
}
func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
if adapter == nil {
return errors.New("numpy adapter is nil")
log.Error("Numpy parser: numpy adapter is nil")
return errors.New("Numpy parser: numpy adapter is nil")
}
// check existence of the target field
@ -105,27 +78,32 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
}
if p.columnDesc.name == "" {
return errors.New("the field " + fieldName + " doesn't exist")
log.Error("Numpy parser: Numpy parser: the field is not found in collection schema", zap.String("fieldName", fieldName))
return fmt.Errorf("Numpy parser: the field name '%s' is not found in collection schema", fieldName)
}
p.columnDesc.dt = schema.DataType
elementType, err := convertNumpyType(adapter.GetType())
if err != nil {
return err
}
elementType := adapter.GetType()
shape := adapter.GetShape()
var err error
// 1. field data type should be consist to numpy data type
// 2. vector field dimension should be consist to numpy shape
if schemapb.DataType_FloatVector == schema.DataType {
// float32/float64 numpy file can be used for float vector file, 2 reasons:
// 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit
// 2. for float64 numpy file, the performance is worse than float32 numpy file
if elementType != schemapb.DataType_Float && elementType != schemapb.DataType_Double {
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
log.Error("Numpy parser: illegal data type of numpy file for float vector field", zap.Any("dataType", elementType),
zap.String("fieldName", fieldName))
return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for float vector field '%s'", getTypeName(elementType), schema.GetName())
}
// vector field, the shape should be 2
if len(shape) != 2 {
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
log.Error("Numpy parser: illegal shape of numpy file for float vector field, shape should be 2", zap.Int("shape", len(shape)),
zap.String("fieldName", fieldName))
return fmt.Errorf("Numpy parser: illegal shape %d of numpy file for float vector field '%s', shape should be 2", shape, schema.GetName())
}
// shape[0] is row count, shape[1] is element count per row
@ -137,16 +115,23 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
}
if shape[1] != p.columnDesc.dimension {
return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension))
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName),
zap.Int("numpyDimension", shape[1]), zap.Int("fieldDimension", p.columnDesc.dimension))
return fmt.Errorf("Numpy parser: illegal dimension %d of numpy file for float vector field '%s', dimension should be %d",
shape[1], schema.GetName(), p.columnDesc.dimension)
}
} else if schemapb.DataType_BinaryVector == schema.DataType {
if elementType != schemapb.DataType_BinaryVector {
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
log.Error("Numpy parser: illegal data type of numpy file for binary vector field", zap.Any("dataType", elementType),
zap.String("fieldName", fieldName))
return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for binary vector field '%s'", getTypeName(elementType), schema.GetName())
}
// vector field, the shape should be 2
if len(shape) != 2 {
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
log.Error("Numpy parser: illegal shape of numpy file for binary vector field, shape should be 2", zap.Int("shape", len(shape)),
zap.String("fieldName", fieldName))
return fmt.Errorf("Numpy parser: illegal shape %d of numpy file for binary vector field '%s', shape should be 2", shape, schema.GetName())
}
// shape[0] is row count, shape[1] is element count per row
@ -158,16 +143,24 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
}
if shape[1] != p.columnDesc.dimension/8 {
return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension))
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName),
zap.Int("numpyDimension", shape[1]*8), zap.Int("fieldDimension", p.columnDesc.dimension))
return fmt.Errorf("Numpy parser: illegal dimension %d of numpy file for binary vector field '%s', dimension should be %d",
shape[1]*8, schema.GetName(), p.columnDesc.dimension)
}
} else {
if elementType != schema.DataType {
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType),
zap.String("fieldName", fieldName), zap.Any("fieldDataType", schema.DataType))
return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for scalar field '%s' with type %d",
getTypeName(elementType), schema.GetName(), schema.DataType)
}
// scalar field, the shape should be 1
if len(shape) != 1 {
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
log.Error("Numpy parser: illegal shape of numpy file for scalar field, shape should be 1", zap.Int("shape", len(shape)),
zap.String("fieldName", fieldName))
return fmt.Errorf("Numpy parser: illegal shape %d of numpy file for scalar field '%s', shape should be 1", shape, schema.GetName())
}
p.columnDesc.elementCount = shape[0]
@ -176,13 +169,14 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
return nil
}
// this method read numpy data section into a storage.FieldData
// consume method reads numpy data section into a storage.FieldData
// please note it will require a large memory block(the memory size is almost equal to numpy file size)
func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
switch p.columnDesc.dt {
case schemapb.DataType_Bool:
data, err := adapter.ReadBool(p.columnDesc.elementCount)
if err != nil {
log.Error(err.Error())
return err
}
@ -194,6 +188,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
case schemapb.DataType_Int8:
data, err := adapter.ReadInt8(p.columnDesc.elementCount)
if err != nil {
log.Error(err.Error())
return err
}
@ -204,6 +199,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
case schemapb.DataType_Int16:
data, err := adapter.ReadInt16(p.columnDesc.elementCount)
if err != nil {
log.Error(err.Error())
return err
}
@ -214,6 +210,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
case schemapb.DataType_Int32:
data, err := adapter.ReadInt32(p.columnDesc.elementCount)
if err != nil {
log.Error(err.Error())
return err
}
@ -224,6 +221,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
case schemapb.DataType_Int64:
data, err := adapter.ReadInt64(p.columnDesc.elementCount)
if err != nil {
log.Error(err.Error())
return err
}
@ -234,6 +232,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
case schemapb.DataType_Float:
data, err := adapter.ReadFloat32(p.columnDesc.elementCount)
if err != nil {
log.Error(err.Error())
return err
}
@ -244,6 +243,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
case schemapb.DataType_Double:
data, err := adapter.ReadFloat64(p.columnDesc.elementCount)
if err != nil {
log.Error(err.Error())
return err
}
@ -251,9 +251,21 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
NumRows: []int64{int64(p.columnDesc.elementCount)},
Data: data,
}
case schemapb.DataType_VarChar:
data, err := adapter.ReadString(p.columnDesc.elementCount)
if err != nil {
log.Error(err.Error())
return err
}
p.columnData = &storage.StringFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
Data: data,
}
case schemapb.DataType_BinaryVector:
data, err := adapter.ReadUint8(p.columnDesc.elementCount)
if err != nil {
log.Error(err.Error())
return err
}
@ -263,24 +275,24 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
Dim: p.columnDesc.dimension,
}
case schemapb.DataType_FloatVector:
// for float vector, we support float32 and float64 numpy file because python float value is 64 bit
// for float64 numpy file, the performance is worse than float32 numpy file
// we don't check overflow here
elementType, err := convertNumpyType(adapter.GetType())
if err != nil {
return err
}
// float32/float64 numpy file can be used for float vector file, 2 reasons:
// 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit
// 2. for float64 numpy file, the performance is worse than float32 numpy file
elementType := adapter.GetType()
var data []float32
var err error
if elementType == schemapb.DataType_Float {
data, err = adapter.ReadFloat32(p.columnDesc.elementCount)
if err != nil {
log.Error(err.Error())
return err
}
} else if elementType == schemapb.DataType_Double {
data = make([]float32, 0, p.columnDesc.elementCount)
data64, err := adapter.ReadFloat64(p.columnDesc.elementCount)
if err != nil {
log.Error(err.Error())
return err
}
@ -295,7 +307,8 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
Dim: p.columnDesc.dimension,
}
default:
return errors.New("unsupported data type: " + strconv.Itoa(int(p.columnDesc.dt)))
log.Error("Numpy parser: unsupported data type of field", zap.Any("dataType", p.columnDesc.dt), zap.String("fieldName", p.columnDesc.name))
return fmt.Errorf("Numpy parser: unsupported data type %s of field '%s'", getTypeName(p.columnDesc.dt), p.columnDesc.name)
}
return nil
@ -304,13 +317,13 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate bool) error {
adapter, err := NewNumpyAdapter(reader)
if err != nil {
return p.logError("Numpy parse: " + err.Error())
return err
}
// the validation method only check the file header information
err = p.validate(adapter, fieldName)
if err != nil {
return p.logError("Numpy parse: " + err.Error())
return err
}
if onlyValidate {
@ -320,7 +333,7 @@ func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate boo
// read all data from the numpy file
err = p.consume(adapter)
if err != nil {
return p.logError("Numpy parse: " + err.Error())
return err
}
return p.callFlushFunc(p.columnData)

View File

@ -36,28 +36,6 @@ func Test_NewNumpyParser(t *testing.T) {
assert.Nil(t, parser)
}
func Test_ConvertNumpyType(t *testing.T) {
checkFunc := func(inputs []string, output schemapb.DataType) {
for i := 0; i < len(inputs); i++ {
dt, err := convertNumpyType(inputs[i])
assert.Nil(t, err)
assert.Equal(t, output, dt)
}
}
checkFunc([]string{"b1", "<b1", "|b1", "bool"}, schemapb.DataType_Bool)
checkFunc([]string{"i1", "<i1", "|i1", ">i1", "int8"}, schemapb.DataType_Int8)
checkFunc([]string{"i2", "<i2", "|i2", ">i2", "int16"}, schemapb.DataType_Int16)
checkFunc([]string{"i4", "<i4", "|i4", ">i4", "int32"}, schemapb.DataType_Int32)
checkFunc([]string{"i8", "<i8", "|i8", ">i8", "int64"}, schemapb.DataType_Int64)
checkFunc([]string{"f4", "<f4", "|f4", ">f4", "float32"}, schemapb.DataType_Float)
checkFunc([]string{"f8", "<f8", "|f8", ">f8", "float64"}, schemapb.DataType_Double)
dt, err := convertNumpyType("dummy")
assert.NotNil(t, err)
assert.Equal(t, schemapb.DataType_None, dt)
}
func Test_NumpyParserValidate(t *testing.T) {
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
@ -71,7 +49,7 @@ func Test_NumpyParserValidate(t *testing.T) {
adapter := &NumpyAdapter{npyReader: &npy.Reader{}}
{
t.Run("not support DataType_String", func(t *testing.T) {
// string type is not supported
p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
@ -88,15 +66,14 @@ func Test_NumpyParserValidate(t *testing.T) {
assert.NotNil(t, err)
err = p.validate(adapter, "field_string")
assert.NotNil(t, err)
}
})
// reader is nil
parser := NewNumpyParser(ctx, schema, flushFunc)
err = parser.validate(nil, "")
assert.NotNil(t, err)
// validate scalar data
func() {
t.Run("validate scalar", func(t *testing.T) {
filePath := TempFilesPath + "scalar_1.npy"
data1 := []float64{0, 1, 2, 3, 4, 5}
err := CreateNumpyFile(filePath, data1)
@ -108,6 +85,7 @@ func Test_NumpyParserValidate(t *testing.T) {
adapter, err := NewNumpyAdapter(file1)
assert.Nil(t, err)
assert.NotNil(t, adapter)
err = parser.validate(adapter, "field_double")
assert.Nil(t, err)
@ -128,6 +106,7 @@ func Test_NumpyParserValidate(t *testing.T) {
adapter, err = NewNumpyAdapter(file2)
assert.Nil(t, err)
assert.NotNil(t, adapter)
err = parser.validate(adapter, "field_double")
assert.NotNil(t, err)
@ -144,13 +123,13 @@ func Test_NumpyParserValidate(t *testing.T) {
adapter, err = NewNumpyAdapter(file3)
assert.Nil(t, err)
assert.NotNil(t, adapter)
err = parser.validate(adapter, "field_double")
assert.NotNil(t, err)
}()
})
// validate binary vector data
func() {
t.Run("validate binary vector", func(t *testing.T) {
filePath := TempFilesPath + "binary_vector_1.npy"
data1 := [][2]uint8{{0, 1}, {2, 3}, {4, 5}}
err := CreateNumpyFile(filePath, data1)
@ -162,6 +141,7 @@ func Test_NumpyParserValidate(t *testing.T) {
adapter, err := NewNumpyAdapter(file1)
assert.Nil(t, err)
assert.NotNil(t, adapter)
err = parser.validate(adapter, "field_binary_vector")
assert.Nil(t, err)
@ -178,10 +158,8 @@ func Test_NumpyParserValidate(t *testing.T) {
defer file2.Close()
adapter, err = NewNumpyAdapter(file2)
assert.Nil(t, err)
err = parser.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
assert.Nil(t, adapter)
// shape mismatch
filePath = TempFilesPath + "binary_vector_3.npy"
@ -195,6 +173,7 @@ func Test_NumpyParserValidate(t *testing.T) {
adapter, err = NewNumpyAdapter(file3)
assert.Nil(t, err)
assert.NotNil(t, adapter)
err = parser.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
@ -211,6 +190,7 @@ func Test_NumpyParserValidate(t *testing.T) {
adapter, err = NewNumpyAdapter(file4)
assert.Nil(t, err)
assert.NotNil(t, adapter)
err = parser.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
@ -228,10 +208,9 @@ func Test_NumpyParserValidate(t *testing.T) {
err = p.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
}()
})
// validate float vector data
func() {
t.Run("validate float vector", func(t *testing.T) {
filePath := TempFilesPath + "float_vector.npy"
data1 := [][4]float32{{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}, {3, 3, 3, 3}}
err := CreateNumpyFile(filePath, data1)
@ -243,6 +222,7 @@ func Test_NumpyParserValidate(t *testing.T) {
adapter, err := NewNumpyAdapter(file1)
assert.Nil(t, err)
assert.NotNil(t, adapter)
err = parser.validate(adapter, "field_float_vector")
assert.Nil(t, err)
@ -260,6 +240,7 @@ func Test_NumpyParserValidate(t *testing.T) {
adapter, err = NewNumpyAdapter(file2)
assert.Nil(t, err)
assert.NotNil(t, adapter)
err = parser.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
@ -276,6 +257,7 @@ func Test_NumpyParserValidate(t *testing.T) {
adapter, err = NewNumpyAdapter(file3)
assert.Nil(t, err)
assert.NotNil(t, adapter)
err = parser.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
@ -292,6 +274,7 @@ func Test_NumpyParserValidate(t *testing.T) {
adapter, err = NewNumpyAdapter(file4)
assert.Nil(t, err)
assert.NotNil(t, adapter)
err = parser.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
@ -309,7 +292,7 @@ func Test_NumpyParserValidate(t *testing.T) {
err = p.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
}()
})
}
func Test_NumpyParserParse(t *testing.T) {
@ -355,153 +338,179 @@ func Test_NumpyParserParse(t *testing.T) {
}()
}
// scalar bool
data1 := []bool{true, false, true, false, true}
t.Run("parse scalar bool", func(t *testing.T) {
data := []bool{true, false, true, false, true}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data1), field.RowNum())
assert.Equal(t, len(data), field.RowNum())
for i := 0; i < len(data1); i++ {
assert.Equal(t, data1[i], field.GetRow(i))
for i := 0; i < len(data); i++ {
assert.Equal(t, data[i], field.GetRow(i))
}
return nil
}
checkFunc(data1, "field_bool", flushFunc)
checkFunc(data, "field_bool", flushFunc)
})
// scalar int8
data2 := []int8{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
t.Run("parse scalar int8", func(t *testing.T) {
data := []int8{1, 2, 3, 4, 5}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data2), field.RowNum())
assert.Equal(t, len(data), field.RowNum())
for i := 0; i < len(data2); i++ {
assert.Equal(t, data2[i], field.GetRow(i))
for i := 0; i < len(data); i++ {
assert.Equal(t, data[i], field.GetRow(i))
}
return nil
}
checkFunc(data2, "field_int8", flushFunc)
checkFunc(data, "field_int8", flushFunc)
})
// scalar int16
data3 := []int16{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
t.Run("parse scalar int16", func(t *testing.T) {
data := []int16{1, 2, 3, 4, 5}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data3), field.RowNum())
assert.Equal(t, len(data), field.RowNum())
for i := 0; i < len(data3); i++ {
assert.Equal(t, data3[i], field.GetRow(i))
for i := 0; i < len(data); i++ {
assert.Equal(t, data[i], field.GetRow(i))
}
return nil
}
checkFunc(data3, "field_int16", flushFunc)
checkFunc(data, "field_int16", flushFunc)
})
// scalar int32
data4 := []int32{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
t.Run("parse scalar int32", func(t *testing.T) {
data := []int32{1, 2, 3, 4, 5}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data4), field.RowNum())
assert.Equal(t, len(data), field.RowNum())
for i := 0; i < len(data4); i++ {
assert.Equal(t, data4[i], field.GetRow(i))
for i := 0; i < len(data); i++ {
assert.Equal(t, data[i], field.GetRow(i))
}
return nil
}
checkFunc(data4, "field_int32", flushFunc)
checkFunc(data, "field_int32", flushFunc)
})
// scalar int64
data5 := []int64{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
t.Run("parse scalar int64", func(t *testing.T) {
data := []int64{1, 2, 3, 4, 5}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data5), field.RowNum())
assert.Equal(t, len(data), field.RowNum())
for i := 0; i < len(data5); i++ {
assert.Equal(t, data5[i], field.GetRow(i))
for i := 0; i < len(data); i++ {
assert.Equal(t, data[i], field.GetRow(i))
}
return nil
}
checkFunc(data5, "field_int64", flushFunc)
checkFunc(data, "field_int64", flushFunc)
})
// scalar float
data6 := []float32{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
t.Run("parse scalar float", func(t *testing.T) {
data := []float32{1, 2, 3, 4, 5}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data6), field.RowNum())
assert.Equal(t, len(data), field.RowNum())
for i := 0; i < len(data6); i++ {
assert.Equal(t, data6[i], field.GetRow(i))
for i := 0; i < len(data); i++ {
assert.Equal(t, data[i], field.GetRow(i))
}
return nil
}
checkFunc(data6, "field_float", flushFunc)
checkFunc(data, "field_float", flushFunc)
})
// scalar double
data7 := []float64{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
t.Run("parse scalar double", func(t *testing.T) {
data := []float64{1, 2, 3, 4, 5}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data7), field.RowNum())
assert.Equal(t, len(data), field.RowNum())
for i := 0; i < len(data7); i++ {
assert.Equal(t, data7[i], field.GetRow(i))
for i := 0; i < len(data); i++ {
assert.Equal(t, data[i], field.GetRow(i))
}
return nil
}
checkFunc(data7, "field_double", flushFunc)
checkFunc(data, "field_double", flushFunc)
})
// binary vector
data8 := [][2]uint8{{1, 2}, {3, 4}, {5, 6}}
flushFunc = func(field storage.FieldData) error {
t.Run("parse scalar varchar", func(t *testing.T) {
data := []string{"abcd", "sdb", "ok", "milvus"}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data8), field.RowNum())
assert.Equal(t, len(data), field.RowNum())
for i := 0; i < len(data8); i++ {
for i := 0; i < len(data); i++ {
assert.Equal(t, data[i], field.GetRow(i))
}
return nil
}
checkFunc(data, "field_string", flushFunc)
})
t.Run("parse binary vector", func(t *testing.T) {
data := [][2]uint8{{1, 2}, {3, 4}, {5, 6}}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data), field.RowNum())
for i := 0; i < len(data); i++ {
row := field.GetRow(i).([]uint8)
for k := 0; k < len(row); k++ {
assert.Equal(t, data8[i][k], row[k])
assert.Equal(t, data[i][k], row[k])
}
}
return nil
}
checkFunc(data8, "field_binary_vector", flushFunc)
checkFunc(data, "field_binary_vector", flushFunc)
})
// double vector(element can be float32 or float64)
data9 := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
flushFunc = func(field storage.FieldData) error {
t.Run("parse binary vector with float32", func(t *testing.T) {
data := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data9), field.RowNum())
assert.Equal(t, len(data), field.RowNum())
for i := 0; i < len(data9); i++ {
for i := 0; i < len(data); i++ {
row := field.GetRow(i).([]float32)
for k := 0; k < len(row); k++ {
assert.Equal(t, data9[i][k], row[k])
assert.Equal(t, data[i][k], row[k])
}
}
return nil
}
checkFunc(data9, "field_float_vector", flushFunc)
checkFunc(data, "field_float_vector", flushFunc)
})
data10 := [][4]float64{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
flushFunc = func(field storage.FieldData) error {
t.Run("parse binary vector with float64", func(t *testing.T) {
data := [][4]float64{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data10), field.RowNum())
assert.Equal(t, len(data), field.RowNum())
for i := 0; i < len(data10); i++ {
for i := 0; i < len(data); i++ {
row := field.GetRow(i).([]float32)
for k := 0; k < len(row); k++ {
assert.Equal(t, float32(data10[i][k]), row[k])
assert.Equal(t, float32(data[i][k]), row[k])
}
}
return nil
}
checkFunc(data10, "field_float_vector", flushFunc)
checkFunc(data, "field_float_vector", flushFunc)
})
}
func Test_NumpyParserParse_perf(t *testing.T) {