fix: [2.4]channel manager's goroutine run order (#33121)

See also: #33117
pr: #33118

---------

Signed-off-by: yangxuan <xuan.yang@zilliz.com>
pull/33218/head
XuanYang-cn 2024-05-21 14:31:39 +08:00 committed by GitHub
parent a27a2e8021
commit b2f7d7ba4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 41 deletions

View File

@ -32,7 +32,10 @@ import (
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
) )
type releaseFunc func(channel string) type (
releaseFunc func(channel string)
watchFunc func(ctx context.Context, dn *DataNode, info *datapb.ChannelWatchInfo, tickler *tickler) (*dataSyncService, error)
)
type ChannelManager interface { type ChannelManager interface {
Submit(info *datapb.ChannelWatchInfo) error Submit(info *datapb.ChannelWatchInfo) error
@ -206,7 +209,7 @@ func (m *ChannelManagerImpl) handleOpState(opState *opState) {
} }
func (m *ChannelManagerImpl) getOrCreateRunner(channel string) *opRunner { func (m *ChannelManagerImpl) getOrCreateRunner(channel string) *opRunner {
runner, loaded := m.opRunners.GetOrInsert(channel, NewOpRunner(channel, m.dn, m.releaseFunc, m.communicateCh)) runner, loaded := m.opRunners.GetOrInsert(channel, NewOpRunner(channel, m.dn, m.releaseFunc, executeWatch, m.communicateCh))
if !loaded { if !loaded {
runner.Start() runner.Start()
} }
@ -228,6 +231,7 @@ type opRunner struct {
channel string channel string
dn *DataNode dn *DataNode
releaseFunc releaseFunc releaseFunc releaseFunc
watchFunc watchFunc
guard sync.RWMutex guard sync.RWMutex
allOps map[UniqueID]*opInfo // opID -> tickler allOps map[UniqueID]*opInfo // opID -> tickler
@ -238,11 +242,12 @@ type opRunner struct {
closeWg sync.WaitGroup closeWg sync.WaitGroup
} }
func NewOpRunner(channel string, dn *DataNode, f releaseFunc, resultCh chan *opState) *opRunner { func NewOpRunner(channel string, dn *DataNode, releaseF releaseFunc, watchF watchFunc, resultCh chan *opState) *opRunner {
return &opRunner{ return &opRunner{
channel: channel, channel: channel,
dn: dn, dn: dn,
releaseFunc: f, releaseFunc: releaseF,
watchFunc: watchF,
opsInQueue: make(chan *datapb.ChannelWatchInfo, 10), opsInQueue: make(chan *datapb.ChannelWatchInfo, 10),
allOps: make(map[UniqueID]*opInfo), allOps: make(map[UniqueID]*opInfo),
resultCh: resultCh, resultCh: resultCh,
@ -334,15 +339,15 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState {
var ( var (
successSig = make(chan struct{}, 1) successSig = make(chan struct{}, 1)
waiter sync.WaitGroup finishWaiter sync.WaitGroup
) )
watchTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second) watchTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), watchTimeout) ctx, cancel := context.WithTimeout(context.Background(), watchTimeout)
defer cancel() defer cancel()
startTimer := func(wg *sync.WaitGroup) { startTimer := func(finishWg *sync.WaitGroup) {
defer wg.Done() defer finishWg.Done()
timer := time.NewTimer(watchTimeout) timer := time.NewTimer(watchTimeout)
defer timer.Stop() defer timer.Stop()
@ -377,11 +382,11 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState {
} }
} }
waiter.Add(2) finishWaiter.Add(2)
go startTimer(&waiter) go startTimer(&finishWaiter)
go func() { go func() {
defer waiter.Done() defer finishWaiter.Done()
fg, err := executeWatch(ctx, r.dn, info, tickler) fg, err := r.watchFunc(ctx, r.dn, info, tickler)
if err != nil { if err != nil {
opState.state = datapb.ChannelWatchState_WatchFailure opState.state = datapb.ChannelWatchState_WatchFailure
} else { } else {
@ -391,7 +396,7 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState {
} }
}() }()
waiter.Wait() finishWaiter.Wait()
return opState return opState
} }
@ -403,12 +408,13 @@ func (r *opRunner) releaseWithTimer(releaseFunc releaseFunc, channel string, opI
} }
var ( var (
successSig = make(chan struct{}, 1) successSig = make(chan struct{}, 1)
waiter sync.WaitGroup finishWaiter sync.WaitGroup
) )
log := log.With(zap.Int64("opID", opID), zap.String("channel", channel)) log := log.With(zap.Int64("opID", opID), zap.String("channel", channel))
startTimer := func(wg *sync.WaitGroup) { startTimer := func(finishWaiter *sync.WaitGroup) {
defer wg.Done() defer finishWaiter.Done()
releaseTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second) releaseTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second)
timer := time.NewTimer(releaseTimeout) timer := time.NewTimer(releaseTimeout)
defer timer.Stop() defer timer.Stop()
@ -435,8 +441,8 @@ func (r *opRunner) releaseWithTimer(releaseFunc releaseFunc, channel string, opI
} }
} }
waiter.Add(1) finishWaiter.Add(1)
go startTimer(&waiter) go startTimer(&finishWaiter)
go func() { go func() {
// TODO: failure should panic this DN, but we're not sure how // TODO: failure should panic this DN, but we're not sure how
// to recover when releaseFunc stuck. // to recover when releaseFunc stuck.
@ -450,7 +456,7 @@ func (r *opRunner) releaseWithTimer(releaseFunc releaseFunc, channel string, opI
successSig <- struct{}{} successSig <- struct{}{}
}() }()
waiter.Wait() finishWaiter.Wait()
return opState return opState
} }

