Fix data race in flow graph (#6946)

* Fix data race in flow graph

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>

* add cancel func to flowgraph

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/6985/head
bigsheeper 2021-08-03 22:43:25 +08:00 committed by GitHub
parent c8a1f780c1
commit 07cc449fbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 8 deletions

View File

@ -20,6 +20,7 @@ import (
type TimeTickedFlowGraph struct {
ctx context.Context
cancel context.CancelFunc
nodeCtx map[NodeName]*nodeCtx
}
@ -89,11 +90,14 @@ func (fg *TimeTickedFlowGraph) Close() {
// }
v.Close()
}
fg.cancel()
}
func NewTimeTickedFlowGraph(ctx context.Context) *TimeTickedFlowGraph {
ctx1, cancel := context.WithCancel(ctx)
flowGraph := TimeTickedFlowGraph{
ctx: ctx,
ctx: ctx1,
cancel: cancel,
nodeCtx: make(map[string]*nodeCtx),
}

View File

@ -29,7 +29,7 @@ func (inNode *InputNode) IsInputNode() bool {
}
func (inNode *InputNode) Close() {
(*inNode.inStream).Close()
// do nothing
}
func (inNode *InputNode) Name() string {

View File

@ -14,9 +14,12 @@ package flowgraph
import (
"context"
"fmt"
"log"
"sync"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
)
type Node interface {
@ -49,7 +52,7 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
// fmt.Println("start InputNode.inStream")
inStream, ok := nodeCtx.node.(*InputNode)
if !ok {
log.Fatal("Invalid inputNode")
log.Error("Invalid inputNode")
}
(*inStream.inStream).Start()
}
@ -57,6 +60,16 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
for {
select {
case <-ctx.Done():
if nodeCtx.node.IsInputNode() {
inStream, ok := nodeCtx.node.(*InputNode)
if !ok {
log.Error("Invalid inputNode")
}
(*inStream.inStream).Close()
log.Debug("message stream closed",
zap.Any("node name", inStream.name),
)
}
wg.Done()
//fmt.Println(nodeCtx.node.Name(), "closed")
return
@ -74,7 +87,7 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
downstreamLength := len(nodeCtx.downstreamInputChanIdx)
if len(nodeCtx.downstream) < downstreamLength {
log.Println("nodeCtx.downstream length = ", len(nodeCtx.downstream))
log.Warn("", zap.Any("nodeCtx.downstream length", len(nodeCtx.downstream)))
}
if len(res) < downstreamLength {
// log.Println("node result length = ", len(res))
@ -104,7 +117,7 @@ func (nodeCtx *nodeCtx) ReceiveMsg(wg *sync.WaitGroup, msg Msg, inputChanIdx int
defer func() {
err := recover()
if err != nil {
log.Println(err)
log.Warn(fmt.Sprintln(err))
}
}()
nodeCtx.inputChannels[inputChanIdx] <- msg
@ -126,7 +139,7 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) {
case msg, ok := <-channel:
if !ok {
// TODO: add status
log.Println("input channel closed")
log.Warn("input channel closed")
return
}
nodeCtx.inputMessages[i] = msg
@ -155,7 +168,7 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) {
return
case msg, ok := <-channel:
if !ok {
log.Println("input channel closed")
log.Warn("input channel closed")
return
}
nodeCtx.inputMessages[i] = msg