mirror of https://github.com/milvus-io/milvus.git
Add msgDispatcher to support sharing msgs for different vChannel (#21917)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/22150/head
parent
a2435cfc4f
commit
d2667064bb
|
@ -39,13 +39,13 @@ import (
|
|||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
|
||||
allocator2 "github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/kv"
|
||||
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
|
@ -127,7 +127,8 @@ type DataNode struct {
|
|||
|
||||
closer io.Closer
|
||||
|
||||
factory dependency.Factory
|
||||
dispClient msgdispatcher.Client
|
||||
factory dependency.Factory
|
||||
}
|
||||
|
||||
// NewDataNode will return a DataNode with abnormal state.
|
||||
|
@ -249,6 +250,9 @@ func (node *DataNode) Init() error {
|
|||
}
|
||||
log.Info("DataNode server init rateCollector done", zap.Int64("node ID", paramtable.GetNodeID()))
|
||||
|
||||
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.DataNodeRole, paramtable.GetNodeID())
|
||||
log.Info("DataNode server init dispatcher client done", zap.Int64("node ID", paramtable.GetNodeID()))
|
||||
|
||||
idAllocator, err := allocator2.NewIDAllocator(node.ctx, node.rootCoord, paramtable.GetNodeID())
|
||||
if err != nil {
|
||||
log.Error("failed to create id allocator",
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
|
@ -49,6 +50,7 @@ type dataSyncService struct {
|
|||
resendTTCh chan resendTTMsg // chan to ask for resending DataNode time tick message.
|
||||
channel Channel // channel stores meta of channel
|
||||
idAllocator allocatorInterface // id/timestamp allocator
|
||||
dispClient msgdispatcher.Client
|
||||
msFactory msgstream.Factory
|
||||
collectionID UniqueID // collection id of vchan for which this data sync service serves
|
||||
vchannelName string
|
||||
|
@ -71,6 +73,7 @@ func newDataSyncService(ctx context.Context,
|
|||
resendTTCh chan resendTTMsg,
|
||||
channel Channel,
|
||||
alloc allocatorInterface,
|
||||
dispClient msgdispatcher.Client,
|
||||
factory msgstream.Factory,
|
||||
vchan *datapb.VchannelInfo,
|
||||
clearSignal chan<- string,
|
||||
|
@ -101,6 +104,7 @@ func newDataSyncService(ctx context.Context,
|
|||
resendTTCh: resendTTCh,
|
||||
channel: channel,
|
||||
idAllocator: alloc,
|
||||
dispClient: dispClient,
|
||||
msFactory: factory,
|
||||
collectionID: vchan.GetCollectionID(),
|
||||
vchannelName: vchan.GetChannelName(),
|
||||
|
@ -156,6 +160,7 @@ func (dsService *dataSyncService) close() {
|
|||
if dsService.fg != nil {
|
||||
log.Info("dataSyncService closing flowgraph", zap.Int64("collectionID", dsService.collectionID),
|
||||
zap.String("vChanName", dsService.vchannelName))
|
||||
dsService.dispClient.Deregister(dsService.vchannelName)
|
||||
dsService.fg.Close()
|
||||
metrics.DataNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
|
||||
metrics.DataNodeNumProducers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Sub(2) // timeTickChannel + deltaChannel
|
||||
|
@ -287,7 +292,7 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo) erro
|
|||
}
|
||||
|
||||
var dmStreamNode Node
|
||||
dmStreamNode, err = newDmInputNode(dsService.ctx, vchanInfo.GetSeekPosition(), c)
|
||||
dmStreamNode, err = newDmInputNode(dsService.dispClient, vchanInfo.GetSeekPosition(), c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
|
@ -40,10 +41,15 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
var dataSyncServiceTestDir = "/tmp/milvus_test/data_sync_service"
|
||||
|
||||
func init() {
|
||||
Params.Init()
|
||||
}
|
||||
|
||||
func getVchanInfo(info *testInfo) *datapb.VchannelInfo {
|
||||
var ufs []*datapb.SegmentInfo
|
||||
var fs []*datapb.SegmentInfo
|
||||
|
@ -160,12 +166,14 @@ func TestDataSyncService_newDataSyncService(te *testing.T) {
|
|||
if test.channelNil {
|
||||
channel = nil
|
||||
}
|
||||
dispClient := msgdispatcher.NewClient(test.inMsgFactory, typeutil.DataNodeRole, paramtable.GetNodeID())
|
||||
|
||||
ds, err := newDataSyncService(ctx,
|
||||
make(chan flushMsg),
|
||||
make(chan resendTTMsg),
|
||||
channel,
|
||||
NewAllocatorFactory(),
|
||||
dispClient,
|
||||
test.inMsgFactory,
|
||||
getVchanInfo(test),
|
||||
make(chan string),
|
||||
|
@ -217,6 +225,7 @@ func TestDataSyncService_Start(t *testing.T) {
|
|||
|
||||
allocFactory := NewAllocatorFactory(1)
|
||||
factory := dependency.NewDefaultFactory(true)
|
||||
dispClient := msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID())
|
||||
defer os.RemoveAll("/tmp/milvus")
|
||||
paramtable.Get().Save(Params.DataNodeCfg.FlushInsertBufferSize.Key, "1")
|
||||
|
||||
|
@ -270,7 +279,7 @@ func TestDataSyncService_Start(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor(), 0)
|
||||
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, dispClient, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor(), 0)
|
||||
assert.Nil(t, err)
|
||||
|
||||
sync.flushListener = make(chan *segmentFlushPack)
|
||||
|
@ -399,6 +408,7 @@ func TestDataSyncService_Close(t *testing.T) {
|
|||
|
||||
allocFactory = NewAllocatorFactory(1)
|
||||
factory = dependency.NewDefaultFactory(true)
|
||||
dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID())
|
||||
mockDataCoord = &DataCoordFactory{}
|
||||
)
|
||||
mockDataCoord.UserSegmentInfo = map[int64]*datapb.SegmentInfo{
|
||||
|
@ -421,7 +431,7 @@ func TestDataSyncService_Close(t *testing.T) {
|
|||
paramtable.Get().Reset(Params.DataNodeCfg.FlushInsertBufferSize.Key)
|
||||
|
||||
channel := newChannel(insertChannelName, collMeta.ID, collMeta.GetSchema(), mockRootCoord, cm)
|
||||
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, mockDataCoord, newCache(), cm, newCompactionExecutor(), 0)
|
||||
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, dispClient, factory, vchan, signalCh, mockDataCoord, newCache(), cm, newCompactionExecutor(), 0)
|
||||
assert.Nil(t, err)
|
||||
|
||||
sync.flushListener = make(chan *segmentFlushPack, 10)
|
||||
|
|
|
@ -220,6 +220,12 @@ func (ddn *ddNode) Operate(in []Msg) []Msg {
|
|||
for i := int64(0); i < dmsg.NumRows; i++ {
|
||||
dmsg.HashValues = append(dmsg.HashValues, uint32(0))
|
||||
}
|
||||
deltaVChannel, err := funcutil.ConvertChannelName(dmsg.ShardName, Params.CommonCfg.RootCoordDml.GetValue(), Params.CommonCfg.RootCoordDelta.GetValue())
|
||||
if err != nil {
|
||||
log.Error("convert dmlVChannel to deltaVChannel failed", zap.String("vchannel", ddn.vChannelName), zap.Error(err))
|
||||
panic(err)
|
||||
}
|
||||
dmsg.ShardName = deltaVChannel
|
||||
forwardMsgs = append(forwardMsgs, dmsg)
|
||||
if dmsg.CollectionID != ddn.collectionID {
|
||||
log.Warn("filter invalid DeleteMsg, collection mis-match",
|
||||
|
|
|
@ -278,6 +278,7 @@ func TestFlowGraph_DDNode_Operate(t *testing.T) {
|
|||
},
|
||||
DeleteRequest: internalpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete},
|
||||
ShardName: "by-dev-rootcoord-dml-mock-0",
|
||||
CollectionID: test.inMsgCollID,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package datanode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -25,10 +24,11 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/flowgraph"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
@ -37,50 +37,32 @@ import (
|
|||
// DmInputNode receives messages from message streams, packs messages between two timeticks, and passes all
|
||||
// messages between two timeticks to the following flowgraph node. In DataNode, the following flow graph node is
|
||||
// flowgraph ddNode.
|
||||
func newDmInputNode(ctx context.Context, seekPos *internalpb.MsgPosition, dmNodeConfig *nodeConfig) (*flowgraph.InputNode, error) {
|
||||
// subName should be unique, since pchannelName is shared among several collections
|
||||
// use vchannel in case of reuse pchannel for same collection
|
||||
consumeSubName := fmt.Sprintf("%s-%d-%s", Params.CommonCfg.DataNodeSubName.GetValue(), paramtable.GetNodeID(), dmNodeConfig.vChannelName)
|
||||
insertStream, err := dmNodeConfig.msFactory.NewTtMsgStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// MsgStream needs a physical channel name, but the channel name in seek position from DataCoord
|
||||
// is virtual channel name, so we need to convert vchannel name into pchannel neme here.
|
||||
pchannelName := funcutil.ToPhysicalChannel(dmNodeConfig.vChannelName)
|
||||
if seekPos != nil {
|
||||
insertStream.AsConsumer([]string{pchannelName}, consumeSubName, mqwrapper.SubscriptionPositionUnknown)
|
||||
seekPos.ChannelName = pchannelName
|
||||
cpTs, _ := tsoutil.ParseTS(seekPos.Timestamp)
|
||||
start := time.Now()
|
||||
log.Info("datanode begin to seek",
|
||||
zap.ByteString("seek msgID", seekPos.GetMsgID()),
|
||||
zap.String("pchannel", seekPos.GetChannelName()),
|
||||
zap.String("vchannel", dmNodeConfig.vChannelName),
|
||||
zap.Time("position", cpTs),
|
||||
zap.Duration("tsLag", time.Since(cpTs)),
|
||||
zap.Int64("collection ID", dmNodeConfig.collectionID))
|
||||
err = insertStream.Seek([]*internalpb.MsgPosition{seekPos})
|
||||
func newDmInputNode(dispatcherClient msgdispatcher.Client, seekPos *internalpb.MsgPosition, dmNodeConfig *nodeConfig) (*flowgraph.InputNode, error) {
|
||||
log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()),
|
||||
zap.Int64("collection ID", dmNodeConfig.collectionID),
|
||||
zap.String("vchannel", dmNodeConfig.vChannelName))
|
||||
var err error
|
||||
var input <-chan *msgstream.MsgPack
|
||||
if seekPos != nil && len(seekPos.MsgID) != 0 {
|
||||
input, err = dispatcherClient.Register(dmNodeConfig.vChannelName, seekPos, mqwrapper.SubscriptionPositionUnknown)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info("datanode seek successfully",
|
||||
zap.ByteString("seek msgID", seekPos.GetMsgID()),
|
||||
zap.String("pchannel", seekPos.GetChannelName()),
|
||||
zap.String("vchannel", dmNodeConfig.vChannelName),
|
||||
zap.Time("position", cpTs),
|
||||
zap.Duration("tsLag", time.Since(cpTs)),
|
||||
zap.Int64("collection ID", dmNodeConfig.collectionID),
|
||||
zap.Duration("elapse", time.Since(start)))
|
||||
log.Info("datanode seek successfully when register to msgDispatcher",
|
||||
zap.ByteString("msgID", seekPos.GetMsgID()),
|
||||
zap.Time("tsTime", tsoutil.PhysicalTime(seekPos.GetTimestamp())),
|
||||
zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(seekPos.GetTimestamp()))))
|
||||
} else {
|
||||
insertStream.AsConsumer([]string{pchannelName}, consumeSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
input, err = dispatcherClient.Register(dmNodeConfig.vChannelName, nil, mqwrapper.SubscriptionPositionEarliest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info("datanode consume successfully when register to msgDispatcher")
|
||||
}
|
||||
metrics.DataNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
|
||||
log.Info("datanode AsConsumer", zap.String("physical channel", pchannelName), zap.String("subName", consumeSubName), zap.Int64("collection ID", dmNodeConfig.collectionID))
|
||||
|
||||
name := fmt.Sprintf("dmInputNode-data-%d-%s", dmNodeConfig.collectionID, dmNodeConfig.vChannelName)
|
||||
node := flowgraph.NewInputNode(insertStream, name, dmNodeConfig.maxQueueLength, dmNodeConfig.maxParallelism,
|
||||
node := flowgraph.NewInputNode(input, name, dmNodeConfig.maxQueueLength, dmNodeConfig.maxParallelism,
|
||||
typeutil.DataNodeRole, paramtable.GetNodeID(), dmNodeConfig.collectionID, metrics.AllLabel)
|
||||
return node, nil
|
||||
}
|
||||
|
|
|
@ -21,12 +21,14 @@ import (
|
|||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type mockMsgStreamFactory struct {
|
||||
|
@ -93,7 +95,10 @@ func (mtm *mockTtMsgStream) GetLatestMsgID(channel string) (msgstream.MessageID,
|
|||
}
|
||||
|
||||
func TestNewDmInputNode(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := newDmInputNode(ctx, new(internalpb.MsgPosition), &nodeConfig{msFactory: &mockMsgStreamFactory{}})
|
||||
client := msgdispatcher.NewClient(&mockMsgStreamFactory{}, typeutil.DataNodeRole, paramtable.GetNodeID())
|
||||
_, err := newDmInputNode(client, new(internalpb.MsgPosition), &nodeConfig{
|
||||
msFactory: &mockMsgStreamFactory{},
|
||||
vChannelName: "mock_vchannel_0",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ func (fm *flowgraphManager) addAndStart(dn *DataNode, vchan *datapb.VchannelInfo
|
|||
var alloc allocatorInterface = newAllocator(dn.rootCoord)
|
||||
|
||||
dataSyncService, err := newDataSyncService(dn.ctx, make(chan flushMsg, 100), make(chan resendTTMsg, 100), channel,
|
||||
alloc, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor, dn.GetSession().ServerID)
|
||||
alloc, dn.dispClient, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor, dn.GetSession().ServerID)
|
||||
if err != nil {
|
||||
log.Warn("new data sync service fail", zap.String("vChannelName", vchan.GetChannelName()), zap.Error(err))
|
||||
return err
|
||||
|
|
|
@ -27,29 +27,29 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/metautil"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
s "github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/etcdpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
s "github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/etcd"
|
||||
"github.com/milvus-io/milvus/internal/util/metautil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
const ctxTimeInMillisecond = 5000
|
||||
|
@ -81,6 +81,7 @@ func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNod
|
|||
factory := dependency.NewDefaultFactory(true)
|
||||
node := NewDataNode(ctx, factory)
|
||||
node.SetSession(&sessionutil.Session{ServerID: 1})
|
||||
node.dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID())
|
||||
|
||||
rc := &RootCoordFactory{
|
||||
ID: 0,
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type (
|
||||
Pos = internalpb.MsgPosition
|
||||
MsgPack = msgstream.MsgPack
|
||||
SubPos = mqwrapper.SubscriptionInitialPosition
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
|
||||
Deregister(vchannel string)
|
||||
}
|
||||
|
||||
var _ Client = (*client)(nil)
|
||||
|
||||
type client struct {
|
||||
role string
|
||||
nodeID int64
|
||||
managers *typeutil.ConcurrentMap[string, DispatcherManager] // pchannel -> DispatcherManager
|
||||
factory msgstream.Factory
|
||||
}
|
||||
|
||||
func NewClient(factory msgstream.Factory, role string, nodeID int64) Client {
|
||||
return &client{
|
||||
role: role,
|
||||
nodeID: nodeID,
|
||||
managers: typeutil.NewConcurrentMap[string, DispatcherManager](),
|
||||
factory: factory,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
|
||||
log := log.With(zap.String("role", c.role),
|
||||
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
|
||||
pchannel := funcutil.ToPhysicalChannel(vchannel)
|
||||
managers, ok := c.managers.Get(pchannel)
|
||||
if !ok {
|
||||
managers = NewDispatcherManager(pchannel, c.role, c.nodeID, c.factory)
|
||||
go managers.Run()
|
||||
old, exist := c.managers.GetOrInsert(pchannel, managers)
|
||||
if exist {
|
||||
managers.Close()
|
||||
managers = old
|
||||
}
|
||||
}
|
||||
ch, err := managers.Add(vchannel, pos, subPos)
|
||||
if err != nil {
|
||||
log.Error("register failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
log.Info("register done")
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (c *client) Deregister(vchannel string) {
|
||||
pchannel := funcutil.ToPhysicalChannel(vchannel)
|
||||
if managers, ok := c.managers.Get(pchannel); ok {
|
||||
managers.Remove(vchannel)
|
||||
if managers.Num() == 0 {
|
||||
managers.Close()
|
||||
c.managers.GetAndRemove(pchannel)
|
||||
}
|
||||
log.Info("deregister done", zap.String("role", c.role),
|
||||
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
|
||||
assert.NotNil(t, client)
|
||||
_, err := client.Register("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
assert.NotPanics(t, func() {
|
||||
client.Deregister("mock_vchannel_0")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_Concurrency(t *testing.T) {
|
||||
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
|
||||
assert.NotNil(t, client)
|
||||
wg := &sync.WaitGroup{}
|
||||
for i := 0; i < 100; i++ {
|
||||
vchannel := fmt.Sprintf("mock-vchannel-%d-%d", i, rand.Int())
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
for j := 0; j < 10; j++ {
|
||||
_, err := client.Register(vchannel, nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
client.Deregister(vchannel)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
|
@ -0,0 +1,244 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type signal int32
|
||||
|
||||
const (
|
||||
start signal = 0
|
||||
pause signal = 1
|
||||
resume signal = 2
|
||||
terminate signal = 3
|
||||
)
|
||||
|
||||
var signalString = map[int32]string{
|
||||
0: "start",
|
||||
1: "pause",
|
||||
2: "resume",
|
||||
3: "terminate",
|
||||
}
|
||||
|
||||
func (s signal) String() string {
|
||||
return signalString[int32(s)]
|
||||
}
|
||||
|
||||
type Dispatcher struct {
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
once sync.Once
|
||||
|
||||
isMain bool // indicates if it's a main dispatcher
|
||||
pchannel string
|
||||
curTs atomic.Uint64
|
||||
|
||||
lagNotifyChan chan struct{}
|
||||
lagTargets *sync.Map // vchannel -> *target
|
||||
|
||||
// vchannel -> *target, lock free since we guarantee that
|
||||
// it's modified only after dispatcher paused or terminated
|
||||
targets map[string]*target
|
||||
|
||||
stream msgstream.MsgStream
|
||||
}
|
||||
|
||||
func NewDispatcher(factory msgstream.Factory,
|
||||
isMain bool,
|
||||
pchannel string,
|
||||
position *Pos,
|
||||
subName string,
|
||||
subPos SubPos,
|
||||
lagNotifyChan chan struct{},
|
||||
lagTargets *sync.Map,
|
||||
) (*Dispatcher, error) {
|
||||
log := log.With(zap.String("pchannel", pchannel),
|
||||
zap.String("subName", subName), zap.Bool("isMain", isMain))
|
||||
log.Info("creating dispatcher...")
|
||||
stream, err := factory.NewTtMsgStream(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if position != nil && len(position.MsgID) != 0 {
|
||||
position.ChannelName = funcutil.ToPhysicalChannel(position.ChannelName)
|
||||
stream.AsConsumer([]string{pchannel}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
err = stream.Seek([]*Pos{position})
|
||||
if err != nil {
|
||||
log.Error("seek failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
posTime := tsoutil.PhysicalTime(position.GetTimestamp())
|
||||
log.Info("seek successfully", zap.Time("posTime", posTime),
|
||||
zap.Duration("tsLag", time.Since(posTime)))
|
||||
} else {
|
||||
stream.AsConsumer([]string{pchannel}, subName, subPos)
|
||||
log.Info("asConsumer successfully")
|
||||
}
|
||||
|
||||
d := &Dispatcher{
|
||||
done: make(chan struct{}, 1),
|
||||
isMain: isMain,
|
||||
pchannel: pchannel,
|
||||
lagNotifyChan: lagNotifyChan,
|
||||
lagTargets: lagTargets,
|
||||
targets: make(map[string]*target),
|
||||
stream: stream,
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *Dispatcher) CurTs() typeutil.Timestamp {
|
||||
return d.curTs.Load()
|
||||
}
|
||||
|
||||
func (d *Dispatcher) AddTarget(t *target) {
|
||||
log := log.With(zap.String("vchannel", t.vchannel), zap.Bool("isMain", d.isMain))
|
||||
if _, ok := d.targets[t.vchannel]; ok {
|
||||
log.Warn("target exists")
|
||||
return
|
||||
}
|
||||
d.targets[t.vchannel] = t
|
||||
log.Info("add new target")
|
||||
}
|
||||
|
||||
func (d *Dispatcher) GetTarget(vchannel string) (*target, error) {
|
||||
if t, ok := d.targets[vchannel]; ok {
|
||||
return t, nil
|
||||
}
|
||||
return nil, fmt.Errorf("cannot find target, vchannel=%s, isMain=%t", vchannel, d.isMain)
|
||||
}
|
||||
|
||||
func (d *Dispatcher) CloseTarget(vchannel string) {
|
||||
log := log.With(zap.String("vchannel", vchannel), zap.Bool("isMain", d.isMain))
|
||||
if t, ok := d.targets[vchannel]; ok {
|
||||
t.close()
|
||||
delete(d.targets, vchannel)
|
||||
log.Info("closed target")
|
||||
} else {
|
||||
log.Warn("target not exist")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) TargetNum() int {
|
||||
return len(d.targets)
|
||||
}
|
||||
|
||||
func (d *Dispatcher) Handle(signal signal) {
|
||||
log := log.With(zap.String("pchannel", d.pchannel),
|
||||
zap.String("signal", signal.String()), zap.Bool("isMain", d.isMain))
|
||||
log.Info("get signal")
|
||||
switch signal {
|
||||
case start:
|
||||
d.wg.Add(1)
|
||||
go d.work()
|
||||
case pause:
|
||||
d.done <- struct{}{}
|
||||
d.wg.Wait()
|
||||
case resume:
|
||||
d.wg.Add(1)
|
||||
go d.work()
|
||||
case terminate:
|
||||
d.done <- struct{}{}
|
||||
d.wg.Wait()
|
||||
d.once.Do(func() {
|
||||
d.stream.Close()
|
||||
})
|
||||
}
|
||||
log.Info("handle signal done")
|
||||
}
|
||||
|
||||
func (d *Dispatcher) work() {
|
||||
log := log.With(zap.String("pchannel", d.pchannel), zap.Bool("isMain", d.isMain))
|
||||
log.Info("begin to work")
|
||||
defer d.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-d.done:
|
||||
log.Info("stop working")
|
||||
return
|
||||
case pack := <-d.stream.Chan():
|
||||
if pack == nil || len(pack.EndPositions) != 1 {
|
||||
log.Error("consumed invalid msgPack")
|
||||
continue
|
||||
}
|
||||
d.curTs.Store(pack.EndPositions[0].GetTimestamp())
|
||||
|
||||
// init packs for all targets, even though there's no msg in pack,
|
||||
// but we still need to dispatch time ticks to the targets.
|
||||
targetPacks := make(map[string]*MsgPack)
|
||||
for vchannel := range d.targets {
|
||||
targetPacks[vchannel] = &MsgPack{
|
||||
BeginTs: pack.BeginTs,
|
||||
EndTs: pack.EndTs,
|
||||
Msgs: make([]msgstream.TsMsg, 0),
|
||||
StartPositions: pack.StartPositions,
|
||||
EndPositions: pack.EndPositions,
|
||||
}
|
||||
}
|
||||
|
||||
// group messages by vchannel
|
||||
for _, msg := range pack.Msgs {
|
||||
if msg.VChannel() == "" {
|
||||
// for non-dml msg, such as CreateCollection, DropCollection, ...
|
||||
// we need to dispatch it to all the vchannels.
|
||||
for k := range targetPacks {
|
||||
targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if _, ok := targetPacks[msg.VChannel()]; !ok {
|
||||
continue
|
||||
}
|
||||
targetPacks[msg.VChannel()].Msgs = append(targetPacks[msg.VChannel()].Msgs, msg)
|
||||
}
|
||||
|
||||
// dispatch messages, split target if block
|
||||
for vchannel, p := range targetPacks {
|
||||
t := d.targets[vchannel]
|
||||
if err := t.send(p); err != nil {
|
||||
t.pos = pack.StartPositions[0]
|
||||
d.lagTargets.LoadOrStore(t.vchannel, t)
|
||||
d.nonBlockingNotify()
|
||||
delete(d.targets, vchannel)
|
||||
log.Warn("lag target notified", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) nonBlockingNotify() {
|
||||
select {
|
||||
case d.lagNotifyChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
)
|
||||
|
||||
func TestDispatcher(t *testing.T) {
|
||||
t.Run("test base", func(t *testing.T) {
|
||||
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
|
||||
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotPanics(t, func() {
|
||||
d.Handle(start)
|
||||
d.Handle(pause)
|
||||
d.Handle(resume)
|
||||
d.Handle(terminate)
|
||||
})
|
||||
|
||||
pos := &msgstream.MsgPosition{
|
||||
ChannelName: "mock_vchannel_0",
|
||||
MsgGroup: "mock_msg_group",
|
||||
Timestamp: 100,
|
||||
}
|
||||
d.curTs.Store(pos.GetTimestamp())
|
||||
curTs := d.CurTs()
|
||||
assert.Equal(t, pos.Timestamp, curTs)
|
||||
})
|
||||
|
||||
t.Run("test target", func(t *testing.T) {
|
||||
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
|
||||
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
output := make(chan *msgstream.MsgPack, 1024)
|
||||
d.AddTarget(&target{
|
||||
vchannel: "mock_vchannel_0",
|
||||
pos: nil,
|
||||
ch: output,
|
||||
})
|
||||
d.AddTarget(&target{
|
||||
vchannel: "mock_vchannel_1",
|
||||
pos: nil,
|
||||
ch: nil,
|
||||
})
|
||||
num := d.TargetNum()
|
||||
assert.Equal(t, 2, num)
|
||||
|
||||
target, err := d.GetTarget("mock_vchannel_0")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, cap(output), cap(target.ch))
|
||||
|
||||
d.CloseTarget("mock_vchannel_0")
|
||||
|
||||
select {
|
||||
case <-time.After(1 * time.Second):
|
||||
assert.Fail(t, "timeout, didn't receive close message")
|
||||
case _, ok := <-target.ch:
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
num = d.TargetNum()
|
||||
assert.Equal(t, 1, num)
|
||||
})
|
||||
|
||||
t.Run("test concurrent send and close", func(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
output := make(chan *msgstream.MsgPack, 1024)
|
||||
target := &target{
|
||||
vchannel: "mock_vchannel_0",
|
||||
pos: nil,
|
||||
ch: output,
|
||||
}
|
||||
assert.Equal(t, cap(output), cap(target.ch))
|
||||
wg := &sync.WaitGroup{}
|
||||
for j := 0; j < 100; j++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
err := target.send(&MsgPack{})
|
||||
assert.NoError(t, err)
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
target.close()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkDispatcher_handle(b *testing.B) {
|
||||
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
|
||||
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
|
||||
assert.NoError(b, err)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
d.Handle(start)
|
||||
d.Handle(pause)
|
||||
d.Handle(resume)
|
||||
d.Handle(terminate)
|
||||
}
|
||||
// BenchmarkDispatcher_handle-12 9568 122123 ns/op
|
||||
// PASS
|
||||
}
|
|
@ -0,0 +1,240 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
)
|
||||
|
||||
var (
|
||||
CheckPeriod = 1 * time.Second // TODO: dyh, move to config
|
||||
)
|
||||
|
||||
type DispatcherManager interface {
|
||||
Add(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
|
||||
Remove(vchannel string)
|
||||
Num() int
|
||||
Run()
|
||||
Close()
|
||||
}
|
||||
|
||||
var _ DispatcherManager = (*dispatcherManager)(nil)
|
||||
|
||||
type dispatcherManager struct {
|
||||
role string
|
||||
nodeID int64
|
||||
pchannel string
|
||||
|
||||
lagNotifyChan chan struct{}
|
||||
lagTargets *sync.Map // vchannel -> *target
|
||||
|
||||
mu sync.RWMutex // guards mainDispatcher and soloDispatchers
|
||||
mainDispatcher *Dispatcher
|
||||
soloDispatchers map[string]*Dispatcher // vchannel -> *Dispatcher
|
||||
|
||||
factory msgstream.Factory
|
||||
closeChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewDispatcherManager(pchannel string, role string, nodeID int64, factory msgstream.Factory) DispatcherManager {
|
||||
log.Info("create new dispatcherManager", zap.String("role", role),
|
||||
zap.Int64("nodeID", nodeID), zap.String("pchannel", pchannel))
|
||||
c := &dispatcherManager{
|
||||
role: role,
|
||||
nodeID: nodeID,
|
||||
pchannel: pchannel,
|
||||
lagNotifyChan: make(chan struct{}, 1),
|
||||
lagTargets: &sync.Map{},
|
||||
soloDispatchers: make(map[string]*Dispatcher),
|
||||
factory: factory,
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *dispatcherManager) constructSubName(vchannel string, isMain bool) string {
|
||||
return fmt.Sprintf("%s-%d-%s-%t", c.role, c.nodeID, vchannel, isMain)
|
||||
}
|
||||
|
||||
func (c *dispatcherManager) Add(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
|
||||
log := log.With(zap.String("role", c.role),
|
||||
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
isMain := c.mainDispatcher == nil
|
||||
d, err := NewDispatcher(c.factory, isMain, c.pchannel, pos,
|
||||
c.constructSubName(vchannel, isMain), subPos, c.lagNotifyChan, c.lagTargets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := newTarget(vchannel, pos)
|
||||
d.AddTarget(t)
|
||||
if isMain {
|
||||
c.mainDispatcher = d
|
||||
log.Info("add main dispatcher")
|
||||
} else {
|
||||
c.soloDispatchers[vchannel] = d
|
||||
log.Info("add solo dispatcher")
|
||||
}
|
||||
d.Handle(start)
|
||||
return t.ch, nil
|
||||
}
|
||||
|
||||
func (c *dispatcherManager) Remove(vchannel string) {
|
||||
log := log.With(zap.String("role", c.role),
|
||||
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.mainDispatcher != nil {
|
||||
c.mainDispatcher.Handle(pause)
|
||||
c.mainDispatcher.CloseTarget(vchannel)
|
||||
if c.mainDispatcher.TargetNum() == 0 && len(c.soloDispatchers) == 0 {
|
||||
c.mainDispatcher.Handle(terminate)
|
||||
c.mainDispatcher = nil
|
||||
} else {
|
||||
c.mainDispatcher.Handle(resume)
|
||||
}
|
||||
}
|
||||
if _, ok := c.soloDispatchers[vchannel]; ok {
|
||||
c.soloDispatchers[vchannel].Handle(terminate)
|
||||
c.soloDispatchers[vchannel].CloseTarget(vchannel)
|
||||
delete(c.soloDispatchers, vchannel)
|
||||
log.Info("remove soloDispatcher done")
|
||||
}
|
||||
c.lagTargets.Delete(vchannel)
|
||||
}
|
||||
|
||||
func (c *dispatcherManager) Num() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
var res int
|
||||
if c.mainDispatcher != nil {
|
||||
res++
|
||||
}
|
||||
return res + len(c.soloDispatchers)
|
||||
}
|
||||
|
||||
func (c *dispatcherManager) Close() {
|
||||
c.closeOnce.Do(func() {
|
||||
c.closeChan <- struct{}{}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *dispatcherManager) Run() {
|
||||
log := log.With(zap.String("role", c.role),
|
||||
zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel))
|
||||
log.Info("dispatcherManager is running...")
|
||||
ticker := time.NewTicker(CheckPeriod)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.closeChan:
|
||||
log.Info("dispatcherManager exited")
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.tryMerge()
|
||||
case <-c.lagNotifyChan:
|
||||
c.mu.Lock()
|
||||
c.lagTargets.Range(func(vchannel, t any) bool {
|
||||
c.split(t.(*target))
|
||||
c.lagTargets.Delete(vchannel)
|
||||
return true
|
||||
})
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *dispatcherManager) tryMerge() {
|
||||
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID))
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.mainDispatcher == nil {
|
||||
return
|
||||
}
|
||||
candidates := make(map[string]struct{})
|
||||
for vchannel, sd := range c.soloDispatchers {
|
||||
if sd.CurTs() == c.mainDispatcher.CurTs() {
|
||||
candidates[vchannel] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("start merging...", zap.Any("vchannel", candidates))
|
||||
c.mainDispatcher.Handle(pause)
|
||||
for vchannel := range candidates {
|
||||
c.soloDispatchers[vchannel].Handle(pause)
|
||||
// after pause, check alignment again, if not, evict it and try to merge next time
|
||||
if c.mainDispatcher.CurTs() != c.soloDispatchers[vchannel].CurTs() {
|
||||
c.soloDispatchers[vchannel].Handle(resume)
|
||||
delete(candidates, vchannel)
|
||||
}
|
||||
}
|
||||
for vchannel := range candidates {
|
||||
t, err := c.soloDispatchers[vchannel].GetTarget(vchannel)
|
||||
if err == nil {
|
||||
c.mainDispatcher.AddTarget(t)
|
||||
}
|
||||
c.soloDispatchers[vchannel].Handle(terminate)
|
||||
delete(c.soloDispatchers, vchannel)
|
||||
}
|
||||
c.mainDispatcher.Handle(resume)
|
||||
log.Info("merge done", zap.Any("vchannel", candidates))
|
||||
}
|
||||
|
||||
func (c *dispatcherManager) split(t *target) {
|
||||
log := log.With(zap.String("role", c.role),
|
||||
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", t.vchannel))
|
||||
log.Info("start splitting...")
|
||||
|
||||
// remove stale soloDispatcher if it existed
|
||||
if _, ok := c.soloDispatchers[t.vchannel]; ok {
|
||||
c.soloDispatchers[t.vchannel].Handle(terminate)
|
||||
delete(c.soloDispatchers, t.vchannel)
|
||||
}
|
||||
|
||||
var newSolo *Dispatcher
|
||||
err := retry.Do(context.Background(), func() error {
|
||||
var err error
|
||||
newSolo, err = NewDispatcher(c.factory, false, c.pchannel, t.pos,
|
||||
c.constructSubName(t.vchannel, false), mqwrapper.SubscriptionPositionUnknown, c.lagNotifyChan, c.lagTargets)
|
||||
return err
|
||||
}, retry.Attempts(10))
|
||||
if err != nil {
|
||||
log.Error("split failed", zap.Error(err))
|
||||
panic(err)
|
||||
}
|
||||
newSolo.AddTarget(t)
|
||||
c.soloDispatchers[t.vchannel] = newSolo
|
||||
newSolo.Handle(start)
|
||||
log.Info("split done")
|
||||
}
|
|
@ -0,0 +1,354 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
func TestManager(t *testing.T) {
|
||||
t.Run("test add and remove dispatcher", func(t *testing.T) {
|
||||
c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
|
||||
assert.NotNil(t, c)
|
||||
assert.Equal(t, 0, c.Num())
|
||||
|
||||
var offset int
|
||||
for i := 0; i < 100; i++ {
|
||||
r := rand.Intn(10) + 1
|
||||
for j := 0; j < r; j++ {
|
||||
offset++
|
||||
_, err := c.Add(fmt.Sprintf("mock_vchannel_%d", offset), nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, offset, c.Num())
|
||||
}
|
||||
for j := 0; j < rand.Intn(r); j++ {
|
||||
c.Remove(fmt.Sprintf("mock_vchannel_%d", offset))
|
||||
offset--
|
||||
assert.Equal(t, offset, c.Num())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test merge and split", func(t *testing.T) {
|
||||
c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
|
||||
assert.NotNil(t, c)
|
||||
_, err := c.Add("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add("mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add("mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, c.Num())
|
||||
|
||||
c.(*dispatcherManager).tryMerge()
|
||||
assert.Equal(t, 1, c.Num())
|
||||
|
||||
info := &target{
|
||||
vchannel: "mock_vchannel_2",
|
||||
pos: nil,
|
||||
ch: nil,
|
||||
}
|
||||
c.(*dispatcherManager).split(info)
|
||||
assert.Equal(t, 2, c.Num())
|
||||
})
|
||||
|
||||
t.Run("test run and close", func(t *testing.T) {
|
||||
c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
|
||||
assert.NotNil(t, c)
|
||||
_, err := c.Add("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add("mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add("mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, c.Num())
|
||||
|
||||
CheckPeriod = 10 * time.Millisecond
|
||||
go c.Run()
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
assert.Equal(t, 1, c.Num()) // expected merged
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
c.Close()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type vchannelHelper struct {
|
||||
output <-chan *msgstream.MsgPack
|
||||
|
||||
pubInsMsgNum int
|
||||
pubDelMsgNum int
|
||||
pubDDLMsgNum int
|
||||
pubPackNum int
|
||||
|
||||
subInsMsgNum int
|
||||
subDelMsgNum int
|
||||
subDDLMsgNum int
|
||||
subPackNum int
|
||||
}
|
||||
|
||||
type SimulationSuite struct {
|
||||
suite.Suite
|
||||
|
||||
testVchannelNum int
|
||||
|
||||
manager DispatcherManager
|
||||
pchannel string
|
||||
vchannels map[string]*vchannelHelper
|
||||
|
||||
producer msgstream.MsgStream
|
||||
factory msgstream.Factory
|
||||
}
|
||||
|
||||
func (suite *SimulationSuite) SetupSuite() {
|
||||
suite.factory = newMockFactory()
|
||||
}
|
||||
|
||||
func (suite *SimulationSuite) SetupTest() {
|
||||
suite.pchannel = fmt.Sprintf("by-dev-rootcoord-dispatcher-simulation-dml-%d-%d", rand.Int(), time.Now().UnixNano())
|
||||
producer, err := newMockProducer(suite.factory, suite.pchannel)
|
||||
assert.NoError(suite.T(), err)
|
||||
suite.producer = producer
|
||||
|
||||
suite.manager = NewDispatcherManager(suite.pchannel, typeutil.DataNodeRole, 0, suite.factory)
|
||||
CheckPeriod = 10 * time.Millisecond
|
||||
go suite.manager.Run()
|
||||
}
|
||||
|
||||
func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
const timeTickCount = 200
|
||||
var uniqueMsgID int64
|
||||
vchannelKeys := reflect.ValueOf(suite.vchannels).MapKeys()
|
||||
|
||||
for i := 1; i <= timeTickCount; i++ {
|
||||
// produce random insert
|
||||
insNum := rand.Intn(10)
|
||||
for j := 0; j < insNum; j++ {
|
||||
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
|
||||
err := suite.producer.Produce(&msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
|
||||
})
|
||||
assert.NoError(suite.T(), err)
|
||||
uniqueMsgID++
|
||||
suite.vchannels[vchannel].pubInsMsgNum++
|
||||
}
|
||||
// produce random delete
|
||||
delNum := rand.Intn(2)
|
||||
for j := 0; j < delNum; j++ {
|
||||
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
|
||||
err := suite.producer.Produce(&msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
|
||||
})
|
||||
assert.NoError(suite.T(), err)
|
||||
uniqueMsgID++
|
||||
suite.vchannels[vchannel].pubDelMsgNum++
|
||||
}
|
||||
// produce random ddl
|
||||
ddlNum := rand.Intn(2)
|
||||
for j := 0; j < ddlNum; j++ {
|
||||
err := suite.producer.Produce(&msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection)},
|
||||
})
|
||||
assert.NoError(suite.T(), err)
|
||||
for k := range suite.vchannels {
|
||||
suite.vchannels[k].pubDDLMsgNum++
|
||||
}
|
||||
}
|
||||
// produce time tick
|
||||
ts := uint64(i * 100)
|
||||
err := suite.producer.Produce(&msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
|
||||
})
|
||||
assert.NoError(suite.T(), err)
|
||||
for k := range suite.vchannels {
|
||||
suite.vchannels[k].pubPackNum++
|
||||
}
|
||||
}
|
||||
suite.T().Logf("[%s] produce %d msgPack for %s done", time.Now(), timeTickCount, suite.pchannel)
|
||||
}
|
||||
|
||||
func (suite *SimulationSuite) consumeMsg(ctx context.Context, wg *sync.WaitGroup, vchannel string) {
|
||||
defer wg.Done()
|
||||
var lastTs typeutil.Timestamp
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(2000 * time.Millisecond): // no message to consume
|
||||
return
|
||||
case pack := <-suite.vchannels[vchannel].output:
|
||||
assert.Greater(suite.T(), pack.EndTs, lastTs)
|
||||
lastTs = pack.EndTs
|
||||
helper := suite.vchannels[vchannel]
|
||||
helper.subPackNum++
|
||||
for _, msg := range pack.Msgs {
|
||||
switch msg.Type() {
|
||||
case commonpb.MsgType_Insert:
|
||||
helper.subInsMsgNum++
|
||||
case commonpb.MsgType_Delete:
|
||||
helper.subDelMsgNum++
|
||||
case commonpb.MsgType_CreateCollection, commonpb.MsgType_DropCollection,
|
||||
commonpb.MsgType_CreatePartition, commonpb.MsgType_DropPartition:
|
||||
helper.subDDLMsgNum++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *SimulationSuite) produceTimeTickOnly(ctx context.Context) {
|
||||
var tt = 1
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
ts := uint64(tt * 1000)
|
||||
err := suite.producer.Produce(&msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
|
||||
})
|
||||
assert.NoError(suite.T(), err)
|
||||
tt++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *SimulationSuite) TestDispatchToVchannels() {
|
||||
const vchannelNum = 20
|
||||
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
|
||||
for i := 0; i < vchannelNum; i++ {
|
||||
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
|
||||
output, err := suite.manager.Add(vchannel, nil, mqwrapper.SubscriptionPositionEarliest)
|
||||
assert.NoError(suite.T(), err)
|
||||
suite.vchannels[vchannel] = &vchannelHelper{output: output}
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go suite.produceMsg(wg)
|
||||
wg.Wait()
|
||||
for vchannel := range suite.vchannels {
|
||||
wg.Add(1)
|
||||
go suite.consumeMsg(context.Background(), wg, vchannel)
|
||||
}
|
||||
wg.Wait()
|
||||
for _, helper := range suite.vchannels {
|
||||
assert.Equal(suite.T(), helper.pubInsMsgNum, helper.subInsMsgNum)
|
||||
assert.Equal(suite.T(), helper.pubDelMsgNum, helper.subDelMsgNum)
|
||||
assert.Equal(suite.T(), helper.pubDDLMsgNum, helper.subDDLMsgNum)
|
||||
assert.Equal(suite.T(), helper.pubPackNum, helper.subPackNum)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *SimulationSuite) TestMerge() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go suite.produceTimeTickOnly(ctx)
|
||||
|
||||
const vchannelNum = 20
|
||||
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
|
||||
positions, err := getSeekPositions(suite.factory, suite.pchannel, 200)
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
for i := 0; i < vchannelNum; i++ {
|
||||
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
|
||||
output, err := suite.manager.Add(vchannel, positions[rand.Intn(len(positions))],
|
||||
mqwrapper.SubscriptionPositionUnknown) // seek from random position
|
||||
assert.NoError(suite.T(), err)
|
||||
suite.vchannels[vchannel] = &vchannelHelper{output: output}
|
||||
}
|
||||
wg := &sync.WaitGroup{}
|
||||
for vchannel := range suite.vchannels {
|
||||
wg.Add(1)
|
||||
go suite.consumeMsg(ctx, wg, vchannel)
|
||||
}
|
||||
|
||||
suite.Eventually(func() bool {
|
||||
suite.T().Logf("dispatcherManager.dispatcherNum = %d", suite.manager.Num())
|
||||
return suite.manager.Num() == 1 // expected all merged, only mainDispatcher exist
|
||||
}, 10*time.Second, 100*time.Millisecond)
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (suite *SimulationSuite) TestSplit() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go suite.produceTimeTickOnly(ctx)
|
||||
|
||||
const vchannelNum = 10
|
||||
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
|
||||
DefaultTargetChanSize = 10
|
||||
MaxTolerantLag = 500 * time.Millisecond
|
||||
for i := 0; i < vchannelNum; i++ {
|
||||
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
|
||||
output, err := suite.manager.Add(vchannel, nil, mqwrapper.SubscriptionPositionEarliest)
|
||||
assert.NoError(suite.T(), err)
|
||||
suite.vchannels[vchannel] = &vchannelHelper{output: output}
|
||||
}
|
||||
|
||||
const splitNum = 3
|
||||
wg := &sync.WaitGroup{}
|
||||
counter := 0
|
||||
for vchannel := range suite.vchannels {
|
||||
wg.Add(1)
|
||||
go suite.consumeMsg(ctx, wg, vchannel)
|
||||
counter++
|
||||
if counter >= len(suite.vchannels)-splitNum {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
suite.Eventually(func() bool {
|
||||
suite.T().Logf("dispatcherManager.dispatcherNum = %d, splitNum+1 = %d", suite.manager.Num(), splitNum+1)
|
||||
return suite.manager.Num() == splitNum+1 // expected 1 mainDispatcher and `splitNum` soloDispatchers
|
||||
}, 10*time.Second, 100*time.Millisecond)
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (suite *SimulationSuite) TearDownTest() {
|
||||
for vchannel := range suite.vchannels {
|
||||
suite.manager.Remove(vchannel)
|
||||
}
|
||||
suite.manager.Close()
|
||||
}
|
||||
|
||||
func (suite *SimulationSuite) TearDownSuite() {
|
||||
|
||||
}
|
||||
|
||||
func TestSimulation(t *testing.T) {
|
||||
suite.Run(t, new(SimulationSuite))
|
||||
}
|
|
@ -0,0 +1,213 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
dim = 128
|
||||
)
|
||||
|
||||
func newMockFactory() msgstream.Factory {
|
||||
paramtable.Init()
|
||||
return msgstream.NewRmsFactory("/tmp/milvus/rocksmq/")
|
||||
}
|
||||
|
||||
func newMockProducer(factory msgstream.Factory, pchannel string) (msgstream.MsgStream, error) {
|
||||
stream, err := factory.NewMsgStream(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream.AsProducer([]string{pchannel})
|
||||
stream.SetRepackFunc(defaultInsertRepackFunc)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func getSeekPositions(factory msgstream.Factory, pchannel string, maxNum int) ([]*msgstream.MsgPosition, error) {
|
||||
stream, err := factory.NewTtMsgStream(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stream.Close()
|
||||
stream.AsConsumer([]string{pchannel}, fmt.Sprintf("%d", rand.Int()), mqwrapper.SubscriptionPositionEarliest)
|
||||
positions := make([]*msgstream.MsgPosition, 0)
|
||||
for {
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond): // no message to consume
|
||||
return positions, nil
|
||||
case pack := <-stream.Chan():
|
||||
positions = append(positions, pack.EndPositions[0])
|
||||
if len(positions) >= maxNum {
|
||||
return positions, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func genPKs(numRows int) []typeutil.IntPrimaryKey {
|
||||
ids := make([]typeutil.IntPrimaryKey, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ids[i] = typeutil.IntPrimaryKey(i)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func genTimestamps(numRows int) []typeutil.Timestamp {
|
||||
ts := make([]typeutil.Timestamp, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ts[i] = typeutil.Timestamp(i + 1)
|
||||
}
|
||||
return ts
|
||||
}
|
||||
|
||||
func genInsertMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstream.InsertMsg {
|
||||
floatVec := make([]float32, numRows*dim)
|
||||
for i := 0; i < numRows*dim; i++ {
|
||||
floatVec[i] = rand.Float32()
|
||||
}
|
||||
hashValues := make([]uint32, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
hashValues[i] = uint32(1)
|
||||
}
|
||||
return &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: hashValues},
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Insert, MsgID: msgID},
|
||||
ShardName: vchannel,
|
||||
Timestamps: genTimestamps(numRows),
|
||||
RowIDs: genPKs(numRows),
|
||||
FieldsData: []*schemapb.FieldData{{
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: floatVec}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
NumRows: uint64(numRows),
|
||||
Version: internalpb.InsertDataVersion_ColumnBased,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func genDeleteMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstream.DeleteMsg {
|
||||
return &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: make([]uint32, numRows)},
|
||||
DeleteRequest: internalpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete, MsgID: msgID},
|
||||
ShardName: vchannel,
|
||||
PrimaryKeys: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: genPKs(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
Timestamps: genTimestamps(numRows),
|
||||
NumRows: int64(numRows),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func genDDLMsg(msgType commonpb.MsgType) msgstream.TsMsg {
|
||||
switch msgType {
|
||||
case commonpb.MsgType_CreateCollection:
|
||||
return &msgstream.CreateCollectionMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
|
||||
CreateCollectionRequest: internalpb.CreateCollectionRequest{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
|
||||
},
|
||||
}
|
||||
case commonpb.MsgType_DropCollection:
|
||||
return &msgstream.DropCollectionMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
|
||||
DropCollectionRequest: internalpb.DropCollectionRequest{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
|
||||
},
|
||||
}
|
||||
case commonpb.MsgType_CreatePartition:
|
||||
return &msgstream.CreatePartitionMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
|
||||
CreatePartitionRequest: internalpb.CreatePartitionRequest{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreatePartition},
|
||||
},
|
||||
}
|
||||
case commonpb.MsgType_DropPartition:
|
||||
return &msgstream.DropPartitionMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
|
||||
DropPartitionRequest: internalpb.DropPartitionRequest{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition},
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func genTimeTickMsg(ts typeutil.Timestamp) *msgstream.TimeTickMsg {
|
||||
return &msgstream.TimeTickMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
|
||||
TimeTickMsg: internalpb.TimeTickMsg{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_TimeTick,
|
||||
Timestamp: ts,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// defaultInsertRepackFunc repacks the dml messages.
|
||||
func defaultInsertRepackFunc(
|
||||
tsMsgs []msgstream.TsMsg,
|
||||
hashKeys [][]int32,
|
||||
) (map[int32]*msgstream.MsgPack, error) {
|
||||
|
||||
if len(hashKeys) < len(tsMsgs) {
|
||||
return nil, fmt.Errorf(
|
||||
"the length of hash keys (%d) is less than the length of messages (%d)",
|
||||
len(hashKeys),
|
||||
len(tsMsgs),
|
||||
)
|
||||
}
|
||||
|
||||
// after assigning segment id to msg, tsMsgs was already re-bucketed
|
||||
pack := make(map[int32]*msgstream.MsgPack)
|
||||
for idx, msg := range tsMsgs {
|
||||
if len(hashKeys[idx]) <= 0 {
|
||||
return nil, fmt.Errorf("no hash key for %dth message", idx)
|
||||
}
|
||||
key := hashKeys[idx][0]
|
||||
_, ok := pack[key]
|
||||
if !ok {
|
||||
pack[key] = &msgstream.MsgPack{}
|
||||
}
|
||||
pack[key].Msgs = append(pack[key].Msgs, msg)
|
||||
}
|
||||
return pack, nil
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TODO: dyh, move to config
|
||||
var (
|
||||
MaxTolerantLag = 3 * time.Second
|
||||
DefaultTargetChanSize = 1024
|
||||
)
|
||||
|
||||
type target struct {
|
||||
vchannel string
|
||||
ch chan *MsgPack
|
||||
pos *Pos
|
||||
|
||||
closeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newTarget(vchannel string, pos *Pos) *target {
|
||||
t := &target{
|
||||
vchannel: vchannel,
|
||||
ch: make(chan *MsgPack, DefaultTargetChanSize),
|
||||
pos: pos,
|
||||
}
|
||||
t.closed = false
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *target) close() {
|
||||
t.closeMu.Lock()
|
||||
defer t.closeMu.Unlock()
|
||||
t.closeOnce.Do(func() {
|
||||
t.closed = true
|
||||
close(t.ch)
|
||||
})
|
||||
}
|
||||
|
||||
func (t *target) send(pack *MsgPack) error {
|
||||
t.closeMu.Lock()
|
||||
defer t.closeMu.Unlock()
|
||||
if t.closed {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-time.After(MaxTolerantLag):
|
||||
return fmt.Errorf("send target timeout, vchannel=%s, timeout=%s", t.vchannel, MaxTolerantLag)
|
||||
case t.ch <- pack:
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -53,6 +53,7 @@ type TsMsg interface {
|
|||
Unmarshal(MarshalType) (TsMsg, error)
|
||||
Position() *MsgPosition
|
||||
SetPosition(*MsgPosition)
|
||||
VChannel() string
|
||||
}
|
||||
|
||||
// BaseMsg is a basic structure that contains begin timestamp, end timestamp and the position of msgstream
|
||||
|
@ -62,6 +63,7 @@ type BaseMsg struct {
|
|||
EndTimestamp Timestamp
|
||||
HashValues []uint32
|
||||
MsgPosition *MsgPosition
|
||||
Vchannel string
|
||||
}
|
||||
|
||||
// TraceCtx returns the context of opentracing
|
||||
|
@ -99,6 +101,10 @@ func (bm *BaseMsg) SetPosition(position *MsgPosition) {
|
|||
bm.MsgPosition = position
|
||||
}
|
||||
|
||||
func (bm *BaseMsg) VChannel() string {
|
||||
return bm.Vchannel
|
||||
}
|
||||
|
||||
func convertToByteArray(input interface{}) ([]byte, error) {
|
||||
switch output := input.(type) {
|
||||
case []byte:
|
||||
|
@ -170,6 +176,7 @@ func (it *InsertMsg) Unmarshal(input MarshalType) (TsMsg, error) {
|
|||
insertMsg.BeginTimestamp = timestamp
|
||||
}
|
||||
}
|
||||
insertMsg.Vchannel = insertMsg.ShardName
|
||||
|
||||
return insertMsg, nil
|
||||
}
|
||||
|
@ -278,6 +285,7 @@ func (it *InsertMsg) IndexMsg(index int) *InsertMsg {
|
|||
Ctx: it.TraceCtx(),
|
||||
BeginTimestamp: it.BeginTimestamp,
|
||||
EndTimestamp: it.EndTimestamp,
|
||||
Vchannel: it.Vchannel,
|
||||
HashValues: it.HashValues,
|
||||
MsgPosition: it.MsgPosition,
|
||||
},
|
||||
|
@ -361,7 +369,7 @@ func (dt *DeleteMsg) Unmarshal(input MarshalType) (TsMsg, error) {
|
|||
deleteMsg.BeginTimestamp = timestamp
|
||||
}
|
||||
}
|
||||
|
||||
deleteMsg.Vchannel = deleteRequest.ShardName
|
||||
return deleteMsg, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -148,6 +148,7 @@ func assignSegmentID(ctx context.Context, insertMsg *msgstream.InsertMsg, result
|
|||
msg.HashValues = append(msg.HashValues, insertMsg.HashValues[offset])
|
||||
msg.Timestamps = append(msg.Timestamps, insertMsg.Timestamps[offset])
|
||||
msg.RowIDs = append(msg.RowIDs, insertMsg.RowIDs[offset])
|
||||
msg.BaseMsg.Vchannel = channelName
|
||||
msg.NumRows++
|
||||
requestSize += curRowMessageSize
|
||||
}
|
||||
|
|
|
@ -268,6 +268,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
|||
partitionName := dt.deleteMsg.PartitionName
|
||||
proxyID := dt.deleteMsg.Base.SourceID
|
||||
for index, key := range dt.deleteMsg.HashValues {
|
||||
vchannel := channelNames[key]
|
||||
ts := dt.deleteMsg.Timestamps[index]
|
||||
_, ok := result[key]
|
||||
if !ok {
|
||||
|
@ -297,6 +298,8 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
|||
curMsg.Timestamps = append(curMsg.Timestamps, dt.deleteMsg.Timestamps[index])
|
||||
typeutil.AppendIDs(curMsg.PrimaryKeys, dt.deleteMsg.PrimaryKeys, index)
|
||||
curMsg.NumRows++
|
||||
curMsg.ShardName = vchannel
|
||||
curMsg.Vchannel = vchannel
|
||||
}
|
||||
|
||||
// send delete request to log broker
|
||||
|
|
|
@ -439,6 +439,7 @@ func (it *upsertTask) deleteExecute(ctx context.Context, msgPack *msgstream.MsgP
|
|||
curMsg.Timestamps = append(curMsg.Timestamps, it.upsertMsg.DeleteMsg.Timestamps[index])
|
||||
typeutil.AppendIDs(curMsg.PrimaryKeys, it.upsertMsg.DeleteMsg.PrimaryKeys, index)
|
||||
curMsg.NumRows++
|
||||
curMsg.ShardName = channelNames[key]
|
||||
}
|
||||
|
||||
// send delete request to log broker
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
|
@ -40,6 +41,7 @@ type dataSyncService struct {
|
|||
|
||||
metaReplica ReplicaInterface
|
||||
tSafeReplica TSafeReplicaInterface
|
||||
dispClient msgdispatcher.Client
|
||||
msFactory msgstream.Factory
|
||||
}
|
||||
|
||||
|
@ -51,7 +53,7 @@ func (dsService *dataSyncService) getFlowGraphNum() int {
|
|||
}
|
||||
|
||||
// addFlowGraphsForDMLChannels add flowGraphs to dmlChannel2FlowGraph
|
||||
func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID UniqueID, dmlChannels []string) (map[string]*queryNodeFlowGraph, error) {
|
||||
func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID UniqueID, dmlChannels map[string]*msgstream.MsgPosition) (map[string]*queryNodeFlowGraph, error) {
|
||||
dsService.mu.Lock()
|
||||
defer dsService.mu.Unlock()
|
||||
|
||||
|
@ -61,7 +63,7 @@ func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID Uniqu
|
|||
}
|
||||
|
||||
results := make(map[string]*queryNodeFlowGraph)
|
||||
for _, channel := range dmlChannels {
|
||||
for channel, position := range dmlChannels {
|
||||
if _, ok := dsService.dmlChannel2FlowGraph[channel]; ok {
|
||||
log.Warn("dml flow graph has been existed",
|
||||
zap.Any("collectionID", collectionID),
|
||||
|
@ -74,7 +76,8 @@ func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID Uniqu
|
|||
dsService.metaReplica,
|
||||
dsService.tSafeReplica,
|
||||
channel,
|
||||
dsService.msFactory)
|
||||
position,
|
||||
dsService.dispClient)
|
||||
if err != nil {
|
||||
for _, fg := range results {
|
||||
fg.flowGraph.Close()
|
||||
|
@ -128,7 +131,7 @@ func (dsService *dataSyncService) addFlowGraphsForDeltaChannels(collectionID Uni
|
|||
dsService.metaReplica,
|
||||
dsService.tSafeReplica,
|
||||
channel,
|
||||
dsService.msFactory)
|
||||
dsService.dispClient)
|
||||
if err != nil {
|
||||
for channel, fg := range results {
|
||||
fg.flowGraph.Close()
|
||||
|
@ -291,6 +294,7 @@ func (dsService *dataSyncService) removeEmptyFlowGraphByChannel(collectionID int
|
|||
func newDataSyncService(ctx context.Context,
|
||||
metaReplica ReplicaInterface,
|
||||
tSafeReplica TSafeReplicaInterface,
|
||||
dispClient msgdispatcher.Client,
|
||||
factory msgstream.Factory) *dataSyncService {
|
||||
|
||||
return &dataSyncService{
|
||||
|
@ -299,6 +303,7 @@ func newDataSyncService(ctx context.Context,
|
|||
deltaChannel2FlowGraph: make(map[Channel]*queryNodeFlowGraph),
|
||||
metaReplica: metaReplica,
|
||||
tSafeReplica: tSafeReplica,
|
||||
dispClient: dispClient,
|
||||
msFactory: factory,
|
||||
}
|
||||
}
|
||||
|
@ -308,6 +313,7 @@ func (dsService *dataSyncService) close() {
|
|||
// close DML flow graphs
|
||||
for channel, nodeFG := range dsService.dmlChannel2FlowGraph {
|
||||
if nodeFG != nil {
|
||||
dsService.dispClient.Deregister(channel)
|
||||
nodeFG.flowGraph.Close()
|
||||
}
|
||||
delete(dsService.dmlChannel2FlowGraph, channel)
|
||||
|
@ -315,6 +321,7 @@ func (dsService *dataSyncService) close() {
|
|||
// close delta flow graphs
|
||||
for channel, nodeFG := range dsService.deltaChannel2FlowGraph {
|
||||
if nodeFG != nil {
|
||||
dsService.dispClient.Deregister(channel)
|
||||
nodeFG.flowGraph.Close()
|
||||
}
|
||||
delete(dsService.deltaChannel2FlowGraph, channel)
|
||||
|
|
|
@ -21,14 +21,20 @@ import (
|
|||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rateCol, _ = newRateCollector()
|
||||
Params.Init()
|
||||
}
|
||||
|
||||
func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
|
||||
|
@ -40,17 +46,18 @@ func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
|
|||
|
||||
fac := genFactory()
|
||||
assert.NoError(t, err)
|
||||
dispClient := msgdispatcher.NewClient(fac, typeutil.QueryNodeRole, paramtable.GetNodeID())
|
||||
|
||||
tSafe := newTSafeReplica()
|
||||
dataSyncService := newDataSyncService(ctx, replica, tSafe, fac)
|
||||
dataSyncService := newDataSyncService(ctx, replica, tSafe, dispClient, fac)
|
||||
assert.NotNil(t, dataSyncService)
|
||||
|
||||
t.Run("test DMLFlowGraphs", func(t *testing.T) {
|
||||
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
|
||||
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1)
|
||||
|
||||
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
|
||||
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1)
|
||||
|
||||
|
@ -68,7 +75,7 @@ func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
|
|||
assert.Nil(t, fg)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
|
||||
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1)
|
||||
|
||||
|
@ -88,7 +95,7 @@ func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
|
|||
t.Run("test addFlowGraphsForDMLChannels checkReplica Failed", func(t *testing.T) {
|
||||
err = dataSyncService.metaReplica.removeCollection(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
|
||||
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
|
||||
assert.Error(t, err)
|
||||
dataSyncService.metaReplica.addCollection(defaultCollectionID, genTestCollectionSchema())
|
||||
})
|
||||
|
@ -103,9 +110,10 @@ func TestDataSyncService_DeltaFlowGraphs(t *testing.T) {
|
|||
|
||||
fac := genFactory()
|
||||
assert.NoError(t, err)
|
||||
dispClient := msgdispatcher.NewClient(fac, typeutil.QueryNodeRole, paramtable.GetNodeID())
|
||||
|
||||
tSafe := newTSafeReplica()
|
||||
dataSyncService := newDataSyncService(ctx, replica, tSafe, fac)
|
||||
dataSyncService := newDataSyncService(ctx, replica, tSafe, dispClient, fac)
|
||||
assert.NotNil(t, dataSyncService)
|
||||
|
||||
t.Run("test DeltaFlowGraphs", func(t *testing.T) {
|
||||
|
@ -160,12 +168,14 @@ func TestDataSyncService_DeltaFlowGraphs(t *testing.T) {
|
|||
|
||||
type DataSyncServiceSuite struct {
|
||||
suite.Suite
|
||||
factory dependency.Factory
|
||||
dsService *dataSyncService
|
||||
dispClient msgdispatcher.Client
|
||||
factory dependency.Factory
|
||||
dsService *dataSyncService
|
||||
}
|
||||
|
||||
func (s *DataSyncServiceSuite) SetupSuite() {
|
||||
s.factory = genFactory()
|
||||
s.dispClient = msgdispatcher.NewClient(s.factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
|
||||
}
|
||||
|
||||
func (s *DataSyncServiceSuite) SetupTest() {
|
||||
|
@ -176,7 +186,7 @@ func (s *DataSyncServiceSuite) SetupTest() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
tSafe := newTSafeReplica()
|
||||
s.dsService = newDataSyncService(ctx, replica, tSafe, s.factory)
|
||||
s.dsService = newDataSyncService(ctx, replica, tSafe, s.dispClient, s.factory)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -18,22 +18,20 @@ package querynode
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/flowgraph"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -49,9 +47,9 @@ type queryNodeFlowGraph struct {
|
|||
collectionID UniqueID
|
||||
vchannel Channel
|
||||
flowGraph *flowgraph.TimeTickedFlowGraph
|
||||
dmlStream msgstream.MsgStream
|
||||
tSafeReplica TSafeReplicaInterface
|
||||
consumerCnt int
|
||||
dispClient msgdispatcher.Client
|
||||
}
|
||||
|
||||
// newQueryNodeFlowGraph returns a new queryNodeFlowGraph
|
||||
|
@ -60,16 +58,18 @@ func newQueryNodeFlowGraph(ctx context.Context,
|
|||
metaReplica ReplicaInterface,
|
||||
tSafeReplica TSafeReplicaInterface,
|
||||
vchannel Channel,
|
||||
factory msgstream.Factory) (*queryNodeFlowGraph, error) {
|
||||
pos *msgstream.MsgPosition,
|
||||
dispClient msgdispatcher.Client) (*queryNodeFlowGraph, error) {
|
||||
|
||||
q := &queryNodeFlowGraph{
|
||||
collectionID: collectionID,
|
||||
vchannel: vchannel,
|
||||
tSafeReplica: tSafeReplica,
|
||||
flowGraph: flowgraph.NewTimeTickedFlowGraph(ctx),
|
||||
dispClient: dispClient,
|
||||
}
|
||||
|
||||
dmStreamNode, err := q.newDmInputNode(ctx, factory, collectionID, vchannel, metrics.InsertLabel)
|
||||
dmStreamNode, err := q.newDmInputNode(collectionID, vchannel, pos, metrics.InsertLabel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -123,16 +123,18 @@ func newQueryNodeDeltaFlowGraph(ctx context.Context,
|
|||
metaReplica ReplicaInterface,
|
||||
tSafeReplica TSafeReplicaInterface,
|
||||
vchannel Channel,
|
||||
factory msgstream.Factory) (*queryNodeFlowGraph, error) {
|
||||
dispClient msgdispatcher.Client) (*queryNodeFlowGraph, error) {
|
||||
|
||||
q := &queryNodeFlowGraph{
|
||||
collectionID: collectionID,
|
||||
vchannel: vchannel,
|
||||
tSafeReplica: tSafeReplica,
|
||||
flowGraph: flowgraph.NewTimeTickedFlowGraph(ctx),
|
||||
dispClient: dispClient,
|
||||
}
|
||||
|
||||
dmStreamNode, err := q.newDmInputNode(ctx, factory, collectionID, vchannel, metrics.DeleteLabel)
|
||||
// use nil position, let deltaFlowGraph consume from latest.
|
||||
dmStreamNode, err := q.newDmInputNode(collectionID, vchannel, nil, metrics.DeleteLabel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -184,84 +186,45 @@ func newQueryNodeDeltaFlowGraph(ctx context.Context,
|
|||
}
|
||||
|
||||
// newDmInputNode returns a new inputNode
|
||||
|
||||
func (q *queryNodeFlowGraph) newDmInputNode(ctx context.Context, factory msgstream.Factory, collectionID UniqueID, vchannel Channel, dataType string) (*flowgraph.InputNode, error) {
|
||||
insertStream, err := factory.NewTtMsgStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (q *queryNodeFlowGraph) newDmInputNode(collectionID UniqueID, vchannel Channel, pos *msgstream.MsgPosition, dataType string) (*flowgraph.InputNode, error) {
|
||||
log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()),
|
||||
zap.Int64("collection ID", collectionID),
|
||||
zap.String("vchannel", vchannel))
|
||||
var err error
|
||||
var input <-chan *msgstream.MsgPack
|
||||
tsBegin := time.Now()
|
||||
if pos != nil && len(pos.MsgID) != 0 {
|
||||
input, err = q.dispClient.Register(vchannel, pos, mqwrapper.SubscriptionPositionUnknown)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info("QueryNode seek successfully when register to msgDispatcher",
|
||||
zap.ByteString("msgID", pos.GetMsgID()),
|
||||
zap.Time("tsTime", tsoutil.PhysicalTime(pos.GetTimestamp())),
|
||||
zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(pos.GetTimestamp()))),
|
||||
zap.Duration("timeTaken", time.Since(tsBegin)))
|
||||
} else {
|
||||
input, err = q.dispClient.Register(vchannel, nil, mqwrapper.SubscriptionPositionLatest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info("QueryNode consume successfully when register to msgDispatcher",
|
||||
zap.Duration("timeTaken", time.Since(tsBegin)))
|
||||
}
|
||||
|
||||
q.dmlStream = insertStream
|
||||
|
||||
maxQueueLength := Params.QueryNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()
|
||||
maxParallelism := Params.QueryNodeCfg.FlowGraphMaxParallelism.GetAsInt32()
|
||||
name := fmt.Sprintf("dmInputNode-query-%d-%s", collectionID, vchannel)
|
||||
node := flowgraph.NewInputNode(insertStream, name, maxQueueLength, maxParallelism, typeutil.QueryNodeRole,
|
||||
node := flowgraph.NewInputNode(input, name, maxQueueLength, maxParallelism, typeutil.QueryNodeRole,
|
||||
paramtable.GetNodeID(), collectionID, dataType)
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// consumeFlowGraph would consume by channel and subName
|
||||
func (q *queryNodeFlowGraph) consumeFlowGraph(channel Channel, subName ConsumeSubName) error {
|
||||
if q.dmlStream == nil {
|
||||
return errors.New("null dml message stream in flow graph")
|
||||
}
|
||||
q.dmlStream.AsConsumer([]string{channel}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
log.Info("query node flow graph consumes from PositionUnknown",
|
||||
zap.Int64("collectionID", q.collectionID),
|
||||
zap.String("pchannel", channel),
|
||||
zap.String("vchannel", q.vchannel),
|
||||
zap.String("subName", subName),
|
||||
)
|
||||
q.consumerCnt++
|
||||
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
// consumeFlowGraphFromLatest would consume from latest by channel and subName
|
||||
func (q *queryNodeFlowGraph) consumeFlowGraphFromLatest(channel Channel, subName ConsumeSubName) error {
|
||||
if q.dmlStream == nil {
|
||||
return errors.New("null dml message stream in flow graph")
|
||||
}
|
||||
q.dmlStream.AsConsumer([]string{channel}, subName, mqwrapper.SubscriptionPositionLatest)
|
||||
log.Info("query node flow graph consumes from latest",
|
||||
zap.Int64("collectionID", q.collectionID),
|
||||
zap.String("pchannel", channel),
|
||||
zap.String("vchannel", q.vchannel),
|
||||
zap.String("subName", subName),
|
||||
)
|
||||
q.consumerCnt++
|
||||
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
// seekQueryNodeFlowGraph would seek by position
|
||||
func (q *queryNodeFlowGraph) consumeFlowGraphFromPosition(position *internalpb.MsgPosition) error {
|
||||
q.dmlStream.AsConsumer([]string{position.ChannelName}, position.MsgGroup, mqwrapper.SubscriptionPositionUnknown)
|
||||
|
||||
start := time.Now()
|
||||
err := q.dmlStream.Seek([]*internalpb.MsgPosition{position})
|
||||
// setup first ts
|
||||
q.tSafeReplica.setTSafe(q.vchannel, position.GetTimestamp())
|
||||
|
||||
ts, _ := tsoutil.ParseTS(position.GetTimestamp())
|
||||
log.Info("query node flow graph seeks from position",
|
||||
zap.Int64("collectionID", q.collectionID),
|
||||
zap.String("pchannel", position.ChannelName),
|
||||
zap.String("vchannel", q.vchannel),
|
||||
zap.Time("checkpointTs", ts),
|
||||
zap.Duration("tsLag", time.Since(ts)),
|
||||
zap.Duration("elapse", time.Since(start)),
|
||||
)
|
||||
q.consumerCnt++
|
||||
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
|
||||
return err
|
||||
}
|
||||
|
||||
// close would close queryNodeFlowGraph
|
||||
func (q *queryNodeFlowGraph) close() {
|
||||
q.dispClient.Deregister(q.vchannel)
|
||||
q.flowGraph.Close()
|
||||
if q.dmlStream != nil && q.consumerCnt > 0 {
|
||||
if q.consumerCnt > 0 {
|
||||
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Sub(float64(q.consumerCnt))
|
||||
}
|
||||
log.Info("stop query node flow graph",
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
)
|
||||
|
||||
func TestQueryNodeFlowGraph_consumerFlowGraph(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tSafe := newTSafeReplica()
|
||||
|
||||
streamingReplica, err := genSimpleReplica()
|
||||
assert.NoError(t, err)
|
||||
|
||||
fac := genFactory()
|
||||
|
||||
fg, err := newQueryNodeFlowGraph(ctx,
|
||||
defaultCollectionID,
|
||||
streamingReplica,
|
||||
tSafe,
|
||||
defaultDMLChannel,
|
||||
fac)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = fg.consumeFlowGraph(defaultDMLChannel, defaultSubName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fg.close()
|
||||
}
|
||||
|
||||
func TestQueryNodeFlowGraph_seekQueryNodeFlowGraph(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
streamingReplica, err := genSimpleReplica()
|
||||
assert.NoError(t, err)
|
||||
|
||||
fac := genFactory()
|
||||
|
||||
tSafe := newTSafeReplica()
|
||||
|
||||
fg, err := newQueryNodeFlowGraph(ctx,
|
||||
defaultCollectionID,
|
||||
streamingReplica,
|
||||
tSafe,
|
||||
defaultDMLChannel,
|
||||
fac)
|
||||
assert.NoError(t, err)
|
||||
|
||||
position := &internalpb.MsgPosition{
|
||||
ChannelName: defaultDMLChannel,
|
||||
MsgID: []byte{},
|
||||
MsgGroup: defaultSubName,
|
||||
Timestamp: 0,
|
||||
}
|
||||
err = fg.consumeFlowGraphFromPosition(position)
|
||||
assert.Error(t, err)
|
||||
|
||||
fg.close()
|
||||
}
|
|
@ -26,7 +26,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/log"
|
||||
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
|
@ -189,31 +188,6 @@ func (l *loadSegmentsTask) watchDeltaChannel(deltaChannels []string) error {
|
|||
}
|
||||
}
|
||||
}()
|
||||
consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName.GetValue(), collectionID, paramtable.GetNodeID())
|
||||
|
||||
// channels as consumer
|
||||
for channel, fg := range channel2FlowGraph {
|
||||
pchannel := VPDeltaChannels[channel]
|
||||
// use pChannel to consume
|
||||
err = fg.consumeFlowGraphFromLatest(pchannel, consumeSubName)
|
||||
if err != nil {
|
||||
log.Error("msgStream as consumer failed for deltaChannels", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warn("watchDeltaChannel, add flowGraph for deltaChannel failed", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels), zap.Error(err))
|
||||
for _, fg := range channel2FlowGraph {
|
||||
fg.flowGraph.Close()
|
||||
}
|
||||
gcChannels := make([]Channel, 0)
|
||||
for channel := range channel2FlowGraph {
|
||||
gcChannels = append(gcChannels, channel)
|
||||
}
|
||||
l.node.dataSyncService.removeFlowGraphsByDeltaChannels(gcChannels)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("watchDeltaChannel, add flowGraph for deltaChannel success", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels))
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/common"
|
||||
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/etcdpb"
|
||||
|
@ -1702,6 +1703,7 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory
|
|||
etcdKV := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath.GetValue())
|
||||
node.etcdKV = etcdKV
|
||||
|
||||
node.dispClient = msgdispatcher.NewClient(fac, typeutil.QueryNodeRole, paramtable.GetNodeID())
|
||||
node.tSafeReplica = newTSafeReplica()
|
||||
|
||||
replica, err := genSimpleReplicaWithSealSegment(ctx)
|
||||
|
@ -1711,7 +1713,7 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory
|
|||
node.tSafeReplica.addTSafe(defaultDMLChannel)
|
||||
|
||||
node.tSafeReplica.addTSafe(defaultDeltaChannel)
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, replica, node.tSafeReplica, node.factory)
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, replica, node.tSafeReplica, node.dispClient, node.factory)
|
||||
|
||||
node.metaReplica = replica
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ import "C"
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
|
@ -106,8 +107,9 @@ type QueryNode struct {
|
|||
etcdCli *clientv3.Client
|
||||
address string
|
||||
|
||||
factory dependency.Factory
|
||||
scheduler *taskScheduler
|
||||
dispClient msgdispatcher.Client
|
||||
factory dependency.Factory
|
||||
scheduler *taskScheduler
|
||||
|
||||
sessionMu sync.Mutex
|
||||
session *sessionutil.Session
|
||||
|
@ -256,6 +258,9 @@ func (node *QueryNode) Init() error {
|
|||
}
|
||||
log.Info("QueryNode init rateCollector done", zap.Int64("nodeID", paramtable.GetNodeID()))
|
||||
|
||||
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
|
||||
log.Info("QueryNode init dispatcher client done", zap.Int64("nodeID", paramtable.GetNodeID()))
|
||||
|
||||
node.vectorStorage, err = node.factory.NewPersistentStorageChunkManager(node.queryNodeLoopCtx)
|
||||
if err != nil {
|
||||
log.Error("QueryNode init vector storage failed", zap.Error(err))
|
||||
|
@ -283,7 +288,7 @@ func (node *QueryNode) Init() error {
|
|||
node.vectorStorage,
|
||||
node.factory)
|
||||
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.metaReplica, node.tSafeReplica, node.factory)
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.metaReplica, node.tSafeReplica, node.dispClient, node.factory)
|
||||
|
||||
node.InitSegcore()
|
||||
|
||||
|
|
|
@ -27,13 +27,14 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/etcd/server/v3/embed"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
|
||||
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/etcd"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
var embedetcdServer *embed.Etcd
|
||||
|
@ -98,10 +99,9 @@ func newQueryNodeMock() *QueryNode {
|
|||
factory := newMessageStreamFactory()
|
||||
svr := NewQueryNode(ctx, factory)
|
||||
tsReplica := newTSafeReplica()
|
||||
|
||||
replica := newCollectionReplica()
|
||||
svr.metaReplica = replica
|
||||
svr.dataSyncService = newDataSyncService(ctx, svr.metaReplica, tsReplica, factory)
|
||||
svr.dispClient = msgdispatcher.NewClient(factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
|
||||
svr.metaReplica = newCollectionReplica()
|
||||
svr.dataSyncService = newDataSyncService(ctx, svr.metaReplica, tsReplica, svr.dispClient, factory)
|
||||
svr.vectorStorage, err = factory.NewPersistentStorageChunkManager(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
|
|
@ -458,7 +458,7 @@ func TestTask_releasePartitionTask(t *testing.T) {
|
|||
req: genReleasePartitionsRequest(),
|
||||
node: node,
|
||||
}
|
||||
_, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
|
||||
_, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
|
||||
assert.NoError(t, err)
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
@ -534,7 +534,7 @@ func TestTask_releasePartitionTask(t *testing.T) {
|
|||
req: genReleasePartitionsRequest(),
|
||||
node: node,
|
||||
}
|
||||
_, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel})
|
||||
_, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
|
||||
assert.NoError(t, err)
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -122,7 +122,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
|
|||
}
|
||||
}()
|
||||
|
||||
channel2FlowGraph, err := w.initFlowGraph(ctx, collectionID, vChannels, VPChannels)
|
||||
channel2FlowGraph, err := w.initFlowGraph(collectionID, vChannels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -243,21 +243,16 @@ func (w *watchDmChannelsTask) LoadGrowingSegments(ctx context.Context, collectio
|
|||
return unFlushedSegmentIDs, nil
|
||||
}
|
||||
|
||||
func (w *watchDmChannelsTask) initFlowGraph(ctx context.Context, collectionID UniqueID, vChannels []Channel, VPChannels map[string]string) (map[string]*queryNodeFlowGraph, error) {
|
||||
func (w *watchDmChannelsTask) initFlowGraph(collectionID UniqueID, vChannels []Channel) (map[string]*queryNodeFlowGraph, error) {
|
||||
// So far, we don't support to enable each node with two different channel
|
||||
consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName.GetValue(), collectionID, paramtable.GetNodeID())
|
||||
|
||||
// group channels by to seeking or consuming
|
||||
// group channels by to seeking
|
||||
channel2SeekPosition := make(map[string]*internalpb.MsgPosition)
|
||||
|
||||
// for channel with no position
|
||||
channel2AsConsumerPosition := make(map[string]*internalpb.MsgPosition)
|
||||
for _, info := range w.req.Infos {
|
||||
if info.SeekPosition == nil || len(info.SeekPosition.MsgID) == 0 {
|
||||
channel2AsConsumerPosition[info.ChannelName] = info.SeekPosition
|
||||
continue
|
||||
if info.SeekPosition != nil && len(info.SeekPosition.MsgID) != 0 {
|
||||
info.SeekPosition.MsgGroup = consumeSubName
|
||||
}
|
||||
info.SeekPosition.MsgGroup = consumeSubName
|
||||
channel2SeekPosition[info.ChannelName] = info.SeekPosition
|
||||
}
|
||||
log.Info("watchDMChannel, group channels done", zap.Int64("collectionID", collectionID))
|
||||
|
@ -333,49 +328,11 @@ func (w *watchDmChannelsTask) initFlowGraph(ctx context.Context, collectionID Un
|
|||
)
|
||||
|
||||
// add flow graph
|
||||
channel2FlowGraph, err := w.node.dataSyncService.addFlowGraphsForDMLChannels(collectionID, vChannels)
|
||||
channel2FlowGraph, err := w.node.dataSyncService.addFlowGraphsForDMLChannels(collectionID, channel2SeekPosition)
|
||||
if err != nil {
|
||||
log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
log.Info("Query node add DML flow graphs", zap.Int64("collectionID", collectionID), zap.Any("channels", vChannels))
|
||||
|
||||
// channels as consumer
|
||||
for channel, fg := range channel2FlowGraph {
|
||||
if _, ok := channel2AsConsumerPosition[channel]; ok {
|
||||
// use pChannel to consume
|
||||
err = fg.consumeFlowGraph(VPChannels[channel], consumeSubName)
|
||||
if err != nil {
|
||||
log.Error("msgStream as consumer failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if pos, ok := channel2SeekPosition[channel]; ok {
|
||||
pos.MsgGroup = consumeSubName
|
||||
// use pChannel to seek
|
||||
pos.ChannelName = VPChannels[channel]
|
||||
err = fg.consumeFlowGraphFromPosition(pos)
|
||||
if err != nil {
|
||||
log.Error("msgStream seek failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err))
|
||||
for _, fg := range channel2FlowGraph {
|
||||
fg.flowGraph.Close()
|
||||
}
|
||||
gcChannels := make([]Channel, 0)
|
||||
for channel := range channel2FlowGraph {
|
||||
gcChannels = append(gcChannels, channel)
|
||||
}
|
||||
w.node.dataSyncService.removeFlowGraphsByDMLChannels(gcChannels)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info("watchDMChannel, add flowGraph for dmChannels success", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
|
||||
return channel2FlowGraph, nil
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ import (
|
|||
// InputNode is the entry point of flowgragh
|
||||
type InputNode struct {
|
||||
BaseNode
|
||||
inStream msgstream.MsgStream
|
||||
input <-chan *msgstream.MsgPack
|
||||
lastMsg *msgstream.MsgPack
|
||||
name string
|
||||
role string
|
||||
|
@ -51,17 +51,6 @@ func (inNode *InputNode) IsInputNode() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// Start is used to start input msgstream
|
||||
func (inNode *InputNode) Start() {
|
||||
}
|
||||
|
||||
// Close implements node
|
||||
func (inNode *InputNode) Close() {
|
||||
inNode.closeOnce.Do(func() {
|
||||
inNode.inStream.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func (inNode *InputNode) IsValidInMsg(in []Msg) bool {
|
||||
return true
|
||||
}
|
||||
|
@ -71,16 +60,11 @@ func (inNode *InputNode) Name() string {
|
|||
return inNode.name
|
||||
}
|
||||
|
||||
// InStream returns the internal MsgStream
|
||||
func (inNode *InputNode) InStream() msgstream.MsgStream {
|
||||
return inNode.inStream
|
||||
}
|
||||
|
||||
// Operate consume a message pack from msgstream and return
|
||||
func (inNode *InputNode) Operate(in []Msg) []Msg {
|
||||
msgPack, ok := <-inNode.inStream.Chan()
|
||||
msgPack, ok := <-inNode.input
|
||||
if !ok {
|
||||
log.Warn("MsgStream closed", zap.Any("input node", inNode.Name()))
|
||||
log.Warn("input closed", zap.Any("input node", inNode.Name()))
|
||||
if inNode.lastMsg != nil {
|
||||
log.Info("trigger force sync", zap.Int64("collection", inNode.collectionID), zap.Any("position", inNode.lastMsg))
|
||||
return []Msg{&MsgStreamMsg{
|
||||
|
@ -151,15 +135,15 @@ func (inNode *InputNode) Operate(in []Msg) []Msg {
|
|||
return []Msg{msgStreamMsg}
|
||||
}
|
||||
|
||||
// NewInputNode composes an InputNode with provided MsgStream, name and parameters
|
||||
func NewInputNode(inStream msgstream.MsgStream, nodeName string, maxQueueLength int32, maxParallelism int32, role string, nodeID int64, collectionID int64, dataType string) *InputNode {
|
||||
// NewInputNode composes an InputNode with provided input channel, name and parameters
|
||||
func NewInputNode(input <-chan *msgstream.MsgPack, nodeName string, maxQueueLength int32, maxParallelism int32, role string, nodeID int64, collectionID int64, dataType string) *InputNode {
|
||||
baseNode := BaseNode{}
|
||||
baseNode.SetMaxQueueLength(maxQueueLength)
|
||||
baseNode.SetMaxParallelism(maxParallelism)
|
||||
|
||||
return &InputNode{
|
||||
BaseNode: baseNode,
|
||||
inStream: inStream,
|
||||
input: input,
|
||||
name: nodeName,
|
||||
role: role,
|
||||
nodeID: nodeID,
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestInputNode(t *testing.T) {
|
|||
produceStream.Produce(&msgPack)
|
||||
|
||||
nodeName := "input_node"
|
||||
inputNode := NewInputNode(msgStream, nodeName, 100, 100, "", 0, 0, "")
|
||||
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
|
||||
defer inputNode.Close()
|
||||
|
||||
isInputNode := inputNode.IsInputNode()
|
||||
|
@ -49,9 +49,6 @@ func TestInputNode(t *testing.T) {
|
|||
name := inputNode.Name()
|
||||
assert.Equal(t, name, nodeName)
|
||||
|
||||
stream := inputNode.InStream()
|
||||
assert.NotNil(t, stream)
|
||||
|
||||
output := inputNode.Operate(nil)
|
||||
assert.NotNil(t, output)
|
||||
msg, ok := output[0].(*MsgStreamMsg)
|
||||
|
|
|
@ -76,6 +76,10 @@ func (bm *MockMsg) SetPosition(position *MsgPosition) {
|
|||
|
||||
}
|
||||
|
||||
func (bm *MockMsg) VChannel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func Test_GenerateMsgStreamMsg(t *testing.T) {
|
||||
messages := make([]msgstream.TsMsg, 1)
|
||||
messages[0] = &MockMsg{
|
||||
|
|
|
@ -74,7 +74,7 @@ func TestNodeCtx_Start(t *testing.T) {
|
|||
produceStream.Produce(&msgPack)
|
||||
|
||||
nodeName := "input_node"
|
||||
inputNode := NewInputNode(msgStream, nodeName, 100, 100, "", 0, 0, "")
|
||||
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
|
||||
|
||||
node := &nodeCtx{
|
||||
node: inputNode,
|
||||
|
|
Loading…
Reference in New Issue