diff --git a/internal/indexcoord/index_coord.go b/internal/indexcoord/index_coord.go index 33782e4405..0284041e96 100644 --- a/internal/indexcoord/index_coord.go +++ b/internal/indexcoord/index_coord.go @@ -55,6 +55,7 @@ type IndexCoord struct { sched *TaskScheduler session *sessionutil.Session + liveCh <-chan bool eventChan <-chan *sessionutil.SessionEvent @@ -106,7 +107,7 @@ func (i *IndexCoord) Register() error { if i.session == nil { return errors.New("failed to initialize session") } - i.session.Init(typeutil.IndexCoordRole, Params.Address, true) + i.liveCh = i.session.Init(typeutil.IndexCoordRole, Params.Address, true) return nil } @@ -236,6 +237,10 @@ func (i *IndexCoord) Start() error { i.loopWg.Add(1) go i.watchMetaLoop() + go i.session.LivenessCheck(i.loopCtx, i.liveCh, func() { + i.Stop() + }) + startErr = i.sched.Start() i.UpdateStateCode(internalpb.StateCode_Healthy) diff --git a/internal/indexnode/indexnode.go b/internal/indexnode/indexnode.go index e2ae1373be..9b4945f893 100644 --- a/internal/indexnode/indexnode.go +++ b/internal/indexnode/indexnode.go @@ -63,6 +63,7 @@ type IndexNode struct { kv kv.BaseKV session *sessionutil.Session + liveCh <-chan bool // Add callback functions at different stages startCallbacks []func() @@ -97,7 +98,7 @@ func (i *IndexNode) Register() error { if i.session == nil { return errors.New("failed to initialize session") } - i.session.Init(typeutil.IndexNodeRole, Params.IP+":"+strconv.Itoa(Params.Port), false) + i.liveCh = i.session.Init(typeutil.IndexNodeRole, Params.IP+":"+strconv.Itoa(Params.Port), false) Params.NodeID = i.session.ServerID return nil } @@ -150,6 +151,11 @@ func (i *IndexNode) Init() error { func (i *IndexNode) Start() error { i.sched.Start() + //start liveness check + go i.session.LivenessCheck(i.loopCtx, i.liveCh, func() { + i.Stop() + }) + i.UpdateStateCode(internalpb.StateCode_Healthy) log.Debug("IndexNode", zap.Any("State", i.stateCode.Load())) // Start callbacks diff --git a/internal/querycoord/query_coord.go b/internal/querycoord/query_coord.go index e581038b87..b85a75849f 100644 --- a/internal/querycoord/query_coord.go +++ b/internal/querycoord/query_coord.go @@ -62,6 +62,7 @@ type QueryCoord struct { rootCoordClient types.RootCoord session *sessionutil.Session + liveCh <-chan bool eventChan <-chan *sessionutil.SessionEvent stateCode atomic.Value @@ -75,7 +76,7 @@ type QueryCoord struct { func (qc *QueryCoord) Register() error { log.Debug("query coord session info", zap.String("metaPath", Params.MetaRootPath), zap.Strings("etcdEndPoints", Params.EtcdEndpoints), zap.String("address", Params.Address)) qc.session = sessionutil.NewSession(qc.loopCtx, Params.MetaRootPath, Params.EtcdEndpoints) - qc.session.Init(typeutil.QueryCoordRole, Params.Address, true) + qc.liveCh = qc.session.Init(typeutil.QueryCoordRole, Params.Address, true) Params.NodeID = uint64(qc.session.ServerID) return nil } @@ -130,6 +131,10 @@ func (qc *QueryCoord) Start() error { qc.loopWg.Add(1) go qc.watchMetaLoop() + go qc.session.LivenessCheck(qc.loopCtx, qc.liveCh, func() { + qc.Stop() + }) + return nil } diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index b1988b8bc9..1e7fbce667 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -50,6 +50,9 @@ type QueryNode struct { stateCode atomic.Value + // liveness channel with etcd + liveCh <-chan bool + // internal components historical *historical streaming *streaming @@ -89,7 +92,7 @@ func NewQueryNode(ctx context.Context, factory msgstream.Factory) *QueryNode { func (node *QueryNode) Register() error { log.Debug("query node session info", zap.String("metaPath", Params.MetaRootPath), zap.Strings("etcdEndPoints", Params.EtcdEndpoints)) node.session = sessionutil.NewSession(node.queryNodeLoopCtx, Params.MetaRootPath, Params.EtcdEndpoints) - node.session.Init(typeutil.QueryNodeRole, Params.QueryNodeIP+":"+strconv.FormatInt(Params.QueryNodePort, 10), false) + node.liveCh = node.session.Init(typeutil.QueryNodeRole, Params.QueryNodeIP+":"+strconv.FormatInt(Params.QueryNodePort, 10), false) Params.QueryNodeID = node.session.ServerID log.Debug("query nodeID", zap.Int64("nodeID", Params.QueryNodeID)) log.Debug("query node address", zap.String("address", node.session.Address)) @@ -173,6 +176,12 @@ func (node *QueryNode) Start() error { // start services go node.historical.start() + + // start liveness check + go node.session.LivenessCheck(node.queryNodeLoopCtx, node.liveCh, func() { + node.Stop() + }) + node.UpdateStateCode(internalpb.StateCode_Healthy) return nil } diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index 015392a654..ff7e1dccdb 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -335,3 +335,27 @@ func (s *Session) WatchServices(prefix string, revision int64) (eventChannel <-c }() return eventCh } + +// LivenessCheck performs liveness check with provided context and channel +// ctx controls the liveness check loop +// ch is the liveness signal channel, ch is closed only when the session is expired +// callback is the function to call when ch is closed, note that callback will not be invoked when loop exits due to context +func (s *Session) LivenessCheck(ctx context.Context, ch <-chan bool, callback func()) { + for { + select { + case _, ok := <-ch: + // ok, still alive + if ok { + continue + } + // not ok, connection lost + log.Warn("connection lost detected, shuting down") + if callback != nil { + go callback() + } + return + case <-ctx.Done(): + log.Debug("liveness exits due to context done") + } + } +} diff --git a/internal/util/sessionutil/session_util_test.go b/internal/util/sessionutil/session_util_test.go index 36a914594b..c4bd19fd9f 100644 --- a/internal/util/sessionutil/session_util_test.go +++ b/internal/util/sessionutil/session_util_test.go @@ -165,3 +165,39 @@ func TestUpdateSessions(t *testing.T) { assert.Equal(t, addEventLen, 10) assert.Equal(t, delEventLen, 10) } + +func TestSessionLivenessCheck(t *testing.T) { + s := &Session{} + ctx := context.Background() + ch := make(chan bool) + signal := make(chan struct{}, 1) + + flag := false + + go s.LivenessCheck(ctx, ch, func() { + flag = true + signal <- struct{}{} + }) + + assert.False(t, flag) + ch <- true + + assert.False(t, flag) + close(ch) + + <-signal + assert.True(t, flag) + + ctx, cancel := context.WithCancel(ctx) + cancel() + ch = make(chan bool) + + flag = false + + go s.LivenessCheck(ctx, ch, func() { + flag = true + signal <- struct{}{} + }) + + assert.False(t, flag) +}