View File

@ -20,6 +20,7 @@ import (
"context" "context"
"testing" "testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -56,7 +57,7 @@ func (s *OpRunnerSuite) TestWatchWithTimer() {
mockReleaseFunc := func(channel string) { mockReleaseFunc := func(channel string) {
log.Info("mock release func") log.Info("mock release func")
} }
runner := NewOpRunner(channel, s.node, mockReleaseFunc, commuCh) runner := NewOpRunner(channel, s.node, mockReleaseFunc, executeWatch, commuCh)
err := runner.Enqueue(info) err := runner.Enqueue(info)
s.Require().NoError(err) s.Require().NoError(err)
@ -67,6 +68,35 @@ func (s *OpRunnerSuite) TestWatchWithTimer() {
runner.FinishOp(100) runner.FinishOp(100)
} }
func (s *OpRunnerSuite) TestWatchTimeout() {
channel := "by-dev-rootcoord-dml-1000"
paramtable.Get().Save(Params.DataCoordCfg.WatchTimeoutInterval.Key, "0.000001")
defer paramtable.Get().Reset(Params.DataCoordCfg.WatchTimeoutInterval.Key)
info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
sig := make(chan struct{})
commuCh := make(chan *opState)
mockReleaseFunc := func(channel string) { log.Info("mock release func") }
mockWatchFunc := func(ctx context.Context, dn *DataNode, info *datapb.ChannelWatchInfo, tickler *tickler) (*dataSyncService, error) {
<-ctx.Done()
sig <- struct{}{}
return nil, errors.New("timeout")
}
runner := NewOpRunner(channel, s.node, mockReleaseFunc, mockWatchFunc, commuCh)
runner.Start()
defer runner.Close()
err := runner.Enqueue(info)
s.Require().NoError(err)
<-sig
opState := <-commuCh
s.Require().NotNil(opState)
s.Equal(info.GetOpID(), opState.opID)
s.Equal(datapb.ChannelWatchState_WatchFailure, opState.state)
}
type OpRunnerSuite struct { type OpRunnerSuite struct {
suite.Suite suite.Suite
node *DataNode node *DataNode
@ -126,26 +156,6 @@ func (s *ChannelManagerSuite) TearDownTest() {
} }
} }
func (s *ChannelManagerSuite) TestWatchFail() {
channel := "by-dev-rootcoord-dml-2"
paramtable.Get().Save(Params.DataCoordCfg.WatchTimeoutInterval.Key, "0.000001")
defer paramtable.Get().Reset(Params.DataCoordCfg.WatchTimeoutInterval.Key)
info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
s.Require().Equal(0, s.manager.opRunners.Len())
err := s.manager.Submit(info)
s.Require().NoError(err)
opState := <-s.manager.communicateCh
s.Require().NotNil(opState)
s.Equal(info.GetOpID(), opState.opID)
s.Equal(datapb.ChannelWatchState_WatchFailure, opState.state)
s.manager.handleOpState(opState)
resp := s.manager.GetProgress(info)
s.Equal(datapb.ChannelWatchState_WatchFailure, resp.GetState())
}
func (s *ChannelManagerSuite) TestReleaseStuck() { func (s *ChannelManagerSuite) TestReleaseStuck() {
var ( var (
channel = "by-dev-rootcoord-dml-2" channel = "by-dev-rootcoord-dml-2"