Add msgDispatcher to support sharing msgs for different vChannel (#21917)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/22150/head
bigsheeper 2023-02-13 16:38:33 +08:00 committed by GitHub
parent a2435cfc4f
commit d2667064bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1601 additions and 351 deletions

View File

@ -39,13 +39,13 @@ import (
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
allocator2 "github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
@ -127,7 +127,8 @@ type DataNode struct {
closer io.Closer
factory dependency.Factory
dispClient msgdispatcher.Client
factory dependency.Factory
}
// NewDataNode will return a DataNode with abnormal state.
@ -249,6 +250,9 @@ func (node *DataNode) Init() error {
}
log.Info("DataNode server init rateCollector done", zap.Int64("node ID", paramtable.GetNodeID()))
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.DataNodeRole, paramtable.GetNodeID())
log.Info("DataNode server init dispatcher client done", zap.Int64("node ID", paramtable.GetNodeID()))
idAllocator, err := allocator2.NewIDAllocator(node.ctx, node.rootCoord, paramtable.GetNodeID())
if err != nil {
log.Error("failed to create id allocator",

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/datapb"
@ -49,6 +50,7 @@ type dataSyncService struct {
resendTTCh chan resendTTMsg // chan to ask for resending DataNode time tick message.
channel Channel // channel stores meta of channel
idAllocator allocatorInterface // id/timestamp allocator
dispClient msgdispatcher.Client
msFactory msgstream.Factory
collectionID UniqueID // collection id of vchan for which this data sync service serves
vchannelName string
@ -71,6 +73,7 @@ func newDataSyncService(ctx context.Context,
resendTTCh chan resendTTMsg,
channel Channel,
alloc allocatorInterface,
dispClient msgdispatcher.Client,
factory msgstream.Factory,
vchan *datapb.VchannelInfo,
clearSignal chan<- string,
@ -101,6 +104,7 @@ func newDataSyncService(ctx context.Context,
resendTTCh: resendTTCh,
channel: channel,
idAllocator: alloc,
dispClient: dispClient,
msFactory: factory,
collectionID: vchan.GetCollectionID(),
vchannelName: vchan.GetChannelName(),
@ -156,6 +160,7 @@ func (dsService *dataSyncService) close() {
if dsService.fg != nil {
log.Info("dataSyncService closing flowgraph", zap.Int64("collectionID", dsService.collectionID),
zap.String("vChanName", dsService.vchannelName))
dsService.dispClient.Deregister(dsService.vchannelName)
dsService.fg.Close()
metrics.DataNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
metrics.DataNodeNumProducers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Sub(2) // timeTickChannel + deltaChannel
@ -287,7 +292,7 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo) erro
}
var dmStreamNode Node
dmStreamNode, err = newDmInputNode(dsService.ctx, vchanInfo.GetSeekPosition(), c)
dmStreamNode, err = newDmInputNode(dsService.dispClient, vchanInfo.GetSeekPosition(), c)
if err != nil {
return err
}

View File

@ -33,6 +33,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
@ -40,10 +41,15 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
var dataSyncServiceTestDir = "/tmp/milvus_test/data_sync_service"
func init() {
Params.Init()
}
func getVchanInfo(info *testInfo) *datapb.VchannelInfo {
var ufs []*datapb.SegmentInfo
var fs []*datapb.SegmentInfo
@ -160,12 +166,14 @@ func TestDataSyncService_newDataSyncService(te *testing.T) {
if test.channelNil {
channel = nil
}
dispClient := msgdispatcher.NewClient(test.inMsgFactory, typeutil.DataNodeRole, paramtable.GetNodeID())
ds, err := newDataSyncService(ctx,
make(chan flushMsg),
make(chan resendTTMsg),
channel,
NewAllocatorFactory(),
dispClient,
test.inMsgFactory,
getVchanInfo(test),
make(chan string),
@ -217,6 +225,7 @@ func TestDataSyncService_Start(t *testing.T) {
allocFactory := NewAllocatorFactory(1)
factory := dependency.NewDefaultFactory(true)
dispClient := msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID())
defer os.RemoveAll("/tmp/milvus")
paramtable.Get().Save(Params.DataNodeCfg.FlushInsertBufferSize.Key, "1")
@ -270,7 +279,7 @@ func TestDataSyncService_Start(t *testing.T) {
},
}
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor(), 0)
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, dispClient, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor(), 0)
assert.Nil(t, err)
sync.flushListener = make(chan *segmentFlushPack)
@ -399,6 +408,7 @@ func TestDataSyncService_Close(t *testing.T) {
allocFactory = NewAllocatorFactory(1)
factory = dependency.NewDefaultFactory(true)
dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID())
mockDataCoord = &DataCoordFactory{}
)
mockDataCoord.UserSegmentInfo = map[int64]*datapb.SegmentInfo{
@ -421,7 +431,7 @@ func TestDataSyncService_Close(t *testing.T) {
paramtable.Get().Reset(Params.DataNodeCfg.FlushInsertBufferSize.Key)
channel := newChannel(insertChannelName, collMeta.ID, collMeta.GetSchema(), mockRootCoord, cm)
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, mockDataCoord, newCache(), cm, newCompactionExecutor(), 0)
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, dispClient, factory, vchan, signalCh, mockDataCoord, newCache(), cm, newCompactionExecutor(), 0)
assert.Nil(t, err)
sync.flushListener = make(chan *segmentFlushPack, 10)

View File

@ -220,6 +220,12 @@ func (ddn *ddNode) Operate(in []Msg) []Msg {
for i := int64(0); i < dmsg.NumRows; i++ {
dmsg.HashValues = append(dmsg.HashValues, uint32(0))
}
deltaVChannel, err := funcutil.ConvertChannelName(dmsg.ShardName, Params.CommonCfg.RootCoordDml.GetValue(), Params.CommonCfg.RootCoordDelta.GetValue())
if err != nil {
log.Error("convert dmlVChannel to deltaVChannel failed", zap.String("vchannel", ddn.vChannelName), zap.Error(err))
panic(err)
}
dmsg.ShardName = deltaVChannel
forwardMsgs = append(forwardMsgs, dmsg)
if dmsg.CollectionID != ddn.collectionID {
log.Warn("filter invalid DeleteMsg, collection mis-match",

View File

@ -278,6 +278,7 @@ func TestFlowGraph_DDNode_Operate(t *testing.T) {
},
DeleteRequest: internalpb.DeleteRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete},
ShardName: "by-dev-rootcoord-dml-mock-0",
CollectionID: test.inMsgCollID,
},
}

View File

@ -17,7 +17,6 @@
package datanode
import (
"context"
"fmt"
"time"
@ -25,10 +24,11 @@ import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/flowgraph"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
@ -37,50 +37,32 @@ import (
// DmInputNode receives messages from message streams, packs messages between two timeticks, and passes all
// messages between two timeticks to the following flowgraph node. In DataNode, the following flow graph node is
// flowgraph ddNode.
func newDmInputNode(ctx context.Context, seekPos *internalpb.MsgPosition, dmNodeConfig *nodeConfig) (*flowgraph.InputNode, error) {
// subName should be unique, since pchannelName is shared among several collections
// use vchannel in case of reuse pchannel for same collection
consumeSubName := fmt.Sprintf("%s-%d-%s", Params.CommonCfg.DataNodeSubName.GetValue(), paramtable.GetNodeID(), dmNodeConfig.vChannelName)
insertStream, err := dmNodeConfig.msFactory.NewTtMsgStream(ctx)
if err != nil {
return nil, err
}
// MsgStream needs a physical channel name, but the channel name in seek position from DataCoord
// is virtual channel name, so we need to convert vchannel name into pchannel neme here.
pchannelName := funcutil.ToPhysicalChannel(dmNodeConfig.vChannelName)
if seekPos != nil {
insertStream.AsConsumer([]string{pchannelName}, consumeSubName, mqwrapper.SubscriptionPositionUnknown)
seekPos.ChannelName = pchannelName
cpTs, _ := tsoutil.ParseTS(seekPos.Timestamp)
start := time.Now()
log.Info("datanode begin to seek",
zap.ByteString("seek msgID", seekPos.GetMsgID()),
zap.String("pchannel", seekPos.GetChannelName()),
zap.String("vchannel", dmNodeConfig.vChannelName),
zap.Time("position", cpTs),
zap.Duration("tsLag", time.Since(cpTs)),
zap.Int64("collection ID", dmNodeConfig.collectionID))
err = insertStream.Seek([]*internalpb.MsgPosition{seekPos})
func newDmInputNode(dispatcherClient msgdispatcher.Client, seekPos *internalpb.MsgPosition, dmNodeConfig *nodeConfig) (*flowgraph.InputNode, error) {
log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("collection ID", dmNodeConfig.collectionID),
zap.String("vchannel", dmNodeConfig.vChannelName))
var err error
var input <-chan *msgstream.MsgPack
if seekPos != nil && len(seekPos.MsgID) != 0 {
input, err = dispatcherClient.Register(dmNodeConfig.vChannelName, seekPos, mqwrapper.SubscriptionPositionUnknown)
if err != nil {
return nil, err
}
log.Info("datanode seek successfully",
zap.ByteString("seek msgID", seekPos.GetMsgID()),
zap.String("pchannel", seekPos.GetChannelName()),
zap.String("vchannel", dmNodeConfig.vChannelName),
zap.Time("position", cpTs),
zap.Duration("tsLag", time.Since(cpTs)),
zap.Int64("collection ID", dmNodeConfig.collectionID),
zap.Duration("elapse", time.Since(start)))
log.Info("datanode seek successfully when register to msgDispatcher",
zap.ByteString("msgID", seekPos.GetMsgID()),
zap.Time("tsTime", tsoutil.PhysicalTime(seekPos.GetTimestamp())),
zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(seekPos.GetTimestamp()))))
} else {
insertStream.AsConsumer([]string{pchannelName}, consumeSubName, mqwrapper.SubscriptionPositionEarliest)
input, err = dispatcherClient.Register(dmNodeConfig.vChannelName, nil, mqwrapper.SubscriptionPositionEarliest)
if err != nil {
return nil, err
}
log.Info("datanode consume successfully when register to msgDispatcher")
}
metrics.DataNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
log.Info("datanode AsConsumer", zap.String("physical channel", pchannelName), zap.String("subName", consumeSubName), zap.Int64("collection ID", dmNodeConfig.collectionID))
name := fmt.Sprintf("dmInputNode-data-%d-%s", dmNodeConfig.collectionID, dmNodeConfig.vChannelName)
node := flowgraph.NewInputNode(insertStream, name, dmNodeConfig.maxQueueLength, dmNodeConfig.maxParallelism,
node := flowgraph.NewInputNode(input, name, dmNodeConfig.maxQueueLength, dmNodeConfig.maxParallelism,
typeutil.DataNodeRole, paramtable.GetNodeID(), dmNodeConfig.collectionID, metrics.AllLabel)
return node, nil
}

View File

@ -21,12 +21,14 @@ import (
"errors"
"testing"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type mockMsgStreamFactory struct {
@ -93,7 +95,10 @@ func (mtm *mockTtMsgStream) GetLatestMsgID(channel string) (msgstream.MessageID,
}
func TestNewDmInputNode(t *testing.T) {
ctx := context.Background()
_, err := newDmInputNode(ctx, new(internalpb.MsgPosition), &nodeConfig{msFactory: &mockMsgStreamFactory{}})
client := msgdispatcher.NewClient(&mockMsgStreamFactory{}, typeutil.DataNodeRole, paramtable.GetNodeID())
_, err := newDmInputNode(client, new(internalpb.MsgPosition), &nodeConfig{
msFactory: &mockMsgStreamFactory{},
vChannelName: "mock_vchannel_0",
})
assert.Nil(t, err)
}

View File

@ -48,7 +48,7 @@ func (fm *flowgraphManager) addAndStart(dn *DataNode, vchan *datapb.VchannelInfo
var alloc allocatorInterface = newAllocator(dn.rootCoord)
dataSyncService, err := newDataSyncService(dn.ctx, make(chan flushMsg, 100), make(chan resendTTMsg, 100), channel,
alloc, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor, dn.GetSession().ServerID)
alloc, dn.dispClient, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor, dn.GetSession().ServerID)
if err != nil {
log.Warn("new data sync service fail", zap.String("vChannelName", vchan.GetChannelName()), zap.Error(err))
return err

View File

@ -27,29 +27,29 @@ import (
"sync"
"time"
"github.com/milvus-io/milvus/internal/util/metautil"
"go.uber.org/zap"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
s "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
s "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/metautil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const ctxTimeInMillisecond = 5000
@ -81,6 +81,7 @@ func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNod
factory := dependency.NewDefaultFactory(true)
node := NewDataNode(ctx, factory)
node.SetSession(&sessionutil.Session{ServerID: 1})
node.dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID())
rc := &RootCoordFactory{
ID: 0,

View File

@ -0,0 +1,93 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type (
Pos = internalpb.MsgPosition
MsgPack = msgstream.MsgPack
SubPos = mqwrapper.SubscriptionInitialPosition
)
type Client interface {
Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
Deregister(vchannel string)
}
var _ Client = (*client)(nil)
type client struct {
role string
nodeID int64
managers *typeutil.ConcurrentMap[string, DispatcherManager] // pchannel -> DispatcherManager
factory msgstream.Factory
}
func NewClient(factory msgstream.Factory, role string, nodeID int64) Client {
return &client{
role: role,
nodeID: nodeID,
managers: typeutil.NewConcurrentMap[string, DispatcherManager](),
factory: factory,
}
}
func (c *client) Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
pchannel := funcutil.ToPhysicalChannel(vchannel)
managers, ok := c.managers.Get(pchannel)
if !ok {
managers = NewDispatcherManager(pchannel, c.role, c.nodeID, c.factory)
go managers.Run()
old, exist := c.managers.GetOrInsert(pchannel, managers)
if exist {
managers.Close()
managers = old
}
}
ch, err := managers.Add(vchannel, pos, subPos)
if err != nil {
log.Error("register failed", zap.Error(err))
return nil, err
}
log.Info("register done")
return ch, nil
}
func (c *client) Deregister(vchannel string) {
pchannel := funcutil.ToPhysicalChannel(vchannel)
if managers, ok := c.managers.Get(pchannel); ok {
managers.Remove(vchannel)
if managers.Num() == 0 {
managers.Close()
c.managers.GetAndRemove(pchannel)
}
log.Info("deregister done", zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
}
}

View File

@ -0,0 +1,58 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"fmt"
"math/rand"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func TestClient(t *testing.T) {
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
assert.NotNil(t, client)
_, err := client.Register("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
assert.NotPanics(t, func() {
client.Deregister("mock_vchannel_0")
})
}
func TestClient_Concurrency(t *testing.T) {
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
assert.NotNil(t, client)
wg := &sync.WaitGroup{}
for i := 0; i < 100; i++ {
vchannel := fmt.Sprintf("mock-vchannel-%d-%d", i, rand.Int())
wg.Add(1)
go func() {
for j := 0; j < 10; j++ {
_, err := client.Register(vchannel, nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
client.Deregister(vchannel)
}
wg.Done()
}()
}
wg.Wait()
}

View File

@ -0,0 +1,244 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"context"
"fmt"
"sync"
"time"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type signal int32
const (
start signal = 0
pause signal = 1
resume signal = 2
terminate signal = 3
)
var signalString = map[int32]string{
0: "start",
1: "pause",
2: "resume",
3: "terminate",
}
func (s signal) String() string {
return signalString[int32(s)]
}
type Dispatcher struct {
done chan struct{}
wg sync.WaitGroup
once sync.Once
isMain bool // indicates if it's a main dispatcher
pchannel string
curTs atomic.Uint64
lagNotifyChan chan struct{}
lagTargets *sync.Map // vchannel -> *target
// vchannel -> *target, lock free since we guarantee that
// it's modified only after dispatcher paused or terminated
targets map[string]*target
stream msgstream.MsgStream
}
func NewDispatcher(factory msgstream.Factory,
isMain bool,
pchannel string,
position *Pos,
subName string,
subPos SubPos,
lagNotifyChan chan struct{},
lagTargets *sync.Map,
) (*Dispatcher, error) {
log := log.With(zap.String("pchannel", pchannel),
zap.String("subName", subName), zap.Bool("isMain", isMain))
log.Info("creating dispatcher...")
stream, err := factory.NewTtMsgStream(context.Background())
if err != nil {
return nil, err
}
if position != nil && len(position.MsgID) != 0 {
position.ChannelName = funcutil.ToPhysicalChannel(position.ChannelName)
stream.AsConsumer([]string{pchannel}, subName, mqwrapper.SubscriptionPositionUnknown)
err = stream.Seek([]*Pos{position})
if err != nil {
log.Error("seek failed", zap.Error(err))
return nil, err
}
posTime := tsoutil.PhysicalTime(position.GetTimestamp())
log.Info("seek successfully", zap.Time("posTime", posTime),
zap.Duration("tsLag", time.Since(posTime)))
} else {
stream.AsConsumer([]string{pchannel}, subName, subPos)
log.Info("asConsumer successfully")
}
d := &Dispatcher{
done: make(chan struct{}, 1),
isMain: isMain,
pchannel: pchannel,
lagNotifyChan: lagNotifyChan,
lagTargets: lagTargets,
targets: make(map[string]*target),
stream: stream,
}
return d, nil
}
func (d *Dispatcher) CurTs() typeutil.Timestamp {
return d.curTs.Load()
}
func (d *Dispatcher) AddTarget(t *target) {
log := log.With(zap.String("vchannel", t.vchannel), zap.Bool("isMain", d.isMain))
if _, ok := d.targets[t.vchannel]; ok {
log.Warn("target exists")
return
}
d.targets[t.vchannel] = t
log.Info("add new target")
}
func (d *Dispatcher) GetTarget(vchannel string) (*target, error) {
if t, ok := d.targets[vchannel]; ok {
return t, nil
}
return nil, fmt.Errorf("cannot find target, vchannel=%s, isMain=%t", vchannel, d.isMain)
}
func (d *Dispatcher) CloseTarget(vchannel string) {
log := log.With(zap.String("vchannel", vchannel), zap.Bool("isMain", d.isMain))
if t, ok := d.targets[vchannel]; ok {
t.close()
delete(d.targets, vchannel)
log.Info("closed target")
} else {
log.Warn("target not exist")
}
}
func (d *Dispatcher) TargetNum() int {
return len(d.targets)
}
func (d *Dispatcher) Handle(signal signal) {
log := log.With(zap.String("pchannel", d.pchannel),
zap.String("signal", signal.String()), zap.Bool("isMain", d.isMain))
log.Info("get signal")
switch signal {
case start:
d.wg.Add(1)
go d.work()
case pause:
d.done <- struct{}{}
d.wg.Wait()
case resume:
d.wg.Add(1)
go d.work()
case terminate:
d.done <- struct{}{}
d.wg.Wait()
d.once.Do(func() {
d.stream.Close()
})
}
log.Info("handle signal done")
}
func (d *Dispatcher) work() {
log := log.With(zap.String("pchannel", d.pchannel), zap.Bool("isMain", d.isMain))
log.Info("begin to work")
defer d.wg.Done()
for {
select {
case <-d.done:
log.Info("stop working")
return
case pack := <-d.stream.Chan():
if pack == nil || len(pack.EndPositions) != 1 {
log.Error("consumed invalid msgPack")
continue
}
d.curTs.Store(pack.EndPositions[0].GetTimestamp())
// init packs for all targets, even though there's no msg in pack,
// but we still need to dispatch time ticks to the targets.
targetPacks := make(map[string]*MsgPack)
for vchannel := range d.targets {
targetPacks[vchannel] = &MsgPack{
BeginTs: pack.BeginTs,
EndTs: pack.EndTs,
Msgs: make([]msgstream.TsMsg, 0),
StartPositions: pack.StartPositions,
EndPositions: pack.EndPositions,
}
}
// group messages by vchannel
for _, msg := range pack.Msgs {
if msg.VChannel() == "" {
// for non-dml msg, such as CreateCollection, DropCollection, ...
// we need to dispatch it to all the vchannels.
for k := range targetPacks {
targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg)
}
continue
}
if _, ok := targetPacks[msg.VChannel()]; !ok {
continue
}
targetPacks[msg.VChannel()].Msgs = append(targetPacks[msg.VChannel()].Msgs, msg)
}
// dispatch messages, split target if block
for vchannel, p := range targetPacks {
t := d.targets[vchannel]
if err := t.send(p); err != nil {
t.pos = pack.StartPositions[0]
d.lagTargets.LoadOrStore(t.vchannel, t)
d.nonBlockingNotify()
delete(d.targets, vchannel)
log.Warn("lag target notified", zap.Error(err))
}
}
}
}
}
func (d *Dispatcher) nonBlockingNotify() {
select {
case d.lagNotifyChan <- struct{}{}:
default:
}
}

View File

@ -0,0 +1,128 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
)
func TestDispatcher(t *testing.T) {
t.Run("test base", func(t *testing.T) {
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
assert.NoError(t, err)
assert.NotPanics(t, func() {
d.Handle(start)
d.Handle(pause)
d.Handle(resume)
d.Handle(terminate)
})
pos := &msgstream.MsgPosition{
ChannelName: "mock_vchannel_0",
MsgGroup: "mock_msg_group",
Timestamp: 100,
}
d.curTs.Store(pos.GetTimestamp())
curTs := d.CurTs()
assert.Equal(t, pos.Timestamp, curTs)
})
t.Run("test target", func(t *testing.T) {
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
assert.NoError(t, err)
output := make(chan *msgstream.MsgPack, 1024)
d.AddTarget(&target{
vchannel: "mock_vchannel_0",
pos: nil,
ch: output,
})
d.AddTarget(&target{
vchannel: "mock_vchannel_1",
pos: nil,
ch: nil,
})
num := d.TargetNum()
assert.Equal(t, 2, num)
target, err := d.GetTarget("mock_vchannel_0")
assert.NoError(t, err)
assert.Equal(t, cap(output), cap(target.ch))
d.CloseTarget("mock_vchannel_0")
select {
case <-time.After(1 * time.Second):
assert.Fail(t, "timeout, didn't receive close message")
case _, ok := <-target.ch:
assert.False(t, ok)
}
num = d.TargetNum()
assert.Equal(t, 1, num)
})
t.Run("test concurrent send and close", func(t *testing.T) {
for i := 0; i < 100; i++ {
output := make(chan *msgstream.MsgPack, 1024)
target := &target{
vchannel: "mock_vchannel_0",
pos: nil,
ch: output,
}
assert.Equal(t, cap(output), cap(target.ch))
wg := &sync.WaitGroup{}
for j := 0; j < 100; j++ {
wg.Add(1)
go func() {
err := target.send(&MsgPack{})
assert.NoError(t, err)
wg.Done()
}()
wg.Add(1)
go func() {
target.close()
wg.Done()
}()
}
wg.Wait()
}
})
}
func BenchmarkDispatcher_handle(b *testing.B) {
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {
d.Handle(start)
d.Handle(pause)
d.Handle(resume)
d.Handle(terminate)
}
// BenchmarkDispatcher_handle-12 9568 122123 ns/op
// PASS
}

View File

@ -0,0 +1,240 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"context"
"fmt"
"sync"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/util/retry"
)
var (
CheckPeriod = 1 * time.Second // TODO: dyh, move to config
)
type DispatcherManager interface {
Add(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
Remove(vchannel string)
Num() int
Run()
Close()
}
var _ DispatcherManager = (*dispatcherManager)(nil)
type dispatcherManager struct {
role string
nodeID int64
pchannel string
lagNotifyChan chan struct{}
lagTargets *sync.Map // vchannel -> *target
mu sync.RWMutex // guards mainDispatcher and soloDispatchers
mainDispatcher *Dispatcher
soloDispatchers map[string]*Dispatcher // vchannel -> *Dispatcher
factory msgstream.Factory
closeChan chan struct{}
closeOnce sync.Once
}
func NewDispatcherManager(pchannel string, role string, nodeID int64, factory msgstream.Factory) DispatcherManager {
log.Info("create new dispatcherManager", zap.String("role", role),
zap.Int64("nodeID", nodeID), zap.String("pchannel", pchannel))
c := &dispatcherManager{
role: role,
nodeID: nodeID,
pchannel: pchannel,
lagNotifyChan: make(chan struct{}, 1),
lagTargets: &sync.Map{},
soloDispatchers: make(map[string]*Dispatcher),
factory: factory,
closeChan: make(chan struct{}),
}
return c
}
func (c *dispatcherManager) constructSubName(vchannel string, isMain bool) string {
return fmt.Sprintf("%s-%d-%s-%t", c.role, c.nodeID, vchannel, isMain)
}
func (c *dispatcherManager) Add(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
c.mu.Lock()
defer c.mu.Unlock()
isMain := c.mainDispatcher == nil
d, err := NewDispatcher(c.factory, isMain, c.pchannel, pos,
c.constructSubName(vchannel, isMain), subPos, c.lagNotifyChan, c.lagTargets)
if err != nil {
return nil, err
}
t := newTarget(vchannel, pos)
d.AddTarget(t)
if isMain {
c.mainDispatcher = d
log.Info("add main dispatcher")
} else {
c.soloDispatchers[vchannel] = d
log.Info("add solo dispatcher")
}
d.Handle(start)
return t.ch, nil
}
func (c *dispatcherManager) Remove(vchannel string) {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
c.mu.Lock()
defer c.mu.Unlock()
if c.mainDispatcher != nil {
c.mainDispatcher.Handle(pause)
c.mainDispatcher.CloseTarget(vchannel)
if c.mainDispatcher.TargetNum() == 0 && len(c.soloDispatchers) == 0 {
c.mainDispatcher.Handle(terminate)
c.mainDispatcher = nil
} else {
c.mainDispatcher.Handle(resume)
}
}
if _, ok := c.soloDispatchers[vchannel]; ok {
c.soloDispatchers[vchannel].Handle(terminate)
c.soloDispatchers[vchannel].CloseTarget(vchannel)
delete(c.soloDispatchers, vchannel)
log.Info("remove soloDispatcher done")
}
c.lagTargets.Delete(vchannel)
}
func (c *dispatcherManager) Num() int {
c.mu.RLock()
defer c.mu.RUnlock()
var res int
if c.mainDispatcher != nil {
res++
}
return res + len(c.soloDispatchers)
}
func (c *dispatcherManager) Close() {
c.closeOnce.Do(func() {
c.closeChan <- struct{}{}
})
}
func (c *dispatcherManager) Run() {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel))
log.Info("dispatcherManager is running...")
ticker := time.NewTicker(CheckPeriod)
defer ticker.Stop()
for {
select {
case <-c.closeChan:
log.Info("dispatcherManager exited")
return
case <-ticker.C:
c.tryMerge()
case <-c.lagNotifyChan:
c.mu.Lock()
c.lagTargets.Range(func(vchannel, t any) bool {
c.split(t.(*target))
c.lagTargets.Delete(vchannel)
return true
})
c.mu.Unlock()
}
}
}
func (c *dispatcherManager) tryMerge() {
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID))
c.mu.Lock()
defer c.mu.Unlock()
if c.mainDispatcher == nil {
return
}
candidates := make(map[string]struct{})
for vchannel, sd := range c.soloDispatchers {
if sd.CurTs() == c.mainDispatcher.CurTs() {
candidates[vchannel] = struct{}{}
}
}
if len(candidates) == 0 {
return
}
log.Info("start merging...", zap.Any("vchannel", candidates))
c.mainDispatcher.Handle(pause)
for vchannel := range candidates {
c.soloDispatchers[vchannel].Handle(pause)
// after pause, check alignment again, if not, evict it and try to merge next time
if c.mainDispatcher.CurTs() != c.soloDispatchers[vchannel].CurTs() {
c.soloDispatchers[vchannel].Handle(resume)
delete(candidates, vchannel)
}
}
for vchannel := range candidates {
t, err := c.soloDispatchers[vchannel].GetTarget(vchannel)
if err == nil {
c.mainDispatcher.AddTarget(t)
}
c.soloDispatchers[vchannel].Handle(terminate)
delete(c.soloDispatchers, vchannel)
}
c.mainDispatcher.Handle(resume)
log.Info("merge done", zap.Any("vchannel", candidates))
}
func (c *dispatcherManager) split(t *target) {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", t.vchannel))
log.Info("start splitting...")
// remove stale soloDispatcher if it existed
if _, ok := c.soloDispatchers[t.vchannel]; ok {
c.soloDispatchers[t.vchannel].Handle(terminate)
delete(c.soloDispatchers, t.vchannel)
}
var newSolo *Dispatcher
err := retry.Do(context.Background(), func() error {
var err error
newSolo, err = NewDispatcher(c.factory, false, c.pchannel, t.pos,
c.constructSubName(t.vchannel, false), mqwrapper.SubscriptionPositionUnknown, c.lagNotifyChan, c.lagTargets)
return err
}, retry.Attempts(10))
if err != nil {
log.Error("split failed", zap.Error(err))
panic(err)
}
newSolo.AddTarget(t)
c.soloDispatchers[t.vchannel] = newSolo
newSolo.Handle(start)
log.Info("split done")
}

View File

@ -0,0 +1,354 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"context"
"fmt"
"math/rand"
"reflect"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func TestManager(t *testing.T) {
t.Run("test add and remove dispatcher", func(t *testing.T) {
c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
assert.NotNil(t, c)
assert.Equal(t, 0, c.Num())
var offset int
for i := 0; i < 100; i++ {
r := rand.Intn(10) + 1
for j := 0; j < r; j++ {
offset++
_, err := c.Add(fmt.Sprintf("mock_vchannel_%d", offset), nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
assert.Equal(t, offset, c.Num())
}
for j := 0; j < rand.Intn(r); j++ {
c.Remove(fmt.Sprintf("mock_vchannel_%d", offset))
offset--
assert.Equal(t, offset, c.Num())
}
}
})
t.Run("test merge and split", func(t *testing.T) {
c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
assert.NotNil(t, c)
_, err := c.Add("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
_, err = c.Add("mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
_, err = c.Add("mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
assert.Equal(t, 3, c.Num())
c.(*dispatcherManager).tryMerge()
assert.Equal(t, 1, c.Num())
info := &target{
vchannel: "mock_vchannel_2",
pos: nil,
ch: nil,
}
c.(*dispatcherManager).split(info)
assert.Equal(t, 2, c.Num())
})
t.Run("test run and close", func(t *testing.T) {
c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
assert.NotNil(t, c)
_, err := c.Add("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
_, err = c.Add("mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
_, err = c.Add("mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
assert.Equal(t, 3, c.Num())
CheckPeriod = 10 * time.Millisecond
go c.Run()
time.Sleep(15 * time.Millisecond)
assert.Equal(t, 1, c.Num()) // expected merged
assert.NotPanics(t, func() {
c.Close()
})
})
}
type vchannelHelper struct {
output <-chan *msgstream.MsgPack
pubInsMsgNum int
pubDelMsgNum int
pubDDLMsgNum int
pubPackNum int
subInsMsgNum int
subDelMsgNum int
subDDLMsgNum int
subPackNum int
}
type SimulationSuite struct {
suite.Suite
testVchannelNum int
manager DispatcherManager
pchannel string
vchannels map[string]*vchannelHelper
producer msgstream.MsgStream
factory msgstream.Factory
}
func (suite *SimulationSuite) SetupSuite() {
suite.factory = newMockFactory()
}
func (suite *SimulationSuite) SetupTest() {
suite.pchannel = fmt.Sprintf("by-dev-rootcoord-dispatcher-simulation-dml-%d-%d", rand.Int(), time.Now().UnixNano())
producer, err := newMockProducer(suite.factory, suite.pchannel)
assert.NoError(suite.T(), err)
suite.producer = producer
suite.manager = NewDispatcherManager(suite.pchannel, typeutil.DataNodeRole, 0, suite.factory)
CheckPeriod = 10 * time.Millisecond
go suite.manager.Run()
}
func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup) {
defer wg.Done()
const timeTickCount = 200
var uniqueMsgID int64
vchannelKeys := reflect.ValueOf(suite.vchannels).MapKeys()
for i := 1; i <= timeTickCount; i++ {
// produce random insert
insNum := rand.Intn(10)
for j := 0; j < insNum; j++ {
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
err := suite.producer.Produce(&msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
})
assert.NoError(suite.T(), err)
uniqueMsgID++
suite.vchannels[vchannel].pubInsMsgNum++
}
// produce random delete
delNum := rand.Intn(2)
for j := 0; j < delNum; j++ {
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
err := suite.producer.Produce(&msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
})
assert.NoError(suite.T(), err)
uniqueMsgID++
suite.vchannels[vchannel].pubDelMsgNum++
}
// produce random ddl
ddlNum := rand.Intn(2)
for j := 0; j < ddlNum; j++ {
err := suite.producer.Produce(&msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection)},
})
assert.NoError(suite.T(), err)
for k := range suite.vchannels {
suite.vchannels[k].pubDDLMsgNum++
}
}
// produce time tick
ts := uint64(i * 100)
err := suite.producer.Produce(&msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
})
assert.NoError(suite.T(), err)
for k := range suite.vchannels {
suite.vchannels[k].pubPackNum++
}
}
suite.T().Logf("[%s] produce %d msgPack for %s done", time.Now(), timeTickCount, suite.pchannel)
}
func (suite *SimulationSuite) consumeMsg(ctx context.Context, wg *sync.WaitGroup, vchannel string) {
defer wg.Done()
var lastTs typeutil.Timestamp
for {
select {
case <-ctx.Done():
return
case <-time.After(2000 * time.Millisecond): // no message to consume
return
case pack := <-suite.vchannels[vchannel].output:
assert.Greater(suite.T(), pack.EndTs, lastTs)
lastTs = pack.EndTs
helper := suite.vchannels[vchannel]
helper.subPackNum++
for _, msg := range pack.Msgs {
switch msg.Type() {
case commonpb.MsgType_Insert:
helper.subInsMsgNum++
case commonpb.MsgType_Delete:
helper.subDelMsgNum++
case commonpb.MsgType_CreateCollection, commonpb.MsgType_DropCollection,
commonpb.MsgType_CreatePartition, commonpb.MsgType_DropPartition:
helper.subDDLMsgNum++
}
}
}
}
}
func (suite *SimulationSuite) produceTimeTickOnly(ctx context.Context) {
var tt = 1
for {
select {
case <-ctx.Done():
return
case <-time.After(10 * time.Millisecond):
ts := uint64(tt * 1000)
err := suite.producer.Produce(&msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
})
assert.NoError(suite.T(), err)
tt++
}
}
}
func (suite *SimulationSuite) TestDispatchToVchannels() {
const vchannelNum = 20
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
for i := 0; i < vchannelNum; i++ {
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
output, err := suite.manager.Add(vchannel, nil, mqwrapper.SubscriptionPositionEarliest)
assert.NoError(suite.T(), err)
suite.vchannels[vchannel] = &vchannelHelper{output: output}
}
wg := &sync.WaitGroup{}
wg.Add(1)
go suite.produceMsg(wg)
wg.Wait()
for vchannel := range suite.vchannels {
wg.Add(1)
go suite.consumeMsg(context.Background(), wg, vchannel)
}
wg.Wait()
for _, helper := range suite.vchannels {
assert.Equal(suite.T(), helper.pubInsMsgNum, helper.subInsMsgNum)
assert.Equal(suite.T(), helper.pubDelMsgNum, helper.subDelMsgNum)
assert.Equal(suite.T(), helper.pubDDLMsgNum, helper.subDDLMsgNum)
assert.Equal(suite.T(), helper.pubPackNum, helper.subPackNum)
}
}
func (suite *SimulationSuite) TestMerge() {
ctx, cancel := context.WithCancel(context.Background())
go suite.produceTimeTickOnly(ctx)
const vchannelNum = 20
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
positions, err := getSeekPositions(suite.factory, suite.pchannel, 200)
assert.NoError(suite.T(), err)
for i := 0; i < vchannelNum; i++ {
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
output, err := suite.manager.Add(vchannel, positions[rand.Intn(len(positions))],
mqwrapper.SubscriptionPositionUnknown) // seek from random position
assert.NoError(suite.T(), err)
suite.vchannels[vchannel] = &vchannelHelper{output: output}
}
wg := &sync.WaitGroup{}
for vchannel := range suite.vchannels {
wg.Add(1)
go suite.consumeMsg(ctx, wg, vchannel)
}
suite.Eventually(func() bool {
suite.T().Logf("dispatcherManager.dispatcherNum = %d", suite.manager.Num())
return suite.manager.Num() == 1 // expected all merged, only mainDispatcher exist
}, 10*time.Second, 100*time.Millisecond)
cancel()
wg.Wait()
}
func (suite *SimulationSuite) TestSplit() {
ctx, cancel := context.WithCancel(context.Background())
go suite.produceTimeTickOnly(ctx)
const vchannelNum = 10
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
DefaultTargetChanSize = 10
MaxTolerantLag = 500 * time.Millisecond
for i := 0; i < vchannelNum; i++ {
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
output, err := suite.manager.Add(vchannel, nil, mqwrapper.SubscriptionPositionEarliest)
assert.NoError(suite.T(), err)
suite.vchannels[vchannel] = &vchannelHelper{output: output}
}
const splitNum = 3
wg := &sync.WaitGroup{}
counter := 0
for vchannel := range suite.vchannels {
wg.Add(1)
go suite.consumeMsg(ctx, wg, vchannel)
counter++
if counter >= len(suite.vchannels)-splitNum {
break
}
}
suite.Eventually(func() bool {
suite.T().Logf("dispatcherManager.dispatcherNum = %d, splitNum+1 = %d", suite.manager.Num(), splitNum+1)
return suite.manager.Num() == splitNum+1 // expected 1 mainDispatcher and `splitNum` soloDispatchers
}, 10*time.Second, 100*time.Millisecond)
cancel()
wg.Wait()
}
func (suite *SimulationSuite) TearDownTest() {
for vchannel := range suite.vchannels {
suite.manager.Remove(vchannel)
}
suite.manager.Close()
}
func (suite *SimulationSuite) TearDownSuite() {
}
func TestSimulation(t *testing.T) {
suite.Run(t, new(SimulationSuite))
}

View File

@ -0,0 +1,213 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"context"
"fmt"
"math/rand"
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const (
dim = 128
)
func newMockFactory() msgstream.Factory {
paramtable.Init()
return msgstream.NewRmsFactory("/tmp/milvus/rocksmq/")
}
func newMockProducer(factory msgstream.Factory, pchannel string) (msgstream.MsgStream, error) {
stream, err := factory.NewMsgStream(context.Background())
if err != nil {
return nil, err
}
stream.AsProducer([]string{pchannel})
stream.SetRepackFunc(defaultInsertRepackFunc)
return stream, nil
}
func getSeekPositions(factory msgstream.Factory, pchannel string, maxNum int) ([]*msgstream.MsgPosition, error) {
stream, err := factory.NewTtMsgStream(context.Background())
if err != nil {
return nil, err
}
defer stream.Close()
stream.AsConsumer([]string{pchannel}, fmt.Sprintf("%d", rand.Int()), mqwrapper.SubscriptionPositionEarliest)
positions := make([]*msgstream.MsgPosition, 0)
for {
select {
case <-time.After(100 * time.Millisecond): // no message to consume
return positions, nil
case pack := <-stream.Chan():
positions = append(positions, pack.EndPositions[0])
if len(positions) >= maxNum {
return positions, nil
}
}
}
}
func genPKs(numRows int) []typeutil.IntPrimaryKey {
ids := make([]typeutil.IntPrimaryKey, numRows)
for i := 0; i < numRows; i++ {
ids[i] = typeutil.IntPrimaryKey(i)
}
return ids
}
func genTimestamps(numRows int) []typeutil.Timestamp {
ts := make([]typeutil.Timestamp, numRows)
for i := 0; i < numRows; i++ {
ts[i] = typeutil.Timestamp(i + 1)
}
return ts
}
func genInsertMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstream.InsertMsg {
floatVec := make([]float32, numRows*dim)
for i := 0; i < numRows*dim; i++ {
floatVec[i] = rand.Float32()
}
hashValues := make([]uint32, numRows)
for i := 0; i < numRows; i++ {
hashValues[i] = uint32(1)
}
return &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{HashValues: hashValues},
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Insert, MsgID: msgID},
ShardName: vchannel,
Timestamps: genTimestamps(numRows),
RowIDs: genPKs(numRows),
FieldsData: []*schemapb.FieldData{{
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: floatVec}},
},
},
}},
NumRows: uint64(numRows),
Version: internalpb.InsertDataVersion_ColumnBased,
},
}
}
func genDeleteMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstream.DeleteMsg {
return &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{HashValues: make([]uint32, numRows)},
DeleteRequest: internalpb.DeleteRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete, MsgID: msgID},
ShardName: vchannel,
PrimaryKeys: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: genPKs(numRows),
},
},
},
Timestamps: genTimestamps(numRows),
NumRows: int64(numRows),
},
}
}
func genDDLMsg(msgType commonpb.MsgType) msgstream.TsMsg {
switch msgType {
case commonpb.MsgType_CreateCollection:
return &msgstream.CreateCollectionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
CreateCollectionRequest: internalpb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
},
}
case commonpb.MsgType_DropCollection:
return &msgstream.DropCollectionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
DropCollectionRequest: internalpb.DropCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
},
}
case commonpb.MsgType_CreatePartition:
return &msgstream.CreatePartitionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
CreatePartitionRequest: internalpb.CreatePartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreatePartition},
},
}
case commonpb.MsgType_DropPartition:
return &msgstream.DropPartitionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
DropPartitionRequest: internalpb.DropPartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition},
},
}
}
return nil
}
func genTimeTickMsg(ts typeutil.Timestamp) *msgstream.TimeTickMsg {
return &msgstream.TimeTickMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
TimeTickMsg: internalpb.TimeTickMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_TimeTick,
Timestamp: ts,
},
},
}
}
// defaultInsertRepackFunc repacks the dml messages.
func defaultInsertRepackFunc(
tsMsgs []msgstream.TsMsg,
hashKeys [][]int32,
) (map[int32]*msgstream.MsgPack, error) {
if len(hashKeys) < len(tsMsgs) {
return nil, fmt.Errorf(
"the length of hash keys (%d) is less than the length of messages (%d)",
len(hashKeys),
len(tsMsgs),
)
}
// after assigning segment id to msg, tsMsgs was already re-bucketed
pack := make(map[int32]*msgstream.MsgPack)
for idx, msg := range tsMsgs {
if len(hashKeys[idx]) <= 0 {
return nil, fmt.Errorf("no hash key for %dth message", idx)
}
key := hashKeys[idx][0]
_, ok := pack[key]
if !ok {
pack[key] = &msgstream.MsgPack{}
}
pack[key].Msgs = append(pack[key].Msgs, msg)
}
return pack, nil
}

