mirror of https://github.com/milvus-io/milvus.git
Refine bulkload (#19671)
Signed-off-by: yhmo <yihua.mo@zilliz.com> Signed-off-by: yhmo <yihua.mo@zilliz.com>pull/20137/head
parent
def5972e01
commit
bee66631e3
|
@ -71,7 +71,7 @@ func (h *ServerHandler) GetDataVChanPositions(channel *channel, partitionID Uniq
|
|||
continue
|
||||
}
|
||||
if s.GetIsImporting() {
|
||||
// Skip bulk load segments.
|
||||
// Skip bulk insert segments.
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -149,7 +149,7 @@ func (h *ServerHandler) GetQueryVChanPositions(channel *channel, partitionID Uni
|
|||
continue
|
||||
}
|
||||
if s.GetIsImporting() {
|
||||
// Skip bulk load segments.
|
||||
// Skip bulk insert segments.
|
||||
continue
|
||||
}
|
||||
segmentInfos[s.GetID()] = s
|
||||
|
|
|
@ -68,7 +68,7 @@ func putAllocation(a *Allocation) {
|
|||
type Manager interface {
|
||||
// AllocSegment allocates rows and record the allocation.
|
||||
AllocSegment(ctx context.Context, collectionID, partitionID UniqueID, channelName string, requestRows int64) ([]*Allocation, error)
|
||||
// allocSegmentForImport allocates one segment allocation for bulk load.
|
||||
// allocSegmentForImport allocates one segment allocation for bulk insert.
|
||||
// TODO: Remove this method and AllocSegment() above instead.
|
||||
allocSegmentForImport(ctx context.Context, collectionID, partitionID UniqueID, channelName string, requestRows int64, taskID int64) (*Allocation, error)
|
||||
// DropSegment drops the segment from manager.
|
||||
|
@ -278,7 +278,7 @@ func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID
|
|||
return allocations, nil
|
||||
}
|
||||
|
||||
// allocSegmentForImport allocates one segment allocation for bulk load.
|
||||
// allocSegmentForImport allocates one segment allocation for bulk insert.
|
||||
func (s *SegmentManager) allocSegmentForImport(ctx context.Context, collectionID UniqueID,
|
||||
partitionID UniqueID, channelName string, requestRows int64, importTaskID int64) (*Allocation, error) {
|
||||
// init allocation
|
||||
|
|
|
@ -637,7 +637,7 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
|
|||
if segment.State != commonpb.SegmentState_Flushed && segment.State != commonpb.SegmentState_Flushing && segment.State != commonpb.SegmentState_Dropped {
|
||||
continue
|
||||
}
|
||||
// Also skip bulk load segments.
|
||||
// Also skip bulk insert segments.
|
||||
if segment.GetIsImporting() {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -1074,16 +1074,19 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest)
|
|||
|
||||
// parse files and generate segments
|
||||
segmentSize := int64(Params.DataCoordCfg.SegmentMaxSize) * 1024 * 1024
|
||||
importWrapper := importutil.NewImportWrapper(newCtx, colInfo.GetSchema(), colInfo.GetShardsNum(), segmentSize, node.rowIDAllocator, node.chunkManager,
|
||||
importFlushReqFunc(node, req, importResult, colInfo.GetSchema(), ts), importResult, reportFunc)
|
||||
importWrapper := importutil.NewImportWrapper(newCtx, colInfo.GetSchema(), colInfo.GetShardsNum(), segmentSize, node.rowIDAllocator,
|
||||
node.chunkManager, importResult, reportFunc)
|
||||
importWrapper.SetCallbackFunctions(assignSegmentFunc(node, req),
|
||||
createBinLogsFunc(node, req, colInfo.GetSchema(), ts),
|
||||
saveSegmentFunc(node, req, importResult, ts))
|
||||
// todo: pass tsStart and tsStart after import_wrapper support
|
||||
tsStart, tsEnd, err := importutil.ParseTSFromOptions(req.GetImportTask().GetInfos())
|
||||
if err != nil {
|
||||
return returnFailFunc(err)
|
||||
}
|
||||
log.Info("import time range", zap.Uint64("start_ts", tsStart), zap.Uint64("end_ts", tsEnd))
|
||||
err = importWrapper.Import(req.GetImportTask().GetFiles(), req.GetImportTask().GetRowBased(), false)
|
||||
//err = importWrapper.Import(req.GetImportTask().GetFiles(), req.GetImportTask().GetRowBased(), false, tsStart, tsEnd)
|
||||
err = importWrapper.Import(req.GetImportTask().GetFiles(),
|
||||
importutil.ImportOptions{OnlyValidate: false, TsStartPoint: tsStart, TsEndPoint: tsEnd})
|
||||
if err != nil {
|
||||
return returnFailFunc(err)
|
||||
}
|
||||
|
@ -1183,8 +1186,8 @@ func (node *DataNode) AddImportSegment(ctx context.Context, req *datapb.AddImpor
|
|||
}, nil
|
||||
}
|
||||
|
||||
func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoordpb.ImportResult, schema *schemapb.CollectionSchema, ts Timestamp) importutil.ImportFlushFunc {
|
||||
return func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
func assignSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest) importutil.AssignSegmentFunc {
|
||||
return func(shardID int) (int64, string, error) {
|
||||
chNames := req.GetImportTask().GetChannelNames()
|
||||
importTaskID := req.GetImportTask().GetTaskId()
|
||||
if shardID >= len(chNames) {
|
||||
|
@ -1194,53 +1197,91 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root
|
|||
zap.Int("# of channels", len(chNames)),
|
||||
zap.Strings("channel names", chNames),
|
||||
)
|
||||
return fmt.Errorf("syncSegmentID Failed: invalid shard ID %d", shardID)
|
||||
return 0, "", fmt.Errorf("syncSegmentID Failed: invalid shard ID %d", shardID)
|
||||
}
|
||||
|
||||
tr := timerecord.NewTimeRecorder("import callback function")
|
||||
defer tr.Elapse("finished")
|
||||
|
||||
colID := req.GetImportTask().GetCollectionId()
|
||||
partID := req.GetImportTask().GetPartitionId()
|
||||
segmentIDReq := composeAssignSegmentIDRequest(1, shardID, chNames, colID, partID)
|
||||
targetChName := segmentIDReq.GetSegmentIDRequests()[0].GetChannelName()
|
||||
log.Info("target channel for the import task",
|
||||
zap.Int64("task ID", importTaskID),
|
||||
zap.Int("shard ID", shardID),
|
||||
zap.String("target channel name", targetChName))
|
||||
resp, err := node.dataCoord.AssignSegmentID(context.Background(), segmentIDReq)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("syncSegmentID Failed:%w", err)
|
||||
}
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return 0, "", fmt.Errorf("syncSegmentID Failed:%s", resp.Status.Reason)
|
||||
}
|
||||
segmentID := resp.SegIDAssignments[0].SegID
|
||||
log.Info("new segment assigned",
|
||||
zap.Int64("task ID", importTaskID),
|
||||
zap.Int64("segmentID", segmentID),
|
||||
zap.Int("shard ID", shardID),
|
||||
zap.String("target channel name", targetChName))
|
||||
return segmentID, targetChName, nil
|
||||
}
|
||||
}
|
||||
|
||||
func createBinLogsFunc(node *DataNode, req *datapb.ImportTaskRequest, schema *schemapb.CollectionSchema, ts Timestamp) importutil.CreateBinlogsFunc {
|
||||
return func(fields map[storage.FieldID]storage.FieldData, segmentID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error) {
|
||||
var rowNum int
|
||||
for _, field := range fields {
|
||||
rowNum = field.RowNum()
|
||||
break
|
||||
}
|
||||
|
||||
chNames := req.GetImportTask().GetChannelNames()
|
||||
importTaskID := req.GetImportTask().GetTaskId()
|
||||
if rowNum <= 0 {
|
||||
log.Info("fields data is empty, no need to generate segment",
|
||||
log.Info("fields data is empty, no need to generate binlog",
|
||||
zap.Int64("task ID", importTaskID),
|
||||
zap.Int("shard ID", shardID),
|
||||
zap.Int("# of channels", len(chNames)),
|
||||
zap.Strings("channel names", chNames),
|
||||
)
|
||||
return nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
colID := req.GetImportTask().GetCollectionId()
|
||||
partID := req.GetImportTask().GetPartitionId()
|
||||
segmentIDReq := composeAssignSegmentIDRequest(rowNum, shardID, chNames, colID, partID)
|
||||
targetChName := segmentIDReq.GetSegmentIDRequests()[0].GetChannelName()
|
||||
log.Info("target channel for the import task",
|
||||
zap.Int64("task ID", importTaskID),
|
||||
zap.String("target channel name", targetChName))
|
||||
resp, err := node.dataCoord.AssignSegmentID(context.Background(), segmentIDReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("syncSegmentID Failed:%w", err)
|
||||
}
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return fmt.Errorf("syncSegmentID Failed:%s", resp.Status.Reason)
|
||||
}
|
||||
segmentID := resp.SegIDAssignments[0].SegID
|
||||
|
||||
fieldInsert, fieldStats, err := createBinLogs(rowNum, schema, ts, fields, node, segmentID, colID, partID)
|
||||
if err != nil {
|
||||
return err
|
||||
log.Error("failed to create binlogs",
|
||||
zap.Int64("task ID", importTaskID),
|
||||
zap.Int("# of channels", len(chNames)),
|
||||
zap.Strings("channel names", chNames),
|
||||
zap.Any("err", err),
|
||||
)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
log.Info("new binlog created",
|
||||
zap.Int64("task ID", importTaskID),
|
||||
zap.Int64("segmentID", segmentID),
|
||||
zap.Int("insert log count", len(fieldInsert)),
|
||||
zap.Int("stats log count", len(fieldStats)))
|
||||
|
||||
return fieldInsert, fieldStats, err
|
||||
}
|
||||
}
|
||||
|
||||
func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoordpb.ImportResult, ts Timestamp) importutil.SaveSegmentFunc {
|
||||
importTaskID := req.GetImportTask().GetTaskId()
|
||||
return func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64) error {
|
||||
log.Info("adding segment to the correct DataNode flow graph and saving binlog paths",
|
||||
zap.Int64("segment ID", segmentID),
|
||||
zap.Int64("task ID", importTaskID),
|
||||
zap.Int64("segmentID", segmentID),
|
||||
zap.String("targetChName", targetChName),
|
||||
zap.Int64("rowCount", rowCount),
|
||||
zap.Uint64("ts", ts))
|
||||
err = retry.Do(context.Background(), func() error {
|
||||
|
||||
err := retry.Do(context.Background(), func() error {
|
||||
// Ask DataCoord to save binlog path and add segment to the corresponding DataNode flow graph.
|
||||
resp, err := node.dataCoord.SaveImportSegment(context.Background(), &datapb.SaveImportSegmentRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
|
@ -1251,7 +1292,7 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root
|
|||
ChannelName: targetChName,
|
||||
CollectionId: req.GetImportTask().GetCollectionId(),
|
||||
PartitionId: req.GetImportTask().GetPartitionId(),
|
||||
RowNum: int64(rowNum),
|
||||
RowNum: rowCount,
|
||||
SaveBinlogPathReq: &datapb.SaveBinlogPathsRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(0),
|
||||
|
@ -1261,8 +1302,8 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root
|
|||
),
|
||||
SegmentID: segmentID,
|
||||
CollectionID: req.GetImportTask().GetCollectionId(),
|
||||
Field2BinlogPaths: fieldInsert,
|
||||
Field2StatslogPaths: fieldStats,
|
||||
Field2BinlogPaths: fieldsInsert,
|
||||
Field2StatslogPaths: fieldsStats,
|
||||
// Set start positions of a SaveBinlogPathRequest explicitly.
|
||||
StartPositions: []*datapb.SegmentStartPosition{
|
||||
{
|
||||
|
@ -1291,9 +1332,11 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root
|
|||
log.Warn("failed to save import segment", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Info("segment imported and persisted", zap.Int64("segmentID", segmentID))
|
||||
log.Info("segment imported and persisted",
|
||||
zap.Int64("task ID", importTaskID),
|
||||
zap.Int64("segmentID", segmentID))
|
||||
res.Segments = append(res.Segments, segmentID)
|
||||
res.RowCount += int64(rowNum)
|
||||
res.RowCount += rowCount
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,7 +38,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
|
@ -603,30 +602,6 @@ func TestDataNode(t *testing.T) {
|
|||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stat.GetErrorCode())
|
||||
})
|
||||
|
||||
t.Run("Test Import callback func error", func(t *testing.T) {
|
||||
req := &datapb.ImportTaskRequest{
|
||||
ImportTask: &datapb.ImportTask{
|
||||
CollectionId: 100,
|
||||
PartitionId: 100,
|
||||
ChannelNames: []string{"ch1", "ch2"},
|
||||
},
|
||||
}
|
||||
importResult := &rootcoordpb.ImportResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
TaskId: 0,
|
||||
DatanodeId: 0,
|
||||
State: commonpb.ImportState_ImportStarted,
|
||||
Segments: make([]int64, 0),
|
||||
AutoIds: make([]int64, 0),
|
||||
RowCount: 0,
|
||||
}
|
||||
callback := importFlushReqFunc(node, req, importResult, nil, 0)
|
||||
err := callback(nil, len(req.ImportTask.ChannelNames)+1)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test BackGroundGC", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64)
|
||||
|
|
|
@ -114,7 +114,7 @@ message SegmentIDRequest {
|
|||
string channel_name = 2;
|
||||
int64 collectionID = 3;
|
||||
int64 partitionID = 4;
|
||||
bool isImport = 5; // Indicate whether this request comes from a bulk load task.
|
||||
bool isImport = 5; // Indicate whether this request comes from a bulk insert task.
|
||||
int64 importTaskID = 6; // Needed for segment lock.
|
||||
}
|
||||
|
||||
|
@ -269,8 +269,8 @@ message SegmentInfo {
|
|||
repeated int64 compactionFrom = 15;
|
||||
uint64 dropped_at = 16; // timestamp when segment marked drop
|
||||
// A flag indicating if:
|
||||
// (1) this segment is created by bulk load, and
|
||||
// (2) the bulk load task that creates this segment has not yet reached `ImportCompleted` state.
|
||||
// (1) this segment is created by bulk insert, and
|
||||
// (2) the bulk insert task that creates this segment has not yet reached `ImportCompleted` state.
|
||||
bool is_importing = 17;
|
||||
bool is_fake = 18;
|
||||
}
|
||||
|
|
|
@ -1598,8 +1598,8 @@ type SegmentInfo struct {
|
|||
CompactionFrom []int64 `protobuf:"varint,15,rep,packed,name=compactionFrom,proto3" json:"compactionFrom,omitempty"`
|
||||
DroppedAt uint64 `protobuf:"varint,16,opt,name=dropped_at,json=droppedAt,proto3" json:"dropped_at,omitempty"`
|
||||
// A flag indicating if:
|
||||
// (1) this segment is created by bulk load, and
|
||||
// (2) the bulk load task that creates this segment has not yet reached `ImportCompleted` state.
|
||||
// (1) this segment is created by bulk insert, and
|
||||
// (2) the bulk insert task that creates this segment has not yet reached `ImportCompleted` state.
|
||||
IsImporting bool `protobuf:"varint,17,opt,name=is_importing,json=isImporting,proto3" json:"is_importing,omitempty"`
|
||||
IsFake bool `protobuf:"varint,18,opt,name=is_fake,json=isFake,proto3" json:"is_fake,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
|
|
|
@ -3967,7 +3967,8 @@ func unhealthyStatus() *commonpb.Status {
|
|||
func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) {
|
||||
log.Info("received import request",
|
||||
zap.String("collection name", req.GetCollectionName()),
|
||||
zap.Bool("row-based", req.GetRowBased()))
|
||||
zap.String("partition name", req.GetPartitionName()),
|
||||
zap.Strings("files", req.GetFiles()))
|
||||
resp := &milvuspb.ImportResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
@ -3996,7 +3997,7 @@ func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*mi
|
|||
respFromRC, err := node.rootCoord.Import(ctx, req)
|
||||
if err != nil {
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method, metrics.FailLabel).Inc()
|
||||
log.Error("failed to execute bulk load request", zap.Error(err))
|
||||
log.Error("failed to execute bulk insert request", zap.Error(err))
|
||||
resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
resp.Status.Reason = err.Error()
|
||||
return resp, nil
|
||||
|
|
|
@ -1698,7 +1698,7 @@ func TestProxy(t *testing.T) {
|
|||
defer wg.Done()
|
||||
req := &milvuspb.ImportRequest{
|
||||
CollectionName: collectionName,
|
||||
Files: []string{"f1.json", "f2.json", "f3.csv"},
|
||||
Files: []string{"f1.json", "f2.json", "f3.json"},
|
||||
}
|
||||
proxy.stateCode.Store(commonpb.StateCode_Healthy)
|
||||
resp, err := proxy.Import(context.TODO(), req)
|
||||
|
|
|
@ -34,7 +34,9 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/util/importutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
@ -210,7 +212,6 @@ func (m *importManager) sendOutTasks(ctx context.Context) error {
|
|||
CollectionId: task.GetCollectionId(),
|
||||
PartitionId: task.GetPartitionId(),
|
||||
ChannelNames: task.GetChannelNames(),
|
||||
RowBased: task.GetRowBased(),
|
||||
TaskId: task.GetId(),
|
||||
Files: task.GetFiles(),
|
||||
Infos: task.GetInfos(),
|
||||
|
@ -381,25 +382,39 @@ func (m *importManager) checkIndexingDone(ctx context.Context, collID UniqueID,
|
|||
return len(allSegmentIDs) == indexedSegmentCount, nil
|
||||
}
|
||||
|
||||
func (m *importManager) isRowbased(files []string) (bool, error) {
|
||||
isRowBased := false
|
||||
for _, filePath := range files {
|
||||
_, fileType := importutil.GetFileNameAndExt(filePath)
|
||||
if fileType == importutil.JSONFileExt {
|
||||
isRowBased = true
|
||||
} else if isRowBased {
|
||||
log.Error("row-based data file type must be JSON, mixed file types is not allowed", zap.Strings("files", files))
|
||||
return isRowBased, fmt.Errorf("row-based data file type must be JSON, file type '%s' is not allowed", fileType)
|
||||
}
|
||||
}
|
||||
|
||||
return isRowBased, nil
|
||||
}
|
||||
|
||||
// importJob processes the import request, generates import tasks, sends these tasks to DataCoord, and returns
|
||||
// immediately.
|
||||
func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportRequest, cID int64, pID int64) *milvuspb.ImportResponse {
|
||||
if req == nil || len(req.Files) == 0 {
|
||||
returnErrorFunc := func(reason string) *milvuspb.ImportResponse {
|
||||
return &milvuspb.ImportResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "import request is empty",
|
||||
Reason: reason,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if req == nil || len(req.Files) == 0 {
|
||||
return returnErrorFunc("import request is empty")
|
||||
}
|
||||
|
||||
if m.callImportService == nil {
|
||||
return &milvuspb.ImportResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "import service is not available",
|
||||
},
|
||||
}
|
||||
return returnErrorFunc("import service is not available")
|
||||
}
|
||||
|
||||
resp := &milvuspb.ImportResponse{
|
||||
|
@ -409,7 +424,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
|
|||
Tasks: make([]int64, 0),
|
||||
}
|
||||
|
||||
log.Debug("request received",
|
||||
log.Debug("receive import job",
|
||||
zap.String("collection name", req.GetCollectionName()),
|
||||
zap.Int64("collection ID", cID),
|
||||
zap.Int64("partition ID", pID))
|
||||
|
@ -420,8 +435,13 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
|
|||
capacity := cap(m.pendingTasks)
|
||||
length := len(m.pendingTasks)
|
||||
|
||||
isRowBased, err := m.isRowbased(req.GetFiles())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
taskCount := 1
|
||||
if req.RowBased {
|
||||
if isRowBased {
|
||||
taskCount = len(req.Files)
|
||||
}
|
||||
|
||||
|
@ -433,7 +453,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
|
|||
}
|
||||
|
||||
// convert import request to import tasks
|
||||
if req.RowBased {
|
||||
if isRowBased {
|
||||
// For row-based importing, each file makes a task.
|
||||
taskList := make([]int64, len(req.Files))
|
||||
for i := 0; i < len(req.Files); i++ {
|
||||
|
@ -446,7 +466,6 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
|
|||
CollectionId: cID,
|
||||
PartitionId: pID,
|
||||
ChannelNames: req.ChannelNames,
|
||||
RowBased: req.GetRowBased(),
|
||||
Files: []string{req.GetFiles()[i]},
|
||||
CreateTs: time.Now().Unix(),
|
||||
State: &datapb.ImportTaskState{
|
||||
|
@ -485,7 +504,6 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
|
|||
CollectionId: cID,
|
||||
PartitionId: pID,
|
||||
ChannelNames: req.ChannelNames,
|
||||
RowBased: req.GetRowBased(),
|
||||
Files: req.GetFiles(),
|
||||
CreateTs: time.Now().Unix(),
|
||||
State: &datapb.ImportTaskState{
|
||||
|
@ -514,12 +532,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
|
|||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return &milvuspb.ImportResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}
|
||||
return returnErrorFunc(err.Error())
|
||||
}
|
||||
if sendOutTasksErr := m.sendOutTasks(ctx); sendOutTasksErr != nil {
|
||||
log.Error("fail to send out tasks", zap.Error(sendOutTasksErr))
|
||||
|
@ -1054,7 +1067,6 @@ func cloneImportTaskInfo(taskInfo *datapb.ImportTaskInfo) *datapb.ImportTaskInfo
|
|||
CollectionId: taskInfo.GetCollectionId(),
|
||||
PartitionId: taskInfo.GetPartitionId(),
|
||||
ChannelNames: taskInfo.GetChannelNames(),
|
||||
RowBased: taskInfo.GetRowBased(),
|
||||
Files: taskInfo.GetFiles(),
|
||||
CreateTs: taskInfo.GetCreateTs(),
|
||||
State: taskInfo.GetState(),
|
||||
|
|
|
@ -574,6 +574,8 @@ func TestImportManager_ImportJob(t *testing.T) {
|
|||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// nil request
|
||||
mgr := newImportManager(context.TODO(), mockKv, idAlloc, nil, callMarkSegmentsDropped, nil, nil, nil, nil)
|
||||
resp := mgr.importJob(context.TODO(), nil, colID, 0)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
|
@ -581,27 +583,14 @@ func TestImportManager_ImportJob(t *testing.T) {
|
|||
rowReq := &milvuspb.ImportRequest{
|
||||
CollectionName: "c1",
|
||||
PartitionName: "p1",
|
||||
RowBased: true,
|
||||
Files: []string{"f1", "f2", "f3"},
|
||||
Files: []string{"f1.json", "f2.json", "f3.json"},
|
||||
}
|
||||
|
||||
// nil callImportService
|
||||
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
|
||||
colReq := &milvuspb.ImportRequest{
|
||||
CollectionName: "c1",
|
||||
PartitionName: "p1",
|
||||
RowBased: false,
|
||||
Files: []string{"f1", "f2"},
|
||||
Options: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: importutil.Bucket,
|
||||
Value: "mybucket",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
|
||||
importServiceFunc := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
|
||||
return &datapb.ImportTaskResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -609,17 +598,27 @@ func TestImportManager_ImportJob(t *testing.T) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil)
|
||||
// row-based case, task count equal to file count
|
||||
// since the importServiceFunc return error, tasks will be kept in pending list
|
||||
mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil)
|
||||
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
|
||||
assert.Equal(t, len(rowReq.Files), len(mgr.pendingTasks))
|
||||
assert.Equal(t, 0, len(mgr.workingTasks))
|
||||
|
||||
mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil)
|
||||
colReq := &milvuspb.ImportRequest{
|
||||
CollectionName: "c1",
|
||||
PartitionName: "p1",
|
||||
Files: []string{"f1.npy", "f2.npy", "f3.npy"},
|
||||
}
|
||||
|
||||
// column-based case, one quest one task
|
||||
// since the importServiceFunc return error, tasks will be kept in pending list
|
||||
mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil)
|
||||
resp = mgr.importJob(context.TODO(), colReq, colID, 0)
|
||||
assert.Equal(t, 1, len(mgr.pendingTasks))
|
||||
assert.Equal(t, 0, len(mgr.workingTasks))
|
||||
|
||||
fn = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
|
||||
importServiceFunc = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
|
||||
return &datapb.ImportTaskResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
@ -627,18 +626,20 @@ func TestImportManager_ImportJob(t *testing.T) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil)
|
||||
// row-based case, since the importServiceFunc return success, tasks will be sent to woring list
|
||||
mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil)
|
||||
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
|
||||
assert.Equal(t, 0, len(mgr.pendingTasks))
|
||||
assert.Equal(t, len(rowReq.Files), len(mgr.workingTasks))
|
||||
|
||||
mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil)
|
||||
// column-based case, since the importServiceFunc return success, tasks will be sent to woring list
|
||||
mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil)
|
||||
resp = mgr.importJob(context.TODO(), colReq, colID, 0)
|
||||
assert.Equal(t, 0, len(mgr.pendingTasks))
|
||||
assert.Equal(t, 1, len(mgr.workingTasks))
|
||||
|
||||
count := 0
|
||||
fn = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
|
||||
importServiceFunc = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
|
||||
if count >= 2 {
|
||||
return &datapb.ImportTaskResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -654,13 +655,16 @@ func TestImportManager_ImportJob(t *testing.T) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callMarkSegmentsDropped, nil, nil, nil, nil)
|
||||
// row-based case, since the importServiceFunc return success for 2 tasks
|
||||
// the 2 tasks are sent to working list, and 1 task left in pending list
|
||||
mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callMarkSegmentsDropped, nil, nil, nil, nil)
|
||||
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
|
||||
assert.Equal(t, len(rowReq.Files)-2, len(mgr.pendingTasks))
|
||||
assert.Equal(t, 2, len(mgr.workingTasks))
|
||||
|
||||
for i := 0; i <= 32; i++ {
|
||||
rowReq.Files = append(rowReq.Files, strconv.Itoa(i))
|
||||
// files count exceed MaxPendingCount, return error
|
||||
for i := 0; i <= MaxPendingCount; i++ {
|
||||
rowReq.Files = append(rowReq.Files, strconv.Itoa(i)+".json")
|
||||
}
|
||||
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
|
@ -683,14 +687,12 @@ func TestImportManager_AllDataNodesBusy(t *testing.T) {
|
|||
rowReq := &milvuspb.ImportRequest{
|
||||
CollectionName: "c1",
|
||||
PartitionName: "p1",
|
||||
RowBased: true,
|
||||
Files: []string{"f1", "f2", "f3"},
|
||||
Files: []string{"f1.json", "f2.json", "f3.json"},
|
||||
}
|
||||
colReq := &milvuspb.ImportRequest{
|
||||
CollectionName: "c1",
|
||||
PartitionName: "p1",
|
||||
RowBased: false,
|
||||
Files: []string{"f1", "f2"},
|
||||
Files: []string{"f1.npy", "f2.npy"},
|
||||
Options: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: importutil.Bucket,
|
||||
|
@ -778,8 +780,7 @@ func TestImportManager_TaskState(t *testing.T) {
|
|||
rowReq := &milvuspb.ImportRequest{
|
||||
CollectionName: "c1",
|
||||
PartitionName: "p1",
|
||||
RowBased: true,
|
||||
Files: []string{"f1", "f2", "f3"},
|
||||
Files: []string{"f1.json", "f2.json", "f3.json"},
|
||||
}
|
||||
|
||||
callMarkSegmentsDropped := func(ctx context.Context, segIDs []typeutil.UniqueID) (*commonpb.Status, error) {
|
||||
|
@ -817,8 +818,7 @@ func TestImportManager_TaskState(t *testing.T) {
|
|||
assert.Equal(t, int64(2), ti.GetId())
|
||||
assert.Equal(t, int64(100), ti.GetCollectionId())
|
||||
assert.Equal(t, int64(0), ti.GetPartitionId())
|
||||
assert.Equal(t, true, ti.GetRowBased())
|
||||
assert.Equal(t, []string{"f2"}, ti.GetFiles())
|
||||
assert.Equal(t, []string{"f2.json"}, ti.GetFiles())
|
||||
assert.Equal(t, commonpb.ImportState_ImportPersisted, ti.GetState().GetStateCode())
|
||||
assert.Equal(t, int64(1000), ti.GetState().GetRowCount())
|
||||
|
||||
|
@ -876,8 +876,7 @@ func TestImportManager_AllocFail(t *testing.T) {
|
|||
rowReq := &milvuspb.ImportRequest{
|
||||
CollectionName: "c1",
|
||||
PartitionName: "p1",
|
||||
RowBased: true,
|
||||
Files: []string{"f1", "f2", "f3"},
|
||||
Files: []string{"f1.json", "f2.json", "f3.json"},
|
||||
}
|
||||
|
||||
callMarkSegmentsDropped := func(ctx context.Context, segIDs []typeutil.UniqueID) (*commonpb.Status, error) {
|
||||
|
@ -917,8 +916,7 @@ func TestImportManager_ListAllTasks(t *testing.T) {
|
|||
rowReq := &milvuspb.ImportRequest{
|
||||
CollectionName: "c1",
|
||||
PartitionName: "p1",
|
||||
RowBased: true,
|
||||
Files: []string{"f1", "f2", "f3"},
|
||||
Files: []string{"f1.json", "f2.json", "f3.json"},
|
||||
}
|
||||
callMarkSegmentsDropped := func(ctx context.Context, segIDs []typeutil.UniqueID) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{
|
||||
|
@ -1007,3 +1005,22 @@ func TestImportManager_rearrangeTasks(t *testing.T) {
|
|||
assert.Equal(t, int64(50), tasks[1].GetId())
|
||||
assert.Equal(t, int64(100), tasks[2].GetId())
|
||||
}
|
||||
|
||||
func TestImportManager_isRowbased(t *testing.T) {
|
||||
mgr := &importManager{}
|
||||
|
||||
files := []string{"1.json", "2.json"}
|
||||
rb, err := mgr.isRowbased(files)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, rb)
|
||||
|
||||
files = []string{"1.json", "2.npy"}
|
||||
rb, err = mgr.isRowbased(files)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, rb)
|
||||
|
||||
files = []string{"1.npy", "2.npy"}
|
||||
rb, err = mgr.isRowbased(files)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, rb)
|
||||
}
|
||||
|
|
|
@ -86,8 +86,8 @@ type IMetaTable interface {
|
|||
ListAliasesByID(collID UniqueID) []string
|
||||
|
||||
// TODO: better to accept ctx.
|
||||
GetPartitionNameByID(collID UniqueID, partitionID UniqueID, ts Timestamp) (string, error) // serve for bulk load.
|
||||
GetPartitionByName(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) // serve for bulk load.
|
||||
GetPartitionNameByID(collID UniqueID, partitionID UniqueID, ts Timestamp) (string, error) // serve for bulk insert.
|
||||
GetPartitionByName(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) // serve for bulk insert.
|
||||
|
||||
// TODO: better to accept ctx.
|
||||
AddCredential(credInfo *internalpb.CredentialInfo) error
|
||||
|
@ -634,7 +634,7 @@ func (mt *MetaTable) ListAliasesByID(collID UniqueID) []string {
|
|||
return mt.listAliasesByID(collID)
|
||||
}
|
||||
|
||||
// GetCollectionNameByID serve for bulk load. TODO: why this didn't accept ts?
|
||||
// GetCollectionNameByID serve for bulk insert. TODO: why this didn't accept ts?
|
||||
// [Deprecated]
|
||||
func (mt *MetaTable) GetCollectionNameByID(collID UniqueID) (string, error) {
|
||||
mt.ddLock.RLock()
|
||||
|
@ -648,7 +648,7 @@ func (mt *MetaTable) GetCollectionNameByID(collID UniqueID) (string, error) {
|
|||
return coll.Name, nil
|
||||
}
|
||||
|
||||
// GetPartitionNameByID serve for bulk load.
|
||||
// GetPartitionNameByID serve for bulk insert.
|
||||
func (mt *MetaTable) GetPartitionNameByID(collID UniqueID, partitionID UniqueID, ts Timestamp) (string, error) {
|
||||
mt.ddLock.RLock()
|
||||
defer mt.ddLock.RUnlock()
|
||||
|
@ -680,7 +680,7 @@ func (mt *MetaTable) GetPartitionNameByID(collID UniqueID, partitionID UniqueID,
|
|||
return "", fmt.Errorf("partition not exist: %d", partitionID)
|
||||
}
|
||||
|
||||
// GetCollectionIDByName serve for bulk load. TODO: why this didn't accept ts?
|
||||
// GetCollectionIDByName serve for bulk insert. TODO: why this didn't accept ts?
|
||||
// [Deprecated]
|
||||
func (mt *MetaTable) GetCollectionIDByName(name string) (UniqueID, error) {
|
||||
mt.ddLock.RLock()
|
||||
|
@ -693,7 +693,7 @@ func (mt *MetaTable) GetCollectionIDByName(name string) (UniqueID, error) {
|
|||
return id, nil
|
||||
}
|
||||
|
||||
// GetPartitionByName serve for bulk load.
|
||||
// GetPartitionByName serve for bulk insert.
|
||||
func (mt *MetaTable) GetPartitionByName(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) {
|
||||
mt.ddLock.RLock()
|
||||
defer mt.ddLock.RUnlock()
|
||||
|
|
|
@ -1594,7 +1594,6 @@ func (c *Core) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvus
|
|||
zap.Strings("virtual channel names", req.GetChannelNames()),
|
||||
zap.Int64("partition ID", pID),
|
||||
zap.Int("# of files = ", len(req.GetFiles())),
|
||||
zap.Bool("row-based", req.GetRowBased()),
|
||||
)
|
||||
importJobResp := c.importManager.importJob(ctx, req, cID, pID)
|
||||
return importJobResp, nil
|
||||
|
@ -1692,7 +1691,7 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) (
|
|||
resendTaskFunc()
|
||||
// Flush all import data segments.
|
||||
if err := c.broker.Flush(ctx, ti.GetCollectionId(), ir.GetSegments()); err != nil {
|
||||
log.Error("failed to call Flush on bulk load segments",
|
||||
log.Error("failed to call Flush on bulk insert segments",
|
||||
zap.Int64("task ID", ir.GetTaskId()))
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package importutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -43,29 +44,39 @@ type SegmentFilesHolder struct {
|
|||
// 1. read insert log of each field, then constructs map[storage.FieldID]storage.FieldData in memory.
|
||||
// 2. read delta log to remove deleted entities(TimeStampField is used to apply or skip the operation).
|
||||
// 3. split data according to shard number
|
||||
// 4. call the callFlushFunc function to flush data into new segment if data size reaches segmentSize.
|
||||
// 4. call the callFlushFunc function to flush data into new binlog file if data size reaches blockSize.
|
||||
type BinlogAdapter struct {
|
||||
ctx context.Context // for canceling parse process
|
||||
collectionSchema *schemapb.CollectionSchema // collection schema
|
||||
chunkManager storage.ChunkManager // storage interfaces to read binlog files
|
||||
callFlushFunc ImportFlushFunc // call back function to flush segment
|
||||
shardNum int32 // sharding number of the collection
|
||||
segmentSize int64 // maximum size of a segment(unit:byte)
|
||||
blockSize int64 // maximum size of a read block(unit:byte)
|
||||
maxTotalSize int64 // maximum size of in-memory segments(unit:byte)
|
||||
primaryKey storage.FieldID // id of primary key
|
||||
primaryType schemapb.DataType // data type of primary key
|
||||
|
||||
// a timestamp to define the end point of restore, data after this point will be ignored
|
||||
// a timestamp to define the start time point of restore, data before this time point will be ignored
|
||||
// set this value to 0, all the data will be imported
|
||||
// set this value to math.MaxUint64, all the data will be ignored
|
||||
// the tsStartPoint value must be less/equal than tsEndPoint
|
||||
tsStartPoint uint64
|
||||
|
||||
// a timestamp to define the end time point of restore, data after this time point will be ignored
|
||||
// set this value to 0, all the data will be ignored
|
||||
// set this value to math.MaxUint64, all the data will be imported
|
||||
// the tsEndPoint value must be larger/equal than tsStartPoint
|
||||
tsEndPoint uint64
|
||||
}
|
||||
|
||||
func NewBinlogAdapter(collectionSchema *schemapb.CollectionSchema,
|
||||
func NewBinlogAdapter(ctx context.Context,
|
||||
collectionSchema *schemapb.CollectionSchema,
|
||||
shardNum int32,
|
||||
segmentSize int64,
|
||||
blockSize int64,
|
||||
maxTotalSize int64,
|
||||
chunkManager storage.ChunkManager,
|
||||
flushFunc ImportFlushFunc,
|
||||
tsStartPoint uint64,
|
||||
tsEndPoint uint64) (*BinlogAdapter, error) {
|
||||
if collectionSchema == nil {
|
||||
log.Error("Binlog adapter: collection schema is nil")
|
||||
|
@ -83,18 +94,20 @@ func NewBinlogAdapter(collectionSchema *schemapb.CollectionSchema,
|
|||
}
|
||||
|
||||
adapter := &BinlogAdapter{
|
||||
ctx: ctx,
|
||||
collectionSchema: collectionSchema,
|
||||
chunkManager: chunkManager,
|
||||
callFlushFunc: flushFunc,
|
||||
shardNum: shardNum,
|
||||
segmentSize: segmentSize,
|
||||
blockSize: blockSize,
|
||||
maxTotalSize: maxTotalSize,
|
||||
tsStartPoint: tsStartPoint,
|
||||
tsEndPoint: tsEndPoint,
|
||||
}
|
||||
|
||||
// amend the segment size to avoid portential OOM risk
|
||||
if adapter.segmentSize > MaxSegmentSizeInMemory {
|
||||
adapter.segmentSize = MaxSegmentSizeInMemory
|
||||
if adapter.blockSize > MaxSegmentSizeInMemory {
|
||||
adapter.blockSize = MaxSegmentSizeInMemory
|
||||
}
|
||||
|
||||
// find out the primary key ID and its data type
|
||||
|
@ -142,7 +155,7 @@ func (p *BinlogAdapter) Read(segmentHolder *SegmentFilesHolder) error {
|
|||
// b has these binlog files: b_1, b_2, b_3 ...
|
||||
// Then first round read a_1 and b_1, second round read a_2 and b_2, etc...
|
||||
// deleted list will be used to remove deleted entities
|
||||
// if accumulate data exceed segmentSize, call callFlushFunc to generate new segment
|
||||
// if accumulate data exceed blockSize, call callFlushFunc to generate new binlog file
|
||||
batchCount := 0
|
||||
for _, files := range segmentHolder.fieldFiles {
|
||||
batchCount = len(files)
|
||||
|
@ -215,24 +228,30 @@ func (p *BinlogAdapter) Read(segmentHolder *SegmentFilesHolder) error {
|
|||
|
||||
// read other insert logs and use the shardList to do sharding
|
||||
for fieldID, file := range batchFiles {
|
||||
// outside context might be canceled(service stop, or future enhancement for canceling import task)
|
||||
if isCanceled(p.ctx) {
|
||||
log.Error("Binlog adapter: import task was canceled")
|
||||
return errors.New("import task was canceled")
|
||||
}
|
||||
|
||||
err = p.readInsertlog(fieldID, file, segmentsData, shardList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// flush segment whose size exceed segmentSize
|
||||
err = p.tryFlushSegments(segmentsData, false)
|
||||
// flush segment whose size exceed blockSize
|
||||
err = tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.callFlushFunc, p.blockSize, p.maxTotalSize, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// finally, force to flush
|
||||
return p.tryFlushSegments(segmentsData, true)
|
||||
return tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.callFlushFunc, p.blockSize, p.maxTotalSize, true)
|
||||
}
|
||||
|
||||
// This method verify the schema and binlog files
|
||||
// verify method verify the schema and binlog files
|
||||
// 1. each field must has binlog file
|
||||
// 2. binlog file count of each field must be equal
|
||||
// 3. the collectionSchema doesn't contain TimeStampField and RowIDField since the import_wrapper excludes them,
|
||||
|
@ -284,7 +303,7 @@ func (p *BinlogAdapter) verify(segmentHolder *SegmentFilesHolder) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// This method read data from deltalog, and convert to a dict
|
||||
// readDeltalogs method reads data from deltalog, and convert to a dict
|
||||
// The deltalog data is a list, to improve performance of next step, we convert it to a dict,
|
||||
// key is the deleted ID, value is operation timestamp which is used to apply or skip the delete operation.
|
||||
func (p *BinlogAdapter) readDeltalogs(segmentHolder *SegmentFilesHolder) (map[int64]uint64, map[string]uint64, error) {
|
||||
|
@ -318,7 +337,7 @@ func (p *BinlogAdapter) readDeltalogs(segmentHolder *SegmentFilesHolder) (map[in
|
|||
}
|
||||
}
|
||||
|
||||
// Decode string array(read from delta log) to storage.DeleteLog array
|
||||
// decodeDeleteLogs decodes string array(read from delta log) to storage.DeleteLog array
|
||||
func (p *BinlogAdapter) decodeDeleteLogs(segmentHolder *SegmentFilesHolder) ([]*storage.DeleteLog, error) {
|
||||
// step 1: read all delta logs to construct a string array, each string is marshaled from storage.DeleteLog
|
||||
stringArray := make([]string, 0)
|
||||
|
@ -345,8 +364,9 @@ func (p *BinlogAdapter) decodeDeleteLogs(segmentHolder *SegmentFilesHolder) ([]*
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// ignore deletions whose timestamp is larger than the tsEndPoint
|
||||
if deleteLog.Ts <= p.tsEndPoint {
|
||||
// only the ts between tsStartPoint and tsEndPoint is effective
|
||||
// ignore deletions whose timestamp is larger than the tsEndPoint or less than tsStartPoint
|
||||
if deleteLog.Ts >= p.tsStartPoint && deleteLog.Ts <= p.tsEndPoint {
|
||||
deleteLogs = append(deleteLogs, deleteLog)
|
||||
}
|
||||
}
|
||||
|
@ -365,7 +385,7 @@ func (p *BinlogAdapter) decodeDeleteLogs(segmentHolder *SegmentFilesHolder) ([]*
|
|||
return deleteLogs, nil
|
||||
}
|
||||
|
||||
// Decode a string to storage.DeleteLog
|
||||
// decodeDeleteLog decodes a string to storage.DeleteLog
|
||||
// Note: the following code is mainly come from data_codec.go, I suppose the code can compatible with old version 2.0
|
||||
func (p *BinlogAdapter) decodeDeleteLog(deltaStr string) (*storage.DeleteLog, error) {
|
||||
deleteLog := &storage.DeleteLog{}
|
||||
|
@ -399,7 +419,7 @@ func (p *BinlogAdapter) decodeDeleteLog(deltaStr string) (*storage.DeleteLog, er
|
|||
return deleteLog, nil
|
||||
}
|
||||
|
||||
// Each delta log data type is varchar, marshaled from an array of storage.DeleteLog objects.
|
||||
// readDeltalog parses a delta log file. Each delta log data type is varchar, marshaled from an array of storage.DeleteLog objects.
|
||||
func (p *BinlogAdapter) readDeltalog(logPath string) ([]string, error) {
|
||||
// open the delta log file
|
||||
binlogFile, err := NewBinlogFile(p.chunkManager)
|
||||
|
@ -426,7 +446,7 @@ func (p *BinlogAdapter) readDeltalog(logPath string) ([]string, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
// This method read data from int64 field, currently we use it to read the timestamp field.
|
||||
// readTimestamp method reads data from int64 field, currently we use it to read the timestamp field.
|
||||
func (p *BinlogAdapter) readTimestamp(logPath string) ([]int64, error) {
|
||||
// open the log file
|
||||
binlogFile, err := NewBinlogFile(p.chunkManager)
|
||||
|
@ -454,7 +474,7 @@ func (p *BinlogAdapter) readTimestamp(logPath string) ([]int64, error) {
|
|||
return int64List, nil
|
||||
}
|
||||
|
||||
// This method read primary keys from insert log.
|
||||
// readPrimaryKeys method reads primary keys from insert log.
|
||||
func (p *BinlogAdapter) readPrimaryKeys(logPath string) ([]int64, []string, error) {
|
||||
// open the delta log file
|
||||
binlogFile, err := NewBinlogFile(p.chunkManager)
|
||||
|
@ -493,7 +513,7 @@ func (p *BinlogAdapter) readPrimaryKeys(logPath string) ([]int64, []string, erro
|
|||
}
|
||||
}
|
||||
|
||||
// This method generate a shard id list by primary key(int64) list and deleted list.
|
||||
// getShardingListByPrimaryInt64 method generates a shard id list by primary key(int64) list and deleted list.
|
||||
// For example, an insert log has 10 rows, the no.3 and no.7 has been deleted, shardNum=2, the shardList could be:
|
||||
// [0, 1, -1, 1, 0, 1, -1, 1, 0, 1]
|
||||
// Compare timestampList with tsEndPoint to skip some rows.
|
||||
|
@ -513,10 +533,10 @@ func (p *BinlogAdapter) getShardingListByPrimaryInt64(primaryKeys []int64,
|
|||
excluded := 0
|
||||
shardList := make([]int32, 0, len(primaryKeys))
|
||||
for i, key := range primaryKeys {
|
||||
// if this entity's timestamp is greater than the tsEndPoint, set shardID = -1 to skip this entity
|
||||
// if this entity's timestamp is greater than the tsEndPoint, or less than tsStartPoint, set shardID = -1 to skip this entity
|
||||
// timestamp is stored as int64 type in log file, actually it is uint64, compare with uint64
|
||||
ts := timestampList[i]
|
||||
if uint64(ts) > p.tsEndPoint {
|
||||
if uint64(ts) > p.tsEndPoint || uint64(ts) < p.tsStartPoint {
|
||||
shardList = append(shardList, -1)
|
||||
excluded++
|
||||
continue
|
||||
|
@ -546,7 +566,7 @@ func (p *BinlogAdapter) getShardingListByPrimaryInt64(primaryKeys []int64,
|
|||
return shardList, nil
|
||||
}
|
||||
|
||||
// This method generate a shard id list by primary key(varchar) list and deleted list.
|
||||
// getShardingListByPrimaryVarchar method generates a shard id list by primary key(varchar) list and deleted list.
|
||||
// For example, an insert log has 10 rows, the no.3 and no.7 has been deleted, shardNum=2, the shardList could be:
|
||||
// [0, 1, -1, 1, 0, 1, -1, 1, 0, 1]
|
||||
func (p *BinlogAdapter) getShardingListByPrimaryVarchar(primaryKeys []string,
|
||||
|
@ -565,10 +585,10 @@ func (p *BinlogAdapter) getShardingListByPrimaryVarchar(primaryKeys []string,
|
|||
excluded := 0
|
||||
shardList := make([]int32, 0, len(primaryKeys))
|
||||
for i, key := range primaryKeys {
|
||||
// if this entity's timestamp is greater than the tsEndPoint, set shardID = -1 to skip this entity
|
||||
// if this entity's timestamp is greater than the tsEndPoint, or less than tsStartPoint, set shardID = -1 to skip this entity
|
||||
// timestamp is stored as int64 type in log file, actually it is uint64, compare with uint64
|
||||
ts := timestampList[i]
|
||||
if uint64(ts) > p.tsEndPoint {
|
||||
if uint64(ts) > p.tsEndPoint || uint64(ts) < p.tsStartPoint {
|
||||
shardList = append(shardList, -1)
|
||||
excluded++
|
||||
continue
|
||||
|
@ -598,7 +618,7 @@ func (p *BinlogAdapter) getShardingListByPrimaryVarchar(primaryKeys []string,
|
|||
return shardList, nil
|
||||
}
|
||||
|
||||
// This method read an insert log, and split the data into different shards according to a shard list
|
||||
// readInsertlog method reads an insert log, and split the data into different shards according to a shard list
|
||||
// The shardList is a list to tell which row belong to which shard, returned by getShardingListByPrimaryXXX()
|
||||
// For deleted rows, we say its shard id is -1.
|
||||
// For example, an insert log has 10 rows, the no.3 and no.7 has been deleted, shardNum=2, the shardList could be:
|
||||
|
@ -1000,96 +1020,3 @@ func (p *BinlogAdapter) dispatchFloatVecToShards(data []float32, dim int, memory
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// This method do the two things:
|
||||
// 1. if accumulate data of a segment exceed segmentSize, call callFlushFunc to generate new segment
|
||||
// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest segment
|
||||
func (p *BinlogAdapter) tryFlushSegments(segmentsData []map[storage.FieldID]storage.FieldData, force bool) error {
|
||||
totalSize := 0
|
||||
biggestSize := 0
|
||||
biggestItem := -1
|
||||
|
||||
// 1. if accumulate data of a segment exceed segmentSize, call callFlushFunc to generate new segment
|
||||
for i := 0; i < len(segmentsData); i++ {
|
||||
segmentData := segmentsData[i]
|
||||
// Note: even rowCount is 0, the size is still non-zero
|
||||
size := 0
|
||||
rowCount := 0
|
||||
for _, fieldData := range segmentData {
|
||||
size += fieldData.GetMemorySize()
|
||||
rowCount = fieldData.RowNum()
|
||||
}
|
||||
|
||||
// force to flush, called at the end of Read()
|
||||
if force && rowCount > 0 {
|
||||
err := p.callFlushFunc(segmentData, i)
|
||||
if err != nil {
|
||||
log.Error("Binlog adapter: failed to force flush segment data", zap.Int("shardID", i))
|
||||
return err
|
||||
}
|
||||
log.Info("Binlog adapter: force flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i))
|
||||
|
||||
segmentsData[i] = initSegmentData(p.collectionSchema)
|
||||
if segmentsData[i] == nil {
|
||||
log.Error("Binlog adapter: failed to initialize FieldData list")
|
||||
return errors.New("failed to initialize FieldData list")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// if segment size is larger than predefined segmentSize, flush to create a new segment
|
||||
// initialize a new FieldData list for next round batch read
|
||||
if size > int(p.segmentSize) && rowCount > 0 {
|
||||
err := p.callFlushFunc(segmentData, i)
|
||||
if err != nil {
|
||||
log.Error("Binlog adapter: failed to flush segment data", zap.Int("shardID", i))
|
||||
return err
|
||||
}
|
||||
log.Info("Binlog adapter: segment size exceed limit and flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i))
|
||||
|
||||
segmentsData[i] = initSegmentData(p.collectionSchema)
|
||||
if segmentsData[i] == nil {
|
||||
log.Error("Binlog adapter: failed to initialize FieldData list")
|
||||
return errors.New("failed to initialize FieldData list")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// calculate the total size(ignore the flushed segments)
|
||||
// find out the biggest segment for the step 2
|
||||
totalSize += size
|
||||
if size > biggestSize {
|
||||
biggestSize = size
|
||||
biggestItem = i
|
||||
}
|
||||
}
|
||||
|
||||
// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest segment
|
||||
if totalSize > int(p.maxTotalSize) && biggestItem >= 0 {
|
||||
segmentData := segmentsData[biggestItem]
|
||||
size := 0
|
||||
rowCount := 0
|
||||
for _, fieldData := range segmentData {
|
||||
size += fieldData.GetMemorySize()
|
||||
rowCount = fieldData.RowNum()
|
||||
}
|
||||
|
||||
if rowCount > 0 {
|
||||
err := p.callFlushFunc(segmentData, biggestItem)
|
||||
if err != nil {
|
||||
log.Error("Binlog adapter: failed to flush biggest segment data", zap.Int("shardID", biggestItem))
|
||||
return err
|
||||
}
|
||||
log.Info("Binlog adapter: total size exceed limit and flush", zap.Int("rowCount", rowCount),
|
||||
zap.Int("size", size), zap.Int("totalSize", totalSize), zap.Int("shardID", biggestItem))
|
||||
|
||||
segmentsData[biggestItem] = initSegmentData(p.collectionSchema)
|
||||
if segmentsData[biggestItem] == nil {
|
||||
log.Error("Binlog adapter: failed to initialize FieldData list")
|
||||
return errors.New("failed to initialize FieldData list")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -16,8 +16,10 @@
|
|||
package importutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
|
@ -154,18 +156,20 @@ func createSegmentsData(fieldsData map[storage.FieldID]interface{}, shardNum int
|
|||
}
|
||||
|
||||
func Test_NewBinlogAdapter(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// nil schema
|
||||
adapter, err := NewBinlogAdapter(nil, 2, 1024, 2048, nil, nil, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, nil, 2, 1024, 2048, nil, nil, 0, math.MaxUint64)
|
||||
assert.Nil(t, adapter)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// nil chunkmanager
|
||||
adapter, err = NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, nil, nil, 0)
|
||||
adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, nil, nil, 0, math.MaxUint64)
|
||||
assert.Nil(t, adapter)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// nil flushfunc
|
||||
adapter, err = NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, nil, 0)
|
||||
adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, nil, 0, math.MaxUint64)
|
||||
assert.Nil(t, adapter)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
|
@ -173,7 +177,7 @@ func Test_NewBinlogAdapter(t *testing.T) {
|
|||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
adapter, err = NewBinlogAdapter(sampleSchema(), 2, 2048, 1024, &MockChunkManager{}, flushFunc, 0)
|
||||
adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 2048, 1024, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -191,16 +195,18 @@ func Test_NewBinlogAdapter(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
adapter, err = NewBinlogAdapter(schema, 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
|
||||
adapter, err = NewBinlogAdapter(ctx, schema, 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.Nil(t, adapter)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_BinlogAdapterVerify(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -247,6 +253,8 @@ func Test_BinlogAdapterVerify(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogAdapterReadDeltalog(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
deleteItems := []int64{1001, 1002, 1003}
|
||||
buf := createDeltalogBuf(t, deleteItems, false)
|
||||
chunkManager := &MockChunkManager{
|
||||
|
@ -259,7 +267,7 @@ func Test_BinlogAdapterReadDeltalog(t *testing.T) {
|
|||
return nil
|
||||
}
|
||||
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -283,6 +291,8 @@ func Test_BinlogAdapterReadDeltalog(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
deleteItems := []int64{1001, 1002, 1003, 1004, 1005}
|
||||
buf := createDeltalogBuf(t, deleteItems, false)
|
||||
chunkManager := &MockChunkManager{
|
||||
|
@ -295,7 +305,7 @@ func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) {
|
|||
return nil
|
||||
}
|
||||
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -316,7 +326,7 @@ func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) {
|
|||
"dummy": createDeltalogBuf(t, []string{"1001", "1002"}, true),
|
||||
}
|
||||
|
||||
adapter, err = NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0)
|
||||
adapter, err = NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -327,11 +337,13 @@ func Test_BinlogAdapterDecodeDeleteLogs(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogAdapterDecodeDeleteLog(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -378,6 +390,8 @@ func Test_BinlogAdapterDecodeDeleteLog(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogAdapterReadDeltalogs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
deleteItems := []int64{1001, 1002, 1003, 1004, 1005}
|
||||
buf := createDeltalogBuf(t, deleteItems, false)
|
||||
chunkManager := &MockChunkManager{
|
||||
|
@ -390,7 +404,7 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) {
|
|||
return nil
|
||||
}
|
||||
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -432,17 +446,19 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) {
|
|||
"dummy": createDeltalogBuf(t, []string{"1001", "1002"}, true),
|
||||
}
|
||||
|
||||
adapter, err = NewBinlogAdapter(schema, 2, 1024, 2048, chunkManager, flushFunc, 0)
|
||||
adapter, err = NewBinlogAdapter(ctx, schema, 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// 2.1 all deletion have been filtered out
|
||||
adapter.tsStartPoint = baseTimestamp + 2
|
||||
intDeletions, strDeletions, err = adapter.readDeltalogs(holder)
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, intDeletions)
|
||||
assert.Nil(t, strDeletions)
|
||||
|
||||
// 2.2 filter the no.1 and no.2 deletion
|
||||
adapter.tsStartPoint = 0
|
||||
adapter.tsEndPoint = baseTimestamp + 1
|
||||
intDeletions, strDeletions, err = adapter.readDeltalogs(holder)
|
||||
assert.Nil(t, err)
|
||||
|
@ -470,7 +486,7 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
adapter, err = NewBinlogAdapter(schema, 2, 1024, 2048, chunkManager, flushFunc, 0)
|
||||
adapter, err = NewBinlogAdapter(ctx, schema, 2, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -482,10 +498,12 @@ func Test_BinlogAdapterReadDeltalogs(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogAdapterReadTimestamp(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -515,10 +533,12 @@ func Test_BinlogAdapterReadTimestamp(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogAdapterReadPrimaryKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -570,12 +590,14 @@ func Test_BinlogAdapterReadPrimaryKeys(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogAdapterShardListInt64(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
shardNum := int32(2)
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -610,12 +632,14 @@ func Test_BinlogAdapterShardListInt64(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogAdapterShardListVarchar(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
shardNum := int32(2)
|
||||
adapter, err := NewBinlogAdapter(strKeySchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, strKeySchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -650,6 +674,8 @@ func Test_BinlogAdapterShardListVarchar(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogAdapterReadInt64PK(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
chunkManager := &MockChunkManager{}
|
||||
|
||||
flushCounter := 0
|
||||
|
@ -669,7 +695,7 @@ func Test_BinlogAdapterReadInt64PK(t *testing.T) {
|
|||
}
|
||||
|
||||
shardNum := int32(2)
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
adapter.tsEndPoint = baseTimestamp + 1
|
||||
|
@ -741,6 +767,8 @@ func Test_BinlogAdapterReadInt64PK(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogAdapterReadVarcharPK(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
chunkManager := &MockChunkManager{}
|
||||
|
||||
flushCounter := 0
|
||||
|
@ -814,7 +842,7 @@ func Test_BinlogAdapterReadVarcharPK(t *testing.T) {
|
|||
|
||||
// succeed
|
||||
shardNum := int32(3)
|
||||
adapter, err := NewBinlogAdapter(strKeySchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, strKeySchema(), shardNum, 1024, 2048, chunkManager, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -825,93 +853,14 @@ func Test_BinlogAdapterReadVarcharPK(t *testing.T) {
|
|||
assert.Equal(t, rowCount-502, flushRowCount)
|
||||
}
|
||||
|
||||
func Test_BinlogAdapterTryFlush(t *testing.T) {
|
||||
flushCounter := 0
|
||||
flushRowCount := 0
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
flushCounter++
|
||||
rowCount := 0
|
||||
for _, v := range fields {
|
||||
rowCount = v.RowNum()
|
||||
break
|
||||
}
|
||||
flushRowCount += rowCount
|
||||
for _, v := range fields {
|
||||
assert.Equal(t, rowCount, v.RowNum())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
segmentSize := int64(1024)
|
||||
maxTotalSize := int64(2048)
|
||||
shardNum := int32(3)
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, segmentSize, maxTotalSize, &MockChunkManager{}, flushFunc, 0)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// prepare flush data, 3 shards, each shard 10 rows
|
||||
rowCount := 10
|
||||
fieldsData := createFieldsData(rowCount)
|
||||
|
||||
// non-force flush
|
||||
segmentsData := createSegmentsData(fieldsData, shardNum)
|
||||
err = adapter.tryFlushSegments(segmentsData, false)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, flushCounter)
|
||||
assert.Equal(t, 0, flushRowCount)
|
||||
|
||||
// force flush
|
||||
err = adapter.tryFlushSegments(segmentsData, true)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int(shardNum), flushCounter)
|
||||
assert.Equal(t, rowCount*int(shardNum), flushRowCount)
|
||||
|
||||
// after force flush, no data left
|
||||
flushCounter = 0
|
||||
flushRowCount = 0
|
||||
err = adapter.tryFlushSegments(segmentsData, true)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, flushCounter)
|
||||
assert.Equal(t, 0, flushRowCount)
|
||||
|
||||
// flush when segment size exceeds segmentSize
|
||||
segmentsData = createSegmentsData(fieldsData, shardNum)
|
||||
adapter.segmentSize = 100 // segmentSize is 100 bytes, less than the 10 rows size
|
||||
err = adapter.tryFlushSegments(segmentsData, false)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int(shardNum), flushCounter)
|
||||
assert.Equal(t, rowCount*int(shardNum), flushRowCount)
|
||||
|
||||
flushCounter = 0
|
||||
flushRowCount = 0
|
||||
err = adapter.tryFlushSegments(segmentsData, true) // no data left
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, flushCounter)
|
||||
assert.Equal(t, 0, flushRowCount)
|
||||
|
||||
// flush when segments total size exceeds maxTotalSize
|
||||
segmentsData = createSegmentsData(fieldsData, shardNum)
|
||||
adapter.segmentSize = 4096 // segmentSize is 4096 bytes, larger than the 10 rows size
|
||||
adapter.maxTotalSize = 100 // maxTotalSize is 100 bytes, less than the 30 rows size
|
||||
err = adapter.tryFlushSegments(segmentsData, false)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, flushCounter) // only the max segment is flushed
|
||||
assert.Equal(t, 10, flushRowCount)
|
||||
|
||||
flushCounter = 0
|
||||
flushRowCount = 0
|
||||
err = adapter.tryFlushSegments(segmentsData, true) // two segments left
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, flushCounter)
|
||||
assert.Equal(t, 20, flushRowCount)
|
||||
}
|
||||
|
||||
func Test_BinlogAdapterDispatch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
shardNum := int32(3)
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), shardNum, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -1055,10 +1004,12 @@ func Test_BinlogAdapterDispatch(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogAdapterReadInsertlog(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
adapter, err := NewBinlogAdapter(sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0)
|
||||
adapter, err := NewBinlogAdapter(ctx, sampleSchema(), 2, 1024, 2048, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ import (
|
|||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// This class is a wrapper of storage.BinlogReader, to read binlog file, block by block.
|
||||
// BinlogFile class is a wrapper of storage.BinlogReader, to read binlog file, block by block.
|
||||
// Note: for bulkoad function, we only handle normal insert log and delta log.
|
||||
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
|
||||
// Typically, an insert log file size is 16MB.
|
||||
|
@ -72,7 +72,7 @@ func (p *BinlogFile) Open(filePath string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// The outer caller must call this method in defer
|
||||
// Close close the reader object, outer caller must call this method in defer
|
||||
func (p *BinlogFile) Close() {
|
||||
if p.reader != nil {
|
||||
p.reader.Close()
|
||||
|
@ -88,8 +88,8 @@ func (p *BinlogFile) DataType() schemapb.DataType {
|
|||
return p.reader.PayloadDataType
|
||||
}
|
||||
|
||||
// ReadBool method reads all the blocks of a binlog by a data type.
|
||||
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
|
||||
// This method read all the blocks of a binlog by a data type.
|
||||
func (p *BinlogFile) ReadBool() ([]bool, error) {
|
||||
if p.reader == nil {
|
||||
log.Error("Binlog file: binlog reader not yet initialized")
|
||||
|
@ -131,8 +131,8 @@ func (p *BinlogFile) ReadBool() ([]bool, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// ReadInt8 method reads all the blocks of a binlog by a data type.
|
||||
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
|
||||
// This method read all the blocks of a binlog by a data type.
|
||||
func (p *BinlogFile) ReadInt8() ([]int8, error) {
|
||||
if p.reader == nil {
|
||||
log.Error("Binlog file: binlog reader not yet initialized")
|
||||
|
@ -174,8 +174,8 @@ func (p *BinlogFile) ReadInt8() ([]int8, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// ReadInt16 method reads all the blocks of a binlog by a data type.
|
||||
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
|
||||
// This method read all the blocks of a binlog by a data type.
|
||||
func (p *BinlogFile) ReadInt16() ([]int16, error) {
|
||||
if p.reader == nil {
|
||||
log.Error("Binlog file: binlog reader not yet initialized")
|
||||
|
@ -217,8 +217,8 @@ func (p *BinlogFile) ReadInt16() ([]int16, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// ReadInt32 method reads all the blocks of a binlog by a data type.
|
||||
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
|
||||
// This method read all the blocks of a binlog by a data type.
|
||||
func (p *BinlogFile) ReadInt32() ([]int32, error) {
|
||||
if p.reader == nil {
|
||||
log.Error("Binlog file: binlog reader not yet initialized")
|
||||
|
@ -260,8 +260,8 @@ func (p *BinlogFile) ReadInt32() ([]int32, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// ReadInt64 method reads all the blocks of a binlog by a data type.
|
||||
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
|
||||
// This method read all the blocks of a binlog by a data type.
|
||||
func (p *BinlogFile) ReadInt64() ([]int64, error) {
|
||||
if p.reader == nil {
|
||||
log.Error("Binlog file: binlog reader not yet initialized")
|
||||
|
@ -303,8 +303,8 @@ func (p *BinlogFile) ReadInt64() ([]int64, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// ReadFloat method reads all the blocks of a binlog by a data type.
|
||||
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
|
||||
// This method read all the blocks of a binlog by a data type.
|
||||
func (p *BinlogFile) ReadFloat() ([]float32, error) {
|
||||
if p.reader == nil {
|
||||
log.Error("Binlog file: binlog reader not yet initialized")
|
||||
|
@ -346,8 +346,8 @@ func (p *BinlogFile) ReadFloat() ([]float32, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// ReadDouble method reads all the blocks of a binlog by a data type.
|
||||
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
|
||||
// This method read all the blocks of a binlog by a data type.
|
||||
func (p *BinlogFile) ReadDouble() ([]float64, error) {
|
||||
if p.reader == nil {
|
||||
log.Error("Binlog file: binlog reader not yet initialized")
|
||||
|
@ -389,8 +389,8 @@ func (p *BinlogFile) ReadDouble() ([]float64, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// ReadVarchar method reads all the blocks of a binlog by a data type.
|
||||
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
|
||||
// This method read all the blocks of a binlog by a data type.
|
||||
func (p *BinlogFile) ReadVarchar() ([]string, error) {
|
||||
if p.reader == nil {
|
||||
log.Error("Binlog file: binlog reader not yet initialized")
|
||||
|
@ -433,8 +433,8 @@ func (p *BinlogFile) ReadVarchar() ([]string, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// ReadBinaryVector method reads all the blocks of a binlog by a data type.
|
||||
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
|
||||
// This method read all the blocks of a binlog by a data type.
|
||||
// return vectors data and the dimension
|
||||
func (p *BinlogFile) ReadBinaryVector() ([]byte, int, error) {
|
||||
if p.reader == nil {
|
||||
|
@ -479,8 +479,8 @@ func (p *BinlogFile) ReadBinaryVector() ([]byte, int, error) {
|
|||
return result, dim, nil
|
||||
}
|
||||
|
||||
// ReadFloatVector method reads all the blocks of a binlog by a data type.
|
||||
// A binlog is designed to support multiple blocks, but so far each binlog always contains only one block.
|
||||
// This method read all the blocks of a binlog by a data type.
|
||||
// return vectors data and the dimension
|
||||
func (p *BinlogFile) ReadFloatVector() ([]float32, int, error) {
|
||||
if p.reader == nil {
|
||||
|
|
|
@ -599,7 +599,7 @@ func Test_BinlogFileDouble(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogFileVarchar(t *testing.T) {
|
||||
source := []string{"a", "b", "c", "d"}
|
||||
source := []string{"a", "bb", "罗伯特", "d"}
|
||||
chunkManager := &MockChunkManager{
|
||||
readBuf: map[string][]byte{
|
||||
"dummy": createBinlogBuf(t, schemapb.DataType_VarChar, source),
|
||||
|
|
|
@ -19,6 +19,7 @@ package importutil
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
@ -30,23 +31,33 @@ import (
|
|||
)
|
||||
|
||||
type BinlogParser struct {
|
||||
ctx context.Context // for canceling parse process
|
||||
collectionSchema *schemapb.CollectionSchema // collection schema
|
||||
shardNum int32 // sharding number of the collection
|
||||
segmentSize int64 // maximum size of a segment(unit:byte)
|
||||
blockSize int64 // maximum size of a read block(unit:byte)
|
||||
chunkManager storage.ChunkManager // storage interfaces to browse/read the files
|
||||
callFlushFunc ImportFlushFunc // call back function to flush segment
|
||||
|
||||
// a timestamp to define the end point of restore, data after this point will be ignored
|
||||
// a timestamp to define the start time point of restore, data before this time point will be ignored
|
||||
// set this value to 0, all the data will be imported
|
||||
// set this value to math.MaxUint64, all the data will be ignored
|
||||
// the tsStartPoint value must be less/equal than tsEndPoint
|
||||
tsStartPoint uint64
|
||||
|
||||
// a timestamp to define the end time point of restore, data after this time point will be ignored
|
||||
// set this value to 0, all the data will be ignored
|
||||
// set this value to math.MaxUint64, all the data will be imported
|
||||
// the tsEndPoint value must be larger/equal than tsStartPoint
|
||||
tsEndPoint uint64
|
||||
}
|
||||
|
||||
func NewBinlogParser(collectionSchema *schemapb.CollectionSchema,
|
||||
func NewBinlogParser(ctx context.Context,
|
||||
collectionSchema *schemapb.CollectionSchema,
|
||||
shardNum int32,
|
||||
segmentSize int64,
|
||||
blockSize int64,
|
||||
chunkManager storage.ChunkManager,
|
||||
flushFunc ImportFlushFunc,
|
||||
tsStartPoint uint64,
|
||||
tsEndPoint uint64) (*BinlogParser, error) {
|
||||
if collectionSchema == nil {
|
||||
log.Error("Binlog parser: collection schema is nil")
|
||||
|
@ -63,18 +74,27 @@ func NewBinlogParser(collectionSchema *schemapb.CollectionSchema,
|
|||
return nil, errors.New("flush function is nil")
|
||||
}
|
||||
|
||||
if tsStartPoint > tsEndPoint {
|
||||
log.Error("Binlog parser: the tsStartPoint should be less than tsEndPoint",
|
||||
zap.Uint64("tsStartPoint", tsStartPoint), zap.Uint64("tsEndPoint", tsEndPoint))
|
||||
return nil, fmt.Errorf("Binlog parser: the tsStartPoint %d should be less than tsEndPoint %d", tsStartPoint, tsEndPoint)
|
||||
}
|
||||
|
||||
v := &BinlogParser{
|
||||
ctx: ctx,
|
||||
collectionSchema: collectionSchema,
|
||||
shardNum: shardNum,
|
||||
segmentSize: segmentSize,
|
||||
blockSize: blockSize,
|
||||
chunkManager: chunkManager,
|
||||
callFlushFunc: flushFunc,
|
||||
tsStartPoint: tsStartPoint,
|
||||
tsEndPoint: tsEndPoint,
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// constructSegmentHolders builds a list of SegmentFilesHolder, each SegmentFilesHolder represents a segment folder
|
||||
// For instance, the insertlogRoot is "backup/bak1/data/insert_log/435978159196147009/435978159196147010".
|
||||
// 435978159196147009 is a collection id, 435978159196147010 is a partition id,
|
||||
// there is a segment(id is 435978159261483009) under this partition.
|
||||
|
@ -195,8 +215,8 @@ func (p *BinlogParser) parseSegmentFiles(segmentHolder *SegmentFilesHolder) erro
|
|||
return errors.New("segment files holder is nil")
|
||||
}
|
||||
|
||||
adapter, err := NewBinlogAdapter(p.collectionSchema, p.shardNum, p.segmentSize,
|
||||
MaxTotalSizeInMemory, p.chunkManager, p.callFlushFunc, p.tsEndPoint)
|
||||
adapter, err := NewBinlogAdapter(p.ctx, p.collectionSchema, p.shardNum, p.blockSize,
|
||||
MaxTotalSizeInMemory, p.chunkManager, p.callFlushFunc, p.tsStartPoint, p.tsEndPoint)
|
||||
if err != nil {
|
||||
log.Error("Binlog parser: failed to create binlog adapter", zap.Error(err))
|
||||
return err
|
||||
|
@ -205,7 +225,7 @@ func (p *BinlogParser) parseSegmentFiles(segmentHolder *SegmentFilesHolder) erro
|
|||
return adapter.Read(segmentHolder)
|
||||
}
|
||||
|
||||
// This functions requires two paths:
|
||||
// Parse requires two paths:
|
||||
// 1. the insert log path of a partition
|
||||
// 2. the delta log path of a partiion (optional)
|
||||
func (p *BinlogParser) Parse(filePaths []string) error {
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
package importutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"path"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
@ -27,18 +29,20 @@ import (
|
|||
)
|
||||
|
||||
func Test_NewBinlogParser(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// nil schema
|
||||
parser, err := NewBinlogParser(nil, 2, 1024, nil, nil, 0)
|
||||
parser, err := NewBinlogParser(ctx, nil, 2, 1024, nil, nil, 0, math.MaxUint64)
|
||||
assert.Nil(t, parser)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// nil chunkmanager
|
||||
parser, err = NewBinlogParser(sampleSchema(), 2, 1024, nil, nil, 0)
|
||||
parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, nil, nil, 0, math.MaxUint64)
|
||||
assert.Nil(t, parser)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// nil flushfunc
|
||||
parser, err = NewBinlogParser(sampleSchema(), 2, 1024, &MockChunkManager{}, nil, 0)
|
||||
parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, nil, 0, math.MaxUint64)
|
||||
assert.Nil(t, parser)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
|
@ -46,12 +50,19 @@ func Test_NewBinlogParser(t *testing.T) {
|
|||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
parser, err = NewBinlogParser(sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0)
|
||||
parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, parser)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// tsStartPoint larger than tsEndPoint
|
||||
parser, err = NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 2, 1)
|
||||
assert.Nil(t, parser)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_BinlogParserConstructHolders(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -116,7 +127,7 @@ func Test_BinlogParserConstructHolders(t *testing.T) {
|
|||
"backup/bak1/data/delta_log/435978159196147009/435978159196147010/435978159261483009/434574382554415105",
|
||||
}
|
||||
|
||||
parser, err := NewBinlogParser(sampleSchema(), 2, 1024, chunkManager, flushFunc, 0)
|
||||
parser, err := NewBinlogParser(ctx, sampleSchema(), 2, 1024, chunkManager, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, parser)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -165,6 +176,8 @@ func Test_BinlogParserConstructHolders(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogParserConstructHoldersFailed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -174,7 +187,7 @@ func Test_BinlogParserConstructHoldersFailed(t *testing.T) {
|
|||
listResult: make(map[string][]string),
|
||||
}
|
||||
|
||||
parser, err := NewBinlogParser(sampleSchema(), 2, 1024, chunkManager, flushFunc, 0)
|
||||
parser, err := NewBinlogParser(ctx, sampleSchema(), 2, 1024, chunkManager, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, parser)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -214,11 +227,13 @@ func Test_BinlogParserConstructHoldersFailed(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogParserParseFilesFailed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
parser, err := NewBinlogParser(sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0)
|
||||
parser, err := NewBinlogParser(ctx, sampleSchema(), 2, 1024, &MockChunkManager{}, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, parser)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -231,6 +246,8 @@ func Test_BinlogParserParseFilesFailed(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_BinlogParserParse(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -251,7 +268,7 @@ func Test_BinlogParserParse(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
parser, err := NewBinlogParser(schema, 2, 1024, chunkManager, flushFunc, 0)
|
||||
parser, err := NewBinlogParser(ctx, schema, 2, 1024, chunkManager, flushFunc, 0, math.MaxUint64)
|
||||
assert.NotNil(t, parser)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
|
|
@ -0,0 +1,512 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
)
|
||||
|
||||
func isCanceled(ctx context.Context) bool {
|
||||
// canceled?
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return true
|
||||
default:
|
||||
break
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.FieldID]storage.FieldData {
|
||||
segmentData := make(map[storage.FieldID]storage.FieldData)
|
||||
// rowID field is a hidden field with fieldID=0, it is always auto-generated by IDAllocator
|
||||
// if primary key is int64 and autoID=true, primary key field is equal to rowID field
|
||||
segmentData[common.RowIDField] = &storage.Int64FieldData{
|
||||
Data: make([]int64, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
|
||||
for i := 0; i < len(collectionSchema.Fields); i++ {
|
||||
schema := collectionSchema.Fields[i]
|
||||
switch schema.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
segmentData[schema.GetFieldID()] = &storage.BoolFieldData{
|
||||
Data: make([]bool, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
segmentData[schema.GetFieldID()] = &storage.FloatFieldData{
|
||||
Data: make([]float32, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
segmentData[schema.GetFieldID()] = &storage.DoubleFieldData{
|
||||
Data: make([]float64, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Int8:
|
||||
segmentData[schema.GetFieldID()] = &storage.Int8FieldData{
|
||||
Data: make([]int8, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Int16:
|
||||
segmentData[schema.GetFieldID()] = &storage.Int16FieldData{
|
||||
Data: make([]int16, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Int32:
|
||||
segmentData[schema.GetFieldID()] = &storage.Int32FieldData{
|
||||
Data: make([]int32, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
segmentData[schema.GetFieldID()] = &storage.Int64FieldData{
|
||||
Data: make([]int64, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
dim, _ := getFieldDimension(schema)
|
||||
segmentData[schema.GetFieldID()] = &storage.BinaryVectorFieldData{
|
||||
Data: make([]byte, 0),
|
||||
NumRows: []int64{0},
|
||||
Dim: dim,
|
||||
}
|
||||
case schemapb.DataType_FloatVector:
|
||||
dim, _ := getFieldDimension(schema)
|
||||
segmentData[schema.GetFieldID()] = &storage.FloatVectorFieldData{
|
||||
Data: make([]float32, 0),
|
||||
NumRows: []int64{0},
|
||||
Dim: dim,
|
||||
}
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
segmentData[schema.GetFieldID()] = &storage.StringFieldData{
|
||||
Data: make([]string, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
default:
|
||||
log.Error("Import util: unsupported data type", zap.Int("DataType", int(schema.DataType)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return segmentData
|
||||
}
|
||||
|
||||
// initValidators constructs valiator methods and data conversion methods
|
||||
func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[storage.FieldID]*Validator) error {
|
||||
if collectionSchema == nil {
|
||||
return errors.New("collection schema is nil")
|
||||
}
|
||||
|
||||
// json decoder parse all the numeric value into float64
|
||||
numericValidator := func(obj interface{}) error {
|
||||
switch obj.(type) {
|
||||
case float64:
|
||||
return nil
|
||||
default:
|
||||
s := fmt.Sprintf("%v", obj)
|
||||
msg := "illegal numeric value " + s
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for i := 0; i < len(collectionSchema.Fields); i++ {
|
||||
schema := collectionSchema.Fields[i]
|
||||
|
||||
validators[schema.GetFieldID()] = &Validator{}
|
||||
validators[schema.GetFieldID()].primaryKey = schema.GetIsPrimaryKey()
|
||||
validators[schema.GetFieldID()].autoID = schema.GetAutoID()
|
||||
validators[schema.GetFieldID()].fieldName = schema.GetName()
|
||||
validators[schema.GetFieldID()].isString = false
|
||||
|
||||
switch schema.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
|
||||
switch obj.(type) {
|
||||
case bool:
|
||||
return nil
|
||||
default:
|
||||
s := fmt.Sprintf("%v", obj)
|
||||
msg := "illegal value " + s + " for bool type field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
}
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := obj.(bool)
|
||||
field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value)
|
||||
field.(*storage.BoolFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := float32(obj.(float64))
|
||||
field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value)
|
||||
field.(*storage.FloatFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := obj.(float64)
|
||||
field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value)
|
||||
field.(*storage.DoubleFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int8:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := int8(obj.(float64))
|
||||
field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value)
|
||||
field.(*storage.Int8FieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int16:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := int16(obj.(float64))
|
||||
field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value)
|
||||
field.(*storage.Int16FieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int32:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := int32(obj.(float64))
|
||||
field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value)
|
||||
field.(*storage.Int32FieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := int64(obj.(float64))
|
||||
field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value)
|
||||
field.(*storage.Int64FieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
dim, err := getFieldDimension(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validators[schema.GetFieldID()].dimension = dim
|
||||
|
||||
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
|
||||
switch vt := obj.(type) {
|
||||
case []interface{}:
|
||||
if len(vt)*8 != dim {
|
||||
msg := "bit size " + strconv.Itoa(len(vt)*8) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
for i := 0; i < len(vt); i++ {
|
||||
if e := numericValidator(vt[i]); e != nil {
|
||||
msg := e.Error() + " for binary vector field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
t := int(vt[i].(float64))
|
||||
if t > 255 || t < 0 {
|
||||
msg := "illegal value " + strconv.Itoa(t) + " for binary vector field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
s := fmt.Sprintf("%v", obj)
|
||||
msg := s + " is not an array for binary vector field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
arr := obj.([]interface{})
|
||||
for i := 0; i < len(arr); i++ {
|
||||
value := byte(arr[i].(float64))
|
||||
field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, value)
|
||||
}
|
||||
|
||||
field.(*storage.BinaryVectorFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_FloatVector:
|
||||
dim, err := getFieldDimension(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validators[schema.GetFieldID()].dimension = dim
|
||||
|
||||
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
|
||||
switch vt := obj.(type) {
|
||||
case []interface{}:
|
||||
if len(vt) != dim {
|
||||
msg := "array size " + strconv.Itoa(len(vt)) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
for i := 0; i < len(vt); i++ {
|
||||
if e := numericValidator(vt[i]); e != nil {
|
||||
msg := e.Error() + " for float vector field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
s := fmt.Sprintf("%v", obj)
|
||||
msg := s + " is not an array for float vector field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
arr := obj.([]interface{})
|
||||
for i := 0; i < len(arr); i++ {
|
||||
value := float32(arr[i].(float64))
|
||||
field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, value)
|
||||
}
|
||||
field.(*storage.FloatVectorFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
validators[schema.GetFieldID()].isString = true
|
||||
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
|
||||
switch obj.(type) {
|
||||
case string:
|
||||
return nil
|
||||
default:
|
||||
s := fmt.Sprintf("%v", obj)
|
||||
msg := s + " is not a string for string type field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := obj.(string)
|
||||
field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, value)
|
||||
field.(*storage.StringFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return errors.New("unsupport data type: " + strconv.Itoa(int(collectionSchema.Fields[i].DataType)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printFieldsDataInfo(fieldsData map[storage.FieldID]storage.FieldData, msg string, files []string) {
|
||||
stats := make([]zapcore.Field, 0)
|
||||
for k, v := range fieldsData {
|
||||
stats = append(stats, zap.Int(strconv.FormatInt(k, 10), v.RowNum()))
|
||||
}
|
||||
|
||||
if len(files) > 0 {
|
||||
stats = append(stats, zap.Any("files", files))
|
||||
}
|
||||
log.Info(msg, stats...)
|
||||
}
|
||||
|
||||
// GetFileNameAndExt extracts file name and extension
|
||||
// for example: "/a/b/c.ttt" returns "c" and ".ttt"
|
||||
func GetFileNameAndExt(filePath string) (string, string) {
|
||||
fileName := path.Base(filePath)
|
||||
fileType := path.Ext(fileName)
|
||||
fileNameWithoutExt := strings.TrimSuffix(fileName, fileType)
|
||||
return fileNameWithoutExt, fileType
|
||||
}
|
||||
|
||||
// getFieldDimension gets dimension of vecotor field
|
||||
func getFieldDimension(schema *schemapb.FieldSchema) (int, error) {
|
||||
for _, kvPair := range schema.GetTypeParams() {
|
||||
key, value := kvPair.GetKey(), kvPair.GetValue()
|
||||
if key == "dim" {
|
||||
dim, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return 0, errors.New("vector dimension is invalid")
|
||||
}
|
||||
return dim, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, errors.New("vector dimension is not defined")
|
||||
}
|
||||
|
||||
// triggerGC triggers golang gc to return all free memory back to the underlying system at once,
|
||||
// Note: this operation is expensive, and can lead to latency spikes as it holds the heap lock through the whole process
|
||||
func triggerGC() {
|
||||
debug.FreeOSMemory()
|
||||
}
|
||||
|
||||
// tryFlushBlocks does the two things:
|
||||
// 1. if accumulate data of a block exceed blockSize, call callFlushFunc to generate new binlog file
|
||||
// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest block
|
||||
func tryFlushBlocks(ctx context.Context,
|
||||
blocksData []map[storage.FieldID]storage.FieldData,
|
||||
collectionSchema *schemapb.CollectionSchema,
|
||||
callFlushFunc ImportFlushFunc,
|
||||
blockSize int64,
|
||||
maxTotalSize int64,
|
||||
force bool) error {
|
||||
|
||||
totalSize := 0
|
||||
biggestSize := 0
|
||||
biggestItem := -1
|
||||
|
||||
// 1. if accumulate data of a block exceed blockSize, call callFlushFunc to generate new binlog file
|
||||
for i := 0; i < len(blocksData); i++ {
|
||||
// outside context might be canceled(service stop, or future enhancement for canceling import task)
|
||||
if isCanceled(ctx) {
|
||||
log.Error("Import util: import task was canceled")
|
||||
return errors.New("import task was canceled")
|
||||
}
|
||||
|
||||
blockData := blocksData[i]
|
||||
// Note: even rowCount is 0, the size is still non-zero
|
||||
size := 0
|
||||
rowCount := 0
|
||||
for _, fieldData := range blockData {
|
||||
size += fieldData.GetMemorySize()
|
||||
rowCount = fieldData.RowNum()
|
||||
}
|
||||
|
||||
// force to flush, called at the end of Read()
|
||||
if force && rowCount > 0 {
|
||||
printFieldsDataInfo(blockData, "import util: prepare to force flush a block", nil)
|
||||
err := callFlushFunc(blockData, i)
|
||||
if err != nil {
|
||||
log.Error("Import util: failed to force flush block data", zap.Int("shardID", i))
|
||||
return err
|
||||
}
|
||||
log.Info("Import util: force flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i))
|
||||
|
||||
blocksData[i] = initSegmentData(collectionSchema)
|
||||
if blocksData[i] == nil {
|
||||
log.Error("Import util: failed to initialize FieldData list")
|
||||
return errors.New("failed to initialize FieldData list")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// if segment size is larger than predefined blockSize, flush to create a new binlog file
|
||||
// initialize a new FieldData list for next round batch read
|
||||
if size > int(blockSize) && rowCount > 0 {
|
||||
printFieldsDataInfo(blockData, "import util: prepare to flush block larger than maxBlockSize", nil)
|
||||
err := callFlushFunc(blockData, i)
|
||||
if err != nil {
|
||||
log.Error("Import util: failed to flush block data", zap.Int("shardID", i))
|
||||
return err
|
||||
}
|
||||
log.Info("Import util: block size exceed limit and flush", zap.Int("rowCount", rowCount), zap.Int("size", size), zap.Int("shardID", i))
|
||||
|
||||
blocksData[i] = initSegmentData(collectionSchema)
|
||||
if blocksData[i] == nil {
|
||||
log.Error("Import util: failed to initialize FieldData list")
|
||||
return errors.New("failed to initialize FieldData list")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// calculate the total size(ignore the flushed blocks)
|
||||
// find out the biggest block for the step 2
|
||||
totalSize += size
|
||||
if size > biggestSize {
|
||||
biggestSize = size
|
||||
biggestItem = i
|
||||
}
|
||||
}
|
||||
|
||||
// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest block
|
||||
if totalSize > int(maxTotalSize) && biggestItem >= 0 {
|
||||
// outside context might be canceled(service stop, or future enhancement for canceling import task)
|
||||
if isCanceled(ctx) {
|
||||
log.Error("Import util: import task was canceled")
|
||||
return errors.New("import task was canceled")
|
||||
}
|
||||
|
||||
blockData := blocksData[biggestItem]
|
||||
// Note: even rowCount is 0, the size is still non-zero
|
||||
size := 0
|
||||
rowCount := 0
|
||||
for _, fieldData := range blockData {
|
||||
size += fieldData.GetMemorySize()
|
||||
rowCount = fieldData.RowNum()
|
||||
}
|
||||
|
||||
if rowCount > 0 {
|
||||
printFieldsDataInfo(blockData, "import util: prepare to flush biggest block", nil)
|
||||
err := callFlushFunc(blockData, biggestItem)
|
||||
if err != nil {
|
||||
log.Error("Import util: failed to flush biggest block data", zap.Int("shardID", biggestItem))
|
||||
return err
|
||||
}
|
||||
log.Info("Import util: total size exceed limit and flush", zap.Int("rowCount", rowCount),
|
||||
zap.Int("size", size), zap.Int("totalSize", totalSize), zap.Int("shardID", biggestItem))
|
||||
|
||||
blocksData[biggestItem] = initSegmentData(collectionSchema)
|
||||
if blocksData[biggestItem] == nil {
|
||||
log.Error("Import util: failed to initialize FieldData list")
|
||||
return errors.New("failed to initialize FieldData list")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getTypeName(dt schemapb.DataType) string {
|
||||
switch dt {
|
||||
case schemapb.DataType_Bool:
|
||||
return "Bool"
|
||||
case schemapb.DataType_Int8:
|
||||
return "Int8"
|
||||
case schemapb.DataType_Int16:
|
||||
return "Int16"
|
||||
case schemapb.DataType_Int32:
|
||||
return "Int32"
|
||||
case schemapb.DataType_Int64:
|
||||
return "Int64"
|
||||
case schemapb.DataType_Float:
|
||||
return "Float"
|
||||
case schemapb.DataType_Double:
|
||||
return "Double"
|
||||
case schemapb.DataType_VarChar:
|
||||
return "Varchar"
|
||||
case schemapb.DataType_String:
|
||||
return "String"
|
||||
case schemapb.DataType_BinaryVector:
|
||||
return "BinaryVector"
|
||||
case schemapb.DataType_FloatVector:
|
||||
return "FloatVector"
|
||||
default:
|
||||
return "InvalidType"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,460 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package importutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func sampleSchema() *schemapb.CollectionSchema {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "field_bool",
|
||||
IsPrimaryKey: false,
|
||||
Description: "bool",
|
||||
DataType: schemapb.DataType_Bool,
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "field_int8",
|
||||
IsPrimaryKey: false,
|
||||
Description: "int8",
|
||||
DataType: schemapb.DataType_Int8,
|
||||
},
|
||||
{
|
||||
FieldID: 104,
|
||||
Name: "field_int16",
|
||||
IsPrimaryKey: false,
|
||||
Description: "int16",
|
||||
DataType: schemapb.DataType_Int16,
|
||||
},
|
||||
{
|
||||
FieldID: 105,
|
||||
Name: "field_int32",
|
||||
IsPrimaryKey: false,
|
||||
Description: "int32",
|
||||
DataType: schemapb.DataType_Int32,
|
||||
},
|
||||
{
|
||||
FieldID: 106,
|
||||
Name: "field_int64",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
Description: "int64",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 107,
|
||||
Name: "field_float",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float",
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
{
|
||||
FieldID: 108,
|
||||
Name: "field_double",
|
||||
IsPrimaryKey: false,
|
||||
Description: "double",
|
||||
DataType: schemapb.DataType_Double,
|
||||
},
|
||||
{
|
||||
FieldID: 109,
|
||||
Name: "field_string",
|
||||
IsPrimaryKey: false,
|
||||
Description: "string",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "max_length", Value: "128"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 110,
|
||||
Name: "field_binary_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "binary_vector",
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "16"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 111,
|
||||
Name: "field_float_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float_vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func strKeySchema() *schemapb.CollectionSchema {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "uid",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
Description: "uid",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "max_length", Value: "1024"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "int_scalar",
|
||||
IsPrimaryKey: false,
|
||||
Description: "int_scalar",
|
||||
DataType: schemapb.DataType_Int32,
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "float_scalar",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float_scalar",
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
{
|
||||
FieldID: 104,
|
||||
Name: "string_scalar",
|
||||
IsPrimaryKey: false,
|
||||
Description: "string_scalar",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "max_length", Value: "128"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 105,
|
||||
Name: "bool_scalar",
|
||||
IsPrimaryKey: false,
|
||||
Description: "bool_scalar",
|
||||
DataType: schemapb.DataType_Bool,
|
||||
},
|
||||
{
|
||||
FieldID: 106,
|
||||
Name: "vectors",
|
||||
IsPrimaryKey: false,
|
||||
Description: "vectors",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func Test_IsCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
assert.False(t, isCanceled(ctx))
|
||||
cancel()
|
||||
assert.True(t, isCanceled(ctx))
|
||||
}
|
||||
|
||||
func Test_InitSegmentData(t *testing.T) {
|
||||
testFunc := func(schema *schemapb.CollectionSchema) {
|
||||
fields := initSegmentData(schema)
|
||||
assert.Equal(t, len(schema.Fields)+1, len(fields))
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
data, ok := fields[field.FieldID]
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, data)
|
||||
}
|
||||
printFieldsDataInfo(fields, "dummy", []string{})
|
||||
}
|
||||
testFunc(sampleSchema())
|
||||
testFunc(strKeySchema())
|
||||
}
|
||||
|
||||
func Test_InitValidators(t *testing.T) {
|
||||
validators := make(map[storage.FieldID]*Validator)
|
||||
err := initValidators(nil, validators)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
schema := sampleSchema()
|
||||
// success case
|
||||
err = initValidators(schema, validators)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(schema.Fields), len(validators))
|
||||
name2ID := make(map[string]storage.FieldID)
|
||||
for _, field := range schema.Fields {
|
||||
name2ID[field.GetName()] = field.GetFieldID()
|
||||
}
|
||||
|
||||
checkFunc := func(funcName string, validVal interface{}, invalidVal interface{}) {
|
||||
id := name2ID[funcName]
|
||||
v, ok := validators[id]
|
||||
assert.True(t, ok)
|
||||
err = v.validateFunc(validVal)
|
||||
assert.Nil(t, err)
|
||||
err = v.validateFunc(invalidVal)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
// validate functions
|
||||
var validVal interface{} = true
|
||||
var invalidVal interface{} = "aa"
|
||||
checkFunc("field_bool", validVal, invalidVal)
|
||||
|
||||
validVal = float64(100)
|
||||
invalidVal = "aa"
|
||||
checkFunc("field_int8", validVal, invalidVal)
|
||||
checkFunc("field_int16", validVal, invalidVal)
|
||||
checkFunc("field_int32", validVal, invalidVal)
|
||||
checkFunc("field_int64", validVal, invalidVal)
|
||||
checkFunc("field_float", validVal, invalidVal)
|
||||
checkFunc("field_double", validVal, invalidVal)
|
||||
|
||||
validVal = "aa"
|
||||
invalidVal = 100
|
||||
checkFunc("field_string", validVal, invalidVal)
|
||||
|
||||
validVal = []interface{}{float64(100), float64(101)}
|
||||
invalidVal = "aa"
|
||||
checkFunc("field_binary_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{float64(100)}
|
||||
checkFunc("field_binary_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{float64(100), float64(101), float64(102)}
|
||||
checkFunc("field_binary_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{true, true}
|
||||
checkFunc("field_binary_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{float64(255), float64(-1)}
|
||||
checkFunc("field_binary_vector", validVal, invalidVal)
|
||||
|
||||
validVal = []interface{}{float64(1), float64(2), float64(3), float64(4)}
|
||||
invalidVal = true
|
||||
checkFunc("field_float_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{float64(1), float64(2), float64(3)}
|
||||
checkFunc("field_float_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5)}
|
||||
checkFunc("field_float_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{"a", "b", "c", "d"}
|
||||
checkFunc("field_float_vector", validVal, invalidVal)
|
||||
|
||||
// error cases
|
||||
schema = &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
AutoID: true,
|
||||
Fields: make([]*schemapb.FieldSchema, 0),
|
||||
}
|
||||
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
|
||||
FieldID: 111,
|
||||
Name: "field_float_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float_vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "aa"},
|
||||
},
|
||||
})
|
||||
|
||||
validators = make(map[storage.FieldID]*Validator)
|
||||
err = initValidators(schema, validators)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
schema.Fields = make([]*schemapb.FieldSchema, 0)
|
||||
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
|
||||
FieldID: 110,
|
||||
Name: "field_binary_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float_vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "aa"},
|
||||
},
|
||||
})
|
||||
|
||||
err = initValidators(schema, validators)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_GetFileNameAndExt(t *testing.T) {
|
||||
filePath := "aaa/bbb/ccc.txt"
|
||||
name, ext := GetFileNameAndExt(filePath)
|
||||
assert.EqualValues(t, "ccc", name)
|
||||
assert.EqualValues(t, ".txt", ext)
|
||||
}
|
||||
|
||||
func Test_GetFieldDimension(t *testing.T) {
|
||||
schema := &schemapb.FieldSchema{
|
||||
FieldID: 111,
|
||||
Name: "field_float_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float_vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
}
|
||||
|
||||
dim, err := getFieldDimension(schema)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, dim)
|
||||
|
||||
schema.TypeParams = []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "abc"},
|
||||
}
|
||||
dim, err = getFieldDimension(schema)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, 0, dim)
|
||||
|
||||
schema.TypeParams = []*commonpb.KeyValuePair{}
|
||||
dim, err = getFieldDimension(schema)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, 0, dim)
|
||||
}
|
||||
|
||||
func Test_TryFlushBlocks(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
flushCounter := 0
|
||||
flushRowCount := 0
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
flushCounter++
|
||||
rowCount := 0
|
||||
for _, v := range fields {
|
||||
rowCount = v.RowNum()
|
||||
break
|
||||
}
|
||||
flushRowCount += rowCount
|
||||
for _, v := range fields {
|
||||
assert.Equal(t, rowCount, v.RowNum())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
blockSize := int64(1024)
|
||||
maxTotalSize := int64(2048)
|
||||
shardNum := int32(3)
|
||||
|
||||
// prepare flush data, 3 shards, each shard 10 rows
|
||||
rowCount := 10
|
||||
fieldsData := createFieldsData(rowCount)
|
||||
|
||||
// non-force flush
|
||||
segmentsData := createSegmentsData(fieldsData, shardNum)
|
||||
err := tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, false)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, flushCounter)
|
||||
assert.Equal(t, 0, flushRowCount)
|
||||
|
||||
// force flush
|
||||
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int(shardNum), flushCounter)
|
||||
assert.Equal(t, rowCount*int(shardNum), flushRowCount)
|
||||
|
||||
// after force flush, no data left
|
||||
flushCounter = 0
|
||||
flushRowCount = 0
|
||||
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, flushCounter)
|
||||
assert.Equal(t, 0, flushRowCount)
|
||||
|
||||
// flush when segment size exceeds blockSize
|
||||
segmentsData = createSegmentsData(fieldsData, shardNum)
|
||||
blockSize = 100 // blockSize is 100 bytes, less than the 10 rows size
|
||||
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, false)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int(shardNum), flushCounter)
|
||||
assert.Equal(t, rowCount*int(shardNum), flushRowCount)
|
||||
|
||||
flushCounter = 0
|
||||
flushRowCount = 0
|
||||
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true) // no data left
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, flushCounter)
|
||||
assert.Equal(t, 0, flushRowCount)
|
||||
|
||||
// flush when segments total size exceeds maxTotalSize
|
||||
segmentsData = createSegmentsData(fieldsData, shardNum)
|
||||
blockSize = 4096 // blockSize is 4096 bytes, larger than the 10 rows size
|
||||
maxTotalSize = 100 // maxTotalSize is 100 bytes, less than the 30 rows size
|
||||
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, false)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, flushCounter) // only the max segment is flushed
|
||||
assert.Equal(t, 10, flushRowCount)
|
||||
|
||||
flushCounter = 0
|
||||
flushRowCount = 0
|
||||
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true) // two segments left
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, flushCounter)
|
||||
assert.Equal(t, 20, flushRowCount)
|
||||
|
||||
// canceled
|
||||
cancel()
|
||||
flushCounter = 0
|
||||
flushRowCount = 0
|
||||
segmentsData = createSegmentsData(fieldsData, shardNum)
|
||||
err = tryFlushBlocks(ctx, segmentsData, sampleSchema(), flushFunc, blockSize, maxTotalSize, true)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 0, flushCounter)
|
||||
assert.Equal(t, 0, flushRowCount)
|
||||
}
|
||||
|
||||
func Test_GetTypeName(t *testing.T) {
|
||||
str := getTypeName(schemapb.DataType_Bool)
|
||||
assert.NotEmpty(t, str)
|
||||
str = getTypeName(schemapb.DataType_Int8)
|
||||
assert.NotEmpty(t, str)
|
||||
str = getTypeName(schemapb.DataType_Int16)
|
||||
assert.NotEmpty(t, str)
|
||||
str = getTypeName(schemapb.DataType_Int32)
|
||||
assert.NotEmpty(t, str)
|
||||
str = getTypeName(schemapb.DataType_Int64)
|
||||
assert.NotEmpty(t, str)
|
||||
str = getTypeName(schemapb.DataType_Float)
|
||||
assert.NotEmpty(t, str)
|
||||
str = getTypeName(schemapb.DataType_Double)
|
||||
assert.NotEmpty(t, str)
|
||||
str = getTypeName(schemapb.DataType_VarChar)
|
||||
assert.NotEmpty(t, str)
|
||||
str = getTypeName(schemapb.DataType_String)
|
||||
assert.NotEmpty(t, str)
|
||||
str = getTypeName(schemapb.DataType_BinaryVector)
|
||||
assert.NotEmpty(t, str)
|
||||
str = getTypeName(schemapb.DataType_FloatVector)
|
||||
assert.NotEmpty(t, str)
|
||||
str = getTypeName(schemapb.DataType_None)
|
||||
assert.Equal(t, "InvalidType", str)
|
||||
}
|
|
@ -19,21 +19,17 @@ package importutil
|
|||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
|
@ -45,6 +41,9 @@ const (
|
|||
JSONFileExt = ".json"
|
||||
NumpyFileExt = ".npy"
|
||||
|
||||
// supposed size of a single block, to control a binlog file size, the max biglog file size is no more than 2*SingleBlockSize
|
||||
SingleBlockSize = 16 * 1024 * 1024 // 16MB
|
||||
|
||||
// this limitation is to avoid this OOM risk:
|
||||
// for column-based file, we read all its data into memory, if user input a large file, the read() method may
|
||||
// cost extra memory and lear to OOM.
|
||||
|
@ -64,26 +63,60 @@ const (
|
|||
// ReportImportAttempts is the maximum # of attempts to retry when import fails.
|
||||
var ReportImportAttempts uint = 10
|
||||
|
||||
type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardID int) error
|
||||
type AssignSegmentFunc func(shardID int) (int64, string, error)
|
||||
type CreateBinlogsFunc func(fields map[storage.FieldID]storage.FieldData, segmentID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error)
|
||||
type SaveSegmentFunc func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64) error
|
||||
|
||||
type WorkingSegment struct {
|
||||
segmentID int64 // segment ID
|
||||
shardID int // shard id
|
||||
targetChName string // target dml channel
|
||||
rowCount int64 // accumulate row count
|
||||
memSize int // total memory size of all binlogs
|
||||
fieldsInsert []*datapb.FieldBinlog // persisted binlogs
|
||||
fieldsStats []*datapb.FieldBinlog // stats of persisted binlogs
|
||||
}
|
||||
|
||||
type ImportOptions struct {
|
||||
OnlyValidate bool
|
||||
TsStartPoint uint64
|
||||
TsEndPoint uint64
|
||||
}
|
||||
|
||||
func DefaultImportOptions() ImportOptions {
|
||||
options := ImportOptions{
|
||||
OnlyValidate: false,
|
||||
TsStartPoint: 0,
|
||||
TsEndPoint: math.MaxUint64,
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
type ImportWrapper struct {
|
||||
ctx context.Context // for canceling parse process
|
||||
cancel context.CancelFunc // for canceling parse process
|
||||
collectionSchema *schemapb.CollectionSchema // collection schema
|
||||
shardNum int32 // sharding number of the collection
|
||||
segmentSize int64 // maximum size of a segment(unit:byte)
|
||||
segmentSize int64 // maximum size of a segment(unit:byte) defined by dataCoord.segment.maxSize (milvus.yml)
|
||||
rowIDAllocator *allocator.IDAllocator // autoid allocator
|
||||
chunkManager storage.ChunkManager
|
||||
|
||||
callFlushFunc ImportFlushFunc // call back function to flush a segment
|
||||
assignSegmentFunc AssignSegmentFunc // function to prepare a new segment
|
||||
createBinlogsFunc CreateBinlogsFunc // function to create binlog for a segment
|
||||
saveSegmentFunc SaveSegmentFunc // function to persist a segment
|
||||
|
||||
importResult *rootcoordpb.ImportResult // import result
|
||||
reportFunc func(res *rootcoordpb.ImportResult) error // report import state to rootcoord
|
||||
|
||||
workingSegments map[int]*WorkingSegment // a map shard id to working segments
|
||||
}
|
||||
|
||||
func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int64,
|
||||
idAlloc *allocator.IDAllocator, cm storage.ChunkManager, flushFunc ImportFlushFunc,
|
||||
importResult *rootcoordpb.ImportResult, reportFunc func(res *rootcoordpb.ImportResult) error) *ImportWrapper {
|
||||
idAlloc *allocator.IDAllocator, cm storage.ChunkManager, importResult *rootcoordpb.ImportResult,
|
||||
reportFunc func(res *rootcoordpb.ImportResult) error) *ImportWrapper {
|
||||
if collectionSchema == nil {
|
||||
log.Error("import error: collection schema is nil")
|
||||
log.Error("import wrapper: collection schema is nil")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -111,96 +144,143 @@ func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.Collection
|
|||
shardNum: shardNum,
|
||||
segmentSize: segmentSize,
|
||||
rowIDAllocator: idAlloc,
|
||||
callFlushFunc: flushFunc,
|
||||
chunkManager: cm,
|
||||
importResult: importResult,
|
||||
reportFunc: reportFunc,
|
||||
workingSegments: make(map[int]*WorkingSegment),
|
||||
}
|
||||
|
||||
return wrapper
|
||||
}
|
||||
|
||||
// this method can be used to cancel parse process
|
||||
func (p *ImportWrapper) SetCallbackFunctions(assignSegmentFunc AssignSegmentFunc, createBinlogsFunc CreateBinlogsFunc, saveSegmentFunc SaveSegmentFunc) error {
|
||||
if assignSegmentFunc == nil {
|
||||
log.Error("import wrapper: callback function AssignSegmentFunc is nil")
|
||||
return fmt.Errorf("import wrapper: callback function AssignSegmentFunc is nil")
|
||||
}
|
||||
|
||||
if createBinlogsFunc == nil {
|
||||
log.Error("import wrapper: callback function CreateBinlogsFunc is nil")
|
||||
return fmt.Errorf("import wrapper: callback function CreateBinlogsFunc is nil")
|
||||
}
|
||||
|
||||
if saveSegmentFunc == nil {
|
||||
log.Error("import wrapper: callback function SaveSegmentFunc is nil")
|
||||
return fmt.Errorf("import wrapper: callback function SaveSegmentFunc is nil")
|
||||
}
|
||||
|
||||
p.assignSegmentFunc = assignSegmentFunc
|
||||
p.createBinlogsFunc = createBinlogsFunc
|
||||
p.saveSegmentFunc = saveSegmentFunc
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel method can be used to cancel parse process
|
||||
func (p *ImportWrapper) Cancel() error {
|
||||
p.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[storage.FieldID]storage.FieldData, msg string, files []string) {
|
||||
stats := make([]zapcore.Field, 0)
|
||||
for k, v := range fieldsData {
|
||||
stats = append(stats, zap.Int(strconv.FormatInt(k, 10), v.RowNum()))
|
||||
func (p *ImportWrapper) validateColumnBasedFiles(filePaths []string, collectionSchema *schemapb.CollectionSchema) error {
|
||||
requiredFieldNames := make(map[string]interface{})
|
||||
for _, schema := range p.collectionSchema.Fields {
|
||||
if schema.GetIsPrimaryKey() {
|
||||
if !schema.GetAutoID() {
|
||||
requiredFieldNames[schema.GetName()] = nil
|
||||
}
|
||||
} else {
|
||||
requiredFieldNames[schema.GetName()] = nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(files) > 0 {
|
||||
stats = append(stats, zap.Any("files", files))
|
||||
// check redundant file
|
||||
fileNames := make(map[string]interface{})
|
||||
for _, filePath := range filePaths {
|
||||
name, _ := GetFileNameAndExt(filePath)
|
||||
fileNames[name] = nil
|
||||
_, ok := requiredFieldNames[name]
|
||||
if !ok {
|
||||
log.Error("import wrapper: the file has no corresponding field in collection", zap.String("fieldName", name))
|
||||
return fmt.Errorf("import wrapper: the file '%s' has no corresponding field in collection", filePath)
|
||||
}
|
||||
}
|
||||
log.Info(msg, stats...)
|
||||
|
||||
// check missed file
|
||||
for name := range requiredFieldNames {
|
||||
_, ok := fileNames[name]
|
||||
if !ok {
|
||||
log.Error("import wrapper: there is no file corresponding to field", zap.String("fieldName", name))
|
||||
return fmt.Errorf("import wrapper: there is no file corresponding to field '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getFileNameAndExt(filePath string) (string, string) {
|
||||
fileName := path.Base(filePath)
|
||||
fileType := path.Ext(fileName)
|
||||
fileNameWithoutExt := strings.TrimSuffix(fileName, fileType)
|
||||
return fileNameWithoutExt, fileType
|
||||
}
|
||||
|
||||
// trigger golang gc to return all free memory back to the underlying system at once,
|
||||
// Note: this operation is expensive, and can lead to latency spikes as it holds the heap lock through the whole process
|
||||
func triggerGC() {
|
||||
debug.FreeOSMemory()
|
||||
}
|
||||
|
||||
func (p *ImportWrapper) fileValidation(filePaths []string, rowBased bool) error {
|
||||
// fileValidation verify the input paths
|
||||
// if all the files are json type, return true
|
||||
// if all the files are numpy type, return false, and not allow duplicate file name
|
||||
func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) {
|
||||
// use this map to check duplicate file name(only for numpy file)
|
||||
fileNames := make(map[string]struct{})
|
||||
|
||||
totalSize := int64(0)
|
||||
rowBased := false
|
||||
for i := 0; i < len(filePaths); i++ {
|
||||
filePath := filePaths[i]
|
||||
name, fileType := getFileNameAndExt(filePath)
|
||||
_, ok := fileNames[name]
|
||||
if ok {
|
||||
// only check dupliate numpy file
|
||||
if fileType == NumpyFileExt {
|
||||
log.Error("import wrapper: duplicate file name", zap.String("fileName", name+"."+fileType))
|
||||
return errors.New("duplicate file: " + name + "." + fileType)
|
||||
}
|
||||
} else {
|
||||
fileNames[name] = struct{}{}
|
||||
name, fileType := GetFileNameAndExt(filePath)
|
||||
|
||||
// only allow json file or numpy file
|
||||
if fileType != JSONFileExt && fileType != NumpyFileExt {
|
||||
log.Error("import wrapper: unsupportted file type", zap.String("filePath", filePath))
|
||||
return false, fmt.Errorf("import wrapper: unsupportted file type: '%s'", filePath)
|
||||
}
|
||||
|
||||
// we use the first file to determine row-based or column-based
|
||||
if i == 0 && fileType == JSONFileExt {
|
||||
rowBased = true
|
||||
}
|
||||
|
||||
// check file type
|
||||
// row-based only support json type, column-based can support json and numpy type
|
||||
// row-based only support json type, column-based only support numpy type
|
||||
if rowBased {
|
||||
if fileType != JSONFileExt {
|
||||
log.Error("import wrapper: unsupported file type for row-based mode", zap.String("filePath", filePath))
|
||||
return errors.New("unsupported file type for row-based mode: " + filePath)
|
||||
return rowBased, fmt.Errorf("import wrapper: unsupported file type for row-based mode: '%s'", filePath)
|
||||
}
|
||||
} else {
|
||||
if fileType != JSONFileExt && fileType != NumpyFileExt {
|
||||
if fileType != NumpyFileExt {
|
||||
log.Error("import wrapper: unsupported file type for column-based mode", zap.String("filePath", filePath))
|
||||
return errors.New("unsupported file type for column-based mode: " + filePath)
|
||||
return rowBased, fmt.Errorf("import wrapper: unsupported file type for column-based mode: '%s'", filePath)
|
||||
}
|
||||
}
|
||||
|
||||
// check dupliate file
|
||||
_, ok := fileNames[name]
|
||||
if ok {
|
||||
log.Error("import wrapper: duplicate file name", zap.String("filePath", filePath))
|
||||
return rowBased, fmt.Errorf("import wrapper: duplicate file: '%s'", filePath)
|
||||
}
|
||||
fileNames[name] = struct{}{}
|
||||
|
||||
// check file size, single file size cannot exceed MaxFileSize
|
||||
// TODO add context
|
||||
size, err := p.chunkManager.Size(context.TODO(), filePath)
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to get file size", zap.String("filePath", filePath), zap.Any("err", err))
|
||||
return errors.New("failed to get file size of " + filePath)
|
||||
return rowBased, fmt.Errorf("import wrapper: failed to get file size of '%s'", filePath)
|
||||
}
|
||||
|
||||
// empty file
|
||||
if size == 0 {
|
||||
log.Error("import wrapper: file path is empty", zap.String("filePath", filePath))
|
||||
return errors.New("the file " + filePath + " is empty")
|
||||
log.Error("import wrapper: file size is zero", zap.String("filePath", filePath))
|
||||
return rowBased, fmt.Errorf("import wrapper: the file '%s' size is zero", filePath)
|
||||
}
|
||||
|
||||
if size > MaxFileSize {
|
||||
log.Error("import wrapper: file size exceeds the maximum size", zap.String("filePath", filePath),
|
||||
zap.Int64("fileSize", size), zap.Int64("MaxFileSize", MaxFileSize))
|
||||
return errors.New("the file " + filePath + " size exceeds the maximum size: " + strconv.FormatInt(MaxFileSize, 10) + " bytes")
|
||||
return rowBased, fmt.Errorf("import wrapper: the file '%s' size exceeds the maximum size: %d bytes", filePath, MaxFileSize)
|
||||
}
|
||||
totalSize += size
|
||||
}
|
||||
|
@ -208,26 +288,36 @@ func (p *ImportWrapper) fileValidation(filePaths []string, rowBased bool) error
|
|||
// especially for column-base, total size of files cannot exceed MaxTotalSizeInMemory
|
||||
if totalSize > MaxTotalSizeInMemory {
|
||||
log.Error("import wrapper: total size of files exceeds the maximum size", zap.Int64("totalSize", totalSize), zap.Int64("MaxTotalSize", MaxTotalSizeInMemory))
|
||||
return errors.New("the total size of all files exceeds the maximum size: " + strconv.FormatInt(MaxTotalSizeInMemory, 10) + " bytes")
|
||||
return rowBased, fmt.Errorf("import wrapper: total size(%d bytes) of all files exceeds the maximum size: %d bytes", totalSize, MaxTotalSizeInMemory)
|
||||
}
|
||||
|
||||
return nil
|
||||
// check redundant files for column-based import
|
||||
// if the field is primary key and autoid is false, the file is required
|
||||
// any redundant file is not allowed
|
||||
if !rowBased {
|
||||
err := p.validateColumnBasedFiles(filePaths, p.collectionSchema)
|
||||
if err != nil {
|
||||
return rowBased, err
|
||||
}
|
||||
}
|
||||
|
||||
return rowBased, nil
|
||||
}
|
||||
|
||||
// import process entry
|
||||
// Import is the entry of import operation
|
||||
// filePath and rowBased are from ImportTask
|
||||
// if onlyValidate is true, this process only do validation, no data generated, callFlushFunc will not be called
|
||||
func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate bool) error {
|
||||
log.Info("import wrapper: filePaths", zap.Any("filePaths", filePaths))
|
||||
// if onlyValidate is true, this process only do validation, no data generated, flushFunc will not be called
|
||||
func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error {
|
||||
log.Info("import wrapper: begin import", zap.Any("filePaths", filePaths), zap.Any("options", options))
|
||||
// data restore function to import milvus native binlog files(for backup/restore tools)
|
||||
// the backup/restore tool provide two paths for a partition, the first path is binlog path, the second is deltalog path
|
||||
if p.isBinlogImport(filePaths) {
|
||||
// TODO: handle the timestamp end point passed from client side, currently use math.MaxUint64
|
||||
return p.doBinlogImport(filePaths, math.MaxUint64)
|
||||
return p.doBinlogImport(filePaths, options.TsStartPoint, options.TsEndPoint)
|
||||
}
|
||||
|
||||
// normal logic for import general data files
|
||||
err := p.fileValidation(filePaths, rowBased)
|
||||
rowBased, err := p.fileValidation(filePaths)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -235,16 +325,16 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
|
|||
if rowBased {
|
||||
// parse and consume row-based files
|
||||
// for row-based files, the JSONRowConsumer will generate autoid for primary key, and split rows into segments
|
||||
// according to shard number, so the callFlushFunc will be called in the JSONRowConsumer
|
||||
// according to shard number, so the flushFunc will be called in the JSONRowConsumer
|
||||
for i := 0; i < len(filePaths); i++ {
|
||||
filePath := filePaths[i]
|
||||
_, fileType := getFileNameAndExt(filePath)
|
||||
_, fileType := GetFileNameAndExt(filePath)
|
||||
log.Info("import wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
|
||||
|
||||
if fileType == JSONFileExt {
|
||||
err = p.parseRowBasedJSON(filePath, onlyValidate)
|
||||
err = p.parseRowBasedJSON(filePath, options.OnlyValidate)
|
||||
if err != nil {
|
||||
log.Error("import error: "+err.Error(), zap.String("filePath", filePath))
|
||||
log.Error("import wrapper: failed to parse row-based json file", zap.Any("err", err), zap.String("filePath", filePath))
|
||||
return err
|
||||
}
|
||||
} // no need to check else, since the fileValidation() already do this
|
||||
|
@ -260,7 +350,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
|
|||
fieldsData := initSegmentData(p.collectionSchema)
|
||||
if fieldsData == nil {
|
||||
log.Error("import wrapper: failed to initialize FieldData list")
|
||||
return errors.New("failed to initialize FieldData list")
|
||||
return fmt.Errorf("import wrapper: failed to initialize FieldData list")
|
||||
}
|
||||
|
||||
rowCount := 0
|
||||
|
@ -271,25 +361,32 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
|
|||
return nil
|
||||
}
|
||||
|
||||
p.printFieldsDataInfo(fields, "import wrapper: combine field data", nil)
|
||||
printFieldsDataInfo(fields, "import wrapper: combine field data", nil)
|
||||
tr := timerecord.NewTimeRecorder("combine field data")
|
||||
defer tr.Elapse("finished")
|
||||
|
||||
for k, v := range fields {
|
||||
// ignore 0 row field
|
||||
if v.RowNum() == 0 {
|
||||
log.Warn("import wrapper: empty FieldData ignored", zap.Int64("fieldID", k))
|
||||
continue
|
||||
}
|
||||
|
||||
// ignore internal fields: RowIDField and TimeStampField
|
||||
if k == common.RowIDField || k == common.TimeStampField {
|
||||
log.Warn("import wrapper: internal fields should not be provided", zap.Int64("fieldID", k))
|
||||
continue
|
||||
}
|
||||
|
||||
// each column should be only combined once
|
||||
data, ok := fieldsData[k]
|
||||
if ok && data.RowNum() > 0 {
|
||||
return errors.New("the field " + strconv.FormatInt(k, 10) + " is duplicated")
|
||||
return fmt.Errorf("the field %d is duplicated", k)
|
||||
}
|
||||
|
||||
// check the row count. only count non-zero row fields
|
||||
if rowCount > 0 && rowCount != v.RowNum() {
|
||||
return errors.New("the field " + strconv.FormatInt(k, 10) + " row count " + strconv.Itoa(v.RowNum()) + " doesn't equal " + strconv.Itoa(rowCount))
|
||||
return fmt.Errorf("the field %d row count %d doesn't equal others row count: %d", k, v.RowNum(), rowCount)
|
||||
}
|
||||
rowCount = v.RowNum()
|
||||
|
||||
|
@ -303,20 +400,14 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
|
|||
// parse/validate/consume data
|
||||
for i := 0; i < len(filePaths); i++ {
|
||||
filePath := filePaths[i]
|
||||
_, fileType := getFileNameAndExt(filePath)
|
||||
_, fileType := GetFileNameAndExt(filePath)
|
||||
log.Info("import wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
|
||||
|
||||
if fileType == JSONFileExt {
|
||||
err = p.parseColumnBasedJSON(filePath, onlyValidate, combineFunc)
|
||||
if err != nil {
|
||||
log.Error("import error: "+err.Error(), zap.String("filePath", filePath))
|
||||
return err
|
||||
}
|
||||
} else if fileType == NumpyFileExt {
|
||||
err = p.parseColumnBasedNumpy(filePath, onlyValidate, combineFunc)
|
||||
if fileType == NumpyFileExt {
|
||||
err = p.parseColumnBasedNumpy(filePath, options.OnlyValidate, combineFunc)
|
||||
|
||||
if err != nil {
|
||||
log.Error("import error: "+err.Error(), zap.String("filePath", filePath))
|
||||
log.Error("import wrapper: failed to parse column-based numpy file", zap.Any("err", err), zap.String("filePath", filePath))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -327,7 +418,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
|
|||
triggerGC()
|
||||
|
||||
// split fields data into segments
|
||||
err := p.splitFieldsData(fieldsData, filePaths)
|
||||
err := p.splitFieldsData(fieldsData, SingleBlockSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -339,7 +430,14 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
|
|||
return p.reportPersisted()
|
||||
}
|
||||
|
||||
// reportPersisted notify the rootcoord to mark the task state to be ImportPersisted
|
||||
func (p *ImportWrapper) reportPersisted() error {
|
||||
// force close all segments
|
||||
err := p.closeAllWorkingSegments()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// report file process state
|
||||
p.importResult.State = commonpb.ImportState_ImportPersisted
|
||||
// persist state task is valuable, retry more times in case fail this task only because of network error
|
||||
|
@ -353,6 +451,7 @@ func (p *ImportWrapper) reportPersisted() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// isBinlogImport is to judge whether it is binlog import operation
|
||||
// For internal usage by the restore tool: https://github.com/zilliztech/milvus-backup
|
||||
// This tool exports data from a milvus service, and call bulkload interface to import native data into another milvus service.
|
||||
// This tool provides two paths: one is data log path of a partition,the other is delta log path of this partition.
|
||||
|
@ -360,16 +459,16 @@ func (p *ImportWrapper) reportPersisted() error {
|
|||
func (p *ImportWrapper) isBinlogImport(filePaths []string) bool {
|
||||
// must contains the insert log path, and the delta log path is optional
|
||||
if len(filePaths) != 1 && len(filePaths) != 2 {
|
||||
log.Info("import wrapper: paths count is not 1 or 2", zap.Int("len", len(filePaths)))
|
||||
log.Info("import wrapper: paths count is not 1 or 2, not binlog import", zap.Int("len", len(filePaths)))
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < len(filePaths); i++ {
|
||||
filePath := filePaths[i]
|
||||
_, fileType := getFileNameAndExt(filePath)
|
||||
_, fileType := GetFileNameAndExt(filePath)
|
||||
// contains file extension, is not a path
|
||||
if len(fileType) != 0 {
|
||||
log.Info("import wrapper: not a path", zap.String("filePath", filePath), zap.String("fileType", fileType))
|
||||
log.Info("import wrapper: not a path, not binlog import", zap.String("filePath", filePath), zap.String("fileType", fileType))
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
@ -378,12 +477,14 @@ func (p *ImportWrapper) isBinlogImport(filePaths []string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (p *ImportWrapper) doBinlogImport(filePaths []string, tsEndPoint uint64) error {
|
||||
// doBinlogImport is the entry of binlog import operation
|
||||
func (p *ImportWrapper) doBinlogImport(filePaths []string, tsStartPoint uint64, tsEndPoint uint64) error {
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
p.printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths)
|
||||
return p.callFlushFunc(fields, shardID)
|
||||
printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths)
|
||||
return p.flushFunc(fields, shardID)
|
||||
}
|
||||
parser, err := NewBinlogParser(p.collectionSchema, p.shardNum, p.segmentSize, p.chunkManager, flushFunc, tsEndPoint)
|
||||
parser, err := NewBinlogParser(p.ctx, p.collectionSchema, p.shardNum, SingleBlockSize, p.chunkManager, flushFunc,
|
||||
tsStartPoint, tsEndPoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -396,6 +497,7 @@ func (p *ImportWrapper) doBinlogImport(filePaths []string, tsEndPoint uint64) er
|
|||
return p.reportPersisted()
|
||||
}
|
||||
|
||||
// parseRowBasedJSON is the entry of row-based json import operation
|
||||
func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) error {
|
||||
tr := timerecord.NewTimeRecorder("json row-based parser: " + filePath)
|
||||
|
||||
|
@ -417,14 +519,16 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er
|
|||
if !onlyValidate {
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
var filePaths = []string{filePath}
|
||||
p.printFieldsDataInfo(fields, "import wrapper: prepare to flush segment", filePaths)
|
||||
return p.callFlushFunc(fields, shardID)
|
||||
printFieldsDataInfo(fields, "import wrapper: prepare to flush binlogs", filePaths)
|
||||
return p.flushFunc(fields, shardID)
|
||||
}
|
||||
consumer, err = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, p.segmentSize, flushFunc)
|
||||
|
||||
consumer, err = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, SingleBlockSize, flushFunc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
validator, err := NewJSONRowValidator(p.collectionSchema, consumer)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -444,52 +548,14 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *ImportWrapper) parseColumnBasedJSON(filePath string, onlyValidate bool,
|
||||
combineFunc func(fields map[storage.FieldID]storage.FieldData) error) error {
|
||||
tr := timerecord.NewTimeRecorder("json column-based parser: " + filePath)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// for minio storage, chunkManager will download file into local memory
|
||||
// for local storage, chunkManager open the file directly
|
||||
file, err := p.chunkManager.Reader(ctx, filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// parse file
|
||||
reader := bufio.NewReader(file)
|
||||
parser := NewJSONParser(p.ctx, p.collectionSchema)
|
||||
var consumer *JSONColumnConsumer
|
||||
if !onlyValidate {
|
||||
consumer, err = NewJSONColumnConsumer(p.collectionSchema, combineFunc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
validator, err := NewJSONColumnValidator(p.collectionSchema, consumer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tr.Elapse("parsed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseColumnBasedNumpy is the entry of column-based numpy import operation
|
||||
func (p *ImportWrapper) parseColumnBasedNumpy(filePath string, onlyValidate bool,
|
||||
combineFunc func(fields map[storage.FieldID]storage.FieldData) error) error {
|
||||
tr := timerecord.NewTimeRecorder("numpy parser: " + filePath)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
fileName, _ := getFileNameAndExt(filePath)
|
||||
fileName, _ := GetFileNameAndExt(filePath)
|
||||
|
||||
// for minio storage, chunkManager will download file into local memory
|
||||
// for local storage, chunkManager open the file directly
|
||||
|
@ -532,6 +598,7 @@ func (p *ImportWrapper) parseColumnBasedNumpy(filePath string, onlyValidate bool
|
|||
return nil
|
||||
}
|
||||
|
||||
// appendFunc defines the methods to append data to storage.FieldData
|
||||
func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error {
|
||||
switch schema.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
|
@ -608,13 +675,14 @@ func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storag
|
|||
}
|
||||
}
|
||||
|
||||
func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, files []string) error {
|
||||
// splitFieldsData is to split the in-memory data(parsed from column-based files) into blocks, each block save to a binlog file
|
||||
func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, blockSize int64) error {
|
||||
if len(fieldsData) == 0 {
|
||||
log.Error("import wrapper: fields data is empty")
|
||||
return errors.New("import error: fields data is empty")
|
||||
return fmt.Errorf("import wrapper: fields data is empty")
|
||||
}
|
||||
|
||||
tr := timerecord.NewTimeRecorder("split field data")
|
||||
tr := timerecord.NewTimeRecorder("import wrapper: split field data")
|
||||
defer tr.Elapse("finished")
|
||||
|
||||
// check existence of each field
|
||||
|
@ -633,7 +701,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
|
|||
v, ok := fieldsData[schema.GetFieldID()]
|
||||
if !ok {
|
||||
log.Error("import wrapper: field not provided", zap.String("fieldName", schema.GetName()))
|
||||
return errors.New("import error: field " + schema.GetName() + " not provided")
|
||||
return fmt.Errorf("import wrapper: field '%s' not provided", schema.GetName())
|
||||
}
|
||||
rowCounter[schema.GetName()] = v.RowNum()
|
||||
if v.RowNum() > rowCount {
|
||||
|
@ -643,27 +711,30 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
|
|||
}
|
||||
if primaryKey == nil {
|
||||
log.Error("import wrapper: primary key field is not found")
|
||||
return errors.New("import error: primary key field is not found")
|
||||
return fmt.Errorf("import wrapper: primary key field is not found")
|
||||
}
|
||||
|
||||
for name, count := range rowCounter {
|
||||
if count != rowCount {
|
||||
log.Error("import wrapper: field row count is not equal to other fields row count", zap.String("fieldName", name),
|
||||
zap.Int("rowCount", count), zap.Int("otherRowCount", rowCount))
|
||||
return errors.New("import error: field " + name + " row count " + strconv.Itoa(count) + " is not equal to other fields row count " + strconv.Itoa(rowCount))
|
||||
return fmt.Errorf("import wrapper: field '%s' row count %d is not equal to other fields row count: %d", name, count, rowCount)
|
||||
}
|
||||
}
|
||||
log.Info("import wrapper: try to split a block with row count", zap.Int("rowCount", rowCount), zap.Any("rowCountOfEachField", rowCounter))
|
||||
|
||||
primaryData, ok := fieldsData[primaryKey.GetFieldID()]
|
||||
if !ok {
|
||||
log.Error("import wrapper: primary key field is not provided", zap.String("keyName", primaryKey.GetName()))
|
||||
return errors.New("import error: primary key field is not provided")
|
||||
return fmt.Errorf("import wrapper: primary key field is not provided")
|
||||
}
|
||||
|
||||
// generate auto id for primary key and rowid field
|
||||
var rowIDBegin typeutil.UniqueID
|
||||
var rowIDEnd typeutil.UniqueID
|
||||
rowIDBegin, rowIDEnd, _ = p.rowIDAllocator.Alloc(uint32(rowCount))
|
||||
rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount))
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to alloc row ID", zap.Any("err", err))
|
||||
return err
|
||||
}
|
||||
|
||||
rowIDField := fieldsData[common.RowIDField]
|
||||
rowIDFieldArr := rowIDField.(*storage.Int64FieldData)
|
||||
|
@ -672,19 +743,25 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
|
|||
}
|
||||
|
||||
if primaryKey.GetAutoID() {
|
||||
log.Info("import wrapper: generating auto-id", zap.Any("rowCount", rowCount))
|
||||
log.Info("import wrapper: generating auto-id", zap.Int("rowCount", rowCount), zap.Int64("rowIDBegin", rowIDBegin))
|
||||
|
||||
primaryDataArr := primaryData.(*storage.Int64FieldData)
|
||||
// reset the primary keys, as we know, only int64 pk can be auto-generated
|
||||
primaryDataArr := &storage.Int64FieldData{
|
||||
NumRows: []int64{int64(rowCount)},
|
||||
Data: make([]int64, 0, rowCount),
|
||||
}
|
||||
for i := rowIDBegin; i < rowIDEnd; i++ {
|
||||
primaryDataArr.Data = append(primaryDataArr.Data, i)
|
||||
}
|
||||
|
||||
primaryData = primaryDataArr
|
||||
fieldsData[primaryKey.GetFieldID()] = primaryData
|
||||
p.importResult.AutoIds = append(p.importResult.AutoIds, rowIDBegin, rowIDEnd)
|
||||
}
|
||||
|
||||
if primaryData.RowNum() <= 0 {
|
||||
log.Error("import wrapper: primary key not provided", zap.String("keyName", primaryKey.GetName()))
|
||||
return errors.New("import wrapper: primary key " + primaryKey.GetName() + " not provided")
|
||||
return fmt.Errorf("import wrapper: the primary key '%s' not provided", primaryKey.GetName())
|
||||
}
|
||||
|
||||
// prepare segemnts
|
||||
|
@ -693,7 +770,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
|
|||
segmentData := initSegmentData(p.collectionSchema)
|
||||
if segmentData == nil {
|
||||
log.Error("import wrapper: failed to initialize FieldData list")
|
||||
return errors.New("failed to initialize FieldData list")
|
||||
return fmt.Errorf("import wrapper: failed to initialize FieldData list")
|
||||
}
|
||||
segmentsData = append(segmentsData, segmentData)
|
||||
}
|
||||
|
@ -705,12 +782,12 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
|
|||
appendFuncErr := p.appendFunc(schema)
|
||||
if appendFuncErr == nil {
|
||||
log.Error("import wrapper: unsupported field data type")
|
||||
return errors.New("import wrapper: unsupported field data type")
|
||||
return fmt.Errorf("import wrapper: unsupported field data type")
|
||||
}
|
||||
appendFunctions[schema.GetName()] = appendFuncErr
|
||||
}
|
||||
|
||||
// split data into segments
|
||||
// split data into shards
|
||||
for i := 0; i < rowCount; i++ {
|
||||
// hash to a shard number
|
||||
var shard uint32
|
||||
|
@ -723,7 +800,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
|
|||
intPK, ok := interface{}(pk).(int64)
|
||||
if !ok {
|
||||
log.Error("import wrapper: primary key field must be int64 or varchar")
|
||||
return errors.New("import error: primary key field must be int64 or varchar")
|
||||
return fmt.Errorf("import wrapper: primary key field must be int64 or varchar")
|
||||
}
|
||||
hash, _ := typeutil.Hash32Int64(intPK)
|
||||
shard = hash % uint32(p.shardNum)
|
||||
|
@ -744,18 +821,118 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
|
|||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// call flush function
|
||||
for i := 0; i < int(p.shardNum); i++ {
|
||||
segmentData := segmentsData[i]
|
||||
p.printFieldsDataInfo(segmentData, "import wrapper: prepare to flush segment", files)
|
||||
err := p.callFlushFunc(segmentData, i)
|
||||
// when the estimated size is close to blockSize, force flush
|
||||
err = tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, false)
|
||||
if err != nil {
|
||||
log.Error("import wrapper: flush callback function failed", zap.Any("err", err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// force flush at the end
|
||||
return tryFlushBlocks(p.ctx, segmentsData, p.collectionSchema, p.flushFunc, blockSize, MaxTotalSizeInMemory, true)
|
||||
}
|
||||
|
||||
// flushFunc is the callback function for parsers generate segment and save binlog files
|
||||
func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData, shardID int) error {
|
||||
// if fields data is empty, do nothing
|
||||
var rowNum int
|
||||
memSize := 0
|
||||
for _, field := range fields {
|
||||
rowNum = field.RowNum()
|
||||
memSize += field.GetMemorySize()
|
||||
break
|
||||
}
|
||||
if rowNum <= 0 {
|
||||
log.Warn("import wrapper: fields data is empty", zap.Int("shardID", shardID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// if there is no segment for this shard, create a new one
|
||||
// if the segment exists and its size almost exceed segmentSize, close it and create a new one
|
||||
var segment *WorkingSegment
|
||||
segment, ok := p.workingSegments[shardID]
|
||||
if ok {
|
||||
// the segment already exists, check its size, if the size exceeds(or almost) segmentSize, close the segment
|
||||
if int64(segment.memSize)+int64(memSize) >= p.segmentSize {
|
||||
err := p.closeWorkingSegment(segment)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
segment = nil
|
||||
p.workingSegments[shardID] = nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if segment == nil {
|
||||
// create a new segment
|
||||
segID, channelName, err := p.assignSegmentFunc(shardID)
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to assign a new segment", zap.Any("error", err), zap.Int("shardID", shardID))
|
||||
return err
|
||||
}
|
||||
|
||||
segment = &WorkingSegment{
|
||||
segmentID: segID,
|
||||
shardID: shardID,
|
||||
targetChName: channelName,
|
||||
rowCount: int64(0),
|
||||
memSize: 0,
|
||||
fieldsInsert: make([]*datapb.FieldBinlog, 0),
|
||||
fieldsStats: make([]*datapb.FieldBinlog, 0),
|
||||
}
|
||||
p.workingSegments[shardID] = segment
|
||||
}
|
||||
|
||||
// save binlogs
|
||||
fieldsInsert, fieldsStats, err := p.createBinlogsFunc(fields, segment.segmentID)
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to save binlogs", zap.Any("error", err), zap.Int("shardID", shardID),
|
||||
zap.Int64("segmentID", segment.segmentID), zap.String("targetChannel", segment.targetChName))
|
||||
return err
|
||||
}
|
||||
|
||||
segment.fieldsInsert = append(segment.fieldsInsert, fieldsInsert...)
|
||||
segment.fieldsStats = append(segment.fieldsStats, fieldsStats...)
|
||||
segment.rowCount += int64(rowNum)
|
||||
segment.memSize += memSize
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeWorkingSegment marks a segment to be sealed
|
||||
func (p *ImportWrapper) closeWorkingSegment(segment *WorkingSegment) error {
|
||||
log.Info("import wrapper: adding segment to the correct DataNode flow graph and saving binlog paths",
|
||||
zap.Int("shardID", segment.shardID),
|
||||
zap.Int64("segmentID", segment.segmentID),
|
||||
zap.String("targetChannel", segment.targetChName),
|
||||
zap.Int64("rowCount", segment.rowCount),
|
||||
zap.Int("insertLogCount", len(segment.fieldsInsert)),
|
||||
zap.Int("statsLogCount", len(segment.fieldsStats)))
|
||||
|
||||
err := p.saveSegmentFunc(segment.fieldsInsert, segment.fieldsStats, segment.segmentID, segment.targetChName, segment.rowCount)
|
||||
if err != nil {
|
||||
log.Error("import wrapper: failed to save segment",
|
||||
zap.Any("error", err),
|
||||
zap.Int("shardID", segment.shardID),
|
||||
zap.Int64("segmentID", segment.segmentID),
|
||||
zap.String("targetChannel", segment.targetChName))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeAllWorkingSegments mark all segments to be sealed at the end of import operation
|
||||
func (p *ImportWrapper) closeAllWorkingSegments() error {
|
||||
for _, segment := range p.workingSegments {
|
||||
err := p.closeWorkingSegment(segment)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
p.workingSegments = make(map[int]*WorkingSegment)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -19,7 +19,7 @@ package importutil
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"reflect"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -31,33 +31,12 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
// interface to process rows data
|
||||
// JSONRowHandler is the interface to process rows data
|
||||
type JSONRowHandler interface {
|
||||
Handle(rows []map[storage.FieldID]interface{}) error
|
||||
}
|
||||
|
||||
// interface to process column data
|
||||
type JSONColumnHandler interface {
|
||||
Handle(columns map[storage.FieldID][]interface{}) error
|
||||
}
|
||||
|
||||
// method to get dimension of vecotor field
|
||||
func getFieldDimension(schema *schemapb.FieldSchema) (int, error) {
|
||||
for _, kvPair := range schema.GetTypeParams() {
|
||||
key, value := kvPair.GetKey(), kvPair.GetValue()
|
||||
if key == "dim" {
|
||||
dim, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return 0, errors.New("vector dimension is invalid")
|
||||
}
|
||||
return dim, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, errors.New("vector dimension is not defined")
|
||||
}
|
||||
|
||||
// field value validator
|
||||
// Validator is field value validator
|
||||
type Validator struct {
|
||||
validateFunc func(obj interface{}) error // validate data type function
|
||||
convertFunc func(obj interface{}, field storage.FieldData) error // convert data function
|
||||
|
@ -68,210 +47,7 @@ type Validator struct {
|
|||
fieldName string // field name
|
||||
}
|
||||
|
||||
// method to construct valiator functions
|
||||
func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[storage.FieldID]*Validator) error {
|
||||
if collectionSchema == nil {
|
||||
return errors.New("collection schema is nil")
|
||||
}
|
||||
|
||||
// json decoder parse all the numeric value into float64
|
||||
numericValidator := func(obj interface{}) error {
|
||||
switch obj.(type) {
|
||||
case float64:
|
||||
return nil
|
||||
default:
|
||||
s := fmt.Sprintf("%v", obj)
|
||||
msg := "illegal numeric value " + s
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for i := 0; i < len(collectionSchema.Fields); i++ {
|
||||
schema := collectionSchema.Fields[i]
|
||||
|
||||
validators[schema.GetFieldID()] = &Validator{}
|
||||
validators[schema.GetFieldID()].primaryKey = schema.GetIsPrimaryKey()
|
||||
validators[schema.GetFieldID()].autoID = schema.GetAutoID()
|
||||
validators[schema.GetFieldID()].fieldName = schema.GetName()
|
||||
validators[schema.GetFieldID()].isString = false
|
||||
|
||||
switch schema.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
|
||||
switch obj.(type) {
|
||||
case bool:
|
||||
return nil
|
||||
default:
|
||||
s := fmt.Sprintf("%v", obj)
|
||||
msg := "illegal value " + s + " for bool type field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
}
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := obj.(bool)
|
||||
field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value)
|
||||
field.(*storage.BoolFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := float32(obj.(float64))
|
||||
field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value)
|
||||
field.(*storage.FloatFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := obj.(float64)
|
||||
field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value)
|
||||
field.(*storage.DoubleFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int8:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := int8(obj.(float64))
|
||||
field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value)
|
||||
field.(*storage.Int8FieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int16:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := int16(obj.(float64))
|
||||
field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value)
|
||||
field.(*storage.Int16FieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int32:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := int32(obj.(float64))
|
||||
field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value)
|
||||
field.(*storage.Int32FieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
validators[schema.GetFieldID()].validateFunc = numericValidator
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := int64(obj.(float64))
|
||||
field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value)
|
||||
field.(*storage.Int64FieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
dim, err := getFieldDimension(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validators[schema.GetFieldID()].dimension = dim
|
||||
|
||||
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
|
||||
switch vt := obj.(type) {
|
||||
case []interface{}:
|
||||
if len(vt)*8 != dim {
|
||||
msg := "bit size " + strconv.Itoa(len(vt)*8) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
for i := 0; i < len(vt); i++ {
|
||||
if e := numericValidator(vt[i]); e != nil {
|
||||
msg := e.Error() + " for binary vector field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
t := int(vt[i].(float64))
|
||||
if t > 255 || t < 0 {
|
||||
msg := "illegal value " + strconv.Itoa(t) + " for binary vector field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
s := fmt.Sprintf("%v", obj)
|
||||
msg := s + " is not an array for binary vector field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
arr := obj.([]interface{})
|
||||
for i := 0; i < len(arr); i++ {
|
||||
value := byte(arr[i].(float64))
|
||||
field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, value)
|
||||
}
|
||||
|
||||
field.(*storage.BinaryVectorFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_FloatVector:
|
||||
dim, err := getFieldDimension(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validators[schema.GetFieldID()].dimension = dim
|
||||
|
||||
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
|
||||
switch vt := obj.(type) {
|
||||
case []interface{}:
|
||||
if len(vt) != dim {
|
||||
msg := "array size " + strconv.Itoa(len(vt)) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
for i := 0; i < len(vt); i++ {
|
||||
if e := numericValidator(vt[i]); e != nil {
|
||||
msg := e.Error() + " for float vector field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
s := fmt.Sprintf("%v", obj)
|
||||
msg := s + " is not an array for float vector field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
arr := obj.([]interface{})
|
||||
for i := 0; i < len(arr); i++ {
|
||||
value := float32(arr[i].(float64))
|
||||
field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, value)
|
||||
}
|
||||
field.(*storage.FloatVectorFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
validators[schema.GetFieldID()].isString = true
|
||||
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
|
||||
switch obj.(type) {
|
||||
case string:
|
||||
return nil
|
||||
default:
|
||||
s := fmt.Sprintf("%v", obj)
|
||||
msg := s + " is not a string for string type field " + schema.GetName()
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
|
||||
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
|
||||
value := obj.(string)
|
||||
field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, value)
|
||||
field.(*storage.StringFieldData).NumRows[0]++
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return errors.New("unsupport data type: " + strconv.Itoa(int(collectionSchema.Fields[i].DataType)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// row-based json format validator class
|
||||
// JSONRowValidator is row-based json format validator class
|
||||
type JSONRowValidator struct {
|
||||
downstream JSONRowHandler // downstream processor, typically is a JSONRowComsumer
|
||||
validators map[storage.FieldID]*Validator // validators for each field
|
||||
|
@ -286,7 +62,7 @@ func NewJSONRowValidator(collectionSchema *schemapb.CollectionSchema, downstream
|
|||
}
|
||||
err := initValidators(collectionSchema, v.validators)
|
||||
if err != nil {
|
||||
log.Error("JSON column validator: failed to initialize json row-based validator", zap.Error(err))
|
||||
log.Error("JSON row validator: failed to initialize json row-based validator", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return v, nil
|
||||
|
@ -298,13 +74,14 @@ func (v *JSONRowValidator) ValidateCount() int64 {
|
|||
|
||||
func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error {
|
||||
if v == nil || v.validators == nil || len(v.validators) == 0 {
|
||||
log.Error("JSON row validator is not initialized")
|
||||
return errors.New("JSON row validator is not initialized")
|
||||
}
|
||||
|
||||
// parse completed
|
||||
if rows == nil {
|
||||
log.Info("JSON row validation finished")
|
||||
if v.downstream != nil {
|
||||
if v.downstream != nil && !reflect.ValueOf(v.downstream).IsNil() {
|
||||
return v.downstream.Handle(rows)
|
||||
}
|
||||
return nil
|
||||
|
@ -314,107 +91,42 @@ func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error
|
|||
row := rows[i]
|
||||
|
||||
for id, validator := range v.validators {
|
||||
value, ok := row[id]
|
||||
if validator.primaryKey && validator.autoID {
|
||||
// auto-generated primary key, ignore
|
||||
// primary key is auto-generated, if user provided it, return error
|
||||
if ok {
|
||||
log.Error("JSON row validator: primary key is auto-generated, no need to provide PK value at the row",
|
||||
zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", v.rowCounter+int64(i)))
|
||||
return fmt.Errorf("the primary key '%s' is auto-generated, no need to provide PK value at the row %d",
|
||||
validator.fieldName, v.rowCounter+int64(i))
|
||||
}
|
||||
continue
|
||||
}
|
||||
value, ok := row[id]
|
||||
if !ok {
|
||||
return errors.New("JSON row validator: field " + validator.fieldName + " missed at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
|
||||
log.Error("JSON row validator: field missed at the row",
|
||||
zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", v.rowCounter+int64(i)))
|
||||
return fmt.Errorf("the field '%s' missed at the row %d", validator.fieldName, v.rowCounter+int64(i))
|
||||
}
|
||||
|
||||
if err := validator.validateFunc(value); err != nil {
|
||||
return errors.New("JSON row validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
|
||||
log.Error("JSON row validator: invalid value at the row", zap.String("fieldName", validator.fieldName),
|
||||
zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Any("value", value), zap.Error(err))
|
||||
return fmt.Errorf("the field '%s' value at the row %d is invalid, error: %s",
|
||||
validator.fieldName, v.rowCounter+int64(i), err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
v.rowCounter += int64(len(rows))
|
||||
|
||||
if v.downstream != nil {
|
||||
if v.downstream != nil && !reflect.ValueOf(v.downstream).IsNil() {
|
||||
return v.downstream.Handle(rows)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// column-based json format validator class
|
||||
type JSONColumnValidator struct {
|
||||
downstream JSONColumnHandler // downstream processor, typically is a JSONColumnComsumer
|
||||
validators map[storage.FieldID]*Validator // validators for each field
|
||||
rowCounter map[string]int64 // row count of each field
|
||||
}
|
||||
|
||||
func NewJSONColumnValidator(schema *schemapb.CollectionSchema, downstream JSONColumnHandler) (*JSONColumnValidator, error) {
|
||||
v := &JSONColumnValidator{
|
||||
validators: make(map[storage.FieldID]*Validator),
|
||||
downstream: downstream,
|
||||
rowCounter: make(map[string]int64),
|
||||
}
|
||||
err := initValidators(schema, v.validators)
|
||||
if err != nil {
|
||||
log.Error("JSON column validator: fail to initialize json column-based validator", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (v *JSONColumnValidator) ValidateCount() map[string]int64 {
|
||||
return v.rowCounter
|
||||
}
|
||||
|
||||
func (v *JSONColumnValidator) Handle(columns map[storage.FieldID][]interface{}) error {
|
||||
if v == nil || v.validators == nil || len(v.validators) == 0 {
|
||||
return errors.New("JSON column validator is not initialized")
|
||||
}
|
||||
|
||||
// parse completed
|
||||
if columns == nil {
|
||||
// compare the row count of columns, should be equal
|
||||
rowCount := int64(-1)
|
||||
for k, counter := range v.rowCounter {
|
||||
if rowCount == -1 {
|
||||
rowCount = counter
|
||||
} else if rowCount != counter {
|
||||
return errors.New("JSON column validator: the field " + k + " row count " + strconv.Itoa(int(counter)) + " is not equal to other fields row count" + strconv.Itoa(int(rowCount)))
|
||||
}
|
||||
}
|
||||
|
||||
// let the downstream know parse is completed
|
||||
log.Info("JSON column validation finished")
|
||||
if v.downstream != nil {
|
||||
return v.downstream.Handle(nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for id, values := range columns {
|
||||
validator, ok := v.validators[id]
|
||||
name := validator.fieldName
|
||||
if !ok {
|
||||
// not a valid field name, skip without parsing
|
||||
break
|
||||
}
|
||||
|
||||
for i := 0; i < len(values); i++ {
|
||||
if err := validator.validateFunc(values[i]); err != nil {
|
||||
return errors.New("JSON column validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter[name]+int64(i), 10))
|
||||
}
|
||||
}
|
||||
|
||||
v.rowCounter[name] += int64(len(values))
|
||||
}
|
||||
|
||||
if v.downstream != nil {
|
||||
return v.downstream.Handle(columns)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardID int) error
|
||||
|
||||
// row-based json format consumer class
|
||||
// JSONRowConsumer is row-based json format consumer class
|
||||
type JSONRowConsumer struct {
|
||||
collectionSchema *schemapb.CollectionSchema // collection schema
|
||||
rowIDAllocator *allocator.IDAllocator // autoid allocator
|
||||
|
@ -422,89 +134,14 @@ type JSONRowConsumer struct {
|
|||
rowCounter int64 // how many rows have been consumed
|
||||
shardNum int32 // sharding number of the collection
|
||||
segmentsData []map[storage.FieldID]storage.FieldData // in-memory segments data
|
||||
segmentSize int64 // maximum size of a segment(unit:byte)
|
||||
blockSize int64 // maximum size of a read block(unit:byte)
|
||||
primaryKey storage.FieldID // name of primary key
|
||||
autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25
|
||||
|
||||
callFlushFunc ImportFlushFunc // call back function to flush segment
|
||||
}
|
||||
|
||||
func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.FieldID]storage.FieldData {
|
||||
segmentData := make(map[storage.FieldID]storage.FieldData)
|
||||
// rowID field is a hidden field with fieldID=0, it is always auto-generated by IDAllocator
|
||||
// if primary key is int64 and autoID=true, primary key field is equal to rowID field
|
||||
segmentData[common.RowIDField] = &storage.Int64FieldData{
|
||||
Data: make([]int64, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
|
||||
for i := 0; i < len(collectionSchema.Fields); i++ {
|
||||
schema := collectionSchema.Fields[i]
|
||||
switch schema.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
segmentData[schema.GetFieldID()] = &storage.BoolFieldData{
|
||||
Data: make([]bool, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
segmentData[schema.GetFieldID()] = &storage.FloatFieldData{
|
||||
Data: make([]float32, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
segmentData[schema.GetFieldID()] = &storage.DoubleFieldData{
|
||||
Data: make([]float64, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Int8:
|
||||
segmentData[schema.GetFieldID()] = &storage.Int8FieldData{
|
||||
Data: make([]int8, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Int16:
|
||||
segmentData[schema.GetFieldID()] = &storage.Int16FieldData{
|
||||
Data: make([]int16, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Int32:
|
||||
segmentData[schema.GetFieldID()] = &storage.Int32FieldData{
|
||||
Data: make([]int32, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
segmentData[schema.GetFieldID()] = &storage.Int64FieldData{
|
||||
Data: make([]int64, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
dim, _ := getFieldDimension(schema)
|
||||
segmentData[schema.GetFieldID()] = &storage.BinaryVectorFieldData{
|
||||
Data: make([]byte, 0),
|
||||
NumRows: []int64{0},
|
||||
Dim: dim,
|
||||
}
|
||||
case schemapb.DataType_FloatVector:
|
||||
dim, _ := getFieldDimension(schema)
|
||||
segmentData[schema.GetFieldID()] = &storage.FloatVectorFieldData{
|
||||
Data: make([]float32, 0),
|
||||
NumRows: []int64{0},
|
||||
Dim: dim,
|
||||
}
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
segmentData[schema.GetFieldID()] = &storage.StringFieldData{
|
||||
Data: make([]string, 0),
|
||||
NumRows: []int64{0},
|
||||
}
|
||||
default:
|
||||
log.Error("JSON row consumer error: unsupported data type", zap.Int("DataType", int(schema.DataType)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return segmentData
|
||||
}
|
||||
|
||||
func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, segmentSize int64,
|
||||
func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, blockSize int64,
|
||||
flushFunc ImportFlushFunc) (*JSONRowConsumer, error) {
|
||||
if collectionSchema == nil {
|
||||
log.Error("JSON row consumer: collection schema is nil")
|
||||
|
@ -516,7 +153,7 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al
|
|||
rowIDAllocator: idAlloc,
|
||||
validators: make(map[storage.FieldID]*Validator),
|
||||
shardNum: shardNum,
|
||||
segmentSize: segmentSize,
|
||||
blockSize: blockSize,
|
||||
rowCounter: 0,
|
||||
primaryKey: -1,
|
||||
autoIDRange: make([]int64, 0),
|
||||
|
@ -526,15 +163,15 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al
|
|||
err := initValidators(collectionSchema, v.validators)
|
||||
if err != nil {
|
||||
log.Error("JSON row consumer: fail to initialize json row-based consumer", zap.Error(err))
|
||||
return nil, errors.New("fail to initialize json row-based consumer")
|
||||
return nil, fmt.Errorf("fail to initialize json row-based consumer: %v", err)
|
||||
}
|
||||
|
||||
v.segmentsData = make([]map[storage.FieldID]storage.FieldData, 0, shardNum)
|
||||
for i := 0; i < int(shardNum); i++ {
|
||||
segmentData := initSegmentData(collectionSchema)
|
||||
if segmentData == nil {
|
||||
log.Error("JSON row consumer: fail to initialize in-memory segment data", zap.Int32("shardNum", shardNum))
|
||||
return nil, errors.New("fail to initialize in-memory segment data")
|
||||
log.Error("JSON row consumer: fail to initialize in-memory segment data", zap.Int("shardID", i))
|
||||
return nil, fmt.Errorf("fail to initialize in-memory segment data for shardID %d", i)
|
||||
}
|
||||
v.segmentsData = append(v.segmentsData, segmentData)
|
||||
}
|
||||
|
@ -554,7 +191,7 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al
|
|||
// primary key is autoid, id generator is required
|
||||
if v.validators[v.primaryKey].autoID && idAlloc == nil {
|
||||
log.Error("JSON row consumer: ID allocator is nil")
|
||||
return nil, errors.New(" ID allocator is nil")
|
||||
return nil, errors.New("ID allocator is nil")
|
||||
}
|
||||
|
||||
return v, nil
|
||||
|
@ -571,8 +208,17 @@ func (v *JSONRowConsumer) flush(force bool) error {
|
|||
segmentData := v.segmentsData[i]
|
||||
rowNum := segmentData[v.primaryKey].RowNum()
|
||||
if rowNum > 0 {
|
||||
log.Info("JSON row consumer: force flush segment", zap.Int("rows", rowNum))
|
||||
v.callFlushFunc(segmentData, i)
|
||||
log.Info("JSON row consumer: force flush binlog", zap.Int("rows", rowNum))
|
||||
err := v.callFlushFunc(segmentData, i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v.segmentsData[i] = initSegmentData(v.collectionSchema)
|
||||
if v.segmentsData[i] == nil {
|
||||
log.Error("JSON row consumer: fail to initialize in-memory segment data")
|
||||
return errors.New("fail to initialize in-memory segment data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -587,9 +233,13 @@ func (v *JSONRowConsumer) flush(force bool) error {
|
|||
for _, field := range segmentData {
|
||||
memSize += field.GetMemorySize()
|
||||
}
|
||||
if memSize >= int(v.segmentSize) && rowNum > 0 {
|
||||
log.Info("JSON row consumer: flush fulled segment", zap.Int("bytes", memSize), zap.Int("rowNum", rowNum))
|
||||
v.callFlushFunc(segmentData, i)
|
||||
if memSize >= int(v.blockSize) && rowNum > 0 {
|
||||
log.Info("JSON row consumer: flush fulled binlog", zap.Int("bytes", memSize), zap.Int("rowNum", rowNum))
|
||||
err := v.callFlushFunc(segmentData, i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v.segmentsData[i] = initSegmentData(v.collectionSchema)
|
||||
if v.segmentsData[i] == nil {
|
||||
log.Error("JSON row consumer: fail to initialize in-memory segment data")
|
||||
|
@ -603,6 +253,7 @@ func (v *JSONRowConsumer) flush(force bool) error {
|
|||
|
||||
func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
|
||||
if v == nil || v.validators == nil || len(v.validators) == 0 {
|
||||
log.Error("JSON row consumer is not initialized")
|
||||
return errors.New("JSON row consumer is not initialized")
|
||||
}
|
||||
|
||||
|
@ -615,10 +266,11 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
|
|||
|
||||
err := v.flush(false)
|
||||
if err != nil {
|
||||
return err
|
||||
log.Error("JSON row consumer: try flush data but failed", zap.Error(err))
|
||||
return fmt.Errorf("try flush data but failed: %s", err.Error())
|
||||
}
|
||||
|
||||
// prepare autoid
|
||||
// prepare autoid, no matter int64 or varchar pk, we always generate autoid since the hidden field RowIDField requires them
|
||||
primaryValidator := v.validators[v.primaryKey]
|
||||
var rowIDBegin typeutil.UniqueID
|
||||
var rowIDEnd typeutil.UniqueID
|
||||
|
@ -626,12 +278,19 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
|
|||
var err error
|
||||
rowIDBegin, rowIDEnd, err = v.rowIDAllocator.Alloc(uint32(len(rows)))
|
||||
if err != nil {
|
||||
return errors.New("JSON row consumer: " + err.Error())
|
||||
log.Error("JSON row consumer: failed to generate primary keys", zap.Int("count", len(rows)), zap.Error(err))
|
||||
return fmt.Errorf("failed to generate %d primary keys: %s", len(rows), err.Error())
|
||||
}
|
||||
if rowIDEnd-rowIDBegin != int64(len(rows)) {
|
||||
return errors.New("JSON row consumer: failed to allocate ID for " + strconv.Itoa(len(rows)) + " rows")
|
||||
log.Error("JSON row consumer: try to generate primary keys but allocated ids are not enough",
|
||||
zap.Int("count", len(rows)), zap.Int64("generated", rowIDEnd-rowIDBegin))
|
||||
return fmt.Errorf("try to generate %d primary keys but only %d keys were allocated", len(rows), rowIDEnd-rowIDBegin)
|
||||
}
|
||||
log.Info("JSON row consumer: auto-generate primary keys", zap.Int64("begin", rowIDBegin), zap.Int64("end", rowIDEnd))
|
||||
if !primaryValidator.isString {
|
||||
// if pk is varchar, no need to record auto-generated row ids
|
||||
v.autoIDRange = append(v.autoIDRange, rowIDBegin, rowIDEnd)
|
||||
}
|
||||
v.autoIDRange = append(v.autoIDRange, rowIDBegin, rowIDEnd)
|
||||
}
|
||||
|
||||
// consume rows
|
||||
|
@ -642,7 +301,8 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
|
|||
var shard uint32
|
||||
if primaryValidator.isString {
|
||||
if primaryValidator.autoID {
|
||||
return errors.New("JSON row consumer: string type primary key cannot be auto-generated")
|
||||
log.Error("JSON row consumer: string type primary key cannot be auto-generated")
|
||||
return errors.New("string type primary key cannot be auto-generated")
|
||||
}
|
||||
|
||||
value := row[v.primaryKey]
|
||||
|
@ -662,7 +322,13 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
|
|||
pk = int64(value.(float64))
|
||||
}
|
||||
|
||||
hash, _ := typeutil.Hash32Int64(pk)
|
||||
hash, err := typeutil.Hash32Int64(pk)
|
||||
if err != nil {
|
||||
log.Error("JSON row consumer: failed to hash primary key at the row",
|
||||
zap.Int64("key", pk), zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Error(err))
|
||||
return fmt.Errorf("failed to hash primary key %d at the row %d, error: %s", pk, v.rowCounter+int64(i), err.Error())
|
||||
}
|
||||
|
||||
shard = hash % uint32(v.shardNum)
|
||||
pkArray := v.segmentsData[shard][v.primaryKey].(*storage.Int64FieldData)
|
||||
pkArray.Data = append(pkArray.Data, pk)
|
||||
|
@ -681,7 +347,10 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
|
|||
}
|
||||
value := row[name]
|
||||
if err := validator.convertFunc(value, v.segmentsData[shard][name]); err != nil {
|
||||
return errors.New("JSON row consumer: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
|
||||
log.Error("JSON row consumer: failed to convert value for field at the row",
|
||||
zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Error(err))
|
||||
return fmt.Errorf("failed to convert value for field %s at the row %d, error: %s",
|
||||
validator.fieldName, v.rowCounter+int64(i), err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -690,115 +359,3 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ColumnFlushFunc func(fields map[storage.FieldID]storage.FieldData) error
|
||||
|
||||
// column-based json format consumer class
|
||||
type JSONColumnConsumer struct {
|
||||
collectionSchema *schemapb.CollectionSchema // collection schema
|
||||
validators map[storage.FieldID]*Validator // validators for each field
|
||||
fieldsData map[storage.FieldID]storage.FieldData // in-memory fields data
|
||||
primaryKey storage.FieldID // name of primary key
|
||||
|
||||
callFlushFunc ColumnFlushFunc // call back function to flush segment
|
||||
}
|
||||
|
||||
func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema, flushFunc ColumnFlushFunc) (*JSONColumnConsumer, error) {
|
||||
if collectionSchema == nil {
|
||||
log.Error("JSON column consumer: collection schema is nil")
|
||||
return nil, errors.New("collection schema is nil")
|
||||
}
|
||||
|
||||
v := &JSONColumnConsumer{
|
||||
collectionSchema: collectionSchema,
|
||||
validators: make(map[storage.FieldID]*Validator),
|
||||
callFlushFunc: flushFunc,
|
||||
}
|
||||
err := initValidators(collectionSchema, v.validators)
|
||||
if err != nil {
|
||||
log.Error("JSON column consumer: fail to initialize validator", zap.Error(err))
|
||||
return nil, errors.New("fail to initialize validator")
|
||||
}
|
||||
v.fieldsData = initSegmentData(collectionSchema)
|
||||
if v.fieldsData == nil {
|
||||
log.Error("JSON column consumer: fail to initialize in-memory segment data")
|
||||
return nil, errors.New("fail to initialize in-memory segment data")
|
||||
}
|
||||
|
||||
for i := 0; i < len(collectionSchema.Fields); i++ {
|
||||
schema := collectionSchema.Fields[i]
|
||||
if schema.GetIsPrimaryKey() {
|
||||
v.primaryKey = schema.GetFieldID()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (v *JSONColumnConsumer) flush() error {
|
||||
// check row count, should be equal
|
||||
rowCount := 0
|
||||
for id, field := range v.fieldsData {
|
||||
// skip the autoid field
|
||||
if id == v.primaryKey && v.validators[v.primaryKey].autoID {
|
||||
continue
|
||||
}
|
||||
cnt := field.RowNum()
|
||||
// skip 0 row fields since a data file may only import one column(there are several data files imported)
|
||||
if cnt == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// only check non-zero row fields
|
||||
if rowCount == 0 {
|
||||
rowCount = cnt
|
||||
} else if rowCount != cnt {
|
||||
return errors.New("JSON column consumer: " + strconv.FormatInt(id, 10) + " row count " + strconv.Itoa(cnt) + " doesn't equal " + strconv.Itoa(rowCount))
|
||||
}
|
||||
}
|
||||
|
||||
if rowCount == 0 {
|
||||
return errors.New("JSON column consumer: row count is 0")
|
||||
}
|
||||
log.Info("JSON column consumer: rows parsed", zap.Int("rowCount", rowCount))
|
||||
|
||||
// output the fileds data, let outside split them into segments
|
||||
return v.callFlushFunc(v.fieldsData)
|
||||
}
|
||||
|
||||
func (v *JSONColumnConsumer) Handle(columns map[storage.FieldID][]interface{}) error {
|
||||
if v == nil || v.validators == nil || len(v.validators) == 0 {
|
||||
return errors.New("JSON column consumer is not initialized")
|
||||
}
|
||||
|
||||
// flush at the end
|
||||
if columns == nil {
|
||||
err := v.flush()
|
||||
log.Info("JSON column consumer finished")
|
||||
return err
|
||||
}
|
||||
|
||||
// consume columns data
|
||||
for id, values := range columns {
|
||||
validator, ok := v.validators[id]
|
||||
if !ok {
|
||||
// not a valid field id
|
||||
break
|
||||
}
|
||||
|
||||
if validator.primaryKey && validator.autoID {
|
||||
// autoid is no need to provide
|
||||
break
|
||||
}
|
||||
|
||||
// convert and consume data
|
||||
for i := 0; i < len(values); i++ {
|
||||
if err := validator.convertFunc(values[i], v.fieldsData[id]); err != nil {
|
||||
return errors.New("JSON column consumer: " + err.Error() + " of field " + strconv.FormatInt(id, 10))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package importutil
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -26,15 +27,15 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
)
|
||||
|
||||
type mockIDAllocator struct {
|
||||
allocErr error
|
||||
}
|
||||
|
||||
func (tso *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) {
|
||||
func (a *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) {
|
||||
return &rootcoordpb.AllocIDResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
@ -42,11 +43,13 @@ func (tso *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocI
|
|||
},
|
||||
ID: int64(1),
|
||||
Count: req.Count,
|
||||
}, nil
|
||||
}, a.allocErr
|
||||
}
|
||||
|
||||
func newIDAllocator(ctx context.Context, t *testing.T) *allocator.IDAllocator {
|
||||
mockIDAllocator := &mockIDAllocator{}
|
||||
func newIDAllocator(ctx context.Context, t *testing.T, allocErr error) *allocator.IDAllocator {
|
||||
mockIDAllocator := &mockIDAllocator{
|
||||
allocErr: allocErr,
|
||||
}
|
||||
|
||||
idAllocator, err := allocator.NewIDAllocator(ctx, mockIDAllocator, int64(1))
|
||||
assert.Nil(t, err)
|
||||
|
@ -56,136 +59,14 @@ func newIDAllocator(ctx context.Context, t *testing.T) *allocator.IDAllocator {
|
|||
return idAllocator
|
||||
}
|
||||
|
||||
func Test_GetFieldDimension(t *testing.T) {
|
||||
schema := &schemapb.FieldSchema{
|
||||
FieldID: 111,
|
||||
Name: "field_float_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float_vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
}
|
||||
func Test_NewJSONRowValidator(t *testing.T) {
|
||||
validator, err := NewJSONRowValidator(nil, nil)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, validator)
|
||||
|
||||
dim, err := getFieldDimension(schema)
|
||||
validator, err = NewJSONRowValidator(sampleSchema(), nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, dim)
|
||||
|
||||
schema.TypeParams = []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "abc"},
|
||||
}
|
||||
dim, err = getFieldDimension(schema)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, 0, dim)
|
||||
|
||||
schema.TypeParams = []*commonpb.KeyValuePair{}
|
||||
dim, err = getFieldDimension(schema)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, 0, dim)
|
||||
}
|
||||
|
||||
func Test_InitValidators(t *testing.T) {
|
||||
validators := make(map[storage.FieldID]*Validator)
|
||||
err := initValidators(nil, validators)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
schema := sampleSchema()
|
||||
// success case
|
||||
err = initValidators(schema, validators)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(schema.Fields), len(validators))
|
||||
name2ID := make(map[string]storage.FieldID)
|
||||
for _, field := range schema.Fields {
|
||||
name2ID[field.GetName()] = field.GetFieldID()
|
||||
}
|
||||
|
||||
checkFunc := func(funcName string, validVal interface{}, invalidVal interface{}) {
|
||||
id := name2ID[funcName]
|
||||
v, ok := validators[id]
|
||||
assert.True(t, ok)
|
||||
err = v.validateFunc(validVal)
|
||||
assert.Nil(t, err)
|
||||
err = v.validateFunc(invalidVal)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
// validate functions
|
||||
var validVal interface{} = true
|
||||
var invalidVal interface{} = "aa"
|
||||
checkFunc("field_bool", validVal, invalidVal)
|
||||
|
||||
validVal = float64(100)
|
||||
invalidVal = "aa"
|
||||
checkFunc("field_int8", validVal, invalidVal)
|
||||
checkFunc("field_int16", validVal, invalidVal)
|
||||
checkFunc("field_int32", validVal, invalidVal)
|
||||
checkFunc("field_int64", validVal, invalidVal)
|
||||
checkFunc("field_float", validVal, invalidVal)
|
||||
checkFunc("field_double", validVal, invalidVal)
|
||||
|
||||
validVal = "aa"
|
||||
invalidVal = 100
|
||||
checkFunc("field_string", validVal, invalidVal)
|
||||
|
||||
validVal = []interface{}{float64(100), float64(101)}
|
||||
invalidVal = "aa"
|
||||
checkFunc("field_binary_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{float64(100)}
|
||||
checkFunc("field_binary_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{float64(100), float64(101), float64(102)}
|
||||
checkFunc("field_binary_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{true, true}
|
||||
checkFunc("field_binary_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{float64(255), float64(-1)}
|
||||
checkFunc("field_binary_vector", validVal, invalidVal)
|
||||
|
||||
validVal = []interface{}{float64(1), float64(2), float64(3), float64(4)}
|
||||
invalidVal = true
|
||||
checkFunc("field_float_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{float64(1), float64(2), float64(3)}
|
||||
checkFunc("field_float_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5)}
|
||||
checkFunc("field_float_vector", validVal, invalidVal)
|
||||
invalidVal = []interface{}{"a", "b", "c", "d"}
|
||||
checkFunc("field_float_vector", validVal, invalidVal)
|
||||
|
||||
// error cases
|
||||
schema = &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
AutoID: true,
|
||||
Fields: make([]*schemapb.FieldSchema, 0),
|
||||
}
|
||||
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
|
||||
FieldID: 111,
|
||||
Name: "field_float_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float_vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "aa"},
|
||||
},
|
||||
})
|
||||
|
||||
validators = make(map[storage.FieldID]*Validator)
|
||||
err = initValidators(schema, validators)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
schema.Fields = make([]*schemapb.FieldSchema, 0)
|
||||
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
|
||||
FieldID: 110,
|
||||
Name: "field_binary_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float_vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "aa"},
|
||||
},
|
||||
})
|
||||
|
||||
err = initValidators(schema, validators)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_JSONRowValidator(t *testing.T) {
|
||||
|
@ -209,15 +90,15 @@ func Test_JSONRowValidator(t *testing.T) {
|
|||
assert.NotNil(t, err)
|
||||
assert.Equal(t, int64(0), validator.ValidateCount())
|
||||
|
||||
// // missed some fields
|
||||
// reader = strings.NewReader(`{
|
||||
// "rows":[
|
||||
// {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
|
||||
// {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}
|
||||
// ]
|
||||
// }`)
|
||||
// err = parser.ParseRows(reader, validator)
|
||||
// assert.NotNil(t, err)
|
||||
// missed some fields
|
||||
reader = strings.NewReader(`{
|
||||
"rows":[
|
||||
{"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
|
||||
{"field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}
|
||||
]
|
||||
}`)
|
||||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// invalid dimension
|
||||
reader = strings.NewReader(`{
|
||||
|
@ -241,92 +122,90 @@ func Test_JSONRowValidator(t *testing.T) {
|
|||
validator.validators = nil
|
||||
err = validator.Handle(nil)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_JSONColumnValidator(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
schema := sampleSchema()
|
||||
parser := NewJSONParser(ctx, schema)
|
||||
assert.NotNil(t, parser)
|
||||
|
||||
// 0 row case
|
||||
reader := strings.NewReader(`{
|
||||
"field_bool": [],
|
||||
"field_int8": [],
|
||||
"field_int16": [],
|
||||
"field_int32": [],
|
||||
"field_int64": [],
|
||||
"field_float": [],
|
||||
"field_double": [],
|
||||
"field_string": [],
|
||||
"field_binary_vector": [],
|
||||
"field_float_vector": []
|
||||
}`)
|
||||
|
||||
validator, err := NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
for _, count := range validator.rowCounter {
|
||||
assert.Equal(t, int64(0), count)
|
||||
// primary key is auto-generate, but user provide pk value, return error
|
||||
schema = &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "ID",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "Age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// different row count
|
||||
reader = strings.NewReader(`{
|
||||
"field_bool": [true],
|
||||
"field_int8": [],
|
||||
"field_int16": [],
|
||||
"field_int32": [1, 2, 3],
|
||||
"field_int64": [],
|
||||
"field_float": [],
|
||||
"field_double": [],
|
||||
"field_string": [],
|
||||
"field_binary_vector": [],
|
||||
"field_float_vector": []
|
||||
}`)
|
||||
|
||||
validator, err = NewJSONColumnValidator(schema, nil)
|
||||
validator, err = NewJSONRowValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// invalid value type
|
||||
reader = strings.NewReader(`{
|
||||
"dummy": [],
|
||||
"field_bool": [true],
|
||||
"field_int8": [1],
|
||||
"field_int16": [2],
|
||||
"field_int32": [3],
|
||||
"field_int64": [4],
|
||||
"field_float": [1],
|
||||
"field_double": [1],
|
||||
"field_string": [9],
|
||||
"field_binary_vector": [[254, 1]],
|
||||
"field_float_vector": [[1.1, 1.2, 1.3, 1.4]]
|
||||
"rows":[
|
||||
{"ID": 1, "Age": 2}
|
||||
]
|
||||
}`)
|
||||
parser = NewJSONParser(ctx, schema)
|
||||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
validator, err = NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
func Test_NewJSONRowConsumer(t *testing.T) {
|
||||
// nil schema
|
||||
consumer, err := NewJSONRowConsumer(nil, nil, 2, 16, nil)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, consumer)
|
||||
|
||||
// wrong schema
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "uid",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_None,
|
||||
},
|
||||
},
|
||||
}
|
||||
consumer, err = NewJSONRowConsumer(schema, nil, 2, 16, nil)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, consumer)
|
||||
|
||||
// no primary key
|
||||
schema.Fields[0].IsPrimaryKey = false
|
||||
schema.Fields[0].DataType = schemapb.DataType_Int64
|
||||
consumer, err = NewJSONRowConsumer(schema, nil, 2, 16, nil)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, consumer)
|
||||
|
||||
// primary key is autoid, but no IDAllocator
|
||||
schema.Fields[0].IsPrimaryKey = true
|
||||
schema.Fields[0].AutoID = true
|
||||
consumer, err = NewJSONRowConsumer(schema, nil, 2, 16, nil)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, consumer)
|
||||
|
||||
// success
|
||||
consumer, err = NewJSONRowConsumer(sampleSchema(), nil, 2, 16, nil)
|
||||
assert.NotNil(t, consumer)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// init failed
|
||||
validator.validators = nil
|
||||
err = validator.Handle(nil)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_JSONRowConsumer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
idAllocator := newIDAllocator(ctx, t)
|
||||
idAllocator := newIDAllocator(ctx, t, nil)
|
||||
|
||||
schema := sampleSchema()
|
||||
parser := NewJSONParser(ctx, schema)
|
||||
|
@ -376,9 +255,196 @@ func Test_JSONRowConsumer(t *testing.T) {
|
|||
assert.Equal(t, 5, totalCount)
|
||||
}
|
||||
|
||||
func Test_JSONRowConsumerFlush(t *testing.T) {
|
||||
var callTime int32
|
||||
var totalCount int
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shard int) error {
|
||||
callTime++
|
||||
field, ok := fields[101]
|
||||
assert.True(t, ok)
|
||||
assert.Greater(t, field.RowNum(), 0)
|
||||
totalCount += field.RowNum()
|
||||
return nil
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "uid",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var shardNum int32 = 4
|
||||
var blockSize int64 = 1
|
||||
consumer, err := NewJSONRowConsumer(schema, nil, shardNum, blockSize, flushFunc)
|
||||
assert.NotNil(t, consumer)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// force flush
|
||||
rowCountEachShard := 100
|
||||
for i := 0; i < int(shardNum); i++ {
|
||||
pkFieldData := consumer.segmentsData[i][101].(*storage.Int64FieldData)
|
||||
for j := 0; j < rowCountEachShard; j++ {
|
||||
pkFieldData.Data = append(pkFieldData.Data, int64(j))
|
||||
}
|
||||
pkFieldData.NumRows = []int64{int64(rowCountEachShard)}
|
||||
}
|
||||
|
||||
err = consumer.flush(true)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, shardNum, callTime)
|
||||
assert.Equal(t, rowCountEachShard*int(shardNum), totalCount)
|
||||
|
||||
// execeed block size trigger flush
|
||||
callTime = 0
|
||||
totalCount = 0
|
||||
for i := 0; i < int(shardNum); i++ {
|
||||
consumer.segmentsData[i] = initSegmentData(schema)
|
||||
if i%2 == 0 {
|
||||
continue
|
||||
}
|
||||
pkFieldData := consumer.segmentsData[i][101].(*storage.Int64FieldData)
|
||||
for j := 0; j < rowCountEachShard; j++ {
|
||||
pkFieldData.Data = append(pkFieldData.Data, int64(j))
|
||||
}
|
||||
pkFieldData.NumRows = []int64{int64(rowCountEachShard)}
|
||||
}
|
||||
err = consumer.flush(true)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, shardNum/2, callTime)
|
||||
assert.Equal(t, rowCountEachShard*int(shardNum)/2, totalCount)
|
||||
}
|
||||
|
||||
func Test_JSONRowConsumerHandle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
idAllocator := newIDAllocator(ctx, t, errors.New("error"))
|
||||
|
||||
var callTime int32
|
||||
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shard int) error {
|
||||
callTime++
|
||||
return errors.New("dummy error")
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "uid",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("handle int64 pk", func(t *testing.T) {
|
||||
consumer, err := NewJSONRowConsumer(schema, idAllocator, 1, 1, flushFunc)
|
||||
assert.NotNil(t, consumer)
|
||||
assert.Nil(t, err)
|
||||
|
||||
pkFieldData := consumer.segmentsData[0][101].(*storage.Int64FieldData)
|
||||
for i := 0; i < 10; i++ {
|
||||
pkFieldData.Data = append(pkFieldData.Data, int64(i))
|
||||
}
|
||||
pkFieldData.NumRows = []int64{int64(10)}
|
||||
|
||||
// nil input will trigger flush
|
||||
err = consumer.Handle(nil)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, int32(1), callTime)
|
||||
|
||||
// optional flush
|
||||
callTime = 0
|
||||
rowCount := 100
|
||||
pkFieldData = consumer.segmentsData[0][101].(*storage.Int64FieldData)
|
||||
for j := 0; j < rowCount; j++ {
|
||||
pkFieldData.Data = append(pkFieldData.Data, int64(j))
|
||||
}
|
||||
pkFieldData.NumRows = []int64{int64(rowCount)}
|
||||
|
||||
input := make([]map[storage.FieldID]interface{}, rowCount)
|
||||
for j := 0; j < rowCount; j++ {
|
||||
input[j] = make(map[int64]interface{})
|
||||
input[j][101] = int64(j)
|
||||
}
|
||||
err = consumer.Handle(input)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, int32(1), callTime)
|
||||
|
||||
// failed to auto-generate pk
|
||||
consumer.blockSize = 1024 * 1024
|
||||
err = consumer.Handle(input)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// hash int64 pk
|
||||
consumer.rowIDAllocator = newIDAllocator(ctx, t, nil)
|
||||
err = consumer.Handle(input)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(rowCount), consumer.rowCounter)
|
||||
assert.Equal(t, 2, len(consumer.autoIDRange))
|
||||
assert.Equal(t, int64(1), consumer.autoIDRange[0])
|
||||
assert.Equal(t, int64(1+rowCount), consumer.autoIDRange[1])
|
||||
})
|
||||
|
||||
t.Run("handle varchar pk", func(t *testing.T) {
|
||||
schema = &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "uid",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "max_length", Value: "1024"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
idAllocator := newIDAllocator(ctx, t, nil)
|
||||
consumer, err := NewJSONRowConsumer(schema, idAllocator, 1, 1024*1024, flushFunc)
|
||||
assert.NotNil(t, consumer)
|
||||
assert.Nil(t, err)
|
||||
|
||||
rowCount := 100
|
||||
input := make([]map[storage.FieldID]interface{}, rowCount)
|
||||
for j := 0; j < rowCount; j++ {
|
||||
input[j] = make(map[int64]interface{})
|
||||
input[j][101] = "abc"
|
||||
}
|
||||
|
||||
// varchar pk cannot be autoid
|
||||
err = consumer.Handle(input)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// hash varchar pk
|
||||
schema.Fields[0].AutoID = false
|
||||
consumer, err = NewJSONRowConsumer(schema, idAllocator, 1, 1024*1024, flushFunc)
|
||||
assert.NotNil(t, consumer)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = consumer.Handle(input)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(rowCount), consumer.rowCounter)
|
||||
assert.Equal(t, 0, len(consumer.autoIDRange))
|
||||
})
|
||||
}
|
||||
|
||||
func Test_JSONRowConsumerStringKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
idAllocator := newIDAllocator(ctx, t)
|
||||
idAllocator := newIDAllocator(ctx, t, nil)
|
||||
|
||||
schema := strKeySchema()
|
||||
parser := NewJSONParser(ctx, schema)
|
||||
|
@ -501,71 +567,3 @@ func Test_JSONRowConsumerStringKey(t *testing.T) {
|
|||
assert.Equal(t, shardNum, callTime)
|
||||
assert.Equal(t, 10, totalCount)
|
||||
}
|
||||
|
||||
func Test_JSONColumnConsumer(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
schema := sampleSchema()
|
||||
parser := NewJSONParser(ctx, schema)
|
||||
assert.NotNil(t, parser)
|
||||
|
||||
reader := strings.NewReader(`{
|
||||
"field_bool": [true, false, true, true, true],
|
||||
"field_int8": [10, 11, 12, 13, 14],
|
||||
"field_int16": [100, 101, 102, 103, 104],
|
||||
"field_int32": [1000, 1001, 1002, 1003, 1004],
|
||||
"field_int64": [10000, 10001, 10002, 10003, 10004],
|
||||
"field_float": [3.14, 3.15, 3.16, 3.17, 3.18],
|
||||
"field_double": [5.1, 5.2, 5.3, 5.4, 5.5],
|
||||
"field_string": ["a", "b", "c", "d", "e"],
|
||||
"field_binary_vector": [
|
||||
[254, 1],
|
||||
[253, 2],
|
||||
[252, 3],
|
||||
[251, 4],
|
||||
[250, 5]
|
||||
],
|
||||
"field_float_vector": [
|
||||
[1.1, 1.2, 1.3, 1.4],
|
||||
[2.1, 2.2, 2.3, 2.4],
|
||||
[3.1, 3.2, 3.3, 3.4],
|
||||
[4.1, 4.2, 4.3, 4.4],
|
||||
[5.1, 5.2, 5.3, 5.4]
|
||||
]
|
||||
}`)
|
||||
|
||||
callTime := 0
|
||||
rowCount := 0
|
||||
consumeFunc := func(fields map[storage.FieldID]storage.FieldData) error {
|
||||
callTime++
|
||||
for id, data := range fields {
|
||||
if id == common.RowIDField {
|
||||
continue
|
||||
}
|
||||
if rowCount == 0 {
|
||||
rowCount = data.RowNum()
|
||||
} else {
|
||||
assert.Equal(t, rowCount, data.RowNum())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
consumer, err := NewJSONColumnConsumer(schema, consumeFunc)
|
||||
assert.NotNil(t, consumer)
|
||||
assert.Nil(t, err)
|
||||
|
||||
validator, err := NewJSONColumnValidator(schema, consumer)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.Nil(t, err)
|
||||
for _, count := range validator.ValidateCount() {
|
||||
assert.Equal(t, int64(5), count)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, callTime)
|
||||
assert.Equal(t, 5, rowCount)
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
|
@ -84,28 +85,26 @@ func adjustBufSize(parser *JSONParser, collectionSchema *schemapb.CollectionSche
|
|||
bufSize = MinBufferSize
|
||||
}
|
||||
|
||||
log.Info("JSON parse: reset bufSize", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufSize", bufSize))
|
||||
log.Info("JSON parser: reset bufSize", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufSize", bufSize))
|
||||
parser.bufSize = int64(bufSize)
|
||||
}
|
||||
|
||||
func (p *JSONParser) logError(msg string) error {
|
||||
log.Error(msg)
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
||||
if handler == nil {
|
||||
return p.logError("JSON parse handler is nil")
|
||||
log.Error("JSON parse handler is nil")
|
||||
return errors.New("JSON parse handler is nil")
|
||||
}
|
||||
|
||||
dec := json.NewDecoder(r)
|
||||
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return p.logError("JSON parse: row count is 0")
|
||||
log.Error("JSON parser: row count is 0")
|
||||
return errors.New("JSON parser: row count is 0")
|
||||
}
|
||||
if t != json.Delim('{') {
|
||||
return p.logError("JSON parse: invalid JSON format, the content should be started with'{'")
|
||||
log.Error("JSON parser: invalid JSON format, the content should be started with'{'")
|
||||
return errors.New("JSON parser: invalid JSON format, the content should be started with'{'")
|
||||
}
|
||||
|
||||
// read the first level
|
||||
|
@ -114,24 +113,28 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
// read the key
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return p.logError("JSON parse: " + err.Error())
|
||||
log.Error("JSON parser: read json token error", zap.Any("err", err))
|
||||
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
||||
}
|
||||
key := t.(string)
|
||||
keyLower := strings.ToLower(key)
|
||||
|
||||
// the root key should be RowRootNode
|
||||
if keyLower != RowRootNode {
|
||||
return p.logError("JSON parse: invalid row-based JSON format, the key " + key + " is not found")
|
||||
log.Error("JSON parser: invalid row-based JSON format, the key is not found", zap.String("key", key))
|
||||
return fmt.Errorf("JSON parser: invalid row-based JSON format, the key %s is not found", key)
|
||||
}
|
||||
|
||||
// started by '['
|
||||
t, err = dec.Token()
|
||||
if err != nil {
|
||||
return p.logError("JSON parse: " + err.Error())
|
||||
log.Error("JSON parser: read json token error", zap.Any("err", err))
|
||||
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
||||
}
|
||||
|
||||
if t != json.Delim('[') {
|
||||
return p.logError("JSON parse: invalid row-based JSON format, rows list should begin with '['")
|
||||
log.Error("JSON parser: invalid row-based JSON format, rows list should begin with '['")
|
||||
return errors.New("JSON parser: invalid row-based JSON format, rows list should begin with '['")
|
||||
}
|
||||
|
||||
// read buffer
|
||||
|
@ -139,27 +142,36 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
for dec.More() {
|
||||
var value interface{}
|
||||
if err := dec.Decode(&value); err != nil {
|
||||
return p.logError("JSON parse: " + err.Error())
|
||||
log.Error("JSON parser: decode json value error", zap.Any("err", err))
|
||||
return fmt.Errorf("JSON parser: decode json value error: %v", err)
|
||||
}
|
||||
|
||||
switch value.(type) {
|
||||
case map[string]interface{}:
|
||||
break
|
||||
default:
|
||||
return p.logError("JSON parse: invalid JSON format, each row should be a key-value map")
|
||||
log.Error("JSON parser: invalid JSON format, each row should be a key-value map")
|
||||
return errors.New("JSON parser: invalid JSON format, each row should be a key-value map")
|
||||
}
|
||||
|
||||
row := make(map[storage.FieldID]interface{})
|
||||
stringMap := value.(map[string]interface{})
|
||||
for k, v := range stringMap {
|
||||
row[p.name2FieldID[k]] = v
|
||||
// if user provided redundant field, return error
|
||||
fieldID, ok := p.name2FieldID[k]
|
||||
if !ok {
|
||||
log.Error("JSON parser: the field is not defined in collection schema", zap.String("fieldName", k))
|
||||
return fmt.Errorf("JSON parser: the field '%s' is not defined in collection schema", k)
|
||||
}
|
||||
row[fieldID] = v
|
||||
}
|
||||
|
||||
buf = append(buf, row)
|
||||
if len(buf) >= int(p.bufSize) {
|
||||
isEmpty = false
|
||||
if err = handler.Handle(buf); err != nil {
|
||||
return p.logError(err.Error())
|
||||
log.Error("JSON parser: parse values error", zap.Any("err", err))
|
||||
return fmt.Errorf("JSON parser: parse values error: %v", err)
|
||||
}
|
||||
|
||||
// clear the buffer
|
||||
|
@ -171,26 +183,27 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
if len(buf) > 0 {
|
||||
isEmpty = false
|
||||
if err = handler.Handle(buf); err != nil {
|
||||
return p.logError(err.Error())
|
||||
log.Error("JSON parser: parse values error", zap.Any("err", err))
|
||||
return fmt.Errorf("JSON parser: parse values error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// end by ']'
|
||||
t, err = dec.Token()
|
||||
if err != nil {
|
||||
return p.logError("JSON parse: " + err.Error())
|
||||
log.Error("JSON parser: read json token error", zap.Any("err", err))
|
||||
return fmt.Errorf("JSON parser: read json token error: %v", err)
|
||||
}
|
||||
|
||||
if t != json.Delim(']') {
|
||||
return p.logError("JSON parse: invalid column-based JSON format, rows list should end with a ']'")
|
||||
log.Error("JSON parser: invalid column-based JSON format, rows list should end with a ']'")
|
||||
return errors.New("JSON parser: invalid column-based JSON format, rows list should end with a ']'")
|
||||
}
|
||||
|
||||
// canceled?
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return p.logError("import task was canceled")
|
||||
default:
|
||||
break
|
||||
// outside context might be canceled(service stop, or future enhancement for canceling import task)
|
||||
if isCanceled(p.ctx) {
|
||||
log.Error("JSON parser: import task was canceled")
|
||||
return errors.New("JSON parser: import task was canceled")
|
||||
}
|
||||
|
||||
// this break means we require the first node must be RowRootNode
|
||||
|
@ -199,106 +212,8 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
|
|||
}
|
||||
|
||||
if isEmpty {
|
||||
return p.logError("JSON parse: row count is 0")
|
||||
}
|
||||
|
||||
// send nil to notify the handler all have done
|
||||
return handler.Handle(nil)
|
||||
}
|
||||
|
||||
func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error {
|
||||
if handler == nil {
|
||||
return p.logError("JSON parse handler is nil")
|
||||
}
|
||||
|
||||
dec := json.NewDecoder(r)
|
||||
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return p.logError("JSON parse: row count is 0")
|
||||
}
|
||||
if t != json.Delim('{') {
|
||||
return p.logError("JSON parse: invalid JSON format, the content should be started with'{'")
|
||||
}
|
||||
|
||||
// read the first level
|
||||
isEmpty := true
|
||||
for dec.More() {
|
||||
// read the key
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return p.logError("JSON parse: " + err.Error())
|
||||
}
|
||||
key := t.(string)
|
||||
|
||||
// not a valid column name, skip
|
||||
_, isValidField := p.fields[key]
|
||||
|
||||
// started by '['
|
||||
t, err = dec.Token()
|
||||
if err != nil {
|
||||
return p.logError("JSON parse: " + err.Error())
|
||||
}
|
||||
|
||||
if t != json.Delim('[') {
|
||||
return p.logError("JSON parse: invalid column-based JSON format, each field should begin with '['")
|
||||
}
|
||||
|
||||
id := p.name2FieldID[key]
|
||||
// read buffer
|
||||
buf := make(map[storage.FieldID][]interface{})
|
||||
buf[id] = make([]interface{}, 0, MinBufferSize)
|
||||
for dec.More() {
|
||||
var value interface{}
|
||||
if err := dec.Decode(&value); err != nil {
|
||||
return p.logError("JSON parse: " + err.Error())
|
||||
}
|
||||
|
||||
if !isValidField {
|
||||
continue
|
||||
}
|
||||
|
||||
buf[id] = append(buf[id], value)
|
||||
if len(buf[id]) >= int(p.bufSize) {
|
||||
isEmpty = false
|
||||
if err = handler.Handle(buf); err != nil {
|
||||
return p.logError(err.Error())
|
||||
}
|
||||
|
||||
// clear the buffer
|
||||
buf[id] = make([]interface{}, 0, MinBufferSize)
|
||||
}
|
||||
}
|
||||
|
||||
// some values in buffer not parsed, parse them
|
||||
if len(buf[id]) > 0 {
|
||||
isEmpty = false
|
||||
if err = handler.Handle(buf); err != nil {
|
||||
return p.logError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// end by ']'
|
||||
t, err = dec.Token()
|
||||
if err != nil {
|
||||
return p.logError("JSON parse: " + err.Error())
|
||||
}
|
||||
|
||||
if t != json.Delim(']') {
|
||||
return p.logError("JSON parse: invalid column-based JSON format, each field should end with a ']'")
|
||||
}
|
||||
|
||||
// canceled?
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return p.logError("import task was canceled")
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isEmpty {
|
||||
return p.logError("JSON parse: row count is 0")
|
||||
log.Error("JSON parser: row count is 0")
|
||||
return errors.New("JSON parser: row count is 0")
|
||||
}
|
||||
|
||||
// send nil to notify the handler all have done
|
||||
|
|
|
@ -27,157 +27,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func sampleSchema() *schemapb.CollectionSchema {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "field_bool",
|
||||
IsPrimaryKey: false,
|
||||
Description: "bool",
|
||||
DataType: schemapb.DataType_Bool,
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "field_int8",
|
||||
IsPrimaryKey: false,
|
||||
Description: "int8",
|
||||
DataType: schemapb.DataType_Int8,
|
||||
},
|
||||
{
|
||||
FieldID: 104,
|
||||
Name: "field_int16",
|
||||
IsPrimaryKey: false,
|
||||
Description: "int16",
|
||||
DataType: schemapb.DataType_Int16,
|
||||
},
|
||||
{
|
||||
FieldID: 105,
|
||||
Name: "field_int32",
|
||||
IsPrimaryKey: false,
|
||||
Description: "int32",
|
||||
DataType: schemapb.DataType_Int32,
|
||||
},
|
||||
{
|
||||
FieldID: 106,
|
||||
Name: "field_int64",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
Description: "int64",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 107,
|
||||
Name: "field_float",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float",
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
{
|
||||
FieldID: 108,
|
||||
Name: "field_double",
|
||||
IsPrimaryKey: false,
|
||||
Description: "double",
|
||||
DataType: schemapb.DataType_Double,
|
||||
},
|
||||
{
|
||||
FieldID: 109,
|
||||
Name: "field_string",
|
||||
IsPrimaryKey: false,
|
||||
Description: "string",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "max_length", Value: "128"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 110,
|
||||
Name: "field_binary_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "binary_vector",
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "16"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 111,
|
||||
Name: "field_float_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float_vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func strKeySchema() *schemapb.CollectionSchema {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "uid",
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
Description: "uid",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "max_length", Value: "1024"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "int_scalar",
|
||||
IsPrimaryKey: false,
|
||||
Description: "int_scalar",
|
||||
DataType: schemapb.DataType_Int32,
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "float_scalar",
|
||||
IsPrimaryKey: false,
|
||||
Description: "float_scalar",
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
{
|
||||
FieldID: 104,
|
||||
Name: "string_scalar",
|
||||
IsPrimaryKey: false,
|
||||
Description: "string_scalar",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
},
|
||||
{
|
||||
FieldID: 105,
|
||||
Name: "bool_scalar",
|
||||
IsPrimaryKey: false,
|
||||
Description: "bool_scalar",
|
||||
DataType: schemapb.DataType_Bool,
|
||||
},
|
||||
{
|
||||
FieldID: 106,
|
||||
Name: "vectors",
|
||||
IsPrimaryKey: false,
|
||||
Description: "vectors",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func Test_AdjustBufSize(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
@ -234,6 +83,7 @@ func Test_JSONParserParserRows(t *testing.T) {
|
|||
]
|
||||
}`)
|
||||
|
||||
// handler is nil
|
||||
err := parser.ParseRows(reader, nil)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
|
@ -241,10 +91,12 @@ func Test_JSONParserParserRows(t *testing.T) {
|
|||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// success
|
||||
err = parser.ParseRows(reader, validator)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(5), validator.ValidateCount())
|
||||
|
||||
// not a row-based format
|
||||
reader = strings.NewReader(`{
|
||||
"dummy":[]
|
||||
}`)
|
||||
|
@ -255,6 +107,7 @@ func Test_JSONParserParserRows(t *testing.T) {
|
|||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// rows is not a list
|
||||
reader = strings.NewReader(`{
|
||||
"rows":
|
||||
}`)
|
||||
|
@ -265,6 +118,7 @@ func Test_JSONParserParserRows(t *testing.T) {
|
|||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// typo
|
||||
reader = strings.NewReader(`{
|
||||
"rows": [}
|
||||
}`)
|
||||
|
@ -275,6 +129,7 @@ func Test_JSONParserParserRows(t *testing.T) {
|
|||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// rows is not a list
|
||||
reader = strings.NewReader(`{
|
||||
"rows": {}
|
||||
}`)
|
||||
|
@ -285,6 +140,7 @@ func Test_JSONParserParserRows(t *testing.T) {
|
|||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// rows is not a list of list
|
||||
reader = strings.NewReader(`{
|
||||
"rows": [[]]
|
||||
}`)
|
||||
|
@ -295,6 +151,7 @@ func Test_JSONParserParserRows(t *testing.T) {
|
|||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// not valid json format
|
||||
reader = strings.NewReader(`[]`)
|
||||
validator, err = NewJSONRowValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
|
@ -303,6 +160,7 @@ func Test_JSONParserParserRows(t *testing.T) {
|
|||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// empty content
|
||||
reader = strings.NewReader(`{}`)
|
||||
validator, err = NewJSONRowValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
|
@ -311,6 +169,7 @@ func Test_JSONParserParserRows(t *testing.T) {
|
|||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// empty content
|
||||
reader = strings.NewReader(``)
|
||||
validator, err = NewJSONRowValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
|
@ -318,129 +177,23 @@ func Test_JSONParserParserRows(t *testing.T) {
|
|||
|
||||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_JSONParserParserColumns(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
schema := sampleSchema()
|
||||
parser := NewJSONParser(ctx, schema)
|
||||
assert.NotNil(t, parser)
|
||||
parser.bufSize = 1
|
||||
|
||||
reader := strings.NewReader(`{
|
||||
"field_bool": [true, false, true, true, true],
|
||||
"field_int8": [10, 11, 12, 13, 14],
|
||||
"field_int16": [100, 101, 102, 103, 104],
|
||||
"field_int32": [1000, 1001, 1002, 1003, 1004],
|
||||
"field_int64": [10000, 10001, 10002, 10003, 10004],
|
||||
"field_float": [3.14, 3.15, 3.16, 3.17, 3.18],
|
||||
"field_double": [5.1, 5.2, 5.3, 5.4, 5.5],
|
||||
"field_string": ["a", "b", "c", "d", "e"],
|
||||
"field_binary_vector": [
|
||||
[254, 1],
|
||||
[253, 2],
|
||||
[252, 3],
|
||||
[251, 4],
|
||||
[250, 5]
|
||||
],
|
||||
"field_float_vector": [
|
||||
[1.1, 1.2, 1.3, 1.4],
|
||||
[2.1, 2.2, 2.3, 2.4],
|
||||
[3.1, 3.2, 3.3, 3.4],
|
||||
[4.1, 4.2, 4.3, 4.4],
|
||||
[5.1, 5.2, 5.3, 5.4]
|
||||
// redundant field
|
||||
reader = strings.NewReader(`{
|
||||
"rows":[
|
||||
{"dummy": 1, "field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
|
||||
]
|
||||
}`)
|
||||
|
||||
err := parser.ParseColumns(reader, nil)
|
||||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
validator, err := NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.Nil(t, err)
|
||||
counter := validator.ValidateCount()
|
||||
for _, v := range counter {
|
||||
assert.Equal(t, int64(5), v)
|
||||
}
|
||||
|
||||
// field missed
|
||||
reader = strings.NewReader(`{
|
||||
"field_int8": [10, 11, 12, 13, 14],
|
||||
"dummy":[1, 2, 3]
|
||||
"rows":[
|
||||
{"field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
|
||||
]
|
||||
}`)
|
||||
validator, err = NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
reader = strings.NewReader(`{
|
||||
"dummy":[1, 2, 3]
|
||||
}`)
|
||||
validator, err = NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
reader = strings.NewReader(`{
|
||||
"field_bool":
|
||||
}`)
|
||||
validator, err = NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
reader = strings.NewReader(`{
|
||||
"field_bool":{}
|
||||
}`)
|
||||
validator, err = NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
reader = strings.NewReader(`{
|
||||
"field_bool":[}
|
||||
}`)
|
||||
validator, err = NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
reader = strings.NewReader(`[]`)
|
||||
validator, err = NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
reader = strings.NewReader(`{}`)
|
||||
validator, err = NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
reader = strings.NewReader(``)
|
||||
validator, err = NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
err = parser.ParseRows(reader, validator)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
|
@ -545,42 +298,3 @@ func Test_JSONParserParserRowsStringKey(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(10), validator.ValidateCount())
|
||||
}
|
||||
|
||||
func Test_JSONParserParserColumnsStrKey(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
schema := strKeySchema()
|
||||
parser := NewJSONParser(ctx, schema)
|
||||
assert.NotNil(t, parser)
|
||||
parser.bufSize = 1
|
||||
|
||||
reader := strings.NewReader(`{
|
||||
"uid": ["Dm4aWrbNzhmjwCTEnCJ9LDPO2N09sqysxgVfbH9Zmn3nBzmwsmk0eZN6x7wSAoPQ", "RP50U0d2napRjXu94a8oGikWgklvVsXFurp8RR4tHGw7N0gk1b7opm59k3FCpyPb", "oxhFkQitWPPw0Bjmj7UQcn4iwvS0CU7RLAC81uQFFQjWtOdiB329CPyWkfGSeYfE", "sxoEL4Mpk1LdsyXhbNm059UWJ3CvxURLCQczaVI5xtBD4QcVWTDFUW7dBdye6nbn", "g33Rqq2UQSHPRHw5FvuXxf5uGEhIAetxE6UuXXCJj0hafG8WuJr1ueZftsySCqAd"],
|
||||
"int_scalar": [9070353, 8505288, 4392660, 7927425, 9288807],
|
||||
"float_scalar": [0.9798043638085004, 0.937913432198687, 0.32381232630490264, 0.31074026464844895, 0.4953578200336135],
|
||||
"string_scalar": ["ShQ44OX0z8kGpRPhaXmfSsdH7JHq5DsZzu0e2umS1hrWG0uONH2RIIAdOECaaXir", "Ld4b0avxathBdNvCrtm3QsWO1pYktUVR7WgAtrtozIwrA8vpeactNhJ85CFGQnK5", "EmAlB0xdQcxeBtwlZJQnLgKodiuRinynoQtg0eXrjkq24dQohzSm7Bx3zquHd3kO", "fdY2beCvs1wSws0Gb9ySD92xwfEfJpX5DQgsWoISylBAoYOcXpRaqIJoXYS4g269", "6f8Iv1zQAGksj5XxMbbI5evTrYrB8fSFQ58jl0oU7Z4BpA81VsD2tlWqkhfoBNa7"],
|
||||
"bool_scalar": [true, false, true, false, false],
|
||||
"vectors": [
|
||||
[0.5040062902126952, 0.8297619818664708, 0.20248342801564806, 0.12834786423659314],
|
||||
[0.528232122836893, 0.6916116750653186, 0.41443762522548705, 0.26624344144792056],
|
||||
[0.7978693027281338, 0.12394906726785092, 0.42431962903815285, 0.4098707807351914],
|
||||
[0.3716157812069954, 0.006981281113265229, 0.9007003458552365, 0.22492634316191004],
|
||||
[0.5921374209648096, 0.04234832587925662, 0.7803878096531548, 0.1964045837884633]
|
||||
]
|
||||
}`)
|
||||
|
||||
err := parser.ParseColumns(reader, nil)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
validator, err := NewJSONColumnValidator(schema, nil)
|
||||
assert.NotNil(t, validator)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.ParseColumns(reader, validator)
|
||||
assert.Nil(t, err)
|
||||
counter := validator.ValidateCount()
|
||||
for _, v := range counter {
|
||||
assert.Equal(t, int64(5), v)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,11 +20,27 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/sbinet/npyio"
|
||||
"github.com/sbinet/npyio/npy"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
reStrPre = regexp.MustCompile(`^[|]*?(\d.*)[Sa]$`)
|
||||
reStrPost = regexp.MustCompile(`^[|]*?[Sa](\d.*)$`)
|
||||
reUniPre = regexp.MustCompile(`^[<|>]*?(\d.*)U$`)
|
||||
reUniPost = regexp.MustCompile(`^[<|>]*?U(\d.*)$`)
|
||||
)
|
||||
|
||||
func CreateNumpyFile(path string, data interface{}) error {
|
||||
|
@ -52,17 +68,18 @@ func CreateNumpyData(data interface{}) ([]byte, error) {
|
|||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// a class to expand other numpy lib ability
|
||||
// NumpyAdapter is the class to expand other numpy lib ability
|
||||
// we evaluate two go-numpy lins: github.com/kshedden/gonpy and github.com/sbinet/npyio
|
||||
// the npyio lib read data one by one, the performance is poor, we expand the read methods
|
||||
// to read data in one batch, the performance is 100X faster
|
||||
// the gonpy lib also read data in one batch, but it has no method to read bool data, and the ability
|
||||
// to handle different data type is not strong as the npylib, so we choose the npyio lib to expand.
|
||||
type NumpyAdapter struct {
|
||||
reader io.Reader // data source, typically is os.File
|
||||
npyReader *npy.Reader // reader of npyio lib
|
||||
order binary.ByteOrder // LittleEndian or BigEndian
|
||||
readPosition int // how many elements have been read
|
||||
reader io.Reader // data source, typically is os.File
|
||||
npyReader *npy.Reader // reader of npyio lib
|
||||
order binary.ByteOrder // LittleEndian or BigEndian
|
||||
readPosition int // how many elements have been read
|
||||
dataType schemapb.DataType // data type parsed from numpy file header
|
||||
}
|
||||
|
||||
func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) {
|
||||
|
@ -70,17 +87,106 @@ func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataType, err := convertNumpyType(r.Header.Descr.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
adapter := &NumpyAdapter{
|
||||
reader: reader,
|
||||
npyReader: r,
|
||||
readPosition: 0,
|
||||
dataType: dataType,
|
||||
}
|
||||
adapter.setByteOrder()
|
||||
|
||||
log.Info("Numpy adapter: numpy header info",
|
||||
zap.Any("shape", r.Header.Descr.Shape),
|
||||
zap.String("dType", r.Header.Descr.Type),
|
||||
zap.Uint8("majorVer", r.Header.Major),
|
||||
zap.Uint8("minorVer", r.Header.Minor),
|
||||
zap.String("ByteOrder", adapter.order.String()))
|
||||
|
||||
return adapter, err
|
||||
}
|
||||
|
||||
// the logic of this method is copied from npyio lib
|
||||
// convertNumpyType gets data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector)
|
||||
func convertNumpyType(typeStr string) (schemapb.DataType, error) {
|
||||
log.Info("Numpy adapter: parse numpy file dtype", zap.String("dtype", typeStr))
|
||||
switch typeStr {
|
||||
case "b1", "<b1", "|b1", "bool":
|
||||
return schemapb.DataType_Bool, nil
|
||||
case "u1", "<u1", "|u1", "uint8": // binary vector data type is uint8
|
||||
return schemapb.DataType_BinaryVector, nil
|
||||
case "i1", "<i1", "|i1", ">i1", "int8":
|
||||
return schemapb.DataType_Int8, nil
|
||||
case "i2", "<i2", "|i2", ">i2", "int16":
|
||||
return schemapb.DataType_Int16, nil
|
||||
case "i4", "<i4", "|i4", ">i4", "int32":
|
||||
return schemapb.DataType_Int32, nil
|
||||
case "i8", "<i8", "|i8", ">i8", "int64":
|
||||
return schemapb.DataType_Int64, nil
|
||||
case "f4", "<f4", "|f4", ">f4", "float32":
|
||||
return schemapb.DataType_Float, nil
|
||||
case "f8", "<f8", "|f8", ">f8", "float64":
|
||||
return schemapb.DataType_Double, nil
|
||||
default:
|
||||
if isStringType(typeStr) {
|
||||
return schemapb.DataType_VarChar, nil
|
||||
}
|
||||
log.Error("Numpy adapter: the numpy file data type not supported", zap.String("dataType", typeStr))
|
||||
return schemapb.DataType_None, fmt.Errorf("Numpy adapter: the numpy file dtype '%s' is not supported", typeStr)
|
||||
}
|
||||
}
|
||||
|
||||
func stringLen(dtype string) (int, bool, error) {
|
||||
var utf bool
|
||||
switch {
|
||||
case reStrPre.MatchString(dtype), reStrPost.MatchString(dtype):
|
||||
utf = false
|
||||
case reUniPre.MatchString(dtype), reUniPost.MatchString(dtype):
|
||||
utf = true
|
||||
}
|
||||
|
||||
if m := reStrPre.FindStringSubmatch(dtype); m != nil {
|
||||
v, err := strconv.Atoi(m[1])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return v, utf, nil
|
||||
}
|
||||
if m := reStrPost.FindStringSubmatch(dtype); m != nil {
|
||||
v, err := strconv.Atoi(m[1])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return v, utf, nil
|
||||
}
|
||||
if m := reUniPre.FindStringSubmatch(dtype); m != nil {
|
||||
v, err := strconv.Atoi(m[1])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return v, utf, nil
|
||||
}
|
||||
if m := reUniPost.FindStringSubmatch(dtype); m != nil {
|
||||
v, err := strconv.Atoi(m[1])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return v, utf, nil
|
||||
}
|
||||
|
||||
return 0, false, fmt.Errorf("Numpy adapter: data type '%s' of numpy file is not varchar data type", dtype)
|
||||
}
|
||||
|
||||
func isStringType(typeStr string) bool {
|
||||
rt := npyio.TypeFrom(typeStr)
|
||||
return rt == reflect.TypeOf((*string)(nil)).Elem()
|
||||
}
|
||||
|
||||
// setByteOrder sets BigEndian/LittleEndian, the logic of this method is copied from npyio lib
|
||||
func (n *NumpyAdapter) setByteOrder() {
|
||||
var nativeEndian binary.ByteOrder
|
||||
v := uint16(1)
|
||||
|
@ -109,15 +215,15 @@ func (n *NumpyAdapter) NpyReader() *npy.Reader {
|
|||
return n.npyReader
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) GetType() string {
|
||||
return n.npyReader.Header.Descr.Type
|
||||
func (n *NumpyAdapter) GetType() schemapb.DataType {
|
||||
return n.dataType
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) GetShape() []int {
|
||||
return n.npyReader.Header.Descr.Shape
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) checkSize(size int) int {
|
||||
func (n *NumpyAdapter) checkCount(count int) int {
|
||||
shape := n.GetShape()
|
||||
|
||||
// empty file?
|
||||
|
@ -135,35 +241,34 @@ func (n *NumpyAdapter) checkSize(size int) int {
|
|||
}
|
||||
|
||||
// overflow?
|
||||
if size > (total - n.readPosition) {
|
||||
if count > (total - n.readPosition) {
|
||||
return total - n.readPosition
|
||||
}
|
||||
|
||||
return size
|
||||
return count
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadBool(size int) ([]bool, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
func (n *NumpyAdapter) ReadBool(count int) ([]bool, error) {
|
||||
if count <= 0 {
|
||||
return nil, errors.New("Numpy adapter: cannot read bool data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "b1", "<b1", "|b1", "bool":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not bool type")
|
||||
if n.dataType != schemapb.DataType_Bool {
|
||||
return nil, errors.New("Numpy adapter: numpy data is not bool type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
return nil, errors.New("Numpy adapter: end of bool file, nothing to read")
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]bool, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read bool data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
|
@ -172,28 +277,30 @@ func (n *NumpyAdapter) ReadBool(size int) ([]bool, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadUint8(size int) ([]uint8, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
func (n *NumpyAdapter) ReadUint8(count int) ([]uint8, error) {
|
||||
if count <= 0 {
|
||||
return nil, errors.New("Numpy adapter: cannot read uint8 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
// here we don't use n.dataType to check because currently milvus has no uint8 type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "u1", "<u1", "|u1", "uint8":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not uint8 type")
|
||||
return nil, errors.New("Numpy adapter: numpy data is not uint8 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
return nil, errors.New("Numpy adapter: end of uint8 file, nothing to read")
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]uint8, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read uint8 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
|
@ -202,28 +309,27 @@ func (n *NumpyAdapter) ReadUint8(size int) ([]uint8, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadInt8(size int) ([]int8, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
func (n *NumpyAdapter) ReadInt8(count int) ([]int8, error) {
|
||||
if count <= 0 {
|
||||
return nil, errors.New("Numpy adapter: cannot read int8 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "i1", "<i1", "|i1", ">i1", "int8":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not int8 type")
|
||||
if n.dataType != schemapb.DataType_Int8 {
|
||||
return nil, errors.New("Numpy adapter: numpy data is not int8 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
return nil, errors.New("Numpy adapter: end of int8 file, nothing to read")
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]int8, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read int8 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
|
@ -232,28 +338,27 @@ func (n *NumpyAdapter) ReadInt8(size int) ([]int8, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadInt16(size int) ([]int16, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
func (n *NumpyAdapter) ReadInt16(count int) ([]int16, error) {
|
||||
if count <= 0 {
|
||||
return nil, errors.New("Numpy adapter: cannot read int16 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "i2", "<i2", "|i2", ">i2", "int16":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not int16 type")
|
||||
if n.dataType != schemapb.DataType_Int16 {
|
||||
return nil, errors.New("Numpy adapter: numpy data is not int16 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
return nil, errors.New("Numpy adapter: end of int16 file, nothing to read")
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]int16, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read int16 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
|
@ -262,28 +367,27 @@ func (n *NumpyAdapter) ReadInt16(size int) ([]int16, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadInt32(size int) ([]int32, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
func (n *NumpyAdapter) ReadInt32(count int) ([]int32, error) {
|
||||
if count <= 0 {
|
||||
return nil, errors.New("Numpy adapter: cannot read int32 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "i4", "<i4", "|i4", ">i4", "int32":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not int32 type")
|
||||
if n.dataType != schemapb.DataType_Int32 {
|
||||
return nil, errors.New("Numpy adapter: numpy data is not int32 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
return nil, errors.New("Numpy adapter: end of int32 file, nothing to read")
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]int32, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read int32 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
|
@ -292,28 +396,27 @@ func (n *NumpyAdapter) ReadInt32(size int) ([]int32, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadInt64(size int) ([]int64, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
func (n *NumpyAdapter) ReadInt64(count int) ([]int64, error) {
|
||||
if count <= 0 {
|
||||
return nil, errors.New("Numpy adapter: cannot read int64 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "i8", "<i8", "|i8", ">i8", "int64":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not int64 type")
|
||||
if n.dataType != schemapb.DataType_Int64 {
|
||||
return nil, errors.New("Numpy adapter: numpy data is not int64 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
return nil, errors.New("Numpy adapter: end of int64 file, nothing to read")
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]int64, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read int64 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
|
@ -322,28 +425,27 @@ func (n *NumpyAdapter) ReadInt64(size int) ([]int64, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadFloat32(size int) ([]float32, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
func (n *NumpyAdapter) ReadFloat32(count int) ([]float32, error) {
|
||||
if count <= 0 {
|
||||
return nil, errors.New("Numpy adapter: cannot read float32 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "f4", "<f4", "|f4", ">f4", "float32":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not float32 type")
|
||||
if n.dataType != schemapb.DataType_Float {
|
||||
return nil, errors.New("Numpy adapter: numpy data is not float32 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
return nil, errors.New("Numpy adapter: end of float32 file, nothing to read")
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]float32, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read float32 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
|
@ -352,28 +454,109 @@ func (n *NumpyAdapter) ReadFloat32(size int) ([]float32, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadFloat64(size int) ([]float64, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
func (n *NumpyAdapter) ReadFloat64(count int) ([]float64, error) {
|
||||
if count <= 0 {
|
||||
return nil, errors.New("Numpy adapter: cannot read float64 data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "f8", "<f8", "|f8", ">f8", "float64":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not float32 type")
|
||||
if n.dataType != schemapb.DataType_Double {
|
||||
return nil, errors.New("Numpy adapter: numpy data is not float64 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
return nil, errors.New("Numpy adapter: end of float64 file, nothing to read")
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]float64, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read float64 data with count %d, error: %w", readSize, err)
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
n.readPosition += readSize
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadString(count int) ([]string, error) {
|
||||
if count <= 0 {
|
||||
return nil, errors.New("Numpy adapter: cannot read varhar data with a zero or nagative count")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
if n.dataType != schemapb.DataType_VarChar {
|
||||
return nil, errors.New("Numpy adapter: numpy data is not varhar type")
|
||||
}
|
||||
|
||||
// varchar length, this is the max length, some item is shorter than this length, but they also occupy bytes of max length
|
||||
maxLen, utf, err := stringLen(n.npyReader.Header.Descr.Type)
|
||||
if err != nil || maxLen <= 0 {
|
||||
log.Error("Numpy adapter: failed to get max length of varchar from numpy file header", zap.Int("maxLen", maxLen), zap.Any("err", err))
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to get max length %d of varchar from numpy file header, error: %w", maxLen, err)
|
||||
}
|
||||
log.Info("Numpy adapter: get varchar max length from numpy file header", zap.Int("maxLen", maxLen), zap.Bool("utf", utf))
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkCount(count)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("Numpy adapter: end of varhar file, nothing to read")
|
||||
}
|
||||
|
||||
// read data
|
||||
data := make([]string, 0)
|
||||
for i := 0; i < readSize; i++ {
|
||||
if utf {
|
||||
// in the numpy file, each utf8 character occupy utf8.UTFMax bytes, each string occupys utf8.UTFMax*maxLen bytes
|
||||
// for example, an ANSI character "a" only uses one byte, but it still occupy utf8.UTFMax bytes
|
||||
// a chinese character uses three bytes, it also occupy utf8.UTFMax bytes
|
||||
raw, err := ioutil.ReadAll(io.LimitReader(n.reader, utf8.UTFMax*int64(maxLen)))
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read utf8 string from numpy file", zap.Int("i", i), zap.Any("err", err))
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read utf8 string from numpy file, error: %w", err)
|
||||
}
|
||||
|
||||
var str string
|
||||
for len(raw) > 0 {
|
||||
r, _ := utf8.DecodeRune(raw)
|
||||
if r == utf8.RuneError {
|
||||
log.Error("Numpy adapter: failed to decode utf8 string from numpy file", zap.Any("raw", raw[:utf8.UTFMax]))
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to decode utf8 string from numpy file, error: illegal utf-8 encoding")
|
||||
}
|
||||
|
||||
// only support ascii characters, because the numpy lib encode the utf8 bytes by its internal method,
|
||||
// the encode/decode logic is not clear now, return error
|
||||
n := n.order.Uint32(raw)
|
||||
if n > 127 {
|
||||
log.Error("Numpy adapter: a string contains non-ascii characters, not support yet", zap.Int32("utf8Code", r))
|
||||
return nil, fmt.Errorf("Numpy adapter: a string contains non-ascii characters, not support yet")
|
||||
}
|
||||
|
||||
// if a string is shorter than maxLen, the tail characters will be filled with "\u0000"(in utf spec this is Null)
|
||||
if r > 0 {
|
||||
str += string(r)
|
||||
}
|
||||
|
||||
raw = raw[utf8.UTFMax:]
|
||||
}
|
||||
|
||||
data = append(data, str)
|
||||
} else {
|
||||
buf, err := ioutil.ReadAll(io.LimitReader(n.reader, int64(maxLen)))
|
||||
if err != nil {
|
||||
log.Error("Numpy adapter: failed to read string from numpy file", zap.Int("i", i), zap.Any("err", err))
|
||||
return nil, fmt.Errorf("Numpy adapter: failed to read string from numpy file, error: %w", err)
|
||||
}
|
||||
n := bytes.Index(buf, []byte{0})
|
||||
if n > 0 {
|
||||
buf = buf[:n]
|
||||
}
|
||||
data = append(data, string(buf))
|
||||
}
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/sbinet/npyio/npy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -59,6 +60,55 @@ func Test_CreateNumpyData(t *testing.T) {
|
|||
assert.Nil(t, buf)
|
||||
}
|
||||
|
||||
func Test_ConvertNumpyType(t *testing.T) {
|
||||
checkFunc := func(inputs []string, output schemapb.DataType) {
|
||||
for i := 0; i < len(inputs); i++ {
|
||||
dt, err := convertNumpyType(inputs[i])
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, output, dt)
|
||||
}
|
||||
}
|
||||
|
||||
checkFunc([]string{"b1", "<b1", "|b1", "bool"}, schemapb.DataType_Bool)
|
||||
checkFunc([]string{"i1", "<i1", "|i1", ">i1", "int8"}, schemapb.DataType_Int8)
|
||||
checkFunc([]string{"i2", "<i2", "|i2", ">i2", "int16"}, schemapb.DataType_Int16)
|
||||
checkFunc([]string{"i4", "<i4", "|i4", ">i4", "int32"}, schemapb.DataType_Int32)
|
||||
checkFunc([]string{"i8", "<i8", "|i8", ">i8", "int64"}, schemapb.DataType_Int64)
|
||||
checkFunc([]string{"f4", "<f4", "|f4", ">f4", "float32"}, schemapb.DataType_Float)
|
||||
checkFunc([]string{"f8", "<f8", "|f8", ">f8", "float64"}, schemapb.DataType_Double)
|
||||
|
||||
dt, err := convertNumpyType("dummy")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, schemapb.DataType_None, dt)
|
||||
}
|
||||
|
||||
func Test_StringLen(t *testing.T) {
|
||||
len, utf, err := stringLen("S1")
|
||||
assert.Equal(t, 1, len)
|
||||
assert.False(t, utf)
|
||||
assert.Nil(t, err)
|
||||
|
||||
len, utf, err = stringLen("2S")
|
||||
assert.Equal(t, 2, len)
|
||||
assert.False(t, utf)
|
||||
assert.Nil(t, err)
|
||||
|
||||
len, utf, err = stringLen("<U3")
|
||||
assert.Equal(t, 3, len)
|
||||
assert.True(t, utf)
|
||||
assert.Nil(t, err)
|
||||
|
||||
len, utf, err = stringLen(">4U")
|
||||
assert.Equal(t, 4, len)
|
||||
assert.True(t, utf)
|
||||
assert.Nil(t, err)
|
||||
|
||||
len, utf, err = stringLen("dummy")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, 0, len)
|
||||
assert.False(t, utf)
|
||||
}
|
||||
|
||||
func Test_NumpyAdapterSetByteOrder(t *testing.T) {
|
||||
adapter := &NumpyAdapter{
|
||||
reader: nil,
|
||||
|
@ -82,32 +132,32 @@ func Test_NumpyAdapterReadError(t *testing.T) {
|
|||
npyReader: nil,
|
||||
}
|
||||
|
||||
// reader is nil
|
||||
{
|
||||
_, err := adapter.ReadBool(1)
|
||||
// reader size is zero
|
||||
t.Run("test size is zero", func(t *testing.T) {
|
||||
_, err := adapter.ReadBool(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadUint8(1)
|
||||
_, err = adapter.ReadUint8(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt8(1)
|
||||
_, err = adapter.ReadInt8(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt16(1)
|
||||
_, err = adapter.ReadInt16(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt32(1)
|
||||
_, err = adapter.ReadInt32(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt64(1)
|
||||
_, err = adapter.ReadInt64(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadFloat32(1)
|
||||
_, err = adapter.ReadFloat32(0)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadFloat64(1)
|
||||
_, err = adapter.ReadFloat64(0)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
adapter = &NumpyAdapter{
|
||||
reader: &MockReader{},
|
||||
npyReader: &npy.Reader{},
|
||||
}
|
||||
|
||||
{
|
||||
t.Run("test read bool", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "bool"
|
||||
data, err := adapter.ReadBool(1)
|
||||
assert.Nil(t, data)
|
||||
|
@ -117,9 +167,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
|
|||
data, err = adapter.ReadBool(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read uint8", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "u1"
|
||||
data, err := adapter.ReadUint8(1)
|
||||
assert.Nil(t, data)
|
||||
|
@ -129,9 +179,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
|
|||
data, err = adapter.ReadUint8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read int8", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "i1"
|
||||
data, err := adapter.ReadInt8(1)
|
||||
assert.Nil(t, data)
|
||||
|
@ -141,9 +191,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
|
|||
data, err = adapter.ReadInt8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read int16", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "i2"
|
||||
data, err := adapter.ReadInt16(1)
|
||||
assert.Nil(t, data)
|
||||
|
@ -153,9 +203,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
|
|||
data, err = adapter.ReadInt16(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read int32", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "i4"
|
||||
data, err := adapter.ReadInt32(1)
|
||||
assert.Nil(t, data)
|
||||
|
@ -165,9 +215,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
|
|||
data, err = adapter.ReadInt32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read int64", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "i8"
|
||||
data, err := adapter.ReadInt64(1)
|
||||
assert.Nil(t, data)
|
||||
|
@ -177,9 +227,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
|
|||
data, err = adapter.ReadInt64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read float", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "f4"
|
||||
data, err := adapter.ReadFloat32(1)
|
||||
assert.Nil(t, data)
|
||||
|
@ -189,9 +239,9 @@ func Test_NumpyAdapterReadError(t *testing.T) {
|
|||
data, err = adapter.ReadFloat32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read double", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "f8"
|
||||
data, err := adapter.ReadFloat64(1)
|
||||
assert.Nil(t, data)
|
||||
|
@ -201,7 +251,19 @@ func Test_NumpyAdapterReadError(t *testing.T) {
|
|||
data, err = adapter.ReadFloat64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test read varchar", func(t *testing.T) {
|
||||
adapter.npyReader.Header.Descr.Type = "U3"
|
||||
data, err := adapter.ReadString(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
data, err = adapter.ReadString(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_NumpyAdapterRead(t *testing.T) {
|
||||
|
@ -209,7 +271,7 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
defer os.RemoveAll(TempFilesPath)
|
||||
|
||||
{
|
||||
t.Run("test read bool", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "bool.npy"
|
||||
data := []bool{true, false, true, false}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
|
@ -267,9 +329,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
resf8, err := adapter.ReadFloat64(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, resf8)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read uint8", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "uint8.npy"
|
||||
data := []uint8{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
|
@ -303,9 +365,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
resb, err := adapter.ReadBool(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, resb)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read int8", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "int8.npy"
|
||||
data := []int8{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
|
@ -334,9 +396,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
res, err = adapter.ReadInt8(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read int16", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "int16.npy"
|
||||
data := []int16{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
|
@ -365,9 +427,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
res, err = adapter.ReadInt16(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read int32", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "int32.npy"
|
||||
data := []int32{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
|
@ -396,9 +458,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
res, err = adapter.ReadInt32(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read int64", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "int64.npy"
|
||||
data := []int64{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
|
@ -427,9 +489,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
res, err = adapter.ReadInt64(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read float", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "float.npy"
|
||||
data := []float32{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
|
@ -458,9 +520,9 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
res, err = adapter.ReadFloat32(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("test read double", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "double.npy"
|
||||
data := []float64{1, 2, 3, 4, 5, 6}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
|
@ -489,5 +551,52 @@ func Test_NumpyAdapterRead(t *testing.T) {
|
|||
res, err = adapter.ReadFloat64(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test read ascii characters", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "varchar1.npy"
|
||||
data := []string{"a", "bbb", "c", "dd", "eeee", "fff"}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
res, err := adapter.ReadString(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
res, err = adapter.ReadString(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
res, err = adapter.ReadString(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
t.Run("test read non-ascii", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "varchar2.npy"
|
||||
data := []string{"a三百", "马克bbb"}
|
||||
err := CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
res, err := adapter.ReadString(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -19,12 +19,13 @@ package importutil
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ColumnDesc struct {
|
||||
|
@ -43,7 +44,7 @@ type NumpyParser struct {
|
|||
callFlushFunc func(field storage.FieldData) error // call back function to output column data
|
||||
}
|
||||
|
||||
// NewNumpyParser helper function to create a NumpyParser
|
||||
// NewNumpyParser is helper function to create a NumpyParser
|
||||
func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema,
|
||||
flushFunc func(field storage.FieldData) error) *NumpyParser {
|
||||
if collectionSchema == nil || flushFunc == nil {
|
||||
|
@ -60,38 +61,10 @@ func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSc
|
|||
return parser
|
||||
}
|
||||
|
||||
func (p *NumpyParser) logError(msg string) error {
|
||||
log.Error(msg)
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
// data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector)
|
||||
func convertNumpyType(str string) (schemapb.DataType, error) {
|
||||
switch str {
|
||||
case "b1", "<b1", "|b1", "bool":
|
||||
return schemapb.DataType_Bool, nil
|
||||
case "u1", "<u1", "|u1", "uint8": // binary vector data type is uint8
|
||||
return schemapb.DataType_BinaryVector, nil
|
||||
case "i1", "<i1", "|i1", ">i1", "int8":
|
||||
return schemapb.DataType_Int8, nil
|
||||
case "i2", "<i2", "|i2", ">i2", "int16":
|
||||
return schemapb.DataType_Int16, nil
|
||||
case "i4", "<i4", "|i4", ">i4", "int32":
|
||||
return schemapb.DataType_Int32, nil
|
||||
case "i8", "<i8", "|i8", ">i8", "int64":
|
||||
return schemapb.DataType_Int64, nil
|
||||
case "f4", "<f4", "|f4", ">f4", "float32":
|
||||
return schemapb.DataType_Float, nil
|
||||
case "f8", "<f8", "|f8", ">f8", "float64":
|
||||
return schemapb.DataType_Double, nil
|
||||
default:
|
||||
return schemapb.DataType_None, errors.New("unsupported data type " + str)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
|
||||
if adapter == nil {
|
||||
return errors.New("numpy adapter is nil")
|
||||
log.Error("Numpy parser: numpy adapter is nil")
|
||||
return errors.New("Numpy parser: numpy adapter is nil")
|
||||
}
|
||||
|
||||
// check existence of the target field
|
||||
|
@ -105,27 +78,32 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
|
|||
}
|
||||
|
||||
if p.columnDesc.name == "" {
|
||||
return errors.New("the field " + fieldName + " doesn't exist")
|
||||
log.Error("Numpy parser: Numpy parser: the field is not found in collection schema", zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("Numpy parser: the field name '%s' is not found in collection schema", fieldName)
|
||||
}
|
||||
|
||||
p.columnDesc.dt = schema.DataType
|
||||
elementType, err := convertNumpyType(adapter.GetType())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elementType := adapter.GetType()
|
||||
shape := adapter.GetShape()
|
||||
|
||||
var err error
|
||||
// 1. field data type should be consist to numpy data type
|
||||
// 2. vector field dimension should be consist to numpy shape
|
||||
if schemapb.DataType_FloatVector == schema.DataType {
|
||||
// float32/float64 numpy file can be used for float vector file, 2 reasons:
|
||||
// 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit
|
||||
// 2. for float64 numpy file, the performance is worse than float32 numpy file
|
||||
if elementType != schemapb.DataType_Float && elementType != schemapb.DataType_Double {
|
||||
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
|
||||
log.Error("Numpy parser: illegal data type of numpy file for float vector field", zap.Any("dataType", elementType),
|
||||
zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for float vector field '%s'", getTypeName(elementType), schema.GetName())
|
||||
}
|
||||
|
||||
// vector field, the shape should be 2
|
||||
if len(shape) != 2 {
|
||||
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
|
||||
log.Error("Numpy parser: illegal shape of numpy file for float vector field, shape should be 2", zap.Int("shape", len(shape)),
|
||||
zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("Numpy parser: illegal shape %d of numpy file for float vector field '%s', shape should be 2", shape, schema.GetName())
|
||||
}
|
||||
|
||||
// shape[0] is row count, shape[1] is element count per row
|
||||
|
@ -137,16 +115,23 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
|
|||
}
|
||||
|
||||
if shape[1] != p.columnDesc.dimension {
|
||||
return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension))
|
||||
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName),
|
||||
zap.Int("numpyDimension", shape[1]), zap.Int("fieldDimension", p.columnDesc.dimension))
|
||||
return fmt.Errorf("Numpy parser: illegal dimension %d of numpy file for float vector field '%s', dimension should be %d",
|
||||
shape[1], schema.GetName(), p.columnDesc.dimension)
|
||||
}
|
||||
} else if schemapb.DataType_BinaryVector == schema.DataType {
|
||||
if elementType != schemapb.DataType_BinaryVector {
|
||||
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
|
||||
log.Error("Numpy parser: illegal data type of numpy file for binary vector field", zap.Any("dataType", elementType),
|
||||
zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for binary vector field '%s'", getTypeName(elementType), schema.GetName())
|
||||
}
|
||||
|
||||
// vector field, the shape should be 2
|
||||
if len(shape) != 2 {
|
||||
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
|
||||
log.Error("Numpy parser: illegal shape of numpy file for binary vector field, shape should be 2", zap.Int("shape", len(shape)),
|
||||
zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("Numpy parser: illegal shape %d of numpy file for binary vector field '%s', shape should be 2", shape, schema.GetName())
|
||||
}
|
||||
|
||||
// shape[0] is row count, shape[1] is element count per row
|
||||
|
@ -158,16 +143,24 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
|
|||
}
|
||||
|
||||
if shape[1] != p.columnDesc.dimension/8 {
|
||||
return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension))
|
||||
log.Error("Numpy parser: illegal dimension of numpy file for float vector field", zap.String("fieldName", fieldName),
|
||||
zap.Int("numpyDimension", shape[1]*8), zap.Int("fieldDimension", p.columnDesc.dimension))
|
||||
return fmt.Errorf("Numpy parser: illegal dimension %d of numpy file for binary vector field '%s', dimension should be %d",
|
||||
shape[1]*8, schema.GetName(), p.columnDesc.dimension)
|
||||
}
|
||||
} else {
|
||||
if elementType != schema.DataType {
|
||||
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
|
||||
log.Error("Numpy parser: illegal data type of numpy file for scalar field", zap.Any("numpyDataType", elementType),
|
||||
zap.String("fieldName", fieldName), zap.Any("fieldDataType", schema.DataType))
|
||||
return fmt.Errorf("Numpy parser: illegal data type %s of numpy file for scalar field '%s' with type %d",
|
||||
getTypeName(elementType), schema.GetName(), schema.DataType)
|
||||
}
|
||||
|
||||
// scalar field, the shape should be 1
|
||||
if len(shape) != 1 {
|
||||
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
|
||||
log.Error("Numpy parser: illegal shape of numpy file for scalar field, shape should be 1", zap.Int("shape", len(shape)),
|
||||
zap.String("fieldName", fieldName))
|
||||
return fmt.Errorf("Numpy parser: illegal shape %d of numpy file for scalar field '%s', shape should be 1", shape, schema.GetName())
|
||||
}
|
||||
|
||||
p.columnDesc.elementCount = shape[0]
|
||||
|
@ -176,13 +169,14 @@ func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// this method read numpy data section into a storage.FieldData
|
||||
// consume method reads numpy data section into a storage.FieldData
|
||||
// please note it will require a large memory block(the memory size is almost equal to numpy file size)
|
||||
func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
||||
switch p.columnDesc.dt {
|
||||
case schemapb.DataType_Bool:
|
||||
data, err := adapter.ReadBool(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -194,6 +188,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
|||
case schemapb.DataType_Int8:
|
||||
data, err := adapter.ReadInt8(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -204,6 +199,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
|||
case schemapb.DataType_Int16:
|
||||
data, err := adapter.ReadInt16(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -214,6 +210,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
|||
case schemapb.DataType_Int32:
|
||||
data, err := adapter.ReadInt32(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -224,6 +221,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
|||
case schemapb.DataType_Int64:
|
||||
data, err := adapter.ReadInt64(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -234,6 +232,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
|||
case schemapb.DataType_Float:
|
||||
data, err := adapter.ReadFloat32(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -244,6 +243,7 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
|||
case schemapb.DataType_Double:
|
||||
data, err := adapter.ReadFloat64(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -251,9 +251,21 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
|||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
Data: data,
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
data, err := adapter.ReadString(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
p.columnData = &storage.StringFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
Data: data,
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
data, err := adapter.ReadUint8(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -263,24 +275,24 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
|||
Dim: p.columnDesc.dimension,
|
||||
}
|
||||
case schemapb.DataType_FloatVector:
|
||||
// for float vector, we support float32 and float64 numpy file because python float value is 64 bit
|
||||
// for float64 numpy file, the performance is worse than float32 numpy file
|
||||
// we don't check overflow here
|
||||
elementType, err := convertNumpyType(adapter.GetType())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// float32/float64 numpy file can be used for float vector file, 2 reasons:
|
||||
// 1. for float vector, we support float32 and float64 numpy file because python float value is 64 bit
|
||||
// 2. for float64 numpy file, the performance is worse than float32 numpy file
|
||||
elementType := adapter.GetType()
|
||||
|
||||
var data []float32
|
||||
var err error
|
||||
if elementType == schemapb.DataType_Float {
|
||||
data, err = adapter.ReadFloat32(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
} else if elementType == schemapb.DataType_Double {
|
||||
data = make([]float32, 0, p.columnDesc.elementCount)
|
||||
data64, err := adapter.ReadFloat64(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -295,7 +307,8 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
|||
Dim: p.columnDesc.dimension,
|
||||
}
|
||||
default:
|
||||
return errors.New("unsupported data type: " + strconv.Itoa(int(p.columnDesc.dt)))
|
||||
log.Error("Numpy parser: unsupported data type of field", zap.Any("dataType", p.columnDesc.dt), zap.String("fieldName", p.columnDesc.name))
|
||||
return fmt.Errorf("Numpy parser: unsupported data type %s of field '%s'", getTypeName(p.columnDesc.dt), p.columnDesc.name)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -304,13 +317,13 @@ func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
|||
func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate bool) error {
|
||||
adapter, err := NewNumpyAdapter(reader)
|
||||
if err != nil {
|
||||
return p.logError("Numpy parse: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// the validation method only check the file header information
|
||||
err = p.validate(adapter, fieldName)
|
||||
if err != nil {
|
||||
return p.logError("Numpy parse: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if onlyValidate {
|
||||
|
@ -320,7 +333,7 @@ func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate boo
|
|||
// read all data from the numpy file
|
||||
err = p.consume(adapter)
|
||||
if err != nil {
|
||||
return p.logError("Numpy parse: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
return p.callFlushFunc(p.columnData)
|
||||
|
|
|
@ -36,28 +36,6 @@ func Test_NewNumpyParser(t *testing.T) {
|
|||
assert.Nil(t, parser)
|
||||
}
|
||||
|
||||
func Test_ConvertNumpyType(t *testing.T) {
|
||||
checkFunc := func(inputs []string, output schemapb.DataType) {
|
||||
for i := 0; i < len(inputs); i++ {
|
||||
dt, err := convertNumpyType(inputs[i])
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, output, dt)
|
||||
}
|
||||
}
|
||||
|
||||
checkFunc([]string{"b1", "<b1", "|b1", "bool"}, schemapb.DataType_Bool)
|
||||
checkFunc([]string{"i1", "<i1", "|i1", ">i1", "int8"}, schemapb.DataType_Int8)
|
||||
checkFunc([]string{"i2", "<i2", "|i2", ">i2", "int16"}, schemapb.DataType_Int16)
|
||||
checkFunc([]string{"i4", "<i4", "|i4", ">i4", "int32"}, schemapb.DataType_Int32)
|
||||
checkFunc([]string{"i8", "<i8", "|i8", ">i8", "int64"}, schemapb.DataType_Int64)
|
||||
checkFunc([]string{"f4", "<f4", "|f4", ">f4", "float32"}, schemapb.DataType_Float)
|
||||
checkFunc([]string{"f8", "<f8", "|f8", ">f8", "float64"}, schemapb.DataType_Double)
|
||||
|
||||
dt, err := convertNumpyType("dummy")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, schemapb.DataType_None, dt)
|
||||
}
|
||||
|
||||
func Test_NumpyParserValidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
err := os.MkdirAll(TempFilesPath, os.ModePerm)
|
||||
|
@ -71,7 +49,7 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
adapter := &NumpyAdapter{npyReader: &npy.Reader{}}
|
||||
|
||||
{
|
||||
t.Run("not support DataType_String", func(t *testing.T) {
|
||||
// string type is not supported
|
||||
p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
|
@ -88,15 +66,14 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
assert.NotNil(t, err)
|
||||
err = p.validate(adapter, "field_string")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
// reader is nil
|
||||
parser := NewNumpyParser(ctx, schema, flushFunc)
|
||||
err = parser.validate(nil, "")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// validate scalar data
|
||||
func() {
|
||||
t.Run("validate scalar", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "scalar_1.npy"
|
||||
data1 := []float64{0, 1, 2, 3, 4, 5}
|
||||
err := CreateNumpyFile(filePath, data1)
|
||||
|
@ -108,6 +85,7 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
adapter, err := NewNumpyAdapter(file1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
|
||||
err = parser.validate(adapter, "field_double")
|
||||
assert.Nil(t, err)
|
||||
|
@ -128,6 +106,7 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
adapter, err = NewNumpyAdapter(file2)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
|
||||
err = parser.validate(adapter, "field_double")
|
||||
assert.NotNil(t, err)
|
||||
|
@ -144,13 +123,13 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
adapter, err = NewNumpyAdapter(file3)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
|
||||
err = parser.validate(adapter, "field_double")
|
||||
assert.NotNil(t, err)
|
||||
}()
|
||||
})
|
||||
|
||||
// validate binary vector data
|
||||
func() {
|
||||
t.Run("validate binary vector", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "binary_vector_1.npy"
|
||||
data1 := [][2]uint8{{0, 1}, {2, 3}, {4, 5}}
|
||||
err := CreateNumpyFile(filePath, data1)
|
||||
|
@ -162,6 +141,7 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
adapter, err := NewNumpyAdapter(file1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
|
||||
err = parser.validate(adapter, "field_binary_vector")
|
||||
assert.Nil(t, err)
|
||||
|
@ -178,10 +158,8 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
defer file2.Close()
|
||||
|
||||
adapter, err = NewNumpyAdapter(file2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_binary_vector")
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, adapter)
|
||||
|
||||
// shape mismatch
|
||||
filePath = TempFilesPath + "binary_vector_3.npy"
|
||||
|
@ -195,6 +173,7 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
adapter, err = NewNumpyAdapter(file3)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
|
||||
err = parser.validate(adapter, "field_binary_vector")
|
||||
assert.NotNil(t, err)
|
||||
|
@ -211,6 +190,7 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
adapter, err = NewNumpyAdapter(file4)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
|
||||
err = parser.validate(adapter, "field_binary_vector")
|
||||
assert.NotNil(t, err)
|
||||
|
@ -228,10 +208,9 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
err = p.validate(adapter, "field_binary_vector")
|
||||
assert.NotNil(t, err)
|
||||
}()
|
||||
})
|
||||
|
||||
// validate float vector data
|
||||
func() {
|
||||
t.Run("validate float vector", func(t *testing.T) {
|
||||
filePath := TempFilesPath + "float_vector.npy"
|
||||
data1 := [][4]float32{{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}, {3, 3, 3, 3}}
|
||||
err := CreateNumpyFile(filePath, data1)
|
||||
|
@ -243,6 +222,7 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
adapter, err := NewNumpyAdapter(file1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
|
||||
err = parser.validate(adapter, "field_float_vector")
|
||||
assert.Nil(t, err)
|
||||
|
@ -260,6 +240,7 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
adapter, err = NewNumpyAdapter(file2)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
|
||||
err = parser.validate(adapter, "field_float_vector")
|
||||
assert.NotNil(t, err)
|
||||
|
@ -276,6 +257,7 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
adapter, err = NewNumpyAdapter(file3)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
|
||||
err = parser.validate(adapter, "field_float_vector")
|
||||
assert.NotNil(t, err)
|
||||
|
@ -292,6 +274,7 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
adapter, err = NewNumpyAdapter(file4)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
|
||||
err = parser.validate(adapter, "field_float_vector")
|
||||
assert.NotNil(t, err)
|
||||
|
@ -309,7 +292,7 @@ func Test_NumpyParserValidate(t *testing.T) {
|
|||
|
||||
err = p.validate(adapter, "field_float_vector")
|
||||
assert.NotNil(t, err)
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func Test_NumpyParserParse(t *testing.T) {
|
||||
|
@ -355,153 +338,179 @@ func Test_NumpyParserParse(t *testing.T) {
|
|||
}()
|
||||
}
|
||||
|
||||
// scalar bool
|
||||
data1 := []bool{true, false, true, false, true}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data1), field.RowNum())
|
||||
t.Run("parse scalar bool", func(t *testing.T) {
|
||||
data := []bool{true, false, true, false, true}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data1); i++ {
|
||||
assert.Equal(t, data1[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data1, "field_bool", flushFunc)
|
||||
|
||||
// scalar int8
|
||||
data2 := []int8{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data2), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data2); i++ {
|
||||
assert.Equal(t, data2[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data2, "field_int8", flushFunc)
|
||||
|
||||
// scalar int16
|
||||
data3 := []int16{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data3), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data3); i++ {
|
||||
assert.Equal(t, data3[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data3, "field_int16", flushFunc)
|
||||
|
||||
// scalar int32
|
||||
data4 := []int32{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data4), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data4); i++ {
|
||||
assert.Equal(t, data4[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data4, "field_int32", flushFunc)
|
||||
|
||||
// scalar int64
|
||||
data5 := []int64{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data5), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data5); i++ {
|
||||
assert.Equal(t, data5[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data5, "field_int64", flushFunc)
|
||||
|
||||
// scalar float
|
||||
data6 := []float32{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data6), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data6); i++ {
|
||||
assert.Equal(t, data6[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data6, "field_float", flushFunc)
|
||||
|
||||
// scalar double
|
||||
data7 := []float64{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data7), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data7); i++ {
|
||||
assert.Equal(t, data7[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data7, "field_double", flushFunc)
|
||||
|
||||
// binary vector
|
||||
data8 := [][2]uint8{{1, 2}, {3, 4}, {5, 6}}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data8), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data8); i++ {
|
||||
row := field.GetRow(i).([]uint8)
|
||||
for k := 0; k < len(row); k++ {
|
||||
assert.Equal(t, data8[i][k], row[k])
|
||||
for i := 0; i < len(data); i++ {
|
||||
assert.Equal(t, data[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data, "field_bool", flushFunc)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data8, "field_binary_vector", flushFunc)
|
||||
t.Run("parse scalar int8", func(t *testing.T) {
|
||||
data := []int8{1, 2, 3, 4, 5}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
|
||||
// double vector(element can be float32 or float64)
|
||||
data9 := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data9), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data9); i++ {
|
||||
row := field.GetRow(i).([]float32)
|
||||
for k := 0; k < len(row); k++ {
|
||||
assert.Equal(t, data9[i][k], row[k])
|
||||
for i := 0; i < len(data); i++ {
|
||||
assert.Equal(t, data[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data, "field_int8", flushFunc)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data9, "field_float_vector", flushFunc)
|
||||
t.Run("parse scalar int16", func(t *testing.T) {
|
||||
data := []int16{1, 2, 3, 4, 5}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
|
||||
data10 := [][4]float64{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data10), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data10); i++ {
|
||||
row := field.GetRow(i).([]float32)
|
||||
for k := 0; k < len(row); k++ {
|
||||
assert.Equal(t, float32(data10[i][k]), row[k])
|
||||
for i := 0; i < len(data); i++ {
|
||||
assert.Equal(t, data[i], field.GetRow(i))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data10, "field_float_vector", flushFunc)
|
||||
return nil
|
||||
}
|
||||
checkFunc(data, "field_int16", flushFunc)
|
||||
})
|
||||
|
||||
t.Run("parse scalar int32", func(t *testing.T) {
|
||||
data := []int32{1, 2, 3, 4, 5}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
assert.Equal(t, data[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data, "field_int32", flushFunc)
|
||||
})
|
||||
|
||||
t.Run("parse scalar int64", func(t *testing.T) {
|
||||
data := []int64{1, 2, 3, 4, 5}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
assert.Equal(t, data[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data, "field_int64", flushFunc)
|
||||
})
|
||||
|
||||
t.Run("parse scalar float", func(t *testing.T) {
|
||||
data := []float32{1, 2, 3, 4, 5}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
assert.Equal(t, data[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data, "field_float", flushFunc)
|
||||
})
|
||||
|
||||
t.Run("parse scalar double", func(t *testing.T) {
|
||||
data := []float64{1, 2, 3, 4, 5}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
assert.Equal(t, data[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data, "field_double", flushFunc)
|
||||
})
|
||||
|
||||
t.Run("parse scalar varchar", func(t *testing.T) {
|
||||
data := []string{"abcd", "sdb", "ok", "milvus"}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
assert.Equal(t, data[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data, "field_string", flushFunc)
|
||||
})
|
||||
|
||||
t.Run("parse binary vector", func(t *testing.T) {
|
||||
data := [][2]uint8{{1, 2}, {3, 4}, {5, 6}}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
row := field.GetRow(i).([]uint8)
|
||||
for k := 0; k < len(row); k++ {
|
||||
assert.Equal(t, data[i][k], row[k])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data, "field_binary_vector", flushFunc)
|
||||
})
|
||||
|
||||
t.Run("parse binary vector with float32", func(t *testing.T) {
|
||||
data := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
row := field.GetRow(i).([]float32)
|
||||
for k := 0; k < len(row); k++ {
|
||||
assert.Equal(t, data[i][k], row[k])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data, "field_float_vector", flushFunc)
|
||||
})
|
||||
|
||||
t.Run("parse binary vector with float64", func(t *testing.T) {
|
||||
data := [][4]float64{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
row := field.GetRow(i).([]float32)
|
||||
for k := 0; k < len(row); k++ {
|
||||
assert.Equal(t, float32(data[i][k]), row[k])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data, "field_float_vector", flushFunc)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_NumpyParserParse_perf(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue