fix: [2.4] Check nodeID before update channel checkpoint (#31473) (#31507)

Cherry-pick from master
pr: #31473
See also #31470 #31506

This PR adds nodeID assignment verification before updating channel
checkpoints.

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/31509/head
congqixia 2024-03-22 15:39:06 +08:00 committed by GitHub
parent 957765042d
commit 278233391f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 79 additions and 11 deletions

View File

@ -3312,43 +3312,96 @@ func TestDataCoord_UnsetIsImportingState(t *testing.T) {
func TestDataCoordServer_UpdateChannelCheckpoint(t *testing.T) {
mockVChannel := "fake-by-dev-rootcoord-dml-1-testchannelcp-v0"
mockPChannel := "fake-by-dev-rootcoord-dml-1"
t.Run("UpdateChannelCheckpoint", func(t *testing.T) {
t.Run("UpdateChannelCheckpoint_Success", func(t *testing.T) {
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
datanodeID := int64(1)
channelManager := NewMockChannelManager(t)
channelManager.EXPECT().Match(datanodeID, mockVChannel).Return(true)
svr.channelManager = channelManager
req := &datapb.UpdateChannelCheckpointRequest{
Base: &commonpb.MsgBase{
SourceID: paramtable.GetNodeID(),
SourceID: datanodeID,
},
VChannel: mockVChannel,
Position: &msgpb.MsgPosition{
ChannelName: mockPChannel,
ChannelName: mockVChannel,
Timestamp: 1000,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
},
}
resp, err := svr.UpdateChannelCheckpoint(context.TODO(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.ErrorCode)
assert.NoError(t, merr.CheckRPCCall(resp, err))
cp := svr.meta.GetChannelCheckpoint(mockVChannel)
assert.NotNil(t, cp)
svr.meta.DropChannelCheckpoint(mockVChannel)
req = &datapb.UpdateChannelCheckpointRequest{
Base: &commonpb.MsgBase{
SourceID: paramtable.GetNodeID(),
SourceID: datanodeID,
},
VChannel: mockVChannel,
ChannelCheckpoints: []*msgpb.MsgPosition{{
ChannelName: mockPChannel,
ChannelName: mockVChannel,
Timestamp: 1000,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
}},
}
resp, err = svr.UpdateChannelCheckpoint(context.TODO(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.ErrorCode)
assert.NoError(t, merr.CheckRPCCall(resp, err))
cp = svr.meta.GetChannelCheckpoint(mockVChannel)
assert.NotNil(t, cp)
})
t.Run("UpdateChannelCheckpoint_NodeNotMatch", func(t *testing.T) {
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
datanodeID := int64(1)
channelManager := NewMockChannelManager(t)
channelManager.EXPECT().Match(datanodeID, mockVChannel).Return(false)
svr.channelManager = channelManager
req := &datapb.UpdateChannelCheckpointRequest{
Base: &commonpb.MsgBase{
SourceID: datanodeID,
},
VChannel: mockVChannel,
Position: &msgpb.MsgPosition{
ChannelName: mockVChannel,
Timestamp: 1000,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
},
}
resp, err := svr.UpdateChannelCheckpoint(context.TODO(), req)
assert.Error(t, merr.CheckRPCCall(resp, err))
assert.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrChannelNotFound)
cp := svr.meta.GetChannelCheckpoint(mockVChannel)
assert.Nil(t, cp)
req = &datapb.UpdateChannelCheckpointRequest{
Base: &commonpb.MsgBase{
SourceID: datanodeID,
},
VChannel: mockVChannel,
ChannelCheckpoints: []*msgpb.MsgPosition{{
ChannelName: mockVChannel,
Timestamp: 1000,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
}},
}
resp, err = svr.UpdateChannelCheckpoint(context.TODO(), req)
assert.NoError(t, merr.CheckRPCCall(resp, err))
cp = svr.meta.GetChannelCheckpoint(mockVChannel)
assert.Nil(t, cp)
})
}

View File

@ -1464,8 +1464,14 @@ func (s *Server) UpdateChannelCheckpoint(ctx context.Context, req *datapb.Update
return merr.Status(err), nil
}
nodeID := req.GetBase().GetSourceID()
// For compatibility with old client
if req.GetVChannel() != "" && req.GetPosition() != nil {
channel := req.GetVChannel()
if !s.channelManager.Match(nodeID, channel) {
log.Warn("node is not matched with channel", zap.String("channel", channel), zap.Int64("nodeID", nodeID))
return merr.Status(merr.WrapErrChannelNotFound(channel, fmt.Sprintf("from node %d", nodeID))), nil
}
err := s.meta.UpdateChannelCheckpoint(req.GetVChannel(), req.GetPosition())
if err != nil {
log.Warn("failed to UpdateChannelCheckpoint", zap.String("vChannel", req.GetVChannel()), zap.Error(err))
@ -1474,7 +1480,16 @@ func (s *Server) UpdateChannelCheckpoint(ctx context.Context, req *datapb.Update
return merr.Success(), nil
}
err := s.meta.UpdateChannelCheckpoints(req.GetChannelCheckpoints())
checkpoints := lo.Filter(req.GetChannelCheckpoints(), func(cp *msgpb.MsgPosition, _ int) bool {
channel := cp.GetChannelName()
matched := s.channelManager.Match(nodeID, channel)
if !matched {
log.Warn("node is not matched with channel", zap.String("channel", channel), zap.Int64("nodeID", nodeID))
}
return matched
})
err := s.meta.UpdateChannelCheckpoints(checkpoints)
if err != nil {
log.Warn("failed to update channel checkpoint", zap.Error(err))
return merr.Status(err), nil