enhance: use the msg position obj when getting replicate channel position (#35606)

/kind improvement

Signed-off-by: SimFG <bang.fu@zilliz.com>
pull/35721/head
SimFG 2024-08-27 10:28:59 +08:00 committed by GitHub
parent 0e7877d413
commit 3e1052f889
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 14 deletions

View File

@ -6046,10 +6046,9 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
}, nil
}
var err error
ctxLog := log.Ctx(ctx)
if req.GetChannelName() == "" {
ctxLog.Warn("channel name is empty")
log.Ctx(ctx).Warn("channel name is empty")
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(merr.WrapErrParameterInvalidMsg("invalid channel name for the replicate message request")),
}, nil
@ -6060,11 +6059,22 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
if req.GetChannelName() == replicateMsgChannel {
msgID, err := msgstream.GetChannelLatestMsgID(ctx, node.factory, replicateMsgChannel)
if err != nil {
ctxLog.Warn("failed to get the latest message id of the replicate msg channel", zap.Error(err))
log.Ctx(ctx).Warn("failed to get the latest message id of the replicate msg channel", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
position := base64.StdEncoding.EncodeToString(msgID)
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(nil), Position: position}, nil
position := &msgpb.MsgPosition{
ChannelName: replicateMsgChannel,
MsgID: msgID,
}
positionBytes, err := proto.Marshal(position)
if err != nil {
log.Ctx(ctx).Warn("failed to marshal position", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(nil),
Position: base64.StdEncoding.EncodeToString(positionBytes),
}, nil
}
msgPack := &msgstream.MsgPack{
@ -6079,16 +6089,16 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
header := commonpb.MsgHeader{}
err = proto.Unmarshal(msgBytes, &header)
if err != nil {
ctxLog.Warn("failed to unmarshal msg header", zap.Int("index", i), zap.Error(err))
log.Ctx(ctx).Warn("failed to unmarshal msg header", zap.Int("index", i), zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
if header.GetBase() == nil {
ctxLog.Warn("msg header base is nil", zap.Int("index", i))
log.Ctx(ctx).Warn("msg header base is nil", zap.Int("index", i))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil
}
tsMsg, err := node.replicateStreamManager.GetMsgDispatcher().Unmarshal(msgBytes, header.GetBase().GetMsgType())
if err != nil {
ctxLog.Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err))
log.Ctx(ctx).Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil
}
switch realMsg := tsMsg.(type) {
@ -6096,11 +6106,11 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(),
realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs)
if err != nil {
ctxLog.Warn("failed to get segment id", zap.Error(err))
log.Ctx(ctx).Warn("failed to get segment id", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
if len(assignedSegmentInfos) == 0 {
ctxLog.Warn("no segment id assigned")
log.Ctx(ctx).Warn("no segment id assigned")
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrNoAssignSegmentID)}, nil
}
for assignSegmentID := range assignedSegmentInfos {
@ -6113,19 +6123,19 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
msgStream, err := node.replicateStreamManager.GetReplicateMsgStream(ctx, req.ChannelName)
if err != nil {
ctxLog.Warn("failed to get msg stream from the replicate stream manager", zap.Error(err))
log.Ctx(ctx).Warn("failed to get msg stream from the replicate stream manager", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(err),
}, nil
}
messageIDsMap, err := msgStream.Broadcast(msgPack)
if err != nil {
ctxLog.Warn("failed to produce msg", zap.Error(err))
log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
var position string
if len(messageIDsMap[req.GetChannelName()]) == 0 {
ctxLog.Warn("no message id returned")
log.Ctx(ctx).Warn("no message id returned")
} else {
messageIDs := messageIDsMap[req.GetChannelName()]
position = base64.StdEncoding.EncodeToString(messageIDs[len(messageIDs)-1].Serialize())

View File

@ -29,6 +29,7 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -1376,6 +1377,21 @@ func TestProxy_ReplicateMessage(t *testing.T) {
})
t.Run("get latest position", func(t *testing.T) {
base64DecodeMsgPosition := func(position string) (*msgstream.MsgPosition, error) {
decodeBytes, err := base64.StdEncoding.DecodeString(position)
if err != nil {
log.Warn("fail to decode the position", zap.Error(err))
return nil, err
}
msgPosition := &msgstream.MsgPosition{}
err = proto.Unmarshal(decodeBytes, msgPosition)
if err != nil {
log.Warn("fail to unmarshal the position", zap.Error(err))
return nil, err
}
return msgPosition, nil
}
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
@ -1397,7 +1413,11 @@ func TestProxy_ReplicateMessage(t *testing.T) {
})
assert.NoError(t, err)
assert.EqualValues(t, 0, resp.GetStatus().GetCode())
assert.Equal(t, base64.StdEncoding.EncodeToString([]byte("mock")), resp.GetPosition())
{
p, err := base64DecodeMsgPosition(resp.GetPosition())
assert.NoError(t, err)
assert.Equal(t, []byte("mock"), p.MsgID)
}
factory.EXPECT().NewMsgStream(mock.Anything).Return(nil, errors.New("mock")).Once()
resp, err = node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{