mirror of https://github.com/milvus-io/milvus.git
Fix data race in flow graph (#6946)
* Fix data race in flow graph Signed-off-by: bigsheeper <yihao.dai@zilliz.com> * add cancel func to flowgraph Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/6985/head
parent
c8a1f780c1
commit
07cc449fbf
|
@ -20,6 +20,7 @@ import (
|
|||
|
||||
type TimeTickedFlowGraph struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
nodeCtx map[NodeName]*nodeCtx
|
||||
}
|
||||
|
||||
|
@ -89,11 +90,14 @@ func (fg *TimeTickedFlowGraph) Close() {
|
|||
// }
|
||||
v.Close()
|
||||
}
|
||||
fg.cancel()
|
||||
}
|
||||
|
||||
func NewTimeTickedFlowGraph(ctx context.Context) *TimeTickedFlowGraph {
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
flowGraph := TimeTickedFlowGraph{
|
||||
ctx: ctx,
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
nodeCtx: make(map[string]*nodeCtx),
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ func (inNode *InputNode) IsInputNode() bool {
|
|||
}
|
||||
|
||||
func (inNode *InputNode) Close() {
|
||||
(*inNode.inStream).Close()
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func (inNode *InputNode) Name() string {
|
||||
|
|
|
@ -14,9 +14,12 @@ package flowgraph
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
)
|
||||
|
||||
type Node interface {
|
||||
|
@ -49,7 +52,7 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
|
|||
// fmt.Println("start InputNode.inStream")
|
||||
inStream, ok := nodeCtx.node.(*InputNode)
|
||||
if !ok {
|
||||
log.Fatal("Invalid inputNode")
|
||||
log.Error("Invalid inputNode")
|
||||
}
|
||||
(*inStream.inStream).Start()
|
||||
}
|
||||
|
@ -57,6 +60,16 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
|
|||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if nodeCtx.node.IsInputNode() {
|
||||
inStream, ok := nodeCtx.node.(*InputNode)
|
||||
if !ok {
|
||||
log.Error("Invalid inputNode")
|
||||
}
|
||||
(*inStream.inStream).Close()
|
||||
log.Debug("message stream closed",
|
||||
zap.Any("node name", inStream.name),
|
||||
)
|
||||
}
|
||||
wg.Done()
|
||||
//fmt.Println(nodeCtx.node.Name(), "closed")
|
||||
return
|
||||
|
@ -74,7 +87,7 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
|
|||
|
||||
downstreamLength := len(nodeCtx.downstreamInputChanIdx)
|
||||
if len(nodeCtx.downstream) < downstreamLength {
|
||||
log.Println("nodeCtx.downstream length = ", len(nodeCtx.downstream))
|
||||
log.Warn("", zap.Any("nodeCtx.downstream length", len(nodeCtx.downstream)))
|
||||
}
|
||||
if len(res) < downstreamLength {
|
||||
// log.Println("node result length = ", len(res))
|
||||
|
@ -104,7 +117,7 @@ func (nodeCtx *nodeCtx) ReceiveMsg(wg *sync.WaitGroup, msg Msg, inputChanIdx int
|
|||
defer func() {
|
||||
err := recover()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
log.Warn(fmt.Sprintln(err))
|
||||
}
|
||||
}()
|
||||
nodeCtx.inputChannels[inputChanIdx] <- msg
|
||||
|
@ -126,7 +139,7 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) {
|
|||
case msg, ok := <-channel:
|
||||
if !ok {
|
||||
// TODO: add status
|
||||
log.Println("input channel closed")
|
||||
log.Warn("input channel closed")
|
||||
return
|
||||
}
|
||||
nodeCtx.inputMessages[i] = msg
|
||||
|
@ -155,7 +168,7 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) {
|
|||
return
|
||||
case msg, ok := <-channel:
|
||||
if !ok {
|
||||
log.Println("input channel closed")
|
||||
log.Warn("input channel closed")
|
||||
return
|
||||
}
|
||||
nodeCtx.inputMessages[i] = msg
|
||||
|
|
Loading…
Reference in New Issue