feat: datacoord/node watch based on rpc (#32036)

issue: https://github.com/milvus-io/milvus/issues/25309

Signed-off-by: yiwangdr <yiwangdr@gmail.com>
pull/32819/head
yiwangdr 2024-05-07 00:49:30 -07:00 committed by GitHub
parent efa0c73c62
commit b1eacb2ae8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 4577 additions and 1255 deletions

View File

@ -457,6 +457,7 @@ generate-mockery-datacoord: getdeps
$(INSTALL_PATH)/mockery --name=CompactionMeta --dir=internal/datacoord --filename=mock_compaction_meta.go --output=internal/datacoord --structname=MockCompactionMeta --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=Scheduler --dir=internal/datacoord --filename=mock_scheduler.go --output=internal/datacoord --structname=MockScheduler --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=ChannelManager --dir=internal/datacoord --filename=mock_channelmanager.go --output=internal/datacoord --structname=MockChannelManager --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=SubCluster --dir=internal/datacoord --filename=mock_subcluster.go --output=internal/datacoord --structname=MockSubCluster --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=Broker --dir=internal/datacoord/broker --filename=mock_coordinator_broker.go --output=internal/datacoord/broker --structname=MockBroker --with-expecter --inpackage
generate-mockery-datanode: getdeps

View File

@ -19,9 +19,14 @@ package datacoord
import (
"fmt"
"github.com/gogo/protobuf/proto"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
type ROChannel interface {
@ -39,7 +44,30 @@ type RWChannel interface {
UpdateWatchInfo(info *datapb.ChannelWatchInfo)
}
var _ RWChannel = (*channelMeta)(nil)
func NewRWChannel(name string,
collectionID int64,
startPos []*commonpb.KeyDataPair,
schema *schemapb.CollectionSchema,
createTs uint64,
) RWChannel {
if paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.GetAsBool() {
return &StateChannel{
Name: name,
CollectionID: collectionID,
StartPositions: startPos,
Schema: schema,
CreateTimestamp: createTs,
}
}
return &channelMeta{
Name: name,
CollectionID: collectionID,
StartPositions: startPos,
Schema: schema,
CreateTimestamp: createTs,
}
}
type channelMeta struct {
Name string
@ -50,8 +78,13 @@ type channelMeta struct {
WatchInfo *datapb.ChannelWatchInfo
}
var _ RWChannel = (*channelMeta)(nil)
func (ch *channelMeta) UpdateWatchInfo(info *datapb.ChannelWatchInfo) {
ch.WatchInfo = info
log.Info("Channel updating watch info",
zap.Any("old watch info", ch.WatchInfo),
zap.Any("new watch info", info))
ch.WatchInfo = proto.Clone(info).(*datapb.ChannelWatchInfo)
}
func (ch *channelMeta) GetWatchInfo() *datapb.ChannelWatchInfo {
@ -83,3 +116,166 @@ func (ch *channelMeta) String() string {
// schema maybe too large to print
return fmt.Sprintf("Name: %s, CollectionID: %d, StartPositions: %v", ch.Name, ch.CollectionID, ch.StartPositions)
}
type ChannelState string
const (
Standby ChannelState = "Standby"
ToWatch ChannelState = "ToWatch"
Watching ChannelState = "Watching"
Watched ChannelState = "Watched"
ToRelease ChannelState = "ToRelease"
Releasing ChannelState = "Releasing"
Legacy ChannelState = "Legacy"
)
type StateChannel struct {
Name string
CollectionID UniqueID
StartPositions []*commonpb.KeyDataPair
Schema *schemapb.CollectionSchema
CreateTimestamp uint64
Info *datapb.ChannelWatchInfo
currentState ChannelState
assignedNode int64
}
var _ RWChannel = (*StateChannel)(nil)
func NewStateChannel(ch RWChannel) *StateChannel {
c := &StateChannel{
Name: ch.GetName(),
CollectionID: ch.GetCollectionID(),
StartPositions: ch.GetStartPositions(),
Schema: ch.GetSchema(),
CreateTimestamp: ch.GetCreateTimestamp(),
Info: ch.GetWatchInfo(),
assignedNode: bufferID,
}
c.setState(Standby)
return c
}
func NewStateChannelByWatchInfo(nodeID int64, info *datapb.ChannelWatchInfo) *StateChannel {
c := &StateChannel{
Name: info.GetVchan().GetChannelName(),
CollectionID: info.GetVchan().GetCollectionID(),
Schema: info.GetSchema(),
Info: info,
assignedNode: nodeID,
}
switch info.GetState() {
case datapb.ChannelWatchState_ToWatch:
c.setState(ToWatch)
case datapb.ChannelWatchState_ToRelease:
c.setState(ToRelease)
// legacy state
case datapb.ChannelWatchState_WatchSuccess:
c.setState(Watched)
case datapb.ChannelWatchState_WatchFailure, datapb.ChannelWatchState_ReleaseSuccess, datapb.ChannelWatchState_ReleaseFailure:
c.setState(Standby)
default:
c.setState(Standby)
}
if nodeID == bufferID {
c.setState(Standby)
}
return c
}
func (c *StateChannel) TransitionOnSuccess() {
switch c.currentState {
case Standby:
c.setState(ToWatch)
case ToWatch:
c.setState(Watching)
case Watching:
c.setState(Watched)
case Watched:
c.setState(ToRelease)
case ToRelease:
c.setState(Releasing)
case Releasing:
c.setState(Standby)
}
}
func (c *StateChannel) TransitionOnFailure() {
switch c.currentState {
case Watching:
c.setState(Standby)
case Releasing:
c.setState(Standby)
case Standby, ToWatch, Watched, ToRelease:
// Stay original state
}
}
func (c *StateChannel) Clone() *StateChannel {
return &StateChannel{
Name: c.Name,
CollectionID: c.CollectionID,
StartPositions: c.StartPositions,
Schema: c.Schema,
CreateTimestamp: c.CreateTimestamp,
Info: proto.Clone(c.Info).(*datapb.ChannelWatchInfo),
currentState: c.currentState,
assignedNode: c.assignedNode,
}
}
func (c *StateChannel) String() string {
// schema maybe too large to print
return fmt.Sprintf("Name: %s, CollectionID: %d, StartPositions: %v", c.Name, c.CollectionID, c.StartPositions)
}
func (c *StateChannel) GetName() string {
return c.Name
}
func (c *StateChannel) GetCollectionID() UniqueID {
return c.CollectionID
}
func (c *StateChannel) GetStartPositions() []*commonpb.KeyDataPair {
return c.StartPositions
}
func (c *StateChannel) GetSchema() *schemapb.CollectionSchema {
return c.Schema
}
func (c *StateChannel) GetCreateTimestamp() Timestamp {
return c.CreateTimestamp
}
func (c *StateChannel) GetWatchInfo() *datapb.ChannelWatchInfo {
return c.Info
}
func (c *StateChannel) UpdateWatchInfo(info *datapb.ChannelWatchInfo) {
if c.Info != nil && c.Info.Vchan != nil && info.GetVchan().GetChannelName() != c.Info.GetVchan().GetChannelName() {
log.Warn("Updating incorrect channel watch info",
zap.Any("old watch info", c.Info),
zap.Any("new watch info", info),
zap.Stack("call stack"),
)
return
}
c.Info = proto.Clone(info).(*datapb.ChannelWatchInfo)
}
func (c *StateChannel) Assign(nodeID int64) {
c.assignedNode = nodeID
}
func (c *StateChannel) setState(state ChannelState) {
c.currentState = state
}

View File

@ -36,25 +36,6 @@ import (
"github.com/milvus-io/milvus/pkg/util/logutil"
)
type ChannelManager interface {
Startup(ctx context.Context, nodes []int64) error
Close()
AddNode(nodeID int64) error
DeleteNode(nodeID int64) error
Watch(ctx context.Context, ch RWChannel) error
RemoveChannel(channelName string) error
Release(nodeID UniqueID, channelName string) error
Match(nodeID int64, channel string) bool
FindWatcher(channel string) (int64, error)
GetNodeChannelsByCollectionID(collectionID UniqueID) map[UniqueID][]string
GetChannelsByCollectionID(collectionID UniqueID) []RWChannel
GetCollectionIDByChannel(channel string) (bool, UniqueID)
GetNodeIDByChannelName(channel string) (bool, UniqueID)
}
// ChannelManagerImpl manages the allocation and the balance between channels and data nodes.
type ChannelManagerImpl struct {
ctx context.Context
@ -66,8 +47,8 @@ type ChannelManagerImpl struct {
deregisterPolicy DeregisterPolicy
assignPolicy ChannelAssignPolicy
reassignPolicy ChannelReassignPolicy
bgChecker ChannelBGChecker
balancePolicy BalanceChannelPolicy
bgChecker ChannelBGChecker
msgstreamFactory msgstream.Factory
stateChecker channelStateChecker
@ -105,7 +86,7 @@ func NewChannelManager(
c := &ChannelManagerImpl{
ctx: context.TODO(),
h: h,
factory: NewChannelPolicyFactoryV1(kv),
factory: NewChannelPolicyFactoryV1(),
store: NewChannelStore(kv),
stateTimer: newChannelStateTimer(kv),
}
@ -128,7 +109,7 @@ func NewChannelManager(
}
// Startup adjusts the channel store according to current cluster states.
func (c *ChannelManagerImpl) Startup(ctx context.Context, nodes []int64) error {
func (c *ChannelManagerImpl) Startup(ctx context.Context, legacyNodes, allNodes []int64) error {
c.ctx = ctx
channels := c.store.GetNodesChannels()
// Retrieve the current old nodes.
@ -138,13 +119,13 @@ func (c *ChannelManagerImpl) Startup(ctx context.Context, nodes []int64) error {
}
// Process watch states for old nodes.
oldOnLines := c.getOldOnlines(nodes, oNodes)
oldOnLines := c.getOldOnlines(allNodes, oNodes)
if err := c.checkOldNodes(oldOnLines); err != nil {
return err
}
// Add new online nodes to the cluster.
newOnLines := c.getNewOnLines(nodes, oNodes)
newOnLines := c.getNewOnLines(allNodes, oNodes)
for _, n := range newOnLines {
if err := c.AddNode(n); err != nil {
return err
@ -152,7 +133,7 @@ func (c *ChannelManagerImpl) Startup(ctx context.Context, nodes []int64) error {
}
// Remove new offline nodes from the cluster.
offLines := c.getOffLines(nodes, oNodes)
offLines := c.getOffLines(allNodes, oNodes)
for _, n := range offLines {
if err := c.DeleteNode(n); err != nil {
return err
@ -176,7 +157,7 @@ func (c *ChannelManagerImpl) Startup(ctx context.Context, nodes []int64) error {
}
log.Info("cluster start up",
zap.Int64s("nodes", nodes),
zap.Int64s("nodes", allNodes),
zap.Int64s("oNodes", oNodes),
zap.Int64s("old onlines", oldOnLines),
zap.Int64s("new onlines", newOnLines),
@ -247,7 +228,7 @@ func (c *ChannelManagerImpl) checkOldNodes(nodes []UniqueID) error {
// unwatchDroppedChannels removes drops channel that are marked to drop.
func (c *ChannelManagerImpl) unwatchDroppedChannels() {
nodeChannels := c.store.GetChannels()
nodeChannels := c.store.GetNodesChannels()
for _, nodeChannel := range nodeChannels {
for _, ch := range nodeChannel.Channels {
if !c.isMarkedDrop(ch.GetName()) {
@ -284,9 +265,14 @@ func (c *ChannelManagerImpl) bgCheckChannelsWork(ctx context.Context) {
if !c.isSilent() {
log.Info("ChannelManager is not silent, skip channel balance this round")
} else {
toReleases := c.balancePolicy(c.store, time.Now())
log.Info("channel manager bg check balance", zap.Array("toReleases", toReleases))
if err := c.updateWithTimer(toReleases, datapb.ChannelWatchState_ToRelease); err != nil {
currCluster := c.store.GetNodesChannels()
updates := c.balancePolicy(currCluster)
if updates == nil {
continue
}
log.Info("channel manager bg check balance", zap.Array("toReleases", updates))
if err := c.updateWithTimer(updates, datapb.ChannelWatchState_ToRelease); err != nil {
log.Warn("channel store update error", zap.Error(err))
}
}
@ -345,7 +331,7 @@ func (c *ChannelManagerImpl) AddNode(nodeID int64) error {
c.mu.Lock()
defer c.mu.Unlock()
c.store.Add(nodeID)
c.store.AddNode(nodeID)
bufferedUpdates, balanceUpdates := c.registerPolicy(c.store, nodeID)
@ -386,6 +372,7 @@ func (c *ChannelManagerImpl) DeleteNode(nodeID int64) error {
nodeChannelInfo := c.store.GetNode(nodeID)
if nodeChannelInfo == nil {
c.store.RemoveNode(nodeID)
return nil
}
@ -393,6 +380,7 @@ func (c *ChannelManagerImpl) DeleteNode(nodeID int64) error {
updates := c.deregisterPolicy(c.store, nodeID)
if updates == nil {
c.store.RemoveNode(nodeID)
return nil
}
log.Info("deregister node", zap.Int64("nodeID", nodeID), zap.Array("updates", updates))
@ -417,8 +405,8 @@ func (c *ChannelManagerImpl) DeleteNode(nodeID int64) error {
}
// No channels will be return
_, err := c.store.Delete(nodeID)
return err
c.store.RemoveNode(nodeID)
return nil
}
// unsubAttempt attempts to unsubscribe node-channel info from the channel.
@ -558,25 +546,21 @@ func (c *ChannelManagerImpl) Match(nodeID int64, channel string) bool {
}
// FindWatcher finds the datanode watching the provided channel.
func (c *ChannelManagerImpl) FindWatcher(channel string) (int64, error) {
func (c *ChannelManagerImpl) FindWatcher(channelName string) (int64, error) {
c.mu.RLock()
defer c.mu.RUnlock()
infos := c.store.GetNodesChannels()
for _, info := range infos {
for _, channelInfo := range info.Channels {
if channelInfo.GetName() == channel {
return info.NodeID, nil
}
if _, ok := info.Channels[channelName]; ok {
return info.NodeID, nil
}
}
// channel in buffer
bufferInfo := c.store.GetBufferChannelInfo()
for _, channelInfo := range bufferInfo.Channels {
if channelInfo.GetName() == channel {
return bufferID, errChannelInBuffer
}
if _, ok := bufferInfo.Channels[channelName]; ok {
return bufferID, errChannelInBuffer
}
return 0, errChannelNotWatched
}
@ -610,10 +594,8 @@ func (c *ChannelManagerImpl) remove(nodeID int64, ch RWChannel) error {
func (c *ChannelManagerImpl) findChannel(channelName string) (int64, RWChannel) {
infos := c.store.GetNodesChannels()
for _, info := range infos {
for _, channelInfo := range info.Channels {
if channelInfo.GetName() == channelName {
return info.NodeID, channelInfo
}
if channelInfo, ok := info.Channels[channelName]; ok {
return info.NodeID, channelInfo
}
}
return 0, nil
@ -640,7 +622,7 @@ type ackEvent struct {
func (c *ChannelManagerImpl) updateWithTimer(updates *ChannelOpSet, state datapb.ChannelWatchState) error {
channelsWithTimer := []string{}
for _, op := range updates.Collect() {
if op.Type == Add {
if op.Type != Delete {
channelsWithTimer = append(channelsWithTimer, c.fillChannelWatchInfoWithState(op, state)...)
}
}
@ -807,14 +789,9 @@ func (c *ChannelManagerImpl) Reassign(originNodeID UniqueID, channelName string)
reallocates := NewNodeChannelInfo(originNodeID, ch)
isDropped := c.isMarkedDrop(channelName)
c.mu.Lock()
defer c.mu.Unlock()
ch = c.getChannelByNodeAndName(originNodeID, channelName)
if ch == nil {
return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", originNodeID, channelName)
}
if isDropped {
c.mu.Lock()
defer c.mu.Unlock()
if err := c.remove(originNodeID, ch); err != nil {
return fmt.Errorf("failed to remove watch info: %v,%s", ch, err.Error())
}
@ -825,6 +802,8 @@ func (c *ChannelManagerImpl) Reassign(originNodeID UniqueID, channelName string)
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
// Reassign policy won't choose the original node when a reassigning a channel.
updates := c.reassignPolicy(c.store, []*NodeChannelInfo{reallocates})
if updates == nil {
@ -864,11 +843,6 @@ func (c *ChannelManagerImpl) CleanupAndReassign(nodeID UniqueID, channelName str
c.mu.Lock()
defer c.mu.Unlock()
chToCleanUp = c.getChannelByNodeAndName(nodeID, channelName)
if chToCleanUp == nil {
return fmt.Errorf("failed to find matching channel: %s and node: %d", channelName, nodeID)
}
if isDropped {
if err := c.remove(nodeID, chToCleanUp); err != nil {
return fmt.Errorf("failed to remove watch info: %v,%s", chToCleanUp, err.Error())
@ -900,42 +874,38 @@ func (c *ChannelManagerImpl) CleanupAndReassign(nodeID UniqueID, channelName str
}
func (c *ChannelManagerImpl) getChannelByNodeAndName(nodeID UniqueID, channelName string) RWChannel {
var ret RWChannel
nodeChannelInfo := c.store.GetNode(nodeID)
if nodeChannelInfo == nil {
return nil
}
for _, channel := range nodeChannelInfo.Channels {
if channel.GetName() == channelName {
ret = channel
break
if nodeChannelInfo := c.store.GetNode(nodeID); nodeChannelInfo != nil {
if ch, ok := nodeChannelInfo.Channels[channelName]; ok {
return ch
}
}
return ret
return nil
}
func (c *ChannelManagerImpl) GetCollectionIDByChannel(channel string) (bool, UniqueID) {
func (c *ChannelManagerImpl) GetCollectionIDByChannel(channelName string) (bool, UniqueID) {
for _, nodeChannel := range c.GetAssignedChannels() {
for _, ch := range nodeChannel.Channels {
if ch.GetName() == channel {
return true, ch.GetCollectionID()
}
if ch, ok := nodeChannel.Channels[channelName]; ok {
return true, ch.GetCollectionID()
}
}
return false, 0
}
func (c *ChannelManagerImpl) GetNodeIDByChannelName(channel string) (bool, UniqueID) {
func (c *ChannelManagerImpl) GetNodeIDByChannelName(channelName string) (UniqueID, bool) {
for _, nodeChannel := range c.GetAssignedChannels() {
for _, ch := range nodeChannel.Channels {
if ch.GetName() == channel {
return true, nodeChannel.NodeID
}
if _, ok := nodeChannel.Channels[channelName]; ok {
return nodeChannel.NodeID, true
}
}
return false, 0
return 0, false
}
func (c *ChannelManagerImpl) GetChannel(nodeID int64, channelName string) (RWChannel, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
ch := c.getChannelByNodeAndName(nodeID, channelName)
return ch, ch != nil
}
func (c *ChannelManagerImpl) isMarkedDrop(channel string) bool {

View File

@ -16,10 +16,6 @@
package datacoord
import (
"github.com/milvus-io/milvus/internal/kv"
)
// ChannelPolicyFactory is the abstract factory that creates policies for channel manager.
type ChannelPolicyFactory interface {
// NewRegisterPolicy creates a new register policy.
@ -35,13 +31,11 @@ type ChannelPolicyFactory interface {
}
// ChannelPolicyFactoryV1 equal to policy batch
type ChannelPolicyFactoryV1 struct {
kv kv.TxnKV
}
type ChannelPolicyFactoryV1 struct{}
// NewChannelPolicyFactoryV1 helper function creates a Channel policy factory v1 from kv.
func NewChannelPolicyFactoryV1(kv kv.TxnKV) *ChannelPolicyFactoryV1 {
return &ChannelPolicyFactoryV1{kv: kv}
func NewChannelPolicyFactoryV1() *ChannelPolicyFactoryV1 {
return &ChannelPolicyFactoryV1{}
}
// NewRegisterPolicy implementing ChannelPolicyFactory returns BufferChannelAssignPolicy.

View File

@ -492,7 +492,7 @@ func TestChannelManager(t *testing.T) {
waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, bufferID, bufferCh, collectionID)
chManager.store.Add(nodeID)
chManager.store.AddNode(nodeID)
err = chManager.Watch(context.TODO(), &channelMeta{Name: chanToAdd, CollectionID: collectionID})
assert.NoError(t, err)
waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, nodeID, chanToAdd, collectionID)
@ -544,7 +544,7 @@ func TestChannelManager(t *testing.T) {
// prepare tests
for _, test := range tests {
chManager.store.Add(test.nodeID)
chManager.store.AddNode(test.nodeID)
ops := getTestOps(test.nodeID, &channelMeta{Name: test.chName, CollectionID: collectionID, WatchInfo: &datapb.ChannelWatchInfo{}})
err = chManager.store.Update(ops)
require.NoError(t, err)
@ -557,7 +557,7 @@ func TestChannelManager(t *testing.T) {
remainTest, reassignTest := tests[0], tests[1]
err = chManager.Reassign(reassignTest.nodeID, reassignTest.chName)
assert.NoError(t, err)
chManager.stateTimer.stopIfExist(&ackEvent{releaseSuccessAck, reassignTest.chName, reassignTest.nodeID})
chManager.stateTimer.stopIfExist(&ackEvent{watchSuccessAck, reassignTest.chName, reassignTest.nodeID})
// test nodes of reassignTest contains no channel
// test all channels are assgined to node of remainTest
@ -587,6 +587,7 @@ func TestChannelManager(t *testing.T) {
t.Run("test Reassign with dropped channel", func(t *testing.T) {
collectionID := UniqueID(5)
watchkv.RemoveWithPrefix("")
handler := NewNMockHandler(t)
handler.EXPECT().
CheckShouldDropChannel(mock.Anything).
@ -595,7 +596,7 @@ func TestChannelManager(t *testing.T) {
chManager, err := NewChannelManager(watchkv, handler)
require.NoError(t, err)
chManager.store.Add(1)
chManager.store.AddNode(1)
ops := getTestOps(1, &channelMeta{Name: "chan", CollectionID: collectionID, WatchInfo: &datapb.ChannelWatchInfo{}})
err = chManager.store.Update(ops)
require.NoError(t, err)
@ -610,24 +611,16 @@ func TestChannelManager(t *testing.T) {
var chManager *ChannelManagerImpl
var err error
handler := NewNMockHandler(t)
handler.EXPECT().
CheckShouldDropChannel(mock.Anything).
Run(func(channel string) {
channels, err := chManager.store.Delete(1)
assert.NoError(t, err)
assert.Equal(t, 1, len(channels))
}).Return(true).Once()
chManager, err = NewChannelManager(watchkv, handler)
require.NoError(t, err)
chManager.store.Add(1)
chManager.store.AddNode(1)
ops := getTestOps(1, &channelMeta{Name: "chan", CollectionID: 1, WatchInfo: &datapb.ChannelWatchInfo{}})
err = chManager.store.Update(ops)
require.NoError(t, err)
assert.Equal(t, 1, chManager.store.GetNodeChannelCount(1))
err = chManager.Reassign(1, "chan")
err = chManager.Reassign(2, "chan")
assert.Error(t, err)
})
@ -635,24 +628,18 @@ func TestChannelManager(t *testing.T) {
var chManager *ChannelManagerImpl
var err error
handler := NewNMockHandler(t)
handler.EXPECT().
CheckShouldDropChannel(mock.Anything).
Run(func(channel string) {
channels, err := chManager.store.Delete(1)
assert.NoError(t, err)
assert.Equal(t, 1, len(channels))
}).Return(true).Once()
watchkv.RemoveWithPrefix("")
chManager, err = NewChannelManager(watchkv, handler)
require.NoError(t, err)
chManager.store.Add(1)
chManager.store.AddNode(1)
ops := getTestOps(1, &channelMeta{Name: "chan", CollectionID: 1, WatchInfo: &datapb.ChannelWatchInfo{}})
err = chManager.store.Update(ops)
require.NoError(t, err)
assert.Equal(t, 1, chManager.store.GetNodeChannelCount(1))
err = chManager.CleanupAndReassign(1, "chan")
err = chManager.CleanupAndReassign(2, "chan")
assert.Error(t, err)
})
@ -670,10 +657,11 @@ func TestChannelManager(t *testing.T) {
CheckShouldDropChannel(mock.Anything).
Return(true)
handler.EXPECT().FinishDropChannel(mock.Anything, mock.Anything).Return(nil)
watchkv.RemoveWithPrefix("")
chManager, err := NewChannelManager(watchkv, handler)
require.NoError(t, err)
chManager.store.Add(1)
chManager.store.AddNode(1)
ops := getTestOps(1, &channelMeta{Name: "chan", CollectionID: 1, WatchInfo: &datapb.ChannelWatchInfo{}})
err = chManager.store.Update(ops)
require.NoError(t, err)
@ -728,7 +716,7 @@ func TestChannelManager(t *testing.T) {
// prepare tests
for _, test := range tests {
chManager.store.Add(test.nodeID)
chManager.store.AddNode(test.nodeID)
ops := getTestOps(test.nodeID, &channelMeta{Name: test.chName, CollectionID: collectionID, WatchInfo: &datapb.ChannelWatchInfo{}})
err = chManager.store.Update(ops)
require.NoError(t, err)
@ -776,7 +764,7 @@ func TestChannelManager(t *testing.T) {
ch := chManager.getChannelByNodeAndName(nodeID, channelName)
assert.Nil(t, ch)
chManager.store.Add(nodeID)
chManager.store.AddNode(nodeID)
ch = chManager.getChannelByNodeAndName(nodeID, channelName)
assert.Nil(t, ch)
@ -837,7 +825,7 @@ func TestChannelManager(t *testing.T) {
chManager, err := NewChannelManager(watchkv, newMockHandler())
require.NoError(t, err)
chManager.store.Add(nodeID)
chManager.store.AddNode(nodeID)
opSet := NewChannelOpSet(NewAddOp(nodeID, &channelMeta{Name: channelName, CollectionID: collectionID}))
@ -864,7 +852,7 @@ func TestChannelManager(t *testing.T) {
chManager, err := NewChannelManager(watchkv, newMockHandler(), withBgChecker())
require.NoError(t, err)
assert.NotNil(t, chManager.bgChecker)
chManager.Startup(ctx, []int64{nodeID})
chManager.Startup(ctx, nil, []int64{nodeID})
// 2. test isSilent function running correctly
Params.Save(Params.DataCoordCfg.ChannelBalanceSilentDuration.Key, "3")
@ -1049,7 +1037,7 @@ func TestChannelManager_Reload(t *testing.T) {
cm2, err := NewChannelManager(watchkv, newMockHandler())
assert.NoError(t, err)
assert.Nil(t, cm2.Startup(ctx, []int64{3}))
assert.Nil(t, cm2.Startup(ctx, nil, []int64{3}))
waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, 3, "channel1", 1)
waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, 3, "channel2", 1)

View File

@ -0,0 +1,727 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"fmt"
"sync"
"time"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/kv"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type ChannelManager interface {
Startup(ctx context.Context, legacyNodes, allNodes []int64) error
Close()
AddNode(nodeID UniqueID) error
DeleteNode(nodeID UniqueID) error
Watch(ctx context.Context, ch RWChannel) error
Release(nodeID UniqueID, channelName string) error
Match(nodeID UniqueID, channel string) bool
FindWatcher(channel string) (UniqueID, error)
GetChannel(nodeID int64, channel string) (RWChannel, bool)
GetNodeIDByChannelName(channel string) (int64, bool)
GetNodeChannelsByCollectionID(collectionID int64) map[int64][]string
GetChannelsByCollectionID(collectionID int64) []RWChannel
GetChannelNamesByCollectionID(collectionID int64) []string
}
// An interface sessionManager implments
type SubCluster interface {
NotifyChannelOperation(ctx context.Context, nodeID int64, req *datapb.ChannelOperationsRequest) error
CheckChannelOperationProgress(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error)
}
type ChannelManagerImplV2 struct {
ctx context.Context
cancel context.CancelFunc
mu sync.RWMutex
h Handler
store RWChannelStore
subCluster SubCluster // sessionManager
allocator allocator
factory ChannelPolicyFactory
balancePolicy BalanceChannelPolicy
balanceCheckLoop ChannelBGChecker
legacyNodes typeutil.UniqueSet
lastActiveTimestamp time.Time
}
// ChannelBGChecker are goroutining running background
type ChannelBGChecker func(ctx context.Context)
// ChannelmanagerOptV2 is to set optional parameters in channel manager.
type ChannelmanagerOptV2 func(c *ChannelManagerImplV2)
func withFactoryV2(f ChannelPolicyFactory) ChannelmanagerOptV2 {
return func(c *ChannelManagerImplV2) { c.factory = f }
}
func withCheckerV2() ChannelmanagerOptV2 {
return func(c *ChannelManagerImplV2) { c.balanceCheckLoop = c.CheckLoop }
}
func NewChannelManagerV2(
kv kv.TxnKV,
h Handler,
subCluster SubCluster, // sessionManager
alloc allocator,
options ...ChannelmanagerOptV2,
) (*ChannelManagerImplV2, error) {
m := &ChannelManagerImplV2{
h: h,
ctx: context.TODO(), // TODO
factory: NewChannelPolicyFactoryV1(),
store: NewChannelStoreV2(kv),
subCluster: subCluster,
allocator: alloc,
}
if err := m.store.Reload(); err != nil {
return nil, err
}
for _, opt := range options {
opt(m)
}
m.balancePolicy = m.factory.NewBalancePolicy()
m.lastActiveTimestamp = time.Now()
return m, nil
}
func (m *ChannelManagerImplV2) Startup(ctx context.Context, legacyNodes, allNodes []int64) error {
m.ctx, m.cancel = context.WithCancel(ctx)
m.legacyNodes = typeutil.NewUniqueSet(legacyNodes...)
m.mu.Lock()
m.store.SetLegacyChannelByNode(legacyNodes...)
oNodes := m.store.GetNodes()
m.mu.Unlock()
// Add new online nodes to the cluster.
offLines, newOnLines := lo.Difference(oNodes, allNodes)
lo.ForEach(newOnLines, func(nodeID int64, _ int) {
m.AddNode(nodeID)
})
// Delete offlines from the cluster
lo.ForEach(offLines, func(nodeID int64, _ int) {
m.DeleteNode(nodeID)
})
m.mu.Lock()
nodeChannels := m.store.GetNodeChannelsBy(
WithAllNodes(),
func(ch *StateChannel) bool {
return m.h.CheckShouldDropChannel(ch.GetName())
})
m.mu.Unlock()
for _, info := range nodeChannels {
m.finishRemoveChannel(info.NodeID, lo.Values(info.Channels)...)
}
if m.balanceCheckLoop != nil {
log.Info("starting channel balance loop")
go m.balanceCheckLoop(m.ctx)
}
log.Info("cluster start up",
zap.Int64s("allNodes", allNodes),
zap.Int64s("legacyNodes", legacyNodes),
zap.Int64s("oldNodes", oNodes),
zap.Int64s("newOnlines", newOnLines),
zap.Int64s("offLines", offLines))
return nil
}
func (m *ChannelManagerImplV2) Close() {
if m.cancel != nil {
m.cancel()
}
}
func (m *ChannelManagerImplV2) AddNode(nodeID UniqueID) error {
m.mu.Lock()
defer m.mu.Unlock()
log.Info("register node", zap.Int64("registered node", nodeID))
m.store.AddNode(nodeID)
updates := AvgAssignByCountPolicy(m.store.GetNodesChannels(), m.store.GetBufferChannelInfo().GetChannels(), m.legacyNodes.Collect())
if updates == nil {
log.Info("register node with no reassignment", zap.Int64("registered node", nodeID))
return nil
}
err := m.execute(updates)
if err != nil {
log.Warn("fail to update channel operation updates into meta", zap.Error(err))
}
return err
}
// Release writes ToRelease channel watch states for a channel
func (m *ChannelManagerImplV2) Release(nodeID UniqueID, channelName string) error {
log := log.With(
zap.Int64("nodeID", nodeID),
zap.String("channel", channelName),
)
// channel in bufferID are released already
if nodeID == bufferID {
return nil
}
log.Info("Releasing channel from watched node")
ch, found := m.GetChannel(nodeID, channelName)
if !found {
return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName)
}
m.mu.Lock()
defer m.mu.Unlock()
updates := NewChannelOpSet(NewChannelOp(nodeID, Release, ch))
return m.execute(updates)
}
func (m *ChannelManagerImplV2) Watch(ctx context.Context, ch RWChannel) error {
log := log.Ctx(ctx).With(zap.String("channel", ch.GetName()))
m.mu.Lock()
defer m.mu.Unlock()
log.Info("Add channel")
updates := NewChannelOpSet(NewChannelOp(bufferID, Watch, ch))
err := m.execute(updates)
if err != nil {
log.Warn("fail to update new channel updates into meta",
zap.Array("updates", updates), zap.Error(err))
}
// channel already written into meta, try to assign it to the cluster
// not error is returned if failed, the assignment will retry later
updates = AvgAssignByCountPolicy(m.store.GetNodesChannels(), []RWChannel{ch}, m.legacyNodes.Collect())
if updates == nil {
return nil
}
if err := m.execute(updates); err != nil {
log.Warn("fail to assign channel, will retry later", zap.Array("updates", updates), zap.Error(err))
return nil
}
log.Info("Assign channel", zap.Array("updates", updates))
return nil
}
func (m *ChannelManagerImplV2) DeleteNode(nodeID UniqueID) error {
m.mu.Lock()
defer m.mu.Unlock()
m.legacyNodes.Remove(nodeID)
info := m.store.GetNode(nodeID)
if info == nil || len(info.Channels) == 0 {
if nodeID != bufferID {
m.store.RemoveNode(nodeID)
}
return nil
}
updates := NewChannelOpSet(
NewDeleteOp(info.NodeID, lo.Values(info.Channels)...),
NewChannelOp(bufferID, Watch, lo.Values(info.Channels)...),
)
log.Info("deregister node", zap.Int64("nodeID", nodeID), zap.Array("updates", updates))
err := m.execute(updates)
if err != nil {
log.Warn("fail to update channel operation updates into meta", zap.Error(err))
return err
}
if nodeID != bufferID {
m.store.RemoveNode(nodeID)
}
return nil
}
// reassign reassigns a channel to another DataNode.
func (m *ChannelManagerImplV2) reassign(original *NodeChannelInfo) error {
m.mu.Lock()
defer m.mu.Unlock()
updates := AvgAssignByCountPolicy(m.store.GetNodesChannels(), original.GetChannels(), m.legacyNodes.Collect())
if updates != nil {
return m.execute(updates)
}
if original.NodeID != bufferID {
log.RatedWarn(5.0, "Failed to reassign channel to other nodes, assign to the original nodes",
zap.Any("original node", original.NodeID),
zap.Strings("channels", lo.Keys(original.Channels)),
)
updates := NewChannelOpSet(NewChannelOp(original.NodeID, Watch, lo.Values(original.Channels)...))
return m.execute(updates)
}
return nil
}
func (m *ChannelManagerImplV2) Balance() {
m.mu.Lock()
defer m.mu.Unlock()
watchedCluster := m.store.GetNodeChannelsBy(WithoutBufferNode(), WithChannelStates(Watched))
updates := m.balancePolicy(watchedCluster)
if updates == nil {
return
}
log.Info("Channel balancer got new reAllocations:", zap.Array("assignment", updates))
if err := m.execute(updates); err != nil {
log.Warn("Channel balancer fail to execute", zap.Array("assignment", updates), zap.Error(err))
}
}
func (m *ChannelManagerImplV2) Match(nodeID UniqueID, channel string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
info := m.store.GetNode(nodeID)
if info == nil {
return false
}
_, ok := info.Channels[channel]
return ok
}
func (m *ChannelManagerImplV2) GetChannel(nodeID int64, channelName string) (RWChannel, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
if nodeChannelInfo := m.store.GetNode(nodeID); nodeChannelInfo != nil {
if ch, ok := nodeChannelInfo.Channels[channelName]; ok {
return ch, true
}
}
return nil, false
}
func (m *ChannelManagerImplV2) GetNodeIDByChannelName(channel string) (int64, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
nodeChannels := m.store.GetNodeChannelsBy(
WithoutBufferNode(),
WithChannelName(channel))
if len(nodeChannels) > 0 {
return nodeChannels[0].NodeID, true
}
return 0, false
}
func (m *ChannelManagerImplV2) GetNodeChannelsByCollectionID(collectionID int64) map[int64][]string {
m.mu.RLock()
defer m.mu.RUnlock()
nodeChs := make(map[UniqueID][]string)
nodeChannels := m.store.GetNodeChannelsBy(
WithoutBufferNode(),
WithCollectionIDV2(collectionID))
lo.ForEach(nodeChannels, func(info *NodeChannelInfo, _ int) {
nodeChs[info.NodeID] = lo.Keys(info.Channels)
})
return nodeChs
}
func (m *ChannelManagerImplV2) GetChannelsByCollectionID(collectionID int64) []RWChannel {
m.mu.RLock()
defer m.mu.RUnlock()
channels := []RWChannel{}
nodeChannels := m.store.GetNodeChannelsBy(
WithAllNodes(),
WithCollectionIDV2(collectionID))
lo.ForEach(nodeChannels, func(info *NodeChannelInfo, _ int) {
channels = append(channels, lo.Values(info.Channels)...)
})
return channels
}
func (m *ChannelManagerImplV2) GetChannelNamesByCollectionID(collectionID int64) []string {
channels := m.GetChannelsByCollectionID(collectionID)
return lo.Map(channels, func(ch RWChannel, _ int) string {
return ch.GetName()
})
}
func (m *ChannelManagerImplV2) FindWatcher(channel string) (UniqueID, error) {
m.mu.RLock()
defer m.mu.RUnlock()
infos := m.store.GetNodesChannels()
for _, info := range infos {
for _, channelInfo := range info.Channels {
if channelInfo.GetName() == channel {
return info.NodeID, nil
}
}
}
// channel in buffer
bufferInfo := m.store.GetBufferChannelInfo()
for _, channelInfo := range bufferInfo.Channels {
if channelInfo.GetName() == channel {
return bufferID, errChannelInBuffer
}
}
return 0, errChannelNotWatched
}
// unsafe innter func
func (m *ChannelManagerImplV2) removeChannel(nodeID int64, ch RWChannel) error {
op := NewChannelOpSet(NewChannelOp(nodeID, Delete, ch))
log.Info("remove channel assignment",
zap.String("channel", ch.GetName()),
zap.Int64("assignment", nodeID),
zap.Int64("collectionID", ch.GetCollectionID()))
return m.store.Update(op)
}
func (m *ChannelManagerImplV2) CheckLoop(ctx context.Context) {
balanceTicker := time.NewTicker(Params.DataCoordCfg.ChannelBalanceInterval.GetAsDuration(time.Second))
defer balanceTicker.Stop()
checkTicker := time.NewTicker(Params.DataCoordCfg.ChannelCheckInterval.GetAsDuration(time.Second))
defer checkTicker.Stop()
for {
select {
case <-ctx.Done():
log.Info("background checking channels loop quit")
return
case <-balanceTicker.C:
// balance
if time.Since(m.lastActiveTimestamp) >= Params.DataCoordCfg.ChannelBalanceSilentDuration.GetAsDuration(time.Second) {
m.Balance()
}
case <-checkTicker.C:
m.AdvanceChannelState()
}
}
}
func (m *ChannelManagerImplV2) AdvanceChannelState() {
m.mu.RLock()
standbys := m.store.GetNodeChannelsBy(WithAllNodes(), WithChannelStates(Standby))
toNotifies := m.store.GetNodeChannelsBy(WithoutBufferNode(), WithChannelStates(ToWatch, ToRelease))
toChecks := m.store.GetNodeChannelsBy(WithoutBufferNode(), WithChannelStates(Watching, Releasing))
m.mu.RUnlock()
// Processing standby channels
updatedStandbys := m.advanceStandbys(standbys)
updatedToCheckes := m.advanceToChecks(toChecks)
updatedToNotifies := m.advanceToNotifies(toNotifies)
if updatedStandbys || updatedToCheckes || updatedToNotifies {
m.lastActiveTimestamp = time.Now()
}
}
func (m *ChannelManagerImplV2) finishRemoveChannel(nodeID int64, channels ...RWChannel) {
m.mu.Lock()
defer m.mu.Unlock()
for _, ch := range channels {
if err := m.removeChannel(nodeID, ch); err != nil {
log.Warn("Failed to remove channel", zap.Any("channel", ch), zap.Error(err))
continue
}
if err := m.h.FinishDropChannel(ch.GetName(), ch.GetCollectionID()); err != nil {
log.Warn("Failed to finish drop channel", zap.Any("channel", ch), zap.Error(err))
continue
}
}
}
func (m *ChannelManagerImplV2) advanceStandbys(standbys []*NodeChannelInfo) bool {
var advanced bool = false
for _, nodeAssign := range standbys {
validChannels := make(map[string]RWChannel)
for chName, ch := range nodeAssign.Channels {
// drop marked-drop channels
if m.h.CheckShouldDropChannel(chName) {
m.finishRemoveChannel(nodeAssign.NodeID, ch)
continue
}
validChannels[chName] = ch
}
nodeAssign.Channels = validChannels
if len(nodeAssign.Channels) == 0 {
continue
}
chNames := lo.Keys(validChannels)
if err := m.reassign(nodeAssign); err != nil {
log.Warn("Reassign channels fail",
zap.Int64("nodeID", nodeAssign.NodeID),
zap.Strings("channels", chNames),
)
}
log.Info("Reassign standby channels to node",
zap.Int64("nodeID", nodeAssign.NodeID),
zap.Strings("channels", chNames),
)
advanced = true
}
return advanced
}
func (m *ChannelManagerImplV2) advanceToNotifies(toNotifies []*NodeChannelInfo) bool {
var advanced bool = false
for _, nodeAssign := range toNotifies {
channelCount := len(nodeAssign.Channels)
if channelCount == 0 {
continue
}
var (
succeededChannels = make([]RWChannel, 0, channelCount)
failedChannels = make([]RWChannel, 0, channelCount)
futures = make([]*conc.Future[any], 0, channelCount)
)
chNames := lo.Keys(nodeAssign.Channels)
log.Info("Notify channel operations to datanode",
zap.Int64("assignment", nodeAssign.NodeID),
zap.Int("total operation count", len(nodeAssign.Channels)),
zap.Strings("channel names", chNames),
)
for _, ch := range nodeAssign.Channels {
innerCh := ch
future := getOrCreateIOPool().Submit(func() (any, error) {
err := m.Notify(nodeAssign.NodeID, innerCh.GetWatchInfo())
return innerCh, err
})
futures = append(futures, future)
}
for _, f := range futures {
ch, err := f.Await()
if err != nil {
failedChannels = append(failedChannels, ch.(RWChannel))
} else {
succeededChannels = append(succeededChannels, ch.(RWChannel))
advanced = true
}
}
log.Info("Finish to notify channel operations to datanode",
zap.Int64("assignment", nodeAssign.NodeID),
zap.Int("operation count", channelCount),
zap.Int("success count", len(succeededChannels)),
zap.Int("failure count", len(failedChannels)),
)
m.mu.Lock()
m.store.UpdateState(false, failedChannels...)
m.store.UpdateState(true, succeededChannels...)
m.mu.Unlock()
}
return advanced
}
type poolResult struct {
successful bool
ch RWChannel
}
func (m *ChannelManagerImplV2) advanceToChecks(toChecks []*NodeChannelInfo) bool {
var advanced bool = false
for _, nodeAssign := range toChecks {
if len(nodeAssign.Channels) == 0 {
continue
}
futures := make([]*conc.Future[any], 0, len(nodeAssign.Channels))
chNames := lo.Keys(nodeAssign.Channels)
log.Info("Check ToWatch/ToRelease channel operations progress",
zap.Int("channel count", len(nodeAssign.Channels)),
zap.Strings("channel names", chNames),
)
for _, ch := range nodeAssign.Channels {
innerCh := ch
future := getOrCreateIOPool().Submit(func() (any, error) {
successful, got := m.Check(nodeAssign.NodeID, innerCh.GetWatchInfo())
if got {
return poolResult{
successful: successful,
ch: innerCh,
}, nil
}
return nil, errors.New("Got results with no progress")
})
futures = append(futures, future)
}
for _, f := range futures {
got, err := f.Await()
if err == nil {
m.mu.Lock()
result := got.(poolResult)
m.store.UpdateState(result.successful, result.ch)
m.mu.Unlock()
advanced = true
}
}
log.Info("Finish to Check ToWatch/ToRelease channel operations progress",
zap.Int("channel count", len(nodeAssign.Channels)),
zap.Strings("channel names", chNames),
)
}
return advanced
}
func (m *ChannelManagerImplV2) Notify(nodeID int64, info *datapb.ChannelWatchInfo) error {
log := log.With(
zap.String("channel", info.GetVchan().GetChannelName()),
zap.Int64("assignment", nodeID),
zap.String("operation", info.GetState().String()),
)
log.Info("Notify channel operation")
err := m.subCluster.NotifyChannelOperation(m.ctx, nodeID, &datapb.ChannelOperationsRequest{Infos: []*datapb.ChannelWatchInfo{info}})
if err != nil {
log.Warn("Fail to notify channel operations", zap.Error(err))
return err
}
log.Debug("Success to notify channel operations")
return nil
}
func (m *ChannelManagerImplV2) Check(nodeID int64, info *datapb.ChannelWatchInfo) (successful bool, got bool) {
log := log.With(
zap.Int64("opID", info.GetOpID()),
zap.Int64("nodeID", nodeID),
zap.String("check operation", info.GetState().String()),
zap.String("channel", info.GetVchan().GetChannelName()),
)
resp, err := m.subCluster.CheckChannelOperationProgress(m.ctx, nodeID, info)
if err != nil {
log.Warn("Fail to check channel operation progress")
return false, false
}
log.Info("Got channel operation progress",
zap.String("got state", resp.GetState().String()),
zap.Int32("progress", resp.GetProgress()))
switch info.GetState() {
case datapb.ChannelWatchState_ToWatch:
if resp.GetState() == datapb.ChannelWatchState_ToWatch {
return false, false
}
if resp.GetState() == datapb.ChannelWatchState_WatchSuccess {
return true, true
}
if resp.GetState() == datapb.ChannelWatchState_WatchFailure {
return false, true
}
case datapb.ChannelWatchState_ToRelease:
if resp.GetState() == datapb.ChannelWatchState_ToRelease {
return false, false
}
if resp.GetState() == datapb.ChannelWatchState_ReleaseSuccess {
return true, true
}
if resp.GetState() == datapb.ChannelWatchState_ReleaseFailure {
return false, true
}
}
return false, false
}
func (m *ChannelManagerImplV2) execute(updates *ChannelOpSet) error {
for _, op := range updates.ops {
if op.Type != Delete {
if err := m.fillChannelWatchInfo(op); err != nil {
log.Warn("fail to fill channel watch info", zap.Error(err))
return err
}
}
}
return m.store.Update(updates)
}
// fillChannelWatchInfoWithState updates the channel op by filling in channel watch info.
func (m *ChannelManagerImplV2) fillChannelWatchInfo(op *ChannelOp) error {
startTs := time.Now().Unix()
for _, ch := range op.Channels {
vcInfo := m.h.GetDataVChanPositions(ch, allPartitionID)
opID, err := m.allocator.allocID(context.Background())
if err != nil {
return err
}
info := &datapb.ChannelWatchInfo{
Vchan: vcInfo,
StartTs: startTs,
State: inferStateByOpType(op.Type),
Schema: ch.GetSchema(),
OpID: opID,
}
ch.UpdateWatchInfo(info)
}
return nil
}
func inferStateByOpType(opType ChannelOpType) datapb.ChannelWatchState {
switch opType {
case Watch:
return datapb.ChannelWatchState_ToWatch
case Release:
return datapb.ChannelWatchState_ToRelease
default:
return datapb.ChannelWatchState_ToWatch
}
}

View File

@ -0,0 +1,661 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"fmt"
"testing"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
kvmock "github.com/milvus-io/milvus/internal/kv/mocks"
"github.com/milvus-io/milvus/internal/kv/predicates"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func TestChannelManagerSuite(t *testing.T) {
suite.Run(t, new(ChannelManagerSuite))
}
type ChannelManagerSuite struct {
suite.Suite
mockKv *kvmock.MetaKv
mockCluster *MockSubCluster
mockAlloc *NMockAllocator
mockHandler *NMockHandler
}
func (s *ChannelManagerSuite) prepareMeta(chNodes map[string]int64, state datapb.ChannelWatchState) {
s.SetupTest()
if chNodes == nil {
s.mockKv.EXPECT().LoadWithPrefix(mock.Anything).Return(nil, nil, nil).Once()
return
}
var keys, values []string
for channel, nodeID := range chNodes {
keys = append(keys, fmt.Sprintf("channel_store/%d/%s", nodeID, channel))
info := generateWatchInfo(channel, state)
bs, err := proto.Marshal(info)
s.Require().NoError(err)
values = append(values, string(bs))
}
s.mockKv.EXPECT().LoadWithPrefix(mock.Anything).Return(keys, values, nil).Once()
}
func (s *ChannelManagerSuite) checkAssignment(m *ChannelManagerImplV2, nodeID int64, channel string, state ChannelState) {
rwChannel, found := m.GetChannel(nodeID, channel)
s.True(found)
s.NotNil(rwChannel)
s.Equal(channel, rwChannel.GetName())
sChannel, ok := rwChannel.(*StateChannel)
s.True(ok)
s.Equal(state, sChannel.currentState)
s.EqualValues(nodeID, sChannel.assignedNode)
s.True(m.Match(nodeID, channel))
if nodeID != bufferID {
gotNode, err := m.FindWatcher(channel)
s.NoError(err)
s.EqualValues(gotNode, nodeID)
}
}
func (s *ChannelManagerSuite) checkNoAssignment(m *ChannelManagerImplV2, nodeID int64, channel string) {
rwChannel, found := m.GetChannel(nodeID, channel)
s.False(found)
s.Nil(rwChannel)
s.False(m.Match(nodeID, channel))
}
func (s *ChannelManagerSuite) SetupTest() {
s.mockKv = kvmock.NewMetaKv(s.T())
s.mockCluster = NewMockSubCluster(s.T())
s.mockAlloc = NewNMockAllocator(s.T())
s.mockHandler = NewNMockHandler(s.T())
s.mockHandler.EXPECT().GetDataVChanPositions(mock.Anything, mock.Anything).
RunAndReturn(func(ch RWChannel, partitionID UniqueID) *datapb.VchannelInfo {
return &datapb.VchannelInfo{
CollectionID: ch.GetCollectionID(),
ChannelName: ch.GetName(),
}
}).Maybe()
s.mockAlloc.EXPECT().allocID(mock.Anything).Return(19530, nil).Maybe()
s.mockKv.EXPECT().MultiSaveAndRemove(mock.Anything, mock.Anything).RunAndReturn(
func(save map[string]string, removals []string, preds ...predicates.Predicate) error {
log.Info("test save and remove", zap.Any("save", save), zap.Any("removals", removals))
return nil
}).Maybe()
}
func (s *ChannelManagerSuite) TearDownTest() {}
func (s *ChannelManagerSuite) TestAddNode() {
s.Run("AddNode with empty store", func() {
s.prepareMeta(nil, 0)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
var testNode int64 = 1
err = m.AddNode(testNode)
s.NoError(err)
info := m.store.GetNode(testNode)
s.NotNil(info)
s.Empty(info.Channels)
s.Equal(info.NodeID, testNode)
})
s.Run("AddNode with channel in bufferID", func() {
chNodes := map[string]int64{
"ch1": bufferID,
"ch2": bufferID,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
var (
testNodeID int64 = 1
testChannels = []string{"ch1", "ch2"}
)
lo.ForEach(testChannels, func(ch string, _ int) {
s.checkAssignment(m, bufferID, ch, Standby)
})
err = m.AddNode(testNodeID)
s.NoError(err)
lo.ForEach(testChannels, func(ch string, _ int) {
s.checkAssignment(m, testNodeID, ch, ToWatch)
})
})
s.Run("AddNode with channels evenly in other node", func() {
var (
testNodeID int64 = 100
storedNodeID int64 = 1
testChannel = "ch1"
)
chNodes := map[string]int64{testChannel: storedNodeID}
s.prepareMeta(chNodes, datapb.ChannelWatchState_WatchSuccess)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, storedNodeID, testChannel, Watched)
err = m.AddNode(testNodeID)
s.NoError(err)
s.ElementsMatch([]int64{100, 1}, m.store.GetNodes())
s.checkNoAssignment(m, testNodeID, testChannel)
testNodeID = 101
paramtable.Get().Save(paramtable.Get().DataCoordCfg.AutoBalance.Key, "true")
defer paramtable.Get().Reset(paramtable.Get().DataCoordCfg.AutoBalance.Key)
err = m.AddNode(testNodeID)
s.NoError(err)
s.ElementsMatch([]int64{100, 101, 1}, m.store.GetNodes())
s.checkNoAssignment(m, testNodeID, testChannel)
})
s.Run("AddNode with channels unevenly in other node", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
"ch3": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_WatchSuccess)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
var testNodeID int64 = 100
paramtable.Get().Save(paramtable.Get().DataCoordCfg.AutoBalance.Key, "true")
defer paramtable.Get().Reset(paramtable.Get().DataCoordCfg.AutoBalance.Key)
err = m.AddNode(testNodeID)
s.NoError(err)
s.ElementsMatch([]int64{testNodeID, 1}, m.store.GetNodes())
})
}
func (s *ChannelManagerSuite) TestWatch() {
s.Run("test Watch with empty store", func() {
s.prepareMeta(nil, 0)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
var testCh string = "ch1"
err = m.Watch(context.TODO(), getChannel(testCh, 1))
s.NoError(err)
s.checkAssignment(m, bufferID, testCh, Standby)
})
s.Run("test Watch with nodeID in store", func() {
s.prepareMeta(nil, 0)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
var (
testCh string = "ch1"
testNodeID int64 = 1
)
err = m.AddNode(testNodeID)
s.NoError(err)
s.checkNoAssignment(m, testNodeID, testCh)
err = m.Watch(context.TODO(), getChannel(testCh, 1))
s.NoError(err)
s.checkAssignment(m, testNodeID, testCh, ToWatch)
})
}
func (s *ChannelManagerSuite) TestRelease() {
s.Run("release not exist nodeID and channel", func() {
s.prepareMeta(nil, 0)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
err = m.Release(1, "ch1")
s.Error(err)
log.Info("error", zap.String("msg", err.Error()))
m.AddNode(1)
err = m.Release(1, "ch1")
s.Error(err)
log.Info("error", zap.String("msg", err.Error()))
})
s.Run("release channel in bufferID", func() {
s.prepareMeta(nil, 0)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
m.Watch(context.TODO(), getChannel("ch1", 1))
s.checkAssignment(m, bufferID, "ch1", Standby)
err = m.Release(bufferID, "ch1")
s.NoError(err)
s.checkAssignment(m, bufferID, "ch1", Standby)
})
}
func (s *ChannelManagerSuite) TestDeleteNode() {
s.Run("delete not exsit node", func() {
s.prepareMeta(nil, 0)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
info := m.store.GetNode(1)
s.Require().Nil(info)
err = m.DeleteNode(1)
s.NoError(err)
})
s.Run("delete bufferID", func() {
s.prepareMeta(nil, 0)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
info := m.store.GetNode(bufferID)
s.Require().NotNil(info)
err = m.DeleteNode(bufferID)
s.NoError(err)
})
s.Run("delete node without assigment", func() {
s.prepareMeta(nil, 0)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
err = m.AddNode(1)
s.NoError(err)
info := m.store.GetNode(bufferID)
s.Require().NotNil(info)
err = m.DeleteNode(1)
s.NoError(err)
info = m.store.GetNode(1)
s.Nil(info)
})
s.Run("delete node with channel", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
"ch3": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_WatchSuccess)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, 1, "ch1", Watched)
s.checkAssignment(m, 1, "ch2", Watched)
s.checkAssignment(m, 1, "ch3", Watched)
err = m.AddNode(2)
s.NoError(err)
err = m.DeleteNode(1)
s.NoError(err)
info := m.store.GetNode(bufferID)
s.NotNil(info)
s.Equal(3, len(info.Channels))
s.EqualValues(bufferID, info.NodeID)
s.checkAssignment(m, bufferID, "ch1", Standby)
s.checkAssignment(m, bufferID, "ch2", Standby)
s.checkAssignment(m, bufferID, "ch3", Standby)
info = m.store.GetNode(1)
s.Nil(info)
})
}
func (s *ChannelManagerSuite) TestFindWatcher() {
chNodes := map[string]int64{
"ch1": bufferID,
"ch2": bufferID,
"ch3": 1,
"ch4": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_WatchSuccess)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
tests := []struct {
description string
testCh string
outNodeID int64
outError bool
}{
{"channel not exist", "ch-notexist", 0, true},
{"channel in bufferID", "ch1", bufferID, true},
{"channel in bufferID", "ch2", bufferID, true},
{"channel in nodeID=1", "ch3", 1, false},
{"channel in nodeID=1", "ch4", 1, false},
}
for _, test := range tests {
s.Run(test.description, func() {
gotID, gotErr := m.FindWatcher(test.testCh)
s.EqualValues(test.outNodeID, gotID)
if test.outError {
s.Error(gotErr)
} else {
s.NoError(gotErr)
}
})
}
}
func (s *ChannelManagerSuite) TestAdvanceChannelState() {
s.Run("advance statndby with no available nodes", func() {
chNodes := map[string]int64{
"ch1": bufferID,
"ch2": bufferID,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch)
s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, bufferID, "ch1", Standby)
s.checkAssignment(m, bufferID, "ch2", Standby)
m.AdvanceChannelState()
s.checkAssignment(m, bufferID, "ch1", Standby)
s.checkAssignment(m, bufferID, "ch2", Standby)
})
s.Run("advance statndby with node 1", func() {
chNodes := map[string]int64{
"ch1": bufferID,
"ch2": bufferID,
"ch3": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_WatchSuccess)
s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false).Times(2)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, bufferID, "ch1", Standby)
s.checkAssignment(m, bufferID, "ch2", Standby)
s.checkAssignment(m, 1, "ch3", Watched)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", ToWatch)
s.checkAssignment(m, 1, "ch2", ToWatch)
})
s.Run("advance towatch channels notify success check success", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch)
s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice()
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, 1, "ch1", ToWatch)
s.checkAssignment(m, 1, "ch2", ToWatch)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Watching)
s.checkAssignment(m, 1, "ch2", Watching)
})
s.Run("advance watching channels check no progress", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch)
s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice()
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, 1, "ch1", ToWatch)
s.checkAssignment(m, 1, "ch2", ToWatch)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Watching)
s.checkAssignment(m, 1, "ch2", Watching)
s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).
Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ToWatch}, nil).Twice()
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Watching)
s.checkAssignment(m, 1, "ch2", Watching)
})
s.Run("advance watching channels check watch success", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch)
s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice()
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, 1, "ch1", ToWatch)
s.checkAssignment(m, 1, "ch2", ToWatch)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Watching)
s.checkAssignment(m, 1, "ch2", Watching)
s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).
Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_WatchSuccess}, nil).Twice()
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Watched)
s.checkAssignment(m, 1, "ch2", Watched)
})
s.Run("advance watching channels check watch fail", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch)
s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Times(2)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, 1, "ch1", ToWatch)
s.checkAssignment(m, 1, "ch2", ToWatch)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Watching)
s.checkAssignment(m, 1, "ch2", Watching)
s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).
Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_WatchFailure}, nil).Twice()
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Standby)
s.checkAssignment(m, 1, "ch2", Standby)
s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", ToWatch)
s.checkAssignment(m, 1, "ch2", ToWatch)
})
s.Run("advance releasing channels check release no progress", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease)
s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice()
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, 1, "ch1", ToRelease)
s.checkAssignment(m, 1, "ch2", ToRelease)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Releasing)
s.checkAssignment(m, 1, "ch2", Releasing)
s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).
Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ToRelease}, nil).Twice()
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Releasing)
s.checkAssignment(m, 1, "ch2", Releasing)
})
s.Run("advance releasing channels check release success", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease)
s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice()
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, 1, "ch1", ToRelease)
s.checkAssignment(m, 1, "ch2", ToRelease)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Releasing)
s.checkAssignment(m, 1, "ch2", Releasing)
s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).
Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ReleaseSuccess}, nil).Twice()
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Standby)
s.checkAssignment(m, 1, "ch2", Standby)
s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", ToWatch)
s.checkAssignment(m, 1, "ch2", ToWatch)
})
s.Run("advance releasing channels check release fail", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease)
s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice()
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, 1, "ch1", ToRelease)
s.checkAssignment(m, 1, "ch2", ToRelease)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Releasing)
s.checkAssignment(m, 1, "ch2", Releasing)
s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).
Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ReleaseFailure}, nil).Twice()
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Standby)
s.checkAssignment(m, 1, "ch2", Standby)
s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false)
m.AdvanceChannelState()
// TODO, donot assign to abnormal nodes
s.checkAssignment(m, 1, "ch1", ToWatch)
s.checkAssignment(m, 1, "ch2", ToWatch)
})
s.Run("advance towatch channels notify fail", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch)
s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).
Return(fmt.Errorf("mock error")).Twice()
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, 1, "ch1", ToWatch)
s.checkAssignment(m, 1, "ch2", ToWatch)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", ToWatch)
s.checkAssignment(m, 1, "ch2", ToWatch)
})
s.Run("advance to release channels notify success", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease)
s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice()
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, 1, "ch1", ToRelease)
s.checkAssignment(m, 1, "ch2", ToRelease)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", Releasing)
s.checkAssignment(m, 1, "ch2", Releasing)
})
s.Run("advance to release channels notify fail", func() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease)
s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).
Return(fmt.Errorf("mock error")).Twice()
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
s.checkAssignment(m, 1, "ch1", ToRelease)
s.checkAssignment(m, 1, "ch2", ToRelease)
m.AdvanceChannelState()
s.checkAssignment(m, 1, "ch1", ToRelease)
s.checkAssignment(m, 1, "ch2", ToRelease)
})
}
func (s *ChannelManagerSuite) TestStartup() {
chNodes := map[string]int64{
"ch1": 1,
"ch2": 1,
"ch3": 3,
}
s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease)
s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false)
m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc)
s.Require().NoError(err)
var (
legacyNodes = []int64{1}
allNodes = []int64{1}
)
err = m.Startup(context.TODO(), legacyNodes, allNodes)
s.NoError(err)
s.checkAssignment(m, 1, "ch1", Legacy)
s.checkAssignment(m, 1, "ch2", Legacy)
s.checkAssignment(m, bufferID, "ch3", Standby)
err = m.DeleteNode(1)
s.NoError(err)
s.checkAssignment(m, bufferID, "ch1", Standby)
s.checkAssignment(m, bufferID, "ch2", Standby)
err = m.AddNode(2)
s.NoError(err)
s.checkAssignment(m, 2, "ch1", ToWatch)
s.checkAssignment(m, 2, "ch2", ToWatch)
s.checkAssignment(m, 2, "ch3", ToWatch)
}
func (s *ChannelManagerSuite) TestCheckLoop() {}
func (s *ChannelManagerSuite) TestGet() {}

