Add flow graph

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/4973/head^2
bigsheeper 2020-11-02 16:44:54 +08:00 committed by yefu.chen
parent 3e596fa474
commit 8020dc2256
4 changed files with 467 additions and 0 deletions

View File

@ -0,0 +1,97 @@
package flowgraph
import (
"context"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
"sync"
)
type Timestamp = typeutil.Timestamp
type flowGraphStates struct {
startTick Timestamp
numActiveTasks map[string]int64
numCompletedTasks map[string]int64
}
type TimeTickedFlowGraph struct {
ctx context.Context
states *flowGraphStates
nodeCtx map[string]*nodeCtx
}
func (fg *TimeTickedFlowGraph) AddNode(node *Node) {
nodeName := (*node).Name()
nodeCtx := nodeCtx{
node: node,
inputChannels: make([]chan *Msg, 0),
downstreamInputChanIdx: make(map[string]int),
}
fg.nodeCtx[nodeName] = &nodeCtx
}
func (fg *TimeTickedFlowGraph) SetEdges(nodeName string, in []string, out []string) error {
currentNode, ok := fg.nodeCtx[nodeName]
if !ok {
errMsg := "Cannot find node:" + nodeName
return errors.New(errMsg)
}
// init current node's downstream
currentNode.downstream = make([]*nodeCtx, len(out))
// set in nodes
for i, inNodeName := range in {
inNode, ok := fg.nodeCtx[inNodeName]
if !ok {
errMsg := "Cannot find in node:" + inNodeName
return errors.New(errMsg)
}
inNode.downstreamInputChanIdx[nodeName] = i
}
// set out nodes
for i, n := range out {
outNode, ok := fg.nodeCtx[n]
if !ok {
errMsg := "Cannot find out node:" + n
return errors.New(errMsg)
}
maxQueueLength := (*outNode.node).MaxQueueLength()
outNode.inputChannels = append(outNode.inputChannels, make(chan *Msg, maxQueueLength))
currentNode.downstream[i] = outNode
}
return nil
}
func (fg *TimeTickedFlowGraph) Start() {
wg := sync.WaitGroup{}
for _, v := range fg.nodeCtx {
wg.Add(1)
go v.Start(fg.ctx, &wg)
}
wg.Wait()
}
func (fg *TimeTickedFlowGraph) Close() error {
for _, v := range fg.nodeCtx {
v.Close()
}
return nil
}
func NewTimeTickedFlowGraph(ctx context.Context) *TimeTickedFlowGraph {
flowGraph := TimeTickedFlowGraph{
ctx: ctx,
states: &flowGraphStates{
startTick: 0,
numActiveTasks: make(map[string]int64),
numCompletedTasks: make(map[string]int64),
},
nodeCtx: make(map[string]*nodeCtx),
}
return &flowGraph
}

View File

