milvus/internal/streamingnode/server/flusher/flusherimpl/wal_flusher_test.go

178 lines
6.1 KiB
Go

package flusherimpl
import (
"context"
"os"
"testing"
"time"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/mock_storage"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/server/wal/mock_recovery"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/recovery"
internaltypes "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
func TestMain(m *testing.M) {
defaultCollectionNotFoundTolerance = 2
paramtable.Init()
if code := m.Run(); code != 0 {
os.Exit(code)
}
}
func TestWALFlusher(t *testing.T) {
streamingutil.SetStreamingServiceEnabled()
defer streamingutil.UnsetStreamingServiceEnabled()
mixcoord := newMockMixcoord(t, false)
mixcoord.EXPECT().AllocSegment(mock.Anything, mock.Anything).Return(&datapb.AllocSegmentResponse{
Status: merr.Status(nil),
}, nil)
mixcoord.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).Return(&datapb.DropVirtualChannelResponse{
Status: merr.Status(nil),
}, nil)
fMixcoord := syncutil.NewFuture[internaltypes.MixCoordClient]()
fMixcoord.Set(mixcoord)
rs := mock_recovery.NewMockRecoveryStorage(t)
rs.EXPECT().GetSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "ID", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "Vector", DataType: schemapb.DataType_FloatVector},
},
}, nil)
rs.EXPECT().ObserveMessage(mock.Anything, mock.Anything).Return(nil)
rs.EXPECT().Close().Return()
resource.InitForTest(
t,
resource.OptMixCoordClient(fMixcoord),
resource.OptChunkManager(mock_storage.NewMockChunkManager(t)),
)
l := newMockWAL(t, false)
param := &RecoverWALFlusherParam{
ChannelInfo: l.Channel(),
WAL: syncutil.NewFuture[wal.WAL](),
RecoverySnapshot: &recovery.RecoverySnapshot{
VChannels: map[string]*streamingpb.VChannelMeta{
"vchannel-1": {
CollectionInfo: &streamingpb.CollectionInfoOfVChannel{
CollectionId: 100,
},
},
"vchannel-2": {
CollectionInfo: &streamingpb.CollectionInfoOfVChannel{
CollectionId: 100,
},
},
"vchannel-3": {
CollectionInfo: &streamingpb.CollectionInfoOfVChannel{
CollectionId: 100,
},
},
},
Checkpoint: &recovery.WALCheckpoint{
TimeTick: 0,
},
},
RecoveryStorage: rs,
}
param.WAL.Set(l)
flusher := RecoverWALFlusher(param)
time.Sleep(5 * time.Second)
flusher.Close()
}
func newMockMixcoord(t *testing.T, maybe bool) *mocks.MockMixCoordClient {
mixcoord := mocks.NewMockMixCoordClient(t)
mixcoord.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).Return(&datapb.DropVirtualChannelResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
expect := mixcoord.EXPECT().GetChannelRecoveryInfo(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, request *datapb.GetChannelRecoveryInfoRequest, option ...grpc.CallOption,
) (*datapb.GetChannelRecoveryInfoResponse, error) {
if request.Vchannel == "vchannel-3" {
return &datapb.GetChannelRecoveryInfoResponse{
Status: merr.Status(merr.ErrCollectionNotFound),
}, nil
} else if request.Vchannel == "vchannel-2" {
return &datapb.GetChannelRecoveryInfoResponse{
Status: merr.Status(merr.ErrChannelNotAvailable),
}, nil
}
messageID := 1
b := make([]byte, 8)
common.Endian.PutUint64(b, uint64(messageID))
return &datapb.GetChannelRecoveryInfoResponse{
Info: &datapb.VchannelInfo{
ChannelName: request.GetVchannel(),
SeekPosition: &msgpb.MsgPosition{MsgID: b},
},
Schema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "ID", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "Vector", DataType: schemapb.DataType_FloatVector},
},
},
}, nil
})
if maybe {
expect.Maybe()
}
return mixcoord
}
func newMockWAL(t *testing.T, maybe bool) *mock_wal.MockWAL {
w := mock_wal.NewMockWAL(t)
walName := w.EXPECT().WALName().Return("rocksmq")
if maybe {
walName.Maybe()
}
w.EXPECT().Channel().Return(types.PChannelInfo{Name: "pchannel"}).Maybe()
read := w.EXPECT().Read(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, option wal.ReadOption) (wal.Scanner, error) {
handler := option.MesasgeHandler
scanner := mock_wal.NewMockScanner(t)
ch := make(chan message.ImmutableMessage, 4)
msg := message.CreateTestCreateCollectionMessage(t, 2, 100, rmq.NewRmqID(100))
ch <- msg.IntoImmutableMessage(rmq.NewRmqID(105))
msg = message.CreateTestCreateSegmentMessage(t, 2, 101, rmq.NewRmqID(101))
ch <- msg.IntoImmutableMessage(rmq.NewRmqID(106))
msg = message.CreateTestTimeTickSyncMessage(t, 2, 102, rmq.NewRmqID(101))
ch <- msg.IntoImmutableMessage(rmq.NewRmqID(107))
msg = message.CreateTestDropCollectionMessage(t, 2, 103, rmq.NewRmqID(104))
ch <- msg.IntoImmutableMessage(rmq.NewRmqID(108))
scanner.EXPECT().Chan().RunAndReturn(func() <-chan message.ImmutableMessage {
return ch
})
scanner.EXPECT().Close().RunAndReturn(func() error {
handler.Close()
return nil
})
return scanner, nil
})
if maybe {
read.Maybe()
}
return w
}