View File

@ -33,8 +33,51 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// ROChannelStore is a read only channel store for channels and nodes.
type ROChannelStore interface {
// GetNode returns the channel info of a specific node.
// Returns nil if the node doesn't belong to the cluster
GetNode(nodeID int64) *NodeChannelInfo
// HasChannel checks if store already has the channel
HasChannel(channel string) bool
// GetNodesChannels returns the channels that are assigned to nodes.
// without bufferID node
GetNodesChannels() []*NodeChannelInfo
// GetBufferChannelInfo gets the unassigned channels.
GetBufferChannelInfo() *NodeChannelInfo
// GetNodes gets all node ids in store.
GetNodes() []int64
// GetNodeChannelCount
GetNodeChannelCount(nodeID int64) int
// GetNodeChannelsBy used by channel_store_v2 and channel_manager_v2 only
GetNodeChannelsBy(nodeSelector NodeSelector, channelSelectors ...ChannelSelector) []*NodeChannelInfo
}
// RWChannelStore is the read write channel store for channels and nodes.
type RWChannelStore interface {
ROChannelStore
// Reload restores the buffer channels and node-channels mapping form kv.
Reload() error
// Add creates a new node-channels mapping, with no channels assigned to the node.
AddNode(nodeID int64)
// Delete removes nodeID and returns its channels.
RemoveNode(nodeID int64)
// Update applies the operations in ChannelOpSet.
Update(op *ChannelOpSet) error
// UpdateState is used by StateChannelStore only
UpdateState(isSuccessful bool, channels ...RWChannel)
// SegLegacyChannelByNode is used by StateChannelStore only
SetLegacyChannelByNode(nodeIDs ...int64)
}
// ChannelOpTypeNames implements zap log marshaller for ChannelOpSet.
var ChannelOpTypeNames = []string{"Add", "Delete", "Watch", "Release"}
const (
bufferID = math.MinInt64
delimiter = "/"
@ -49,6 +92,8 @@ type ChannelOpType int8
const (
Add ChannelOpType = iota
Delete
Watch
Release
)
// ChannelOp is an individual ADD or DELETE operation to the channel store.
@ -58,6 +103,14 @@ type ChannelOp struct {
Channels []RWChannel
}
func NewChannelOp(ID int64, opType ChannelOpType, channels ...RWChannel) *ChannelOp {
return &ChannelOp{
Type: opType,
NodeID: ID,
Channels: channels,
}
}
func NewAddOp(id int64, channels ...RWChannel) *ChannelOp {
return &ChannelOp{
NodeID: id,
@ -92,7 +145,7 @@ func (op *ChannelOp) BuildKV() (map[string]string, []string, error) {
for _, ch := range op.Channels {
k := buildNodeChannelKey(op.NodeID, ch.GetName())
switch op.Type {
case Add:
case Add, Watch, Release:
info, err := proto.Marshal(ch.GetWatchInfo())
if err != nil {
return saves, removals, err
@ -107,6 +160,24 @@ func (op *ChannelOp) BuildKV() (map[string]string, []string, error) {
return saves, removals, nil
}
// TODO: NIT: ObjectMarshaler -> ObjectMarshaller
// MarshalLogObject implements the interface ObjectMarshaler.
func (op *ChannelOp) MarshalLogObject(enc zapcore.ObjectEncoder) error {
enc.AddString("type", ChannelOpTypeNames[op.Type])
enc.AddInt64("nodeID", op.NodeID)
cstr := "["
if len(op.Channels) > 0 {
for _, s := range op.Channels {
cstr += s.GetName()
cstr += ", "
}
cstr = cstr[:len(cstr)-2]
}
cstr += "]"
enc.AddString("channels", cstr)
return nil
}
// ChannelOpSet is a set of channel operations.
type ChannelOpSet struct {
ops []*ChannelOp
@ -139,24 +210,31 @@ func (c *ChannelOpSet) Len() int {
}
// Add a new Add channel op, for ToWatch and ToRelease
func (c *ChannelOpSet) Add(id int64, channels ...RWChannel) {
c.ops = append(c.ops, NewAddOp(id, channels...))
func (c *ChannelOpSet) Add(ID int64, channels ...RWChannel) {
c.Append(ID, Add, channels...)
}
func (c *ChannelOpSet) Delete(id int64, channels ...RWChannel) {
c.ops = append(c.ops, NewDeleteOp(id, channels...))
func (c *ChannelOpSet) Delete(ID int64, channels ...RWChannel) {
c.Append(ID, Delete, channels...)
}
func (c *ChannelOpSet) Append(ID int64, opType ChannelOpType, channels ...RWChannel) {
c.ops = append(c.ops, NewChannelOp(ID, opType, channels...))
}
func (c *ChannelOpSet) GetChannelNumber() int {
if c == nil {
return 0
}
number := 0
uniqChannels := typeutil.NewSet[string]()
for _, op := range c.ops {
number += len(op.Channels)
uniqChannels.Insert(lo.Map(op.Channels, func(ch RWChannel, _ int) string {
return ch.GetName()
})...)
}
return number
return uniqChannels.Len()
}
func (c *ChannelOpSet) SplitByChannel() map[string]*ChannelOpSet {
@ -168,43 +246,19 @@ func (c *ChannelOpSet) SplitByChannel() map[string]*ChannelOpSet {
perChOps[ch.GetName()] = NewChannelOpSet()
}
if op.Type == Add {
perChOps[ch.GetName()].Add(op.NodeID, ch)
} else {
perChOps[ch.GetName()].Delete(op.NodeID, ch)
}
perChOps[ch.GetName()].Append(op.NodeID, op.Type, ch)
}
}
return perChOps
}
// ROChannelStore is a read only channel store for channels and nodes.
type ROChannelStore interface {
// GetNode returns the channel info of a specific node.
GetNode(nodeID int64) *NodeChannelInfo
// GetChannels returns info of all channels.
GetChannels() []*NodeChannelInfo
// GetNodesChannels returns the channels that are assigned to nodes.
GetNodesChannels() []*NodeChannelInfo
// GetBufferChannelInfo gets the unassigned channels.
GetBufferChannelInfo() *NodeChannelInfo
// GetNodes gets all node ids in store.
GetNodes() []int64
// GetNodeChannelCount
GetNodeChannelCount(nodeID int64) int
}
// RWChannelStore is the read write channel store for channels and nodes.
type RWChannelStore interface {
ROChannelStore
// Reload restores the buffer channels and node-channels mapping form kv.
Reload() error
// Add creates a new node-channels mapping, with no channels assigned to the node.
Add(nodeID int64)
// Delete removes nodeID and returns its channels.
Delete(nodeID int64) ([]RWChannel, error)
// Update applies the operations in ChannelOpSet.
Update(op *ChannelOpSet) error
// TODO: NIT: ArrayMarshaler -> ArrayMarshaller
// MarshalLogArray implements the interface of ArrayMarshaler of zap.
func (c *ChannelOpSet) MarshalLogArray(enc zapcore.ArrayEncoder) error {
for _, o := range c.Collect() {
enc.AppendObject(o)
}
return nil
}
// ChannelStore must satisfy RWChannelStore.
@ -246,6 +300,13 @@ func NewNodeChannelInfo(nodeID int64, channels ...RWChannel) *NodeChannelInfo {
return info
}
func (info *NodeChannelInfo) GetChannels() []RWChannel {
if info == nil {
return nil
}
return lo.Values(info.Channels)
}
// NewChannelStore creates and returns a new ChannelStore.
func NewChannelStore(kv kv.TxnKV) *ChannelStore {
c := &ChannelStore{
@ -280,7 +341,7 @@ func (c *ChannelStore) Reload() error {
}
reviseVChannelInfo(cw.GetVchan())
c.Add(nodeID)
c.AddNode(nodeID)
channel := &channelMeta{
Name: cw.GetVchan().GetChannelName(),
CollectionID: cw.GetVchan().GetCollectionID(),
@ -297,9 +358,9 @@ func (c *ChannelStore) Reload() error {
return nil
}
// Add creates a new node-channels mapping for the given node, and assigns no channels to it.
// AddNode creates a new node-channels mapping for the given node, and assigns no channels to it.
// Returns immediately if the node's already in the channel.
func (c *ChannelStore) Add(nodeID int64) {
func (c *ChannelStore) AddNode(nodeID int64) {
if _, ok := c.channelsInfo[nodeID]; ok {
return
}
@ -356,7 +417,7 @@ func (c *ChannelStore) update(opSet *ChannelOpSet) error {
// Update node id -> channel mapping.
for _, op := range opSet.Collect() {
switch op.Type {
case Add:
case Add, Watch, Release:
for _, ch := range op.Channels {
if c.checkIfExist(op.NodeID, ch) {
continue // prevent adding duplicated channel info
@ -420,16 +481,9 @@ func (c *ChannelStore) GetNodeChannelCount(nodeID int64) int {
return 0
}
// Delete removes the given node from the channel store and returns its channels.
func (c *ChannelStore) Delete(nodeID int64) ([]RWChannel, error) {
if info, ok := c.channelsInfo[nodeID]; ok {
if err := c.remove(nodeID); err != nil {
return nil, err
}
delete(c.channelsInfo, nodeID)
return lo.Values(info.Channels), nil
}
return nil, nil
// RemoveNode removes the given node from the channel store and returns its channels.
func (c *ChannelStore) RemoveNode(nodeID int64) {
delete(c.channelsInfo, nodeID)
}
// GetNodes returns a slice of all nodes ids in the current channel store.
@ -467,7 +521,32 @@ func (c *ChannelStore) txn(opSet *ChannelOpSet) error {
return c.store.MultiSaveAndRemove(saves, removals)
}
func (c *ChannelStore) HasChannel(channel string) bool {
for _, info := range c.channelsInfo {
for _, ch := range info.Channels {
if ch.GetName() == channel {
return true
}
}
}
return false
}
func (c *ChannelStore) GetNodeChannelsBy(nodeSelector NodeSelector, channelSelectors ...ChannelSelector) []*NodeChannelInfo {
log.Error("ChannelStore doesn't implement GetNodeChannelsBy")
return nil
}
func (c *ChannelStore) UpdateState(isSuccessful bool, channels ...RWChannel) {
log.Error("ChannelStore doesn't implement UpdateState")
}
func (c *ChannelStore) SetLegacyChannelByNode(nodeIDs ...int64) {
log.Error("ChannelStore doesn't implement SetLegacyChannelByNode")
}
// buildNodeChannelKey generates a key for kv store, where the key is a concatenation of ChannelWatchSubPath, nodeID and channel name.
// ${WatchSubPath}/${nodeID}/${channelName}
func buildNodeChannelKey(nodeID int64, chName string) string {
return fmt.Sprintf("%s%s%d%s%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), delimiter, nodeID, delimiter, chName)
}
@ -485,33 +564,3 @@ func parseNodeKey(key string) (int64, error) {
}
return strconv.ParseInt(s[len(s)-2], 10, 64)
}
// ChannelOpTypeNames implements zap log marshaller for ChannelOpSet.
var ChannelOpTypeNames = []string{"Add", "Delete"}
// TODO: NIT: ObjectMarshaler -> ObjectMarshaller
// MarshalLogObject implements the interface ObjectMarshaler.
func (op *ChannelOp) MarshalLogObject(enc zapcore.ObjectEncoder) error {
enc.AddString("type", ChannelOpTypeNames[op.Type])
enc.AddInt64("nodeID", op.NodeID)
cstr := "["
if len(op.Channels) > 0 {
for _, s := range op.Channels {
cstr += s.GetName()
cstr += ", "
}
cstr = cstr[:len(cstr)-2]
}
cstr += "]"
enc.AddString("channels", cstr)
return nil
}
// TODO: NIT: ArrayMarshaler -> ArrayMarshaller
// MarshalLogArray implements the interface of ArrayMarshaler of zap.
func (c *ChannelOpSet) MarshalLogArray(enc zapcore.ArrayEncoder) error {
for _, o := range c.Collect() {
enc.AppendObject(o)
}
return nil
}

View File

@ -43,7 +43,7 @@ func genNodeChannelInfos(id int64, num int) *NodeChannelInfo {
return NewNodeChannelInfo(id, channels...)
}
func genChannelOperations(from, to int64, num int) *ChannelOpSet {
func genChannelOperationsV1(from, to int64, num int) *ChannelOpSet {
channels := make([]RWChannel, 0, num)
for i := 0; i < num; i++ {
name := fmt.Sprintf("ch%d", i)
@ -86,7 +86,7 @@ func TestChannelStore_Update(t *testing.T) {
},
},
args{
genChannelOperations(1, 2, 250),
genChannelOperationsV1(1, 2, 250),
},
false,
},

View File

@ -0,0 +1,432 @@
package datacoord
import (
"strconv"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/kv"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)
type StateChannelStore struct {
store kv.TxnKV
channelsInfo map[int64]*NodeChannelInfo // A map of (nodeID) -> (NodeChannelInfo).
}
var _ RWChannelStore = (*StateChannelStore)(nil)
var errChannelNotExistInNode = errors.New("channel doesn't exist in given node")
func NewChannelStoreV2(kv kv.TxnKV) RWChannelStore {
return NewStateChannelStore(kv)
}
func NewStateChannelStore(kv kv.TxnKV) *StateChannelStore {
c := StateChannelStore{
store: kv,
channelsInfo: make(map[int64]*NodeChannelInfo),
}
c.channelsInfo[bufferID] = &NodeChannelInfo{
NodeID: bufferID,
Channels: make(map[string]RWChannel),
}
return &c
}
func (c *StateChannelStore) Reload() error {
record := timerecord.NewTimeRecorder("datacoord")
keys, values, err := c.store.LoadWithPrefix(Params.CommonCfg.DataCoordWatchSubPath.GetValue())
if err != nil {
return err
}
for i := 0; i < len(keys); i++ {
k := keys[i]
v := values[i]
nodeID, err := parseNodeKey(k)
if err != nil {
return err
}
info := &datapb.ChannelWatchInfo{}
if err := proto.Unmarshal([]byte(v), info); err != nil {
return err
}
reviseVChannelInfo(info.GetVchan())
c.AddNode(nodeID)
channel := NewStateChannelByWatchInfo(nodeID, info)
c.channelsInfo[nodeID].AddChannel(channel)
log.Info("channel store reload channel",
zap.Int64("nodeID", nodeID), zap.String("channel", channel.Name))
metrics.DataCoordDmlChannelNum.WithLabelValues(strconv.FormatInt(nodeID, 10)).Set(float64(len(c.channelsInfo[nodeID].Channels)))
}
log.Info("channel store reload done", zap.Duration("duration", record.ElapseSpan()))
return nil
}
func (c *StateChannelStore) AddNode(nodeID int64) {
if _, ok := c.channelsInfo[nodeID]; ok {
return
}
c.channelsInfo[nodeID] = &NodeChannelInfo{
NodeID: nodeID,
Channels: make(map[string]RWChannel),
}
}
func (c *StateChannelStore) UpdateState(isSuccessful bool, channels ...RWChannel) {
lo.ForEach(channels, func(ch RWChannel, _ int) {
for _, cInfo := range c.channelsInfo {
if stateChannel, ok := cInfo.Channels[ch.GetName()]; ok {
if isSuccessful {
stateChannel.(*StateChannel).TransitionOnSuccess()
} else {
stateChannel.(*StateChannel).TransitionOnFailure()
}
}
}
})
}
func (c *StateChannelStore) SetLegacyChannelByNode(nodeIDs ...int64) {
lo.ForEach(nodeIDs, func(nodeID int64, _ int) {
if cInfo, ok := c.channelsInfo[nodeID]; ok {
for _, ch := range cInfo.Channels {
ch.(*StateChannel).setState(Legacy)
}
}
})
}
func (c *StateChannelStore) Update(opSet *ChannelOpSet) error {
// Split opset into multiple txn. Operations on the same channel must be executed in one txn.
perChOps := opSet.SplitByChannel()
// Execute a txn for every 64 operations.
count := 0
operations := make([]*ChannelOp, 0, maxOperationsPerTxn)
for _, opset := range perChOps {
if !c.sanityCheckPerChannelOpSet(opset) {
log.Error("unsupported ChannelOpSet", zap.Any("OpSet", opset))
continue
}
if opset.Len() > maxOperationsPerTxn {
log.Error("Operations for one channel exceeds maxOperationsPerTxn",
zap.Any("opset size", opset.Len()),
zap.Int("limit", maxOperationsPerTxn))
}
if count+opset.Len() > maxOperationsPerTxn {
if err := c.updateMeta(NewChannelOpSet(operations...)); err != nil {
return err
}
count = 0
operations = make([]*ChannelOp, 0, maxOperationsPerTxn)
}
count += opset.Len()
operations = append(operations, opset.Collect()...)
}
if count == 0 {
return nil
}
return c.updateMeta(NewChannelOpSet(operations...))
}
// remove from the assignments
func (c *StateChannelStore) removeAssignment(nodeID int64, channelName string) {
if cInfo, ok := c.channelsInfo[nodeID]; ok {
delete(cInfo.Channels, channelName)
}
}
func (c *StateChannelStore) addAssignment(nodeID int64, channel RWChannel) {
if cInfo, ok := c.channelsInfo[nodeID]; ok {
cInfo.Channels[channel.GetName()] = channel
} else {
c.channelsInfo[nodeID] = &NodeChannelInfo{
NodeID: nodeID,
Channels: map[string]RWChannel{
channel.GetName(): channel,
},
}
}
}
// updateMeta applies the WATCH/RELEASE/DELETE operations to the current channel store.
// DELETE + WATCH ---> from bufferID to nodeID
// DELETE + WATCH ---> from lagecyID to nodeID
// DELETE + WATCH ---> from deletedNode to nodeID/bufferID
// RELEASE ---> release from nodeID
// WATCH ---> watch to a new channel
// DELETE ---> remove the channel
func (c *StateChannelStore) sanityCheckPerChannelOpSet(opSet *ChannelOpSet) bool {
if opSet.Len() == 2 {
ops := opSet.Collect()
return (ops[0].Type == Delete && ops[1].Type == Watch) || (ops[1].Type == Delete && ops[0].Type == Watch)
} else if opSet.Len() == 1 {
t := opSet.Collect()[0].Type
return t == Delete || t == Watch || t == Release
}
return false
}
// DELETE + WATCH
func (c *StateChannelStore) updateMetaMemoryForPairOp(chName string, opSet *ChannelOpSet) error {
if !c.sanityCheckPerChannelOpSet(opSet) {
return errUnknownOpType
}
ops := opSet.Collect()
op1 := ops[1]
op2 := ops[0]
if ops[0].Type == Delete {
op1 = ops[0]
op2 = ops[1]
}
cInfo, ok := c.channelsInfo[op1.NodeID]
if !ok {
return errChannelNotExistInNode
}
var ch *StateChannel
if channel, ok := cInfo.Channels[chName]; ok {
ch = channel.(*StateChannel)
c.addAssignment(op2.NodeID, ch)
c.removeAssignment(op1.NodeID, chName)
} else {
if cInfo, ok = c.channelsInfo[op2.NodeID]; ok {
if channel2, ok := cInfo.Channels[chName]; ok {
ch = channel2.(*StateChannel)
}
}
}
// update channel
if ch != nil {
ch.Assign(op2.NodeID)
if op2.NodeID == bufferID {
ch.setState(Standby)
} else {
ch.setState(ToWatch)
}
}
return nil
}
func (c *StateChannelStore) getChannel(nodeID int64, channelName string) *StateChannel {
if cInfo, ok := c.channelsInfo[nodeID]; ok {
if storedChannel, ok := cInfo.Channels[channelName]; ok {
return storedChannel.(*StateChannel)
}
log.Error("Channel doesn't exist in Node", zap.String("channel", channelName), zap.Int64("nodeID", nodeID))
} else {
log.Error("Node doesn't exist", zap.Int64("NodeID", nodeID))
}
return nil
}
func (c *StateChannelStore) updateMetaMemoryForSingleOp(op *ChannelOp) error {
lo.ForEach(op.Channels, func(ch RWChannel, _ int) {
switch op.Type {
case Release: // release an already exsits storedChannel-node pair
if channel := c.getChannel(op.NodeID, ch.GetName()); channel != nil {
channel.setState(ToRelease)
}
case Watch:
storedChannel := c.getChannel(op.NodeID, ch.GetName())
if storedChannel == nil { // New Channel
// set the correct assigment and state for NEW stateChannel
newChannel := NewStateChannel(ch)
newChannel.Assign(op.NodeID)
if op.NodeID != bufferID {
newChannel.setState(ToWatch)
}
// add channel to memory
c.addAssignment(op.NodeID, newChannel)
} else { // assign to the original nodes
storedChannel.setState(ToWatch)
}
case Delete: // Remove Channel
// if not Delete from bufferID, remove from channel
if op.NodeID != bufferID {
c.removeAssignment(op.NodeID, ch.GetName())
}
default:
log.Error("unknown opType in updateMetaMemoryForSingleOp", zap.Any("type", op.Type))
}
})
return nil
}
func (c *StateChannelStore) updateMeta(opSet *ChannelOpSet) error {
// Update ChannelStore's kv store.
if err := c.txn(opSet); err != nil {
return err
}
// Update memory
chOpSet := opSet.SplitByChannel()
for chName, ops := range chOpSet {
// DELETE + WATCH
if ops.Len() == 2 {
c.updateMetaMemoryForPairOp(chName, ops)
// RELEASE, DELETE, WATCH
} else if ops.Len() == 1 {
c.updateMetaMemoryForSingleOp(ops.Collect()[0])
} else {
log.Error("unsupported ChannelOpSet", zap.Any("OpSet", ops))
}
}
return nil
}
// txn updates the channelStore's kv store with the given channel ops.
func (c *StateChannelStore) txn(opSet *ChannelOpSet) error {
var (
saves = make(map[string]string)
removals []string
)
for _, op := range opSet.Collect() {
opSaves, opRemovals, err := op.BuildKV()
if err != nil {
return err
}
saves = lo.Assign(opSaves, saves)
removals = append(removals, opRemovals...)
}
return c.store.MultiSaveAndRemove(saves, removals)
}
func (c *StateChannelStore) RemoveNode(nodeID int64) {
delete(c.channelsInfo, nodeID)
}
func (c *StateChannelStore) HasChannel(channel string) bool {
for _, info := range c.channelsInfo {
if _, ok := info.Channels[channel]; ok {
return true
}
}
return false
}
type (
ChannelSelector func(ch *StateChannel) bool
NodeSelector func(ID int64) bool
)
func WithAllNodes() NodeSelector {
return func(ID int64) bool {
return true
}
}
func WithoutBufferNode() NodeSelector {
return func(ID int64) bool {
return ID != int64(bufferID)
}
}
func WithNodeIDs(IDs ...int64) NodeSelector {
return func(ID int64) bool {
return lo.Contains(IDs, ID)
}
}
func WithoutNodeIDs(IDs ...int64) NodeSelector {
return func(ID int64) bool {
return !lo.Contains(IDs, ID)
}
}
func WithChannelName(channel string) ChannelSelector {
return func(ch *StateChannel) bool {
return ch.GetName() == channel
}
}
func WithCollectionIDV2(collectionID int64) ChannelSelector {
return func(ch *StateChannel) bool {
return ch.GetCollectionID() == collectionID
}
}
func WithChannelStates(states ...ChannelState) ChannelSelector {
return func(ch *StateChannel) bool {
return lo.Contains(states, ch.currentState)
}
}
func (c *StateChannelStore) GetNodeChannelsBy(nodeSelector NodeSelector, channelSelectors ...ChannelSelector) []*NodeChannelInfo {
nodeChannels := make(map[int64]*NodeChannelInfo)
for nodeID, cInfo := range c.channelsInfo {
if nodeSelector(nodeID) {
selected := make(map[string]RWChannel)
for chName, channel := range cInfo.Channels {
var sel bool = true
for _, selector := range channelSelectors {
if !selector(channel.(*StateChannel)) {
sel = false
break
}
}
if sel {
selected[chName] = channel
}
}
nodeChannels[nodeID] = &NodeChannelInfo{
NodeID: nodeID,
Channels: selected,
}
}
}
return lo.Values(nodeChannels)
}
func (c *StateChannelStore) GetNodesChannels() []*NodeChannelInfo {
ret := make([]*NodeChannelInfo, 0, len(c.channelsInfo))
for id, info := range c.channelsInfo {
if id != bufferID {
ret = append(ret, info)
}
}
return ret
}
func (c *StateChannelStore) GetBufferChannelInfo() *NodeChannelInfo {
return c.GetNode(bufferID)
}
func (c *StateChannelStore) GetNode(nodeID int64) *NodeChannelInfo {
if info, ok := c.channelsInfo[nodeID]; ok {
return info
}
return nil
}
func (c *StateChannelStore) GetNodeChannelCount(nodeID int64) int {
if cInfo, ok := c.channelsInfo[nodeID]; ok {
return len(cInfo.Channels)
}
return 0
}
func (c *StateChannelStore) GetNodes() []int64 {
return lo.Filter(lo.Keys(c.channelsInfo), func(ID int64, _ int) bool {
return ID != bufferID
})
}
// remove deletes kv pairs from the kv store where keys have given nodeID as prefix.
func (c *StateChannelStore) remove(nodeID int64) error {
k := buildKeyPrefix(nodeID)
return c.store.RemoveWithPrefix(k)
}

View File

@ -0,0 +1,483 @@
package datacoord
import (
"fmt"
"strconv"
"testing"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/kv/mocks"
"github.com/milvus-io/milvus/internal/kv/predicates"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/testutils"
)
func TestStateChannelStore(t *testing.T) {
suite.Run(t, new(StateChannelStoreSuite))
}
type StateChannelStoreSuite struct {
testutils.PromMetricsSuite
mockTxn *mocks.TxnKV
}
func (s *StateChannelStoreSuite) SetupTest() {
s.mockTxn = mocks.NewTxnKV(s.T())
}
func generateWatchInfo(name string, state datapb.ChannelWatchState) *datapb.ChannelWatchInfo {
return &datapb.ChannelWatchInfo{
Vchan: &datapb.VchannelInfo{
ChannelName: name,
},
State: state,
}
}
func (s *StateChannelStoreSuite) createChannelInfo(nodeID int64, channels ...RWChannel) *NodeChannelInfo {
cInfo := &NodeChannelInfo{
NodeID: nodeID,
Channels: make(map[string]RWChannel),
}
for _, channel := range channels {
cInfo.Channels[channel.GetName()] = channel
}
return cInfo
}
func (s *StateChannelStoreSuite) TestGetNodeChannelsBy() {
nodes := []int64{bufferID, 100, 101, 102}
nodesExcludeBufferID := []int64{100, 101, 102}
channels := []*StateChannel{
getChannel("ch1", 1),
getChannel("ch2", 1),
getChannel("ch3", 1),
getChannel("ch4", 1),
getChannel("ch5", 1),
getChannel("ch6", 1),
getChannel("ch7", 1),
}
channelsInfo := map[int64]*NodeChannelInfo{
bufferID: s.createChannelInfo(bufferID, channels[0]),
100: s.createChannelInfo(100, channels[1], channels[2]),
101: s.createChannelInfo(101, channels[3], channels[4]),
102: s.createChannelInfo(102, channels[5], channels[6]), // legacy nodes
}
store := NewStateChannelStore(s.mockTxn)
lo.ForEach(nodes, func(nodeID int64, _ int) { store.AddNode(nodeID) })
store.channelsInfo = channelsInfo
lo.ForEach(channels, func(ch *StateChannel, _ int) {
if ch.GetName() == "ch6" || ch.GetName() == "ch7" {
ch.setState(Legacy)
}
s.Require().True(store.HasChannel(ch.GetName()))
})
s.Require().ElementsMatch(nodesExcludeBufferID, store.GetNodes())
store.SetLegacyChannelByNode(102)
s.Run("test AddNode RemoveNode", func() {
var nodeID int64 = 19530
_, ok := store.channelsInfo[nodeID]
s.Require().False(ok)
store.AddNode(nodeID)
_, ok = store.channelsInfo[nodeID]
s.True(ok)
store.RemoveNode(nodeID)
_, ok = store.channelsInfo[nodeID]
s.False(ok)
})
s.Run("test GetNodeChannels", func() {
infos := store.GetNodesChannels()
expectedResults := map[int64][]string{
100: {"ch2", "ch3"},
101: {"ch4", "ch5"},
102: {"ch6", "ch7"},
}
s.Equal(3, len(infos))
lo.ForEach(infos, func(info *NodeChannelInfo, _ int) {
expectedChannels, ok := expectedResults[info.NodeID]
s.True(ok)
gotChannels := lo.Keys(info.Channels)
s.ElementsMatch(expectedChannels, gotChannels)
})
})
s.Run("test GetBufferChannelInfo", func() {
info := store.GetBufferChannelInfo()
s.NotNil(info)
gotChannels := lo.Keys(info.Channels)
s.ElementsMatch([]string{"ch1"}, gotChannels)
})
s.Run("test GetNode", func() {
info := store.GetNode(19530)
s.Nil(info)
info = store.GetNode(bufferID)
s.NotNil(info)
gotChannels := lo.Keys(info.Channels)
s.ElementsMatch([]string{"ch1"}, gotChannels)
})
tests := []struct {
description string
nodeSelector NodeSelector
channelSelectors []ChannelSelector
expectedResult map[int64][]string
}{
{"test withnodeIDs bufferID", WithNodeIDs(bufferID), nil, map[int64][]string{bufferID: {"ch1"}}},
{"test withnodeIDs 100", WithNodeIDs(100), nil, map[int64][]string{100: {"ch2", "ch3"}}},
{"test withnodeIDs 101 102", WithNodeIDs(101, 102), nil, map[int64][]string{
101: {"ch4", "ch5"},
102: {"ch6", "ch7"},
}},
{"test withAllNodes", WithAllNodes(), nil, map[int64][]string{
bufferID: {"ch1"},
100: {"ch2", "ch3"},
101: {"ch4", "ch5"},
102: {"ch6", "ch7"},
}},
{"test WithoutBufferNode", WithoutBufferNode(), nil, map[int64][]string{
100: {"ch2", "ch3"},
101: {"ch4", "ch5"},
102: {"ch6", "ch7"},
}},
{"test WithoutNodeIDs 100, 101", WithoutNodeIDs(100, 101), nil, map[int64][]string{
bufferID: {"ch1"},
102: {"ch6", "ch7"},
}},
{
"test WithChannelName ch1", WithNodeIDs(bufferID),
[]ChannelSelector{WithChannelName("ch1")},
map[int64][]string{
bufferID: {"ch1"},
},
},
{
"test WithChannelName ch1, collectionID 1", WithNodeIDs(100),
[]ChannelSelector{
WithChannelName("ch2"),
WithCollectionIDV2(1),
},
map[int64][]string{100: {"ch2"}},
},
{
"test WithCollectionID 1", WithAllNodes(),
[]ChannelSelector{
WithCollectionIDV2(1),
},
map[int64][]string{
bufferID: {"ch1"},
100: {"ch2", "ch3"},
101: {"ch4", "ch5"},
102: {"ch6", "ch7"},
},
},
{
"test WithChannelState", WithNodeIDs(102),
[]ChannelSelector{
WithChannelStates(Legacy),
},
map[int64][]string{
102: {"ch6", "ch7"},
},
},
}
for _, test := range tests {
s.Run(test.description, func() {
if test.channelSelectors == nil {
test.channelSelectors = []ChannelSelector{}
}
infos := store.GetNodeChannelsBy(test.nodeSelector, test.channelSelectors...)
log.Info("got test infos", zap.Any("infos", infos))
s.Equal(len(test.expectedResult), len(infos))
lo.ForEach(infos, func(info *NodeChannelInfo, _ int) {
expectedChannels, ok := test.expectedResult[info.NodeID]
s.True(ok)
gotChannels := lo.Keys(info.Channels)
s.ElementsMatch(expectedChannels, gotChannels)
})
})
}
}
func (s *StateChannelStoreSuite) TestUpdateWithTxnLimit() {
tests := []struct {
description string
inOpCount int
outTxnCount int
}{
{"operations count < maxPerTxn", maxOperationsPerTxn - 1, 1},
{"operations count = maxPerTxn", maxOperationsPerTxn, 1},
{"operations count > maxPerTxn", maxOperationsPerTxn + 1, 2},
{"operations count = 2*maxPerTxn", maxOperationsPerTxn * 2, 2},
{"operations count = 2*maxPerTxn+1", maxOperationsPerTxn*2 + 1, 3},
}
for _, test := range tests {
s.SetupTest()
s.Run(test.description, func() {
s.mockTxn.EXPECT().MultiSaveAndRemove(mock.Anything, mock.Anything).
Run(func(saves map[string]string, removals []string, preds ...predicates.Predicate) {
log.Info("test save and remove", zap.Any("saves", saves), zap.Any("removals", removals))
}).Return(nil).Times(test.outTxnCount)
store := NewStateChannelStore(s.mockTxn)
store.AddNode(1)
s.Require().ElementsMatch([]int64{1}, store.GetNodes())
s.Require().Equal(0, store.GetNodeChannelCount(1))
// Get operations
ops := genChannelOperations(1, Watch, test.inOpCount)
err := store.Update(ops)
s.NoError(err)
})
}
}
func (s *StateChannelStoreSuite) TestUpdateMeta() {
tests := []struct {
description string
opSet *ChannelOpSet
nodeIDs []int64
channels []*StateChannel
assignments map[int64][]string
outAssignments map[int64][]string
}{
{
"delete_watch_ch1 from bufferID to nodeID=100",
NewChannelOpSet(
NewChannelOp(bufferID, Delete, getChannel("ch1", 1)),
NewChannelOp(100, Watch, getChannel("ch1", 1)),
),
[]int64{bufferID, 100},
[]*StateChannel{getChannel("ch1", 1)},
map[int64][]string{
bufferID: {"ch1"},
},
map[int64][]string{
100: {"ch1"},
},
},
{
"delete_watch_ch1 from lagecyID=99 to nodeID=100",
NewChannelOpSet(
NewChannelOp(99, Delete, getChannel("ch1", 1)),
NewChannelOp(100, Watch, getChannel("ch1", 1)),
),
[]int64{bufferID, 99, 100},
[]*StateChannel{getChannel("ch1", 1)},
map[int64][]string{
99: {"ch1"},
},
map[int64][]string{
100: {"ch1"},
},
},
{
"release from nodeID=100",
NewChannelOpSet(
NewChannelOp(100, Release, getChannel("ch1", 1)),
),
[]int64{bufferID, 100},
[]*StateChannel{getChannel("ch1", 1)},
map[int64][]string{
100: {"ch1"},
},
map[int64][]string{
100: {"ch1"},
},
},
{
"watch a new channel from nodeID=100",
NewChannelOpSet(
NewChannelOp(100, Watch, getChannel("ch1", 1)),
),
[]int64{bufferID, 100},
[]*StateChannel{getChannel("ch1", 1)},
map[int64][]string{
100: {"ch1"},
},
map[int64][]string{
100: {"ch1"},
},
},
{
"Delete remove a channelfrom nodeID=100",
NewChannelOpSet(
NewChannelOp(100, Delete, getChannel("ch1", 1)),
),
[]int64{bufferID, 100},
[]*StateChannel{getChannel("ch1", 1)},
map[int64][]string{
100: {"ch1"},
},
map[int64][]string{
100: {},
},
},
}
s.SetupTest()
s.mockTxn.EXPECT().MultiSaveAndRemove(mock.Anything, mock.Anything).
Run(func(saves map[string]string, removals []string, preds ...predicates.Predicate) {
}).Return(nil).Times(len(tests))
for _, test := range tests {
s.Run(test.description, func() {
store := NewStateChannelStore(s.mockTxn)
lo.ForEach(test.nodeIDs, func(nodeID int64, _ int) {
store.AddNode(nodeID)
s.Require().Equal(0, store.GetNodeChannelCount(nodeID))
})
c := make(map[string]*StateChannel)
lo.ForEach(test.channels, func(ch *StateChannel, _ int) { c[ch.GetName()] = ch })
for nodeID, channels := range test.assignments {
lo.ForEach(channels, func(ch string, _ int) {
store.addAssignment(nodeID, c[ch])
})
s.Require().Equal(1, store.GetNodeChannelCount(nodeID))
}
err := store.updateMeta(test.opSet)
s.NoError(err)
for nodeID, channels := range test.outAssignments {
got := store.GetNodeChannelsBy(WithNodeIDs(nodeID))
s.NotNil(got)
s.Require().Equal(1, len(got))
info := got[0]
s.ElementsMatch(channels, lo.Keys(info.Channels))
}
})
}
}
func (s *StateChannelStoreSuite) TestUpdateState() {
tests := []struct {
description string
inSuccess bool
inChannelState ChannelState
outChannelState ChannelState
}{
{"input standby, fail", false, Standby, Standby},
{"input standby, success", true, Standby, ToWatch},
}
for _, test := range tests {
s.Run(test.description, func() {
store := NewStateChannelStore(s.mockTxn)
ch := "ch-1"
channel := NewStateChannel(getChannel(ch, 1))
channel.setState(test.inChannelState)
store.channelsInfo[1] = &NodeChannelInfo{
NodeID: bufferID,
Channels: map[string]RWChannel{
ch: channel,
},
}
store.UpdateState(test.inSuccess, channel)
s.Equal(test.outChannelState, channel.currentState)
})
}
}
func (s *StateChannelStoreSuite) TestReload() {
type item struct {
nodeID int64
channelName string
}
type testCase struct {
tag string
items []item
expect map[int64]int
}
cases := []testCase{
{
tag: "empty",
items: []item{},
expect: map[int64]int{},
},
{
tag: "normal",
items: []item{
{nodeID: 1, channelName: "dml1_v0"},
{nodeID: 1, channelName: "dml2_v1"},
{nodeID: 2, channelName: "dml3_v0"},
},
expect: map[int64]int{1: 2, 2: 1},
},
{
tag: "buffer",
items: []item{
{nodeID: bufferID, channelName: "dml1_v0"},
},
expect: map[int64]int{bufferID: 1},
},
}
for _, tc := range cases {
s.Run(tc.tag, func() {
s.mockTxn.ExpectedCalls = nil
var keys, values []string
for _, item := range tc.items {
keys = append(keys, fmt.Sprintf("channel_store/%d/%s", item.nodeID, item.channelName))
info := generateWatchInfo(item.channelName, datapb.ChannelWatchState_WatchSuccess)
bs, err := proto.Marshal(info)
s.Require().NoError(err)
values = append(values, string(bs))
}
s.mockTxn.EXPECT().LoadWithPrefix(mock.AnythingOfType("string")).Return(keys, values, nil)
store := NewStateChannelStore(s.mockTxn)
err := store.Reload()
s.Require().NoError(err)
for nodeID, expect := range tc.expect {
s.MetricsEqual(metrics.DataCoordDmlChannelNum.WithLabelValues(strconv.FormatInt(nodeID, 10)), float64(expect))
}
})
}
}
func genChannelOperations(nodeID int64, opType ChannelOpType, num int) *ChannelOpSet {
channels := make([]RWChannel, 0, num)
for i := 0; i < num; i++ {
name := fmt.Sprintf("ch%d", i)
channel := NewStateChannel(getChannel(name, 1))
channel.Info = &datapb.ChannelWatchInfo{}
channels = append(channels, channel)
}
ops := NewChannelOpSet(NewChannelOp(nodeID, opType, channels...))
return ops
}

View File

@ -35,7 +35,7 @@ type Cluster interface {
Startup(ctx context.Context, nodes []*NodeInfo) error
Register(node *NodeInfo) error
UnRegister(node *NodeInfo) error
Watch(ctx context.Context, ch string, collectionID UniqueID) error
Watch(ctx context.Context, ch RWChannel) error
Flush(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error
FlushChannels(ctx context.Context, nodeID int64, flushTs Timestamp, channels []string) error
PreImport(nodeID int64, in *datapb.PreImportRequest) error
@ -69,10 +69,19 @@ func (c *ClusterImpl) Startup(ctx context.Context, nodes []*NodeInfo) error {
for _, node := range nodes {
c.sessionManager.AddSession(node)
}
currs := lo.Map(nodes, func(info *NodeInfo, _ int) int64 {
return info.NodeID
var (
legacyNodes []int64
allNodes []int64
)
lo.ForEach(nodes, func(info *NodeInfo, _ int) {
if info.IsLegacy {
legacyNodes = append(legacyNodes, info.NodeID)
}
allNodes = append(allNodes, info.NodeID)
})
return c.channelManager.Startup(ctx, currs)
return c.channelManager.Startup(ctx, legacyNodes, allNodes)
}
// Register registers a new node in cluster
@ -88,14 +97,15 @@ func (c *ClusterImpl) UnRegister(node *NodeInfo) error {
}
// Watch tries to add a channel in datanode cluster
func (c *ClusterImpl) Watch(ctx context.Context, ch string, collectionID UniqueID) error {
return c.channelManager.Watch(ctx, &channelMeta{Name: ch, CollectionID: collectionID})
func (c *ClusterImpl) Watch(ctx context.Context, ch RWChannel) error {
return c.channelManager.Watch(ctx, ch)
}
// Flush sends async FlushSegments requests to dataNodes
// which also according to channels where segments are assigned to.
func (c *ClusterImpl) Flush(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error {
if !c.channelManager.Match(nodeID, channel) {
ch, founded := c.channelManager.GetChannel(nodeID, channel)
if !founded {
log.Warn("node is not matched with channel",
zap.String("channel", channel),
zap.Int64("nodeID", nodeID),
@ -103,8 +113,6 @@ func (c *ClusterImpl) Flush(ctx context.Context, nodeID int64, channel string, s
return fmt.Errorf("channel %s is not watched on node %d", channel, nodeID)
}
_, collID := c.channelManager.GetCollectionIDByChannel(channel)
getSegmentID := func(segment *datapb.SegmentInfo, _ int) int64 {
return segment.GetID()
}
@ -115,7 +123,7 @@ func (c *ClusterImpl) Flush(ctx context.Context, nodeID int64, channel string, s
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithTargetID(nodeID),
),
CollectionID: collID,
CollectionID: ch.GetCollectionID(),
SegmentIDs: lo.Map(segments, getSegmentID),
ChannelName: channel,
}

View File

@ -67,8 +67,8 @@ func (suite *ClusterSuite) TestStartup() {
{NodeID: 4, Address: "addr4"},
}
suite.mockSession.EXPECT().AddSession(mock.Anything).Return().Times(len(nodes))
suite.mockChManager.EXPECT().Startup(mock.Anything, mock.Anything).
RunAndReturn(func(ctx context.Context, nodeIDs []int64) error {
suite.mockChManager.EXPECT().Startup(mock.Anything, mock.Anything, mock.Anything).
RunAndReturn(func(ctx context.Context, legacys []int64, nodeIDs []int64) error {
suite.ElementsMatch(lo.Map(nodes, func(info *NodeInfo, _ int) int64 { return info.NodeID }), nodeIDs)
return nil
}).Once()
@ -122,17 +122,19 @@ func (suite *ClusterSuite) TestWatch() {
}).Once()
cluster := NewClusterImpl(suite.mockSession, suite.mockChManager)
err := cluster.Watch(context.Background(), ch, collectionID)
err := cluster.Watch(context.Background(), getChannel(ch, collectionID))
suite.NoError(err)
}
func (suite *ClusterSuite) TestFlush() {
suite.mockChManager.EXPECT().Match(mock.Anything, mock.Anything).
RunAndReturn(func(nodeID int64, channel string) bool {
return nodeID != 1
suite.mockChManager.EXPECT().GetChannel(mock.Anything, mock.Anything).
RunAndReturn(func(nodeID int64, channel string) (RWChannel, bool) {
if nodeID == 1 {
return nil, false
}
return getChannel("ch-1", 2), true
}).Twice()
suite.mockChManager.EXPECT().GetCollectionIDByChannel(mock.Anything).Return(true, 100).Once()
suite.mockSession.EXPECT().Flush(mock.Anything, mock.Anything, mock.Anything).Once()
cluster := NewClusterImpl(suite.mockSession, suite.mockChManager)

View File

@ -17,89 +17,35 @@ func (_m *MockRWChannelStore) EXPECT() *MockRWChannelStore_Expecter {
return &MockRWChannelStore_Expecter{mock: &_m.Mock}
}
// Add provides a mock function with given fields: nodeID
func (_m *MockRWChannelStore) Add(nodeID int64) {
// AddNode provides a mock function with given fields: nodeID
func (_m *MockRWChannelStore) AddNode(nodeID int64) {
_m.Called(nodeID)
}
// MockRWChannelStore_Add_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Add'
type MockRWChannelStore_Add_Call struct {
// MockRWChannelStore_AddNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddNode'
type MockRWChannelStore_AddNode_Call struct {
*mock.Call
}
// Add is a helper method to define mock.On call
// AddNode is a helper method to define mock.On call
// - nodeID int64
func (_e *MockRWChannelStore_Expecter) Add(nodeID interface{}) *MockRWChannelStore_Add_Call {
return &MockRWChannelStore_Add_Call{Call: _e.mock.On("Add", nodeID)}
func (_e *MockRWChannelStore_Expecter) AddNode(nodeID interface{}) *MockRWChannelStore_AddNode_Call {
return &MockRWChannelStore_AddNode_Call{Call: _e.mock.On("AddNode", nodeID)}
}
func (_c *MockRWChannelStore_Add_Call) Run(run func(nodeID int64)) *MockRWChannelStore_Add_Call {
func (_c *MockRWChannelStore_AddNode_Call) Run(run func(nodeID int64)) *MockRWChannelStore_AddNode_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
})
return _c
}
func (_c *MockRWChannelStore_Add_Call) Return() *MockRWChannelStore_Add_Call {
func (_c *MockRWChannelStore_AddNode_Call) Return() *MockRWChannelStore_AddNode_Call {
_c.Call.Return()
return _c
}
func (_c *MockRWChannelStore_Add_Call) RunAndReturn(run func(int64)) *MockRWChannelStore_Add_Call {
_c.Call.Return(run)
return _c
}
// Delete provides a mock function with given fields: nodeID
func (_m *MockRWChannelStore) Delete(nodeID int64) ([]RWChannel, error) {
ret := _m.Called(nodeID)
var r0 []RWChannel
var r1 error
if rf, ok := ret.Get(0).(func(int64) ([]RWChannel, error)); ok {
return rf(nodeID)
}
if rf, ok := ret.Get(0).(func(int64) []RWChannel); ok {
r0 = rf(nodeID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]RWChannel)
}
}
if rf, ok := ret.Get(1).(func(int64) error); ok {
r1 = rf(nodeID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockRWChannelStore_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete'
type MockRWChannelStore_Delete_Call struct {
*mock.Call
}
// Delete is a helper method to define mock.On call
// - nodeID int64
func (_e *MockRWChannelStore_Expecter) Delete(nodeID interface{}) *MockRWChannelStore_Delete_Call {
return &MockRWChannelStore_Delete_Call{Call: _e.mock.On("Delete", nodeID)}
}
func (_c *MockRWChannelStore_Delete_Call) Run(run func(nodeID int64)) *MockRWChannelStore_Delete_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
})
return _c
}
func (_c *MockRWChannelStore_Delete_Call) Return(_a0 []RWChannel, _a1 error) *MockRWChannelStore_Delete_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockRWChannelStore_Delete_Call) RunAndReturn(run func(int64) ([]RWChannel, error)) *MockRWChannelStore_Delete_Call {
func (_c *MockRWChannelStore_AddNode_Call) RunAndReturn(run func(int64)) *MockRWChannelStore_AddNode_Call {
_c.Call.Return(run)
return _c
}
@ -147,49 +93,6 @@ func (_c *MockRWChannelStore_GetBufferChannelInfo_Call) RunAndReturn(run func()
return _c
}
// GetChannels provides a mock function with given fields:
func (_m *MockRWChannelStore) GetChannels() []*NodeChannelInfo {
ret := _m.Called()
var r0 []*NodeChannelInfo
if rf, ok := ret.Get(0).(func() []*NodeChannelInfo); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*NodeChannelInfo)
}
}
return r0
}
// MockRWChannelStore_GetChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChannels'
type MockRWChannelStore_GetChannels_Call struct {
*mock.Call
}
// GetChannels is a helper method to define mock.On call
func (_e *MockRWChannelStore_Expecter) GetChannels() *MockRWChannelStore_GetChannels_Call {
return &MockRWChannelStore_GetChannels_Call{Call: _e.mock.On("GetChannels")}
}
func (_c *MockRWChannelStore_GetChannels_Call) Run(run func()) *MockRWChannelStore_GetChannels_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockRWChannelStore_GetChannels_Call) Return(_a0 []*NodeChannelInfo) *MockRWChannelStore_GetChannels_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockRWChannelStore_GetChannels_Call) RunAndReturn(run func() []*NodeChannelInfo) *MockRWChannelStore_GetChannels_Call {
_c.Call.Return(run)
return _c
}
// GetNode provides a mock function with given fields: nodeID
func (_m *MockRWChannelStore) GetNode(nodeID int64) *NodeChannelInfo {
ret := _m.Called(nodeID)
@ -276,6 +179,65 @@ func (_c *MockRWChannelStore_GetNodeChannelCount_Call) RunAndReturn(run func(int
return _c
}
// GetNodeChannelsBy provides a mock function with given fields: nodeSelector, channelSelectors
func (_m *MockRWChannelStore) GetNodeChannelsBy(nodeSelector NodeSelector, channelSelectors ...ChannelSelector) []*NodeChannelInfo {
_va := make([]interface{}, len(channelSelectors))
for _i := range channelSelectors {
_va[_i] = channelSelectors[_i]
}
var _ca []interface{}
_ca = append(_ca, nodeSelector)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 []*NodeChannelInfo
if rf, ok := ret.Get(0).(func(NodeSelector, ...ChannelSelector) []*NodeChannelInfo); ok {
r0 = rf(nodeSelector, channelSelectors...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*NodeChannelInfo)
}
}
return r0
}
// MockRWChannelStore_GetNodeChannelsBy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeChannelsBy'
type MockRWChannelStore_GetNodeChannelsBy_Call struct {
*mock.Call
}
// GetNodeChannelsBy is a helper method to define mock.On call
// - nodeSelector NodeSelector
// - channelSelectors ...ChannelSelector
func (_e *MockRWChannelStore_Expecter) GetNodeChannelsBy(nodeSelector interface{}, channelSelectors ...interface{}) *MockRWChannelStore_GetNodeChannelsBy_Call {
return &MockRWChannelStore_GetNodeChannelsBy_Call{Call: _e.mock.On("GetNodeChannelsBy",
append([]interface{}{nodeSelector}, channelSelectors...)...)}
}
func (_c *MockRWChannelStore_GetNodeChannelsBy_Call) Run(run func(nodeSelector NodeSelector, channelSelectors ...ChannelSelector)) *MockRWChannelStore_GetNodeChannelsBy_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]ChannelSelector, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(ChannelSelector)
}
}
run(args[0].(NodeSelector), variadicArgs...)
})
return _c
}
func (_c *MockRWChannelStore_GetNodeChannelsBy_Call) Return(_a0 []*NodeChannelInfo) *MockRWChannelStore_GetNodeChannelsBy_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockRWChannelStore_GetNodeChannelsBy_Call) RunAndReturn(run func(NodeSelector, ...ChannelSelector) []*NodeChannelInfo) *MockRWChannelStore_GetNodeChannelsBy_Call {
_c.Call.Return(run)
return _c
}
// GetNodes provides a mock function with given fields:
func (_m *MockRWChannelStore) GetNodes() []int64 {
ret := _m.Called()
@ -362,6 +324,48 @@ func (_c *MockRWChannelStore_GetNodesChannels_Call) RunAndReturn(run func() []*N
return _c
}
// HasChannel provides a mock function with given fields: channel
func (_m *MockRWChannelStore) HasChannel(channel string) bool {
ret := _m.Called(channel)
var r0 bool
if rf, ok := ret.Get(0).(func(string) bool); ok {
r0 = rf(channel)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// MockRWChannelStore_HasChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HasChannel'
type MockRWChannelStore_HasChannel_Call struct {
*mock.Call
}
// HasChannel is a helper method to define mock.On call
// - channel string
func (_e *MockRWChannelStore_Expecter) HasChannel(channel interface{}) *MockRWChannelStore_HasChannel_Call {
return &MockRWChannelStore_HasChannel_Call{Call: _e.mock.On("HasChannel", channel)}
}
func (_c *MockRWChannelStore_HasChannel_Call) Run(run func(channel string)) *MockRWChannelStore_HasChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockRWChannelStore_HasChannel_Call) Return(_a0 bool) *MockRWChannelStore_HasChannel_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockRWChannelStore_HasChannel_Call) RunAndReturn(run func(string) bool) *MockRWChannelStore_HasChannel_Call {
_c.Call.Return(run)
return _c
}
// Reload provides a mock function with given fields:
func (_m *MockRWChannelStore) Reload() error {
ret := _m.Called()
@ -403,6 +407,85 @@ func (_c *MockRWChannelStore_Reload_Call) RunAndReturn(run func() error) *MockRW
return _c
}
// RemoveNode provides a mock function with given fields: nodeID
func (_m *MockRWChannelStore) RemoveNode(nodeID int64) {
_m.Called(nodeID)
}
// MockRWChannelStore_RemoveNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveNode'
type MockRWChannelStore_RemoveNode_Call struct {
*mock.Call
}
// RemoveNode is a helper method to define mock.On call
// - nodeID int64
func (_e *MockRWChannelStore_Expecter) RemoveNode(nodeID interface{}) *MockRWChannelStore_RemoveNode_Call {
return &MockRWChannelStore_RemoveNode_Call{Call: _e.mock.On("RemoveNode", nodeID)}
}
func (_c *MockRWChannelStore_RemoveNode_Call) Run(run func(nodeID int64)) *MockRWChannelStore_RemoveNode_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
})
return _c
}
func (_c *MockRWChannelStore_RemoveNode_Call) Return() *MockRWChannelStore_RemoveNode_Call {
_c.Call.Return()
return _c
}
func (_c *MockRWChannelStore_RemoveNode_Call) RunAndReturn(run func(int64)) *MockRWChannelStore_RemoveNode_Call {
_c.Call.Return(run)
return _c
}
// SetLegacyChannelByNode provides a mock function with given fields: nodeIDs
func (_m *MockRWChannelStore) SetLegacyChannelByNode(nodeIDs ...int64) {
_va := make([]interface{}, len(nodeIDs))
for _i := range nodeIDs {
_va[_i] = nodeIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, _va...)
_m.Called(_ca...)
}
// MockRWChannelStore_SetLegacyChannelByNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetLegacyChannelByNode'
type MockRWChannelStore_SetLegacyChannelByNode_Call struct {
*mock.Call
}
// SetLegacyChannelByNode is a helper method to define mock.On call
// - nodeIDs ...int64
func (_e *MockRWChannelStore_Expecter) SetLegacyChannelByNode(nodeIDs ...interface{}) *MockRWChannelStore_SetLegacyChannelByNode_Call {
return &MockRWChannelStore_SetLegacyChannelByNode_Call{Call: _e.mock.On("SetLegacyChannelByNode",
append([]interface{}{}, nodeIDs...)...)}
}
func (_c *MockRWChannelStore_SetLegacyChannelByNode_Call) Run(run func(nodeIDs ...int64)) *MockRWChannelStore_SetLegacyChannelByNode_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]int64, len(args)-0)
for i, a := range args[0:] {
if a != nil {
variadicArgs[i] = a.(int64)
}
}
run(variadicArgs...)
})
return _c
}
func (_c *MockRWChannelStore_SetLegacyChannelByNode_Call) Return() *MockRWChannelStore_SetLegacyChannelByNode_Call {
_c.Call.Return()
return _c
}
func (_c *MockRWChannelStore_SetLegacyChannelByNode_Call) RunAndReturn(run func(...int64)) *MockRWChannelStore_SetLegacyChannelByNode_Call {
_c.Call.Return(run)
return _c
}
// Update provides a mock function with given fields: op
func (_m *MockRWChannelStore) Update(op *ChannelOpSet) error {
ret := _m.Called(op)
@ -445,6 +528,54 @@ func (_c *MockRWChannelStore_Update_Call) RunAndReturn(run func(*ChannelOpSet) e
return _c
}
// UpdateState provides a mock function with given fields: isSuccessful, channels
func (_m *MockRWChannelStore) UpdateState(isSuccessful bool, channels ...RWChannel) {
_va := make([]interface{}, len(channels))
for _i := range channels {
_va[_i] = channels[_i]
}
var _ca []interface{}
_ca = append(_ca, isSuccessful)
_ca = append(_ca, _va...)
_m.Called(_ca...)
}
// MockRWChannelStore_UpdateState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateState'
type MockRWChannelStore_UpdateState_Call struct {
*mock.Call
}
// UpdateState is a helper method to define mock.On call
// - isSuccessful bool
// - channels ...RWChannel
func (_e *MockRWChannelStore_Expecter) UpdateState(isSuccessful interface{}, channels ...interface{}) *MockRWChannelStore_UpdateState_Call {
return &MockRWChannelStore_UpdateState_Call{Call: _e.mock.On("UpdateState",
append([]interface{}{isSuccessful}, channels...)...)}
}
func (_c *MockRWChannelStore_UpdateState_Call) Run(run func(isSuccessful bool, channels ...RWChannel)) *MockRWChannelStore_UpdateState_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]RWChannel, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(RWChannel)
}
}
run(args[0].(bool), variadicArgs...)
})
return _c
}
func (_c *MockRWChannelStore_UpdateState_Call) Return() *MockRWChannelStore_UpdateState_Call {
_c.Call.Return()
return _c
}
func (_c *MockRWChannelStore_UpdateState_Call) RunAndReturn(run func(bool, ...RWChannel)) *MockRWChannelStore_UpdateState_Call {
_c.Call.Return(run)
return _c
}
// NewMockRWChannelStore creates a new instance of MockRWChannelStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockRWChannelStore(t interface {

View File

@ -189,6 +189,105 @@ func (_c *MockChannelManager_FindWatcher_Call) RunAndReturn(run func(string) (in
return _c
}
// GetChannel provides a mock function with given fields: nodeID, channel
func (_m *MockChannelManager) GetChannel(nodeID int64, channel string) (RWChannel, bool) {
ret := _m.Called(nodeID, channel)
var r0 RWChannel
var r1 bool
if rf, ok := ret.Get(0).(func(int64, string) (RWChannel, bool)); ok {
return rf(nodeID, channel)
}
if rf, ok := ret.Get(0).(func(int64, string) RWChannel); ok {
r0 = rf(nodeID, channel)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(RWChannel)
}
}
if rf, ok := ret.Get(1).(func(int64, string) bool); ok {
r1 = rf(nodeID, channel)
} else {
r1 = ret.Get(1).(bool)
}
return r0, r1
}
// MockChannelManager_GetChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChannel'
type MockChannelManager_GetChannel_Call struct {
*mock.Call
}
// GetChannel is a helper method to define mock.On call
// - nodeID int64
// - channel string
func (_e *MockChannelManager_Expecter) GetChannel(nodeID interface{}, channel interface{}) *MockChannelManager_GetChannel_Call {
return &MockChannelManager_GetChannel_Call{Call: _e.mock.On("GetChannel", nodeID, channel)}
}
func (_c *MockChannelManager_GetChannel_Call) Run(run func(nodeID int64, channel string)) *MockChannelManager_GetChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(string))
})
return _c
}
func (_c *MockChannelManager_GetChannel_Call) Return(_a0 RWChannel, _a1 bool) *MockChannelManager_GetChannel_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockChannelManager_GetChannel_Call) RunAndReturn(run func(int64, string) (RWChannel, bool)) *MockChannelManager_GetChannel_Call {
_c.Call.Return(run)
return _c
}
// GetChannelNamesByCollectionID provides a mock function with given fields: collectionID
func (_m *MockChannelManager) GetChannelNamesByCollectionID(collectionID int64) []string {
ret := _m.Called(collectionID)
var r0 []string
if rf, ok := ret.Get(0).(func(int64) []string); ok {
r0 = rf(collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
return r0
}
// MockChannelManager_GetChannelNamesByCollectionID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChannelNamesByCollectionID'
type MockChannelManager_GetChannelNamesByCollectionID_Call struct {
*mock.Call
}
// GetChannelNamesByCollectionID is a helper method to define mock.On call
// - collectionID int64
func (_e *MockChannelManager_Expecter) GetChannelNamesByCollectionID(collectionID interface{}) *MockChannelManager_GetChannelNamesByCollectionID_Call {
return &MockChannelManager_GetChannelNamesByCollectionID_Call{Call: _e.mock.On("GetChannelNamesByCollectionID", collectionID)}
}
func (_c *MockChannelManager_GetChannelNamesByCollectionID_Call) Run(run func(collectionID int64)) *MockChannelManager_GetChannelNamesByCollectionID_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
})
return _c
}
func (_c *MockChannelManager_GetChannelNamesByCollectionID_Call) Return(_a0 []string) *MockChannelManager_GetChannelNamesByCollectionID_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockChannelManager_GetChannelNamesByCollectionID_Call) RunAndReturn(run func(int64) []string) *MockChannelManager_GetChannelNamesByCollectionID_Call {
_c.Call.Return(run)
return _c
}
// GetChannelsByCollectionID provides a mock function with given fields: collectionID
func (_m *MockChannelManager) GetChannelsByCollectionID(collectionID int64) []RWChannel {
ret := _m.Called(collectionID)
@ -233,58 +332,6 @@ func (_c *MockChannelManager_GetChannelsByCollectionID_Call) RunAndReturn(run fu
return _c
}
// GetCollectionIDByChannel provides a mock function with given fields: channel
func (_m *MockChannelManager) GetCollectionIDByChannel(channel string) (bool, int64) {
ret := _m.Called(channel)
var r0 bool
var r1 int64
if rf, ok := ret.Get(0).(func(string) (bool, int64)); ok {
return rf(channel)
}
if rf, ok := ret.Get(0).(func(string) bool); ok {
r0 = rf(channel)
} else {
r0 = ret.Get(0).(bool)
}
if rf, ok := ret.Get(1).(func(string) int64); ok {
r1 = rf(channel)
} else {
r1 = ret.Get(1).(int64)
}
return r0, r1
}
// MockChannelManager_GetCollectionIDByChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionIDByChannel'
type MockChannelManager_GetCollectionIDByChannel_Call struct {
*mock.Call
}
// GetCollectionIDByChannel is a helper method to define mock.On call
// - channel string
func (_e *MockChannelManager_Expecter) GetCollectionIDByChannel(channel interface{}) *MockChannelManager_GetCollectionIDByChannel_Call {
return &MockChannelManager_GetCollectionIDByChannel_Call{Call: _e.mock.On("GetCollectionIDByChannel", channel)}
}
func (_c *MockChannelManager_GetCollectionIDByChannel_Call) Run(run func(channel string)) *MockChannelManager_GetCollectionIDByChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockChannelManager_GetCollectionIDByChannel_Call) Return(_a0 bool, _a1 int64) *MockChannelManager_GetCollectionIDByChannel_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockChannelManager_GetCollectionIDByChannel_Call) RunAndReturn(run func(string) (bool, int64)) *MockChannelManager_GetCollectionIDByChannel_Call {
_c.Call.Return(run)
return _c
}
// GetNodeChannelsByCollectionID provides a mock function with given fields: collectionID
func (_m *MockChannelManager) GetNodeChannelsByCollectionID(collectionID int64) map[int64][]string {
ret := _m.Called(collectionID)
@ -330,24 +377,24 @@ func (_c *MockChannelManager_GetNodeChannelsByCollectionID_Call) RunAndReturn(ru
}
// GetNodeIDByChannelName provides a mock function with given fields: channel
func (_m *MockChannelManager) GetNodeIDByChannelName(channel string) (bool, int64) {
func (_m *MockChannelManager) GetNodeIDByChannelName(channel string) (int64, bool) {
ret := _m.Called(channel)
var r0 bool
var r1 int64
if rf, ok := ret.Get(0).(func(string) (bool, int64)); ok {
var r0 int64
var r1 bool
if rf, ok := ret.Get(0).(func(string) (int64, bool)); ok {
return rf(channel)
}
if rf, ok := ret.Get(0).(func(string) bool); ok {
if rf, ok := ret.Get(0).(func(string) int64); ok {
r0 = rf(channel)
} else {
r0 = ret.Get(0).(bool)
r0 = ret.Get(0).(int64)
}
if rf, ok := ret.Get(1).(func(string) int64); ok {
if rf, ok := ret.Get(1).(func(string) bool); ok {
r1 = rf(channel)
} else {
r1 = ret.Get(1).(int64)
r1 = ret.Get(1).(bool)
}
return r0, r1
@ -371,12 +418,12 @@ func (_c *MockChannelManager_GetNodeIDByChannelName_Call) Run(run func(channel s
return _c
}
func (_c *MockChannelManager_GetNodeIDByChannelName_Call) Return(_a0 bool, _a1 int64) *MockChannelManager_GetNodeIDByChannelName_Call {
func (_c *MockChannelManager_GetNodeIDByChannelName_Call) Return(_a0 int64, _a1 bool) *MockChannelManager_GetNodeIDByChannelName_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockChannelManager_GetNodeIDByChannelName_Call) RunAndReturn(run func(string) (bool, int64)) *MockChannelManager_GetNodeIDByChannelName_Call {
func (_c *MockChannelManager_GetNodeIDByChannelName_Call) RunAndReturn(run func(string) (int64, bool)) *MockChannelManager_GetNodeIDByChannelName_Call {
_c.Call.Return(run)
return _c
}
@ -467,55 +514,13 @@ func (_c *MockChannelManager_Release_Call) RunAndReturn(run func(int64, string)
return _c
}
// RemoveChannel provides a mock function with given fields: channelName
func (_m *MockChannelManager) RemoveChannel(channelName string) error {
ret := _m.Called(channelName)
// Startup provides a mock function with given fields: ctx, legacyNodes, allNodes
func (_m *MockChannelManager) Startup(ctx context.Context, legacyNodes []int64, allNodes []int64) error {
ret := _m.Called(ctx, legacyNodes, allNodes)
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(channelName)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockChannelManager_RemoveChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveChannel'
type MockChannelManager_RemoveChannel_Call struct {
*mock.Call
}
// RemoveChannel is a helper method to define mock.On call
// - channelName string
func (_e *MockChannelManager_Expecter) RemoveChannel(channelName interface{}) *MockChannelManager_RemoveChannel_Call {
return &MockChannelManager_RemoveChannel_Call{Call: _e.mock.On("RemoveChannel", channelName)}
}
func (_c *MockChannelManager_RemoveChannel_Call) Run(run func(channelName string)) *MockChannelManager_RemoveChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockChannelManager_RemoveChannel_Call) Return(_a0 error) *MockChannelManager_RemoveChannel_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockChannelManager_RemoveChannel_Call) RunAndReturn(run func(string) error) *MockChannelManager_RemoveChannel_Call {
_c.Call.Return(run)
return _c
}
// Startup provides a mock function with given fields: ctx, nodes
func (_m *MockChannelManager) Startup(ctx context.Context, nodes []int64) error {
ret := _m.Called(ctx, nodes)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, []int64) error); ok {
r0 = rf(ctx, nodes)
if rf, ok := ret.Get(0).(func(context.Context, []int64, []int64) error); ok {
r0 = rf(ctx, legacyNodes, allNodes)
} else {
r0 = ret.Error(0)
}
@ -530,14 +535,15 @@ type MockChannelManager_Startup_Call struct {
// Startup is a helper method to define mock.On call
// - ctx context.Context
// - nodes []int64
func (_e *MockChannelManager_Expecter) Startup(ctx interface{}, nodes interface{}) *MockChannelManager_Startup_Call {
return &MockChannelManager_Startup_Call{Call: _e.mock.On("Startup", ctx, nodes)}
// - legacyNodes []int64
// - allNodes []int64
func (_e *MockChannelManager_Expecter) Startup(ctx interface{}, legacyNodes interface{}, allNodes interface{}) *MockChannelManager_Startup_Call {
return &MockChannelManager_Startup_Call{Call: _e.mock.On("Startup", ctx, legacyNodes, allNodes)}
}
func (_c *MockChannelManager_Startup_Call) Run(run func(ctx context.Context, nodes []int64)) *MockChannelManager_Startup_Call {
func (_c *MockChannelManager_Startup_Call) Run(run func(ctx context.Context, legacyNodes []int64, allNodes []int64)) *MockChannelManager_Startup_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].([]int64))
run(args[0].(context.Context), args[1].([]int64), args[2].([]int64))
})
return _c
}
@ -547,7 +553,7 @@ func (_c *MockChannelManager_Startup_Call) Return(_a0 error) *MockChannelManager
return _c
}
func (_c *MockChannelManager_Startup_Call) RunAndReturn(run func(context.Context, []int64) error) *MockChannelManager_Startup_Call {
func (_c *MockChannelManager_Startup_Call) RunAndReturn(run func(context.Context, []int64, []int64) error) *MockChannelManager_Startup_Call {
_c.Call.Return(run)
return _c
}

View File

@ -553,13 +553,13 @@ func (_c *MockCluster_UnRegister_Call) RunAndReturn(run func(*NodeInfo) error) *
return _c
}
// Watch provides a mock function with given fields: ctx, ch, collectionID
func (_m *MockCluster) Watch(ctx context.Context, ch string, collectionID int64) error {
ret := _m.Called(ctx, ch, collectionID)
// Watch provides a mock function with given fields: ctx, ch
func (_m *MockCluster) Watch(ctx context.Context, ch RWChannel) error {
ret := _m.Called(ctx, ch)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, int64) error); ok {
r0 = rf(ctx, ch, collectionID)
if rf, ok := ret.Get(0).(func(context.Context, RWChannel) error); ok {
r0 = rf(ctx, ch)
} else {
r0 = ret.Error(0)
}
@ -574,15 +574,14 @@ type MockCluster_Watch_Call struct {
// Watch is a helper method to define mock.On call
// - ctx context.Context
// - ch string
// - collectionID int64
func (_e *MockCluster_Expecter) Watch(ctx interface{}, ch interface{}, collectionID interface{}) *MockCluster_Watch_Call {
return &MockCluster_Watch_Call{Call: _e.mock.On("Watch", ctx, ch, collectionID)}
// - ch RWChannel
func (_e *MockCluster_Expecter) Watch(ctx interface{}, ch interface{}) *MockCluster_Watch_Call {
return &MockCluster_Watch_Call{Call: _e.mock.On("Watch", ctx, ch)}
}
func (_c *MockCluster_Watch_Call) Run(run func(ctx context.Context, ch string, collectionID int64)) *MockCluster_Watch_Call {
func (_c *MockCluster_Watch_Call) Run(run func(ctx context.Context, ch RWChannel)) *MockCluster_Watch_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(int64))
run(args[0].(context.Context), args[1].(RWChannel))
})
return _c
}
@ -592,7 +591,7 @@ func (_c *MockCluster_Watch_Call) Return(_a0 error) *MockCluster_Watch_Call {
return _c
}
func (_c *MockCluster_Watch_Call) RunAndReturn(run func(context.Context, string, int64) error) *MockCluster_Watch_Call {
func (_c *MockCluster_Watch_Call) RunAndReturn(run func(context.Context, RWChannel) error) *MockCluster_Watch_Call {
_c.Call.Return(run)
return _c
}

View File

@ -0,0 +1,137 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package datacoord
import (
context "context"
datapb "github.com/milvus-io/milvus/internal/proto/datapb"
mock "github.com/stretchr/testify/mock"
)
// MockSubCluster is an autogenerated mock type for the SubCluster type
type MockSubCluster struct {
mock.Mock
}
type MockSubCluster_Expecter struct {
mock *mock.Mock
}
func (_m *MockSubCluster) EXPECT() *MockSubCluster_Expecter {
return &MockSubCluster_Expecter{mock: &_m.Mock}
}
// CheckChannelOperationProgress provides a mock function with given fields: ctx, nodeID, info
func (_m *MockSubCluster) CheckChannelOperationProgress(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error) {
ret := _m.Called(ctx, nodeID, info)
var r0 *datapb.ChannelOperationProgressResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int64, *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error)); ok {
return rf(ctx, nodeID, info)
}
if rf, ok := ret.Get(0).(func(context.Context, int64, *datapb.ChannelWatchInfo) *datapb.ChannelOperationProgressResponse); ok {
r0 = rf(ctx, nodeID, info)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*datapb.ChannelOperationProgressResponse)
}
}
if rf, ok := ret.Get(1).(func(context.Context, int64, *datapb.ChannelWatchInfo) error); ok {
r1 = rf(ctx, nodeID, info)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockSubCluster_CheckChannelOperationProgress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckChannelOperationProgress'
type MockSubCluster_CheckChannelOperationProgress_Call struct {
*mock.Call
}
// CheckChannelOperationProgress is a helper method to define mock.On call
// - ctx context.Context
// - nodeID int64
// - info *datapb.ChannelWatchInfo
func (_e *MockSubCluster_Expecter) CheckChannelOperationProgress(ctx interface{}, nodeID interface{}, info interface{}) *MockSubCluster_CheckChannelOperationProgress_Call {
return &MockSubCluster_CheckChannelOperationProgress_Call{Call: _e.mock.On("CheckChannelOperationProgress", ctx, nodeID, info)}
}
func (_c *MockSubCluster_CheckChannelOperationProgress_Call) Run(run func(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo)) *MockSubCluster_CheckChannelOperationProgress_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(*datapb.ChannelWatchInfo))
})
return _c
}
func (_c *MockSubCluster_CheckChannelOperationProgress_Call) Return(_a0 *datapb.ChannelOperationProgressResponse, _a1 error) *MockSubCluster_CheckChannelOperationProgress_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockSubCluster_CheckChannelOperationProgress_Call) RunAndReturn(run func(context.Context, int64, *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error)) *MockSubCluster_CheckChannelOperationProgress_Call {
_c.Call.Return(run)
return _c
}
// NotifyChannelOperation provides a mock function with given fields: ctx, nodeID, req
func (_m *MockSubCluster) NotifyChannelOperation(ctx context.Context, nodeID int64, req *datapb.ChannelOperationsRequest) error {
ret := _m.Called(ctx, nodeID, req)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, *datapb.ChannelOperationsRequest) error); ok {
r0 = rf(ctx, nodeID, req)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockSubCluster_NotifyChannelOperation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NotifyChannelOperation'
type MockSubCluster_NotifyChannelOperation_Call struct {
*mock.Call
}
// NotifyChannelOperation is a helper method to define mock.On call
// - ctx context.Context
// - nodeID int64
// - req *datapb.ChannelOperationsRequest
func (_e *MockSubCluster_Expecter) NotifyChannelOperation(ctx interface{}, nodeID interface{}, req interface{}) *MockSubCluster_NotifyChannelOperation_Call {
return &MockSubCluster_NotifyChannelOperation_Call{Call: _e.mock.On("NotifyChannelOperation", ctx, nodeID, req)}
}
func (_c *MockSubCluster_NotifyChannelOperation_Call) Run(run func(ctx context.Context, nodeID int64, req *datapb.ChannelOperationsRequest)) *MockSubCluster_NotifyChannelOperation_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(*datapb.ChannelOperationsRequest))
})
return _c
}
func (_c *MockSubCluster_NotifyChannelOperation_Call) Return(_a0 error) *MockSubCluster_NotifyChannelOperation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockSubCluster_NotifyChannelOperation_Call) RunAndReturn(run func(context.Context, int64, *datapb.ChannelOperationsRequest) error) *MockSubCluster_NotifyChannelOperation_Call {
_c.Call.Return(run)
return _c
}
// NewMockSubCluster creates a new instance of MockSubCluster. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockSubCluster(t interface {
mock.TestingT
Cleanup(func())
}) *MockSubCluster {
mock := &MockSubCluster{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -17,20 +17,18 @@
package datacoord
import (
"context"
"math"
"sort"
"strconv"
"time"
"github.com/samber/lo"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// RegisterPolicy decides the channels mapping after registering the nodeID
// RegisterPolicy decides the channels mapping after registering a new nodeID
// return bufferedUpdates and balanceUpdates
type RegisterPolicy func(store ROChannelStore, nodeID int64) (*ChannelOpSet, *ChannelOpSet)
@ -47,8 +45,8 @@ func BufferChannelAssignPolicy(store ROChannelStore, nodeID int64) *ChannelOpSet
}
opSet := NewChannelOpSet(
NewDeleteOp(bufferID, lo.Values(info.Channels)...),
NewAddOp(nodeID, lo.Values(info.Channels)...))
NewChannelOp(bufferID, Delete, lo.Values(info.Channels)...),
NewChannelOp(nodeID, Watch, lo.Values(info.Channels)...))
return opSet
}
@ -61,14 +59,15 @@ func AvgAssignRegisterPolicy(store ROChannelStore, nodeID int64) (*ChannelOpSet,
}
// Get a list of available node-channel info.
avaNodes := filterNode(store.GetNodesChannels(), nodeID)
allNodes := store.GetNodesChannels()
avaNodes := filterNode(allNodes, nodeID)
channelNum := 0
for _, info := range avaNodes {
channelNum += len(info.Channels)
}
// store already add the new node
chPerNode := channelNum / len(store.GetNodes())
chPerNode := channelNum / len(allNodes)
if chPerNode == 0 {
return nil, nil
}
@ -95,7 +94,7 @@ func AvgAssignRegisterPolicy(store ROChannelStore, nodeID int64) (*ChannelOpSet,
// Channels in `releases` are reassigned eventually by channel manager.
opSet = NewChannelOpSet()
for k, v := range releases {
opSet.Add(k, v...)
opSet.Append(k, Release, v...)
}
return nil, opSet
}
@ -112,20 +111,14 @@ func filterNode(infos []*NodeChannelInfo, nodeID int64) []*NodeChannelInfo {
return filtered
}
func formatNodeID(nodeID int64) string {
return strconv.FormatInt(nodeID, 10)
}
func deformatNodeID(node string) (int64, error) {
return strconv.ParseInt(node, 10, 64)
}
// ChannelAssignPolicy assign channels to registered nodes.
// ChannelAssignPolicy assign new channels to registered nodes.
type ChannelAssignPolicy func(store ROChannelStore, channels []RWChannel) *ChannelOpSet
// AverageAssignPolicy ensure that the number of channels per nodes is approximately the same
func AverageAssignPolicy(store ROChannelStore, channels []RWChannel) *ChannelOpSet {
newChannels := filterChannels(store, channels)
newChannels := lo.Filter(channels, func(ch RWChannel, _ int) bool {
return !store.HasChannel(ch.GetName())
})
if len(newChannels) == 0 {
return nil
}
@ -135,7 +128,7 @@ func AverageAssignPolicy(store ROChannelStore, channels []RWChannel) *ChannelOpS
// If no datanode alive, save channels in buffer
if len(allDataNodes) == 0 {
opSet.Add(bufferID, channels...)
opSet.Append(bufferID, Watch, channels...)
return opSet
}
@ -151,35 +144,11 @@ func AverageAssignPolicy(store ROChannelStore, channels []RWChannel) *ChannelOpS
}
for id, chs := range updates {
opSet.Add(id, chs...)
opSet.Append(id, Watch, chs...)
}
return opSet
}
func filterChannels(store ROChannelStore, channels []RWChannel) []RWChannel {
channelsMap := make(map[string]RWChannel)
for _, c := range channels {
channelsMap[c.GetName()] = c
}
allChannelsInfo := store.GetChannels()
for _, info := range allChannelsInfo {
for _, c := range info.Channels {
delete(channelsMap, c.GetName())
}
}
if len(channelsMap) == 0 {
return nil
}
filtered := make([]RWChannel, 0, len(channelsMap))
for _, v := range channelsMap {
filtered = append(filtered, v)
}
return filtered
}
// DeregisterPolicy determine the mapping after deregistering the nodeID
type DeregisterPolicy func(store ROChannelStore, nodeID int64) *ChannelOpSet
@ -190,22 +159,21 @@ func EmptyDeregisterPolicy(store ROChannelStore, nodeID int64) *ChannelOpSet {
// AvgAssignUnregisteredChannels evenly assign the unregistered channels
func AvgAssignUnregisteredChannels(store ROChannelStore, nodeID int64) *ChannelOpSet {
allNodes := store.GetNodesChannels()
avaNodes := make([]*NodeChannelInfo, 0, len(allNodes))
unregisteredChannels := make([]RWChannel, 0)
opSet := NewChannelOpSet()
for _, c := range allNodes {
if c.NodeID == nodeID {
opSet.Delete(nodeID, lo.Values(c.Channels)...)
unregisteredChannels = append(unregisteredChannels, lo.Values(c.Channels)...)
continue
}
avaNodes = append(avaNodes, c)
nodeChannel := store.GetNode(nodeID)
if nodeChannel == nil || len(nodeChannel.Channels) == 0 {
return nil
}
unregisteredChannels := nodeChannel.Channels
avaNodes := lo.Filter(store.GetNodesChannels(), func(info *NodeChannelInfo, _ int) bool {
return info.NodeID != nodeID
})
opSet := NewChannelOpSet()
opSet.Delete(nodeChannel.NodeID, lo.Values(nodeChannel.Channels)...)
if len(avaNodes) == 0 {
opSet.Add(bufferID, unregisteredChannels...)
opSet.Append(bufferID, Watch, lo.Values(unregisteredChannels)...)
return opSet
}
@ -215,33 +183,19 @@ func AvgAssignUnregisteredChannels(store ROChannelStore, nodeID int64) *ChannelO
})
updates := make(map[int64][]RWChannel)
for i, unregisteredChannel := range unregisteredChannels {
n := avaNodes[i%len(avaNodes)].NodeID
cnt := 0
for _, unregisteredChannel := range unregisteredChannels {
n := avaNodes[cnt%len(avaNodes)].NodeID
updates[n] = append(updates[n], unregisteredChannel)
cnt++
}
for id, chs := range updates {
opSet.Add(id, chs...)
opSet.Append(id, Watch, chs...)
}
return opSet
}
type BalanceChannelPolicy func(store ROChannelStore, ts time.Time) *ChannelOpSet
func AvgBalanceChannelPolicy(store ROChannelStore, ts time.Time) *ChannelOpSet {
opSet := NewChannelOpSet()
reAllocates, err := BgBalanceCheck(store.GetNodesChannels(), ts)
if err != nil {
log.Error("failed to balance node channels", zap.Error(err))
return opSet
}
for _, reAlloc := range reAllocates {
opSet.Add(reAlloc.NodeID, lo.Values(reAlloc.Channels)...)
}
return opSet
}
// ChannelReassignPolicy is a policy for reassigning channels
type ChannelReassignPolicy func(store ROChannelStore, reassigns []*NodeChannelInfo) *ChannelOpSet
@ -250,25 +204,20 @@ func EmptyReassignPolicy(store ROChannelStore, reassigns []*NodeChannelInfo) *Ch
return nil
}
// EmptyBalancePolicy is a dummy balance policy
func EmptyBalancePolicy(store ROChannelStore, ts time.Time) *ChannelOpSet {
return nil
}
// AverageReassignPolicy is a reassigning policy that evenly balance channels among datanodes
// which is used by bgChecker
func AverageReassignPolicy(store ROChannelStore, reassigns []*NodeChannelInfo) *ChannelOpSet {
allNodes := store.GetNodesChannels()
filterMap := make(map[int64]struct{})
toReassignTotalNum := 0
for _, reassign := range reassigns {
filterMap[reassign.NodeID] = struct{}{}
toReassignTotalNum += len(reassign.Channels)
}
avaNodes := make([]*NodeChannelInfo, 0, len(allNodes))
avaNodesChannelSum := 0
for _, node := range allNodes {
if _, ok := filterMap[node.NodeID]; ok {
if lo.ContainsBy(reassigns, func(info *NodeChannelInfo) bool {
return node.NodeID == info.NodeID
}) {
continue
}
avaNodes = append(avaNodes, node)
@ -279,7 +228,6 @@ func AverageReassignPolicy(store ROChannelStore, reassigns []*NodeChannelInfo) *
if len(avaNodes) == 0 {
// if no node is left, do not reassign
log.Warn("there is no available nodes when reassigning, return")
return nil
}
@ -322,7 +270,7 @@ func AverageReassignPolicy(store ROChannelStore, reassigns []*NodeChannelInfo) *
nodeIdx++
}
if _, ok := addUpdates[targetID]; !ok {
addUpdates[targetID] = NewAddOp(targetID, ch)
addUpdates[targetID] = NewChannelOp(targetID, Watch, ch)
} else {
addUpdates[targetID].Append(ch)
}
@ -334,18 +282,19 @@ func AverageReassignPolicy(store ROChannelStore, reassigns []*NodeChannelInfo) *
return opSet
}
// ChannelBGChecker check nodes' channels and return the channels needed to be reallocated.
type ChannelBGChecker func(ctx context.Context)
type Assignments []*NodeChannelInfo
// EmptyBgChecker does nothing
func EmptyBgChecker(channels []*NodeChannelInfo, ts time.Time) ([]*NodeChannelInfo, error) {
return nil, nil
func (a Assignments) GetChannelCount(nodeID int64) int {
for _, info := range a {
if info.NodeID == nodeID {
return len(info.Channels)
}
}
return 0
}
type ReAllocates []*NodeChannelInfo
func (rallocates ReAllocates) MarshalLogArray(enc zapcore.ArrayEncoder) error {
for _, nChannelInfo := range rallocates {
func (a Assignments) MarshalLogArray(enc zapcore.ArrayEncoder) error {
for _, nChannelInfo := range a {
enc.AppendString("nodeID:")
enc.AppendInt64(nChannelInfo.NodeID)
cstr := "["
@ -362,22 +311,33 @@ func (rallocates ReAllocates) MarshalLogArray(enc zapcore.ArrayEncoder) error {
return nil
}
func BgBalanceCheck(nodeChannels []*NodeChannelInfo, ts time.Time) ([]*NodeChannelInfo, error) {
avaNodeNum := len(nodeChannels)
reAllocations := make(ReAllocates, 0, avaNodeNum)
// BalanceChannelPolicy try to balance watched channels to registered nodes
type BalanceChannelPolicy func(cluster Assignments) *ChannelOpSet
// EmptyBalancePolicy is a dummy balance policy
func EmptyBalancePolicy(cluster Assignments) *ChannelOpSet {
return nil
}
// AvgBalanceChannelPolicy tries to balance channel evenly
func AvgBalanceChannelPolicy(cluster Assignments) *ChannelOpSet {
avaNodeNum := len(cluster)
if avaNodeNum == 0 {
return reAllocations, nil
return nil
}
reAllocations := make(Assignments, 0, avaNodeNum)
totalChannelNum := 0
for _, nodeChs := range nodeChannels {
for _, nodeChs := range cluster {
totalChannelNum += len(nodeChs.Channels)
}
channelCountPerNode := totalChannelNum / avaNodeNum
for _, nChannels := range nodeChannels {
for _, nChannels := range cluster {
chCount := len(nChannels.Channels)
if chCount <= channelCountPerNode+1 {
log.Info("node channel count is not much larger than average, skip reallocate",
zap.Int64("nodeID", nChannels.NodeID), zap.Int("channelCount", chCount),
zap.Int64("nodeID", nChannels.NodeID),
zap.Int("channelCount", chCount),
zap.Int("channelCountPerNode", channelCountPerNode))
continue
}
@ -392,25 +352,136 @@ func BgBalanceCheck(nodeChannels []*NodeChannelInfo, ts time.Time) ([]*NodeChann
}
reAllocations = append(reAllocations, reallocate)
}
log.Info("Channel Balancer got new reAllocations:", zap.Array("reAllocations", reAllocations))
return reAllocations, nil
}
func formatNodeIDs(ids []int64) []string {
formatted := make([]string, 0, len(ids))
for _, id := range ids {
formatted = append(formatted, formatNodeID(id))
if len(reAllocations) == 0 {
return nil
}
return formatted
opSet := NewChannelOpSet()
for _, reAlloc := range reAllocations {
opSet.Append(reAlloc.NodeID, Release, lo.Values(reAlloc.Channels)...)
}
return opSet
}
func formatNodeIDsWithFilter(ids []int64, filter int64) []string {
formatted := make([]string, 0, len(ids))
for _, id := range ids {
if id == filter {
continue
func AvgAssignByCountPolicy(currentCluster Assignments, unassignedChannels []RWChannel, execlusiveNodes []int64) *ChannelOpSet {
var (
toCluster Assignments
fromCluster Assignments
channelNum int = 0
)
nodeToAvg := typeutil.NewUniqueSet()
lo.ForEach(currentCluster, func(info *NodeChannelInfo, _ int) {
if !lo.Contains(execlusiveNodes, info.NodeID) {
toCluster = append(toCluster, info)
nodeToAvg.Insert(info.NodeID)
}
formatted = append(formatted, formatNodeID(id))
if len(info.Channels) > 0 {
fromCluster = append(fromCluster, info)
channelNum += len(info.Channels)
nodeToAvg.Insert(info.NodeID)
}
})
// If no datanode alive, do nothing
if len(toCluster) == 0 {
return nil
}
return formatted
// 1. assign unassigned channels first
if len(unassignedChannels) > 0 {
chPerNode := (len(unassignedChannels) + channelNum) / nodeToAvg.Len()
// sort by assigned channels count ascsending
sort.Slice(toCluster, func(i, j int) bool {
return len(toCluster[i].Channels) <= len(toCluster[j].Channels)
})
nodesLackOfChannels := Assignments(lo.Filter(toCluster, func(info *NodeChannelInfo, _ int) bool {
return len(info.Channels) < chPerNode
}))
if len(nodesLackOfChannels) == 0 {
nodesLackOfChannels = toCluster
}
updates := make(map[int64][]RWChannel)
for i, newChannel := range unassignedChannels {
n := nodesLackOfChannels[i%len(nodesLackOfChannels)].NodeID
updates[n] = append(updates[n], newChannel)
}
opSet := NewChannelOpSet()
for id, chs := range updates {
opSet.Append(id, Watch, chs...)
opSet.Delete(bufferID, chs...)
}
log.Info("Assign channels to nodes by channel count",
zap.Int("channel count", len(unassignedChannels)),
zap.Int("cluster count", len(toCluster)),
zap.Int64s("exclusive nodes", execlusiveNodes),
zap.Any("operations", opSet),
zap.Int64s("nodesLackOfChannels", lo.Map(nodesLackOfChannels, func(info *NodeChannelInfo, _ int) int64 {
return info.NodeID
})),
)
return opSet
}
if !Params.DataCoordCfg.AutoBalance.GetAsBool() {
log.Info("auto balance disabled")
return nil
}
// 2. balance fromCluster to toCluster if no unassignedChannels
if len(fromCluster) == 0 {
return nil
}
chPerNode := channelNum / nodeToAvg.Len()
if chPerNode == 0 {
return nil
}
// sort in descending order and reallocate
sort.Slice(fromCluster, func(i, j int) bool {
return len(fromCluster[i].Channels) > len(fromCluster[j].Channels)
})
releases := make(map[int64][]RWChannel)
for _, info := range fromCluster {
if len(info.Channels) > chPerNode {
cnt := 0
for _, ch := range info.Channels {
cnt++
if cnt > chPerNode {
releases[info.NodeID] = append(releases[info.NodeID], ch)
}
}
}
}
// Channels in `releases` are reassigned eventually by channel manager.
opSet := NewChannelOpSet()
for k, v := range releases {
if lo.Contains(execlusiveNodes, k) {
opSet.Append(k, Delete, v...)
opSet.Append(bufferID, Watch, v...)
} else {
opSet.Append(k, Release, v...)
}
}
log.Info("Assign channels to nodes by channel count",
zap.Int64s("exclusive nodes", execlusiveNodes),
zap.Int("channel count", channelNum),
zap.Int("channel per node", chPerNode),
zap.Any("operations", opSet),
zap.Array("fromCluster", fromCluster),
zap.Array("toCluster", toCluster),
)
return opSet
}

File diff suppressed because it is too large Load Diff

View File

@ -343,6 +343,7 @@ func (s *Server) initDataCoord() error {
log.Info("init rootcoord client done")
s.broker = broker.NewCoordinatorBroker(s.rootCoordClient)
s.allocator = newRootCoordAllocator(s.rootCoordClient)
storageCli, err := s.newChunkManagerFactory()
if err != nil {
@ -364,8 +365,6 @@ func (s *Server) initDataCoord() error {
}
log.Info("init datanode cluster done")
s.allocator = newRootCoordAllocator(s.rootCoordClient)
s.initIndexNodeManager()
if err = s.initServiceDiscovery(); err != nil {
@ -466,6 +465,13 @@ func (s *Server) startDataCoord() {
sessionutil.SaveServerInfo(typeutil.DataCoordRole, s.session.GetServerID())
}
func (s *Server) GetServerID() int64 {
if s.session != nil {
return s.session.GetServerID()
}
return paramtable.GetNodeID()
}
func (s *Server) afterStart() {}
func (s *Server) initCluster() error {
@ -473,13 +479,20 @@ func (s *Server) initCluster() error {
return nil
}
var err error
s.channelManager, err = NewChannelManager(s.watchClient, s.handler, withMsgstreamFactory(s.factory),
withStateChecker(), withBgChecker())
if err != nil {
return err
}
s.sessionManager = NewSessionManagerImpl(withSessionCreator(s.dataNodeCreator))
var err error
if paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.GetAsBool() {
s.channelManager, err = NewChannelManagerV2(s.watchClient, s.handler, s.sessionManager, s.allocator, withCheckerV2())
if err != nil {
return err
}
} else {
s.channelManager, err = NewChannelManager(s.watchClient, s.handler, withMsgstreamFactory(s.factory), withStateChecker(), withBgChecker())
if err != nil {
return err
}
}
s.cluster = NewClusterImpl(s.sessionManager, s.channelManager)
return nil
}
@ -559,11 +572,21 @@ func (s *Server) initServiceDiscovery() error {
log.Info("DataCoord success to get DataNode sessions", zap.Any("sessions", sessions))
datanodes := make([]*NodeInfo, 0, len(sessions))
legacyVersion, err := semver.Parse(paramtable.Get().DataCoordCfg.LegacyVersionWithoutRPCWatch.GetValue())
if err != nil {
log.Warn("DataCoord failed to init service discovery", zap.Error(err))
}
for _, session := range sessions {
info := &NodeInfo{
NodeID: session.ServerID,
Address: session.Address,
}
if session.Version.LTE(legacyVersion) {
info.IsLegacy = true
}
datanodes = append(datanodes, info)
}

View File

@ -2424,7 +2424,7 @@ func TestOptions(t *testing.T) {
defer kv.RemoveWithPrefix("")
sessionManager := NewSessionManagerImpl()
channelManager, err := NewChannelManager(kv, newMockHandler())
channelManager, err := NewChannelManagerV2(kv, newMockHandler(), sessionManager, newMockAllocator())
assert.NoError(t, err)
cluster := NewClusterImpl(sessionManager, channelManager)
@ -2479,7 +2479,7 @@ func TestHandleSessionEvent(t *testing.T) {
defer cancel()
sessionManager := NewSessionManagerImpl()
channelManager, err := NewChannelManager(kv, newMockHandler(), withFactory(&mockPolicyFactory{}))
channelManager, err := NewChannelManagerV2(kv, newMockHandler(), sessionManager, newMockAllocator(), withFactoryV2(&mockPolicyFactory{}))
assert.NoError(t, err)
cluster := NewClusterImpl(sessionManager, channelManager)

View File

@ -1249,20 +1249,14 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq
}, nil
}
for _, channelName := range req.GetChannelNames() {
ch := &channelMeta{
Name: channelName,
CollectionID: req.GetCollectionID(),
StartPositions: req.GetStartPositions(),
Schema: req.GetSchema(),
CreateTimestamp: req.GetCreateTimestamp(),
}
ch := NewRWChannel(channelName, req.GetCollectionID(), req.GetStartPositions(), req.GetSchema(), req.GetCreateTimestamp())
err := s.channelManager.Watch(ctx, ch)
if err != nil {
log.Warn("fail to watch channelName", zap.Error(err))
resp.Status = merr.Status(err)
return resp, nil
}
if err := s.meta.catalog.MarkChannelAdded(ctx, ch.Name); err != nil {
if err := s.meta.catalog.MarkChannelAdded(ctx, channelName); err != nil {
// TODO: add background task to periodically cleanup the orphaned channel add marks.
log.Error("failed to mark channel added", zap.Error(err))
resp.Status = merr.Status(err)

View File

@ -50,7 +50,7 @@ func WithChannelManager(cm ChannelManager) Option {
func (s *ServerSuite) SetupTest() {
s.mockChMgr = NewMockChannelManager(s.T())
s.mockChMgr.EXPECT().Startup(mock.Anything, mock.Anything).Return(nil).Maybe()
s.mockChMgr.EXPECT().Startup(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
s.mockChMgr.EXPECT().Close().Maybe()
s.testServer = newTestServer(s.T(), WithChannelManager(s.mockChMgr))

View File

@ -30,8 +30,9 @@ var errDisposed = errors.New("client is disposed")
// NodeInfo contains node base info
type NodeInfo struct {
NodeID int64
Address string
NodeID int64
Address string
IsLegacy bool
}
// Session contains session info of a node

View File

@ -35,7 +35,7 @@ func (s *SessionManagerSuite) SetupTest() {
return s.dn, nil
}))
s.m.AddSession(&NodeInfo{1000, "addr-1"})
s.m.AddSession(&NodeInfo{1000, "addr-1", true})
s.MetricsEqual(metrics.DataCoordNumDataNodes, 1)
}

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -52,24 +53,22 @@ type ChannelManagerImpl struct {
releaseFunc releaseFunc
closeCh chan struct{}
closeOnce sync.Once
closeCh lifetime.SafeChan
closeWaiter sync.WaitGroup
}
func NewChannelManager(dn *DataNode) *ChannelManagerImpl {
fm := newFlowgraphManager()
cm := ChannelManagerImpl{
dn: dn,
fgManager: fm,
fgManager: dn.flowgraphManager,
communicateCh: make(chan *opState, 100),
opRunners: typeutil.NewConcurrentMap[string, *opRunner](),
abnormals: typeutil.NewConcurrentMap[int64, string](),
releaseFunc: fm.RemoveFlowgraph,
releaseFunc: dn.flowgraphManager.RemoveFlowgraph,
closeCh: make(chan struct{}),
closeCh: lifetime.NewSafeChan(),
}
return &cm
@ -131,14 +130,14 @@ func (m *ChannelManagerImpl) GetProgress(info *datapb.ChannelWatchInfo) *datapb.
}
func (m *ChannelManagerImpl) Close() {
m.closeOnce.Do(func() {
if m.opRunners != nil {
m.opRunners.Range(func(channel string, runner *opRunner) bool {
runner.Close()
return true
})
close(m.closeCh)
m.closeWaiter.Wait()
})
}
m.closeCh.Close()
m.closeWaiter.Wait()
}
func (m *ChannelManagerImpl) Start() {
@ -150,7 +149,7 @@ func (m *ChannelManagerImpl) Start() {
select {
case opState := <-m.communicateCh:
m.handleOpState(opState)
case <-m.closeCh:
case <-m.closeCh.CloseCh():
log.Info("DataNode ChannelManager exit")
return
}
@ -170,23 +169,19 @@ func (m *ChannelManagerImpl) handleOpState(opState *opState) {
case datapb.ChannelWatchState_WatchSuccess:
log.Info("Success to watch")
m.fgManager.AddFlowgraph(opState.fg)
m.finishOp(opState.opID, opState.channel)
case datapb.ChannelWatchState_WatchFailure:
log.Info("Fail to watch")
m.finishOp(opState.opID, opState.channel)
case datapb.ChannelWatchState_ReleaseSuccess:
log.Info("Success to release")
m.finishOp(opState.opID, opState.channel)
m.destoryRunner(opState.channel)
case datapb.ChannelWatchState_ReleaseFailure:
log.Info("Fail to release, add channel to abnormal lists")
m.abnormals.Insert(opState.opID, opState.channel)
m.finishOp(opState.opID, opState.channel)
m.destoryRunner(opState.channel)
}
m.finishOp(opState.opID, opState.channel)
}
func (m *ChannelManagerImpl) getOrCreateRunner(channel string) *opRunner {
@ -197,15 +192,10 @@ func (m *ChannelManagerImpl) getOrCreateRunner(channel string) *opRunner {
return runner
}
func (m *ChannelManagerImpl) destoryRunner(channel string) {
if runner, loaded := m.opRunners.GetAndRemove(channel); loaded {
runner.Close()
}
}
func (m *ChannelManagerImpl) finishOp(opID int64, channel string) {
if runner, loaded := m.opRunners.Get(channel); loaded {
if runner, loaded := m.opRunners.GetAndRemove(channel); loaded {
runner.FinishOp(opID)
runner.Close()
}
}
@ -223,9 +213,8 @@ type opRunner struct {
opsInQueue chan *datapb.ChannelWatchInfo
resultCh chan *opState
closeWg sync.WaitGroup
closeOnce sync.Once
closeCh chan struct{}
closeCh lifetime.SafeChan
closeWg sync.WaitGroup
}
func NewOpRunner(channel string, dn *DataNode, f releaseFunc, resultCh chan *opState) *opRunner {
@ -236,7 +225,7 @@ func NewOpRunner(channel string, dn *DataNode, f releaseFunc, resultCh chan *opS
opsInQueue: make(chan *datapb.ChannelWatchInfo, 10),
allOps: make(map[UniqueID]*opInfo),
resultCh: resultCh,
closeCh: make(chan struct{}),
closeCh: lifetime.NewSafeChan(),
}
}
@ -248,7 +237,7 @@ func (r *opRunner) Start() {
select {
case info := <-r.opsInQueue:
r.NotifyState(r.Execute(info))
case <-r.closeCh:
case <-r.closeCh.CloseCh():
return
}
}
@ -301,7 +290,7 @@ func (r *opRunner) Execute(info *datapb.ChannelWatchInfo) *opState {
}
// ToRelease state
return releaseWithTimer(r.releaseFunc, info.GetVchan().GetChannelName(), info.GetOpID())
return r.releaseWithTimer(r.releaseFunc, info.GetVchan().GetChannelName(), info.GetOpID())
}
// watchWithTimer will return WatchFailure after WatchTimeoutInterval
@ -314,13 +303,13 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState {
r.guard.Lock()
opInfo, ok := r.allOps[info.GetOpID()]
r.guard.Unlock()
if !ok {
opState.state = datapb.ChannelWatchState_WatchFailure
return opState
}
tickler := newTickler()
opInfo.tickler = tickler
r.guard.Unlock()
var (
successSig = make(chan struct{}, 1)
@ -348,6 +337,13 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState {
log.Info("Stop timer for ToWatch operation timeout")
return
case <-r.closeCh.CloseCh():
// runner closed from outside
tickler.close()
cancel()
log.Info("Suspend ToWatch operation from outside of opRunner")
return
case <-tickler.progressSig:
log.Info("Reset timer for tickler updated")
timer.Reset(watchTimeout)
@ -379,7 +375,7 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState {
}
// releaseWithTimer will return ReleaseFailure after WatchTimeoutInterval
func releaseWithTimer(releaseFunc releaseFunc, channel string, opID UniqueID) *opState {
func (r *opRunner) releaseWithTimer(releaseFunc releaseFunc, channel string, opID UniqueID) *opState {
opState := &opState{
channel: channel,
opID: opID,
@ -389,23 +385,29 @@ func releaseWithTimer(releaseFunc releaseFunc, channel string, opID UniqueID) *o
waiter sync.WaitGroup
)
log := log.With(zap.String("channel", channel))
log := log.With(zap.Int64("opID", opID), zap.String("channel", channel))
startTimer := func(wg *sync.WaitGroup) {
defer wg.Done()
releaseTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second)
timer := time.NewTimer(releaseTimeout)
defer timer.Stop()
log.Info("Start timer for ToRelease operation", zap.Duration("timeout", releaseTimeout))
log := log.With(zap.Duration("timeout", releaseTimeout))
log.Info("Start ToRelease timer")
for {
select {
case <-timer.C:
log.Info("Stop timer for ToRelease operation timeout", zap.Duration("timeout", releaseTimeout))
log.Info("Stop timer for ToRelease operation timeout")
opState.state = datapb.ChannelWatchState_ReleaseFailure
return
case <-r.closeCh.CloseCh():
// runner closed from outside
log.Info("Stop timer for opRunner closed")
return
case <-successSig:
log.Info("Stop timer for ToRelease operation succeeded", zap.Duration("timeout", releaseTimeout))
log.Info("Stop timer for ToRelease operation succeeded")
opState.state = datapb.ChannelWatchState_ReleaseSuccess
return
}
@ -436,18 +438,8 @@ func (r *opRunner) NotifyState(state *opState) {
}
func (r *opRunner) Close() {
r.guard.Lock()
for _, info := range r.allOps {
if info.tickler != nil {
info.tickler.close()
}
}
r.guard.Unlock()
r.closeOnce.Do(func() {
close(r.closeCh)
r.closeWg.Wait()
})
r.closeCh.Close()
r.closeWg.Wait()
}
type opState struct {

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus/internal/datanode/allocator"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
@ -34,6 +35,44 @@ func TestChannelManagerSuite(t *testing.T) {
suite.Run(t, new(ChannelManagerSuite))
}
func TestOpRunnerSuite(t *testing.T) {
suite.Run(t, new(OpRunnerSuite))
}
func (s *OpRunnerSuite) SetupTest() {
ctx := context.Background()
s.mockAlloc = allocator.NewMockAllocator(s.T())
s.node = newIDLEDataNodeMock(ctx, schemapb.DataType_Int64)
s.node.allocator = s.mockAlloc
}
func (s *OpRunnerSuite) TestWatchWithTimer() {
var (
channel string = "ch-1"
commuCh = make(chan *opState)
)
info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
mockReleaseFunc := func(channel string) {
log.Info("mock release func")
}
runner := NewOpRunner(channel, s.node, mockReleaseFunc, commuCh)
err := runner.Enqueue(info)
s.Require().NoError(err)
opState := runner.watchWithTimer(info)
s.NotNil(opState.fg)
s.Equal(channel, opState.channel)
runner.FinishOp(100)
}
type OpRunnerSuite struct {
suite.Suite
node *DataNode
mockAlloc *allocator.MockAllocator
}
type ChannelManagerSuite struct {
suite.Suite
@ -45,6 +84,8 @@ func (s *ChannelManagerSuite) SetupTest() {
ctx := context.Background()
s.node = newIDLEDataNodeMock(ctx, schemapb.DataType_Int64)
s.node.allocator = allocator.NewMockAllocator(s.T())
s.node.flowgraphManager = newFlowgraphManager()
s.manager = NewChannelManager(s.node)
}
@ -80,7 +121,9 @@ func getWatchInfoByOpID(opID UniqueID, channel string, state datapb.ChannelWatch
}
func (s *ChannelManagerSuite) TearDownTest() {
s.manager.Close()
if s.manager != nil {
s.manager.Close()
}
}
func (s *ChannelManagerSuite) TestWatchFail() {
@ -167,11 +210,12 @@ func (s *ChannelManagerSuite) TestSubmitIdempotent() {
func (s *ChannelManagerSuite) TestSubmitWatchAndRelease() {
channel := "by-dev-rootcoord-dml-0"
// watch
info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
err := s.manager.Submit(info)
s.NoError(err)
// wait for result
opState := <-s.manager.communicateCh
s.NotNil(opState)
s.Equal(datapb.ChannelWatchState_WatchSuccess, opState.state)
@ -184,8 +228,8 @@ func (s *ChannelManagerSuite) TestSubmitWatchAndRelease() {
s.manager.handleOpState(opState)
s.Equal(1, s.manager.fgManager.GetFlowgraphCount())
s.True(s.manager.opRunners.Contain(info.GetVchan().GetChannelName()))
s.Equal(1, s.manager.opRunners.Len())
s.False(s.manager.opRunners.Contain(info.GetVchan().GetChannelName()))
s.Equal(0, s.manager.opRunners.Len())
resp = s.manager.GetProgress(info)
s.Equal(info.GetOpID(), resp.GetOpID())
@ -193,10 +237,10 @@ func (s *ChannelManagerSuite) TestSubmitWatchAndRelease() {
// release
info = getWatchInfoByOpID(101, channel, datapb.ChannelWatchState_ToRelease)
err = s.manager.Submit(info)
s.NoError(err)
// wait for result
opState = <-s.manager.communicateCh
s.NotNil(opState)
s.Equal(datapb.ChannelWatchState_ReleaseSuccess, opState.state)

View File

@ -83,7 +83,6 @@ var Params *paramtable.ComponentParam = paramtable.Get()
// `segmentCache` stores all flushing and flushed segments.
type DataNode struct {
ctx context.Context
serverID int64
cancel context.CancelFunc
Role string
stateCode atomic.Value // commonpb.StateCode_Initializing
@ -129,7 +128,7 @@ type DataNode struct {
}
// NewDataNode will return a DataNode with abnormal state.
func NewDataNode(ctx context.Context, factory dependency.Factory, serverID int64) *DataNode {
func NewDataNode(ctx context.Context, factory dependency.Factory) *DataNode {
rand.Seed(time.Now().UnixNano())
ctx2, cancel2 := context.WithCancel(ctx)
node := &DataNode{
@ -140,13 +139,10 @@ func NewDataNode(ctx context.Context, factory dependency.Factory, serverID int64
rootCoord: nil,
dataCoord: nil,
factory: factory,
serverID: serverID,
segmentCache: newCache(),
compactionExecutor: newCompactionExecutor(),
eventManager: NewEventManager(),
flowgraphManager: newFlowgraphManager(),
clearSignal: make(chan string, 100),
clearSignal: make(chan string, 100),
reportImportRetryTimes: 10,
}
@ -228,10 +224,10 @@ func (node *DataNode) initRateCollector() error {
}
func (node *DataNode) GetNodeID() int64 {
if node.serverID == 0 && node.session != nil {
if node.session != nil {
return node.session.ServerID
}
return node.serverID
return paramtable.GetNodeID()
}
func (node *DataNode) Init() error {
@ -294,6 +290,13 @@ func (node *DataNode) Init() error {
node.importTaskMgr = importv2.NewTaskManager()
node.importScheduler = importv2.NewScheduler(node.importTaskMgr, node.syncMgr, node.chunkManager)
node.channelCheckpointUpdater = newChannelCheckpointUpdater(node)
node.flowgraphManager = newFlowgraphManager()
if paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.GetAsBool() {
node.channelManager = NewChannelManager(node)
} else {
node.eventManager = NewEventManager()
}
log.Info("init datanode done", zap.String("Address", node.address))
})
@ -322,9 +325,15 @@ func (node *DataNode) handleChannelEvt(evt *clientv3.Event) {
// tryToReleaseFlowgraph tries to release a flowgraph
func (node *DataNode) tryToReleaseFlowgraph(channel string) {
log.Info("try to release flowgraph", zap.String("channel", channel))
node.compactionExecutor.discardPlan(channel)
node.flowgraphManager.RemoveFlowgraph(channel)
node.writeBufferManager.RemoveChannel(channel)
if node.compactionExecutor != nil {
node.compactionExecutor.discardPlan(channel)
}
if node.flowgraphManager != nil {
node.flowgraphManager.RemoveFlowgraph(channel)
}
if node.writeBufferManager != nil {
node.writeBufferManager.RemoveChannel(channel)
}
}
// BackGroundGC runs in background to release datanode resources
@ -398,8 +407,12 @@ func (node *DataNode) Start() error {
go node.channelCheckpointUpdater.start()
// Start node watch node
node.startWatchChannelsAtBackground(node.ctx)
if paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.GetAsBool() {
node.channelManager.Start()
} else {
// Start node watch node
node.startWatchChannelsAtBackground(node.ctx)
}
node.UpdateStateCode(commonpb.StateCode_Healthy)
})
@ -433,9 +446,13 @@ func (node *DataNode) Stop() error {
node.stopOnce.Do(func() {
// https://github.com/milvus-io/milvus/issues/12282
node.UpdateStateCode(commonpb.StateCode_Abnormal)
if node.channelManager != nil {
node.channelManager.Close()
}
node.flowgraphManager.Close()
node.eventManager.CloseAll()
if node.eventManager != nil {
node.eventManager.CloseAll()
}
if node.writeBufferManager != nil {
node.writeBufferManager.Stop()
@ -466,6 +483,7 @@ func (node *DataNode) Stop() error {
node.importScheduler.Close()
}
// Delay the cancellation of ctx to ensure that the session is automatically recycled after closed the flow graph
node.cancel()
node.stopWaiter.Wait()
})

View File

@ -92,7 +92,7 @@ func (node *DataNode) StartWatchChannels(ctx context.Context) {
// serves the corner case for etcd connection lost and missing some events
func (node *DataNode) checkWatchedList() error {
// REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name}
prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.serverID))
prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.GetNodeID()))
keys, values, err := node.watchKv.LoadWithPrefix(prefix)
if err != nil {
return err

View File

@ -42,6 +42,9 @@ import (
func TestWatchChannel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key)
node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64)
etcdCli, err := etcd.GetEtcdClient(
Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),

View File

@ -83,7 +83,7 @@ var segID2SegInfo = map[int64]*datapb.SegmentInfo{
func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode {
factory := dependency.NewDefaultFactory(true)
node := NewDataNode(ctx, factory, 1)
node := NewDataNode(ctx, factory)
node.SetSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}})
node.dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID())

View File

@ -333,6 +333,11 @@ func (node *DataNode) NotifyChannelOperation(ctx context.Context, req *datapb.Ch
log.Ctx(ctx).Info("DataNode receives NotifyChannelOperation",
zap.Int("operation count", len(req.GetInfos())))
if node.channelManager == nil {
log.Warn("DataNode NotifyChannelOperation failed due to nil channelManager")
return merr.Status(merr.WrapErrServiceInternal("channelManager is nil! Ignore if you are upgrading datanode/coord to rpc based watch")), nil
}
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
log.Warn("DataNode.NotifyChannelOperation failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err))
return merr.Status(err), nil
@ -356,6 +361,14 @@ func (node *DataNode) CheckChannelOperationProgress(ctx context.Context, req *da
)
log.Info("DataNode receives CheckChannelOperationProgress")
if node.channelManager == nil {
log.Warn("DataNode CheckChannelOperationProgress failed due to nil channelManager")
return &datapb.ChannelOperationProgressResponse{
Status: merr.Status(merr.WrapErrServiceInternal("channelManager is nil! Ignore if you are upgrading datanode/coord to rpc based watch")),
}, nil
}
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
log.Warn("DataNode.CheckChannelOperationProgress failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err))
return &datapb.ChannelOperationProgressResponse{

View File

@ -635,40 +635,6 @@ func (s *DataNodeServicesSuite) TestResendSegmentStats() {
s.Assert().True(merr.Ok(resp.GetStatus()), "empty call, status shall be OK")
}
/*
func (s *DataNodeServicesSuite) TestFlushChannels() {
dmChannelName := "fake-by-dev-rootcoord-dml-channel-TestFlushChannels"
vChan := &datapb.VchannelInfo{
CollectionID: 1,
ChannelName: dmChannelName,
UnflushedSegmentIds: []int64{},
FlushedSegmentIds: []int64{},
}
err := s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, vChan, nil, genTestTickler())
s.Require().NoError(err)
fgService, ok := s.node.flowgraphManager.getFlowgraphService(dmChannelName)
s.Require().True(ok)
flushTs := Timestamp(100)
req := &datapb.FlushChannelsRequest{
Base: &commonpb.MsgBase{
TargetID: s.node.GetSession().ServerID,
},
FlushTs: flushTs,
Channels: []string{dmChannelName},
}
status, err := s.node.FlushChannels(s.ctx, req)
s.Assert().NoError(err)
s.Assert().True(merr.Ok(status))
s.Assert().True(fgService.channel.getFlushTs() == flushTs)
}*/
func (s *DataNodeServicesSuite) TestRPCWatch() {
s.Run("node not healthy", func() {
s.SetupTest()
@ -686,22 +652,16 @@ func (s *DataNodeServicesSuite) TestRPCWatch() {
s.ErrorIs(merr.Error(status), merr.ErrServiceNotReady)
})
s.Run("node healthy", func() {
s.Run("submit error", func() {
s.SetupTest()
mockChManager := NewMockChannelManager(s.T())
s.node.channelManager = mockChManager
mockChManager.EXPECT().Submit(mock.Anything).Return(nil).Once()
ctx := context.Background()
status, err := s.node.NotifyChannelOperation(ctx, &datapb.ChannelOperationsRequest{Infos: []*datapb.ChannelWatchInfo{{OpID: 19530}}})
s.NoError(err)
s.True(merr.Ok(status))
mockChManager.EXPECT().GetProgress(mock.Anything).Return(
&datapb.ChannelOperationProgressResponse{Status: merr.Status(nil)},
).Once()
s.False(merr.Ok(status))
s.NotErrorIs(merr.Error(status), merr.ErrServiceNotReady)
resp, err := s.node.CheckChannelOperationProgress(ctx, nil)
s.NoError(err)
s.True(merr.Ok(resp.GetStatus()))
s.False(merr.Ok(resp.GetStatus()))
})
}

View File

@ -180,7 +180,7 @@ func (s *Server) startGrpcLoop(grpcPort int) {
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
s.serverID.Store(s.dataCoord.(*datacoord.Server).GetServerID())
}
return s.serverID.Load()
}),
@ -191,7 +191,7 @@ func (s *Server) startGrpcLoop(grpcPort int) {
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
s.serverID.Store(s.dataCoord.(*datacoord.Server).GetServerID())
}
return s.serverID.Load()
}),

View File

@ -91,7 +91,7 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error)
}
s.serverID.Store(paramtable.GetNodeID())
s.datanode = dn.NewDataNode(s.ctx, s.factory, s.serverID.Load())
s.datanode = dn.NewDataNode(s.ctx, s.factory)
return s, nil
}

View File

@ -81,9 +81,8 @@ func (pool *Pool[T]) Submit(method func() (T, error)) *Future[T] {
res, err := method()
if err != nil {
future.err = err
} else {
future.value = res
}
future.value = res
})
if err != nil {
future.err = err

View File

@ -2601,6 +2601,8 @@ user-task-polling:
type dataCoordConfig struct {
// --- CHANNEL ---
WatchTimeoutInterval ParamItem `refreshable:"false"`
EnableBalanceChannelWithRPC ParamItem `refreshable:"false"`
LegacyVersionWithoutRPCWatch ParamItem `refreshable:"false"`
ChannelBalanceSilentDuration ParamItem `refreshable:"true"`
ChannelBalanceInterval ParamItem `refreshable:"true"`
ChannelCheckInterval ParamItem `refreshable:"true"`
@ -2692,6 +2694,24 @@ func (p *dataCoordConfig) init(base *BaseTable) {
}
p.WatchTimeoutInterval.Init(base.mgr)
p.EnableBalanceChannelWithRPC = ParamItem{
Key: "dataCoord.channel.balanceWithRpc",
Version: "2.4.0",
DefaultValue: "true",
Doc: "Whether to enable balance with RPC, default to use etcd watch",
Export: true,
}
p.EnableBalanceChannelWithRPC.Init(base.mgr)
p.LegacyVersionWithoutRPCWatch = ParamItem{
Key: "dataCoord.channel.legacyVersionWithoutRPCWatch",
Version: "2.4.0",
DefaultValue: "2.4.0",
Doc: "Datanodes <= this version are considered as legacy nodes, which doesn't have rpc based watch(). This is only used during rolling upgrade where legacy nodes won't get new channels",
Export: true,
}
p.LegacyVersionWithoutRPCWatch.Init(base.mgr)
p.ChannelBalanceSilentDuration = ParamItem{
Key: "dataCoord.channel.balanceSilentDuration",
Version: "2.2.3",
@ -2713,7 +2733,7 @@ func (p *dataCoordConfig) init(base *BaseTable) {
p.ChannelCheckInterval = ParamItem{
Key: "dataCoord.channel.checkInterval",
Version: "2.4.0",
DefaultValue: "10",
DefaultValue: "1",
Doc: "The interval in seconds with which the channel manager advances channel states",
Export: true,
}

View File

@ -414,17 +414,19 @@ func (cluster *MiniClusterV2) StopAllQueryNodes() {
for _, node := range cluster.querynodes {
node.Stop()
}
log.Info(fmt.Sprintf("mini cluster stoped %d extra querynode", numExtraQN))
cluster.querynodes = nil
log.Info(fmt.Sprintf("mini cluster stopped %d extra querynode", numExtraQN))
}
func (cluster *MiniClusterV2) StopAllDataNodes() {
cluster.DataNode.Stop()
log.Info("mini cluster main dataNode stopped")
numExtraQN := len(cluster.datanodes)
numExtraDN := len(cluster.datanodes)
for _, node := range cluster.datanodes {
node.Stop()
}
log.Info(fmt.Sprintf("mini cluster stoped %d extra datanode", numExtraQN))
cluster.datanodes = nil
log.Info(fmt.Sprintf("mini cluster stopped %d extra datanode", numExtraDN))
}
func (cluster *MiniClusterV2) GetContext() context.Context {

View File

@ -0,0 +1,365 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package watchcompatibility
import (
"context"
"fmt"
"strconv"
"sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
grpcdatacoord "github.com/milvus-io/milvus/internal/distributed/datacoord"
grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/tests/integration"
)
type DataNodeCompatibility struct {
integration.MiniClusterSuite
maxGoRoutineNum int
dim int
numCollections int
rowsPerCollection int
waitTimeInSec time.Duration
prefix string
}
func (s *DataNodeCompatibility) setupParam() {
s.maxGoRoutineNum = 100
s.dim = 128
s.numCollections = 1
s.rowsPerCollection = 100
s.waitTimeInSec = time.Second * 1
}
func (s *DataNodeCompatibility) flush(collectionName string) {
c := s.Cluster
flushResp, err := c.Proxy.Flush(context.TODO(), &milvuspb.FlushRequest{
DbName: "",
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
s.Require().True(has)
s.Require().NotEmpty(segmentIDs)
ids := segmentIDs.GetData()
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
s.WaitForFlush(context.TODO(), ids, flushTs, "", collectionName)
}
func (s *DataNodeCompatibility) loadCollection(collectionName string) {
c := s.Cluster
dbName := ""
schema := integration.ConstructSchema(collectionName, s.dim, true)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
err = merr.Error(createCollectionStatus)
s.NoError(err)
showCollectionsResp, err := c.Proxy.ShowCollections(context.TODO(), &milvuspb.ShowCollectionsRequest{})
s.NoError(err)
s.True(merr.Ok(showCollectionsResp.GetStatus()))
batchSize := 500000
for start := 0; start < s.rowsPerCollection; start += batchSize {
rowNum := batchSize
if start+batchSize > s.rowsPerCollection {
rowNum = s.rowsPerCollection - start
}
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, s.dim)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.True(merr.Ok(insertResult.GetStatus()))
}
s.flush(collectionName)
// create index
createIndexStatus, err := c.Proxy.CreateIndex(context.TODO(), &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(s.dim, integration.IndexFaissIvfFlat, metric.IP),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
s.NoError(err)
s.WaitForIndexBuilt(context.TODO(), collectionName, integration.FloatVecField)
// load
loadStatus, err := c.Proxy.LoadCollection(context.TODO(), &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
err = merr.Error(loadStatus)
s.NoError(err)
s.WaitForLoad(context.TODO(), collectionName)
}
func (s *DataNodeCompatibility) checkCollections() bool {
req := &milvuspb.ShowCollectionsRequest{
DbName: "",
TimeStamp: 0, // means now
}
resp, err := s.Cluster.Proxy.ShowCollections(context.TODO(), req)
s.NoError(err)
s.Equal(len(resp.CollectionIds), s.numCollections)
notLoaded := 0
loaded := 0
for _, name := range resp.CollectionNames {
loadProgress, err := s.Cluster.Proxy.GetLoadingProgress(context.TODO(), &milvuspb.GetLoadingProgressRequest{
DbName: "",
CollectionName: name,
})
s.NoError(err)
if loadProgress.GetProgress() != int64(100) {
notLoaded++
} else {
loaded++
}
}
return notLoaded == 0
}
func (s *DataNodeCompatibility) search(collectionName string, currentNumRows int) {
c := s.Cluster
var err error
// Query
queryReq := &milvuspb.QueryRequest{
Base: nil,
CollectionName: collectionName,
PartitionNames: nil,
Expr: "",
OutputFields: []string{"count(*)"},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
queryResult, err := c.Proxy.Query(context.TODO(), queryReq)
s.NoError(err)
s.Equal(len(queryResult.FieldsData), 1)
numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]
s.Equal(numEntities, int64(currentNumRows))
// Search
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
radius := 10
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
params["radius"] = radius
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, s.dim, topk, roundDecimal)
searchResult, _ := c.Proxy.Search(context.TODO(), searchReq)
err = merr.Error(searchResult.GetStatus())
s.NoError(err)
}
func (s *DataNodeCompatibility) insertBatchCollections(prefix string, collectionBatchSize, idxStart int, wg *sync.WaitGroup) {
for idx := 0; idx < collectionBatchSize; idx++ {
collectionName := prefix + "_" + strconv.Itoa(idxStart+idx)
s.loadCollection(collectionName)
}
wg.Done()
}
func (s *DataNodeCompatibility) insert(collectionName string, rowNum int) {
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, s.dim)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := s.Cluster.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{
DbName: "",
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.True(merr.Ok(insertResult.GetStatus()))
s.flush(collectionName)
}
func (s *DataNodeCompatibility) insertAndCheck(collectionName string, currentNumRows *int, testInsert bool) {
s.search(collectionName, *currentNumRows)
insertRows := 1000
if testInsert {
s.insert(collectionName, insertRows)
*currentNumRows += insertRows
}
s.search(collectionName, *currentNumRows)
}
func (s *DataNodeCompatibility) setupData() {
// Add the second data node
s.Cluster.AddDataNode()
goRoutineNum := s.maxGoRoutineNum
if goRoutineNum > s.numCollections {
goRoutineNum = s.numCollections
}
collectionBatchSize := s.numCollections / goRoutineNum
log.Info(fmt.Sprintf("=========================test with dim=%d, s.rowsPerCollection=%d, s.numCollections=%d, goRoutineNum=%d==================", s.dim, s.rowsPerCollection, s.numCollections, goRoutineNum))
log.Info("=========================Start to inject data=========================")
s.prefix = "TestDataNodeUtil" + funcutil.GenRandomStr()
searchName := s.prefix + "_0"
wg := sync.WaitGroup{}
for idx := 0; idx < goRoutineNum; idx++ {
wg.Add(1)
go s.insertBatchCollections(s.prefix, collectionBatchSize, idx*collectionBatchSize, &wg)
}
wg.Wait()
log.Info("=========================Data injection finished=========================")
s.checkCollections()
log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName))
s.search(searchName, s.rowsPerCollection)
log.Info("=========================Search finished=========================")
time.Sleep(s.waitTimeInSec)
s.checkCollections()
log.Info(fmt.Sprintf("=========================start to search2 %s=========================", searchName))
s.search(searchName, s.rowsPerCollection)
log.Info("=========================Search2 finished=========================")
s.checkAllCollectionsReady()
}
func (s *DataNodeCompatibility) checkAllCollectionsReady() {
goRoutineNum := s.maxGoRoutineNum
if goRoutineNum > s.numCollections {
goRoutineNum = s.numCollections
}
collectionBatchSize := s.numCollections / goRoutineNum
for i := 0; i < goRoutineNum; i++ {
for idx := 0; idx < collectionBatchSize; idx++ {
collectionName := s.prefix + "_" + strconv.Itoa(i*collectionBatchSize+idx)
s.search(collectionName, s.rowsPerCollection)
queryReq := &milvuspb.QueryRequest{
CollectionName: collectionName,
Expr: "",
OutputFields: []string{"count(*)"},
}
_, err := s.Cluster.Proxy.Query(context.TODO(), queryReq)
s.NoError(err)
}
}
}
func (s *DataNodeCompatibility) checkSingleDNRestarts(currentNumRows *int, numNodes, idx int, testInsert bool) {
// Stop all data nodes
s.Cluster.StopAllDataNodes()
// Add new data nodes.
var dn []*grpcdatanode.Server
for i := 0; i < numNodes; i++ {
dn = append(dn, s.Cluster.AddDataNode())
}
time.Sleep(s.waitTimeInSec)
cn := fmt.Sprintf("%s_0", s.prefix)
s.insertAndCheck(cn, currentNumRows, testInsert)
dn[idx].Stop()
time.Sleep(s.waitTimeInSec)
s.insertAndCheck(cn, currentNumRows, testInsert)
}
func (s *DataNodeCompatibility) checkDNRestarts(currentNumRows *int, testInsert bool) {
numDatanodes := 2 // configurable
for idx := 0; idx < numDatanodes; idx++ {
s.checkSingleDNRestarts(currentNumRows, numDatanodes, idx, testInsert)
}
}
func (s *DataNodeCompatibility) restartDC() {
c := s.Cluster
c.DataCoord.Stop()
c.DataCoord = grpcdatacoord.NewServer(context.TODO(), c.GetFactory())
err := c.DataCoord.Run()
s.NoError(err)
}
func (s *DataNodeCompatibility) TestCompatibility() {
s.setupParam()
s.setupData()
rows := s.rowsPerCollection
// new coord + new node
s.checkDNRestarts(&rows, true)
// new coord + old node
paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key, "false")
s.checkDNRestarts(&rows, false)
// old coord + old node
s.restartDC()
s.checkDNRestarts(&rows, true)
// old coord + new node
paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key, "true")
s.checkDNRestarts(&rows, false)
// new coord + both old & new datanodes.
paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key, "false")
s.restartDC()
s.Cluster.StopAllDataNodes()
d1 := s.Cluster.AddDataNode()
d2 := s.Cluster.AddDataNode()
cn := fmt.Sprintf("%s_0", s.prefix)
s.insertAndCheck(cn, &rows, true)
paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key, "true")
s.restartDC()
s.insertAndCheck(cn, &rows, false)
s.Cluster.AddDataNode()
d1.Stop()
s.checkDNRestarts(&rows, true)
s.Cluster.AddDataNode()
d2.Stop()
s.checkDNRestarts(&rows, true)
}
func TestDataNodeCompatibility(t *testing.T) {
suite.Run(t, new(DataNodeCompatibility))
}