@ -0,0 +1,246 @@
package flowgraph
import (
"context"
"fmt"
"log"
"math"
"math/rand"
"sync"
"testing"
"time"
)
const ctxTimeInMillisecond = 3000
type nodeA struct {
baseNode
a float64
}
type nodeB struct {
baseNode
b float64
}
type nodeC struct {
baseNode
c float64
}
type nodeD struct {
baseNode
d float64
resChan chan float64
}
type intMsg struct {
num float64
t Timestamp
}
func (m *intMsg) TimeTick() Timestamp {
return m.t
}
func (m *intMsg) DownStreamNodeIdx() int32 {
return 1
}
func intMsg2Msg(in []*intMsg) []*Msg {
out := make([]*Msg, 0)
for _, msg := range in {
var m Msg = msg
out = append(out, &m)
}
return out
}
func msg2IntMsg(in []*Msg) []*intMsg {
out := make([]*intMsg, 0)
for _, msg := range in {
out = append(out, (*msg).(*intMsg))
}
return out
}
func (a *nodeA) Name() string {
return "NodeA"
}
func (a *nodeA) Operate(in []*Msg) []*Msg {
return append(in, in...)
}
func (b *nodeB) Name() string {
return "NodeB"
}
func (b *nodeB) Operate(in []*Msg) []*Msg {
messages := make([]*intMsg, 0)
for _, msg := range msg2IntMsg(in) {
messages = append(messages, &intMsg{
num: math.Pow(msg.num, 2),
})
}
return intMsg2Msg(messages)
}
func (c *nodeC) Name() string {
return "NodeC"
}
func (c *nodeC) Operate(in []*Msg) []*Msg {
messages := make([]*intMsg, 0)
for _, msg := range msg2IntMsg(in) {
messages = append(messages, &intMsg{
num: math.Sqrt(msg.num),
})
}
return intMsg2Msg(messages)
}
func (d *nodeD) Name() string {
return "NodeD"
}
func (d *nodeD) Operate(in []*Msg) []*Msg {
messages := make([]*intMsg, 0)
outLength := len(in) / 2
inMessages := msg2IntMsg(in)
for i := 0; i < outLength; i++ {
var msg = &intMsg{
num: inMessages[i].num + inMessages[i+outLength].num,
}
messages = append(messages, msg)
}
d.d = messages[0].num
d.resChan <- d.d
fmt.Println("flow graph result:", d.d)
return intMsg2Msg(messages)
}
func sendMsgFromCmd(ctx context.Context, fg *TimeTickedFlowGraph) {
for {
select {
case <-ctx.Done():
return
default:
time.Sleep(time.Millisecond * time.Duration(500))
var num = float64(rand.Int() % 100)
var msg Msg = &intMsg{num: num}
a := nodeA{}
fg.nodeCtx[a.Name()].inputChannels[0] <- &msg
fmt.Println("send number", num, "to node", a.Name())
res, ok := receiveResult(ctx, fg)
if !ok {
return
}
// assert result
if res != math.Pow(num, 2)+math.Sqrt(num) {
fmt.Println(res)
fmt.Println(math.Pow(num, 2) + math.Sqrt(num))
panic("wrong answer")
}
}
}
}
func receiveResultFromNodeD(res *float64, fg *TimeTickedFlowGraph, wg *sync.WaitGroup) {
d := nodeD{}
node := fg.nodeCtx[d.Name()]
nd, ok := (*node.node).(*nodeD)
if !ok {
log.Fatal("not nodeD type")
}
*res = <-nd.resChan
wg.Done()
}
func receiveResult(ctx context.Context, fg *TimeTickedFlowGraph) (float64, bool) {
d := nodeD{}
node := fg.nodeCtx[d.Name()]
nd, ok := (*node.node).(*nodeD)
if !ok {
log.Fatal("not nodeD type")
}
select {
case <-ctx.Done():
return 0, false
case res := <-nd.resChan:
return res, true
}
}
func TestTimeTickedFlowGraph_Start(t *testing.T) {
duration := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, _ := context.WithDeadline(context.Background(), duration)
fg := NewTimeTickedFlowGraph(ctx)
var a Node = &nodeA{
baseNode: baseNode{
maxQueueLength: maxQueueLength,
},
}
var b Node = &nodeB{
baseNode: baseNode{
maxQueueLength: maxQueueLength,
},
}
var c Node = &nodeC{
baseNode: baseNode{
maxQueueLength: maxQueueLength,
},
}
var d Node = &nodeD{
baseNode: baseNode{
maxQueueLength: maxQueueLength,
},
resChan: make(chan float64),
}
fg.AddNode(&a)
fg.AddNode(&b)
fg.AddNode(&c)
fg.AddNode(&d)
var err = fg.SetEdges(a.Name(),
[]string{},
[]string{b.Name(), c.Name()},
)
if err != nil {
log.Fatal("set edges failed")
}
err = fg.SetEdges(b.Name(),
[]string{a.Name()},
[]string{d.Name()},
)
if err != nil {
log.Fatal("set edges failed")
}
err = fg.SetEdges(c.Name(),
[]string{a.Name()},
[]string{d.Name()},
)
if err != nil {
log.Fatal("set edges failed")
}
err = fg.SetEdges(d.Name(),
[]string{b.Name(), c.Name()},
[]string{},
)
if err != nil {
log.Fatal("set edges failed")
}
// init node A
nodeCtxA := fg.nodeCtx[a.Name()]
nodeCtxA.inputChannels = []chan *Msg{make(chan *Msg, 10)}
go fg.Start()
sendMsgFromCmd(ctx, fg)
}

