mirror of https://github.com/milvus-io/milvus.git
parent
3e596fa474
commit
8020dc2256
|
@ -0,0 +1,97 @@
|
|||
package flowgraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Timestamp = typeutil.Timestamp
|
||||
|
||||
type flowGraphStates struct {
|
||||
startTick Timestamp
|
||||
numActiveTasks map[string]int64
|
||||
numCompletedTasks map[string]int64
|
||||
}
|
||||
|
||||
type TimeTickedFlowGraph struct {
|
||||
ctx context.Context
|
||||
states *flowGraphStates
|
||||
nodeCtx map[string]*nodeCtx
|
||||
}
|
||||
|
||||
func (fg *TimeTickedFlowGraph) AddNode(node *Node) {
|
||||
nodeName := (*node).Name()
|
||||
nodeCtx := nodeCtx{
|
||||
node: node,
|
||||
inputChannels: make([]chan *Msg, 0),
|
||||
downstreamInputChanIdx: make(map[string]int),
|
||||
}
|
||||
fg.nodeCtx[nodeName] = &nodeCtx
|
||||
}
|
||||
|
||||
func (fg *TimeTickedFlowGraph) SetEdges(nodeName string, in []string, out []string) error {
|
||||
currentNode, ok := fg.nodeCtx[nodeName]
|
||||
if !ok {
|
||||
errMsg := "Cannot find node:" + nodeName
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// init current node's downstream
|
||||
currentNode.downstream = make([]*nodeCtx, len(out))
|
||||
|
||||
// set in nodes
|
||||
for i, inNodeName := range in {
|
||||
inNode, ok := fg.nodeCtx[inNodeName]
|
||||
if !ok {
|
||||
errMsg := "Cannot find in node:" + inNodeName
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
inNode.downstreamInputChanIdx[nodeName] = i
|
||||
}
|
||||
|
||||
// set out nodes
|
||||
for i, n := range out {
|
||||
outNode, ok := fg.nodeCtx[n]
|
||||
if !ok {
|
||||
errMsg := "Cannot find out node:" + n
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
maxQueueLength := (*outNode.node).MaxQueueLength()
|
||||
outNode.inputChannels = append(outNode.inputChannels, make(chan *Msg, maxQueueLength))
|
||||
currentNode.downstream[i] = outNode
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fg *TimeTickedFlowGraph) Start() {
|
||||
wg := sync.WaitGroup{}
|
||||
for _, v := range fg.nodeCtx {
|
||||
wg.Add(1)
|
||||
go v.Start(fg.ctx, &wg)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (fg *TimeTickedFlowGraph) Close() error {
|
||||
for _, v := range fg.nodeCtx {
|
||||
v.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewTimeTickedFlowGraph(ctx context.Context) *TimeTickedFlowGraph {
|
||||
flowGraph := TimeTickedFlowGraph{
|
||||
ctx: ctx,
|
||||
states: &flowGraphStates{
|
||||
startTick: 0,
|
||||
numActiveTasks: make(map[string]int64),
|
||||
numCompletedTasks: make(map[string]int64),
|
||||
},
|
||||
nodeCtx: make(map[string]*nodeCtx),
|
||||
}
|
||||
|
||||
return &flowGraph
|
||||
}
|
|
@ -0,0 +1,246 @@
|
|||
package flowgraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const ctxTimeInMillisecond = 3000
|
||||
|
||||
type nodeA struct {
|
||||
baseNode
|
||||
a float64
|
||||
}
|
||||
|
||||
type nodeB struct {
|
||||
baseNode
|
||||
b float64
|
||||
}
|
||||
|
||||
type nodeC struct {
|
||||
baseNode
|
||||
c float64
|
||||
}
|
||||
|
||||
type nodeD struct {
|
||||
baseNode
|
||||
d float64
|
||||
resChan chan float64
|
||||
}
|
||||
|
||||
type intMsg struct {
|
||||
num float64
|
||||
t Timestamp
|
||||
}
|
||||
|
||||
func (m *intMsg) TimeTick() Timestamp {
|
||||
return m.t
|
||||
}
|
||||
|
||||
func (m *intMsg) DownStreamNodeIdx() int32 {
|
||||
return 1
|
||||
}
|
||||
|
||||
func intMsg2Msg(in []*intMsg) []*Msg {
|
||||
out := make([]*Msg, 0)
|
||||
for _, msg := range in {
|
||||
var m Msg = msg
|
||||
out = append(out, &m)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func msg2IntMsg(in []*Msg) []*intMsg {
|
||||
out := make([]*intMsg, 0)
|
||||
for _, msg := range in {
|
||||
out = append(out, (*msg).(*intMsg))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (a *nodeA) Name() string {
|
||||
return "NodeA"
|
||||
}
|
||||
|
||||
func (a *nodeA) Operate(in []*Msg) []*Msg {
|
||||
return append(in, in...)
|
||||
}
|
||||
|
||||
func (b *nodeB) Name() string {
|
||||
return "NodeB"
|
||||
}
|
||||
|
||||
func (b *nodeB) Operate(in []*Msg) []*Msg {
|
||||
messages := make([]*intMsg, 0)
|
||||
for _, msg := range msg2IntMsg(in) {
|
||||
messages = append(messages, &intMsg{
|
||||
num: math.Pow(msg.num, 2),
|
||||
})
|
||||
}
|
||||
return intMsg2Msg(messages)
|
||||
}
|
||||
|
||||
func (c *nodeC) Name() string {
|
||||
return "NodeC"
|
||||
}
|
||||
|
||||
func (c *nodeC) Operate(in []*Msg) []*Msg {
|
||||
messages := make([]*intMsg, 0)
|
||||
for _, msg := range msg2IntMsg(in) {
|
||||
messages = append(messages, &intMsg{
|
||||
num: math.Sqrt(msg.num),
|
||||
})
|
||||
}
|
||||
return intMsg2Msg(messages)
|
||||
}
|
||||
|
||||
func (d *nodeD) Name() string {
|
||||
return "NodeD"
|
||||
}
|
||||
|
||||
func (d *nodeD) Operate(in []*Msg) []*Msg {
|
||||
messages := make([]*intMsg, 0)
|
||||
outLength := len(in) / 2
|
||||
inMessages := msg2IntMsg(in)
|
||||
for i := 0; i < outLength; i++ {
|
||||
var msg = &intMsg{
|
||||
num: inMessages[i].num + inMessages[i+outLength].num,
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
d.d = messages[0].num
|
||||
d.resChan <- d.d
|
||||
fmt.Println("flow graph result:", d.d)
|
||||
return intMsg2Msg(messages)
|
||||
}
|
||||
|
||||
func sendMsgFromCmd(ctx context.Context, fg *TimeTickedFlowGraph) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
time.Sleep(time.Millisecond * time.Duration(500))
|
||||
var num = float64(rand.Int() % 100)
|
||||
var msg Msg = &intMsg{num: num}
|
||||
a := nodeA{}
|
||||
fg.nodeCtx[a.Name()].inputChannels[0] <- &msg
|
||||
fmt.Println("send number", num, "to node", a.Name())
|
||||
res, ok := receiveResult(ctx, fg)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// assert result
|
||||
if res != math.Pow(num, 2)+math.Sqrt(num) {
|
||||
fmt.Println(res)
|
||||
fmt.Println(math.Pow(num, 2) + math.Sqrt(num))
|
||||
panic("wrong answer")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func receiveResultFromNodeD(res *float64, fg *TimeTickedFlowGraph, wg *sync.WaitGroup) {
|
||||
d := nodeD{}
|
||||
node := fg.nodeCtx[d.Name()]
|
||||
nd, ok := (*node.node).(*nodeD)
|
||||
if !ok {
|
||||
log.Fatal("not nodeD type")
|
||||
}
|
||||
*res = <-nd.resChan
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
func receiveResult(ctx context.Context, fg *TimeTickedFlowGraph) (float64, bool) {
|
||||
d := nodeD{}
|
||||
node := fg.nodeCtx[d.Name()]
|
||||
nd, ok := (*node.node).(*nodeD)
|
||||
if !ok {
|
||||
log.Fatal("not nodeD type")
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, false
|
||||
case res := <-nd.resChan:
|
||||
return res, true
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeTickedFlowGraph_Start(t *testing.T) {
|
||||
duration := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, _ := context.WithDeadline(context.Background(), duration)
|
||||
fg := NewTimeTickedFlowGraph(ctx)
|
||||
|
||||
var a Node = &nodeA{
|
||||
baseNode: baseNode{
|
||||
maxQueueLength: maxQueueLength,
|
||||
},
|
||||
}
|
||||
var b Node = &nodeB{
|
||||
baseNode: baseNode{
|
||||
maxQueueLength: maxQueueLength,
|
||||
},
|
||||
}
|
||||
var c Node = &nodeC{
|
||||
baseNode: baseNode{
|
||||
maxQueueLength: maxQueueLength,
|
||||
},
|
||||
}
|
||||
var d Node = &nodeD{
|
||||
baseNode: baseNode{
|
||||
maxQueueLength: maxQueueLength,
|
||||
},
|
||||
resChan: make(chan float64),
|
||||
}
|
||||
|
||||
fg.AddNode(&a)
|
||||
fg.AddNode(&b)
|
||||
fg.AddNode(&c)
|
||||
fg.AddNode(&d)
|
||||
|
||||
var err = fg.SetEdges(a.Name(),
|
||||
[]string{},
|
||||
[]string{b.Name(), c.Name()},
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal("set edges failed")
|
||||
}
|
||||
|
||||
err = fg.SetEdges(b.Name(),
|
||||
[]string{a.Name()},
|
||||
[]string{d.Name()},
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal("set edges failed")
|
||||
}
|
||||
|
||||
err = fg.SetEdges(c.Name(),
|
||||
[]string{a.Name()},
|
||||
[]string{d.Name()},
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal("set edges failed")
|
||||
}
|
||||
|
||||
err = fg.SetEdges(d.Name(),
|
||||
[]string{b.Name(), c.Name()},
|
||||
[]string{},
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal("set edges failed")
|
||||
}
|
||||
|
||||
// init node A
|
||||
nodeCtxA := fg.nodeCtx[a.Name()]
|
||||
nodeCtxA.inputChannels = []chan *Msg{make(chan *Msg, 10)}
|
||||
|
||||
go fg.Start()
|
||||
|
||||
sendMsgFromCmd(ctx, fg)
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
package flowgraph
|
||||
|
||||
type Msg interface {
|
||||
TimeTick() Timestamp
|
||||
DownStreamNodeIdx() int32
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
package flowgraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const maxQueueLength = 1024
|
||||
|
||||
type Node interface {
|
||||
Name() string
|
||||
MaxQueueLength() int32
|
||||
MaxParallelism() int32
|
||||
SetPipelineStates(states *flowGraphStates)
|
||||
Operate(in []*Msg) []*Msg
|
||||
}
|
||||
|
||||
type baseNode struct {
|
||||
maxQueueLength int32
|
||||
maxParallelism int32
|
||||
graphStates *flowGraphStates
|
||||
}
|
||||
|
||||
type nodeCtx struct {
|
||||
node *Node
|
||||
inputChannels []chan *Msg
|
||||
inputMessages [][]*Msg
|
||||
downstream []*nodeCtx
|
||||
downstreamInputChanIdx map[string]int
|
||||
}
|
||||
|
||||
func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
wg.Done()
|
||||
return
|
||||
default:
|
||||
if !nodeCtx.allUpstreamDone() {
|
||||
continue
|
||||
}
|
||||
nodeCtx.getMessagesFromChannel()
|
||||
// inputs from inputsMessages for Operate
|
||||
inputs := make([]*Msg, 0)
|
||||
for i := 0; i < len(nodeCtx.inputMessages); i++ {
|
||||
inputs = append(inputs, nodeCtx.inputMessages[i]...)
|
||||
}
|
||||
n := *nodeCtx.node
|
||||
res := n.Operate(inputs)
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < len(nodeCtx.downstreamInputChanIdx); i++ {
|
||||
wg.Add(1)
|
||||
go nodeCtx.downstream[i].ReceiveMsg(&wg, res[i], nodeCtx.downstreamInputChanIdx[(*nodeCtx.downstream[i].node).Name()])
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (nodeCtx *nodeCtx) Close() {
|
||||
for _, channel := range nodeCtx.inputChannels {
|
||||
close(channel)
|
||||
}
|
||||
}
|
||||
|
||||
func (nodeCtx *nodeCtx) ReceiveMsg(wg *sync.WaitGroup, msg *Msg, inputChanIdx int) {
|
||||
nodeCtx.inputChannels[inputChanIdx] <- msg
|
||||
fmt.Println("node:", (*nodeCtx.node).Name(), "receive to input channel ", inputChanIdx)
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
func (nodeCtx *nodeCtx) allUpstreamDone() bool {
|
||||
inputsNum := len(nodeCtx.inputChannels)
|
||||
hasInputs := 0
|
||||
for i := 0; i < inputsNum; i++ {
|
||||
channel := nodeCtx.inputChannels[i]
|
||||
if len(channel) > 0 {
|
||||
hasInputs++
|
||||
}
|
||||
}
|
||||
return hasInputs == inputsNum
|
||||
}
|
||||
|
||||
func (nodeCtx *nodeCtx) getMessagesFromChannel() {
|
||||
inputsNum := len(nodeCtx.inputChannels)
|
||||
nodeCtx.inputMessages = make([][]*Msg, inputsNum)
|
||||
|
||||
// init inputMessages,
|
||||
// receive messages from inputChannels,
|
||||
// and move them to inputMessages.
|
||||
for i := 0; i < inputsNum; i++ {
|
||||
nodeCtx.inputMessages[i] = make([]*Msg, 0)
|
||||
channel := nodeCtx.inputChannels[i]
|
||||
msg := <-channel
|
||||
nodeCtx.inputMessages[i] = append(nodeCtx.inputMessages[i], msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (node *baseNode) MaxQueueLength() int32 {
|
||||
return node.maxQueueLength
|
||||
}
|
||||
|
||||
func (node *baseNode) MaxParallelism() int32 {
|
||||
return node.maxParallelism
|
||||
}
|
||||
|
||||
func (node *baseNode) SetMaxQueueLength(n int32) {
|
||||
node.maxQueueLength = n
|
||||
}
|
||||
|
||||
func (node *baseNode) SetMaxParallelism(n int32) {
|
||||
node.maxParallelism = n
|
||||
}
|
||||
|
||||
func (node *baseNode) SetPipelineStates(states *flowGraphStates) {
|
||||
node.graphStates = states
|
||||
}
|
Loading…
Reference in New Issue