Fix bug: check message payload before unmarshaling (#12315)

Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
pull/12303/head
zhenshan.cao 2021-11-29 14:31:18 +08:00 committed by GitHub
parent 3850979308
commit bdd39c0623
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 4 deletions

View File

@ -25,7 +25,6 @@ import (
"time"
"github.com/golang/protobuf/proto"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
@ -34,6 +33,7 @@ import (
"github.com/milvus-io/milvus/internal/util/mqclient"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/opentracing/opentracing-go"
)
var _ MsgStream = (*mqMsgStream)(nil)
@ -485,13 +485,19 @@ func (ms *mqMsgStream) Consume() *MsgPack {
func (ms *mqMsgStream) getTsMsgFromConsumerMsg(msg mqclient.Message) (TsMsg, error) {
header := commonpb.MsgHeader{}
if msg.Payload() == nil {
return nil, fmt.Errorf("failed to unmarshal message header, payload is empty")
}
err := proto.Unmarshal(msg.Payload(), &header)
if err != nil {
return nil, fmt.Errorf("Failed to unmarshal message header, err %s", err.Error())
return nil, fmt.Errorf("failed to unmarshal message header, err %s", err.Error())
}
if header.Base == nil {
return nil, fmt.Errorf("failed to unmarshal message, header is uncomplete")
}
tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), header.Base.MsgType)
if err != nil {
return nil, fmt.Errorf("Failed to unmarshal tsMsg, err %s", err.Error())
return nil, fmt.Errorf("failed to unmarshal tsMsg, err %s", err.Error())
}
// set msg info to tsMsg
@ -515,7 +521,10 @@ func (ms *mqMsgStream) receiveMsg(consumer mqclient.Consumer) {
return
}
consumer.Ack(msg)
if msg.Payload() == nil {
log.Warn("MqMsgStream get msg whose payload is nil")
continue
}
tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
if err != nil {
log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
@ -579,6 +588,10 @@ func (ms *mqMsgStream) Next(ctx context.Context, channelName string) (TsMsg, err
if err != nil {
return nil, err
}
if msg.Payload() == nil {
log.Warn("mqMsgStream reader Next get msg whose payload is nil")
return nil, nil
}
tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
if err != nil {
log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
@ -868,6 +881,10 @@ func (ms *MqTtMsgStream) consumeToTtMsg(consumer mqclient.Consumer) {
}
consumer.Ack(msg)
if msg.Payload() == nil {
log.Warn("MqTtMsgStream get msg whose payload is nil")
continue
}
tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
if err != nil {
log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err))