milvus/internal/datanode/channel/channel_manager_test.go

349 lines
11 KiB
Go
Raw Normal View History

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package channel
import (
"context"
"os"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datanode/allocator"
"github.com/milvus-io/milvus/internal/flushcommon/broker"
"github.com/milvus-io/milvus/internal/flushcommon/pipeline"
"github.com/milvus-io/milvus/internal/flushcommon/syncmgr"
util2 "github.com/milvus-io/milvus/internal/flushcommon/util"
"github.com/milvus-io/milvus/internal/flushcommon/writebuffer"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestMain(t *testing.M) {
paramtable.Init()
code := t.Run()
os.Exit(code)
}
func TestChannelManagerSuite(t *testing.T) {
suite.Run(t, new(ChannelManagerSuite))
}
func TestOpRunnerSuite(t *testing.T) {
suite.Run(t, new(OpRunnerSuite))
}
func (s *OpRunnerSuite) SetupTest() {
mockedBroker := broker.NewMockBroker(s.T())
mockedBroker.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).
Return([]*datapb.SegmentInfo{}, nil).Maybe()
wbManager := writebuffer.NewMockBufferManager(s.T())
wbManager.EXPECT().
Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil).Maybe()
dispClient := msgdispatcher.NewMockClient(s.T())
dispClient.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(make(chan *msgstream.MsgPack), nil).Maybe()
dispClient.EXPECT().Deregister(mock.Anything).Maybe()
s.pipelineParams = &util2.PipelineParams{
Ctx: context.TODO(),
Session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 0}},
CheckpointUpdater: util2.NewChannelCheckpointUpdater(mockedBroker),
WriteBufferManager: wbManager,
Broker: mockedBroker,
DispClient: dispClient,
SyncMgr: syncmgr.NewMockSyncManager(s.T()),
Allocator: allocator.NewMockAllocator(s.T()),
}
}
func (s *OpRunnerSuite) TestWatchWithTimer() {
var (
channel string = "ch-1"
commuCh = make(chan *opState)
)
info := GetWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
mockReleaseFunc := func(channel string) {
log.Info("mock release func")
}
runner := NewOpRunner(channel, s.pipelineParams, mockReleaseFunc, executeWatch, commuCh)
err := runner.Enqueue(info)
s.Require().NoError(err)
opState := runner.watchWithTimer(info)
s.NotNil(opState.fg)
s.Equal(channel, opState.channel)
runner.FinishOp(100)
}
func (s *OpRunnerSuite) TestWatchTimeout() {
channel := "by-dev-rootcoord-dml-1000"
paramtable.Get().Save(paramtable.Get().DataCoordCfg.WatchTimeoutInterval.Key, "0.000001")
defer paramtable.Get().Reset(paramtable.Get().DataCoordCfg.WatchTimeoutInterval.Key)
info := GetWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
sig := make(chan struct{})
commuCh := make(chan *opState)
mockReleaseFunc := func(channel string) { log.Info("mock release func") }
mockWatchFunc := func(ctx context.Context, param *util2.PipelineParams, info *datapb.ChannelWatchInfo, tickler *util2.Tickler) (*pipeline.DataSyncService, error) {
<-ctx.Done()
sig <- struct{}{}
return nil, errors.New("timeout")
}
runner := NewOpRunner(channel, s.pipelineParams, mockReleaseFunc, mockWatchFunc, commuCh)
runner.Start()
defer runner.Close()
err := runner.Enqueue(info)
s.Require().NoError(err)
<-sig
opState := <-commuCh
s.Require().NotNil(opState)
s.Equal(info.GetOpID(), opState.opID)
s.Equal(datapb.ChannelWatchState_WatchFailure, opState.state)
}
type OpRunnerSuite struct {
suite.Suite
pipelineParams *util2.PipelineParams
}
type ChannelManagerSuite struct {
suite.Suite
pipelineParams *util2.PipelineParams
manager *ChannelManagerImpl
}
func (s *ChannelManagerSuite) SetupTest() {
factory := dependency.NewDefaultFactory(true)
wbManager := writebuffer.NewMockBufferManager(s.T())
wbManager.EXPECT().
Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil).Maybe()
wbManager.EXPECT().RemoveChannel(mock.Anything).Maybe()
mockedBroker := &broker.MockBroker{}
mockedBroker.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return([]*datapb.SegmentInfo{}, nil).Maybe()
s.pipelineParams = &util2.PipelineParams{
Ctx: context.TODO(),
Session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 0}},
WriteBufferManager: wbManager,
Broker: mockedBroker,
MsgStreamFactory: factory,
DispClient: msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID()),
SyncMgr: syncmgr.NewMockSyncManager(s.T()),
Allocator: allocator.NewMockAllocator(s.T()),
}
s.manager = NewChannelManager(s.pipelineParams, pipeline.NewFlowgraphManager())
}
func (s *ChannelManagerSuite) TearDownTest() {
if s.manager != nil {
s.manager.Close()
}
}
func (s *ChannelManagerSuite) TestReleaseStuck() {
channel := "by-dev-rootcoord-dml-2"
s.manager.releaseFunc = func(channel string) {
time.Sleep(1 * time.Second)
}
info := GetWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
s.Require().Equal(0, s.manager.opRunners.Len())
err := s.manager.Submit(info)
s.Require().NoError(err)
opState := <-s.manager.communicateCh
s.Require().NotNil(opState)
s.manager.handleOpState(opState)
releaseInfo := GetWatchInfoByOpID(101, channel, datapb.ChannelWatchState_ToRelease)
paramtable.Get().Save(paramtable.Get().DataCoordCfg.WatchTimeoutInterval.Key, "0.1")
defer paramtable.Get().Reset(paramtable.Get().DataCoordCfg.WatchTimeoutInterval.Key)
err = s.manager.Submit(releaseInfo)
s.NoError(err)
opState = <-s.manager.communicateCh
s.Require().NotNil(opState)
s.Equal(datapb.ChannelWatchState_ReleaseFailure, opState.state)
s.manager.handleOpState(opState)
s.Equal(1, s.manager.abnormals.Len())
abchannel, ok := s.manager.abnormals.Get(releaseInfo.GetOpID())
s.True(ok)
s.Equal(channel, abchannel)
resp := s.manager.GetProgress(releaseInfo)
s.Equal(datapb.ChannelWatchState_ReleaseFailure, resp.GetState())
}
func (s *ChannelManagerSuite) TestSubmitIdempotent() {
channel := "by-dev-rootcoord-dml-1"
info := GetWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
s.Require().Equal(0, s.manager.opRunners.Len())
for i := 0; i < 10; i++ {
err := s.manager.Submit(info)
s.NoError(err)
}
s.Equal(1, s.manager.opRunners.Len())
s.True(s.manager.opRunners.Contain(channel))
runner, ok := s.manager.opRunners.Get(channel)
s.True(ok)
s.Equal(1, runner.UnfinishedOpSize())
}
func (s *ChannelManagerSuite) TestSubmitSkip() {
channel := "by-dev-rootcoord-dml-1"
info := GetWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
s.Require().Equal(0, s.manager.opRunners.Len())
err := s.manager.Submit(info)
s.NoError(err)
s.Equal(1, s.manager.opRunners.Len())
s.True(s.manager.opRunners.Contain(channel))
opState := <-s.manager.communicateCh
s.NotNil(opState)
s.Equal(datapb.ChannelWatchState_WatchSuccess, opState.state)
s.NotNil(opState.fg)
s.Equal(info.GetOpID(), opState.fg.GetOpID())
s.manager.handleOpState(opState)
err = s.manager.Submit(info)
s.NoError(err)
runner, ok := s.manager.opRunners.Get(channel)
s.False(ok)
s.Nil(runner)
}
func (s *ChannelManagerSuite) TestSubmitWatchAndRelease() {
channel := "by-dev-rootcoord-dml-0"
// watch
info := GetWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
err := s.manager.Submit(info)
s.NoError(err)
// wait for result
opState := <-s.manager.communicateCh
s.NotNil(opState)
s.Equal(datapb.ChannelWatchState_WatchSuccess, opState.state)
s.NotNil(opState.fg)
s.Equal(info.GetOpID(), opState.fg.GetOpID())
resp := s.manager.GetProgress(info)
s.Equal(info.GetOpID(), resp.GetOpID())
s.Equal(datapb.ChannelWatchState_ToWatch, resp.GetState())
s.manager.handleOpState(opState)
s.Equal(1, s.manager.fgManager.GetFlowgraphCount())
s.False(s.manager.opRunners.Contain(info.GetVchan().GetChannelName()))
s.Equal(0, s.manager.opRunners.Len())
resp = s.manager.GetProgress(info)
s.Equal(info.GetOpID(), resp.GetOpID())
s.Equal(datapb.ChannelWatchState_WatchSuccess, resp.GetState())
// release
info = GetWatchInfoByOpID(101, channel, datapb.ChannelWatchState_ToRelease)
err = s.manager.Submit(info)
s.NoError(err)
// wait for result
opState = <-s.manager.communicateCh
s.NotNil(opState)
s.Equal(datapb.ChannelWatchState_ReleaseSuccess, opState.state)
s.manager.handleOpState(opState)
resp = s.manager.GetProgress(info)
s.Equal(info.GetOpID(), resp.GetOpID())
s.Equal(datapb.ChannelWatchState_ReleaseSuccess, resp.GetState())
s.Equal(0, s.manager.fgManager.GetFlowgraphCount())
s.False(s.manager.opRunners.Contain(info.GetVchan().GetChannelName()))
s.Equal(0, s.manager.opRunners.Len())
err = s.manager.Submit(info)
s.NoError(err)
runner, ok := s.manager.opRunners.Get(channel)
s.False(ok)
s.Nil(runner)
}
func GetWatchInfoByOpID(opID typeutil.UniqueID, channel string, state datapb.ChannelWatchState) *datapb.ChannelWatchInfo {
return &datapb.ChannelWatchInfo{
OpID: opID,
State: state,
Vchan: &datapb.VchannelInfo{
CollectionID: 1,
ChannelName: channel,
},
Schema: &schemapb.CollectionSchema{
Name: "test_collection",
Fields: []*schemapb.FieldSchema{
{
FieldID: common.RowIDField, Name: common.RowIDFieldName, DataType: schemapb.DataType_Int64,
},
{
FieldID: common.TimeStampField, Name: common.TimeStampFieldName, DataType: schemapb.DataType_Int64,
},
{
FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true,
},
{
FieldID: 101, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
},
},
},
}
}