Make bulk load fully work (#16512)

issue: #15604

/kind enhancement

Signed-off-by: Yuchen Gao <yuchen.gao@zilliz.com>
pull/16470/head
Ten Thousand Leaves 2022-04-20 14:03:40 +08:00 committed by GitHub
parent 3a1b2cedd2
commit 289e468a7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 735 additions and 355 deletions

View File

@ -240,6 +240,7 @@ func (m *meta) UpdateFlushSegmentsInfo(
segmentID UniqueID,
flushed bool,
dropped bool,
importing bool,
binlogs, statslogs, deltalogs []*datapb.FieldBinlog,
checkpoints []*datapb.CheckPoint,
startPositions []*datapb.SegmentStartPosition,
@ -248,13 +249,16 @@ func (m *meta) UpdateFlushSegmentsInfo(
defer m.Unlock()
segment := m.segments.GetSegment(segmentID)
if importing {
m.segments.SetRowCount(segmentID, segment.currRows)
segment = m.segments.GetSegment(segmentID)
}
if segment == nil || !isSegmentHealthy(segment) {
return nil
}
clonedSegment := segment.Clone()
kv := make(map[string]string)
modSegments := make(map[UniqueID]*SegmentInfo)
if flushed {
@ -352,6 +356,7 @@ func (m *meta) UpdateFlushSegmentsInfo(
modSegments[cp.GetSegmentID()] = s
}
kv := make(map[string]string)
for _, segment := range modSegments {
segBytes, err := proto.Marshal(segment.SegmentInfo)
if err != nil {
@ -366,6 +371,7 @@ func (m *meta) UpdateFlushSegmentsInfo(
}
if err := m.saveKvTxn(kv); err != nil {
log.Error("failed to store flush segment info into Etcd", zap.Error(err))
return err
}
oldSegmentState := segment.GetState()

View File

@ -241,7 +241,7 @@ func TestUpdateFlushSegmentsInfo(t *testing.T) {
err = meta.AddSegment(segment1)
assert.Nil(t, err)
err = meta.UpdateFlushSegmentsInfo(1, true, false, []*datapb.FieldBinlog{getFieldBinlogPaths(1, "binlog1")},
err = meta.UpdateFlushSegmentsInfo(1, true, false, true, []*datapb.FieldBinlog{getFieldBinlogPaths(1, "binlog1")},
[]*datapb.FieldBinlog{getFieldBinlogPaths(1, "statslog1")},
[]*datapb.FieldBinlog{{Binlogs: []*datapb.Binlog{{EntriesNum: 1, TimestampFrom: 100, TimestampTo: 200, LogSize: 1000}}}},
[]*datapb.CheckPoint{{SegmentID: 1, NumOfRows: 10}}, []*datapb.SegmentStartPosition{{SegmentID: 1, StartPosition: &internalpb.MsgPosition{MsgID: []byte{1, 2, 3}}}})
@ -262,7 +262,7 @@ func TestUpdateFlushSegmentsInfo(t *testing.T) {
meta, err := newMeta(memkv.NewMemoryKV())
assert.Nil(t, err)
err = meta.UpdateFlushSegmentsInfo(1, false, false, nil, nil, nil, nil, nil)
err = meta.UpdateFlushSegmentsInfo(1, false, false, false, nil, nil, nil, nil, nil)
assert.Nil(t, err)
})
@ -274,7 +274,7 @@ func TestUpdateFlushSegmentsInfo(t *testing.T) {
err = meta.AddSegment(segment1)
assert.Nil(t, err)
err = meta.UpdateFlushSegmentsInfo(1, false, false, nil, nil, nil, []*datapb.CheckPoint{{SegmentID: 2, NumOfRows: 10}},
err = meta.UpdateFlushSegmentsInfo(1, false, false, false, nil, nil, nil, []*datapb.CheckPoint{{SegmentID: 2, NumOfRows: 10}},
[]*datapb.SegmentStartPosition{{SegmentID: 2, StartPosition: &internalpb.MsgPosition{MsgID: []byte{1, 2, 3}}}})
assert.Nil(t, err)
@ -296,7 +296,7 @@ func TestUpdateFlushSegmentsInfo(t *testing.T) {
}
meta.segments.SetSegment(1, segmentInfo)
err = meta.UpdateFlushSegmentsInfo(1, true, false, []*datapb.FieldBinlog{getFieldBinlogPaths(1, "binlog")},
err = meta.UpdateFlushSegmentsInfo(1, true, false, false, []*datapb.FieldBinlog{getFieldBinlogPaths(1, "binlog")},
[]*datapb.FieldBinlog{getFieldBinlogPaths(1, "statslog")},
[]*datapb.FieldBinlog{{Binlogs: []*datapb.Binlog{{EntriesNum: 1, TimestampFrom: 100, TimestampTo: 200, LogSize: 1000}}}},
[]*datapb.CheckPoint{{SegmentID: 1, NumOfRows: 10}}, []*datapb.SegmentStartPosition{{SegmentID: 1, StartPosition: &internalpb.MsgPosition{MsgID: []byte{1, 2, 3}}}})

View File

@ -2445,6 +2445,34 @@ func TestImport(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.Status.GetErrorCode())
assert.Equal(t, msgDataCoordIsUnhealthy(Params.DataCoordCfg.NodeID), resp.Status.GetReason())
})
t.Run("test update segment stat", func(t *testing.T) {
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
status, err := svr.UpdateSegmentStatistics(context.TODO(), &datapb.UpdateSegmentStatisticsRequest{
Stats: []*datapb.SegmentStats{{
SegmentID: 100,
NumRows: int64(1),
}},
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.GetErrorCode())
})
t.Run("test update segment stat w/ closed server", func(t *testing.T) {
svr := newTestServer(t, nil)
closeTestServer(t, svr)
status, err := svr.UpdateSegmentStatistics(context.TODO(), &datapb.UpdateSegmentStatisticsRequest{
Stats: []*datapb.SegmentStats{{
SegmentID: 100,
NumRows: int64(1),
}},
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
})
}
// https://github.com/milvus-io/milvus/issues/15659

View File

@ -348,6 +348,7 @@ func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath
req.GetSegmentID(),
req.GetFlushed(),
req.GetDropped(),
req.GetImporting(),
req.GetField2BinlogPaths(),
req.GetField2StatslogPaths(),
req.GetDeltalogs(),
@ -1032,6 +1033,23 @@ func (s *Server) Import(ctx context.Context, itr *datapb.ImportTaskRequest) (*da
return resp, nil
}
// UpdateSegmentStatistics updates a segment's stats.
func (s *Server) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) {
resp := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "",
}
if s.isClosed() {
log.Warn("failed to update segment stat for closed server")
resp.Reason = msgDataCoordIsUnhealthy(Params.DataCoordCfg.NodeID)
return resp, nil
}
s.updateSegmentStatistics(req.GetStats())
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
// getDiff returns the difference of base and remove. i.e. all items that are in `base` but not in `remove`.
func getDiff(base, remove []int64) []int64 {
mb := make(map[int64]struct{}, len(remove))

View File

@ -820,7 +820,18 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest)
}, nil
}
idAllocator, err := allocator2.NewIDAllocator(node.ctx, node.rootCoord, Params.DataNodeCfg.NodeID)
importWrapper := importutil.NewImportWrapper(ctx, schema, 2, Params.DataNodeCfg.FlushInsertBufferSize/(1024*1024), idAllocator, node.chunkManager, importFlushReqFunc(node, req, schema, ts))
importResult := &rootcoordpb.ImportResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
TaskId: req.GetImportTask().TaskId,
DatanodeId: node.NodeID,
State: commonpb.ImportState_ImportPersisted,
Segments: make([]int64, 0),
RowCount: 0,
}
importWrapper := importutil.NewImportWrapper(ctx, schema, 2, Params.DataNodeCfg.FlushInsertBufferSize, idAllocator, node.chunkManager,
importFlushReqFunc(node, req, importResult, schema, ts))
err = importWrapper.Import(req.GetImportTask().GetFiles(), req.GetImportTask().GetRowBased(), false)
if err != nil {
return &commonpb.Status{
@ -829,22 +840,38 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest)
}, nil
}
// report root coord that the import task has been finished
_, err = node.rootCoord.ReportImport(ctx, importResult)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}, nil
}
resp := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
return resp, nil
}
type importFlushFunc func(fields map[storage.FieldID]storage.FieldData) error
func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, schema *schemapb.CollectionSchema, ts Timestamp) importFlushFunc {
return func(fields map[storage.FieldID]storage.FieldData) error {
func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoordpb.ImportResult, schema *schemapb.CollectionSchema, ts Timestamp) importutil.ImportFlushFunc {
return func(fields map[storage.FieldID]storage.FieldData, shardNum int) error {
if shardNum >= len(req.GetImportTask().GetChannelNames()) {
log.Error("import task returns invalid shard number",
zap.Int("# of shards", shardNum),
zap.Int("# of channels", len(req.GetImportTask().GetChannelNames())),
zap.Any("channel names", req.GetImportTask().GetChannelNames()),
)
return fmt.Errorf("syncSegmentID Failed: invalid shard number %d", shardNum)
}
log.Info("import task flush segment", zap.Any("ChannelNames", req.ImportTask.ChannelNames), zap.Int("shardNum", shardNum))
segReqs := []*datapb.SegmentIDRequest{
{
ChannelName: "test-channel",
ChannelName: req.ImportTask.ChannelNames[shardNum],
Count: 1,
CollectionID: req.GetImportTask().GetCollectionId(),
PartitionID: req.GetImportTask().GetCollectionId(),
PartitionID: req.GetImportTask().GetPartitionId(),
},
}
segmentIDReq := &datapb.AssignSegmentIDRequest{
@ -884,6 +911,15 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, schema *s
}
}
fields[common.RowIDField] = fields[pkFieldID]
if status, _ := node.dataCoord.UpdateSegmentStatistics(context.TODO(), &datapb.UpdateSegmentStatisticsRequest{
Stats: []*datapb.SegmentStats{{
SegmentID: segmentID,
NumRows: int64(rowNum),
}},
}); status.GetErrorCode() != commonpb.ErrorCode_Success {
// TODO: reportImport the failure.
return fmt.Errorf(status.GetReason())
}
data := BufferData{buffer: &InsertData{
Data: fields,
@ -985,6 +1021,7 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, schema *s
Field2BinlogPaths: fieldInsert,
Field2StatslogPaths: fieldStats,
Importing: true,
Flushed: true,
}
err = retry.Do(context.Background(), func() error {
@ -1005,7 +1042,9 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, schema *s
return err
}
log.Info("segment imported and persisted", zap.Int64("segmentID", segmentID))
res.Segments = append(res.Segments, segmentID)
res.RowCount += int64(rowNum)
return nil
}
}

View File

@ -52,6 +52,10 @@ import (
"go.uber.org/zap"
)
const returnError = "ReturnError"
type ctxKey struct{}
func TestMain(t *testing.M) {
rand.Seed(time.Now().Unix())
path := "/tmp/milvus_ut/rdb_data"
@ -321,6 +325,10 @@ func TestDataNode(t *testing.T) {
})
t.Run("Test Import", func(t *testing.T) {
node.rootCoord = &RootCoordFactory{
collectionID: 100,
pkType: schemapb.DataType_Int64,
}
content := []byte(`{
"rows":[
{"bool_field": true, "int8_field": 10, "int16_field": 101, "int32_field": 1001, "int64_field": 10001, "float32_field": 3.14, "float64_field": 1.56, "varChar_field": "hello world", "binary_vector_field": [254, 0, 254, 0], "float_vector_field": [1.1, 1.2]},
@ -338,13 +346,47 @@ func TestDataNode(t *testing.T) {
ImportTask: &datapb.ImportTask{
CollectionId: 100,
PartitionId: 100,
ChannelNames: []string{"ch1", "ch2"},
Files: []string{filePath},
RowBased: true,
},
}
stat, err := node.Import(node.ctx, req)
stat, err := node.Import(context.WithValue(ctx, ctxKey{}, ""), req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stat.ErrorCode)
assert.Equal(t, commonpb.ErrorCode_Success, stat.GetErrorCode())
assert.Equal(t, "", stat.GetReason())
})
t.Run("Test Import report import error", func(t *testing.T) {
node.rootCoord = &RootCoordFactory{
collectionID: 100,
pkType: schemapb.DataType_Int64,
}
content := []byte(`{
"rows":[
{"bool_field": true, "int8_field": 10, "int16_field": 101, "int32_field": 1001, "int64_field": 10001, "float32_field": 3.14, "float64_field": 1.56, "varChar_field": "hello world", "binary_vector_field": [254, 0, 254, 0], "float_vector_field": [1.1, 1.2]},
{"bool_field": false, "int8_field": 11, "int16_field": 102, "int32_field": 1002, "int64_field": 10002, "float32_field": 3.15, "float64_field": 2.56, "varChar_field": "hello world", "binary_vector_field": [253, 0, 253, 0], "float_vector_field": [2.1, 2.2]},
{"bool_field": true, "int8_field": 12, "int16_field": 103, "int32_field": 1003, "int64_field": 10003, "float32_field": 3.16, "float64_field": 3.56, "varChar_field": "hello world", "binary_vector_field": [252, 0, 252, 0], "float_vector_field": [3.1, 3.2]},
{"bool_field": false, "int8_field": 13, "int16_field": 104, "int32_field": 1004, "int64_field": 10004, "float32_field": 3.17, "float64_field": 4.56, "varChar_field": "hello world", "binary_vector_field": [251, 0, 251, 0], "float_vector_field": [4.1, 4.2]},
{"bool_field": true, "int8_field": 14, "int16_field": 105, "int32_field": 1005, "int64_field": 10005, "float32_field": 3.18, "float64_field": 5.56, "varChar_field": "hello world", "binary_vector_field": [250, 0, 250, 0], "float_vector_field": [5.1, 5.2]}
]
}`)
filePath := "import/rows_1.json"
err = node.chunkManager.Write(filePath, content)
assert.NoError(t, err)
req := &datapb.ImportTaskRequest{
ImportTask: &datapb.ImportTask{
CollectionId: 100,
PartitionId: 100,
ChannelNames: []string{"ch1", "ch2"},
Files: []string{filePath},
RowBased: true,
},
}
stat, err := node.Import(context.WithValue(node.ctx, ctxKey{}, returnError), req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stat.GetErrorCode())
})
t.Run("Test Import error", func(t *testing.T) {
@ -355,7 +397,7 @@ func TestDataNode(t *testing.T) {
PartitionId: 100,
},
}
stat, err := node.Import(node.ctx, req)
stat, err := node.Import(context.WithValue(ctx, ctxKey{}, ""), req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stat.ErrorCode)
})

View File

@ -283,6 +283,10 @@ func newDDNode(ctx context.Context, collID UniqueID, vchanInfo *datapb.VchannelI
return nil
}
pChannelName := funcutil.ToPhysicalChannel(vchanInfo.ChannelName)
log.Info("ddNode add flushed segment",
zap.String("channelName", vchanInfo.ChannelName),
zap.String("pChannelName", pChannelName),
)
deltaChannelName, err := funcutil.ConvertChannelName(pChannelName, Params.CommonCfg.RootCoordDml, Params.CommonCfg.RootCoordDelta)
if err != nil {
log.Error(err.Error())

View File

@ -21,6 +21,7 @@ import (
"context"
"encoding/binary"
"errors"
"fmt"
"math"
"math/rand"
"sync"
@ -226,6 +227,12 @@ func (ds *DataCoordFactory) DropVirtualChannel(ctx context.Context, req *datapb.
}, nil
}
func (ds *DataCoordFactory) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
func (mf *MetaFactory) GetCollectionMeta(collectionID UniqueID, collectionName string, pkDataType schemapb.DataType) *etcdpb.CollectionMeta {
sch := schemapb.CollectionSchema{
Name: collectionName,
@ -930,6 +937,16 @@ func (m *RootCoordFactory) GetComponentStates(ctx context.Context) (*internalpb.
}, nil
}
func (m *RootCoordFactory) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) (*commonpb.Status, error) {
v := ctx.Value(ctxKey{}).(string)
if v == returnError {
return nil, fmt.Errorf("injected error")
}
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
// FailMessageStreamFactory mock MessageStreamFactory failure
type FailMessageStreamFactory struct {
dependency.Factory

View File

@ -514,3 +514,17 @@ func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*da
}
return ret.(*datapb.ImportTaskResponse), err
}
// UpdateSegmentStatistics is the client side caller of UpdateSegmentStatistics.
func (c *Client) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) {
ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) {
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.(datapb.DataCoordClient).UpdateSegmentStatistics(ctx, req)
})
if err != nil || ret == nil {
return nil, err
}
return ret.(*commonpb.Status), err
}

View File

@ -120,6 +120,15 @@ func Test_NewClient(t *testing.T) {
r21, err := client.DropVirtualChannel(ctx, nil)
retCheck(retNotNil, r21, err)
r22, err := client.SetSegmentState(ctx, nil)
retCheck(retNotNil, r22, err)
r23, err := client.Import(ctx, nil)
retCheck(retNotNil, r23, err)
r24, err := client.UpdateSegmentStatistics(ctx, nil)
retCheck(retNotNil, r24, err)
}
client.grpcClient = &mock.ClientBase{

View File

@ -335,3 +335,8 @@ func (s *Server) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStat
func (s *Server) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) {
return s.dataCoord.Import(ctx, req)
}
// UpdateSegmentStatistics is the dataCoord service caller of UpdateSegmentStatistics.
func (s *Server) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) {
return s.dataCoord.UpdateSegmentStatistics(ctx, req)
}

View File

@ -57,6 +57,7 @@ type MockDataCoord struct {
dropVChanResp *datapb.DropVirtualChannelResponse
setSegmentStateResp *datapb.SetSegmentStateResponse
importResp *datapb.ImportTaskResponse
updateSegStatResp *commonpb.Status
}
func (m *MockDataCoord) Init() error {
@ -174,6 +175,10 @@ func (m *MockDataCoord) Import(ctx context.Context, req *datapb.ImportTaskReques
return m.importResp, m.err
}
func (m *MockDataCoord) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) {
return m.updateSegStatResp, m.err
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func Test_NewServer(t *testing.T) {
ctx := context.Background()
@ -393,7 +398,7 @@ func Test_NewServer(t *testing.T) {
assert.NotNil(t, resp)
})
t.Run("Import", func(t *testing.T) {
t.Run("import", func(t *testing.T) {
server.dataCoord = &MockDataCoord{
importResp: &datapb.ImportTaskResponse{
Status: &commonpb.Status{},
@ -404,6 +409,17 @@ func Test_NewServer(t *testing.T) {
assert.NotNil(t, resp)
})
t.Run("update seg stat", func(t *testing.T) {
server.dataCoord = &MockDataCoord{
updateSegStatResp: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
}
resp, err := server.UpdateSegmentStatistics(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, resp)
})
err := server.Stop()
assert.Nil(t, err)
}

View File

@ -498,6 +498,10 @@ func (m *MockDataCoord) Import(ctx context.Context, req *datapb.ImportTaskReques
return nil, nil
}
func (m *MockDataCoord) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) {
return nil, nil
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockProxy struct {
MockBase

View File

@ -9,6 +9,9 @@ import "internal.proto";
import "milvus.proto";
import "schema.proto";
// TODO: import google/protobuf/empty.proto
message Empty {}
service DataCoord {
rpc GetComponentStates(internal.GetComponentStatesRequest) returns (internal.ComponentStates) {}
rpc GetTimeTickChannel(internal.GetTimeTickChannelRequest) returns(milvus.StringResponse) {}
@ -45,6 +48,7 @@ service DataCoord {
rpc SetSegmentState(SetSegmentStateRequest) returns (SetSegmentStateResponse) {}
// https://wiki.lfaidata.foundation/display/MIL/MEP+24+--+Support+bulk+load
rpc Import(ImportTaskRequest) returns (ImportTaskResponse) {}
rpc UpdateSegmentStatistics(UpdateSegmentStatisticsRequest) returns (common.Status) {}
}
service DataNode {
@ -473,3 +477,8 @@ message ImportTaskRequest {
ImportTask import_task = 2; // Target import task.
repeated int64 working_nodes = 3; // DataNodes that are currently working.
}
message UpdateSegmentStatisticsRequest {
common.MsgBase base = 1;
repeated SegmentStats stats = 2;
}

File diff suppressed because it is too large Load Diff

View File

@ -183,7 +183,7 @@ func (mgr *singleTypeChannelsMgr) getLatestVID(collectionID UniqueID) (int, erro
ids, ok := mgr.collectionID2VIDs[collectionID]
if !ok || len(ids) <= 0 {
return 0, fmt.Errorf("collection %d not found", collectionID)
return 0, fmt.Errorf("v-channel ID is not found for collection %d", collectionID)
}
return ids[len(ids)-1], nil

View File

@ -214,6 +214,13 @@ func (coord *DataCoordMock) Import(ctx context.Context, req *datapb.ImportTaskRe
return &datapb.ImportTaskResponse{}, nil
}
func (coord *DataCoordMock) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
func NewDataCoordMock() *DataCoordMock {
return &DataCoordMock{
nodeID: typeutil.UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),

View File

@ -3922,7 +3922,9 @@ func unhealthyStatus() *commonpb.Status {
// Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments
func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) {
log.Info("received import request")
log.Info("received import request",
zap.String("collection name", req.GetCollectionName()),
zap.Bool("row-based", req.GetRowBased()))
resp := &milvuspb.ImportResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
@ -3945,20 +3947,21 @@ func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*mi
}
chNames, err := node.chMgr.getVChannels(collID)
if err != nil {
log.Error("get vChannels failed",
zap.Int64("collection ID", collID),
zap.Error(err))
resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
resp.Status.Reason = err.Error()
return resp, err
err = node.chMgr.createDMLMsgStream(collID)
if err != nil {
return nil, err
}
chNames, err = node.chMgr.getVChannels(collID)
if err != nil {
return nil, err
}
}
req.ChannelNames = chNames
if req.GetPartitionName() == "" {
req.PartitionName = Params.CommonCfg.DefaultPartitionName
}
// Call rootCoord to finish import.
resp, err = node.rootCoord.Import(ctx, req)
log.Info("received import response",
zap.String("collection name", req.GetCollectionName()),
zap.Any("resp", resp),
zap.Error(err))
return resp, err
}

View File

@ -3035,6 +3035,8 @@ func TestProxy_GetComponentStates_state_code(t *testing.T) {
func TestProxy_Import(t *testing.T) {
rc := NewRootCoordMock()
master := newMockGetChannelsService()
msgStreamFactory := newSimpleMockMsgStreamFactory()
rc.Start()
defer rc.Stop()
err := InitMetaCache(rc)
@ -3052,19 +3054,20 @@ func TestProxy_Import(t *testing.T) {
defer cancel()
factory := dependency.NewDefaultFactory(localMsg)
proxy, err := NewProxy(ctx, factory)
proxy.rootCoord = rc
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
t.Run("test import get vChannel failed", func(t *testing.T) {
t.Run("test import get vChannel failed (the first one)", func(t *testing.T) {
defer wg.Done()
proxy.stateCode.Store(internalpb.StateCode_Healthy)
proxy.chMgr = newChannelsMgrImpl(nil, nil, nil, nil, nil)
proxy.chMgr = newChannelsMgrImpl(master.GetChannels, nil, nil, nil, msgStreamFactory)
resp, err := proxy.Import(context.TODO(),
&milvuspb.ImportRequest{
CollectionName: "import_collection",
})
assert.EqualValues(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode)
assert.Error(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.NoError(t, err)
})
wg.Add(1)
t.Run("test import with unhealthy", func(t *testing.T) {

View File

@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
@ -37,6 +38,7 @@ import (
const (
Bucket = "bucket"
FailedReason = "failed_reason"
Files = "files"
MaxPendingCount = 32
delimiter = "/"
taskExpiredMsgPrefix = "task has expired after "
@ -187,7 +189,7 @@ func (m *importManager) genReqID() int64 {
// importJob processes the import request, generates import tasks, sends these tasks to DataCoord, and returns
// immediately.
func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportRequest, cID int64) *milvuspb.ImportResponse {
func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportRequest, cID int64, pID int64) *milvuspb.ImportResponse {
if req == nil || len(req.Files) == 0 {
return &milvuspb.ImportResponse{
Status: &commonpb.Status{
@ -210,11 +212,13 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Tasks: make([]int64, 0),
}
log.Debug("request received",
zap.String("collection name", req.GetCollectionName()),
zap.Int64("collection ID", cID))
zap.Int64("collection ID", cID),
zap.Int64("partition ID", pID))
func() {
m.pendingLock.Lock()
defer m.pendingLock.Unlock()
@ -254,6 +258,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
Id: m.nextTaskID,
RequestId: reqID,
CollectionId: cID,
PartitionId: pID,
ChannelNames: req.ChannelNames,
Bucket: bucket,
RowBased: req.GetRowBased(),
@ -263,6 +268,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
StateCode: commonpb.ImportState_ImportPending,
},
}
resp.Tasks = append(resp.Tasks, newTask.GetId())
taskList[i] = newTask.GetId()
m.nextTaskID++
log.Info("new task created as pending task", zap.Int64("task ID", newTask.GetId()))
@ -277,6 +283,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
Id: m.nextTaskID,
RequestId: reqID,
CollectionId: cID,
PartitionId: pID,
ChannelNames: req.ChannelNames,
Bucket: bucket,
RowBased: req.GetRowBased(),
@ -286,6 +293,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque
StateCode: commonpb.ImportState_ImportPending,
},
}
resp.Tasks = append(resp.Tasks, newTask.GetId())
m.nextTaskID++
log.Info("new task created as pending task", zap.Int64("task ID", newTask.GetId()))
m.pendingTasks = append(m.pendingTasks, newTask)
@ -345,6 +353,7 @@ func (m *importManager) getTaskState(tID int64) *milvuspb.GetImportStateResponse
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "import task id doesn't exist",
},
Infos: make([]*commonpb.KeyValuePair, 0),
}
log.Debug("getting import task state", zap.Int64("taskID", tID))
@ -358,6 +367,7 @@ func (m *importManager) getTaskState(tID int64) *milvuspb.GetImportStateResponse
ErrorCode: commonpb.ErrorCode_Success,
}
resp.State = commonpb.ImportState_ImportPending
resp.Infos = append(resp.Infos, &commonpb.KeyValuePair{Key: Files, Value: strings.Join(m.pendingTasks[i].GetFiles(), ",")})
found = true
break
}
@ -378,6 +388,7 @@ func (m *importManager) getTaskState(tID int64) *milvuspb.GetImportStateResponse
resp.State = v.GetState().GetStateCode()
resp.RowCount = v.GetState().GetRowCount()
resp.IdList = v.GetState().GetRowIds()
resp.Infos = append(resp.Infos, &commonpb.KeyValuePair{Key: Files, Value: strings.Join(v.GetFiles(), ",")})
resp.Infos = append(resp.Infos, &commonpb.KeyValuePair{
Key: FailedReason,
Value: v.GetState().GetErrorMessage(),

View File

@ -128,7 +128,7 @@ func TestImportManager_ImportJob(t *testing.T) {
mockKv := &kv.MockMetaKV{}
mockKv.InMemKv = make(map[string]string)
mgr := newImportManager(context.TODO(), mockKv, nil)
resp := mgr.importJob(context.TODO(), nil, colID)
resp := mgr.importJob(context.TODO(), nil, colID, 0)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
rowReq := &milvuspb.ImportRequest{
@ -138,7 +138,7 @@ func TestImportManager_ImportJob(t *testing.T) {
Files: []string{"f1", "f2", "f3"},
}
resp = mgr.importJob(context.TODO(), rowReq, colID)
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
colReq := &milvuspb.ImportRequest{
@ -163,12 +163,12 @@ func TestImportManager_ImportJob(t *testing.T) {
}
mgr = newImportManager(context.TODO(), mockKv, fn)
resp = mgr.importJob(context.TODO(), rowReq, colID)
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
assert.Equal(t, len(rowReq.Files), len(mgr.pendingTasks))
assert.Equal(t, 0, len(mgr.workingTasks))
mgr = newImportManager(context.TODO(), mockKv, fn)
resp = mgr.importJob(context.TODO(), colReq, colID)
resp = mgr.importJob(context.TODO(), colReq, colID, 0)
assert.Equal(t, 1, len(mgr.pendingTasks))
assert.Equal(t, 0, len(mgr.workingTasks))
@ -181,12 +181,12 @@ func TestImportManager_ImportJob(t *testing.T) {
}
mgr = newImportManager(context.TODO(), mockKv, fn)
resp = mgr.importJob(context.TODO(), rowReq, colID)
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
assert.Equal(t, 0, len(mgr.pendingTasks))
assert.Equal(t, len(rowReq.Files), len(mgr.workingTasks))
mgr = newImportManager(context.TODO(), mockKv, fn)
resp = mgr.importJob(context.TODO(), colReq, colID)
resp = mgr.importJob(context.TODO(), colReq, colID, 0)
assert.Equal(t, 0, len(mgr.pendingTasks))
assert.Equal(t, 1, len(mgr.workingTasks))
@ -208,7 +208,7 @@ func TestImportManager_ImportJob(t *testing.T) {
}
mgr = newImportManager(context.TODO(), mockKv, fn)
resp = mgr.importJob(context.TODO(), rowReq, colID)
resp = mgr.importJob(context.TODO(), rowReq, colID, 0)
assert.Equal(t, len(rowReq.Files)-2, len(mgr.pendingTasks))
assert.Equal(t, 2, len(mgr.workingTasks))
}
@ -234,7 +234,7 @@ func TestImportManager_TaskState(t *testing.T) {
}
mgr := newImportManager(context.TODO(), mockKv, fn)
mgr.importJob(context.TODO(), rowReq, colID)
mgr.importJob(context.TODO(), rowReq, colID, 0)
state := &rootcoordpb.ImportResult{
TaskId: 10000,

View File

@ -2246,13 +2246,20 @@ func (c *Core) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvus
zap.String("collection name", req.GetCollectionName()))
return nil, fmt.Errorf("collection ID not found for collection name %s", req.GetCollectionName())
}
var pID int64
var err error
if pID, err = c.MetaTable.getPartitionByName(cID, req.GetPartitionName(), 0); err != nil {
return nil, err
}
log.Info("receive import request",
zap.String("collection name", req.GetCollectionName()),
zap.Int64("collection ID", cID),
zap.String("partition name", req.GetPartitionName()),
zap.Int64("partition ID", pID),
zap.Int("# of files = ", len(req.GetFiles())),
zap.Bool("row-based", req.GetRowBased()),
)
resp := c.importManager.importJob(ctx, req, cID)
resp := c.importManager.importJob(ctx, req, cID, pID)
return resp, nil
}
@ -2317,9 +2324,10 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) (
zap.Int64("task ID", ir.GetTaskId()))
}()
// TODO: Resurrect index check when ready.
// Start a loop to check segments' index states periodically.
c.wg.Add(1)
go c.checkCompleteIndexLoop(ctx, ti, colName, ir.Segments)
// c.wg.Add(1)
// go c.checkCompleteIndexLoop(ctx, ti, colName, ir.Segments)
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
@ -2458,8 +2466,8 @@ func (c *Core) bringSegmentsOnline(ctx context.Context, segIDs []UniqueID) {
log.Info("bringing import task's segments online!", zap.Any("segment IDs", segIDs))
// TODO: Make update on segment states atomic.
for _, id := range segIDs {
// Explicitly mark segment states `flushed`.
c.CallUpdateSegmentStateService(ctx, id, commonpb.SegmentState_Flushed)
// Explicitly mark segment states `flushing`.
c.CallUpdateSegmentStateService(ctx, id, commonpb.SegmentState_Flushing)
}
}

View File

@ -1368,10 +1368,15 @@ func TestRootCoord_Base(t *testing.T) {
Files: []string{"f1", "f2", "f3"},
}
core.MetaTable.collName2ID["new"+collName] = 123
core.MetaTable.collID2Meta[123] = etcdpb.CollectionInfo{
ID: 123,
PartitionIDs: []int64{456},
PartitionNames: []string{"testPartition"}}
rsp, err := core.Import(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode)
delete(core.MetaTable.collName2ID, "new"+collName)
delete(core.MetaTable.collID2Meta, 123)
reqIR := &rootcoordpb.ImportResult{
TaskId: 3,

View File

@ -278,6 +278,9 @@ type DataCoord interface {
// the `tasks` in `ImportResponse` return an id list of tasks.
// error is always nil
Import(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error)
// UpdateSegmentStatistics updates a segment's stats.
UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error)
}
// DataCoordComponent defines the interface of DataCoord component.

View File

@ -29,15 +29,15 @@ type ImportWrapper struct {
cancel context.CancelFunc // for canceling parse process
collectionSchema *schemapb.CollectionSchema // collection schema
shardNum int32 // sharding number of the collection
segmentSize int64 // maximum size of a segment in MB
segmentSize int64 // maximum size of a segment(unit:byte)
rowIDAllocator *allocator.IDAllocator // autoid allocator
chunkManager storage.ChunkManager
callFlushFunc func(fields map[storage.FieldID]storage.FieldData) error // call back function to flush a segment
callFlushFunc ImportFlushFunc // call back function to flush a segment
}
func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int64,
idAlloc *allocator.IDAllocator, cm storage.ChunkManager, flushFunc func(fields map[storage.FieldID]storage.FieldData) error) *ImportWrapper {
idAlloc *allocator.IDAllocator, cm storage.ChunkManager, flushFunc ImportFlushFunc) *ImportWrapper {
if collectionSchema == nil {
log.Error("import error: collection schema is nil")
return nil
@ -89,7 +89,7 @@ func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[storage.FieldID]stora
if len(files) > 0 {
stats = append(stats, zap.Any("files", files))
}
log.Debug(msg, stats...)
log.Info(msg, stats...)
}
func getFileNameAndExt(filePath string) (string, string) {
@ -110,7 +110,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
for i := 0; i < len(filePaths); i++ {
filePath := filePaths[i]
_, fileType := getFileNameAndExt(filePath)
log.Debug("imprort wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
log.Info("imprort wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
if fileType == JSONFileExt {
err := func() error {
@ -124,9 +124,9 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
parser := NewJSONParser(p.ctx, p.collectionSchema)
var consumer *JSONRowConsumer
if !onlyValidate {
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error {
p.printFieldsDataInfo(fields, "import wrapper: prepare to flush segment", filePaths)
return p.callFlushFunc(fields)
return p.callFlushFunc(fields, shardNum)
}
consumer = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, p.segmentSize, flushFunc)
}
@ -147,7 +147,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
}
}
} else {
// parse and consume row-based files
// parse and consume column-based files
// for column-based files, the XXXColumnConsumer only output map[string]storage.FieldData
// after all columns are parsed/consumed, we need to combine map[string]storage.FieldData into one
// and use splitFieldsData() to split fields data into segments according to shard number
@ -193,7 +193,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
for i := 0; i < len(filePaths); i++ {
filePath := filePaths[i]
fileName, fileType := getFileNameAndExt(filePath)
log.Debug("imprort wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
log.Info("imprort wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
if fileType == JSONFileExt {
err := func() error {
@ -434,7 +434,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F
for i := 0; i < int(p.shardNum); i++ {
segmentData := segmentsData[i]
p.printFieldsDataInfo(segmentData, "import wrapper: prepare to flush segment", files)
err := p.callFlushFunc(segmentData)
err := p.callFlushFunc(segmentData, i)
if err != nil {
return err
}

View File

@ -78,7 +78,7 @@ func Test_ImportRowBased(t *testing.T) {
defer cm.RemoveWithPrefix("")
rowCount := 0
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
@ -122,7 +122,6 @@ func Test_ImportRowBased(t *testing.T) {
files = append(files, "/dummy/dummy.json")
err = wrapper.Import(files, true, false)
assert.NotNil(t, err)
}
func Test_ImportColumnBased_json(t *testing.T) {
@ -164,7 +163,7 @@ func Test_ImportColumnBased_json(t *testing.T) {
assert.NoError(t, err)
rowCount := 0
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
@ -254,7 +253,7 @@ func Test_ImportColumnBased_numpy(t *testing.T) {
files = append(files, filePath)
rowCount := 0
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
@ -388,7 +387,7 @@ func Test_ImportRowBased_perf(t *testing.T) {
// parse the json file
parseCount := 0
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
@ -486,7 +485,7 @@ func Test_ImportColumnBased_perf(t *testing.T) {
// parse the json file
parseCount := 0
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())

View File

@ -272,13 +272,13 @@ func (v *JSONRowValidator) ValidateCount() int64 {
}
func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
if v == nil || v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON row validator is not initialized")
}
// parse completed
if rows == nil {
log.Debug("JSON row validation finished")
log.Info("JSON row validation finished")
if v.downstream != nil {
return v.downstream.Handle(rows)
}
@ -287,6 +287,7 @@ func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error
for i := 0; i < len(rows); i++ {
row := rows[i]
for id, validator := range v.validators {
if validator.primaryKey && validator.autoID {
// auto-generated primary key, ignore
@ -335,7 +336,7 @@ func (v *JSONColumnValidator) ValidateCount() map[storage.FieldID]int64 {
}
func (v *JSONColumnValidator) Handle(columns map[storage.FieldID][]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
if v == nil || v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON column validator is not initialized")
}
@ -352,7 +353,7 @@ func (v *JSONColumnValidator) Handle(columns map[storage.FieldID][]interface{})
}
// let the downstream know parse is completed
log.Debug("JSON column validation finished")
log.Info("JSON column validation finished")
if v.downstream != nil {
return v.downstream.Handle(nil)
}
@ -382,6 +383,8 @@ func (v *JSONColumnValidator) Handle(columns map[storage.FieldID][]interface{})
return nil
}
type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardNum int) error
// row-based json format consumer class
type JSONRowConsumer struct {
collectionSchema *schemapb.CollectionSchema // collection schema
@ -390,10 +393,10 @@ type JSONRowConsumer struct {
rowCounter int64 // how many rows have been consumed
shardNum int32 // sharding number of the collection
segmentsData []map[storage.FieldID]storage.FieldData // in-memory segments data
segmentSize int64 // maximum size of a segment in MB
segmentSize int64 // maximum size of a segment(unit:byte)
primaryKey storage.FieldID // name of primary key
callFlushFunc func(fields map[storage.FieldID]storage.FieldData) error // call back function to flush segment
callFlushFunc ImportFlushFunc // call back function to flush segment
}
func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.FieldID]storage.FieldData {
@ -465,7 +468,7 @@ func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.Fi
}
func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, segmentSize int64,
flushFunc func(fields map[storage.FieldID]storage.FieldData) error) *JSONRowConsumer {
flushFunc ImportFlushFunc) *JSONRowConsumer {
if collectionSchema == nil {
log.Error("JSON row consumer: collection schema is nil")
return nil
@ -521,8 +524,8 @@ func (v *JSONRowConsumer) flush(force bool) error {
segmentData := v.segmentsData[i]
rowNum := segmentData[v.primaryKey].RowNum()
if rowNum > 0 {
log.Debug("JSON row consumer: force flush segment", zap.Int("rows", rowNum))
v.callFlushFunc(segmentData)
log.Info("JSON row consumer: force flush segment", zap.Int("rows", rowNum))
v.callFlushFunc(segmentData, i)
}
}
@ -532,13 +535,14 @@ func (v *JSONRowConsumer) flush(force bool) error {
// segment size can be flushed
for i := 0; i < len(v.segmentsData); i++ {
segmentData := v.segmentsData[i]
rowNum := segmentData[v.primaryKey].RowNum()
memSize := 0
for _, field := range segmentData {
memSize += field.GetMemorySize()
}
if memSize >= int(v.segmentSize)*1024*1024 {
log.Debug("JSON row consumer: flush fulled segment", zap.Int("bytes", memSize))
v.callFlushFunc(segmentData)
if memSize >= int(v.segmentSize) && rowNum > 0 {
log.Info("JSON row consumer: flush fulled segment", zap.Int("bytes", memSize), zap.Int("rowNum", rowNum))
v.callFlushFunc(segmentData, i)
v.segmentsData[i] = initSegmentData(v.collectionSchema)
}
}
@ -547,14 +551,14 @@ func (v *JSONRowConsumer) flush(force bool) error {
}
func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
if v == nil || v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON row consumer is not initialized")
}
// flush in necessery
if rows == nil {
err := v.flush(true)
log.Debug("JSON row consumer finished")
log.Info("JSON row consumer finished")
return err
}
@ -596,6 +600,7 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
shard := hash % uint32(v.shardNum)
pkArray := v.segmentsData[shard][v.primaryKey].(*storage.Int64FieldData)
pkArray.Data = append(pkArray.Data, id)
pkArray.NumRows[0]++
// convert value and consume
for name, validator := range v.validators {
@ -614,6 +619,8 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
return nil
}
type ColumnFlushFunc func(fields map[storage.FieldID]storage.FieldData) error
// column-based json format consumer class
type JSONColumnConsumer struct {
collectionSchema *schemapb.CollectionSchema // collection schema
@ -621,11 +628,10 @@ type JSONColumnConsumer struct {
fieldsData map[storage.FieldID]storage.FieldData // in-memory fields data
primaryKey storage.FieldID // name of primary key
callFlushFunc func(fields map[storage.FieldID]storage.FieldData) error // call back function to flush segment
callFlushFunc ColumnFlushFunc // call back function to flush segment
}
func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema,
flushFunc func(fields map[storage.FieldID]storage.FieldData) error) *JSONColumnConsumer {
func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema, flushFunc ColumnFlushFunc) *JSONColumnConsumer {
if collectionSchema == nil {
return nil
}
@ -674,21 +680,21 @@ func (v *JSONColumnConsumer) flush() error {
if rowCount == 0 {
return errors.New("JSON column consumer: row count is 0")
}
log.Debug("JSON column consumer: rows parsed", zap.Int("rowCount", rowCount))
log.Info("JSON column consumer: rows parsed", zap.Int("rowCount", rowCount))
// output the fileds data, let outside split them into segments
return v.callFlushFunc(v.fieldsData)
}
func (v *JSONColumnConsumer) Handle(columns map[storage.FieldID][]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
if v == nil || v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON column consumer is not initialized")
}
// flush at the end
if columns == nil {
err := v.flush()
log.Debug("JSON column consumer finished")
log.Info("JSON column consumer finished")
return err
}

View File

@ -313,9 +313,11 @@ func Test_JSONRowConsumer(t *testing.T) {
]
}`)
var shardNum int32 = 2
var callTime int32
var totalCount int
consumeFunc := func(fields map[storage.FieldID]storage.FieldData) error {
consumeFunc := func(fields map[storage.FieldID]storage.FieldData, shard int) error {
assert.Equal(t, int(callTime), shard)
callTime++
rowCount := 0
for _, data := range fields {
@ -329,7 +331,6 @@ func Test_JSONRowConsumer(t *testing.T) {
return nil
}
var shardNum int32 = 2
consumer := NewJSONRowConsumer(schema, idAllocator, shardNum, 1, consumeFunc)
assert.NotNil(t, consumer)

View File

@ -126,3 +126,7 @@ func (m *DataCoordClient) SetSegmentState(ctx context.Context, req *datapb.SetSe
func (m *DataCoordClient) Import(ctx context.Context, req *datapb.ImportTaskRequest, opts ...grpc.CallOption) (*datapb.ImportTaskResponse, error) {
return &datapb.ImportTaskResponse{}, m.Err
}
func (m *DataCoordClient) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}