mirror of https://github.com/milvus-io/milvus.git
Fix bug: check message payload before unmarshaling (#12315)
Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>pull/12303/head
parent
3850979308
commit
bdd39c0623
|
@ -25,7 +25,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
|
@ -34,6 +33,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/util/mqclient"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
"github.com/milvus-io/milvus/internal/util/trace"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
var _ MsgStream = (*mqMsgStream)(nil)
|
||||
|
@ -485,13 +485,19 @@ func (ms *mqMsgStream) Consume() *MsgPack {
|
|||
|
||||
func (ms *mqMsgStream) getTsMsgFromConsumerMsg(msg mqclient.Message) (TsMsg, error) {
|
||||
header := commonpb.MsgHeader{}
|
||||
if msg.Payload() == nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal message header, payload is empty")
|
||||
}
|
||||
err := proto.Unmarshal(msg.Payload(), &header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to unmarshal message header, err %s", err.Error())
|
||||
return nil, fmt.Errorf("failed to unmarshal message header, err %s", err.Error())
|
||||
}
|
||||
if header.Base == nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal message, header is uncomplete")
|
||||
}
|
||||
tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), header.Base.MsgType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to unmarshal tsMsg, err %s", err.Error())
|
||||
return nil, fmt.Errorf("failed to unmarshal tsMsg, err %s", err.Error())
|
||||
}
|
||||
|
||||
// set msg info to tsMsg
|
||||
|
@ -515,7 +521,10 @@ func (ms *mqMsgStream) receiveMsg(consumer mqclient.Consumer) {
|
|||
return
|
||||
}
|
||||
consumer.Ack(msg)
|
||||
|
||||
if msg.Payload() == nil {
|
||||
log.Warn("MqMsgStream get msg whose payload is nil")
|
||||
continue
|
||||
}
|
||||
tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
|
||||
if err != nil {
|
||||
log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
|
||||
|
@ -579,6 +588,10 @@ func (ms *mqMsgStream) Next(ctx context.Context, channelName string) (TsMsg, err
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg.Payload() == nil {
|
||||
log.Warn("mqMsgStream reader Next get msg whose payload is nil")
|
||||
return nil, nil
|
||||
}
|
||||
tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
|
||||
if err != nil {
|
||||
log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
|
||||
|
@ -868,6 +881,10 @@ func (ms *MqTtMsgStream) consumeToTtMsg(consumer mqclient.Consumer) {
|
|||
}
|
||||
consumer.Ack(msg)
|
||||
|
||||
if msg.Payload() == nil {
|
||||
log.Warn("MqTtMsgStream get msg whose payload is nil")
|
||||
continue
|
||||
}
|
||||
tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
|
||||
if err != nil {
|
||||
log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
|
||||
|
|
Loading…
Reference in New Issue