View File

@ -0,0 +1,72 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"fmt"
"sync"
"time"
)
// TODO: dyh, move to config
var (
MaxTolerantLag = 3 * time.Second
DefaultTargetChanSize = 1024
)
type target struct {
vchannel string
ch chan *MsgPack
pos *Pos
closeMu sync.Mutex
closeOnce sync.Once
closed bool
}
func newTarget(vchannel string, pos *Pos) *target {
t := &target{
vchannel: vchannel,
ch: make(chan *MsgPack, DefaultTargetChanSize),
pos: pos,
}
t.closed = false
return t
}
func (t *target) close() {
t.closeMu.Lock()
defer t.closeMu.Unlock()
t.closeOnce.Do(func() {
t.closed = true
close(t.ch)
})
}
func (t *target) send(pack *MsgPack) error {
t.closeMu.Lock()
defer t.closeMu.Unlock()
if t.closed {
return nil
}
select {
case <-time.After(MaxTolerantLag):
return fmt.Errorf("send target timeout, vchannel=%s, timeout=%s", t.vchannel, MaxTolerantLag)
case t.ch <- pack:
return nil
}
}

View File

@ -53,6 +53,7 @@ type TsMsg interface {
Unmarshal(MarshalType) (TsMsg, error)
Position() *MsgPosition
SetPosition(*MsgPosition)
VChannel() string
}
// BaseMsg is a basic structure that contains begin timestamp, end timestamp and the position of msgstream
@ -62,6 +63,7 @@ type BaseMsg struct {
EndTimestamp Timestamp
HashValues []uint32
MsgPosition *MsgPosition
Vchannel string
}
// TraceCtx returns the context of opentracing
@ -99,6 +101,10 @@ func (bm *BaseMsg) SetPosition(position *MsgPosition) {
bm.MsgPosition = position
}
func (bm *BaseMsg) VChannel() string {
return bm.Vchannel
}
func convertToByteArray(input interface{}) ([]byte, error) {
switch output := input.(type) {
case []byte:
@ -170,6 +176,7 @@ func (it *InsertMsg) Unmarshal(input MarshalType) (TsMsg, error) {
insertMsg.BeginTimestamp = timestamp
}
}
insertMsg.Vchannel = insertMsg.ShardName
return insertMsg, nil
}
@ -278,6 +285,7 @@ func (it *InsertMsg) IndexMsg(index int) *InsertMsg {
Ctx: it.TraceCtx(),
BeginTimestamp: it.BeginTimestamp,
EndTimestamp: it.EndTimestamp,
Vchannel: it.Vchannel,
HashValues: it.HashValues,
MsgPosition: it.MsgPosition,
},
@ -361,7 +369,7 @@ func (dt *DeleteMsg) Unmarshal(input MarshalType) (TsMsg, error) {
deleteMsg.BeginTimestamp = timestamp
}
}
deleteMsg.Vchannel = deleteRequest.ShardName
return deleteMsg, nil
}

