diff --git a/internal/mq/mqimpl/rocksmq/server/global_rmq.go b/internal/mq/mqimpl/rocksmq/server/global_rmq.go index 291709fa9a..e668c63d8d 100644 --- a/internal/mq/mqimpl/rocksmq/server/global_rmq.go +++ b/internal/mq/mqimpl/rocksmq/server/global_rmq.go @@ -23,7 +23,6 @@ import ( "github.com/cockroachdb/errors" "go.uber.org/zap" - "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/pkg/log" ) @@ -33,13 +32,6 @@ var Rmq *rocksmq // once is used to init global rocksmq var once sync.Once -// InitRmq is deprecate implementation of global rocksmq. will be removed later -func InitRmq(rocksdbName string, idAllocator allocator.Interface) error { - var err error - Rmq, err = NewRocksMQ(rocksdbName, idAllocator) - return err -} - // InitRocksMQ init global rocksmq single instance func InitRocksMQ(path string) error { var finalErr error diff --git a/internal/mq/mqimpl/rocksmq/server/global_rmq_test.go b/internal/mq/mqimpl/rocksmq/server/global_rmq_test.go index 986b3f67dc..ecb3d682c0 100644 --- a/internal/mq/mqimpl/rocksmq/server/global_rmq_test.go +++ b/internal/mq/mqimpl/rocksmq/server/global_rmq_test.go @@ -12,44 +12,13 @@ package server import ( - "log" "os" - "strings" "sync" "testing" "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus/internal/allocator" - etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" - "github.com/milvus-io/milvus/pkg/util/etcd" ) -func Test_InitRmq(t *testing.T) { - name := "/tmp/rmq_init" - defer os.RemoveAll("/tmp/rmq_init") - endpoints := os.Getenv("ETCD_ENDPOINTS") - if endpoints == "" { - endpoints = "localhost:2379" - } - etcdEndpoints := strings.Split(endpoints, ",") - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - defer etcdCli.Close() - if err != nil { - log.Fatalf("New clientv3 error = %v", err) - } - etcdKV := etcdkv.NewEtcdKV(etcdCli, "/etcd/test/root") - idAllocator := allocator.NewGlobalIDAllocator("dummy", etcdKV) - _ = idAllocator.Initialize() - - defer os.RemoveAll(name + kvSuffix) - defer os.RemoveAll(name) - err = InitRmq(name, idAllocator) - defer Rmq.stopRetention() - assert.NoError(t, err) - defer CloseRocksMQ() -} - func Test_InitRocksMQ(t *testing.T) { rmqPath := "/tmp/milvus/rdb_data_global" defer os.RemoveAll("/tmp/milvus") diff --git a/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go b/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go index 26104ef97b..35bcf29c76 100644 --- a/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go +++ b/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go @@ -20,8 +20,6 @@ import ( "context" "fmt" "log" - "os" - "strings" "sync" "testing" @@ -30,267 +28,176 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus/internal/allocator" - "github.com/milvus-io/milvus/internal/kv" - etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" - "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" - "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" ) -type fixture struct { - t *testing.T - kv kv.MetaKv -} - -type parameters struct { - client mqwrapper.Client -} - -func (f *fixture) setup() []parameters { - rocksdbName := "/tmp/rocksmq_unittest_" + f.t.Name() - endpoints := os.Getenv("ETCD_ENDPOINTS") - if endpoints == "" { - endpoints = "localhost:2379" - } - etcdEndpoints := strings.Split(endpoints, ",") - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - defer etcdCli.Close() - if err != nil { - log.Fatalf("New clientv3 error = %v", err) - } - f.kv = etcdkv.NewEtcdKV(etcdCli, "/etcd/test/root") - idAllocator := allocator.NewGlobalIDAllocator("dummy", f.kv) - _ = idAllocator.Initialize() - err = server.InitRmq(rocksdbName, idAllocator) - if err != nil { - log.Fatalf("InitRmq error = %v", err) - } - - rmqClient, _ := NewClientWithDefaultOptions() - - parameters := []parameters{ - {rmqClient}, - } - return parameters -} - -func (f *fixture) teardown() { - rocksdbName := "/tmp/rocksmq_unittest_" + f.t.Name() - - server.CloseRocksMQ() - f.kv.Close() - _ = os.RemoveAll(rocksdbName) - _ = os.RemoveAll(rocksdbName + "_meta_kv") -} - func Test_NewMqMsgStream(t *testing.T) { - f := &fixture{t: t} - parameters := f.setup() - defer f.teardown() + client, _ := createRmqClient() + defer client.Close() factory := &msgstream.ProtoUDFactory{} - for i := range parameters { - func(client mqwrapper.Client) { - _, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) - assert.NoError(t, err) - }(parameters[i].client) - } + _, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.NoError(t, err) } // TODO(wxyu): add a mock implement of mqwrapper.Client, then inject errors to improve coverage func TestMqMsgStream_AsProducer(t *testing.T) { - f := &fixture{t: t} - parameters := f.setup() - defer f.teardown() + client, _ := createRmqClient() + defer client.Close() factory := &msgstream.ProtoUDFactory{} - for i := range parameters { - func(client mqwrapper.Client) { - m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) - assert.NoError(t, err) + m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.NoError(t, err) - // empty channel name - m.AsProducer([]string{""}) - }(parameters[i].client) - } + // empty channel name + m.AsProducer([]string{""}) } // TODO(wxyu): add a mock implement of mqwrapper.Client, then inject errors to improve coverage func TestMqMsgStream_AsConsumer(t *testing.T) { - f := &fixture{t: t} - parameters := f.setup() - defer f.teardown() + client, _ := createRmqClient() + defer client.Close() factory := &msgstream.ProtoUDFactory{} - for i := range parameters { - func(client mqwrapper.Client) { - m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) - assert.NoError(t, err) + m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.NoError(t, err) - // repeat calling AsConsumer - m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown) - m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown) - }(parameters[i].client) - } + // repeat calling AsConsumer + m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown) + m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown) } func TestMqMsgStream_ComputeProduceChannelIndexes(t *testing.T) { - f := &fixture{t: t} - parameters := f.setup() - defer f.teardown() + client, _ := createRmqClient() + defer client.Close() factory := &msgstream.ProtoUDFactory{} - for i := range parameters { - func(client mqwrapper.Client) { - m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) - assert.NoError(t, err) + m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.NoError(t, err) - // empty parameters - reBucketValues := m.ComputeProduceChannelIndexes([]msgstream.TsMsg{}) - assert.Nil(t, reBucketValues) + // empty parameters + reBucketValues := m.ComputeProduceChannelIndexes([]msgstream.TsMsg{}) + assert.Nil(t, reBucketValues) - // not called AsProducer yet - insertMsg := &msgstream.InsertMsg{ - BaseMsg: generateBaseMsg(), - InsertRequest: msgpb.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Insert, - MsgID: 1, - Timestamp: 2, - SourceID: 3, - }, + // not called AsProducer yet + insertMsg := &msgstream.InsertMsg{ + BaseMsg: generateBaseMsg(), + InsertRequest: msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: 1, + Timestamp: 2, + SourceID: 3, + }, - DbName: "test_db", - CollectionName: "test_collection", - PartitionName: "test_partition", - DbID: 4, - CollectionID: 5, - PartitionID: 6, - SegmentID: 7, - ShardName: "test-channel", - Timestamps: []uint64{2, 1, 3}, - RowData: []*commonpb.Blob{}, - }, - } - reBucketValues = m.ComputeProduceChannelIndexes([]msgstream.TsMsg{insertMsg}) - assert.Nil(t, reBucketValues) - }(parameters[i].client) + DbName: "test_db", + CollectionName: "test_collection", + PartitionName: "test_partition", + DbID: 4, + CollectionID: 5, + PartitionID: 6, + SegmentID: 7, + ShardName: "test-channel", + Timestamps: []uint64{2, 1, 3}, + RowData: []*commonpb.Blob{}, + }, } + reBucketValues = m.ComputeProduceChannelIndexes([]msgstream.TsMsg{insertMsg}) + assert.Nil(t, reBucketValues) } func TestMqMsgStream_GetProduceChannels(t *testing.T) { - f := &fixture{t: t} - parameters := f.setup() - defer f.teardown() + client, _ := createRmqClient() + defer client.Close() factory := &msgstream.ProtoUDFactory{} - for i := range parameters { - func(client mqwrapper.Client) { - m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) - assert.NoError(t, err) + m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.NoError(t, err) - // empty if not called AsProducer yet - chs := m.GetProduceChannels() - assert.Equal(t, 0, len(chs)) + // empty if not called AsProducer yet + chs := m.GetProduceChannels() + assert.Equal(t, 0, len(chs)) - // not empty after AsProducer - m.AsProducer([]string{"a"}) - chs = m.GetProduceChannels() - assert.Equal(t, 1, len(chs)) - }(parameters[i].client) - } + // not empty after AsProducer + m.AsProducer([]string{"a"}) + chs = m.GetProduceChannels() + assert.Equal(t, 1, len(chs)) } func TestMqMsgStream_Produce(t *testing.T) { - f := &fixture{t: t} - parameters := f.setup() - defer f.teardown() + client, _ := createRmqClient() + defer client.Close() factory := &msgstream.ProtoUDFactory{} - for i := range parameters { - func(client mqwrapper.Client) { - m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) - assert.NoError(t, err) + m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.NoError(t, err) - // Produce before called AsProducer - insertMsg := &msgstream.InsertMsg{ - BaseMsg: generateBaseMsg(), - InsertRequest: msgpb.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Insert, - MsgID: 1, - Timestamp: 2, - SourceID: 3, - }, + // Produce before called AsProducer + insertMsg := &msgstream.InsertMsg{ + BaseMsg: generateBaseMsg(), + InsertRequest: msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: 1, + Timestamp: 2, + SourceID: 3, + }, - DbName: "test_db", - CollectionName: "test_collection", - PartitionName: "test_partition", - DbID: 4, - CollectionID: 5, - PartitionID: 6, - SegmentID: 7, - ShardName: "test-channel", - Timestamps: []uint64{2, 1, 3}, - RowData: []*commonpb.Blob{}, - }, - } - msgPack := &msgstream.MsgPack{ - Msgs: []msgstream.TsMsg{insertMsg}, - } - err = m.Produce(msgPack) - assert.Error(t, err) - }(parameters[i].client) + DbName: "test_db", + CollectionName: "test_collection", + PartitionName: "test_partition", + DbID: 4, + CollectionID: 5, + PartitionID: 6, + SegmentID: 7, + ShardName: "test-channel", + Timestamps: []uint64{2, 1, 3}, + RowData: []*commonpb.Blob{}, + }, } + msgPack := &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{insertMsg}, + } + err = m.Produce(msgPack) + assert.Error(t, err) } func TestMqMsgStream_Broadcast(t *testing.T) { - f := &fixture{t: t} - parameters := f.setup() - defer f.teardown() + client, _ := createRmqClient() + defer client.Close() factory := &msgstream.ProtoUDFactory{} - for i := range parameters { - func(client mqwrapper.Client) { - m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) - assert.NoError(t, err) + m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.NoError(t, err) - // Broadcast nil pointer - _, err = m.Broadcast(nil) - assert.Error(t, err) - }(parameters[i].client) - } + // Broadcast nil pointer + _, err = m.Broadcast(nil) + assert.Error(t, err) } func TestMqMsgStream_Consume(t *testing.T) { - f := &fixture{t: t} - parameters := f.setup() - defer f.teardown() + client, _ := createRmqClient() + defer client.Close() factory := &msgstream.ProtoUDFactory{} - for i := range parameters { - func(client mqwrapper.Client) { - // Consume return nil when ctx canceled - var wg sync.WaitGroup - ctx, cancel := context.WithCancel(context.Background()) - m, err := msgstream.NewMqMsgStream(ctx, 100, 100, client, factory.NewUnmarshalDispatcher()) - assert.NoError(t, err) + // Consume return nil when ctx canceled + var wg sync.WaitGroup + ctx, cancel := context.WithCancel(context.Background()) + m, err := msgstream.NewMqMsgStream(ctx, 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.NoError(t, err) - wg.Add(1) - go func() { - defer wg.Done() - msgPack := consumer(ctx, m) - assert.Nil(t, msgPack) - }() + wg.Add(1) + go func() { + defer wg.Done() + msgPack := consumer(ctx, m) + assert.Nil(t, msgPack) + }() - cancel() - wg.Wait() - }(parameters[i].client) - } + cancel() + wg.Wait() } func consumer(ctx context.Context, mq msgstream.MsgStream) *msgstream.MsgPack { @@ -308,43 +215,33 @@ func consumer(ctx context.Context, mq msgstream.MsgStream) *msgstream.MsgPack { } func TestMqMsgStream_Chan(t *testing.T) { - f := &fixture{t: t} - parameters := f.setup() - defer f.teardown() + client, _ := createRmqClient() + defer client.Close() factory := &msgstream.ProtoUDFactory{} - for i := range parameters { - func(client mqwrapper.Client) { - m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) - assert.NoError(t, err) + m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.NoError(t, err) - ch := m.Chan() - assert.NotNil(t, ch) - }(parameters[i].client) - } + ch := m.Chan() + assert.NotNil(t, ch) } func TestMqMsgStream_SeekNotSubscribed(t *testing.T) { - f := &fixture{t: t} - parameters := f.setup() - defer f.teardown() + client, _ := createRmqClient() + defer client.Close() factory := &msgstream.ProtoUDFactory{} - for i := range parameters { - func(client mqwrapper.Client) { - m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) - assert.NoError(t, err) + m, err := msgstream.NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.NoError(t, err) - // seek in not subscribed channel - p := []*msgpb.MsgPosition{ - { - ChannelName: "b", - }, - } - err = m.Seek(p) - assert.Error(t, err) - }(parameters[i].client) + // seek in not subscribed channel + p := []*msgpb.MsgPosition{ + { + ChannelName: "b", + }, } + err = m.Seek(p) + assert.Error(t, err) } func generateBaseMsg() msgstream.BaseMsg { @@ -360,38 +257,6 @@ func generateBaseMsg() msgstream.BaseMsg { /****************************************Rmq test******************************************/ -func initRmq(name string) kv.MetaKv { - endpoints := os.Getenv("ETCD_ENDPOINTS") - if endpoints == "" { - endpoints = "localhost:2379" - } - etcdEndpoints := strings.Split(endpoints, ",") - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - if err != nil { - log.Fatalf("New clientv3 error = %v", err) - } - kv := etcdkv.NewEtcdKV(etcdCli, "/etcd/test/root") - idAllocator := allocator.NewGlobalIDAllocator("dummy", kv) - _ = idAllocator.Initialize() - - err = server.InitRmq(name, idAllocator) - - if err != nil { - log.Fatalf("InitRmq error = %v", err) - } - return kv -} - -func Close(rocksdbName string, intputStream, outputStream msgstream.MsgStream, kv kv.MetaKv) { - server.CloseRocksMQ() - intputStream.Close() - outputStream.Close() - kv.Close() - err := os.RemoveAll(rocksdbName) - _ = os.RemoveAll(rocksdbName + "_meta_kv") - log.Println(err) -} - func initRmqStream(ctx context.Context, producerChannels []string, consumerChannels []string, @@ -449,15 +314,14 @@ func TestStream_RmqMsgStream_Insert(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 3)) - rocksdbName := "/tmp/rocksmq_insert" - kv := initRmq(rocksdbName) ctx := context.Background() inputStream, outputStream := initRmqStream(ctx, producerChannels, consumerChannels, consumerGroupName) err := inputStream.Produce(&msgPack) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) receiveMsg(ctx, outputStream, len(msgPack.Msgs)) - Close(rocksdbName, inputStream, outputStream, kv) + inputStream.Close() + outputStream.Close() } func TestStream_RmqTtMsgStream_Insert(t *testing.T) { @@ -475,8 +339,6 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) { msgPack2 := msgstream.MsgPack{} msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5)) - rocksdbName := "/tmp/rocksmq_insert_tt" - kv := initRmq(rocksdbName) ctx := context.Background() inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName) @@ -490,12 +352,11 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) { require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) receiveMsg(ctx, outputStream, len(msgPack1.Msgs)) - Close(rocksdbName, inputStream, outputStream, kv) + inputStream.Close() + outputStream.Close() } func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { - rocksdbName := "/tmp/rocksmq_tt_msg_seek" - kv := initRmq(rocksdbName) c1 := funcutil.RandomString(8) producerChannels := []string{c1} @@ -550,12 +411,11 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[1].Type()) assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[2].Type()) - Close(rocksdbName, inputStream, outputStream, kv) + inputStream.Close() + outputStream.Close() } func TestStream_RmqTtMsgStream_Seek(t *testing.T) { - rocksdbName := "/tmp/rocksmq_tt_msg_seek" - kv := initRmq(rocksdbName) c1 := funcutil.RandomString(8) producerChannels := []string{c1} @@ -662,12 +522,11 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) { assert.Equal(t, msg.BeginTs(), uint64(19)) } - Close(rocksdbName, inputStream, outputStream, kv) + inputStream.Close() + outputStream.Close() } func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) { - rocksdbName := "/tmp/rocksmq_tt_msg_seekInvalid" - kv := initRmq(rocksdbName) c := funcutil.RandomString(8) producerChannels := []string{c} consumerChannels := []string{c} @@ -721,7 +580,8 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) { result := consumer(ctx, outputStream2) assert.Equal(t, result.Msgs[0].ID(), int64(1)) - Close(rocksdbName, inputStream, outputStream2, kv) + inputStream.Close() + outputStream2.Close() } func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) { @@ -729,8 +589,6 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) { consumerChannels := []string{"insert1"} consumerSubName := "subInsert" - rocksdbName := "/tmp/rocksmq_asconsumer_withpos" - kv := initRmq(rocksdbName) factory := msgstream.ProtoUDFactory{} rmqClient, _ := NewClientWithDefaultOptions() @@ -756,7 +614,8 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) { assert.Equal(t, 1, len(pack.Msgs)) assert.EqualValues(t, 1000, pack.Msgs[0].BeginTs()) - Close(rocksdbName, inputStream, outputStream, kv) + inputStream.Close() + outputStream.Close() } func getTimeTickMsgPack(reqID msgstream.UniqueID) *msgstream.MsgPack {