View File

@ -0,0 +1,6 @@
package flowgraph
type Msg interface {
TimeTick() Timestamp
DownStreamNodeIdx() int32
}

View File

@ -0,0 +1,118 @@
package flowgraph
import (
"context"
"fmt"
"sync"
)
const maxQueueLength = 1024
type Node interface {
Name() string
MaxQueueLength() int32
MaxParallelism() int32
SetPipelineStates(states *flowGraphStates)
Operate(in []*Msg) []*Msg
}
type baseNode struct {
maxQueueLength int32
maxParallelism int32
graphStates *flowGraphStates
}
type nodeCtx struct {
node *Node
inputChannels []chan *Msg
inputMessages [][]*Msg
downstream []*nodeCtx
downstreamInputChanIdx map[string]int
}
func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
for {
select {
case <-ctx.Done():
wg.Done()
return
default:
if !nodeCtx.allUpstreamDone() {
continue
}
nodeCtx.getMessagesFromChannel()
// inputs from inputsMessages for Operate
inputs := make([]*Msg, 0)
for i := 0; i < len(nodeCtx.inputMessages); i++ {
inputs = append(inputs, nodeCtx.inputMessages[i]...)
}
n := *nodeCtx.node
res := n.Operate(inputs)
wg := sync.WaitGroup{}
for i := 0; i < len(nodeCtx.downstreamInputChanIdx); i++ {
wg.Add(1)
go nodeCtx.downstream[i].ReceiveMsg(&wg, res[i], nodeCtx.downstreamInputChanIdx[(*nodeCtx.downstream[i].node).Name()])
}
wg.Wait()
}
}
}
func (nodeCtx *nodeCtx) Close() {
for _, channel := range nodeCtx.inputChannels {
close(channel)
}
}
func (nodeCtx *nodeCtx) ReceiveMsg(wg *sync.WaitGroup, msg *Msg, inputChanIdx int) {
nodeCtx.inputChannels[inputChanIdx] <- msg
fmt.Println("node:", (*nodeCtx.node).Name(), "receive to input channel ", inputChanIdx)
wg.Done()
}
func (nodeCtx *nodeCtx) allUpstreamDone() bool {
inputsNum := len(nodeCtx.inputChannels)
hasInputs := 0
for i := 0; i < inputsNum; i++ {
channel := nodeCtx.inputChannels[i]
if len(channel) > 0 {
hasInputs++
}
}
return hasInputs == inputsNum
}
func (nodeCtx *nodeCtx) getMessagesFromChannel() {
inputsNum := len(nodeCtx.inputChannels)
nodeCtx.inputMessages = make([][]*Msg, inputsNum)
// init inputMessages,
// receive messages from inputChannels,
// and move them to inputMessages.
for i := 0; i < inputsNum; i++ {
nodeCtx.inputMessages[i] = make([]*Msg, 0)
channel := nodeCtx.inputChannels[i]
msg := <-channel
nodeCtx.inputMessages[i] = append(nodeCtx.inputMessages[i], msg)
}
}
func (node *baseNode) MaxQueueLength() int32 {
return node.maxQueueLength
}
func (node *baseNode) MaxParallelism() int32 {
return node.maxParallelism
}
func (node *baseNode) SetMaxQueueLength(n int32) {
node.maxQueueLength = n
}
func (node *baseNode) SetMaxParallelism(n int32) {
node.maxParallelism = n
}
func (node *baseNode) SetPipelineStates(states *flowGraphStates) {
node.graphStates = states
}