milvus/pkg/streaming/walimpls/test_framework.go

336 lines
8.5 KiB
Go

//go:build test
// +build test
package walimpls
import (
"context"
"fmt"
"math/rand"
"sort"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/remeh/sizedwaitgroup"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
"github.com/milvus-io/milvus/pkg/streaming/util/options"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
)
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func randString(l int) string {
builder := strings.Builder{}
for i := 0; i < l; i++ {
builder.WriteRune(letters[rand.Intn(len(letters))])
}
return builder.String()
}
type walImplsTestFramework struct {
b OpenerBuilderImpls
t *testing.T
messageCount int
}
func NewWALImplsTestFramework(t *testing.T, messageCount int, b OpenerBuilderImpls) *walImplsTestFramework {
return &walImplsTestFramework{
b: b,
t: t,
messageCount: messageCount,
}
}
// Run runs the test framework.
// if test failed, a error will be returned.
func (f walImplsTestFramework) Run() {
// create opener.
o, err := f.b.Build()
assert.NoError(f.t, err)
assert.NotNil(f.t, o)
defer o.Close()
// Test on multi pchannels
wg := sync.WaitGroup{}
pchannelCnt := 3
wg.Add(pchannelCnt)
for i := 0; i < pchannelCnt; i++ {
// construct pChannel
name := fmt.Sprintf("test_%d_%s", i, randString(4))
go func(name string) {
defer wg.Done()
newTestOneWALImpls(f.t, o, name, f.messageCount).Run()
}(name)
}
wg.Wait()
}
func newTestOneWALImpls(t *testing.T, opener OpenerImpls, pchannel string, messageCount int) *testOneWALImplsFramework {
return &testOneWALImplsFramework{
t: t,
opener: opener,
pchannel: pchannel,
written: make([]message.ImmutableMessage, 0),
messageCount: messageCount,
term: 1,
}
}
type testOneWALImplsFramework struct {
t *testing.T
opener OpenerImpls
written []message.ImmutableMessage
pchannel string
messageCount int
term int
}
func (f *testOneWALImplsFramework) Run() {
ctx := context.Background()
// test a read write loop
for ; f.term <= 3; f.term++ {
pChannel := types.PChannelInfo{
Name: f.pchannel,
Term: int64(f.term),
ServerID: 1,
}
// create a wal.
w, err := f.opener.Open(ctx, &OpenOption{
Channel: pChannel,
})
assert.NoError(f.t, err)
assert.NotNil(f.t, w)
assert.Equal(f.t, pChannel.Name, w.Channel().Name)
assert.Equal(f.t, pChannel.ServerID, w.Channel().ServerID)
assert.Equal(f.t, pChannel.Term, w.Channel().Term)
f.testReadAndWrite(ctx, w)
// close the wal
w.Close()
}
}
func (f *testOneWALImplsFramework) testReadAndWrite(ctx context.Context, w WALImpls) {
// Test read and write.
wg := sync.WaitGroup{}
wg.Add(3)
var newWritten []message.ImmutableMessage
var read1, read2 []message.ImmutableMessage
go func() {
defer wg.Done()
var err error
newWritten, err = f.testAppend(ctx, w)
assert.NoError(f.t, err)
}()
go func() {
defer wg.Done()
var err error
read1, err = f.testRead(ctx, w, "scanner1")
assert.NoError(f.t, err)
}()
go func() {
defer wg.Done()
var err error
read2, err = f.testRead(ctx, w, "scanner2")
assert.NoError(f.t, err)
}()
wg.Wait()
f.assertSortedMessageList(read1)
f.assertSortedMessageList(read2)
sort.Sort(sortByMessageID(newWritten))
f.written = append(f.written, newWritten...)
f.assertSortedMessageList(f.written)
f.assertEqualMessageList(f.written, read1)
f.assertEqualMessageList(f.written, read2)
// Test different scan policy, StartFrom.
readFromIdx := len(f.written) / 2
readFromMsgID := f.written[readFromIdx].MessageID()
s, err := w.Read(ctx, ReadOption{
Name: "scanner_deliver_start_from",
DeliverPolicy: options.DeliverPolicyStartFrom(readFromMsgID),
})
assert.NoError(f.t, err)
for i := readFromIdx; i < len(f.written); i++ {
msg, ok := <-s.Chan()
assert.NotNil(f.t, msg)
assert.True(f.t, ok)
assert.True(f.t, msg.MessageID().EQ(f.written[i].MessageID()))
}
s.Close()
// Test different scan policy, StartAfter.
s, err = w.Read(ctx, ReadOption{
Name: "scanner_deliver_start_after",
DeliverPolicy: options.DeliverPolicyStartAfter(readFromMsgID),
})
assert.NoError(f.t, err)
for i := readFromIdx + 1; i < len(f.written); i++ {
msg, ok := <-s.Chan()
assert.NotNil(f.t, msg)
assert.True(f.t, ok)
assert.True(f.t, msg.MessageID().EQ(f.written[i].MessageID()))
}
s.Close()
// Test different scan policy, Latest.
s, err = w.Read(ctx, ReadOption{
Name: "scanner_deliver_latest",
DeliverPolicy: options.DeliverPolicyLatest(),
})
assert.NoError(f.t, err)
timeoutCh := time.After(1 * time.Second)
select {
case <-s.Chan():
f.t.Errorf("should be blocked")
case <-timeoutCh:
}
s.Close()
}
func (f *testOneWALImplsFramework) assertSortedMessageList(msgs []message.ImmutableMessage) {
for i := 1; i < len(msgs); i++ {
assert.True(f.t, msgs[i-1].MessageID().LT(msgs[i].MessageID()))
}
}
func (f *testOneWALImplsFramework) assertEqualMessageList(msgs1 []message.ImmutableMessage, msgs2 []message.ImmutableMessage) {
assert.Equal(f.t, len(msgs2), len(msgs1))
for i := 0; i < len(msgs1); i++ {
assert.True(f.t, msgs1[i].MessageID().EQ(msgs2[i].MessageID()))
// assert.True(f.t, bytes.Equal(msgs1[i].Payload(), msgs2[i].Payload()))
id1, ok1 := msgs1[i].Properties().Get("id")
id2, ok2 := msgs2[i].Properties().Get("id")
assert.True(f.t, ok1)
assert.True(f.t, ok2)
assert.Equal(f.t, id1, id2)
id1, ok1 = msgs1[i].Properties().Get("const")
id2, ok2 = msgs2[i].Properties().Get("const")
assert.True(f.t, ok1)
assert.True(f.t, ok2)
assert.Equal(f.t, id1, id2)
}
}
func (f *testOneWALImplsFramework) testAppend(ctx context.Context, w WALImpls) ([]message.ImmutableMessage, error) {
ids := make([]message.ImmutableMessage, f.messageCount)
swg := sizedwaitgroup.New(5)
for i := 0; i < f.messageCount-1; i++ {
swg.Add()
go func(i int) {
defer swg.Done()
// ...rocksmq has a dirty implement of properties,
// without commonpb.MsgHeader, it can not work.
header := commonpb.MsgHeader{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: int64(i),
},
}
payload, err := proto.Marshal(&header)
if err != nil {
panic(err)
}
properties := map[string]string{
"id": fmt.Sprintf("%d", i),
"const": "t",
}
typ := message.MessageTypeUnknown
msg := message.NewMutableMessageBuilder().
WithMessageType(typ).
WithPayload(payload).
WithProperties(properties).
BuildMutable()
id, err := w.Append(ctx, msg)
assert.NoError(f.t, err)
assert.NotNil(f.t, id)
ids[i] = msg.IntoImmutableMessage(id)
}(i)
}
swg.Wait()
// send a final hint message
header := commonpb.MsgHeader{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: int64(f.messageCount - 1),
},
}
payload, err := proto.Marshal(&header)
if err != nil {
panic(err)
}
properties := map[string]string{
"id": fmt.Sprintf("%d", f.messageCount-1),
"const": "t",
"term": strconv.FormatInt(int64(f.term), 10),
}
msg := message.NewMutableMessageBuilder().
WithPayload(payload).
WithProperties(properties).
WithMessageType(message.MessageTypeTimeTick).
BuildMutable()
id, err := w.Append(ctx, msg)
assert.NoError(f.t, err)
ids[f.messageCount-1] = msg.IntoImmutableMessage(id)
return ids, nil
}
func (f *testOneWALImplsFramework) testRead(ctx context.Context, w WALImpls, name string) ([]message.ImmutableMessage, error) {
s, err := w.Read(ctx, ReadOption{
Name: name,
DeliverPolicy: options.DeliverPolicyAll(),
ReadAheadBufferSize: 128,
})
assert.NoError(f.t, err)
assert.Equal(f.t, name, s.Name())
defer s.Close()
expectedCnt := f.messageCount + len(f.written)
msgs := make([]message.ImmutableMessage, 0, expectedCnt)
for {
msg, ok := <-s.Chan()
assert.NotNil(f.t, msg)
assert.True(f.t, ok)
msgs = append(msgs, msg)
if msg.MessageType() == message.MessageTypeTimeTick {
termString, ok := msg.Properties().Get("term")
if !ok {
panic("lost term properties")
}
term, err := strconv.ParseInt(termString, 10, 64)
if err != nil {
panic(err)
}
if int(term) == f.term {
break
}
}
}
return msgs, nil
}
type sortByMessageID []message.ImmutableMessage
func (a sortByMessageID) Len() int {
return len(a)
}
func (a sortByMessageID) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}
func (a sortByMessageID) Less(i, j int) bool {
return a[i].MessageID().LT(a[j].MessageID())
}