View File

@ -148,6 +148,7 @@ func assignSegmentID(ctx context.Context, insertMsg *msgstream.InsertMsg, result
msg.HashValues = append(msg.HashValues, insertMsg.HashValues[offset])
msg.Timestamps = append(msg.Timestamps, insertMsg.Timestamps[offset])
msg.RowIDs = append(msg.RowIDs, insertMsg.RowIDs[offset])
msg.BaseMsg.Vchannel = channelName
msg.NumRows++
requestSize += curRowMessageSize
}

View File

@ -268,6 +268,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
partitionName := dt.deleteMsg.PartitionName
proxyID := dt.deleteMsg.Base.SourceID
for index, key := range dt.deleteMsg.HashValues {
vchannel := channelNames[key]
ts := dt.deleteMsg.Timestamps[index]
_, ok := result[key]
if !ok {
@ -297,6 +298,8 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
curMsg.Timestamps = append(curMsg.Timestamps, dt.deleteMsg.Timestamps[index])
typeutil.AppendIDs(curMsg.PrimaryKeys, dt.deleteMsg.PrimaryKeys, index)
curMsg.NumRows++
curMsg.ShardName = vchannel
curMsg.Vchannel = vchannel
}
// send delete request to log broker

View File

@ -439,6 +439,7 @@ func (it *upsertTask) deleteExecute(ctx context.Context, msgPack *msgstream.MsgP
curMsg.Timestamps = append(curMsg.Timestamps, it.upsertMsg.DeleteMsg.Timestamps[index])
typeutil.AppendIDs(curMsg.PrimaryKeys, it.upsertMsg.DeleteMsg.PrimaryKeys, index)
curMsg.NumRows++
curMsg.ShardName = channelNames[key]
}
// send delete request to log broker

View File

@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
@ -40,6 +41,7 @@ type dataSyncService struct {
metaReplica ReplicaInterface
tSafeReplica TSafeReplicaInterface
dispClient msgdispatcher.Client
msFactory msgstream.Factory
}
@ -51,7 +53,7 @@ func (dsService *dataSyncService) getFlowGraphNum() int {
}
// addFlowGraphsForDMLChannels add flowGraphs to dmlChannel2FlowGraph
func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID UniqueID, dmlChannels []string) (map[string]*queryNodeFlowGraph, error) {
func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID UniqueID, dmlChannels map[string]*msgstream.MsgPosition) (map[string]*queryNodeFlowGraph, error) {
dsService.mu.Lock()
defer dsService.mu.Unlock()
@ -61,7 +63,7 @@ func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID Uniqu
}
results := make(map[string]*queryNodeFlowGraph)
for _, channel := range dmlChannels {
for channel, position := range dmlChannels {
if _, ok := dsService.dmlChannel2FlowGraph[channel]; ok {
log.Warn("dml flow graph has been existed",
zap.Any("collectionID", collectionID),
@ -74,7 +76,8 @@ func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID Uniqu
dsService.metaReplica,
dsService.tSafeReplica,
channel,
dsService.msFactory)
position,
dsService.dispClient)
if err != nil {
for _, fg := range results {
fg.flowGraph.Close()
@ -128,7 +131,7 @@ func (dsService *dataSyncService) addFlowGraphsForDeltaChannels(collectionID Uni
dsService.metaReplica,
dsService.tSafeReplica,
channel,
dsService.msFactory)
dsService.dispClient)
if err != nil {
for channel, fg := range results {
fg.flowGraph.Close()
@ -291,6 +294,7 @@ func (dsService *dataSyncService) removeEmptyFlowGraphByChannel(collectionID int
func newDataSyncService(ctx context.Context,
metaReplica ReplicaInterface,
tSafeReplica TSafeReplicaInterface,
dispClient msgdispatcher.Client,
factory msgstream.Factory) *dataSyncService {
return &dataSyncService{
@ -299,6 +303,7 @@ func newDataSyncService(ctx context.Context,
deltaChannel2FlowGraph: make(map[Channel]*queryNodeFlowGraph),
metaReplica: metaReplica,
tSafeReplica: tSafeReplica,
dispClient: dispClient,
msFactory: factory,
}
}
@ -308,6 +313,7 @@ func (dsService *dataSyncService) close() {
// close DML flow graphs
for channel, nodeFG := range dsService.dmlChannel2FlowGraph {
if nodeFG != nil {
dsService.dispClient.Deregister(channel)
nodeFG.flowGraph.Close()
}
delete(dsService.dmlChannel2FlowGraph, channel)
@ -315,6 +321,7 @@ func (dsService *dataSyncService) close() {
// close delta flow graphs
for channel, nodeFG := range dsService.deltaChannel2FlowGraph {
if nodeFG != nil {
dsService.dispClient.Deregister(channel)
nodeFG.flowGraph.Close()
}
delete(dsService.deltaChannel2FlowGraph, channel)

View File

@ -21,14 +21,20 @@ import (
"fmt"
"testing"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func init() {
rateCol, _ = newRateCollector()
Params.Init()
}
func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
@ -40,17 +46,18 @@ func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
fac := genFactory()
assert.NoError(t, err)
dispClient := msgdispatcher.NewClient(fac, typeutil.QueryNodeRole, paramtable.GetNodeID())
tSafe := newTSafeReplica()
dataSyncService := newDataSyncService(ctx, replica, tSafe, fac)
dataSyncService := newDataSyncService(ctx, replica, tSafe, dispClient, fac)
assert.NotNil(t, dataSyncService)
t.Run("test DMLFlowGraphs", func(t *testing.T) {
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.NoError(t, err)
assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1)
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.NoError(t, err)
assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1)
@ -68,7 +75,7 @@ func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
assert.Nil(t, fg)
assert.Error(t, err)
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.NoError(t, err)
assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1)
@ -88,7 +95,7 @@ func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
t.Run("test addFlowGraphsForDMLChannels checkReplica Failed", func(t *testing.T) {
err = dataSyncService.metaReplica.removeCollection(defaultCollectionID)
assert.NoError(t, err)
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.Error(t, err)
dataSyncService.metaReplica.addCollection(defaultCollectionID, genTestCollectionSchema())
})
@ -103,9 +110,10 @@ func TestDataSyncService_DeltaFlowGraphs(t *testing.T) {
fac := genFactory()
assert.NoError(t, err)
dispClient := msgdispatcher.NewClient(fac, typeutil.QueryNodeRole, paramtable.GetNodeID())
tSafe := newTSafeReplica()
dataSyncService := newDataSyncService(ctx, replica, tSafe, fac)
dataSyncService := newDataSyncService(ctx, replica, tSafe, dispClient, fac)
assert.NotNil(t, dataSyncService)
t.Run("test DeltaFlowGraphs", func(t *testing.T) {
@ -160,12 +168,14 @@ func TestDataSyncService_DeltaFlowGraphs(t *testing.T) {
type DataSyncServiceSuite struct {
suite.Suite
factory dependency.Factory
dsService *dataSyncService
dispClient msgdispatcher.Client
factory dependency.Factory
dsService *dataSyncService
}
func (s *DataSyncServiceSuite) SetupSuite() {
s.factory = genFactory()
s.dispClient = msgdispatcher.NewClient(s.factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
}
func (s *DataSyncServiceSuite) SetupTest() {
@ -176,7 +186,7 @@ func (s *DataSyncServiceSuite) SetupTest() {
s.Require().NoError(err)
tSafe := newTSafeReplica()
s.dsService = newDataSyncService(ctx, replica, tSafe, s.factory)
s.dsService = newDataSyncService(ctx, replica, tSafe, s.dispClient, s.factory)
s.Require().NoError(err)
}

View File

@ -18,22 +18,20 @@ package querynode
import (
"context"
"errors"
"fmt"
"time"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/flowgraph"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type (
@ -49,9 +47,9 @@ type queryNodeFlowGraph struct {
collectionID UniqueID
vchannel Channel
flowGraph *flowgraph.TimeTickedFlowGraph
dmlStream msgstream.MsgStream
tSafeReplica TSafeReplicaInterface
consumerCnt int
dispClient msgdispatcher.Client
}
// newQueryNodeFlowGraph returns a new queryNodeFlowGraph
@ -60,16 +58,18 @@ func newQueryNodeFlowGraph(ctx context.Context,
metaReplica ReplicaInterface,
tSafeReplica TSafeReplicaInterface,
vchannel Channel,
factory msgstream.Factory) (*queryNodeFlowGraph, error) {
pos *msgstream.MsgPosition,
dispClient msgdispatcher.Client) (*queryNodeFlowGraph, error) {
q := &queryNodeFlowGraph{
collectionID: collectionID,
vchannel: vchannel,
tSafeReplica: tSafeReplica,
flowGraph: flowgraph.NewTimeTickedFlowGraph(ctx),
dispClient: dispClient,
}
dmStreamNode, err := q.newDmInputNode(ctx, factory, collectionID, vchannel, metrics.InsertLabel)
dmStreamNode, err := q.newDmInputNode(collectionID, vchannel, pos, metrics.InsertLabel)
if err != nil {
return nil, err
}
@ -123,16 +123,18 @@ func newQueryNodeDeltaFlowGraph(ctx context.Context,
metaReplica ReplicaInterface,
tSafeReplica TSafeReplicaInterface,
vchannel Channel,
factory msgstream.Factory) (*queryNodeFlowGraph, error) {
dispClient msgdispatcher.Client) (*queryNodeFlowGraph, error) {
q := &queryNodeFlowGraph{
collectionID: collectionID,
vchannel: vchannel,
tSafeReplica: tSafeReplica,
flowGraph: flowgraph.NewTimeTickedFlowGraph(ctx),
dispClient: dispClient,
}
dmStreamNode, err := q.newDmInputNode(ctx, factory, collectionID, vchannel, metrics.DeleteLabel)
// use nil position, let deltaFlowGraph consume from latest.
dmStreamNode, err := q.newDmInputNode(collectionID, vchannel, nil, metrics.DeleteLabel)
if err != nil {
return nil, err
}
@ -184,84 +186,45 @@ func newQueryNodeDeltaFlowGraph(ctx context.Context,
}
// newDmInputNode returns a new inputNode
func (q *queryNodeFlowGraph) newDmInputNode(ctx context.Context, factory msgstream.Factory, collectionID UniqueID, vchannel Channel, dataType string) (*flowgraph.InputNode, error) {
insertStream, err := factory.NewTtMsgStream(ctx)
if err != nil {
return nil, err
func (q *queryNodeFlowGraph) newDmInputNode(collectionID UniqueID, vchannel Channel, pos *msgstream.MsgPosition, dataType string) (*flowgraph.InputNode, error) {
log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("collection ID", collectionID),
zap.String("vchannel", vchannel))
var err error
var input <-chan *msgstream.MsgPack
tsBegin := time.Now()
if pos != nil && len(pos.MsgID) != 0 {
input, err = q.dispClient.Register(vchannel, pos, mqwrapper.SubscriptionPositionUnknown)
if err != nil {
return nil, err
}
log.Info("QueryNode seek successfully when register to msgDispatcher",
zap.ByteString("msgID", pos.GetMsgID()),
zap.Time("tsTime", tsoutil.PhysicalTime(pos.GetTimestamp())),
zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(pos.GetTimestamp()))),
zap.Duration("timeTaken", time.Since(tsBegin)))
} else {
input, err = q.dispClient.Register(vchannel, nil, mqwrapper.SubscriptionPositionLatest)
if err != nil {
return nil, err
}
log.Info("QueryNode consume successfully when register to msgDispatcher",
zap.Duration("timeTaken", time.Since(tsBegin)))
}
q.dmlStream = insertStream
maxQueueLength := Params.QueryNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()
maxParallelism := Params.QueryNodeCfg.FlowGraphMaxParallelism.GetAsInt32()
name := fmt.Sprintf("dmInputNode-query-%d-%s", collectionID, vchannel)
node := flowgraph.NewInputNode(insertStream, name, maxQueueLength, maxParallelism, typeutil.QueryNodeRole,
node := flowgraph.NewInputNode(input, name, maxQueueLength, maxParallelism, typeutil.QueryNodeRole,
paramtable.GetNodeID(), collectionID, dataType)
return node, nil
}
// consumeFlowGraph would consume by channel and subName
func (q *queryNodeFlowGraph) consumeFlowGraph(channel Channel, subName ConsumeSubName) error {
if q.dmlStream == nil {
return errors.New("null dml message stream in flow graph")
}
q.dmlStream.AsConsumer([]string{channel}, subName, mqwrapper.SubscriptionPositionUnknown)
log.Info("query node flow graph consumes from PositionUnknown",
zap.Int64("collectionID", q.collectionID),
zap.String("pchannel", channel),
zap.String("vchannel", q.vchannel),
zap.String("subName", subName),
)
q.consumerCnt++
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
return nil
}
// consumeFlowGraphFromLatest would consume from latest by channel and subName
func (q *queryNodeFlowGraph) consumeFlowGraphFromLatest(channel Channel, subName ConsumeSubName) error {
if q.dmlStream == nil {
return errors.New("null dml message stream in flow graph")
}
q.dmlStream.AsConsumer([]string{channel}, subName, mqwrapper.SubscriptionPositionLatest)
log.Info("query node flow graph consumes from latest",
zap.Int64("collectionID", q.collectionID),
zap.String("pchannel", channel),
zap.String("vchannel", q.vchannel),
zap.String("subName", subName),
)
q.consumerCnt++
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
return nil
}
// seekQueryNodeFlowGraph would seek by position
func (q *queryNodeFlowGraph) consumeFlowGraphFromPosition(position *internalpb.MsgPosition) error {
q.dmlStream.AsConsumer([]string{position.ChannelName}, position.MsgGroup, mqwrapper.SubscriptionPositionUnknown)
start := time.Now()
err := q.dmlStream.Seek([]*internalpb.MsgPosition{position})
// setup first ts
q.tSafeReplica.setTSafe(q.vchannel, position.GetTimestamp())
ts, _ := tsoutil.ParseTS(position.GetTimestamp())
log.Info("query node flow graph seeks from position",
zap.Int64("collectionID", q.collectionID),
zap.String("pchannel", position.ChannelName),
zap.String("vchannel", q.vchannel),
zap.Time("checkpointTs", ts),
zap.Duration("tsLag", time.Since(ts)),
zap.Duration("elapse", time.Since(start)),
)
q.consumerCnt++
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
return err
}
// close would close queryNodeFlowGraph
func (q *queryNodeFlowGraph) close() {
q.dispClient.Deregister(q.vchannel)
q.flowGraph.Close()
if q.dmlStream != nil && q.consumerCnt > 0 {
if q.consumerCnt > 0 {
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Sub(float64(q.consumerCnt))
}
log.Info("stop query node flow graph",

View File

@ -1,82 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynode
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/internalpb"
)
func TestQueryNodeFlowGraph_consumerFlowGraph(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tSafe := newTSafeReplica()
streamingReplica, err := genSimpleReplica()
assert.NoError(t, err)
fac := genFactory()
fg, err := newQueryNodeFlowGraph(ctx,
defaultCollectionID,
streamingReplica,
tSafe,
defaultDMLChannel,
fac)
assert.NoError(t, err)
err = fg.consumeFlowGraph(defaultDMLChannel, defaultSubName)
assert.NoError(t, err)
fg.close()
}
func TestQueryNodeFlowGraph_seekQueryNodeFlowGraph(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
streamingReplica, err := genSimpleReplica()
assert.NoError(t, err)
fac := genFactory()
tSafe := newTSafeReplica()
fg, err := newQueryNodeFlowGraph(ctx,
defaultCollectionID,
streamingReplica,
tSafe,
defaultDMLChannel,
fac)
assert.NoError(t, err)
position := &internalpb.MsgPosition{
ChannelName: defaultDMLChannel,
MsgID: []byte{},
MsgGroup: defaultSubName,
Timestamp: 0,
}
err = fg.consumeFlowGraphFromPosition(position)
assert.Error(t, err)
fg.close()
}

View File

@ -26,7 +26,6 @@ import (
"github.com/milvus-io/milvus/internal/log"
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/samber/lo"
)
@ -189,31 +188,6 @@ func (l *loadSegmentsTask) watchDeltaChannel(deltaChannels []string) error {
}
}
}()
consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName.GetValue(), collectionID, paramtable.GetNodeID())
// channels as consumer
for channel, fg := range channel2FlowGraph {
pchannel := VPDeltaChannels[channel]
// use pChannel to consume
err = fg.consumeFlowGraphFromLatest(pchannel, consumeSubName)
if err != nil {
log.Error("msgStream as consumer failed for deltaChannels", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels))
break
}
}
if err != nil {
log.Warn("watchDeltaChannel, add flowGraph for deltaChannel failed", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels), zap.Error(err))
for _, fg := range channel2FlowGraph {
fg.flowGraph.Close()
}
gcChannels := make([]Channel, 0)
for channel := range channel2FlowGraph {
gcChannels = append(gcChannels, channel)
}
l.node.dataSyncService.removeFlowGraphsByDeltaChannels(gcChannels)
return err
}
log.Info("watchDeltaChannel, add flowGraph for deltaChannel success", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels))

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus/internal/common"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
@ -1702,6 +1703,7 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory
etcdKV := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath.GetValue())
node.etcdKV = etcdKV
node.dispClient = msgdispatcher.NewClient(fac, typeutil.QueryNodeRole, paramtable.GetNodeID())
node.tSafeReplica = newTSafeReplica()
replica, err := genSimpleReplicaWithSealSegment(ctx)
@ -1711,7 +1713,7 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory
node.tSafeReplica.addTSafe(defaultDMLChannel)
node.tSafeReplica.addTSafe(defaultDeltaChannel)
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, replica, node.tSafeReplica, node.factory)
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, replica, node.tSafeReplica, node.dispClient, node.factory)
node.metaReplica = replica

View File

@ -30,6 +30,7 @@ import "C"
import (
"context"
"fmt"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"os"
"path"
"runtime"
@ -106,8 +107,9 @@ type QueryNode struct {
etcdCli *clientv3.Client
address string
factory dependency.Factory
scheduler *taskScheduler
dispClient msgdispatcher.Client
factory dependency.Factory
scheduler *taskScheduler
sessionMu sync.Mutex
session *sessionutil.Session
@ -256,6 +258,9 @@ func (node *QueryNode) Init() error {
}
log.Info("QueryNode init rateCollector done", zap.Int64("nodeID", paramtable.GetNodeID()))
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
log.Info("QueryNode init dispatcher client done", zap.Int64("nodeID", paramtable.GetNodeID()))
node.vectorStorage, err = node.factory.NewPersistentStorageChunkManager(node.queryNodeLoopCtx)
if err != nil {
log.Error("QueryNode init vector storage failed", zap.Error(err))
@ -283,7 +288,7 @@ func (node *QueryNode) Init() error {
node.vectorStorage,
node.factory)
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.metaReplica, node.tSafeReplica, node.factory)
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.metaReplica, node.tSafeReplica, node.dispClient, node.factory)
node.InitSegcore()

View File

@ -27,13 +27,14 @@ import (
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/embed"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/paramtable"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
var embedetcdServer *embed.Etcd
@ -98,10 +99,9 @@ func newQueryNodeMock() *QueryNode {
factory := newMessageStreamFactory()
svr := NewQueryNode(ctx, factory)
tsReplica := newTSafeReplica()
replica := newCollectionReplica()
svr.metaReplica = replica
svr.dataSyncService = newDataSyncService(ctx, svr.metaReplica, tsReplica, factory)
svr.dispClient = msgdispatcher.NewClient(factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
svr.metaReplica = newCollectionReplica()
svr.dataSyncService = newDataSyncService(ctx, svr.metaReplica, tsReplica, svr.dispClient, factory)
svr.vectorStorage, err = factory.NewPersistentStorageChunkManager(ctx)
if err != nil {
panic(err)

View File

@ -458,7 +458,7 @@ func TestTask_releasePartitionTask(t *testing.T) {
req: genReleasePartitionsRequest(),
node: node,
}
_, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
_, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.NoError(t, err)
err = task.Execute(ctx)
assert.NoError(t, err)
@ -534,7 +534,7 @@ func TestTask_releasePartitionTask(t *testing.T) {
req: genReleasePartitionsRequest(),
node: node,
}
_, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
_, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.NoError(t, err)
err = task.Execute(ctx)
assert.NoError(t, err)

View File

@ -122,7 +122,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
}
}()
channel2FlowGraph, err := w.initFlowGraph(ctx, collectionID, vChannels, VPChannels)
channel2FlowGraph, err := w.initFlowGraph(collectionID, vChannels)
if err != nil {
return err
}
@ -243,21 +243,16 @@ func (w *watchDmChannelsTask) LoadGrowingSegments(ctx context.Context, collectio
return unFlushedSegmentIDs, nil
}
func (w *watchDmChannelsTask) initFlowGraph(ctx context.Context, collectionID UniqueID, vChannels []Channel, VPChannels map[string]string) (map[string]*queryNodeFlowGraph, error) {
func (w *watchDmChannelsTask) initFlowGraph(collectionID UniqueID, vChannels []Channel) (map[string]*queryNodeFlowGraph, error) {
// So far, we don't support to enable each node with two different channel
consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName.GetValue(), collectionID, paramtable.GetNodeID())
// group channels by to seeking or consuming
// group channels by to seeking
channel2SeekPosition := make(map[string]*internalpb.MsgPosition)
// for channel with no position
channel2AsConsumerPosition := make(map[string]*internalpb.MsgPosition)
for _, info := range w.req.Infos {
if info.SeekPosition == nil || len(info.SeekPosition.MsgID) == 0 {
channel2AsConsumerPosition[info.ChannelName] = info.SeekPosition
continue
if info.SeekPosition != nil && len(info.SeekPosition.MsgID) != 0 {
info.SeekPosition.MsgGroup = consumeSubName
}
info.SeekPosition.MsgGroup = consumeSubName
channel2SeekPosition[info.ChannelName] = info.SeekPosition
}
log.Info("watchDMChannel, group channels done", zap.Int64("collectionID", collectionID))
@ -333,49 +328,11 @@ func (w *watchDmChannelsTask) initFlowGraph(ctx context.Context, collectionID Un
)
// add flow graph
channel2FlowGraph, err := w.node.dataSyncService.addFlowGraphsForDMLChannels(collectionID, vChannels)
channel2FlowGraph, err := w.node.dataSyncService.addFlowGraphsForDMLChannels(collectionID, channel2SeekPosition)
if err != nil {
log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err))
return nil, err
}
log.Info("Query node add DML flow graphs", zap.Int64("collectionID", collectionID), zap.Any("channels", vChannels))
// channels as consumer
for channel, fg := range channel2FlowGraph {
if _, ok := channel2AsConsumerPosition[channel]; ok {
// use pChannel to consume
err = fg.consumeFlowGraph(VPChannels[channel], consumeSubName)
if err != nil {
log.Error("msgStream as consumer failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel))
break
}
}
if pos, ok := channel2SeekPosition[channel]; ok {
pos.MsgGroup = consumeSubName
// use pChannel to seek
pos.ChannelName = VPChannels[channel]
err = fg.consumeFlowGraphFromPosition(pos)
if err != nil {
log.Error("msgStream seek failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel))
break
}
}
}
if err != nil {
log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err))
for _, fg := range channel2FlowGraph {
fg.flowGraph.Close()
}
gcChannels := make([]Channel, 0)
for channel := range channel2FlowGraph {
gcChannels = append(gcChannels, channel)
}
w.node.dataSyncService.removeFlowGraphsByDMLChannels(gcChannels)
return nil, err
}
log.Info("watchDMChannel, add flowGraph for dmChannels success", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
return channel2FlowGraph, nil
}

View File

@ -36,7 +36,7 @@ import (
// InputNode is the entry point of flowgragh
type InputNode struct {
BaseNode
inStream msgstream.MsgStream
input <-chan *msgstream.MsgPack
lastMsg *msgstream.MsgPack
name string
role string
@ -51,17 +51,6 @@ func (inNode *InputNode) IsInputNode() bool {
return true
}
// Start is used to start input msgstream
func (inNode *InputNode) Start() {
}
// Close implements node
func (inNode *InputNode) Close() {
inNode.closeOnce.Do(func() {
inNode.inStream.Close()
})
}
func (inNode *InputNode) IsValidInMsg(in []Msg) bool {
return true
}
@ -71,16 +60,11 @@ func (inNode *InputNode) Name() string {
return inNode.name
}
// InStream returns the internal MsgStream
func (inNode *InputNode) InStream() msgstream.MsgStream {
return inNode.inStream
}
// Operate consume a message pack from msgstream and return
func (inNode *InputNode) Operate(in []Msg) []Msg {
msgPack, ok := <-inNode.inStream.Chan()
msgPack, ok := <-inNode.input
if !ok {
log.Warn("MsgStream closed", zap.Any("input node", inNode.Name()))
log.Warn("input closed", zap.Any("input node", inNode.Name()))
if inNode.lastMsg != nil {
log.Info("trigger force sync", zap.Int64("collection", inNode.collectionID), zap.Any("position", inNode.lastMsg))
return []Msg{&MsgStreamMsg{
@ -151,15 +135,15 @@ func (inNode *InputNode) Operate(in []Msg) []Msg {
return []Msg{msgStreamMsg}
}
// NewInputNode composes an InputNode with provided MsgStream, name and parameters
func NewInputNode(inStream msgstream.MsgStream, nodeName string, maxQueueLength int32, maxParallelism int32, role string, nodeID int64, collectionID int64, dataType string) *InputNode {
// NewInputNode composes an InputNode with provided input channel, name and parameters
func NewInputNode(input <-chan *msgstream.MsgPack, nodeName string, maxQueueLength int32, maxParallelism int32, role string, nodeID int64, collectionID int64, dataType string) *InputNode {
baseNode := BaseNode{}
baseNode.SetMaxQueueLength(maxQueueLength)
baseNode.SetMaxParallelism(maxParallelism)
return &InputNode{
BaseNode: baseNode,
inStream: inStream,
input: input,
name: nodeName,
role: role,
nodeID: nodeID,

View File

@ -40,7 +40,7 @@ func TestInputNode(t *testing.T) {
produceStream.Produce(&msgPack)
nodeName := "input_node"
inputNode := NewInputNode(msgStream, nodeName, 100, 100, "", 0, 0, "")
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
defer inputNode.Close()
isInputNode := inputNode.IsInputNode()
@ -49,9 +49,6 @@ func TestInputNode(t *testing.T) {
name := inputNode.Name()
assert.Equal(t, name, nodeName)
stream := inputNode.InStream()
assert.NotNil(t, stream)
output := inputNode.Operate(nil)
assert.NotNil(t, output)
msg, ok := output[0].(*MsgStreamMsg)

View File

@ -76,6 +76,10 @@ func (bm *MockMsg) SetPosition(position *MsgPosition) {
}
func (bm *MockMsg) VChannel() string {
return ""
}
func Test_GenerateMsgStreamMsg(t *testing.T) {
messages := make([]msgstream.TsMsg, 1)
messages[0] = &MockMsg{

View File

@ -74,7 +74,7 @@ func TestNodeCtx_Start(t *testing.T) {
produceStream.Produce(&msgPack)
nodeName := "input_node"
inputNode := NewInputNode(msgStream, nodeName, 100, 100, "", 0, 0, "")
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
node := &nodeCtx{
node: inputNode,