diff --git a/internal/util/flowgraph/flow_graph.go b/internal/util/flowgraph/flow_graph.go index 373aab4650..28e5002daf 100644 --- a/internal/util/flowgraph/flow_graph.go +++ b/internal/util/flowgraph/flow_graph.go @@ -27,6 +27,7 @@ type TimeTickedFlowGraph struct { nodeCtx map[NodeName]*nodeCtx stopOnce sync.Once startOnce sync.Once + closeWg *sync.WaitGroup } // AddNode add Node into flowgraph @@ -35,6 +36,7 @@ func (fg *TimeTickedFlowGraph) AddNode(node Node) { node: node, downstreamInputChanIdx: make(map[string]int), closeCh: make(chan struct{}), + closeWg: fg.closeWg, } fg.nodeCtx[node.Name()] = &nodeCtx } @@ -92,6 +94,7 @@ func (fg *TimeTickedFlowGraph) Close() { v.Close() } } + fg.closeWg.Wait() }) } @@ -99,6 +102,7 @@ func (fg *TimeTickedFlowGraph) Close() { func NewTimeTickedFlowGraph(ctx context.Context) *TimeTickedFlowGraph { flowGraph := TimeTickedFlowGraph{ nodeCtx: make(map[string]*nodeCtx), + closeWg: &sync.WaitGroup{}, } return &flowGraph diff --git a/internal/util/flowgraph/node.go b/internal/util/flowgraph/node.go index 29cb8a9a3e..997a193c5f 100644 --- a/internal/util/flowgraph/node.go +++ b/internal/util/flowgraph/node.go @@ -59,12 +59,14 @@ type nodeCtx struct { downstreamInputChanIdx map[string]int closeCh chan struct{} // notify work to exit + closeWg *sync.WaitGroup } // Start invoke Node `Start` method and start a worker goroutine func (nodeCtx *nodeCtx) Start() { nodeCtx.node.Start() + nodeCtx.closeWg.Add(1) go nodeCtx.work() } @@ -114,6 +116,7 @@ func (nodeCtx *nodeCtx) work() { // the res decide whether the node should be closed. if isCloseMsg(res) { close(nodeCtx.closeCh) + nodeCtx.closeWg.Done() nodeCtx.node.Close() } diff --git a/internal/util/flowgraph/node_test.go b/internal/util/flowgraph/node_test.go index e700e3b3e0..6bc7957e10 100644 --- a/internal/util/flowgraph/node_test.go +++ b/internal/util/flowgraph/node_test.go @@ -20,6 +20,7 @@ import ( "context" "math" "os" + "sync" "testing" "time" @@ -80,6 +81,7 @@ func TestNodeCtx_Start(t *testing.T) { inputChannels: make([]chan Msg, 2), downstreamInputChanIdx: make(map[string]int), closeCh: make(chan struct{}), + closeWg: &sync.WaitGroup{}, } for i := 0; i < len(node.inputChannels); i++ {