diff --git a/internal/mq/msgstream/mq_msgstream.go b/internal/mq/msgstream/mq_msgstream.go index 9e25338e14..9a1242e5c3 100644 --- a/internal/mq/msgstream/mq_msgstream.go +++ b/internal/mq/msgstream/mq_msgstream.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/trace" + "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/opentracing/opentracing-go" ) @@ -756,10 +757,21 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { ms.chanMsgBufMutex.Unlock() ms.consumerLock.Unlock() + idset := make(typeutil.UniqueSet) + uniqueMsgs := make([]TsMsg, 0, len(timeTickBuf)) + for _, msg := range timeTickBuf { + if idset.Contain(msg.ID()) { + log.Warn("mqTtMsgStream, found duplicated msg", zap.Int64("msgID", msg.ID())) + continue + } + idset.Insert(msg.ID()) + uniqueMsgs = append(uniqueMsgs, msg) + } + msgPack := MsgPack{ BeginTs: ms.lastTimeStamp, EndTs: currTs, - Msgs: timeTickBuf, + Msgs: uniqueMsgs, StartPositions: startMsgPosition, EndPositions: endMsgPositions, } diff --git a/internal/mq/msgstream/mq_msgstream_test.go b/internal/mq/msgstream/mq_msgstream_test.go index c402108172..0dee723f1f 100644 --- a/internal/mq/msgstream/mq_msgstream_test.go +++ b/internal/mq/msgstream/mq_msgstream_test.go @@ -30,6 +30,7 @@ import ( "unsafe" "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" + "go.uber.org/atomic" "github.com/apache/pulsar-client-go/pulsar" pulsarwrapper "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper/pulsar" @@ -1527,6 +1528,62 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) { Close(rocksdbName, inputStream, outputStream, etcdKV) } +func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { + rocksdbName := "/tmp/rocksmq_tt_msg_seek" + etcdKV := initRmq(rocksdbName) + + c1 := funcutil.RandomString(8) + producerChannels := []string{c1} + consumerChannels := []string{c1} + consumerSubName := funcutil.RandomString(8) + + msgPack0 := MsgPack{} + msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0)) + + msgPack1 := MsgPack{} + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) + + msgPack2 := MsgPack{} + msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(15)) + + ctx := context.Background() + inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName) + + err := inputStream.Broadcast(&msgPack0) + assert.Nil(t, err) + err = inputStream.Produce(&msgPack1) + assert.Nil(t, err) + err = inputStream.Broadcast(&msgPack2) + assert.Nil(t, err) + + receivedMsg := consumer(ctx, outputStream) + assert.Equal(t, len(receivedMsg.Msgs), 1) + assert.Equal(t, receivedMsg.BeginTs, uint64(0)) + assert.Equal(t, receivedMsg.EndTs, uint64(15)) + + outputStream.Close() + + factory := ProtoUDFactory{} + + rmqClient, _ := rmq.NewClientWithDefaultOptions() + outputStream, _ = NewMqTtMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) + consumerSubName = funcutil.RandomString(8) + outputStream.AsConsumer(consumerChannels, consumerSubName) + + outputStream.Seek(receivedMsg.StartPositions) + outputStream.Start() + seekMsg := consumer(ctx, outputStream) + assert.Equal(t, len(seekMsg.Msgs), 1) + for _, msg := range seekMsg.Msgs { + assert.EqualValues(t, msg.BeginTs(), 1) + } + + Close(rocksdbName, inputStream, outputStream, etcdKV) + +} + func TestStream_RmqTtMsgStream_Seek(t *testing.T) { rocksdbName := "/tmp/rocksmq_tt_msg_seek" etcdKV := initRmq(rocksdbName) @@ -1965,7 +2022,7 @@ func getRandInsertMsgPack(num int, start int, end int) *MsgPack { _, ok := set[reqID] if !ok { set[reqID] = true - msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, int64(reqID))) + msgPack.Msgs = append(msgPack.Msgs, getInsertMsgUniqueID(int64(reqID))) //getTsMsg(commonpb.MsgType_Insert, int64(reqID))) } } return &msgPack @@ -1974,11 +2031,45 @@ func getRandInsertMsgPack(num int, start int, end int) *MsgPack { func getInsertMsgPack(ts []int) *MsgPack { msgPack := MsgPack{} for i := 0; i < len(ts); i++ { - msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, int64(ts[i]))) + msgPack.Msgs = append(msgPack.Msgs, getInsertMsgUniqueID(int64(ts[i]))) //getTsMsg(commonpb.MsgType_Insert, int64(ts[i]))) } return &msgPack } +var idCounter atomic.Int64 + +func getInsertMsgUniqueID(ts UniqueID) TsMsg { + hashValue := uint32(ts) + time := uint64(ts) + baseMsg := BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{hashValue}, + } + + insertRequest := internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: idCounter.Inc(), + Timestamp: time, + SourceID: ts, + }, + CollectionName: "Collection", + PartitionName: "Partition", + SegmentID: 1, + ShardName: "0", + Timestamps: []Timestamp{time}, + RowIDs: []int64{1}, + RowData: []*commonpb.Blob{{}}, + } + insertMsg := &InsertMsg{ + BaseMsg: baseMsg, + InsertRequest: insertRequest, + } + return insertMsg + +} + func getTimeTickMsgPack(reqID UniqueID) *MsgPack { msgPack := MsgPack{} msgPack.Msgs = append(msgPack.Msgs, getTimeTickMsg(reqID))