Fix test structure for rmq (#24876)

Signed-off-by: fan.yang <julyyangfan@gmail.com>
pull/24831/head
fan yang 2023-06-13 19:28:37 -07:00 committed by GitHub
parent 561021ec5a
commit 893c3c0409
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 139 additions and 319 deletions

View File

@ -23,7 +23,6 @@ import (
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/pkg/log"
)
@ -33,13 +32,6 @@ var Rmq *rocksmq
// once is used to init global rocksmq
var once sync.Once
// InitRmq is deprecate implementation of global rocksmq. will be removed later
func InitRmq(rocksdbName string, idAllocator allocator.Interface) error {
var err error
Rmq, err = NewRocksMQ(rocksdbName, idAllocator)
return err
}
// InitRocksMQ init global rocksmq single instance
func InitRocksMQ(path string) error {
var finalErr error

View File

@ -12,44 +12,13 @@
package server
import (
"log"
"os"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/allocator"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/pkg/util/etcd"
)
func Test_InitRmq(t *testing.T) {
name := "/tmp/rmq_init"
defer os.RemoveAll("/tmp/rmq_init")
endpoints := os.Getenv("ETCD_ENDPOINTS")
if endpoints == "" {
endpoints = "localhost:2379"
}
etcdEndpoints := strings.Split(endpoints, ",")
etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints)
defer etcdCli.Close()
if err != nil {
log.Fatalf("New clientv3 error = %v", err)
}
etcdKV := etcdkv.NewEtcdKV(etcdCli, "/etcd/test/root")
idAllocator := allocator.NewGlobalIDAllocator("dummy", etcdKV)
_ = idAllocator.Initialize()
defer os.RemoveAll(name + kvSuffix)
defer os.RemoveAll(name)
err = InitRmq(name, idAllocator)
defer Rmq.stopRetention()
assert.NoError(t, err)
defer CloseRocksMQ()
}
func Test_InitRocksMQ(t *testing.T) {
rmqPath := "/tmp/milvus/rdb_data_global"
defer os.RemoveAll("/tmp/milvus")

View File

@ -20,8 +20,6 @@ import (
"context"
"fmt"
"log"
"os"
"strings"
"sync"
"testing"
@ -30,267 +28,176 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/funcutil"
)
type fixture struct {
t *testing.T
kv kv.MetaKv
}
type parameters struct {
client mqwrapper.Client
}
func (f *fixture) setup() []parameters {
rocksdbName := "/tmp/rocksmq_unittest_" + f.t.Name()
endpoints := os.Getenv("ETCD_ENDPOINTS")
if endpoints == "" {
endpoints = "localhost:2379"
}
etcdEndpoints := strings.Split(endpoints, ",")
etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints)
defer etcdCli.Close()
if err != nil {
log.Fatalf("New clientv3 error = %v", err)
}
f.kv = etcdkv.NewEtcdKV(etcdCli, "/etcd/test/root")
idAllocator := allocator.NewGlobalIDAllocator("dummy", f.kv)
_ = idAllocator.Initialize()
err = server.InitRmq(rocksdbName, idAllocator)
if err != nil {
log.Fatalf("InitRmq error = %v", err)
}
rmqClient, _ := NewClientWithDefaultOptions()
parameters := []parameters{
{rmqClient},
}
return parameters
}
func (f *fixture) teardown() {
rocksdbName := "/tmp/rocksmq_unittest_" + f.t.Name()
server.CloseRocksMQ()
f.kv.Close()
_ = os.RemoveAll(rocksdbName)
_ = os.RemoveAll(rocksdbName + "_meta_kv")
}
func Test_NewMqMsgStream(t *testing.T) {
f := &fixture{t: t}
parameters := f.setup()
defer f.teardown()
client, _ := createRmqClient()
defer client.Close()
factory := &msgstream.ProtoUDFactory{}
for i := range parameters {
func(client mqwrapper.Client) {
_, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
}(parameters[i].client)
}
_, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
}
// TODO(wxyu): add a mock implement of mqwrapper.Client, then inject errors to improve coverage
func TestMqMsgStream_AsProducer(t *testing.T) {
f := &fixture{t: t}
parameters := f.setup()
defer f.teardown()
client, _ := createRmqClient()
defer client.Close()
factory := &msgstream.ProtoUDFactory{}
for i := range parameters {
func(client mqwrapper.Client) {
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
// empty channel name
m.AsProducer([]string{""})
}(parameters[i].client)
}
// empty channel name
m.AsProducer([]string{""})
}
// TODO(wxyu): add a mock implement of mqwrapper.Client, then inject errors to improve coverage
func TestMqMsgStream_AsConsumer(t *testing.T) {
f := &fixture{t: t}
parameters := f.setup()
defer f.teardown()
client, _ := createRmqClient()
defer client.Close()
factory := &msgstream.ProtoUDFactory{}
for i := range parameters {
func(client mqwrapper.Client) {
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
// repeat calling AsConsumer
m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown)
m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown)
}(parameters[i].client)
}
// repeat calling AsConsumer
m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown)
m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown)
}
func TestMqMsgStream_ComputeProduceChannelIndexes(t *testing.T) {
f := &fixture{t: t}
parameters := f.setup()
defer f.teardown()
client, _ := createRmqClient()
defer client.Close()
factory := &msgstream.ProtoUDFactory{}
for i := range parameters {
func(client mqwrapper.Client) {
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
// empty parameters
reBucketValues := m.ComputeProduceChannelIndexes([]msgstream.TsMsg{})
assert.Nil(t, reBucketValues)
// empty parameters
reBucketValues := m.ComputeProduceChannelIndexes([]msgstream.TsMsg{})
assert.Nil(t, reBucketValues)
// not called AsProducer yet
insertMsg := &msgstream.InsertMsg{
BaseMsg: generateBaseMsg(),
InsertRequest: msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 1,
Timestamp: 2,
SourceID: 3,
},
// not called AsProducer yet
insertMsg := &msgstream.InsertMsg{
BaseMsg: generateBaseMsg(),
InsertRequest: msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 1,
Timestamp: 2,
SourceID: 3,
},
DbName: "test_db",
CollectionName: "test_collection",
PartitionName: "test_partition",
DbID: 4,
CollectionID: 5,
PartitionID: 6,
SegmentID: 7,
ShardName: "test-channel",
Timestamps: []uint64{2, 1, 3},
RowData: []*commonpb.Blob{},
},
}
reBucketValues = m.ComputeProduceChannelIndexes([]msgstream.TsMsg{insertMsg})
assert.Nil(t, reBucketValues)
}(parameters[i].client)
DbName: "test_db",
CollectionName: "test_collection",
PartitionName: "test_partition",
DbID: 4,
CollectionID: 5,
PartitionID: 6,
SegmentID: 7,
ShardName: "test-channel",
Timestamps: []uint64{2, 1, 3},
RowData: []*commonpb.Blob{},
},
}
reBucketValues = m.ComputeProduceChannelIndexes([]msgstream.TsMsg{insertMsg})
assert.Nil(t, reBucketValues)
}
func TestMqMsgStream_GetProduceChannels(t *testing.T) {
f := &fixture{t: t}
parameters := f.setup()
defer f.teardown()
client, _ := createRmqClient()
defer client.Close()
factory := &msgstream.ProtoUDFactory{}
for i := range parameters {
func(client mqwrapper.Client) {
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
// empty if not called AsProducer yet
chs := m.GetProduceChannels()
assert.Equal(t, 0, len(chs))
// empty if not called AsProducer yet
chs := m.GetProduceChannels()
assert.Equal(t, 0, len(chs))
// not empty after AsProducer
m.AsProducer([]string{"a"})
chs = m.GetProduceChannels()
assert.Equal(t, 1, len(chs))
}(parameters[i].client)
}
// not empty after AsProducer
m.AsProducer([]string{"a"})
chs = m.GetProduceChannels()
assert.Equal(t, 1, len(chs))
}
func TestMqMsgStream_Produce(t *testing.T) {
f := &fixture{t: t}
parameters := f.setup()
defer f.teardown()
client, _ := createRmqClient()
defer client.Close()
factory := &msgstream.ProtoUDFactory{}
for i := range parameters {
func(client mqwrapper.Client) {
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
// Produce before called AsProducer
insertMsg := &msgstream.InsertMsg{
BaseMsg: generateBaseMsg(),
InsertRequest: msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 1,
Timestamp: 2,
SourceID: 3,
},
// Produce before called AsProducer
insertMsg := &msgstream.InsertMsg{
BaseMsg: generateBaseMsg(),
InsertRequest: msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 1,
Timestamp: 2,
SourceID: 3,
},
DbName: "test_db",
CollectionName: "test_collection",
PartitionName: "test_partition",
DbID: 4,
CollectionID: 5,
PartitionID: 6,
SegmentID: 7,
ShardName: "test-channel",
Timestamps: []uint64{2, 1, 3},
RowData: []*commonpb.Blob{},
},
}
msgPack := &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{insertMsg},
}
err = m.Produce(msgPack)
assert.Error(t, err)
}(parameters[i].client)
DbName: "test_db",
CollectionName: "test_collection",
PartitionName: "test_partition",
DbID: 4,
CollectionID: 5,
PartitionID: 6,
SegmentID: 7,
ShardName: "test-channel",
Timestamps: []uint64{2, 1, 3},
RowData: []*commonpb.Blob{},
},
}
msgPack := &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{insertMsg},
}
err = m.Produce(msgPack)
assert.Error(t, err)
}
func TestMqMsgStream_Broadcast(t *testing.T) {
f := &fixture{t: t}
parameters := f.setup()
defer f.teardown()
client, _ := createRmqClient()
defer client.Close()
factory := &msgstream.ProtoUDFactory{}
for i := range parameters {
func(client mqwrapper.Client) {
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
// Broadcast nil pointer
_, err = m.Broadcast(nil)
assert.Error(t, err)
}(parameters[i].client)
}
// Broadcast nil pointer
_, err = m.Broadcast(nil)
assert.Error(t, err)
}
func TestMqMsgStream_Consume(t *testing.T) {
f := &fixture{t: t}
parameters := f.setup()
defer f.teardown()
client, _ := createRmqClient()
defer client.Close()
factory := &msgstream.ProtoUDFactory{}
for i := range parameters {
func(client mqwrapper.Client) {
// Consume return nil when ctx canceled
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
m, err := msgstream.NewMqMsgStream(ctx, 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
// Consume return nil when ctx canceled
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
m, err := msgstream.NewMqMsgStream(ctx, 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
wg.Add(1)
go func() {
defer wg.Done()
msgPack := consumer(ctx, m)
assert.Nil(t, msgPack)
}()
wg.Add(1)
go func() {
defer wg.Done()
msgPack := consumer(ctx, m)
assert.Nil(t, msgPack)
}()
cancel()
wg.Wait()
}(parameters[i].client)
}
cancel()
wg.Wait()
}
func consumer(ctx context.Context, mq msgstream.MsgStream) *msgstream.MsgPack {
@ -308,43 +215,33 @@ func consumer(ctx context.Context, mq msgstream.MsgStream) *msgstream.MsgPack {
}
func TestMqMsgStream_Chan(t *testing.T) {
f := &fixture{t: t}
parameters := f.setup()
defer f.teardown()
client, _ := createRmqClient()
defer client.Close()
factory := &msgstream.ProtoUDFactory{}
for i := range parameters {
func(client mqwrapper.Client) {
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
ch := m.Chan()
assert.NotNil(t, ch)
}(parameters[i].client)
}
ch := m.Chan()
assert.NotNil(t, ch)
}
func TestMqMsgStream_SeekNotSubscribed(t *testing.T) {
f := &fixture{t: t}
parameters := f.setup()
defer f.teardown()
client, _ := createRmqClient()
defer client.Close()
factory := &msgstream.ProtoUDFactory{}
for i := range parameters {
func(client mqwrapper.Client) {
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher())
assert.NoError(t, err)
// seek in not subscribed channel
p := []*msgpb.MsgPosition{
{
ChannelName: "b",
},
}
err = m.Seek(p)
assert.Error(t, err)
}(parameters[i].client)
// seek in not subscribed channel
p := []*msgpb.MsgPosition{
{
ChannelName: "b",
},
}
err = m.Seek(p)
assert.Error(t, err)
}
func generateBaseMsg() msgstream.BaseMsg {
@ -360,38 +257,6 @@ func generateBaseMsg() msgstream.BaseMsg {
/****************************************Rmq test******************************************/
func initRmq(name string) kv.MetaKv {
endpoints := os.Getenv("ETCD_ENDPOINTS")
if endpoints == "" {
endpoints = "localhost:2379"
}
etcdEndpoints := strings.Split(endpoints, ",")
etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints)
if err != nil {
log.Fatalf("New clientv3 error = %v", err)
}
kv := etcdkv.NewEtcdKV(etcdCli, "/etcd/test/root")
idAllocator := allocator.NewGlobalIDAllocator("dummy", kv)
_ = idAllocator.Initialize()
err = server.InitRmq(name, idAllocator)
if err != nil {
log.Fatalf("InitRmq error = %v", err)
}
return kv
}
func Close(rocksdbName string, intputStream, outputStream msgstream.MsgStream, kv kv.MetaKv) {
server.CloseRocksMQ()
intputStream.Close()
outputStream.Close()
kv.Close()
err := os.RemoveAll(rocksdbName)
_ = os.RemoveAll(rocksdbName + "_meta_kv")
log.Println(err)
}
func initRmqStream(ctx context.Context,
producerChannels []string,
consumerChannels []string,
@ -449,15 +314,14 @@ func TestStream_RmqMsgStream_Insert(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
rocksdbName := "/tmp/rocksmq_insert"
kv := initRmq(rocksdbName)
ctx := context.Background()
inputStream, outputStream := initRmqStream(ctx, producerChannels, consumerChannels, consumerGroupName)
err := inputStream.Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
Close(rocksdbName, inputStream, outputStream, kv)
inputStream.Close()
outputStream.Close()
}
func TestStream_RmqTtMsgStream_Insert(t *testing.T) {
@ -475,8 +339,6 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) {
msgPack2 := msgstream.MsgPack{}
msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5))
rocksdbName := "/tmp/rocksmq_insert_tt"
kv := initRmq(rocksdbName)
ctx := context.Background()
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
@ -490,12 +352,11 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) {
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
Close(rocksdbName, inputStream, outputStream, kv)
inputStream.Close()
outputStream.Close()
}
func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
rocksdbName := "/tmp/rocksmq_tt_msg_seek"
kv := initRmq(rocksdbName)
c1 := funcutil.RandomString(8)
producerChannels := []string{c1}
@ -550,12 +411,11 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[1].Type())
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[2].Type())
Close(rocksdbName, inputStream, outputStream, kv)
inputStream.Close()
outputStream.Close()
}
func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
rocksdbName := "/tmp/rocksmq_tt_msg_seek"
kv := initRmq(rocksdbName)
c1 := funcutil.RandomString(8)
producerChannels := []string{c1}
@ -662,12 +522,11 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
assert.Equal(t, msg.BeginTs(), uint64(19))
}
Close(rocksdbName, inputStream, outputStream, kv)
inputStream.Close()
outputStream.Close()
}
func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
rocksdbName := "/tmp/rocksmq_tt_msg_seekInvalid"
kv := initRmq(rocksdbName)
c := funcutil.RandomString(8)
producerChannels := []string{c}
consumerChannels := []string{c}
@ -721,7 +580,8 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
result := consumer(ctx, outputStream2)
assert.Equal(t, result.Msgs[0].ID(), int64(1))
Close(rocksdbName, inputStream, outputStream2, kv)
inputStream.Close()
outputStream2.Close()
}
func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) {
@ -729,8 +589,6 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) {
consumerChannels := []string{"insert1"}
consumerSubName := "subInsert"
rocksdbName := "/tmp/rocksmq_asconsumer_withpos"
kv := initRmq(rocksdbName)
factory := msgstream.ProtoUDFactory{}
rmqClient, _ := NewClientWithDefaultOptions()
@ -756,7 +614,8 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) {
assert.Equal(t, 1, len(pack.Msgs))
assert.EqualValues(t, 1000, pack.Msgs[0].BeginTs())
Close(rocksdbName, inputStream, outputStream, kv)
inputStream.Close()
outputStream.Close()
}
func getTimeTickMsgPack(reqID msgstream.UniqueID) *msgstream.MsgPack {