mirror of https://github.com/milvus-io/milvus.git
enhance: use the msg position obj when getting replicate channel position (#35606)
/kind improvement Signed-off-by: SimFG <bang.fu@zilliz.com>pull/35721/head
parent
0e7877d413
commit
3e1052f889
|
@ -6046,10 +6046,9 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
|
|||
}, nil
|
||||
}
|
||||
var err error
|
||||
ctxLog := log.Ctx(ctx)
|
||||
|
||||
if req.GetChannelName() == "" {
|
||||
ctxLog.Warn("channel name is empty")
|
||||
log.Ctx(ctx).Warn("channel name is empty")
|
||||
return &milvuspb.ReplicateMessageResponse{
|
||||
Status: merr.Status(merr.WrapErrParameterInvalidMsg("invalid channel name for the replicate message request")),
|
||||
}, nil
|
||||
|
@ -6060,11 +6059,22 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
|
|||
if req.GetChannelName() == replicateMsgChannel {
|
||||
msgID, err := msgstream.GetChannelLatestMsgID(ctx, node.factory, replicateMsgChannel)
|
||||
if err != nil {
|
||||
ctxLog.Warn("failed to get the latest message id of the replicate msg channel", zap.Error(err))
|
||||
log.Ctx(ctx).Warn("failed to get the latest message id of the replicate msg channel", zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
position := base64.StdEncoding.EncodeToString(msgID)
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(nil), Position: position}, nil
|
||||
position := &msgpb.MsgPosition{
|
||||
ChannelName: replicateMsgChannel,
|
||||
MsgID: msgID,
|
||||
}
|
||||
positionBytes, err := proto.Marshal(position)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to marshal position", zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
return &milvuspb.ReplicateMessageResponse{
|
||||
Status: merr.Status(nil),
|
||||
Position: base64.StdEncoding.EncodeToString(positionBytes),
|
||||
}, nil
|
||||
}
|
||||
|
||||
msgPack := &msgstream.MsgPack{
|
||||
|
@ -6079,16 +6089,16 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
|
|||
header := commonpb.MsgHeader{}
|
||||
err = proto.Unmarshal(msgBytes, &header)
|
||||
if err != nil {
|
||||
ctxLog.Warn("failed to unmarshal msg header", zap.Int("index", i), zap.Error(err))
|
||||
log.Ctx(ctx).Warn("failed to unmarshal msg header", zap.Int("index", i), zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
if header.GetBase() == nil {
|
||||
ctxLog.Warn("msg header base is nil", zap.Int("index", i))
|
||||
log.Ctx(ctx).Warn("msg header base is nil", zap.Int("index", i))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil
|
||||
}
|
||||
tsMsg, err := node.replicateStreamManager.GetMsgDispatcher().Unmarshal(msgBytes, header.GetBase().GetMsgType())
|
||||
if err != nil {
|
||||
ctxLog.Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err))
|
||||
log.Ctx(ctx).Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil
|
||||
}
|
||||
switch realMsg := tsMsg.(type) {
|
||||
|
@ -6096,11 +6106,11 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
|
|||
assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(),
|
||||
realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs)
|
||||
if err != nil {
|
||||
ctxLog.Warn("failed to get segment id", zap.Error(err))
|
||||
log.Ctx(ctx).Warn("failed to get segment id", zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
if len(assignedSegmentInfos) == 0 {
|
||||
ctxLog.Warn("no segment id assigned")
|
||||
log.Ctx(ctx).Warn("no segment id assigned")
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrNoAssignSegmentID)}, nil
|
||||
}
|
||||
for assignSegmentID := range assignedSegmentInfos {
|
||||
|
@ -6113,19 +6123,19 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
|
|||
|
||||
msgStream, err := node.replicateStreamManager.GetReplicateMsgStream(ctx, req.ChannelName)
|
||||
if err != nil {
|
||||
ctxLog.Warn("failed to get msg stream from the replicate stream manager", zap.Error(err))
|
||||
log.Ctx(ctx).Warn("failed to get msg stream from the replicate stream manager", zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{
|
||||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
messageIDsMap, err := msgStream.Broadcast(msgPack)
|
||||
if err != nil {
|
||||
ctxLog.Warn("failed to produce msg", zap.Error(err))
|
||||
log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
var position string
|
||||
if len(messageIDsMap[req.GetChannelName()]) == 0 {
|
||||
ctxLog.Warn("no message id returned")
|
||||
log.Ctx(ctx).Warn("no message id returned")
|
||||
} else {
|
||||
messageIDs := messageIDsMap[req.GetChannelName()]
|
||||
position = base64.StdEncoding.EncodeToString(messageIDs[len(messageIDs)-1].Serialize())
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
|
@ -1376,6 +1377,21 @@ func TestProxy_ReplicateMessage(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("get latest position", func(t *testing.T) {
|
||||
base64DecodeMsgPosition := func(position string) (*msgstream.MsgPosition, error) {
|
||||
decodeBytes, err := base64.StdEncoding.DecodeString(position)
|
||||
if err != nil {
|
||||
log.Warn("fail to decode the position", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
msgPosition := &msgstream.MsgPosition{}
|
||||
err = proto.Unmarshal(decodeBytes, msgPosition)
|
||||
if err != nil {
|
||||
log.Warn("fail to unmarshal the position", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return msgPosition, nil
|
||||
}
|
||||
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
|
||||
defer paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
|
||||
|
||||
|
@ -1397,7 +1413,11 @@ func TestProxy_ReplicateMessage(t *testing.T) {
|
|||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 0, resp.GetStatus().GetCode())
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString([]byte("mock")), resp.GetPosition())
|
||||
{
|
||||
p, err := base64DecodeMsgPosition(resp.GetPosition())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("mock"), p.MsgID)
|
||||
}
|
||||
|
||||
factory.EXPECT().NewMsgStream(mock.Anything).Return(nil, errors.New("mock")).Once()
|
||||
resp, err = node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
|
||||
|
|
Loading…
Reference in New Issue