From b1eacb2ae88b9f13705962e31a61e998fac8c921 Mon Sep 17 00:00:00 2001 From: yiwangdr <80064917+yiwangdr@users.noreply.github.com> Date: Tue, 7 May 2024 00:49:30 -0700 Subject: [PATCH] feat: datacoord/node watch based on rpc (#32036) issue: https://github.com/milvus-io/milvus/issues/25309 Signed-off-by: yiwangdr --- Makefile | 1 + internal/datacoord/channel.go | 200 ++- internal/datacoord/channel_manager.go | 134 +- internal/datacoord/channel_manager_factory.go | 12 +- internal/datacoord/channel_manager_test.go | 46 +- internal/datacoord/channel_manager_v2.go | 727 +++++++++++ internal/datacoord/channel_manager_v2_test.go | 661 ++++++++++ internal/datacoord/channel_store.go | 217 ++-- internal/datacoord/channel_store_test.go | 4 +- internal/datacoord/channel_store_v2.go | 432 +++++++ internal/datacoord/channel_store_v2_test.go | 483 ++++++++ internal/datacoord/cluster.go | 28 +- internal/datacoord/cluster_test.go | 16 +- internal/datacoord/mock_channel_store.go | 345 ++++-- internal/datacoord/mock_channelmanager.go | 236 ++-- internal/datacoord/mock_cluster.go | 23 +- internal/datacoord/mock_subcluster.go | 137 +++ internal/datacoord/policy.go | 311 +++-- internal/datacoord/policy_test.go | 1095 +++++++++-------- internal/datacoord/server.go | 39 +- internal/datacoord/server_test.go | 4 +- internal/datacoord/services.go | 10 +- internal/datacoord/services_test.go | 2 +- internal/datacoord/session.go | 5 +- internal/datacoord/session_manager_test.go | 2 +- internal/datanode/channel_manager.go | 88 +- internal/datanode/channel_manager_test.go | 54 +- internal/datanode/data_node.go | 48 +- internal/datanode/event_manager.go | 2 +- internal/datanode/event_manager_test.go | 3 + internal/datanode/mock_test.go | 2 +- internal/datanode/services.go | 13 + internal/datanode/services_test.go | 48 +- internal/distributed/datacoord/service.go | 4 +- internal/distributed/datanode/service.go | 2 +- pkg/util/conc/pool.go | 3 +- pkg/util/paramtable/component_param.go | 22 +- tests/integration/minicluster_v2.go | 8 +- .../watchcompatibility/watch_test.go | 365 ++++++ 39 files changed, 4577 insertions(+), 1255 deletions(-) create mode 100644 internal/datacoord/channel_manager_v2.go create mode 100644 internal/datacoord/channel_manager_v2_test.go create mode 100644 internal/datacoord/channel_store_v2.go create mode 100644 internal/datacoord/channel_store_v2_test.go create mode 100644 internal/datacoord/mock_subcluster.go create mode 100644 tests/integration/watchcompatibility/watch_test.go diff --git a/Makefile b/Makefile index cec50b01cb..3edb13eb4c 100644 --- a/Makefile +++ b/Makefile @@ -457,6 +457,7 @@ generate-mockery-datacoord: getdeps $(INSTALL_PATH)/mockery --name=CompactionMeta --dir=internal/datacoord --filename=mock_compaction_meta.go --output=internal/datacoord --structname=MockCompactionMeta --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=Scheduler --dir=internal/datacoord --filename=mock_scheduler.go --output=internal/datacoord --structname=MockScheduler --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=ChannelManager --dir=internal/datacoord --filename=mock_channelmanager.go --output=internal/datacoord --structname=MockChannelManager --with-expecter --inpackage + $(INSTALL_PATH)/mockery --name=SubCluster --dir=internal/datacoord --filename=mock_subcluster.go --output=internal/datacoord --structname=MockSubCluster --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=Broker --dir=internal/datacoord/broker --filename=mock_coordinator_broker.go --output=internal/datacoord/broker --structname=MockBroker --with-expecter --inpackage generate-mockery-datanode: getdeps diff --git a/internal/datacoord/channel.go b/internal/datacoord/channel.go index 6eaf9df007..f2a8f44bf8 100644 --- a/internal/datacoord/channel.go +++ b/internal/datacoord/channel.go @@ -19,9 +19,14 @@ package datacoord import ( "fmt" + "github.com/gogo/protobuf/proto" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) type ROChannel interface { @@ -39,7 +44,30 @@ type RWChannel interface { UpdateWatchInfo(info *datapb.ChannelWatchInfo) } -var _ RWChannel = (*channelMeta)(nil) +func NewRWChannel(name string, + collectionID int64, + startPos []*commonpb.KeyDataPair, + schema *schemapb.CollectionSchema, + createTs uint64, +) RWChannel { + if paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.GetAsBool() { + return &StateChannel{ + Name: name, + CollectionID: collectionID, + StartPositions: startPos, + Schema: schema, + CreateTimestamp: createTs, + } + } + + return &channelMeta{ + Name: name, + CollectionID: collectionID, + StartPositions: startPos, + Schema: schema, + CreateTimestamp: createTs, + } +} type channelMeta struct { Name string @@ -50,8 +78,13 @@ type channelMeta struct { WatchInfo *datapb.ChannelWatchInfo } +var _ RWChannel = (*channelMeta)(nil) + func (ch *channelMeta) UpdateWatchInfo(info *datapb.ChannelWatchInfo) { - ch.WatchInfo = info + log.Info("Channel updating watch info", + zap.Any("old watch info", ch.WatchInfo), + zap.Any("new watch info", info)) + ch.WatchInfo = proto.Clone(info).(*datapb.ChannelWatchInfo) } func (ch *channelMeta) GetWatchInfo() *datapb.ChannelWatchInfo { @@ -83,3 +116,166 @@ func (ch *channelMeta) String() string { // schema maybe too large to print return fmt.Sprintf("Name: %s, CollectionID: %d, StartPositions: %v", ch.Name, ch.CollectionID, ch.StartPositions) } + +type ChannelState string + +const ( + Standby ChannelState = "Standby" + ToWatch ChannelState = "ToWatch" + Watching ChannelState = "Watching" + Watched ChannelState = "Watched" + ToRelease ChannelState = "ToRelease" + Releasing ChannelState = "Releasing" + Legacy ChannelState = "Legacy" +) + +type StateChannel struct { + Name string + CollectionID UniqueID + StartPositions []*commonpb.KeyDataPair + Schema *schemapb.CollectionSchema + CreateTimestamp uint64 + Info *datapb.ChannelWatchInfo + + currentState ChannelState + assignedNode int64 +} + +var _ RWChannel = (*StateChannel)(nil) + +func NewStateChannel(ch RWChannel) *StateChannel { + c := &StateChannel{ + Name: ch.GetName(), + CollectionID: ch.GetCollectionID(), + StartPositions: ch.GetStartPositions(), + Schema: ch.GetSchema(), + CreateTimestamp: ch.GetCreateTimestamp(), + Info: ch.GetWatchInfo(), + + assignedNode: bufferID, + } + + c.setState(Standby) + return c +} + +func NewStateChannelByWatchInfo(nodeID int64, info *datapb.ChannelWatchInfo) *StateChannel { + c := &StateChannel{ + Name: info.GetVchan().GetChannelName(), + CollectionID: info.GetVchan().GetCollectionID(), + Schema: info.GetSchema(), + Info: info, + assignedNode: nodeID, + } + + switch info.GetState() { + case datapb.ChannelWatchState_ToWatch: + c.setState(ToWatch) + case datapb.ChannelWatchState_ToRelease: + c.setState(ToRelease) + // legacy state + case datapb.ChannelWatchState_WatchSuccess: + c.setState(Watched) + case datapb.ChannelWatchState_WatchFailure, datapb.ChannelWatchState_ReleaseSuccess, datapb.ChannelWatchState_ReleaseFailure: + c.setState(Standby) + default: + c.setState(Standby) + } + + if nodeID == bufferID { + c.setState(Standby) + } + return c +} + +func (c *StateChannel) TransitionOnSuccess() { + switch c.currentState { + case Standby: + c.setState(ToWatch) + case ToWatch: + c.setState(Watching) + case Watching: + c.setState(Watched) + case Watched: + c.setState(ToRelease) + case ToRelease: + c.setState(Releasing) + case Releasing: + c.setState(Standby) + } +} + +func (c *StateChannel) TransitionOnFailure() { + switch c.currentState { + case Watching: + c.setState(Standby) + case Releasing: + c.setState(Standby) + case Standby, ToWatch, Watched, ToRelease: + // Stay original state + } +} + +func (c *StateChannel) Clone() *StateChannel { + return &StateChannel{ + Name: c.Name, + CollectionID: c.CollectionID, + StartPositions: c.StartPositions, + Schema: c.Schema, + CreateTimestamp: c.CreateTimestamp, + Info: proto.Clone(c.Info).(*datapb.ChannelWatchInfo), + + currentState: c.currentState, + assignedNode: c.assignedNode, + } +} + +func (c *StateChannel) String() string { + // schema maybe too large to print + return fmt.Sprintf("Name: %s, CollectionID: %d, StartPositions: %v", c.Name, c.CollectionID, c.StartPositions) +} + +func (c *StateChannel) GetName() string { + return c.Name +} + +func (c *StateChannel) GetCollectionID() UniqueID { + return c.CollectionID +} + +func (c *StateChannel) GetStartPositions() []*commonpb.KeyDataPair { + return c.StartPositions +} + +func (c *StateChannel) GetSchema() *schemapb.CollectionSchema { + return c.Schema +} + +func (c *StateChannel) GetCreateTimestamp() Timestamp { + return c.CreateTimestamp +} + +func (c *StateChannel) GetWatchInfo() *datapb.ChannelWatchInfo { + return c.Info +} + +func (c *StateChannel) UpdateWatchInfo(info *datapb.ChannelWatchInfo) { + if c.Info != nil && c.Info.Vchan != nil && info.GetVchan().GetChannelName() != c.Info.GetVchan().GetChannelName() { + log.Warn("Updating incorrect channel watch info", + zap.Any("old watch info", c.Info), + zap.Any("new watch info", info), + zap.Stack("call stack"), + ) + return + } + + c.Info = proto.Clone(info).(*datapb.ChannelWatchInfo) +} + +func (c *StateChannel) Assign(nodeID int64) { + c.assignedNode = nodeID +} + +func (c *StateChannel) setState(state ChannelState) { + c.currentState = state +} diff --git a/internal/datacoord/channel_manager.go b/internal/datacoord/channel_manager.go index 5f4d0bc136..396fcd6eaf 100644 --- a/internal/datacoord/channel_manager.go +++ b/internal/datacoord/channel_manager.go @@ -36,25 +36,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/logutil" ) -type ChannelManager interface { - Startup(ctx context.Context, nodes []int64) error - Close() - - AddNode(nodeID int64) error - DeleteNode(nodeID int64) error - Watch(ctx context.Context, ch RWChannel) error - RemoveChannel(channelName string) error - Release(nodeID UniqueID, channelName string) error - - Match(nodeID int64, channel string) bool - FindWatcher(channel string) (int64, error) - - GetNodeChannelsByCollectionID(collectionID UniqueID) map[UniqueID][]string - GetChannelsByCollectionID(collectionID UniqueID) []RWChannel - GetCollectionIDByChannel(channel string) (bool, UniqueID) - GetNodeIDByChannelName(channel string) (bool, UniqueID) -} - // ChannelManagerImpl manages the allocation and the balance between channels and data nodes. type ChannelManagerImpl struct { ctx context.Context @@ -66,8 +47,8 @@ type ChannelManagerImpl struct { deregisterPolicy DeregisterPolicy assignPolicy ChannelAssignPolicy reassignPolicy ChannelReassignPolicy - bgChecker ChannelBGChecker balancePolicy BalanceChannelPolicy + bgChecker ChannelBGChecker msgstreamFactory msgstream.Factory stateChecker channelStateChecker @@ -105,7 +86,7 @@ func NewChannelManager( c := &ChannelManagerImpl{ ctx: context.TODO(), h: h, - factory: NewChannelPolicyFactoryV1(kv), + factory: NewChannelPolicyFactoryV1(), store: NewChannelStore(kv), stateTimer: newChannelStateTimer(kv), } @@ -128,7 +109,7 @@ func NewChannelManager( } // Startup adjusts the channel store according to current cluster states. -func (c *ChannelManagerImpl) Startup(ctx context.Context, nodes []int64) error { +func (c *ChannelManagerImpl) Startup(ctx context.Context, legacyNodes, allNodes []int64) error { c.ctx = ctx channels := c.store.GetNodesChannels() // Retrieve the current old nodes. @@ -138,13 +119,13 @@ func (c *ChannelManagerImpl) Startup(ctx context.Context, nodes []int64) error { } // Process watch states for old nodes. - oldOnLines := c.getOldOnlines(nodes, oNodes) + oldOnLines := c.getOldOnlines(allNodes, oNodes) if err := c.checkOldNodes(oldOnLines); err != nil { return err } // Add new online nodes to the cluster. - newOnLines := c.getNewOnLines(nodes, oNodes) + newOnLines := c.getNewOnLines(allNodes, oNodes) for _, n := range newOnLines { if err := c.AddNode(n); err != nil { return err @@ -152,7 +133,7 @@ func (c *ChannelManagerImpl) Startup(ctx context.Context, nodes []int64) error { } // Remove new offline nodes from the cluster. - offLines := c.getOffLines(nodes, oNodes) + offLines := c.getOffLines(allNodes, oNodes) for _, n := range offLines { if err := c.DeleteNode(n); err != nil { return err @@ -176,7 +157,7 @@ func (c *ChannelManagerImpl) Startup(ctx context.Context, nodes []int64) error { } log.Info("cluster start up", - zap.Int64s("nodes", nodes), + zap.Int64s("nodes", allNodes), zap.Int64s("oNodes", oNodes), zap.Int64s("old onlines", oldOnLines), zap.Int64s("new onlines", newOnLines), @@ -247,7 +228,7 @@ func (c *ChannelManagerImpl) checkOldNodes(nodes []UniqueID) error { // unwatchDroppedChannels removes drops channel that are marked to drop. func (c *ChannelManagerImpl) unwatchDroppedChannels() { - nodeChannels := c.store.GetChannels() + nodeChannels := c.store.GetNodesChannels() for _, nodeChannel := range nodeChannels { for _, ch := range nodeChannel.Channels { if !c.isMarkedDrop(ch.GetName()) { @@ -284,9 +265,14 @@ func (c *ChannelManagerImpl) bgCheckChannelsWork(ctx context.Context) { if !c.isSilent() { log.Info("ChannelManager is not silent, skip channel balance this round") } else { - toReleases := c.balancePolicy(c.store, time.Now()) - log.Info("channel manager bg check balance", zap.Array("toReleases", toReleases)) - if err := c.updateWithTimer(toReleases, datapb.ChannelWatchState_ToRelease); err != nil { + currCluster := c.store.GetNodesChannels() + updates := c.balancePolicy(currCluster) + if updates == nil { + continue + } + + log.Info("channel manager bg check balance", zap.Array("toReleases", updates)) + if err := c.updateWithTimer(updates, datapb.ChannelWatchState_ToRelease); err != nil { log.Warn("channel store update error", zap.Error(err)) } } @@ -345,7 +331,7 @@ func (c *ChannelManagerImpl) AddNode(nodeID int64) error { c.mu.Lock() defer c.mu.Unlock() - c.store.Add(nodeID) + c.store.AddNode(nodeID) bufferedUpdates, balanceUpdates := c.registerPolicy(c.store, nodeID) @@ -386,6 +372,7 @@ func (c *ChannelManagerImpl) DeleteNode(nodeID int64) error { nodeChannelInfo := c.store.GetNode(nodeID) if nodeChannelInfo == nil { + c.store.RemoveNode(nodeID) return nil } @@ -393,6 +380,7 @@ func (c *ChannelManagerImpl) DeleteNode(nodeID int64) error { updates := c.deregisterPolicy(c.store, nodeID) if updates == nil { + c.store.RemoveNode(nodeID) return nil } log.Info("deregister node", zap.Int64("nodeID", nodeID), zap.Array("updates", updates)) @@ -417,8 +405,8 @@ func (c *ChannelManagerImpl) DeleteNode(nodeID int64) error { } // No channels will be return - _, err := c.store.Delete(nodeID) - return err + c.store.RemoveNode(nodeID) + return nil } // unsubAttempt attempts to unsubscribe node-channel info from the channel. @@ -558,25 +546,21 @@ func (c *ChannelManagerImpl) Match(nodeID int64, channel string) bool { } // FindWatcher finds the datanode watching the provided channel. -func (c *ChannelManagerImpl) FindWatcher(channel string) (int64, error) { +func (c *ChannelManagerImpl) FindWatcher(channelName string) (int64, error) { c.mu.RLock() defer c.mu.RUnlock() infos := c.store.GetNodesChannels() for _, info := range infos { - for _, channelInfo := range info.Channels { - if channelInfo.GetName() == channel { - return info.NodeID, nil - } + if _, ok := info.Channels[channelName]; ok { + return info.NodeID, nil } } // channel in buffer bufferInfo := c.store.GetBufferChannelInfo() - for _, channelInfo := range bufferInfo.Channels { - if channelInfo.GetName() == channel { - return bufferID, errChannelInBuffer - } + if _, ok := bufferInfo.Channels[channelName]; ok { + return bufferID, errChannelInBuffer } return 0, errChannelNotWatched } @@ -610,10 +594,8 @@ func (c *ChannelManagerImpl) remove(nodeID int64, ch RWChannel) error { func (c *ChannelManagerImpl) findChannel(channelName string) (int64, RWChannel) { infos := c.store.GetNodesChannels() for _, info := range infos { - for _, channelInfo := range info.Channels { - if channelInfo.GetName() == channelName { - return info.NodeID, channelInfo - } + if channelInfo, ok := info.Channels[channelName]; ok { + return info.NodeID, channelInfo } } return 0, nil @@ -640,7 +622,7 @@ type ackEvent struct { func (c *ChannelManagerImpl) updateWithTimer(updates *ChannelOpSet, state datapb.ChannelWatchState) error { channelsWithTimer := []string{} for _, op := range updates.Collect() { - if op.Type == Add { + if op.Type != Delete { channelsWithTimer = append(channelsWithTimer, c.fillChannelWatchInfoWithState(op, state)...) } } @@ -807,14 +789,9 @@ func (c *ChannelManagerImpl) Reassign(originNodeID UniqueID, channelName string) reallocates := NewNodeChannelInfo(originNodeID, ch) isDropped := c.isMarkedDrop(channelName) - c.mu.Lock() - defer c.mu.Unlock() - ch = c.getChannelByNodeAndName(originNodeID, channelName) - if ch == nil { - return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", originNodeID, channelName) - } - if isDropped { + c.mu.Lock() + defer c.mu.Unlock() if err := c.remove(originNodeID, ch); err != nil { return fmt.Errorf("failed to remove watch info: %v,%s", ch, err.Error()) } @@ -825,6 +802,8 @@ func (c *ChannelManagerImpl) Reassign(originNodeID UniqueID, channelName string) return nil } + c.mu.Lock() + defer c.mu.Unlock() // Reassign policy won't choose the original node when a reassigning a channel. updates := c.reassignPolicy(c.store, []*NodeChannelInfo{reallocates}) if updates == nil { @@ -864,11 +843,6 @@ func (c *ChannelManagerImpl) CleanupAndReassign(nodeID UniqueID, channelName str c.mu.Lock() defer c.mu.Unlock() - chToCleanUp = c.getChannelByNodeAndName(nodeID, channelName) - if chToCleanUp == nil { - return fmt.Errorf("failed to find matching channel: %s and node: %d", channelName, nodeID) - } - if isDropped { if err := c.remove(nodeID, chToCleanUp); err != nil { return fmt.Errorf("failed to remove watch info: %v,%s", chToCleanUp, err.Error()) @@ -900,42 +874,38 @@ func (c *ChannelManagerImpl) CleanupAndReassign(nodeID UniqueID, channelName str } func (c *ChannelManagerImpl) getChannelByNodeAndName(nodeID UniqueID, channelName string) RWChannel { - var ret RWChannel - - nodeChannelInfo := c.store.GetNode(nodeID) - if nodeChannelInfo == nil { - return nil - } - - for _, channel := range nodeChannelInfo.Channels { - if channel.GetName() == channelName { - ret = channel - break + if nodeChannelInfo := c.store.GetNode(nodeID); nodeChannelInfo != nil { + if ch, ok := nodeChannelInfo.Channels[channelName]; ok { + return ch } } - return ret + return nil } -func (c *ChannelManagerImpl) GetCollectionIDByChannel(channel string) (bool, UniqueID) { +func (c *ChannelManagerImpl) GetCollectionIDByChannel(channelName string) (bool, UniqueID) { for _, nodeChannel := range c.GetAssignedChannels() { - for _, ch := range nodeChannel.Channels { - if ch.GetName() == channel { - return true, ch.GetCollectionID() - } + if ch, ok := nodeChannel.Channels[channelName]; ok { + return true, ch.GetCollectionID() } } return false, 0 } -func (c *ChannelManagerImpl) GetNodeIDByChannelName(channel string) (bool, UniqueID) { +func (c *ChannelManagerImpl) GetNodeIDByChannelName(channelName string) (UniqueID, bool) { for _, nodeChannel := range c.GetAssignedChannels() { - for _, ch := range nodeChannel.Channels { - if ch.GetName() == channel { - return true, nodeChannel.NodeID - } + if _, ok := nodeChannel.Channels[channelName]; ok { + return nodeChannel.NodeID, true } } - return false, 0 + return 0, false +} + +func (c *ChannelManagerImpl) GetChannel(nodeID int64, channelName string) (RWChannel, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + ch := c.getChannelByNodeAndName(nodeID, channelName) + return ch, ch != nil } func (c *ChannelManagerImpl) isMarkedDrop(channel string) bool { diff --git a/internal/datacoord/channel_manager_factory.go b/internal/datacoord/channel_manager_factory.go index 69d999c86f..cadaafc46e 100644 --- a/internal/datacoord/channel_manager_factory.go +++ b/internal/datacoord/channel_manager_factory.go @@ -16,10 +16,6 @@ package datacoord -import ( - "github.com/milvus-io/milvus/internal/kv" -) - // ChannelPolicyFactory is the abstract factory that creates policies for channel manager. type ChannelPolicyFactory interface { // NewRegisterPolicy creates a new register policy. @@ -35,13 +31,11 @@ type ChannelPolicyFactory interface { } // ChannelPolicyFactoryV1 equal to policy batch -type ChannelPolicyFactoryV1 struct { - kv kv.TxnKV -} +type ChannelPolicyFactoryV1 struct{} // NewChannelPolicyFactoryV1 helper function creates a Channel policy factory v1 from kv. -func NewChannelPolicyFactoryV1(kv kv.TxnKV) *ChannelPolicyFactoryV1 { - return &ChannelPolicyFactoryV1{kv: kv} +func NewChannelPolicyFactoryV1() *ChannelPolicyFactoryV1 { + return &ChannelPolicyFactoryV1{} } // NewRegisterPolicy implementing ChannelPolicyFactory returns BufferChannelAssignPolicy. diff --git a/internal/datacoord/channel_manager_test.go b/internal/datacoord/channel_manager_test.go index 7e890b55d3..d255e64ac9 100644 --- a/internal/datacoord/channel_manager_test.go +++ b/internal/datacoord/channel_manager_test.go @@ -492,7 +492,7 @@ func TestChannelManager(t *testing.T) { waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, bufferID, bufferCh, collectionID) - chManager.store.Add(nodeID) + chManager.store.AddNode(nodeID) err = chManager.Watch(context.TODO(), &channelMeta{Name: chanToAdd, CollectionID: collectionID}) assert.NoError(t, err) waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, nodeID, chanToAdd, collectionID) @@ -544,7 +544,7 @@ func TestChannelManager(t *testing.T) { // prepare tests for _, test := range tests { - chManager.store.Add(test.nodeID) + chManager.store.AddNode(test.nodeID) ops := getTestOps(test.nodeID, &channelMeta{Name: test.chName, CollectionID: collectionID, WatchInfo: &datapb.ChannelWatchInfo{}}) err = chManager.store.Update(ops) require.NoError(t, err) @@ -557,7 +557,7 @@ func TestChannelManager(t *testing.T) { remainTest, reassignTest := tests[0], tests[1] err = chManager.Reassign(reassignTest.nodeID, reassignTest.chName) assert.NoError(t, err) - chManager.stateTimer.stopIfExist(&ackEvent{releaseSuccessAck, reassignTest.chName, reassignTest.nodeID}) + chManager.stateTimer.stopIfExist(&ackEvent{watchSuccessAck, reassignTest.chName, reassignTest.nodeID}) // test nodes of reassignTest contains no channel // test all channels are assgined to node of remainTest @@ -587,6 +587,7 @@ func TestChannelManager(t *testing.T) { t.Run("test Reassign with dropped channel", func(t *testing.T) { collectionID := UniqueID(5) + watchkv.RemoveWithPrefix("") handler := NewNMockHandler(t) handler.EXPECT(). CheckShouldDropChannel(mock.Anything). @@ -595,7 +596,7 @@ func TestChannelManager(t *testing.T) { chManager, err := NewChannelManager(watchkv, handler) require.NoError(t, err) - chManager.store.Add(1) + chManager.store.AddNode(1) ops := getTestOps(1, &channelMeta{Name: "chan", CollectionID: collectionID, WatchInfo: &datapb.ChannelWatchInfo{}}) err = chManager.store.Update(ops) require.NoError(t, err) @@ -610,24 +611,16 @@ func TestChannelManager(t *testing.T) { var chManager *ChannelManagerImpl var err error handler := NewNMockHandler(t) - handler.EXPECT(). - CheckShouldDropChannel(mock.Anything). - Run(func(channel string) { - channels, err := chManager.store.Delete(1) - assert.NoError(t, err) - assert.Equal(t, 1, len(channels)) - }).Return(true).Once() - chManager, err = NewChannelManager(watchkv, handler) require.NoError(t, err) - chManager.store.Add(1) + chManager.store.AddNode(1) ops := getTestOps(1, &channelMeta{Name: "chan", CollectionID: 1, WatchInfo: &datapb.ChannelWatchInfo{}}) err = chManager.store.Update(ops) require.NoError(t, err) assert.Equal(t, 1, chManager.store.GetNodeChannelCount(1)) - err = chManager.Reassign(1, "chan") + err = chManager.Reassign(2, "chan") assert.Error(t, err) }) @@ -635,24 +628,18 @@ func TestChannelManager(t *testing.T) { var chManager *ChannelManagerImpl var err error handler := NewNMockHandler(t) - handler.EXPECT(). - CheckShouldDropChannel(mock.Anything). - Run(func(channel string) { - channels, err := chManager.store.Delete(1) - assert.NoError(t, err) - assert.Equal(t, 1, len(channels)) - }).Return(true).Once() + watchkv.RemoveWithPrefix("") chManager, err = NewChannelManager(watchkv, handler) require.NoError(t, err) - chManager.store.Add(1) + chManager.store.AddNode(1) ops := getTestOps(1, &channelMeta{Name: "chan", CollectionID: 1, WatchInfo: &datapb.ChannelWatchInfo{}}) err = chManager.store.Update(ops) require.NoError(t, err) assert.Equal(t, 1, chManager.store.GetNodeChannelCount(1)) - err = chManager.CleanupAndReassign(1, "chan") + err = chManager.CleanupAndReassign(2, "chan") assert.Error(t, err) }) @@ -670,10 +657,11 @@ func TestChannelManager(t *testing.T) { CheckShouldDropChannel(mock.Anything). Return(true) handler.EXPECT().FinishDropChannel(mock.Anything, mock.Anything).Return(nil) + watchkv.RemoveWithPrefix("") chManager, err := NewChannelManager(watchkv, handler) require.NoError(t, err) - chManager.store.Add(1) + chManager.store.AddNode(1) ops := getTestOps(1, &channelMeta{Name: "chan", CollectionID: 1, WatchInfo: &datapb.ChannelWatchInfo{}}) err = chManager.store.Update(ops) require.NoError(t, err) @@ -728,7 +716,7 @@ func TestChannelManager(t *testing.T) { // prepare tests for _, test := range tests { - chManager.store.Add(test.nodeID) + chManager.store.AddNode(test.nodeID) ops := getTestOps(test.nodeID, &channelMeta{Name: test.chName, CollectionID: collectionID, WatchInfo: &datapb.ChannelWatchInfo{}}) err = chManager.store.Update(ops) require.NoError(t, err) @@ -776,7 +764,7 @@ func TestChannelManager(t *testing.T) { ch := chManager.getChannelByNodeAndName(nodeID, channelName) assert.Nil(t, ch) - chManager.store.Add(nodeID) + chManager.store.AddNode(nodeID) ch = chManager.getChannelByNodeAndName(nodeID, channelName) assert.Nil(t, ch) @@ -837,7 +825,7 @@ func TestChannelManager(t *testing.T) { chManager, err := NewChannelManager(watchkv, newMockHandler()) require.NoError(t, err) - chManager.store.Add(nodeID) + chManager.store.AddNode(nodeID) opSet := NewChannelOpSet(NewAddOp(nodeID, &channelMeta{Name: channelName, CollectionID: collectionID})) @@ -864,7 +852,7 @@ func TestChannelManager(t *testing.T) { chManager, err := NewChannelManager(watchkv, newMockHandler(), withBgChecker()) require.NoError(t, err) assert.NotNil(t, chManager.bgChecker) - chManager.Startup(ctx, []int64{nodeID}) + chManager.Startup(ctx, nil, []int64{nodeID}) // 2. test isSilent function running correctly Params.Save(Params.DataCoordCfg.ChannelBalanceSilentDuration.Key, "3") @@ -1049,7 +1037,7 @@ func TestChannelManager_Reload(t *testing.T) { cm2, err := NewChannelManager(watchkv, newMockHandler()) assert.NoError(t, err) - assert.Nil(t, cm2.Startup(ctx, []int64{3})) + assert.Nil(t, cm2.Startup(ctx, nil, []int64{3})) waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, 3, "channel1", 1) waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, 3, "channel2", 1) diff --git a/internal/datacoord/channel_manager_v2.go b/internal/datacoord/channel_manager_v2.go new file mode 100644 index 0000000000..4b69161741 --- /dev/null +++ b/internal/datacoord/channel_manager_v2.go @@ -0,0 +1,727 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datacoord + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/cockroachdb/errors" + "github.com/samber/lo" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/kv" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/conc" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type ChannelManager interface { + Startup(ctx context.Context, legacyNodes, allNodes []int64) error + Close() + + AddNode(nodeID UniqueID) error + DeleteNode(nodeID UniqueID) error + Watch(ctx context.Context, ch RWChannel) error + Release(nodeID UniqueID, channelName string) error + + Match(nodeID UniqueID, channel string) bool + FindWatcher(channel string) (UniqueID, error) + + GetChannel(nodeID int64, channel string) (RWChannel, bool) + GetNodeIDByChannelName(channel string) (int64, bool) + GetNodeChannelsByCollectionID(collectionID int64) map[int64][]string + GetChannelsByCollectionID(collectionID int64) []RWChannel + GetChannelNamesByCollectionID(collectionID int64) []string +} + +// An interface sessionManager implments +type SubCluster interface { + NotifyChannelOperation(ctx context.Context, nodeID int64, req *datapb.ChannelOperationsRequest) error + CheckChannelOperationProgress(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error) +} + +type ChannelManagerImplV2 struct { + ctx context.Context + cancel context.CancelFunc + mu sync.RWMutex + + h Handler + store RWChannelStore + subCluster SubCluster // sessionManager + allocator allocator + + factory ChannelPolicyFactory + balancePolicy BalanceChannelPolicy + + balanceCheckLoop ChannelBGChecker + + legacyNodes typeutil.UniqueSet + + lastActiveTimestamp time.Time +} + +// ChannelBGChecker are goroutining running background +type ChannelBGChecker func(ctx context.Context) + +// ChannelmanagerOptV2 is to set optional parameters in channel manager. +type ChannelmanagerOptV2 func(c *ChannelManagerImplV2) + +func withFactoryV2(f ChannelPolicyFactory) ChannelmanagerOptV2 { + return func(c *ChannelManagerImplV2) { c.factory = f } +} + +func withCheckerV2() ChannelmanagerOptV2 { + return func(c *ChannelManagerImplV2) { c.balanceCheckLoop = c.CheckLoop } +} + +func NewChannelManagerV2( + kv kv.TxnKV, + h Handler, + subCluster SubCluster, // sessionManager + alloc allocator, + options ...ChannelmanagerOptV2, +) (*ChannelManagerImplV2, error) { + m := &ChannelManagerImplV2{ + h: h, + ctx: context.TODO(), // TODO + factory: NewChannelPolicyFactoryV1(), + store: NewChannelStoreV2(kv), + subCluster: subCluster, + allocator: alloc, + } + + if err := m.store.Reload(); err != nil { + return nil, err + } + + for _, opt := range options { + opt(m) + } + + m.balancePolicy = m.factory.NewBalancePolicy() + m.lastActiveTimestamp = time.Now() + return m, nil +} + +func (m *ChannelManagerImplV2) Startup(ctx context.Context, legacyNodes, allNodes []int64) error { + m.ctx, m.cancel = context.WithCancel(ctx) + + m.legacyNodes = typeutil.NewUniqueSet(legacyNodes...) + + m.mu.Lock() + m.store.SetLegacyChannelByNode(legacyNodes...) + oNodes := m.store.GetNodes() + m.mu.Unlock() + + // Add new online nodes to the cluster. + offLines, newOnLines := lo.Difference(oNodes, allNodes) + lo.ForEach(newOnLines, func(nodeID int64, _ int) { + m.AddNode(nodeID) + }) + + // Delete offlines from the cluster + lo.ForEach(offLines, func(nodeID int64, _ int) { + m.DeleteNode(nodeID) + }) + + m.mu.Lock() + nodeChannels := m.store.GetNodeChannelsBy( + WithAllNodes(), + func(ch *StateChannel) bool { + return m.h.CheckShouldDropChannel(ch.GetName()) + }) + m.mu.Unlock() + + for _, info := range nodeChannels { + m.finishRemoveChannel(info.NodeID, lo.Values(info.Channels)...) + } + + if m.balanceCheckLoop != nil { + log.Info("starting channel balance loop") + go m.balanceCheckLoop(m.ctx) + } + + log.Info("cluster start up", + zap.Int64s("allNodes", allNodes), + zap.Int64s("legacyNodes", legacyNodes), + zap.Int64s("oldNodes", oNodes), + zap.Int64s("newOnlines", newOnLines), + zap.Int64s("offLines", offLines)) + return nil +} + +func (m *ChannelManagerImplV2) Close() { + if m.cancel != nil { + m.cancel() + } +} + +func (m *ChannelManagerImplV2) AddNode(nodeID UniqueID) error { + m.mu.Lock() + defer m.mu.Unlock() + + log.Info("register node", zap.Int64("registered node", nodeID)) + + m.store.AddNode(nodeID) + updates := AvgAssignByCountPolicy(m.store.GetNodesChannels(), m.store.GetBufferChannelInfo().GetChannels(), m.legacyNodes.Collect()) + + if updates == nil { + log.Info("register node with no reassignment", zap.Int64("registered node", nodeID)) + return nil + } + + err := m.execute(updates) + if err != nil { + log.Warn("fail to update channel operation updates into meta", zap.Error(err)) + } + return err +} + +// Release writes ToRelease channel watch states for a channel +func (m *ChannelManagerImplV2) Release(nodeID UniqueID, channelName string) error { + log := log.With( + zap.Int64("nodeID", nodeID), + zap.String("channel", channelName), + ) + + // channel in bufferID are released already + if nodeID == bufferID { + return nil + } + + log.Info("Releasing channel from watched node") + ch, found := m.GetChannel(nodeID, channelName) + if !found { + return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName) + } + + m.mu.Lock() + defer m.mu.Unlock() + updates := NewChannelOpSet(NewChannelOp(nodeID, Release, ch)) + return m.execute(updates) +} + +func (m *ChannelManagerImplV2) Watch(ctx context.Context, ch RWChannel) error { + log := log.Ctx(ctx).With(zap.String("channel", ch.GetName())) + m.mu.Lock() + defer m.mu.Unlock() + + log.Info("Add channel") + updates := NewChannelOpSet(NewChannelOp(bufferID, Watch, ch)) + err := m.execute(updates) + if err != nil { + log.Warn("fail to update new channel updates into meta", + zap.Array("updates", updates), zap.Error(err)) + } + + // channel already written into meta, try to assign it to the cluster + // not error is returned if failed, the assignment will retry later + updates = AvgAssignByCountPolicy(m.store.GetNodesChannels(), []RWChannel{ch}, m.legacyNodes.Collect()) + if updates == nil { + return nil + } + + if err := m.execute(updates); err != nil { + log.Warn("fail to assign channel, will retry later", zap.Array("updates", updates), zap.Error(err)) + return nil + } + + log.Info("Assign channel", zap.Array("updates", updates)) + return nil +} + +func (m *ChannelManagerImplV2) DeleteNode(nodeID UniqueID) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.legacyNodes.Remove(nodeID) + info := m.store.GetNode(nodeID) + if info == nil || len(info.Channels) == 0 { + if nodeID != bufferID { + m.store.RemoveNode(nodeID) + } + return nil + } + + updates := NewChannelOpSet( + NewDeleteOp(info.NodeID, lo.Values(info.Channels)...), + NewChannelOp(bufferID, Watch, lo.Values(info.Channels)...), + ) + log.Info("deregister node", zap.Int64("nodeID", nodeID), zap.Array("updates", updates)) + + err := m.execute(updates) + if err != nil { + log.Warn("fail to update channel operation updates into meta", zap.Error(err)) + return err + } + + if nodeID != bufferID { + m.store.RemoveNode(nodeID) + } + return nil +} + +// reassign reassigns a channel to another DataNode. +func (m *ChannelManagerImplV2) reassign(original *NodeChannelInfo) error { + m.mu.Lock() + defer m.mu.Unlock() + + updates := AvgAssignByCountPolicy(m.store.GetNodesChannels(), original.GetChannels(), m.legacyNodes.Collect()) + if updates != nil { + return m.execute(updates) + } + + if original.NodeID != bufferID { + log.RatedWarn(5.0, "Failed to reassign channel to other nodes, assign to the original nodes", + zap.Any("original node", original.NodeID), + zap.Strings("channels", lo.Keys(original.Channels)), + ) + updates := NewChannelOpSet(NewChannelOp(original.NodeID, Watch, lo.Values(original.Channels)...)) + return m.execute(updates) + } + + return nil +} + +func (m *ChannelManagerImplV2) Balance() { + m.mu.Lock() + defer m.mu.Unlock() + + watchedCluster := m.store.GetNodeChannelsBy(WithoutBufferNode(), WithChannelStates(Watched)) + updates := m.balancePolicy(watchedCluster) + if updates == nil { + return + } + + log.Info("Channel balancer got new reAllocations:", zap.Array("assignment", updates)) + if err := m.execute(updates); err != nil { + log.Warn("Channel balancer fail to execute", zap.Array("assignment", updates), zap.Error(err)) + } +} + +func (m *ChannelManagerImplV2) Match(nodeID UniqueID, channel string) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + info := m.store.GetNode(nodeID) + if info == nil { + return false + } + + _, ok := info.Channels[channel] + return ok +} + +func (m *ChannelManagerImplV2) GetChannel(nodeID int64, channelName string) (RWChannel, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + if nodeChannelInfo := m.store.GetNode(nodeID); nodeChannelInfo != nil { + if ch, ok := nodeChannelInfo.Channels[channelName]; ok { + return ch, true + } + } + return nil, false +} + +func (m *ChannelManagerImplV2) GetNodeIDByChannelName(channel string) (int64, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + nodeChannels := m.store.GetNodeChannelsBy( + WithoutBufferNode(), + WithChannelName(channel)) + + if len(nodeChannels) > 0 { + return nodeChannels[0].NodeID, true + } + + return 0, false +} + +func (m *ChannelManagerImplV2) GetNodeChannelsByCollectionID(collectionID int64) map[int64][]string { + m.mu.RLock() + defer m.mu.RUnlock() + nodeChs := make(map[UniqueID][]string) + nodeChannels := m.store.GetNodeChannelsBy( + WithoutBufferNode(), + WithCollectionIDV2(collectionID)) + lo.ForEach(nodeChannels, func(info *NodeChannelInfo, _ int) { + nodeChs[info.NodeID] = lo.Keys(info.Channels) + }) + return nodeChs +} + +func (m *ChannelManagerImplV2) GetChannelsByCollectionID(collectionID int64) []RWChannel { + m.mu.RLock() + defer m.mu.RUnlock() + channels := []RWChannel{} + + nodeChannels := m.store.GetNodeChannelsBy( + WithAllNodes(), + WithCollectionIDV2(collectionID)) + lo.ForEach(nodeChannels, func(info *NodeChannelInfo, _ int) { + channels = append(channels, lo.Values(info.Channels)...) + }) + return channels +} + +func (m *ChannelManagerImplV2) GetChannelNamesByCollectionID(collectionID int64) []string { + channels := m.GetChannelsByCollectionID(collectionID) + return lo.Map(channels, func(ch RWChannel, _ int) string { + return ch.GetName() + }) +} + +func (m *ChannelManagerImplV2) FindWatcher(channel string) (UniqueID, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + infos := m.store.GetNodesChannels() + for _, info := range infos { + for _, channelInfo := range info.Channels { + if channelInfo.GetName() == channel { + return info.NodeID, nil + } + } + } + + // channel in buffer + bufferInfo := m.store.GetBufferChannelInfo() + for _, channelInfo := range bufferInfo.Channels { + if channelInfo.GetName() == channel { + return bufferID, errChannelInBuffer + } + } + return 0, errChannelNotWatched +} + +// unsafe innter func +func (m *ChannelManagerImplV2) removeChannel(nodeID int64, ch RWChannel) error { + op := NewChannelOpSet(NewChannelOp(nodeID, Delete, ch)) + log.Info("remove channel assignment", + zap.String("channel", ch.GetName()), + zap.Int64("assignment", nodeID), + zap.Int64("collectionID", ch.GetCollectionID())) + return m.store.Update(op) +} + +func (m *ChannelManagerImplV2) CheckLoop(ctx context.Context) { + balanceTicker := time.NewTicker(Params.DataCoordCfg.ChannelBalanceInterval.GetAsDuration(time.Second)) + defer balanceTicker.Stop() + checkTicker := time.NewTicker(Params.DataCoordCfg.ChannelCheckInterval.GetAsDuration(time.Second)) + defer checkTicker.Stop() + for { + select { + case <-ctx.Done(): + log.Info("background checking channels loop quit") + return + case <-balanceTicker.C: + // balance + if time.Since(m.lastActiveTimestamp) >= Params.DataCoordCfg.ChannelBalanceSilentDuration.GetAsDuration(time.Second) { + m.Balance() + } + case <-checkTicker.C: + m.AdvanceChannelState() + } + } +} + +func (m *ChannelManagerImplV2) AdvanceChannelState() { + m.mu.RLock() + standbys := m.store.GetNodeChannelsBy(WithAllNodes(), WithChannelStates(Standby)) + toNotifies := m.store.GetNodeChannelsBy(WithoutBufferNode(), WithChannelStates(ToWatch, ToRelease)) + toChecks := m.store.GetNodeChannelsBy(WithoutBufferNode(), WithChannelStates(Watching, Releasing)) + m.mu.RUnlock() + + // Processing standby channels + updatedStandbys := m.advanceStandbys(standbys) + updatedToCheckes := m.advanceToChecks(toChecks) + updatedToNotifies := m.advanceToNotifies(toNotifies) + + if updatedStandbys || updatedToCheckes || updatedToNotifies { + m.lastActiveTimestamp = time.Now() + } +} + +func (m *ChannelManagerImplV2) finishRemoveChannel(nodeID int64, channels ...RWChannel) { + m.mu.Lock() + defer m.mu.Unlock() + for _, ch := range channels { + if err := m.removeChannel(nodeID, ch); err != nil { + log.Warn("Failed to remove channel", zap.Any("channel", ch), zap.Error(err)) + continue + } + + if err := m.h.FinishDropChannel(ch.GetName(), ch.GetCollectionID()); err != nil { + log.Warn("Failed to finish drop channel", zap.Any("channel", ch), zap.Error(err)) + continue + } + } +} + +func (m *ChannelManagerImplV2) advanceStandbys(standbys []*NodeChannelInfo) bool { + var advanced bool = false + for _, nodeAssign := range standbys { + validChannels := make(map[string]RWChannel) + for chName, ch := range nodeAssign.Channels { + // drop marked-drop channels + if m.h.CheckShouldDropChannel(chName) { + m.finishRemoveChannel(nodeAssign.NodeID, ch) + continue + } + validChannels[chName] = ch + } + nodeAssign.Channels = validChannels + + if len(nodeAssign.Channels) == 0 { + continue + } + + chNames := lo.Keys(validChannels) + if err := m.reassign(nodeAssign); err != nil { + log.Warn("Reassign channels fail", + zap.Int64("nodeID", nodeAssign.NodeID), + zap.Strings("channels", chNames), + ) + } + + log.Info("Reassign standby channels to node", + zap.Int64("nodeID", nodeAssign.NodeID), + zap.Strings("channels", chNames), + ) + advanced = true + } + + return advanced +} + +func (m *ChannelManagerImplV2) advanceToNotifies(toNotifies []*NodeChannelInfo) bool { + var advanced bool = false + for _, nodeAssign := range toNotifies { + channelCount := len(nodeAssign.Channels) + if channelCount == 0 { + continue + } + + var ( + succeededChannels = make([]RWChannel, 0, channelCount) + failedChannels = make([]RWChannel, 0, channelCount) + futures = make([]*conc.Future[any], 0, channelCount) + ) + + chNames := lo.Keys(nodeAssign.Channels) + log.Info("Notify channel operations to datanode", + zap.Int64("assignment", nodeAssign.NodeID), + zap.Int("total operation count", len(nodeAssign.Channels)), + zap.Strings("channel names", chNames), + ) + for _, ch := range nodeAssign.Channels { + innerCh := ch + + future := getOrCreateIOPool().Submit(func() (any, error) { + err := m.Notify(nodeAssign.NodeID, innerCh.GetWatchInfo()) + return innerCh, err + }) + futures = append(futures, future) + } + + for _, f := range futures { + ch, err := f.Await() + if err != nil { + failedChannels = append(failedChannels, ch.(RWChannel)) + } else { + succeededChannels = append(succeededChannels, ch.(RWChannel)) + advanced = true + } + } + + log.Info("Finish to notify channel operations to datanode", + zap.Int64("assignment", nodeAssign.NodeID), + zap.Int("operation count", channelCount), + zap.Int("success count", len(succeededChannels)), + zap.Int("failure count", len(failedChannels)), + ) + m.mu.Lock() + m.store.UpdateState(false, failedChannels...) + m.store.UpdateState(true, succeededChannels...) + m.mu.Unlock() + } + + return advanced +} + +type poolResult struct { + successful bool + ch RWChannel +} + +func (m *ChannelManagerImplV2) advanceToChecks(toChecks []*NodeChannelInfo) bool { + var advanced bool = false + for _, nodeAssign := range toChecks { + if len(nodeAssign.Channels) == 0 { + continue + } + + futures := make([]*conc.Future[any], 0, len(nodeAssign.Channels)) + + chNames := lo.Keys(nodeAssign.Channels) + log.Info("Check ToWatch/ToRelease channel operations progress", + zap.Int("channel count", len(nodeAssign.Channels)), + zap.Strings("channel names", chNames), + ) + + for _, ch := range nodeAssign.Channels { + innerCh := ch + + future := getOrCreateIOPool().Submit(func() (any, error) { + successful, got := m.Check(nodeAssign.NodeID, innerCh.GetWatchInfo()) + if got { + return poolResult{ + successful: successful, + ch: innerCh, + }, nil + } + return nil, errors.New("Got results with no progress") + }) + futures = append(futures, future) + } + + for _, f := range futures { + got, err := f.Await() + if err == nil { + m.mu.Lock() + result := got.(poolResult) + m.store.UpdateState(result.successful, result.ch) + m.mu.Unlock() + + advanced = true + } + } + + log.Info("Finish to Check ToWatch/ToRelease channel operations progress", + zap.Int("channel count", len(nodeAssign.Channels)), + zap.Strings("channel names", chNames), + ) + } + return advanced +} + +func (m *ChannelManagerImplV2) Notify(nodeID int64, info *datapb.ChannelWatchInfo) error { + log := log.With( + zap.String("channel", info.GetVchan().GetChannelName()), + zap.Int64("assignment", nodeID), + zap.String("operation", info.GetState().String()), + ) + log.Info("Notify channel operation") + err := m.subCluster.NotifyChannelOperation(m.ctx, nodeID, &datapb.ChannelOperationsRequest{Infos: []*datapb.ChannelWatchInfo{info}}) + if err != nil { + log.Warn("Fail to notify channel operations", zap.Error(err)) + return err + } + log.Debug("Success to notify channel operations") + return nil +} + +func (m *ChannelManagerImplV2) Check(nodeID int64, info *datapb.ChannelWatchInfo) (successful bool, got bool) { + log := log.With( + zap.Int64("opID", info.GetOpID()), + zap.Int64("nodeID", nodeID), + zap.String("check operation", info.GetState().String()), + zap.String("channel", info.GetVchan().GetChannelName()), + ) + resp, err := m.subCluster.CheckChannelOperationProgress(m.ctx, nodeID, info) + if err != nil { + log.Warn("Fail to check channel operation progress") + return false, false + } + log.Info("Got channel operation progress", + zap.String("got state", resp.GetState().String()), + zap.Int32("progress", resp.GetProgress())) + switch info.GetState() { + case datapb.ChannelWatchState_ToWatch: + if resp.GetState() == datapb.ChannelWatchState_ToWatch { + return false, false + } + if resp.GetState() == datapb.ChannelWatchState_WatchSuccess { + return true, true + } + if resp.GetState() == datapb.ChannelWatchState_WatchFailure { + return false, true + } + case datapb.ChannelWatchState_ToRelease: + if resp.GetState() == datapb.ChannelWatchState_ToRelease { + return false, false + } + if resp.GetState() == datapb.ChannelWatchState_ReleaseSuccess { + return true, true + } + if resp.GetState() == datapb.ChannelWatchState_ReleaseFailure { + return false, true + } + } + return false, false +} + +func (m *ChannelManagerImplV2) execute(updates *ChannelOpSet) error { + for _, op := range updates.ops { + if op.Type != Delete { + if err := m.fillChannelWatchInfo(op); err != nil { + log.Warn("fail to fill channel watch info", zap.Error(err)) + return err + } + } + } + + return m.store.Update(updates) +} + +// fillChannelWatchInfoWithState updates the channel op by filling in channel watch info. +func (m *ChannelManagerImplV2) fillChannelWatchInfo(op *ChannelOp) error { + startTs := time.Now().Unix() + for _, ch := range op.Channels { + vcInfo := m.h.GetDataVChanPositions(ch, allPartitionID) + opID, err := m.allocator.allocID(context.Background()) + if err != nil { + return err + } + + info := &datapb.ChannelWatchInfo{ + Vchan: vcInfo, + StartTs: startTs, + State: inferStateByOpType(op.Type), + Schema: ch.GetSchema(), + OpID: opID, + } + ch.UpdateWatchInfo(info) + } + return nil +} + +func inferStateByOpType(opType ChannelOpType) datapb.ChannelWatchState { + switch opType { + case Watch: + return datapb.ChannelWatchState_ToWatch + case Release: + return datapb.ChannelWatchState_ToRelease + default: + return datapb.ChannelWatchState_ToWatch + } +} diff --git a/internal/datacoord/channel_manager_v2_test.go b/internal/datacoord/channel_manager_v2_test.go new file mode 100644 index 0000000000..87ac92dcb3 --- /dev/null +++ b/internal/datacoord/channel_manager_v2_test.go @@ -0,0 +1,661 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datacoord + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + + kvmock "github.com/milvus-io/milvus/internal/kv/mocks" + "github.com/milvus-io/milvus/internal/kv/predicates" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestChannelManagerSuite(t *testing.T) { + suite.Run(t, new(ChannelManagerSuite)) +} + +type ChannelManagerSuite struct { + suite.Suite + + mockKv *kvmock.MetaKv + mockCluster *MockSubCluster + mockAlloc *NMockAllocator + mockHandler *NMockHandler +} + +func (s *ChannelManagerSuite) prepareMeta(chNodes map[string]int64, state datapb.ChannelWatchState) { + s.SetupTest() + if chNodes == nil { + s.mockKv.EXPECT().LoadWithPrefix(mock.Anything).Return(nil, nil, nil).Once() + return + } + var keys, values []string + for channel, nodeID := range chNodes { + keys = append(keys, fmt.Sprintf("channel_store/%d/%s", nodeID, channel)) + info := generateWatchInfo(channel, state) + bs, err := proto.Marshal(info) + s.Require().NoError(err) + values = append(values, string(bs)) + } + s.mockKv.EXPECT().LoadWithPrefix(mock.Anything).Return(keys, values, nil).Once() +} + +func (s *ChannelManagerSuite) checkAssignment(m *ChannelManagerImplV2, nodeID int64, channel string, state ChannelState) { + rwChannel, found := m.GetChannel(nodeID, channel) + s.True(found) + s.NotNil(rwChannel) + s.Equal(channel, rwChannel.GetName()) + sChannel, ok := rwChannel.(*StateChannel) + s.True(ok) + s.Equal(state, sChannel.currentState) + s.EqualValues(nodeID, sChannel.assignedNode) + s.True(m.Match(nodeID, channel)) + + if nodeID != bufferID { + gotNode, err := m.FindWatcher(channel) + s.NoError(err) + s.EqualValues(gotNode, nodeID) + } +} + +func (s *ChannelManagerSuite) checkNoAssignment(m *ChannelManagerImplV2, nodeID int64, channel string) { + rwChannel, found := m.GetChannel(nodeID, channel) + s.False(found) + s.Nil(rwChannel) + s.False(m.Match(nodeID, channel)) +} + +func (s *ChannelManagerSuite) SetupTest() { + s.mockKv = kvmock.NewMetaKv(s.T()) + s.mockCluster = NewMockSubCluster(s.T()) + s.mockAlloc = NewNMockAllocator(s.T()) + s.mockHandler = NewNMockHandler(s.T()) + s.mockHandler.EXPECT().GetDataVChanPositions(mock.Anything, mock.Anything). + RunAndReturn(func(ch RWChannel, partitionID UniqueID) *datapb.VchannelInfo { + return &datapb.VchannelInfo{ + CollectionID: ch.GetCollectionID(), + ChannelName: ch.GetName(), + } + }).Maybe() + s.mockAlloc.EXPECT().allocID(mock.Anything).Return(19530, nil).Maybe() + s.mockKv.EXPECT().MultiSaveAndRemove(mock.Anything, mock.Anything).RunAndReturn( + func(save map[string]string, removals []string, preds ...predicates.Predicate) error { + log.Info("test save and remove", zap.Any("save", save), zap.Any("removals", removals)) + return nil + }).Maybe() +} + +func (s *ChannelManagerSuite) TearDownTest() {} + +func (s *ChannelManagerSuite) TestAddNode() { + s.Run("AddNode with empty store", func() { + s.prepareMeta(nil, 0) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + + var testNode int64 = 1 + err = m.AddNode(testNode) + s.NoError(err) + + info := m.store.GetNode(testNode) + s.NotNil(info) + s.Empty(info.Channels) + s.Equal(info.NodeID, testNode) + }) + s.Run("AddNode with channel in bufferID", func() { + chNodes := map[string]int64{ + "ch1": bufferID, + "ch2": bufferID, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + + var ( + testNodeID int64 = 1 + testChannels = []string{"ch1", "ch2"} + ) + lo.ForEach(testChannels, func(ch string, _ int) { + s.checkAssignment(m, bufferID, ch, Standby) + }) + + err = m.AddNode(testNodeID) + s.NoError(err) + + lo.ForEach(testChannels, func(ch string, _ int) { + s.checkAssignment(m, testNodeID, ch, ToWatch) + }) + }) + s.Run("AddNode with channels evenly in other node", func() { + var ( + testNodeID int64 = 100 + storedNodeID int64 = 1 + testChannel = "ch1" + ) + + chNodes := map[string]int64{testChannel: storedNodeID} + s.prepareMeta(chNodes, datapb.ChannelWatchState_WatchSuccess) + + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + + s.checkAssignment(m, storedNodeID, testChannel, Watched) + + err = m.AddNode(testNodeID) + s.NoError(err) + s.ElementsMatch([]int64{100, 1}, m.store.GetNodes()) + s.checkNoAssignment(m, testNodeID, testChannel) + + testNodeID = 101 + paramtable.Get().Save(paramtable.Get().DataCoordCfg.AutoBalance.Key, "true") + defer paramtable.Get().Reset(paramtable.Get().DataCoordCfg.AutoBalance.Key) + + err = m.AddNode(testNodeID) + s.NoError(err) + s.ElementsMatch([]int64{100, 101, 1}, m.store.GetNodes()) + s.checkNoAssignment(m, testNodeID, testChannel) + }) + s.Run("AddNode with channels unevenly in other node", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + "ch3": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_WatchSuccess) + + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + + var testNodeID int64 = 100 + paramtable.Get().Save(paramtable.Get().DataCoordCfg.AutoBalance.Key, "true") + defer paramtable.Get().Reset(paramtable.Get().DataCoordCfg.AutoBalance.Key) + + err = m.AddNode(testNodeID) + s.NoError(err) + s.ElementsMatch([]int64{testNodeID, 1}, m.store.GetNodes()) + }) +} + +func (s *ChannelManagerSuite) TestWatch() { + s.Run("test Watch with empty store", func() { + s.prepareMeta(nil, 0) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + + var testCh string = "ch1" + + err = m.Watch(context.TODO(), getChannel(testCh, 1)) + s.NoError(err) + + s.checkAssignment(m, bufferID, testCh, Standby) + }) + s.Run("test Watch with nodeID in store", func() { + s.prepareMeta(nil, 0) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + + var ( + testCh string = "ch1" + testNodeID int64 = 1 + ) + err = m.AddNode(testNodeID) + s.NoError(err) + s.checkNoAssignment(m, testNodeID, testCh) + + err = m.Watch(context.TODO(), getChannel(testCh, 1)) + s.NoError(err) + + s.checkAssignment(m, testNodeID, testCh, ToWatch) + }) +} + +func (s *ChannelManagerSuite) TestRelease() { + s.Run("release not exist nodeID and channel", func() { + s.prepareMeta(nil, 0) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + + err = m.Release(1, "ch1") + s.Error(err) + log.Info("error", zap.String("msg", err.Error())) + + m.AddNode(1) + err = m.Release(1, "ch1") + s.Error(err) + log.Info("error", zap.String("msg", err.Error())) + }) + + s.Run("release channel in bufferID", func() { + s.prepareMeta(nil, 0) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + + m.Watch(context.TODO(), getChannel("ch1", 1)) + s.checkAssignment(m, bufferID, "ch1", Standby) + + err = m.Release(bufferID, "ch1") + s.NoError(err) + s.checkAssignment(m, bufferID, "ch1", Standby) + }) +} + +func (s *ChannelManagerSuite) TestDeleteNode() { + s.Run("delete not exsit node", func() { + s.prepareMeta(nil, 0) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + info := m.store.GetNode(1) + s.Require().Nil(info) + + err = m.DeleteNode(1) + s.NoError(err) + }) + s.Run("delete bufferID", func() { + s.prepareMeta(nil, 0) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + info := m.store.GetNode(bufferID) + s.Require().NotNil(info) + + err = m.DeleteNode(bufferID) + s.NoError(err) + }) + + s.Run("delete node without assigment", func() { + s.prepareMeta(nil, 0) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + + err = m.AddNode(1) + s.NoError(err) + info := m.store.GetNode(bufferID) + s.Require().NotNil(info) + + err = m.DeleteNode(1) + s.NoError(err) + info = m.store.GetNode(1) + s.Nil(info) + }) + s.Run("delete node with channel", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + "ch3": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_WatchSuccess) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, 1, "ch1", Watched) + s.checkAssignment(m, 1, "ch2", Watched) + s.checkAssignment(m, 1, "ch3", Watched) + + err = m.AddNode(2) + s.NoError(err) + + err = m.DeleteNode(1) + s.NoError(err) + info := m.store.GetNode(bufferID) + s.NotNil(info) + + s.Equal(3, len(info.Channels)) + s.EqualValues(bufferID, info.NodeID) + s.checkAssignment(m, bufferID, "ch1", Standby) + s.checkAssignment(m, bufferID, "ch2", Standby) + s.checkAssignment(m, bufferID, "ch3", Standby) + + info = m.store.GetNode(1) + s.Nil(info) + }) +} + +func (s *ChannelManagerSuite) TestFindWatcher() { + chNodes := map[string]int64{ + "ch1": bufferID, + "ch2": bufferID, + "ch3": 1, + "ch4": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_WatchSuccess) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + + tests := []struct { + description string + testCh string + + outNodeID int64 + outError bool + }{ + {"channel not exist", "ch-notexist", 0, true}, + {"channel in bufferID", "ch1", bufferID, true}, + {"channel in bufferID", "ch2", bufferID, true}, + {"channel in nodeID=1", "ch3", 1, false}, + {"channel in nodeID=1", "ch4", 1, false}, + } + + for _, test := range tests { + s.Run(test.description, func() { + gotID, gotErr := m.FindWatcher(test.testCh) + s.EqualValues(test.outNodeID, gotID) + if test.outError { + s.Error(gotErr) + } else { + s.NoError(gotErr) + } + }) + } +} + +func (s *ChannelManagerSuite) TestAdvanceChannelState() { + s.Run("advance statndby with no available nodes", func() { + chNodes := map[string]int64{ + "ch1": bufferID, + "ch2": bufferID, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch) + s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, bufferID, "ch1", Standby) + s.checkAssignment(m, bufferID, "ch2", Standby) + + m.AdvanceChannelState() + s.checkAssignment(m, bufferID, "ch1", Standby) + s.checkAssignment(m, bufferID, "ch2", Standby) + }) + + s.Run("advance statndby with node 1", func() { + chNodes := map[string]int64{ + "ch1": bufferID, + "ch2": bufferID, + "ch3": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_WatchSuccess) + s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false).Times(2) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, bufferID, "ch1", Standby) + s.checkAssignment(m, bufferID, "ch2", Standby) + s.checkAssignment(m, 1, "ch3", Watched) + + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", ToWatch) + s.checkAssignment(m, 1, "ch2", ToWatch) + }) + s.Run("advance towatch channels notify success check success", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch) + s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice() + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, 1, "ch1", ToWatch) + s.checkAssignment(m, 1, "ch2", ToWatch) + + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Watching) + s.checkAssignment(m, 1, "ch2", Watching) + }) + s.Run("advance watching channels check no progress", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch) + s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice() + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, 1, "ch1", ToWatch) + s.checkAssignment(m, 1, "ch2", ToWatch) + + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Watching) + s.checkAssignment(m, 1, "ch2", Watching) + + s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). + Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ToWatch}, nil).Twice() + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Watching) + s.checkAssignment(m, 1, "ch2", Watching) + }) + s.Run("advance watching channels check watch success", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch) + s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice() + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, 1, "ch1", ToWatch) + s.checkAssignment(m, 1, "ch2", ToWatch) + + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Watching) + s.checkAssignment(m, 1, "ch2", Watching) + + s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). + Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_WatchSuccess}, nil).Twice() + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Watched) + s.checkAssignment(m, 1, "ch2", Watched) + }) + s.Run("advance watching channels check watch fail", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch) + s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Times(2) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, 1, "ch1", ToWatch) + s.checkAssignment(m, 1, "ch2", ToWatch) + + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Watching) + s.checkAssignment(m, 1, "ch2", Watching) + + s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). + Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_WatchFailure}, nil).Twice() + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Standby) + s.checkAssignment(m, 1, "ch2", Standby) + + s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false) + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", ToWatch) + s.checkAssignment(m, 1, "ch2", ToWatch) + }) + s.Run("advance releasing channels check release no progress", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease) + s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice() + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, 1, "ch1", ToRelease) + s.checkAssignment(m, 1, "ch2", ToRelease) + + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Releasing) + s.checkAssignment(m, 1, "ch2", Releasing) + + s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). + Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ToRelease}, nil).Twice() + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Releasing) + s.checkAssignment(m, 1, "ch2", Releasing) + }) + s.Run("advance releasing channels check release success", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease) + s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice() + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, 1, "ch1", ToRelease) + s.checkAssignment(m, 1, "ch2", ToRelease) + + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Releasing) + s.checkAssignment(m, 1, "ch2", Releasing) + + s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). + Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ReleaseSuccess}, nil).Twice() + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Standby) + s.checkAssignment(m, 1, "ch2", Standby) + + s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false) + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", ToWatch) + s.checkAssignment(m, 1, "ch2", ToWatch) + }) + s.Run("advance releasing channels check release fail", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease) + s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice() + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, 1, "ch1", ToRelease) + s.checkAssignment(m, 1, "ch2", ToRelease) + + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Releasing) + s.checkAssignment(m, 1, "ch2", Releasing) + + s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). + Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ReleaseFailure}, nil).Twice() + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Standby) + s.checkAssignment(m, 1, "ch2", Standby) + + s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false) + m.AdvanceChannelState() + // TODO, donot assign to abnormal nodes + s.checkAssignment(m, 1, "ch1", ToWatch) + s.checkAssignment(m, 1, "ch2", ToWatch) + }) + s.Run("advance towatch channels notify fail", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToWatch) + s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything). + Return(fmt.Errorf("mock error")).Twice() + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, 1, "ch1", ToWatch) + s.checkAssignment(m, 1, "ch2", ToWatch) + + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", ToWatch) + s.checkAssignment(m, 1, "ch2", ToWatch) + }) + s.Run("advance to release channels notify success", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease) + s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice() + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, 1, "ch1", ToRelease) + s.checkAssignment(m, 1, "ch2", ToRelease) + + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", Releasing) + s.checkAssignment(m, 1, "ch2", Releasing) + }) + s.Run("advance to release channels notify fail", func() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease) + s.mockCluster.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything, mock.Anything). + Return(fmt.Errorf("mock error")).Twice() + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + s.checkAssignment(m, 1, "ch1", ToRelease) + s.checkAssignment(m, 1, "ch2", ToRelease) + + m.AdvanceChannelState() + s.checkAssignment(m, 1, "ch1", ToRelease) + s.checkAssignment(m, 1, "ch2", ToRelease) + }) +} + +func (s *ChannelManagerSuite) TestStartup() { + chNodes := map[string]int64{ + "ch1": 1, + "ch2": 1, + "ch3": 3, + } + s.prepareMeta(chNodes, datapb.ChannelWatchState_ToRelease) + s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false) + m, err := NewChannelManagerV2(s.mockKv, s.mockHandler, s.mockCluster, s.mockAlloc) + s.Require().NoError(err) + + var ( + legacyNodes = []int64{1} + allNodes = []int64{1} + ) + err = m.Startup(context.TODO(), legacyNodes, allNodes) + s.NoError(err) + + s.checkAssignment(m, 1, "ch1", Legacy) + s.checkAssignment(m, 1, "ch2", Legacy) + s.checkAssignment(m, bufferID, "ch3", Standby) + + err = m.DeleteNode(1) + s.NoError(err) + + s.checkAssignment(m, bufferID, "ch1", Standby) + s.checkAssignment(m, bufferID, "ch2", Standby) + + err = m.AddNode(2) + s.NoError(err) + s.checkAssignment(m, 2, "ch1", ToWatch) + s.checkAssignment(m, 2, "ch2", ToWatch) + s.checkAssignment(m, 2, "ch3", ToWatch) +} + +func (s *ChannelManagerSuite) TestCheckLoop() {} +func (s *ChannelManagerSuite) TestGet() {} diff --git a/internal/datacoord/channel_store.go b/internal/datacoord/channel_store.go index b463b4aaa0..c59e626a6f 100644 --- a/internal/datacoord/channel_store.go +++ b/internal/datacoord/channel_store.go @@ -33,8 +33,51 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) +// ROChannelStore is a read only channel store for channels and nodes. +type ROChannelStore interface { + // GetNode returns the channel info of a specific node. + // Returns nil if the node doesn't belong to the cluster + GetNode(nodeID int64) *NodeChannelInfo + // HasChannel checks if store already has the channel + HasChannel(channel string) bool + // GetNodesChannels returns the channels that are assigned to nodes. + // without bufferID node + GetNodesChannels() []*NodeChannelInfo + // GetBufferChannelInfo gets the unassigned channels. + GetBufferChannelInfo() *NodeChannelInfo + // GetNodes gets all node ids in store. + GetNodes() []int64 + // GetNodeChannelCount + GetNodeChannelCount(nodeID int64) int + + // GetNodeChannelsBy used by channel_store_v2 and channel_manager_v2 only + GetNodeChannelsBy(nodeSelector NodeSelector, channelSelectors ...ChannelSelector) []*NodeChannelInfo +} + +// RWChannelStore is the read write channel store for channels and nodes. +type RWChannelStore interface { + ROChannelStore + // Reload restores the buffer channels and node-channels mapping form kv. + Reload() error + // Add creates a new node-channels mapping, with no channels assigned to the node. + AddNode(nodeID int64) + // Delete removes nodeID and returns its channels. + RemoveNode(nodeID int64) + // Update applies the operations in ChannelOpSet. + Update(op *ChannelOpSet) error + + // UpdateState is used by StateChannelStore only + UpdateState(isSuccessful bool, channels ...RWChannel) + // SegLegacyChannelByNode is used by StateChannelStore only + SetLegacyChannelByNode(nodeIDs ...int64) +} + +// ChannelOpTypeNames implements zap log marshaller for ChannelOpSet. +var ChannelOpTypeNames = []string{"Add", "Delete", "Watch", "Release"} + const ( bufferID = math.MinInt64 delimiter = "/" @@ -49,6 +92,8 @@ type ChannelOpType int8 const ( Add ChannelOpType = iota Delete + Watch + Release ) // ChannelOp is an individual ADD or DELETE operation to the channel store. @@ -58,6 +103,14 @@ type ChannelOp struct { Channels []RWChannel } +func NewChannelOp(ID int64, opType ChannelOpType, channels ...RWChannel) *ChannelOp { + return &ChannelOp{ + Type: opType, + NodeID: ID, + Channels: channels, + } +} + func NewAddOp(id int64, channels ...RWChannel) *ChannelOp { return &ChannelOp{ NodeID: id, @@ -92,7 +145,7 @@ func (op *ChannelOp) BuildKV() (map[string]string, []string, error) { for _, ch := range op.Channels { k := buildNodeChannelKey(op.NodeID, ch.GetName()) switch op.Type { - case Add: + case Add, Watch, Release: info, err := proto.Marshal(ch.GetWatchInfo()) if err != nil { return saves, removals, err @@ -107,6 +160,24 @@ func (op *ChannelOp) BuildKV() (map[string]string, []string, error) { return saves, removals, nil } +// TODO: NIT: ObjectMarshaler -> ObjectMarshaller +// MarshalLogObject implements the interface ObjectMarshaler. +func (op *ChannelOp) MarshalLogObject(enc zapcore.ObjectEncoder) error { + enc.AddString("type", ChannelOpTypeNames[op.Type]) + enc.AddInt64("nodeID", op.NodeID) + cstr := "[" + if len(op.Channels) > 0 { + for _, s := range op.Channels { + cstr += s.GetName() + cstr += ", " + } + cstr = cstr[:len(cstr)-2] + } + cstr += "]" + enc.AddString("channels", cstr) + return nil +} + // ChannelOpSet is a set of channel operations. type ChannelOpSet struct { ops []*ChannelOp @@ -139,24 +210,31 @@ func (c *ChannelOpSet) Len() int { } // Add a new Add channel op, for ToWatch and ToRelease -func (c *ChannelOpSet) Add(id int64, channels ...RWChannel) { - c.ops = append(c.ops, NewAddOp(id, channels...)) +func (c *ChannelOpSet) Add(ID int64, channels ...RWChannel) { + c.Append(ID, Add, channels...) } -func (c *ChannelOpSet) Delete(id int64, channels ...RWChannel) { - c.ops = append(c.ops, NewDeleteOp(id, channels...)) +func (c *ChannelOpSet) Delete(ID int64, channels ...RWChannel) { + c.Append(ID, Delete, channels...) +} + +func (c *ChannelOpSet) Append(ID int64, opType ChannelOpType, channels ...RWChannel) { + c.ops = append(c.ops, NewChannelOp(ID, opType, channels...)) } func (c *ChannelOpSet) GetChannelNumber() int { if c == nil { return 0 } - number := 0 + + uniqChannels := typeutil.NewSet[string]() for _, op := range c.ops { - number += len(op.Channels) + uniqChannels.Insert(lo.Map(op.Channels, func(ch RWChannel, _ int) string { + return ch.GetName() + })...) } - return number + return uniqChannels.Len() } func (c *ChannelOpSet) SplitByChannel() map[string]*ChannelOpSet { @@ -168,43 +246,19 @@ func (c *ChannelOpSet) SplitByChannel() map[string]*ChannelOpSet { perChOps[ch.GetName()] = NewChannelOpSet() } - if op.Type == Add { - perChOps[ch.GetName()].Add(op.NodeID, ch) - } else { - perChOps[ch.GetName()].Delete(op.NodeID, ch) - } + perChOps[ch.GetName()].Append(op.NodeID, op.Type, ch) } } return perChOps } -// ROChannelStore is a read only channel store for channels and nodes. -type ROChannelStore interface { - // GetNode returns the channel info of a specific node. - GetNode(nodeID int64) *NodeChannelInfo - // GetChannels returns info of all channels. - GetChannels() []*NodeChannelInfo - // GetNodesChannels returns the channels that are assigned to nodes. - GetNodesChannels() []*NodeChannelInfo - // GetBufferChannelInfo gets the unassigned channels. - GetBufferChannelInfo() *NodeChannelInfo - // GetNodes gets all node ids in store. - GetNodes() []int64 - // GetNodeChannelCount - GetNodeChannelCount(nodeID int64) int -} - -// RWChannelStore is the read write channel store for channels and nodes. -type RWChannelStore interface { - ROChannelStore - // Reload restores the buffer channels and node-channels mapping form kv. - Reload() error - // Add creates a new node-channels mapping, with no channels assigned to the node. - Add(nodeID int64) - // Delete removes nodeID and returns its channels. - Delete(nodeID int64) ([]RWChannel, error) - // Update applies the operations in ChannelOpSet. - Update(op *ChannelOpSet) error +// TODO: NIT: ArrayMarshaler -> ArrayMarshaller +// MarshalLogArray implements the interface of ArrayMarshaler of zap. +func (c *ChannelOpSet) MarshalLogArray(enc zapcore.ArrayEncoder) error { + for _, o := range c.Collect() { + enc.AppendObject(o) + } + return nil } // ChannelStore must satisfy RWChannelStore. @@ -246,6 +300,13 @@ func NewNodeChannelInfo(nodeID int64, channels ...RWChannel) *NodeChannelInfo { return info } +func (info *NodeChannelInfo) GetChannels() []RWChannel { + if info == nil { + return nil + } + return lo.Values(info.Channels) +} + // NewChannelStore creates and returns a new ChannelStore. func NewChannelStore(kv kv.TxnKV) *ChannelStore { c := &ChannelStore{ @@ -280,7 +341,7 @@ func (c *ChannelStore) Reload() error { } reviseVChannelInfo(cw.GetVchan()) - c.Add(nodeID) + c.AddNode(nodeID) channel := &channelMeta{ Name: cw.GetVchan().GetChannelName(), CollectionID: cw.GetVchan().GetCollectionID(), @@ -297,9 +358,9 @@ func (c *ChannelStore) Reload() error { return nil } -// Add creates a new node-channels mapping for the given node, and assigns no channels to it. +// AddNode creates a new node-channels mapping for the given node, and assigns no channels to it. // Returns immediately if the node's already in the channel. -func (c *ChannelStore) Add(nodeID int64) { +func (c *ChannelStore) AddNode(nodeID int64) { if _, ok := c.channelsInfo[nodeID]; ok { return } @@ -356,7 +417,7 @@ func (c *ChannelStore) update(opSet *ChannelOpSet) error { // Update node id -> channel mapping. for _, op := range opSet.Collect() { switch op.Type { - case Add: + case Add, Watch, Release: for _, ch := range op.Channels { if c.checkIfExist(op.NodeID, ch) { continue // prevent adding duplicated channel info @@ -420,16 +481,9 @@ func (c *ChannelStore) GetNodeChannelCount(nodeID int64) int { return 0 } -// Delete removes the given node from the channel store and returns its channels. -func (c *ChannelStore) Delete(nodeID int64) ([]RWChannel, error) { - if info, ok := c.channelsInfo[nodeID]; ok { - if err := c.remove(nodeID); err != nil { - return nil, err - } - delete(c.channelsInfo, nodeID) - return lo.Values(info.Channels), nil - } - return nil, nil +// RemoveNode removes the given node from the channel store and returns its channels. +func (c *ChannelStore) RemoveNode(nodeID int64) { + delete(c.channelsInfo, nodeID) } // GetNodes returns a slice of all nodes ids in the current channel store. @@ -467,7 +521,32 @@ func (c *ChannelStore) txn(opSet *ChannelOpSet) error { return c.store.MultiSaveAndRemove(saves, removals) } +func (c *ChannelStore) HasChannel(channel string) bool { + for _, info := range c.channelsInfo { + for _, ch := range info.Channels { + if ch.GetName() == channel { + return true + } + } + } + return false +} + +func (c *ChannelStore) GetNodeChannelsBy(nodeSelector NodeSelector, channelSelectors ...ChannelSelector) []*NodeChannelInfo { + log.Error("ChannelStore doesn't implement GetNodeChannelsBy") + return nil +} + +func (c *ChannelStore) UpdateState(isSuccessful bool, channels ...RWChannel) { + log.Error("ChannelStore doesn't implement UpdateState") +} + +func (c *ChannelStore) SetLegacyChannelByNode(nodeIDs ...int64) { + log.Error("ChannelStore doesn't implement SetLegacyChannelByNode") +} + // buildNodeChannelKey generates a key for kv store, where the key is a concatenation of ChannelWatchSubPath, nodeID and channel name. +// ${WatchSubPath}/${nodeID}/${channelName} func buildNodeChannelKey(nodeID int64, chName string) string { return fmt.Sprintf("%s%s%d%s%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), delimiter, nodeID, delimiter, chName) } @@ -485,33 +564,3 @@ func parseNodeKey(key string) (int64, error) { } return strconv.ParseInt(s[len(s)-2], 10, 64) } - -// ChannelOpTypeNames implements zap log marshaller for ChannelOpSet. -var ChannelOpTypeNames = []string{"Add", "Delete"} - -// TODO: NIT: ObjectMarshaler -> ObjectMarshaller -// MarshalLogObject implements the interface ObjectMarshaler. -func (op *ChannelOp) MarshalLogObject(enc zapcore.ObjectEncoder) error { - enc.AddString("type", ChannelOpTypeNames[op.Type]) - enc.AddInt64("nodeID", op.NodeID) - cstr := "[" - if len(op.Channels) > 0 { - for _, s := range op.Channels { - cstr += s.GetName() - cstr += ", " - } - cstr = cstr[:len(cstr)-2] - } - cstr += "]" - enc.AddString("channels", cstr) - return nil -} - -// TODO: NIT: ArrayMarshaler -> ArrayMarshaller -// MarshalLogArray implements the interface of ArrayMarshaler of zap. -func (c *ChannelOpSet) MarshalLogArray(enc zapcore.ArrayEncoder) error { - for _, o := range c.Collect() { - enc.AppendObject(o) - } - return nil -} diff --git a/internal/datacoord/channel_store_test.go b/internal/datacoord/channel_store_test.go index 47413961fb..235bd5103c 100644 --- a/internal/datacoord/channel_store_test.go +++ b/internal/datacoord/channel_store_test.go @@ -43,7 +43,7 @@ func genNodeChannelInfos(id int64, num int) *NodeChannelInfo { return NewNodeChannelInfo(id, channels...) } -func genChannelOperations(from, to int64, num int) *ChannelOpSet { +func genChannelOperationsV1(from, to int64, num int) *ChannelOpSet { channels := make([]RWChannel, 0, num) for i := 0; i < num; i++ { name := fmt.Sprintf("ch%d", i) @@ -86,7 +86,7 @@ func TestChannelStore_Update(t *testing.T) { }, }, args{ - genChannelOperations(1, 2, 250), + genChannelOperationsV1(1, 2, 250), }, false, }, diff --git a/internal/datacoord/channel_store_v2.go b/internal/datacoord/channel_store_v2.go new file mode 100644 index 0000000000..82f0d14e9e --- /dev/null +++ b/internal/datacoord/channel_store_v2.go @@ -0,0 +1,432 @@ +package datacoord + +import ( + "strconv" + + "github.com/cockroachdb/errors" + "github.com/golang/protobuf/proto" + "github.com/samber/lo" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/kv" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/timerecord" +) + +type StateChannelStore struct { + store kv.TxnKV + channelsInfo map[int64]*NodeChannelInfo // A map of (nodeID) -> (NodeChannelInfo). +} + +var _ RWChannelStore = (*StateChannelStore)(nil) + +var errChannelNotExistInNode = errors.New("channel doesn't exist in given node") + +func NewChannelStoreV2(kv kv.TxnKV) RWChannelStore { + return NewStateChannelStore(kv) +} + +func NewStateChannelStore(kv kv.TxnKV) *StateChannelStore { + c := StateChannelStore{ + store: kv, + channelsInfo: make(map[int64]*NodeChannelInfo), + } + c.channelsInfo[bufferID] = &NodeChannelInfo{ + NodeID: bufferID, + Channels: make(map[string]RWChannel), + } + return &c +} + +func (c *StateChannelStore) Reload() error { + record := timerecord.NewTimeRecorder("datacoord") + keys, values, err := c.store.LoadWithPrefix(Params.CommonCfg.DataCoordWatchSubPath.GetValue()) + if err != nil { + return err + } + for i := 0; i < len(keys); i++ { + k := keys[i] + v := values[i] + nodeID, err := parseNodeKey(k) + if err != nil { + return err + } + + info := &datapb.ChannelWatchInfo{} + if err := proto.Unmarshal([]byte(v), info); err != nil { + return err + } + reviseVChannelInfo(info.GetVchan()) + + c.AddNode(nodeID) + + channel := NewStateChannelByWatchInfo(nodeID, info) + c.channelsInfo[nodeID].AddChannel(channel) + log.Info("channel store reload channel", + zap.Int64("nodeID", nodeID), zap.String("channel", channel.Name)) + metrics.DataCoordDmlChannelNum.WithLabelValues(strconv.FormatInt(nodeID, 10)).Set(float64(len(c.channelsInfo[nodeID].Channels))) + } + log.Info("channel store reload done", zap.Duration("duration", record.ElapseSpan())) + return nil +} + +func (c *StateChannelStore) AddNode(nodeID int64) { + if _, ok := c.channelsInfo[nodeID]; ok { + return + } + c.channelsInfo[nodeID] = &NodeChannelInfo{ + NodeID: nodeID, + Channels: make(map[string]RWChannel), + } +} + +func (c *StateChannelStore) UpdateState(isSuccessful bool, channels ...RWChannel) { + lo.ForEach(channels, func(ch RWChannel, _ int) { + for _, cInfo := range c.channelsInfo { + if stateChannel, ok := cInfo.Channels[ch.GetName()]; ok { + if isSuccessful { + stateChannel.(*StateChannel).TransitionOnSuccess() + } else { + stateChannel.(*StateChannel).TransitionOnFailure() + } + } + } + }) +} + +func (c *StateChannelStore) SetLegacyChannelByNode(nodeIDs ...int64) { + lo.ForEach(nodeIDs, func(nodeID int64, _ int) { + if cInfo, ok := c.channelsInfo[nodeID]; ok { + for _, ch := range cInfo.Channels { + ch.(*StateChannel).setState(Legacy) + } + } + }) +} + +func (c *StateChannelStore) Update(opSet *ChannelOpSet) error { + // Split opset into multiple txn. Operations on the same channel must be executed in one txn. + perChOps := opSet.SplitByChannel() + + // Execute a txn for every 64 operations. + count := 0 + operations := make([]*ChannelOp, 0, maxOperationsPerTxn) + for _, opset := range perChOps { + if !c.sanityCheckPerChannelOpSet(opset) { + log.Error("unsupported ChannelOpSet", zap.Any("OpSet", opset)) + continue + } + if opset.Len() > maxOperationsPerTxn { + log.Error("Operations for one channel exceeds maxOperationsPerTxn", + zap.Any("opset size", opset.Len()), + zap.Int("limit", maxOperationsPerTxn)) + } + if count+opset.Len() > maxOperationsPerTxn { + if err := c.updateMeta(NewChannelOpSet(operations...)); err != nil { + return err + } + count = 0 + operations = make([]*ChannelOp, 0, maxOperationsPerTxn) + } + count += opset.Len() + operations = append(operations, opset.Collect()...) + } + if count == 0 { + return nil + } + + return c.updateMeta(NewChannelOpSet(operations...)) +} + +// remove from the assignments +func (c *StateChannelStore) removeAssignment(nodeID int64, channelName string) { + if cInfo, ok := c.channelsInfo[nodeID]; ok { + delete(cInfo.Channels, channelName) + } +} + +func (c *StateChannelStore) addAssignment(nodeID int64, channel RWChannel) { + if cInfo, ok := c.channelsInfo[nodeID]; ok { + cInfo.Channels[channel.GetName()] = channel + } else { + c.channelsInfo[nodeID] = &NodeChannelInfo{ + NodeID: nodeID, + Channels: map[string]RWChannel{ + channel.GetName(): channel, + }, + } + } +} + +// updateMeta applies the WATCH/RELEASE/DELETE operations to the current channel store. +// DELETE + WATCH ---> from bufferID to nodeID +// DELETE + WATCH ---> from lagecyID to nodeID +// DELETE + WATCH ---> from deletedNode to nodeID/bufferID +// RELEASE ---> release from nodeID +// WATCH ---> watch to a new channel +// DELETE ---> remove the channel +func (c *StateChannelStore) sanityCheckPerChannelOpSet(opSet *ChannelOpSet) bool { + if opSet.Len() == 2 { + ops := opSet.Collect() + return (ops[0].Type == Delete && ops[1].Type == Watch) || (ops[1].Type == Delete && ops[0].Type == Watch) + } else if opSet.Len() == 1 { + t := opSet.Collect()[0].Type + return t == Delete || t == Watch || t == Release + } + return false +} + +// DELETE + WATCH +func (c *StateChannelStore) updateMetaMemoryForPairOp(chName string, opSet *ChannelOpSet) error { + if !c.sanityCheckPerChannelOpSet(opSet) { + return errUnknownOpType + } + ops := opSet.Collect() + op1 := ops[1] + op2 := ops[0] + if ops[0].Type == Delete { + op1 = ops[0] + op2 = ops[1] + } + cInfo, ok := c.channelsInfo[op1.NodeID] + if !ok { + return errChannelNotExistInNode + } + var ch *StateChannel + if channel, ok := cInfo.Channels[chName]; ok { + ch = channel.(*StateChannel) + c.addAssignment(op2.NodeID, ch) + c.removeAssignment(op1.NodeID, chName) + } else { + if cInfo, ok = c.channelsInfo[op2.NodeID]; ok { + if channel2, ok := cInfo.Channels[chName]; ok { + ch = channel2.(*StateChannel) + } + } + } + // update channel + if ch != nil { + ch.Assign(op2.NodeID) + if op2.NodeID == bufferID { + ch.setState(Standby) + } else { + ch.setState(ToWatch) + } + } + return nil +} + +func (c *StateChannelStore) getChannel(nodeID int64, channelName string) *StateChannel { + if cInfo, ok := c.channelsInfo[nodeID]; ok { + if storedChannel, ok := cInfo.Channels[channelName]; ok { + return storedChannel.(*StateChannel) + } + log.Error("Channel doesn't exist in Node", zap.String("channel", channelName), zap.Int64("nodeID", nodeID)) + } else { + log.Error("Node doesn't exist", zap.Int64("NodeID", nodeID)) + } + return nil +} + +func (c *StateChannelStore) updateMetaMemoryForSingleOp(op *ChannelOp) error { + lo.ForEach(op.Channels, func(ch RWChannel, _ int) { + switch op.Type { + case Release: // release an already exsits storedChannel-node pair + if channel := c.getChannel(op.NodeID, ch.GetName()); channel != nil { + channel.setState(ToRelease) + } + case Watch: + storedChannel := c.getChannel(op.NodeID, ch.GetName()) + if storedChannel == nil { // New Channel + // set the correct assigment and state for NEW stateChannel + newChannel := NewStateChannel(ch) + newChannel.Assign(op.NodeID) + + if op.NodeID != bufferID { + newChannel.setState(ToWatch) + } + + // add channel to memory + c.addAssignment(op.NodeID, newChannel) + } else { // assign to the original nodes + storedChannel.setState(ToWatch) + } + case Delete: // Remove Channel + // if not Delete from bufferID, remove from channel + if op.NodeID != bufferID { + c.removeAssignment(op.NodeID, ch.GetName()) + } + default: + log.Error("unknown opType in updateMetaMemoryForSingleOp", zap.Any("type", op.Type)) + } + }) + return nil +} + +func (c *StateChannelStore) updateMeta(opSet *ChannelOpSet) error { + // Update ChannelStore's kv store. + if err := c.txn(opSet); err != nil { + return err + } + + // Update memory + chOpSet := opSet.SplitByChannel() + for chName, ops := range chOpSet { + // DELETE + WATCH + if ops.Len() == 2 { + c.updateMetaMemoryForPairOp(chName, ops) + // RELEASE, DELETE, WATCH + } else if ops.Len() == 1 { + c.updateMetaMemoryForSingleOp(ops.Collect()[0]) + } else { + log.Error("unsupported ChannelOpSet", zap.Any("OpSet", ops)) + } + } + return nil +} + +// txn updates the channelStore's kv store with the given channel ops. +func (c *StateChannelStore) txn(opSet *ChannelOpSet) error { + var ( + saves = make(map[string]string) + removals []string + ) + for _, op := range opSet.Collect() { + opSaves, opRemovals, err := op.BuildKV() + if err != nil { + return err + } + + saves = lo.Assign(opSaves, saves) + removals = append(removals, opRemovals...) + } + return c.store.MultiSaveAndRemove(saves, removals) +} + +func (c *StateChannelStore) RemoveNode(nodeID int64) { + delete(c.channelsInfo, nodeID) +} + +func (c *StateChannelStore) HasChannel(channel string) bool { + for _, info := range c.channelsInfo { + if _, ok := info.Channels[channel]; ok { + return true + } + } + return false +} + +type ( + ChannelSelector func(ch *StateChannel) bool + NodeSelector func(ID int64) bool +) + +func WithAllNodes() NodeSelector { + return func(ID int64) bool { + return true + } +} + +func WithoutBufferNode() NodeSelector { + return func(ID int64) bool { + return ID != int64(bufferID) + } +} + +func WithNodeIDs(IDs ...int64) NodeSelector { + return func(ID int64) bool { + return lo.Contains(IDs, ID) + } +} + +func WithoutNodeIDs(IDs ...int64) NodeSelector { + return func(ID int64) bool { + return !lo.Contains(IDs, ID) + } +} + +func WithChannelName(channel string) ChannelSelector { + return func(ch *StateChannel) bool { + return ch.GetName() == channel + } +} + +func WithCollectionIDV2(collectionID int64) ChannelSelector { + return func(ch *StateChannel) bool { + return ch.GetCollectionID() == collectionID + } +} + +func WithChannelStates(states ...ChannelState) ChannelSelector { + return func(ch *StateChannel) bool { + return lo.Contains(states, ch.currentState) + } +} + +func (c *StateChannelStore) GetNodeChannelsBy(nodeSelector NodeSelector, channelSelectors ...ChannelSelector) []*NodeChannelInfo { + nodeChannels := make(map[int64]*NodeChannelInfo) + for nodeID, cInfo := range c.channelsInfo { + if nodeSelector(nodeID) { + selected := make(map[string]RWChannel) + for chName, channel := range cInfo.Channels { + var sel bool = true + for _, selector := range channelSelectors { + if !selector(channel.(*StateChannel)) { + sel = false + break + } + } + if sel { + selected[chName] = channel + } + } + nodeChannels[nodeID] = &NodeChannelInfo{ + NodeID: nodeID, + Channels: selected, + } + } + } + return lo.Values(nodeChannels) +} + +func (c *StateChannelStore) GetNodesChannels() []*NodeChannelInfo { + ret := make([]*NodeChannelInfo, 0, len(c.channelsInfo)) + for id, info := range c.channelsInfo { + if id != bufferID { + ret = append(ret, info) + } + } + return ret +} + +func (c *StateChannelStore) GetBufferChannelInfo() *NodeChannelInfo { + return c.GetNode(bufferID) +} + +func (c *StateChannelStore) GetNode(nodeID int64) *NodeChannelInfo { + if info, ok := c.channelsInfo[nodeID]; ok { + return info + } + return nil +} + +func (c *StateChannelStore) GetNodeChannelCount(nodeID int64) int { + if cInfo, ok := c.channelsInfo[nodeID]; ok { + return len(cInfo.Channels) + } + return 0 +} + +func (c *StateChannelStore) GetNodes() []int64 { + return lo.Filter(lo.Keys(c.channelsInfo), func(ID int64, _ int) bool { + return ID != bufferID + }) +} + +// remove deletes kv pairs from the kv store where keys have given nodeID as prefix. +func (c *StateChannelStore) remove(nodeID int64) error { + k := buildKeyPrefix(nodeID) + return c.store.RemoveWithPrefix(k) +} diff --git a/internal/datacoord/channel_store_v2_test.go b/internal/datacoord/channel_store_v2_test.go new file mode 100644 index 0000000000..d2f9a22f58 --- /dev/null +++ b/internal/datacoord/channel_store_v2_test.go @@ -0,0 +1,483 @@ +package datacoord + +import ( + "fmt" + "strconv" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/kv/mocks" + "github.com/milvus-io/milvus/internal/kv/predicates" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/testutils" +) + +func TestStateChannelStore(t *testing.T) { + suite.Run(t, new(StateChannelStoreSuite)) +} + +type StateChannelStoreSuite struct { + testutils.PromMetricsSuite + + mockTxn *mocks.TxnKV +} + +func (s *StateChannelStoreSuite) SetupTest() { + s.mockTxn = mocks.NewTxnKV(s.T()) +} + +func generateWatchInfo(name string, state datapb.ChannelWatchState) *datapb.ChannelWatchInfo { + return &datapb.ChannelWatchInfo{ + Vchan: &datapb.VchannelInfo{ + ChannelName: name, + }, + State: state, + } +} + +func (s *StateChannelStoreSuite) createChannelInfo(nodeID int64, channels ...RWChannel) *NodeChannelInfo { + cInfo := &NodeChannelInfo{ + NodeID: nodeID, + Channels: make(map[string]RWChannel), + } + for _, channel := range channels { + cInfo.Channels[channel.GetName()] = channel + } + return cInfo +} + +func (s *StateChannelStoreSuite) TestGetNodeChannelsBy() { + nodes := []int64{bufferID, 100, 101, 102} + nodesExcludeBufferID := []int64{100, 101, 102} + channels := []*StateChannel{ + getChannel("ch1", 1), + getChannel("ch2", 1), + getChannel("ch3", 1), + getChannel("ch4", 1), + getChannel("ch5", 1), + getChannel("ch6", 1), + getChannel("ch7", 1), + } + + channelsInfo := map[int64]*NodeChannelInfo{ + bufferID: s.createChannelInfo(bufferID, channels[0]), + 100: s.createChannelInfo(100, channels[1], channels[2]), + 101: s.createChannelInfo(101, channels[3], channels[4]), + 102: s.createChannelInfo(102, channels[5], channels[6]), // legacy nodes + } + + store := NewStateChannelStore(s.mockTxn) + lo.ForEach(nodes, func(nodeID int64, _ int) { store.AddNode(nodeID) }) + store.channelsInfo = channelsInfo + lo.ForEach(channels, func(ch *StateChannel, _ int) { + if ch.GetName() == "ch6" || ch.GetName() == "ch7" { + ch.setState(Legacy) + } + s.Require().True(store.HasChannel(ch.GetName())) + }) + s.Require().ElementsMatch(nodesExcludeBufferID, store.GetNodes()) + store.SetLegacyChannelByNode(102) + + s.Run("test AddNode RemoveNode", func() { + var nodeID int64 = 19530 + _, ok := store.channelsInfo[nodeID] + s.Require().False(ok) + store.AddNode(nodeID) + _, ok = store.channelsInfo[nodeID] + s.True(ok) + + store.RemoveNode(nodeID) + _, ok = store.channelsInfo[nodeID] + s.False(ok) + }) + + s.Run("test GetNodeChannels", func() { + infos := store.GetNodesChannels() + expectedResults := map[int64][]string{ + 100: {"ch2", "ch3"}, + 101: {"ch4", "ch5"}, + 102: {"ch6", "ch7"}, + } + + s.Equal(3, len(infos)) + + lo.ForEach(infos, func(info *NodeChannelInfo, _ int) { + expectedChannels, ok := expectedResults[info.NodeID] + s.True(ok) + + gotChannels := lo.Keys(info.Channels) + s.ElementsMatch(expectedChannels, gotChannels) + }) + }) + + s.Run("test GetBufferChannelInfo", func() { + info := store.GetBufferChannelInfo() + s.NotNil(info) + + gotChannels := lo.Keys(info.Channels) + s.ElementsMatch([]string{"ch1"}, gotChannels) + }) + + s.Run("test GetNode", func() { + info := store.GetNode(19530) + s.Nil(info) + + info = store.GetNode(bufferID) + s.NotNil(info) + + gotChannels := lo.Keys(info.Channels) + s.ElementsMatch([]string{"ch1"}, gotChannels) + }) + + tests := []struct { + description string + nodeSelector NodeSelector + channelSelectors []ChannelSelector + + expectedResult map[int64][]string + }{ + {"test withnodeIDs bufferID", WithNodeIDs(bufferID), nil, map[int64][]string{bufferID: {"ch1"}}}, + {"test withnodeIDs 100", WithNodeIDs(100), nil, map[int64][]string{100: {"ch2", "ch3"}}}, + {"test withnodeIDs 101 102", WithNodeIDs(101, 102), nil, map[int64][]string{ + 101: {"ch4", "ch5"}, + 102: {"ch6", "ch7"}, + }}, + {"test withAllNodes", WithAllNodes(), nil, map[int64][]string{ + bufferID: {"ch1"}, + 100: {"ch2", "ch3"}, + 101: {"ch4", "ch5"}, + 102: {"ch6", "ch7"}, + }}, + {"test WithoutBufferNode", WithoutBufferNode(), nil, map[int64][]string{ + 100: {"ch2", "ch3"}, + 101: {"ch4", "ch5"}, + 102: {"ch6", "ch7"}, + }}, + {"test WithoutNodeIDs 100, 101", WithoutNodeIDs(100, 101), nil, map[int64][]string{ + bufferID: {"ch1"}, + 102: {"ch6", "ch7"}, + }}, + { + "test WithChannelName ch1", WithNodeIDs(bufferID), + []ChannelSelector{WithChannelName("ch1")}, + map[int64][]string{ + bufferID: {"ch1"}, + }, + }, + { + "test WithChannelName ch1, collectionID 1", WithNodeIDs(100), + []ChannelSelector{ + WithChannelName("ch2"), + WithCollectionIDV2(1), + }, + map[int64][]string{100: {"ch2"}}, + }, + { + "test WithCollectionID 1", WithAllNodes(), + []ChannelSelector{ + WithCollectionIDV2(1), + }, + map[int64][]string{ + bufferID: {"ch1"}, + 100: {"ch2", "ch3"}, + 101: {"ch4", "ch5"}, + 102: {"ch6", "ch7"}, + }, + }, + { + "test WithChannelState", WithNodeIDs(102), + []ChannelSelector{ + WithChannelStates(Legacy), + }, + map[int64][]string{ + 102: {"ch6", "ch7"}, + }, + }, + } + + for _, test := range tests { + s.Run(test.description, func() { + if test.channelSelectors == nil { + test.channelSelectors = []ChannelSelector{} + } + + infos := store.GetNodeChannelsBy(test.nodeSelector, test.channelSelectors...) + log.Info("got test infos", zap.Any("infos", infos)) + s.Equal(len(test.expectedResult), len(infos)) + + lo.ForEach(infos, func(info *NodeChannelInfo, _ int) { + expectedChannels, ok := test.expectedResult[info.NodeID] + s.True(ok) + + gotChannels := lo.Keys(info.Channels) + s.ElementsMatch(expectedChannels, gotChannels) + }) + }) + } +} + +func (s *StateChannelStoreSuite) TestUpdateWithTxnLimit() { + tests := []struct { + description string + inOpCount int + outTxnCount int + }{ + {"operations count < maxPerTxn", maxOperationsPerTxn - 1, 1}, + {"operations count = maxPerTxn", maxOperationsPerTxn, 1}, + {"operations count > maxPerTxn", maxOperationsPerTxn + 1, 2}, + {"operations count = 2*maxPerTxn", maxOperationsPerTxn * 2, 2}, + {"operations count = 2*maxPerTxn+1", maxOperationsPerTxn*2 + 1, 3}, + } + + for _, test := range tests { + s.SetupTest() + s.Run(test.description, func() { + s.mockTxn.EXPECT().MultiSaveAndRemove(mock.Anything, mock.Anything). + Run(func(saves map[string]string, removals []string, preds ...predicates.Predicate) { + log.Info("test save and remove", zap.Any("saves", saves), zap.Any("removals", removals)) + }).Return(nil).Times(test.outTxnCount) + + store := NewStateChannelStore(s.mockTxn) + store.AddNode(1) + s.Require().ElementsMatch([]int64{1}, store.GetNodes()) + s.Require().Equal(0, store.GetNodeChannelCount(1)) + + // Get operations + ops := genChannelOperations(1, Watch, test.inOpCount) + err := store.Update(ops) + s.NoError(err) + }) + } +} + +func (s *StateChannelStoreSuite) TestUpdateMeta() { + tests := []struct { + description string + + opSet *ChannelOpSet + nodeIDs []int64 + channels []*StateChannel + assignments map[int64][]string + + outAssignments map[int64][]string + }{ + { + "delete_watch_ch1 from bufferID to nodeID=100", + NewChannelOpSet( + NewChannelOp(bufferID, Delete, getChannel("ch1", 1)), + NewChannelOp(100, Watch, getChannel("ch1", 1)), + ), + []int64{bufferID, 100}, + []*StateChannel{getChannel("ch1", 1)}, + map[int64][]string{ + bufferID: {"ch1"}, + }, + map[int64][]string{ + 100: {"ch1"}, + }, + }, + { + "delete_watch_ch1 from lagecyID=99 to nodeID=100", + NewChannelOpSet( + NewChannelOp(99, Delete, getChannel("ch1", 1)), + NewChannelOp(100, Watch, getChannel("ch1", 1)), + ), + []int64{bufferID, 99, 100}, + []*StateChannel{getChannel("ch1", 1)}, + map[int64][]string{ + 99: {"ch1"}, + }, + map[int64][]string{ + 100: {"ch1"}, + }, + }, + { + "release from nodeID=100", + NewChannelOpSet( + NewChannelOp(100, Release, getChannel("ch1", 1)), + ), + []int64{bufferID, 100}, + []*StateChannel{getChannel("ch1", 1)}, + map[int64][]string{ + 100: {"ch1"}, + }, + map[int64][]string{ + 100: {"ch1"}, + }, + }, + { + "watch a new channel from nodeID=100", + NewChannelOpSet( + NewChannelOp(100, Watch, getChannel("ch1", 1)), + ), + []int64{bufferID, 100}, + []*StateChannel{getChannel("ch1", 1)}, + map[int64][]string{ + 100: {"ch1"}, + }, + map[int64][]string{ + 100: {"ch1"}, + }, + }, + { + "Delete remove a channelfrom nodeID=100", + NewChannelOpSet( + NewChannelOp(100, Delete, getChannel("ch1", 1)), + ), + []int64{bufferID, 100}, + []*StateChannel{getChannel("ch1", 1)}, + map[int64][]string{ + 100: {"ch1"}, + }, + map[int64][]string{ + 100: {}, + }, + }, + } + s.SetupTest() + s.mockTxn.EXPECT().MultiSaveAndRemove(mock.Anything, mock.Anything). + Run(func(saves map[string]string, removals []string, preds ...predicates.Predicate) { + }).Return(nil).Times(len(tests)) + + for _, test := range tests { + s.Run(test.description, func() { + store := NewStateChannelStore(s.mockTxn) + + lo.ForEach(test.nodeIDs, func(nodeID int64, _ int) { + store.AddNode(nodeID) + s.Require().Equal(0, store.GetNodeChannelCount(nodeID)) + }) + c := make(map[string]*StateChannel) + lo.ForEach(test.channels, func(ch *StateChannel, _ int) { c[ch.GetName()] = ch }) + for nodeID, channels := range test.assignments { + lo.ForEach(channels, func(ch string, _ int) { + store.addAssignment(nodeID, c[ch]) + }) + s.Require().Equal(1, store.GetNodeChannelCount(nodeID)) + } + + err := store.updateMeta(test.opSet) + s.NoError(err) + + for nodeID, channels := range test.outAssignments { + got := store.GetNodeChannelsBy(WithNodeIDs(nodeID)) + s.NotNil(got) + s.Require().Equal(1, len(got)) + info := got[0] + s.ElementsMatch(channels, lo.Keys(info.Channels)) + } + }) + } +} + +func (s *StateChannelStoreSuite) TestUpdateState() { + tests := []struct { + description string + + inSuccess bool + inChannelState ChannelState + outChannelState ChannelState + }{ + {"input standby, fail", false, Standby, Standby}, + {"input standby, success", true, Standby, ToWatch}, + } + + for _, test := range tests { + s.Run(test.description, func() { + store := NewStateChannelStore(s.mockTxn) + + ch := "ch-1" + channel := NewStateChannel(getChannel(ch, 1)) + channel.setState(test.inChannelState) + store.channelsInfo[1] = &NodeChannelInfo{ + NodeID: bufferID, + Channels: map[string]RWChannel{ + ch: channel, + }, + } + + store.UpdateState(test.inSuccess, channel) + s.Equal(test.outChannelState, channel.currentState) + }) + } +} + +func (s *StateChannelStoreSuite) TestReload() { + type item struct { + nodeID int64 + channelName string + } + type testCase struct { + tag string + items []item + expect map[int64]int + } + + cases := []testCase{ + { + tag: "empty", + items: []item{}, + expect: map[int64]int{}, + }, + { + tag: "normal", + items: []item{ + {nodeID: 1, channelName: "dml1_v0"}, + {nodeID: 1, channelName: "dml2_v1"}, + {nodeID: 2, channelName: "dml3_v0"}, + }, + expect: map[int64]int{1: 2, 2: 1}, + }, + { + tag: "buffer", + items: []item{ + {nodeID: bufferID, channelName: "dml1_v0"}, + }, + expect: map[int64]int{bufferID: 1}, + }, + } + + for _, tc := range cases { + s.Run(tc.tag, func() { + s.mockTxn.ExpectedCalls = nil + + var keys, values []string + for _, item := range tc.items { + keys = append(keys, fmt.Sprintf("channel_store/%d/%s", item.nodeID, item.channelName)) + info := generateWatchInfo(item.channelName, datapb.ChannelWatchState_WatchSuccess) + bs, err := proto.Marshal(info) + s.Require().NoError(err) + values = append(values, string(bs)) + } + s.mockTxn.EXPECT().LoadWithPrefix(mock.AnythingOfType("string")).Return(keys, values, nil) + + store := NewStateChannelStore(s.mockTxn) + err := store.Reload() + s.Require().NoError(err) + + for nodeID, expect := range tc.expect { + s.MetricsEqual(metrics.DataCoordDmlChannelNum.WithLabelValues(strconv.FormatInt(nodeID, 10)), float64(expect)) + } + }) + } +} + +func genChannelOperations(nodeID int64, opType ChannelOpType, num int) *ChannelOpSet { + channels := make([]RWChannel, 0, num) + for i := 0; i < num; i++ { + name := fmt.Sprintf("ch%d", i) + channel := NewStateChannel(getChannel(name, 1)) + channel.Info = &datapb.ChannelWatchInfo{} + channels = append(channels, channel) + } + + ops := NewChannelOpSet(NewChannelOp(nodeID, opType, channels...)) + return ops +} diff --git a/internal/datacoord/cluster.go b/internal/datacoord/cluster.go index bc8d3f0844..ee07d7be41 100644 --- a/internal/datacoord/cluster.go +++ b/internal/datacoord/cluster.go @@ -35,7 +35,7 @@ type Cluster interface { Startup(ctx context.Context, nodes []*NodeInfo) error Register(node *NodeInfo) error UnRegister(node *NodeInfo) error - Watch(ctx context.Context, ch string, collectionID UniqueID) error + Watch(ctx context.Context, ch RWChannel) error Flush(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error FlushChannels(ctx context.Context, nodeID int64, flushTs Timestamp, channels []string) error PreImport(nodeID int64, in *datapb.PreImportRequest) error @@ -69,10 +69,19 @@ func (c *ClusterImpl) Startup(ctx context.Context, nodes []*NodeInfo) error { for _, node := range nodes { c.sessionManager.AddSession(node) } - currs := lo.Map(nodes, func(info *NodeInfo, _ int) int64 { - return info.NodeID + + var ( + legacyNodes []int64 + allNodes []int64 + ) + + lo.ForEach(nodes, func(info *NodeInfo, _ int) { + if info.IsLegacy { + legacyNodes = append(legacyNodes, info.NodeID) + } + allNodes = append(allNodes, info.NodeID) }) - return c.channelManager.Startup(ctx, currs) + return c.channelManager.Startup(ctx, legacyNodes, allNodes) } // Register registers a new node in cluster @@ -88,14 +97,15 @@ func (c *ClusterImpl) UnRegister(node *NodeInfo) error { } // Watch tries to add a channel in datanode cluster -func (c *ClusterImpl) Watch(ctx context.Context, ch string, collectionID UniqueID) error { - return c.channelManager.Watch(ctx, &channelMeta{Name: ch, CollectionID: collectionID}) +func (c *ClusterImpl) Watch(ctx context.Context, ch RWChannel) error { + return c.channelManager.Watch(ctx, ch) } // Flush sends async FlushSegments requests to dataNodes // which also according to channels where segments are assigned to. func (c *ClusterImpl) Flush(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error { - if !c.channelManager.Match(nodeID, channel) { + ch, founded := c.channelManager.GetChannel(nodeID, channel) + if !founded { log.Warn("node is not matched with channel", zap.String("channel", channel), zap.Int64("nodeID", nodeID), @@ -103,8 +113,6 @@ func (c *ClusterImpl) Flush(ctx context.Context, nodeID int64, channel string, s return fmt.Errorf("channel %s is not watched on node %d", channel, nodeID) } - _, collID := c.channelManager.GetCollectionIDByChannel(channel) - getSegmentID := func(segment *datapb.SegmentInfo, _ int) int64 { return segment.GetID() } @@ -115,7 +123,7 @@ func (c *ClusterImpl) Flush(ctx context.Context, nodeID int64, channel string, s commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithTargetID(nodeID), ), - CollectionID: collID, + CollectionID: ch.GetCollectionID(), SegmentIDs: lo.Map(segments, getSegmentID), ChannelName: channel, } diff --git a/internal/datacoord/cluster_test.go b/internal/datacoord/cluster_test.go index 57c886f387..145cd11662 100644 --- a/internal/datacoord/cluster_test.go +++ b/internal/datacoord/cluster_test.go @@ -67,8 +67,8 @@ func (suite *ClusterSuite) TestStartup() { {NodeID: 4, Address: "addr4"}, } suite.mockSession.EXPECT().AddSession(mock.Anything).Return().Times(len(nodes)) - suite.mockChManager.EXPECT().Startup(mock.Anything, mock.Anything). - RunAndReturn(func(ctx context.Context, nodeIDs []int64) error { + suite.mockChManager.EXPECT().Startup(mock.Anything, mock.Anything, mock.Anything). + RunAndReturn(func(ctx context.Context, legacys []int64, nodeIDs []int64) error { suite.ElementsMatch(lo.Map(nodes, func(info *NodeInfo, _ int) int64 { return info.NodeID }), nodeIDs) return nil }).Once() @@ -122,17 +122,19 @@ func (suite *ClusterSuite) TestWatch() { }).Once() cluster := NewClusterImpl(suite.mockSession, suite.mockChManager) - err := cluster.Watch(context.Background(), ch, collectionID) + err := cluster.Watch(context.Background(), getChannel(ch, collectionID)) suite.NoError(err) } func (suite *ClusterSuite) TestFlush() { - suite.mockChManager.EXPECT().Match(mock.Anything, mock.Anything). - RunAndReturn(func(nodeID int64, channel string) bool { - return nodeID != 1 + suite.mockChManager.EXPECT().GetChannel(mock.Anything, mock.Anything). + RunAndReturn(func(nodeID int64, channel string) (RWChannel, bool) { + if nodeID == 1 { + return nil, false + } + return getChannel("ch-1", 2), true }).Twice() - suite.mockChManager.EXPECT().GetCollectionIDByChannel(mock.Anything).Return(true, 100).Once() suite.mockSession.EXPECT().Flush(mock.Anything, mock.Anything, mock.Anything).Once() cluster := NewClusterImpl(suite.mockSession, suite.mockChManager) diff --git a/internal/datacoord/mock_channel_store.go b/internal/datacoord/mock_channel_store.go index 81b8bc73cf..e0e469fba7 100644 --- a/internal/datacoord/mock_channel_store.go +++ b/internal/datacoord/mock_channel_store.go @@ -17,89 +17,35 @@ func (_m *MockRWChannelStore) EXPECT() *MockRWChannelStore_Expecter { return &MockRWChannelStore_Expecter{mock: &_m.Mock} } -// Add provides a mock function with given fields: nodeID -func (_m *MockRWChannelStore) Add(nodeID int64) { +// AddNode provides a mock function with given fields: nodeID +func (_m *MockRWChannelStore) AddNode(nodeID int64) { _m.Called(nodeID) } -// MockRWChannelStore_Add_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Add' -type MockRWChannelStore_Add_Call struct { +// MockRWChannelStore_AddNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddNode' +type MockRWChannelStore_AddNode_Call struct { *mock.Call } -// Add is a helper method to define mock.On call +// AddNode is a helper method to define mock.On call // - nodeID int64 -func (_e *MockRWChannelStore_Expecter) Add(nodeID interface{}) *MockRWChannelStore_Add_Call { - return &MockRWChannelStore_Add_Call{Call: _e.mock.On("Add", nodeID)} +func (_e *MockRWChannelStore_Expecter) AddNode(nodeID interface{}) *MockRWChannelStore_AddNode_Call { + return &MockRWChannelStore_AddNode_Call{Call: _e.mock.On("AddNode", nodeID)} } -func (_c *MockRWChannelStore_Add_Call) Run(run func(nodeID int64)) *MockRWChannelStore_Add_Call { +func (_c *MockRWChannelStore_AddNode_Call) Run(run func(nodeID int64)) *MockRWChannelStore_AddNode_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(int64)) }) return _c } -func (_c *MockRWChannelStore_Add_Call) Return() *MockRWChannelStore_Add_Call { +func (_c *MockRWChannelStore_AddNode_Call) Return() *MockRWChannelStore_AddNode_Call { _c.Call.Return() return _c } -func (_c *MockRWChannelStore_Add_Call) RunAndReturn(run func(int64)) *MockRWChannelStore_Add_Call { - _c.Call.Return(run) - return _c -} - -// Delete provides a mock function with given fields: nodeID -func (_m *MockRWChannelStore) Delete(nodeID int64) ([]RWChannel, error) { - ret := _m.Called(nodeID) - - var r0 []RWChannel - var r1 error - if rf, ok := ret.Get(0).(func(int64) ([]RWChannel, error)); ok { - return rf(nodeID) - } - if rf, ok := ret.Get(0).(func(int64) []RWChannel); ok { - r0 = rf(nodeID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]RWChannel) - } - } - - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(nodeID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockRWChannelStore_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' -type MockRWChannelStore_Delete_Call struct { - *mock.Call -} - -// Delete is a helper method to define mock.On call -// - nodeID int64 -func (_e *MockRWChannelStore_Expecter) Delete(nodeID interface{}) *MockRWChannelStore_Delete_Call { - return &MockRWChannelStore_Delete_Call{Call: _e.mock.On("Delete", nodeID)} -} - -func (_c *MockRWChannelStore_Delete_Call) Run(run func(nodeID int64)) *MockRWChannelStore_Delete_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) - }) - return _c -} - -func (_c *MockRWChannelStore_Delete_Call) Return(_a0 []RWChannel, _a1 error) *MockRWChannelStore_Delete_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockRWChannelStore_Delete_Call) RunAndReturn(run func(int64) ([]RWChannel, error)) *MockRWChannelStore_Delete_Call { +func (_c *MockRWChannelStore_AddNode_Call) RunAndReturn(run func(int64)) *MockRWChannelStore_AddNode_Call { _c.Call.Return(run) return _c } @@ -147,49 +93,6 @@ func (_c *MockRWChannelStore_GetBufferChannelInfo_Call) RunAndReturn(run func() return _c } -// GetChannels provides a mock function with given fields: -func (_m *MockRWChannelStore) GetChannels() []*NodeChannelInfo { - ret := _m.Called() - - var r0 []*NodeChannelInfo - if rf, ok := ret.Get(0).(func() []*NodeChannelInfo); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]*NodeChannelInfo) - } - } - - return r0 -} - -// MockRWChannelStore_GetChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChannels' -type MockRWChannelStore_GetChannels_Call struct { - *mock.Call -} - -// GetChannels is a helper method to define mock.On call -func (_e *MockRWChannelStore_Expecter) GetChannels() *MockRWChannelStore_GetChannels_Call { - return &MockRWChannelStore_GetChannels_Call{Call: _e.mock.On("GetChannels")} -} - -func (_c *MockRWChannelStore_GetChannels_Call) Run(run func()) *MockRWChannelStore_GetChannels_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *MockRWChannelStore_GetChannels_Call) Return(_a0 []*NodeChannelInfo) *MockRWChannelStore_GetChannels_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockRWChannelStore_GetChannels_Call) RunAndReturn(run func() []*NodeChannelInfo) *MockRWChannelStore_GetChannels_Call { - _c.Call.Return(run) - return _c -} - // GetNode provides a mock function with given fields: nodeID func (_m *MockRWChannelStore) GetNode(nodeID int64) *NodeChannelInfo { ret := _m.Called(nodeID) @@ -276,6 +179,65 @@ func (_c *MockRWChannelStore_GetNodeChannelCount_Call) RunAndReturn(run func(int return _c } +// GetNodeChannelsBy provides a mock function with given fields: nodeSelector, channelSelectors +func (_m *MockRWChannelStore) GetNodeChannelsBy(nodeSelector NodeSelector, channelSelectors ...ChannelSelector) []*NodeChannelInfo { + _va := make([]interface{}, len(channelSelectors)) + for _i := range channelSelectors { + _va[_i] = channelSelectors[_i] + } + var _ca []interface{} + _ca = append(_ca, nodeSelector) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 []*NodeChannelInfo + if rf, ok := ret.Get(0).(func(NodeSelector, ...ChannelSelector) []*NodeChannelInfo); ok { + r0 = rf(nodeSelector, channelSelectors...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*NodeChannelInfo) + } + } + + return r0 +} + +// MockRWChannelStore_GetNodeChannelsBy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeChannelsBy' +type MockRWChannelStore_GetNodeChannelsBy_Call struct { + *mock.Call +} + +// GetNodeChannelsBy is a helper method to define mock.On call +// - nodeSelector NodeSelector +// - channelSelectors ...ChannelSelector +func (_e *MockRWChannelStore_Expecter) GetNodeChannelsBy(nodeSelector interface{}, channelSelectors ...interface{}) *MockRWChannelStore_GetNodeChannelsBy_Call { + return &MockRWChannelStore_GetNodeChannelsBy_Call{Call: _e.mock.On("GetNodeChannelsBy", + append([]interface{}{nodeSelector}, channelSelectors...)...)} +} + +func (_c *MockRWChannelStore_GetNodeChannelsBy_Call) Run(run func(nodeSelector NodeSelector, channelSelectors ...ChannelSelector)) *MockRWChannelStore_GetNodeChannelsBy_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]ChannelSelector, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(ChannelSelector) + } + } + run(args[0].(NodeSelector), variadicArgs...) + }) + return _c +} + +func (_c *MockRWChannelStore_GetNodeChannelsBy_Call) Return(_a0 []*NodeChannelInfo) *MockRWChannelStore_GetNodeChannelsBy_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRWChannelStore_GetNodeChannelsBy_Call) RunAndReturn(run func(NodeSelector, ...ChannelSelector) []*NodeChannelInfo) *MockRWChannelStore_GetNodeChannelsBy_Call { + _c.Call.Return(run) + return _c +} + // GetNodes provides a mock function with given fields: func (_m *MockRWChannelStore) GetNodes() []int64 { ret := _m.Called() @@ -362,6 +324,48 @@ func (_c *MockRWChannelStore_GetNodesChannels_Call) RunAndReturn(run func() []*N return _c } +// HasChannel provides a mock function with given fields: channel +func (_m *MockRWChannelStore) HasChannel(channel string) bool { + ret := _m.Called(channel) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(channel) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockRWChannelStore_HasChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HasChannel' +type MockRWChannelStore_HasChannel_Call struct { + *mock.Call +} + +// HasChannel is a helper method to define mock.On call +// - channel string +func (_e *MockRWChannelStore_Expecter) HasChannel(channel interface{}) *MockRWChannelStore_HasChannel_Call { + return &MockRWChannelStore_HasChannel_Call{Call: _e.mock.On("HasChannel", channel)} +} + +func (_c *MockRWChannelStore_HasChannel_Call) Run(run func(channel string)) *MockRWChannelStore_HasChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockRWChannelStore_HasChannel_Call) Return(_a0 bool) *MockRWChannelStore_HasChannel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRWChannelStore_HasChannel_Call) RunAndReturn(run func(string) bool) *MockRWChannelStore_HasChannel_Call { + _c.Call.Return(run) + return _c +} + // Reload provides a mock function with given fields: func (_m *MockRWChannelStore) Reload() error { ret := _m.Called() @@ -403,6 +407,85 @@ func (_c *MockRWChannelStore_Reload_Call) RunAndReturn(run func() error) *MockRW return _c } +// RemoveNode provides a mock function with given fields: nodeID +func (_m *MockRWChannelStore) RemoveNode(nodeID int64) { + _m.Called(nodeID) +} + +// MockRWChannelStore_RemoveNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveNode' +type MockRWChannelStore_RemoveNode_Call struct { + *mock.Call +} + +// RemoveNode is a helper method to define mock.On call +// - nodeID int64 +func (_e *MockRWChannelStore_Expecter) RemoveNode(nodeID interface{}) *MockRWChannelStore_RemoveNode_Call { + return &MockRWChannelStore_RemoveNode_Call{Call: _e.mock.On("RemoveNode", nodeID)} +} + +func (_c *MockRWChannelStore_RemoveNode_Call) Run(run func(nodeID int64)) *MockRWChannelStore_RemoveNode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockRWChannelStore_RemoveNode_Call) Return() *MockRWChannelStore_RemoveNode_Call { + _c.Call.Return() + return _c +} + +func (_c *MockRWChannelStore_RemoveNode_Call) RunAndReturn(run func(int64)) *MockRWChannelStore_RemoveNode_Call { + _c.Call.Return(run) + return _c +} + +// SetLegacyChannelByNode provides a mock function with given fields: nodeIDs +func (_m *MockRWChannelStore) SetLegacyChannelByNode(nodeIDs ...int64) { + _va := make([]interface{}, len(nodeIDs)) + for _i := range nodeIDs { + _va[_i] = nodeIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + _m.Called(_ca...) +} + +// MockRWChannelStore_SetLegacyChannelByNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetLegacyChannelByNode' +type MockRWChannelStore_SetLegacyChannelByNode_Call struct { + *mock.Call +} + +// SetLegacyChannelByNode is a helper method to define mock.On call +// - nodeIDs ...int64 +func (_e *MockRWChannelStore_Expecter) SetLegacyChannelByNode(nodeIDs ...interface{}) *MockRWChannelStore_SetLegacyChannelByNode_Call { + return &MockRWChannelStore_SetLegacyChannelByNode_Call{Call: _e.mock.On("SetLegacyChannelByNode", + append([]interface{}{}, nodeIDs...)...)} +} + +func (_c *MockRWChannelStore_SetLegacyChannelByNode_Call) Run(run func(nodeIDs ...int64)) *MockRWChannelStore_SetLegacyChannelByNode_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]int64, len(args)-0) + for i, a := range args[0:] { + if a != nil { + variadicArgs[i] = a.(int64) + } + } + run(variadicArgs...) + }) + return _c +} + +func (_c *MockRWChannelStore_SetLegacyChannelByNode_Call) Return() *MockRWChannelStore_SetLegacyChannelByNode_Call { + _c.Call.Return() + return _c +} + +func (_c *MockRWChannelStore_SetLegacyChannelByNode_Call) RunAndReturn(run func(...int64)) *MockRWChannelStore_SetLegacyChannelByNode_Call { + _c.Call.Return(run) + return _c +} + // Update provides a mock function with given fields: op func (_m *MockRWChannelStore) Update(op *ChannelOpSet) error { ret := _m.Called(op) @@ -445,6 +528,54 @@ func (_c *MockRWChannelStore_Update_Call) RunAndReturn(run func(*ChannelOpSet) e return _c } +// UpdateState provides a mock function with given fields: isSuccessful, channels +func (_m *MockRWChannelStore) UpdateState(isSuccessful bool, channels ...RWChannel) { + _va := make([]interface{}, len(channels)) + for _i := range channels { + _va[_i] = channels[_i] + } + var _ca []interface{} + _ca = append(_ca, isSuccessful) + _ca = append(_ca, _va...) + _m.Called(_ca...) +} + +// MockRWChannelStore_UpdateState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateState' +type MockRWChannelStore_UpdateState_Call struct { + *mock.Call +} + +// UpdateState is a helper method to define mock.On call +// - isSuccessful bool +// - channels ...RWChannel +func (_e *MockRWChannelStore_Expecter) UpdateState(isSuccessful interface{}, channels ...interface{}) *MockRWChannelStore_UpdateState_Call { + return &MockRWChannelStore_UpdateState_Call{Call: _e.mock.On("UpdateState", + append([]interface{}{isSuccessful}, channels...)...)} +} + +func (_c *MockRWChannelStore_UpdateState_Call) Run(run func(isSuccessful bool, channels ...RWChannel)) *MockRWChannelStore_UpdateState_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]RWChannel, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(RWChannel) + } + } + run(args[0].(bool), variadicArgs...) + }) + return _c +} + +func (_c *MockRWChannelStore_UpdateState_Call) Return() *MockRWChannelStore_UpdateState_Call { + _c.Call.Return() + return _c +} + +func (_c *MockRWChannelStore_UpdateState_Call) RunAndReturn(run func(bool, ...RWChannel)) *MockRWChannelStore_UpdateState_Call { + _c.Call.Return(run) + return _c +} + // NewMockRWChannelStore creates a new instance of MockRWChannelStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockRWChannelStore(t interface { diff --git a/internal/datacoord/mock_channelmanager.go b/internal/datacoord/mock_channelmanager.go index 795478a3a4..e8b4ebe897 100644 --- a/internal/datacoord/mock_channelmanager.go +++ b/internal/datacoord/mock_channelmanager.go @@ -189,6 +189,105 @@ func (_c *MockChannelManager_FindWatcher_Call) RunAndReturn(run func(string) (in return _c } +// GetChannel provides a mock function with given fields: nodeID, channel +func (_m *MockChannelManager) GetChannel(nodeID int64, channel string) (RWChannel, bool) { + ret := _m.Called(nodeID, channel) + + var r0 RWChannel + var r1 bool + if rf, ok := ret.Get(0).(func(int64, string) (RWChannel, bool)); ok { + return rf(nodeID, channel) + } + if rf, ok := ret.Get(0).(func(int64, string) RWChannel); ok { + r0 = rf(nodeID, channel) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(RWChannel) + } + } + + if rf, ok := ret.Get(1).(func(int64, string) bool); ok { + r1 = rf(nodeID, channel) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// MockChannelManager_GetChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChannel' +type MockChannelManager_GetChannel_Call struct { + *mock.Call +} + +// GetChannel is a helper method to define mock.On call +// - nodeID int64 +// - channel string +func (_e *MockChannelManager_Expecter) GetChannel(nodeID interface{}, channel interface{}) *MockChannelManager_GetChannel_Call { + return &MockChannelManager_GetChannel_Call{Call: _e.mock.On("GetChannel", nodeID, channel)} +} + +func (_c *MockChannelManager_GetChannel_Call) Run(run func(nodeID int64, channel string)) *MockChannelManager_GetChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(string)) + }) + return _c +} + +func (_c *MockChannelManager_GetChannel_Call) Return(_a0 RWChannel, _a1 bool) *MockChannelManager_GetChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockChannelManager_GetChannel_Call) RunAndReturn(run func(int64, string) (RWChannel, bool)) *MockChannelManager_GetChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetChannelNamesByCollectionID provides a mock function with given fields: collectionID +func (_m *MockChannelManager) GetChannelNamesByCollectionID(collectionID int64) []string { + ret := _m.Called(collectionID) + + var r0 []string + if rf, ok := ret.Get(0).(func(int64) []string); ok { + r0 = rf(collectionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// MockChannelManager_GetChannelNamesByCollectionID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChannelNamesByCollectionID' +type MockChannelManager_GetChannelNamesByCollectionID_Call struct { + *mock.Call +} + +// GetChannelNamesByCollectionID is a helper method to define mock.On call +// - collectionID int64 +func (_e *MockChannelManager_Expecter) GetChannelNamesByCollectionID(collectionID interface{}) *MockChannelManager_GetChannelNamesByCollectionID_Call { + return &MockChannelManager_GetChannelNamesByCollectionID_Call{Call: _e.mock.On("GetChannelNamesByCollectionID", collectionID)} +} + +func (_c *MockChannelManager_GetChannelNamesByCollectionID_Call) Run(run func(collectionID int64)) *MockChannelManager_GetChannelNamesByCollectionID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockChannelManager_GetChannelNamesByCollectionID_Call) Return(_a0 []string) *MockChannelManager_GetChannelNamesByCollectionID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockChannelManager_GetChannelNamesByCollectionID_Call) RunAndReturn(run func(int64) []string) *MockChannelManager_GetChannelNamesByCollectionID_Call { + _c.Call.Return(run) + return _c +} + // GetChannelsByCollectionID provides a mock function with given fields: collectionID func (_m *MockChannelManager) GetChannelsByCollectionID(collectionID int64) []RWChannel { ret := _m.Called(collectionID) @@ -233,58 +332,6 @@ func (_c *MockChannelManager_GetChannelsByCollectionID_Call) RunAndReturn(run fu return _c } -// GetCollectionIDByChannel provides a mock function with given fields: channel -func (_m *MockChannelManager) GetCollectionIDByChannel(channel string) (bool, int64) { - ret := _m.Called(channel) - - var r0 bool - var r1 int64 - if rf, ok := ret.Get(0).(func(string) (bool, int64)); ok { - return rf(channel) - } - if rf, ok := ret.Get(0).(func(string) bool); ok { - r0 = rf(channel) - } else { - r0 = ret.Get(0).(bool) - } - - if rf, ok := ret.Get(1).(func(string) int64); ok { - r1 = rf(channel) - } else { - r1 = ret.Get(1).(int64) - } - - return r0, r1 -} - -// MockChannelManager_GetCollectionIDByChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionIDByChannel' -type MockChannelManager_GetCollectionIDByChannel_Call struct { - *mock.Call -} - -// GetCollectionIDByChannel is a helper method to define mock.On call -// - channel string -func (_e *MockChannelManager_Expecter) GetCollectionIDByChannel(channel interface{}) *MockChannelManager_GetCollectionIDByChannel_Call { - return &MockChannelManager_GetCollectionIDByChannel_Call{Call: _e.mock.On("GetCollectionIDByChannel", channel)} -} - -func (_c *MockChannelManager_GetCollectionIDByChannel_Call) Run(run func(channel string)) *MockChannelManager_GetCollectionIDByChannel_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) - }) - return _c -} - -func (_c *MockChannelManager_GetCollectionIDByChannel_Call) Return(_a0 bool, _a1 int64) *MockChannelManager_GetCollectionIDByChannel_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockChannelManager_GetCollectionIDByChannel_Call) RunAndReturn(run func(string) (bool, int64)) *MockChannelManager_GetCollectionIDByChannel_Call { - _c.Call.Return(run) - return _c -} - // GetNodeChannelsByCollectionID provides a mock function with given fields: collectionID func (_m *MockChannelManager) GetNodeChannelsByCollectionID(collectionID int64) map[int64][]string { ret := _m.Called(collectionID) @@ -330,24 +377,24 @@ func (_c *MockChannelManager_GetNodeChannelsByCollectionID_Call) RunAndReturn(ru } // GetNodeIDByChannelName provides a mock function with given fields: channel -func (_m *MockChannelManager) GetNodeIDByChannelName(channel string) (bool, int64) { +func (_m *MockChannelManager) GetNodeIDByChannelName(channel string) (int64, bool) { ret := _m.Called(channel) - var r0 bool - var r1 int64 - if rf, ok := ret.Get(0).(func(string) (bool, int64)); ok { + var r0 int64 + var r1 bool + if rf, ok := ret.Get(0).(func(string) (int64, bool)); ok { return rf(channel) } - if rf, ok := ret.Get(0).(func(string) bool); ok { + if rf, ok := ret.Get(0).(func(string) int64); ok { r0 = rf(channel) } else { - r0 = ret.Get(0).(bool) + r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(string) int64); ok { + if rf, ok := ret.Get(1).(func(string) bool); ok { r1 = rf(channel) } else { - r1 = ret.Get(1).(int64) + r1 = ret.Get(1).(bool) } return r0, r1 @@ -371,12 +418,12 @@ func (_c *MockChannelManager_GetNodeIDByChannelName_Call) Run(run func(channel s return _c } -func (_c *MockChannelManager_GetNodeIDByChannelName_Call) Return(_a0 bool, _a1 int64) *MockChannelManager_GetNodeIDByChannelName_Call { +func (_c *MockChannelManager_GetNodeIDByChannelName_Call) Return(_a0 int64, _a1 bool) *MockChannelManager_GetNodeIDByChannelName_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockChannelManager_GetNodeIDByChannelName_Call) RunAndReturn(run func(string) (bool, int64)) *MockChannelManager_GetNodeIDByChannelName_Call { +func (_c *MockChannelManager_GetNodeIDByChannelName_Call) RunAndReturn(run func(string) (int64, bool)) *MockChannelManager_GetNodeIDByChannelName_Call { _c.Call.Return(run) return _c } @@ -467,55 +514,13 @@ func (_c *MockChannelManager_Release_Call) RunAndReturn(run func(int64, string) return _c } -// RemoveChannel provides a mock function with given fields: channelName -func (_m *MockChannelManager) RemoveChannel(channelName string) error { - ret := _m.Called(channelName) +// Startup provides a mock function with given fields: ctx, legacyNodes, allNodes +func (_m *MockChannelManager) Startup(ctx context.Context, legacyNodes []int64, allNodes []int64) error { + ret := _m.Called(ctx, legacyNodes, allNodes) var r0 error - if rf, ok := ret.Get(0).(func(string) error); ok { - r0 = rf(channelName) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MockChannelManager_RemoveChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveChannel' -type MockChannelManager_RemoveChannel_Call struct { - *mock.Call -} - -// RemoveChannel is a helper method to define mock.On call -// - channelName string -func (_e *MockChannelManager_Expecter) RemoveChannel(channelName interface{}) *MockChannelManager_RemoveChannel_Call { - return &MockChannelManager_RemoveChannel_Call{Call: _e.mock.On("RemoveChannel", channelName)} -} - -func (_c *MockChannelManager_RemoveChannel_Call) Run(run func(channelName string)) *MockChannelManager_RemoveChannel_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) - }) - return _c -} - -func (_c *MockChannelManager_RemoveChannel_Call) Return(_a0 error) *MockChannelManager_RemoveChannel_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockChannelManager_RemoveChannel_Call) RunAndReturn(run func(string) error) *MockChannelManager_RemoveChannel_Call { - _c.Call.Return(run) - return _c -} - -// Startup provides a mock function with given fields: ctx, nodes -func (_m *MockChannelManager) Startup(ctx context.Context, nodes []int64) error { - ret := _m.Called(ctx, nodes) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, []int64) error); ok { - r0 = rf(ctx, nodes) + if rf, ok := ret.Get(0).(func(context.Context, []int64, []int64) error); ok { + r0 = rf(ctx, legacyNodes, allNodes) } else { r0 = ret.Error(0) } @@ -530,14 +535,15 @@ type MockChannelManager_Startup_Call struct { // Startup is a helper method to define mock.On call // - ctx context.Context -// - nodes []int64 -func (_e *MockChannelManager_Expecter) Startup(ctx interface{}, nodes interface{}) *MockChannelManager_Startup_Call { - return &MockChannelManager_Startup_Call{Call: _e.mock.On("Startup", ctx, nodes)} +// - legacyNodes []int64 +// - allNodes []int64 +func (_e *MockChannelManager_Expecter) Startup(ctx interface{}, legacyNodes interface{}, allNodes interface{}) *MockChannelManager_Startup_Call { + return &MockChannelManager_Startup_Call{Call: _e.mock.On("Startup", ctx, legacyNodes, allNodes)} } -func (_c *MockChannelManager_Startup_Call) Run(run func(ctx context.Context, nodes []int64)) *MockChannelManager_Startup_Call { +func (_c *MockChannelManager_Startup_Call) Run(run func(ctx context.Context, legacyNodes []int64, allNodes []int64)) *MockChannelManager_Startup_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].([]int64)) + run(args[0].(context.Context), args[1].([]int64), args[2].([]int64)) }) return _c } @@ -547,7 +553,7 @@ func (_c *MockChannelManager_Startup_Call) Return(_a0 error) *MockChannelManager return _c } -func (_c *MockChannelManager_Startup_Call) RunAndReturn(run func(context.Context, []int64) error) *MockChannelManager_Startup_Call { +func (_c *MockChannelManager_Startup_Call) RunAndReturn(run func(context.Context, []int64, []int64) error) *MockChannelManager_Startup_Call { _c.Call.Return(run) return _c } diff --git a/internal/datacoord/mock_cluster.go b/internal/datacoord/mock_cluster.go index e83494cf33..77a9b56633 100644 --- a/internal/datacoord/mock_cluster.go +++ b/internal/datacoord/mock_cluster.go @@ -553,13 +553,13 @@ func (_c *MockCluster_UnRegister_Call) RunAndReturn(run func(*NodeInfo) error) * return _c } -// Watch provides a mock function with given fields: ctx, ch, collectionID -func (_m *MockCluster) Watch(ctx context.Context, ch string, collectionID int64) error { - ret := _m.Called(ctx, ch, collectionID) +// Watch provides a mock function with given fields: ctx, ch +func (_m *MockCluster) Watch(ctx context.Context, ch RWChannel) error { + ret := _m.Called(ctx, ch) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, int64) error); ok { - r0 = rf(ctx, ch, collectionID) + if rf, ok := ret.Get(0).(func(context.Context, RWChannel) error); ok { + r0 = rf(ctx, ch) } else { r0 = ret.Error(0) } @@ -574,15 +574,14 @@ type MockCluster_Watch_Call struct { // Watch is a helper method to define mock.On call // - ctx context.Context -// - ch string -// - collectionID int64 -func (_e *MockCluster_Expecter) Watch(ctx interface{}, ch interface{}, collectionID interface{}) *MockCluster_Watch_Call { - return &MockCluster_Watch_Call{Call: _e.mock.On("Watch", ctx, ch, collectionID)} +// - ch RWChannel +func (_e *MockCluster_Expecter) Watch(ctx interface{}, ch interface{}) *MockCluster_Watch_Call { + return &MockCluster_Watch_Call{Call: _e.mock.On("Watch", ctx, ch)} } -func (_c *MockCluster_Watch_Call) Run(run func(ctx context.Context, ch string, collectionID int64)) *MockCluster_Watch_Call { +func (_c *MockCluster_Watch_Call) Run(run func(ctx context.Context, ch RWChannel)) *MockCluster_Watch_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(int64)) + run(args[0].(context.Context), args[1].(RWChannel)) }) return _c } @@ -592,7 +591,7 @@ func (_c *MockCluster_Watch_Call) Return(_a0 error) *MockCluster_Watch_Call { return _c } -func (_c *MockCluster_Watch_Call) RunAndReturn(run func(context.Context, string, int64) error) *MockCluster_Watch_Call { +func (_c *MockCluster_Watch_Call) RunAndReturn(run func(context.Context, RWChannel) error) *MockCluster_Watch_Call { _c.Call.Return(run) return _c } diff --git a/internal/datacoord/mock_subcluster.go b/internal/datacoord/mock_subcluster.go new file mode 100644 index 0000000000..465eb2ac73 --- /dev/null +++ b/internal/datacoord/mock_subcluster.go @@ -0,0 +1,137 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package datacoord + +import ( + context "context" + + datapb "github.com/milvus-io/milvus/internal/proto/datapb" + mock "github.com/stretchr/testify/mock" +) + +// MockSubCluster is an autogenerated mock type for the SubCluster type +type MockSubCluster struct { + mock.Mock +} + +type MockSubCluster_Expecter struct { + mock *mock.Mock +} + +func (_m *MockSubCluster) EXPECT() *MockSubCluster_Expecter { + return &MockSubCluster_Expecter{mock: &_m.Mock} +} + +// CheckChannelOperationProgress provides a mock function with given fields: ctx, nodeID, info +func (_m *MockSubCluster) CheckChannelOperationProgress(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error) { + ret := _m.Called(ctx, nodeID, info) + + var r0 *datapb.ChannelOperationProgressResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64, *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error)); ok { + return rf(ctx, nodeID, info) + } + if rf, ok := ret.Get(0).(func(context.Context, int64, *datapb.ChannelWatchInfo) *datapb.ChannelOperationProgressResponse); ok { + r0 = rf(ctx, nodeID, info) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.ChannelOperationProgressResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64, *datapb.ChannelWatchInfo) error); ok { + r1 = rf(ctx, nodeID, info) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSubCluster_CheckChannelOperationProgress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckChannelOperationProgress' +type MockSubCluster_CheckChannelOperationProgress_Call struct { + *mock.Call +} + +// CheckChannelOperationProgress is a helper method to define mock.On call +// - ctx context.Context +// - nodeID int64 +// - info *datapb.ChannelWatchInfo +func (_e *MockSubCluster_Expecter) CheckChannelOperationProgress(ctx interface{}, nodeID interface{}, info interface{}) *MockSubCluster_CheckChannelOperationProgress_Call { + return &MockSubCluster_CheckChannelOperationProgress_Call{Call: _e.mock.On("CheckChannelOperationProgress", ctx, nodeID, info)} +} + +func (_c *MockSubCluster_CheckChannelOperationProgress_Call) Run(run func(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo)) *MockSubCluster_CheckChannelOperationProgress_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64), args[2].(*datapb.ChannelWatchInfo)) + }) + return _c +} + +func (_c *MockSubCluster_CheckChannelOperationProgress_Call) Return(_a0 *datapb.ChannelOperationProgressResponse, _a1 error) *MockSubCluster_CheckChannelOperationProgress_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSubCluster_CheckChannelOperationProgress_Call) RunAndReturn(run func(context.Context, int64, *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error)) *MockSubCluster_CheckChannelOperationProgress_Call { + _c.Call.Return(run) + return _c +} + +// NotifyChannelOperation provides a mock function with given fields: ctx, nodeID, req +func (_m *MockSubCluster) NotifyChannelOperation(ctx context.Context, nodeID int64, req *datapb.ChannelOperationsRequest) error { + ret := _m.Called(ctx, nodeID, req) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64, *datapb.ChannelOperationsRequest) error); ok { + r0 = rf(ctx, nodeID, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockSubCluster_NotifyChannelOperation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NotifyChannelOperation' +type MockSubCluster_NotifyChannelOperation_Call struct { + *mock.Call +} + +// NotifyChannelOperation is a helper method to define mock.On call +// - ctx context.Context +// - nodeID int64 +// - req *datapb.ChannelOperationsRequest +func (_e *MockSubCluster_Expecter) NotifyChannelOperation(ctx interface{}, nodeID interface{}, req interface{}) *MockSubCluster_NotifyChannelOperation_Call { + return &MockSubCluster_NotifyChannelOperation_Call{Call: _e.mock.On("NotifyChannelOperation", ctx, nodeID, req)} +} + +func (_c *MockSubCluster_NotifyChannelOperation_Call) Run(run func(ctx context.Context, nodeID int64, req *datapb.ChannelOperationsRequest)) *MockSubCluster_NotifyChannelOperation_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64), args[2].(*datapb.ChannelOperationsRequest)) + }) + return _c +} + +func (_c *MockSubCluster_NotifyChannelOperation_Call) Return(_a0 error) *MockSubCluster_NotifyChannelOperation_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSubCluster_NotifyChannelOperation_Call) RunAndReturn(run func(context.Context, int64, *datapb.ChannelOperationsRequest) error) *MockSubCluster_NotifyChannelOperation_Call { + _c.Call.Return(run) + return _c +} + +// NewMockSubCluster creates a new instance of MockSubCluster. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSubCluster(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSubCluster { + mock := &MockSubCluster{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/datacoord/policy.go b/internal/datacoord/policy.go index fc943c9503..6d6077dd82 100644 --- a/internal/datacoord/policy.go +++ b/internal/datacoord/policy.go @@ -17,20 +17,18 @@ package datacoord import ( - "context" "math" "sort" - "strconv" - "time" "github.com/samber/lo" "go.uber.org/zap" "go.uber.org/zap/zapcore" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) -// RegisterPolicy decides the channels mapping after registering the nodeID +// RegisterPolicy decides the channels mapping after registering a new nodeID // return bufferedUpdates and balanceUpdates type RegisterPolicy func(store ROChannelStore, nodeID int64) (*ChannelOpSet, *ChannelOpSet) @@ -47,8 +45,8 @@ func BufferChannelAssignPolicy(store ROChannelStore, nodeID int64) *ChannelOpSet } opSet := NewChannelOpSet( - NewDeleteOp(bufferID, lo.Values(info.Channels)...), - NewAddOp(nodeID, lo.Values(info.Channels)...)) + NewChannelOp(bufferID, Delete, lo.Values(info.Channels)...), + NewChannelOp(nodeID, Watch, lo.Values(info.Channels)...)) return opSet } @@ -61,14 +59,15 @@ func AvgAssignRegisterPolicy(store ROChannelStore, nodeID int64) (*ChannelOpSet, } // Get a list of available node-channel info. - avaNodes := filterNode(store.GetNodesChannels(), nodeID) + allNodes := store.GetNodesChannels() + avaNodes := filterNode(allNodes, nodeID) channelNum := 0 for _, info := range avaNodes { channelNum += len(info.Channels) } // store already add the new node - chPerNode := channelNum / len(store.GetNodes()) + chPerNode := channelNum / len(allNodes) if chPerNode == 0 { return nil, nil } @@ -95,7 +94,7 @@ func AvgAssignRegisterPolicy(store ROChannelStore, nodeID int64) (*ChannelOpSet, // Channels in `releases` are reassigned eventually by channel manager. opSet = NewChannelOpSet() for k, v := range releases { - opSet.Add(k, v...) + opSet.Append(k, Release, v...) } return nil, opSet } @@ -112,20 +111,14 @@ func filterNode(infos []*NodeChannelInfo, nodeID int64) []*NodeChannelInfo { return filtered } -func formatNodeID(nodeID int64) string { - return strconv.FormatInt(nodeID, 10) -} - -func deformatNodeID(node string) (int64, error) { - return strconv.ParseInt(node, 10, 64) -} - -// ChannelAssignPolicy assign channels to registered nodes. +// ChannelAssignPolicy assign new channels to registered nodes. type ChannelAssignPolicy func(store ROChannelStore, channels []RWChannel) *ChannelOpSet // AverageAssignPolicy ensure that the number of channels per nodes is approximately the same func AverageAssignPolicy(store ROChannelStore, channels []RWChannel) *ChannelOpSet { - newChannels := filterChannels(store, channels) + newChannels := lo.Filter(channels, func(ch RWChannel, _ int) bool { + return !store.HasChannel(ch.GetName()) + }) if len(newChannels) == 0 { return nil } @@ -135,7 +128,7 @@ func AverageAssignPolicy(store ROChannelStore, channels []RWChannel) *ChannelOpS // If no datanode alive, save channels in buffer if len(allDataNodes) == 0 { - opSet.Add(bufferID, channels...) + opSet.Append(bufferID, Watch, channels...) return opSet } @@ -151,35 +144,11 @@ func AverageAssignPolicy(store ROChannelStore, channels []RWChannel) *ChannelOpS } for id, chs := range updates { - opSet.Add(id, chs...) + opSet.Append(id, Watch, chs...) } return opSet } -func filterChannels(store ROChannelStore, channels []RWChannel) []RWChannel { - channelsMap := make(map[string]RWChannel) - for _, c := range channels { - channelsMap[c.GetName()] = c - } - - allChannelsInfo := store.GetChannels() - for _, info := range allChannelsInfo { - for _, c := range info.Channels { - delete(channelsMap, c.GetName()) - } - } - - if len(channelsMap) == 0 { - return nil - } - - filtered := make([]RWChannel, 0, len(channelsMap)) - for _, v := range channelsMap { - filtered = append(filtered, v) - } - return filtered -} - // DeregisterPolicy determine the mapping after deregistering the nodeID type DeregisterPolicy func(store ROChannelStore, nodeID int64) *ChannelOpSet @@ -190,22 +159,21 @@ func EmptyDeregisterPolicy(store ROChannelStore, nodeID int64) *ChannelOpSet { // AvgAssignUnregisteredChannels evenly assign the unregistered channels func AvgAssignUnregisteredChannels(store ROChannelStore, nodeID int64) *ChannelOpSet { - allNodes := store.GetNodesChannels() - avaNodes := make([]*NodeChannelInfo, 0, len(allNodes)) - unregisteredChannels := make([]RWChannel, 0) - opSet := NewChannelOpSet() - - for _, c := range allNodes { - if c.NodeID == nodeID { - opSet.Delete(nodeID, lo.Values(c.Channels)...) - unregisteredChannels = append(unregisteredChannels, lo.Values(c.Channels)...) - continue - } - avaNodes = append(avaNodes, c) + nodeChannel := store.GetNode(nodeID) + if nodeChannel == nil || len(nodeChannel.Channels) == 0 { + return nil } + unregisteredChannels := nodeChannel.Channels + avaNodes := lo.Filter(store.GetNodesChannels(), func(info *NodeChannelInfo, _ int) bool { + return info.NodeID != nodeID + }) + + opSet := NewChannelOpSet() + opSet.Delete(nodeChannel.NodeID, lo.Values(nodeChannel.Channels)...) + if len(avaNodes) == 0 { - opSet.Add(bufferID, unregisteredChannels...) + opSet.Append(bufferID, Watch, lo.Values(unregisteredChannels)...) return opSet } @@ -215,33 +183,19 @@ func AvgAssignUnregisteredChannels(store ROChannelStore, nodeID int64) *ChannelO }) updates := make(map[int64][]RWChannel) - for i, unregisteredChannel := range unregisteredChannels { - n := avaNodes[i%len(avaNodes)].NodeID + cnt := 0 + for _, unregisteredChannel := range unregisteredChannels { + n := avaNodes[cnt%len(avaNodes)].NodeID updates[n] = append(updates[n], unregisteredChannel) + cnt++ } for id, chs := range updates { - opSet.Add(id, chs...) + opSet.Append(id, Watch, chs...) } return opSet } -type BalanceChannelPolicy func(store ROChannelStore, ts time.Time) *ChannelOpSet - -func AvgBalanceChannelPolicy(store ROChannelStore, ts time.Time) *ChannelOpSet { - opSet := NewChannelOpSet() - reAllocates, err := BgBalanceCheck(store.GetNodesChannels(), ts) - if err != nil { - log.Error("failed to balance node channels", zap.Error(err)) - return opSet - } - for _, reAlloc := range reAllocates { - opSet.Add(reAlloc.NodeID, lo.Values(reAlloc.Channels)...) - } - - return opSet -} - // ChannelReassignPolicy is a policy for reassigning channels type ChannelReassignPolicy func(store ROChannelStore, reassigns []*NodeChannelInfo) *ChannelOpSet @@ -250,25 +204,20 @@ func EmptyReassignPolicy(store ROChannelStore, reassigns []*NodeChannelInfo) *Ch return nil } -// EmptyBalancePolicy is a dummy balance policy -func EmptyBalancePolicy(store ROChannelStore, ts time.Time) *ChannelOpSet { - return nil -} - // AverageReassignPolicy is a reassigning policy that evenly balance channels among datanodes -// which is used by bgChecker func AverageReassignPolicy(store ROChannelStore, reassigns []*NodeChannelInfo) *ChannelOpSet { allNodes := store.GetNodesChannels() - filterMap := make(map[int64]struct{}) toReassignTotalNum := 0 for _, reassign := range reassigns { - filterMap[reassign.NodeID] = struct{}{} toReassignTotalNum += len(reassign.Channels) } + avaNodes := make([]*NodeChannelInfo, 0, len(allNodes)) avaNodesChannelSum := 0 for _, node := range allNodes { - if _, ok := filterMap[node.NodeID]; ok { + if lo.ContainsBy(reassigns, func(info *NodeChannelInfo) bool { + return node.NodeID == info.NodeID + }) { continue } avaNodes = append(avaNodes, node) @@ -279,7 +228,6 @@ func AverageReassignPolicy(store ROChannelStore, reassigns []*NodeChannelInfo) * if len(avaNodes) == 0 { // if no node is left, do not reassign - log.Warn("there is no available nodes when reassigning, return") return nil } @@ -322,7 +270,7 @@ func AverageReassignPolicy(store ROChannelStore, reassigns []*NodeChannelInfo) * nodeIdx++ } if _, ok := addUpdates[targetID]; !ok { - addUpdates[targetID] = NewAddOp(targetID, ch) + addUpdates[targetID] = NewChannelOp(targetID, Watch, ch) } else { addUpdates[targetID].Append(ch) } @@ -334,18 +282,19 @@ func AverageReassignPolicy(store ROChannelStore, reassigns []*NodeChannelInfo) * return opSet } -// ChannelBGChecker check nodes' channels and return the channels needed to be reallocated. -type ChannelBGChecker func(ctx context.Context) +type Assignments []*NodeChannelInfo -// EmptyBgChecker does nothing -func EmptyBgChecker(channels []*NodeChannelInfo, ts time.Time) ([]*NodeChannelInfo, error) { - return nil, nil +func (a Assignments) GetChannelCount(nodeID int64) int { + for _, info := range a { + if info.NodeID == nodeID { + return len(info.Channels) + } + } + return 0 } -type ReAllocates []*NodeChannelInfo - -func (rallocates ReAllocates) MarshalLogArray(enc zapcore.ArrayEncoder) error { - for _, nChannelInfo := range rallocates { +func (a Assignments) MarshalLogArray(enc zapcore.ArrayEncoder) error { + for _, nChannelInfo := range a { enc.AppendString("nodeID:") enc.AppendInt64(nChannelInfo.NodeID) cstr := "[" @@ -362,22 +311,33 @@ func (rallocates ReAllocates) MarshalLogArray(enc zapcore.ArrayEncoder) error { return nil } -func BgBalanceCheck(nodeChannels []*NodeChannelInfo, ts time.Time) ([]*NodeChannelInfo, error) { - avaNodeNum := len(nodeChannels) - reAllocations := make(ReAllocates, 0, avaNodeNum) +// BalanceChannelPolicy try to balance watched channels to registered nodes +type BalanceChannelPolicy func(cluster Assignments) *ChannelOpSet + +// EmptyBalancePolicy is a dummy balance policy +func EmptyBalancePolicy(cluster Assignments) *ChannelOpSet { + return nil +} + +// AvgBalanceChannelPolicy tries to balance channel evenly +func AvgBalanceChannelPolicy(cluster Assignments) *ChannelOpSet { + avaNodeNum := len(cluster) if avaNodeNum == 0 { - return reAllocations, nil + return nil } + + reAllocations := make(Assignments, 0, avaNodeNum) totalChannelNum := 0 - for _, nodeChs := range nodeChannels { + for _, nodeChs := range cluster { totalChannelNum += len(nodeChs.Channels) } channelCountPerNode := totalChannelNum / avaNodeNum - for _, nChannels := range nodeChannels { + for _, nChannels := range cluster { chCount := len(nChannels.Channels) if chCount <= channelCountPerNode+1 { log.Info("node channel count is not much larger than average, skip reallocate", - zap.Int64("nodeID", nChannels.NodeID), zap.Int("channelCount", chCount), + zap.Int64("nodeID", nChannels.NodeID), + zap.Int("channelCount", chCount), zap.Int("channelCountPerNode", channelCountPerNode)) continue } @@ -392,25 +352,136 @@ func BgBalanceCheck(nodeChannels []*NodeChannelInfo, ts time.Time) ([]*NodeChann } reAllocations = append(reAllocations, reallocate) } - log.Info("Channel Balancer got new reAllocations:", zap.Array("reAllocations", reAllocations)) - return reAllocations, nil -} - -func formatNodeIDs(ids []int64) []string { - formatted := make([]string, 0, len(ids)) - for _, id := range ids { - formatted = append(formatted, formatNodeID(id)) + if len(reAllocations) == 0 { + return nil } - return formatted + + opSet := NewChannelOpSet() + for _, reAlloc := range reAllocations { + opSet.Append(reAlloc.NodeID, Release, lo.Values(reAlloc.Channels)...) + } + return opSet } -func formatNodeIDsWithFilter(ids []int64, filter int64) []string { - formatted := make([]string, 0, len(ids)) - for _, id := range ids { - if id == filter { - continue +func AvgAssignByCountPolicy(currentCluster Assignments, unassignedChannels []RWChannel, execlusiveNodes []int64) *ChannelOpSet { + var ( + toCluster Assignments + fromCluster Assignments + channelNum int = 0 + ) + + nodeToAvg := typeutil.NewUniqueSet() + + lo.ForEach(currentCluster, func(info *NodeChannelInfo, _ int) { + if !lo.Contains(execlusiveNodes, info.NodeID) { + toCluster = append(toCluster, info) + nodeToAvg.Insert(info.NodeID) } - formatted = append(formatted, formatNodeID(id)) + + if len(info.Channels) > 0 { + fromCluster = append(fromCluster, info) + channelNum += len(info.Channels) + nodeToAvg.Insert(info.NodeID) + } + }) + + // If no datanode alive, do nothing + if len(toCluster) == 0 { + return nil } - return formatted + + // 1. assign unassigned channels first + if len(unassignedChannels) > 0 { + chPerNode := (len(unassignedChannels) + channelNum) / nodeToAvg.Len() + + // sort by assigned channels count ascsending + sort.Slice(toCluster, func(i, j int) bool { + return len(toCluster[i].Channels) <= len(toCluster[j].Channels) + }) + + nodesLackOfChannels := Assignments(lo.Filter(toCluster, func(info *NodeChannelInfo, _ int) bool { + return len(info.Channels) < chPerNode + })) + + if len(nodesLackOfChannels) == 0 { + nodesLackOfChannels = toCluster + } + + updates := make(map[int64][]RWChannel) + for i, newChannel := range unassignedChannels { + n := nodesLackOfChannels[i%len(nodesLackOfChannels)].NodeID + updates[n] = append(updates[n], newChannel) + } + + opSet := NewChannelOpSet() + for id, chs := range updates { + opSet.Append(id, Watch, chs...) + opSet.Delete(bufferID, chs...) + } + + log.Info("Assign channels to nodes by channel count", + zap.Int("channel count", len(unassignedChannels)), + zap.Int("cluster count", len(toCluster)), + zap.Int64s("exclusive nodes", execlusiveNodes), + zap.Any("operations", opSet), + zap.Int64s("nodesLackOfChannels", lo.Map(nodesLackOfChannels, func(info *NodeChannelInfo, _ int) int64 { + return info.NodeID + })), + ) + return opSet + } + + if !Params.DataCoordCfg.AutoBalance.GetAsBool() { + log.Info("auto balance disabled") + return nil + } + + // 2. balance fromCluster to toCluster if no unassignedChannels + if len(fromCluster) == 0 { + return nil + } + chPerNode := channelNum / nodeToAvg.Len() + if chPerNode == 0 { + return nil + } + + // sort in descending order and reallocate + sort.Slice(fromCluster, func(i, j int) bool { + return len(fromCluster[i].Channels) > len(fromCluster[j].Channels) + }) + + releases := make(map[int64][]RWChannel) + for _, info := range fromCluster { + if len(info.Channels) > chPerNode { + cnt := 0 + for _, ch := range info.Channels { + cnt++ + if cnt > chPerNode { + releases[info.NodeID] = append(releases[info.NodeID], ch) + } + } + } + } + + // Channels in `releases` are reassigned eventually by channel manager. + opSet := NewChannelOpSet() + for k, v := range releases { + if lo.Contains(execlusiveNodes, k) { + opSet.Append(k, Delete, v...) + opSet.Append(bufferID, Watch, v...) + } else { + opSet.Append(k, Release, v...) + } + } + + log.Info("Assign channels to nodes by channel count", + zap.Int64s("exclusive nodes", execlusiveNodes), + zap.Int("channel count", channelNum), + zap.Int("channel per node", chPerNode), + zap.Any("operations", opSet), + zap.Array("fromCluster", fromCluster), + zap.Array("toCluster", toCluster), + ) + + return opSet } diff --git a/internal/datacoord/policy_test.go b/internal/datacoord/policy_test.go index fbd7c9d95d..839228017b 100644 --- a/internal/datacoord/policy_test.go +++ b/internal/datacoord/policy_test.go @@ -17,568 +17,601 @@ package datacoord import ( + "fmt" "testing" - "time" "github.com/samber/lo" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" - memkv "github.com/milvus-io/milvus/internal/kv/mem" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/log" ) -func TestBufferChannelAssignPolicy(t *testing.T) { - kv := memkv.NewMemoryKV() +func TestPolicySuite(t *testing.T) { + suite.Run(t, new(PolicySuite)) +} - channels := []RWChannel{getChannel("chan1", 1)} - store := &ChannelStore{ - store: kv, - channelsInfo: map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1), - bufferID: NewNodeChannelInfo(bufferID, channels...), - }, +func getChannel(name string, collID int64) *StateChannel { + return &StateChannel{ + Name: name, + CollectionID: collID, + Info: &datapb.ChannelWatchInfo{}, } - - updates := BufferChannelAssignPolicy(store, 1).Collect() - assert.NotNil(t, updates) - assert.Equal(t, 2, len(updates)) - assert.ElementsMatch(t, - NewChannelOpSet( - NewAddOp(1, channels...), - NewDeleteOp(bufferID, channels...), - ).Collect(), - updates) } -func getChannel(name string, collID int64) *channelMeta { - return &channelMeta{Name: name, CollectionID: collID} +func getChannels(ch2Coll map[string]int64) map[string]RWChannel { + ret := make(map[string]RWChannel) + for k, v := range ch2Coll { + ret[k] = getChannel(k, v) + } + return ret } -func getChannels(ch2Coll map[string]int64) []RWChannel { - return lo.MapToSlice(ch2Coll, func(name string, coll int64) RWChannel { - return &channelMeta{Name: name, CollectionID: coll} +type PolicySuite struct { + suite.Suite + + mockStore *MockRWChannelStore +} + +func (s *PolicySuite) SetupSubTest() { + s.mockStore = NewMockRWChannelStore(s.T()) +} + +func (s *PolicySuite) TestBufferChannelAssignPolicy() { + s.Run("Test no channels in bufferID", func() { + s.mockStore.EXPECT().GetBufferChannelInfo().Return(nil) + + opSet := BufferChannelAssignPolicy(s.mockStore, 1) + s.Nil(opSet) + }) + + s.Run("Test channels remain in bufferID", func() { + ch2Colls := map[string]int64{ + "ch1": 1, + "ch2": 1, + "ch3": 2, + } + info := &NodeChannelInfo{NodeID: bufferID, Channels: getChannels(ch2Colls)} + s.mockStore.EXPECT().GetBufferChannelInfo().Return(info) + + var testNodeID int64 = 100 + opSet := BufferChannelAssignPolicy(s.mockStore, testNodeID) + s.NotNil(opSet) + s.Equal(2, opSet.Len()) + s.Equal(3, opSet.GetChannelNumber()) + + for _, op := range opSet.Collect() { + s.ElementsMatch([]string{"ch1", "ch2", "ch3"}, op.GetChannelNames()) + s.True(op.Type == Delete || op.Type == Watch) + if op.Type == Delete { + s.EqualValues(bufferID, op.NodeID) + } + + if op.Type == Watch { + s.EqualValues(testNodeID, op.NodeID) + } + } }) } -func TestAverageAssignPolicy(t *testing.T) { - type args struct { - store ROChannelStore - channels []RWChannel +func (s *PolicySuite) TestAvarageAssignPolicy() { + ch2Coll := map[string]int64{ + "ch1": 100, + "ch2": 100, } - tests := []struct { - name string - args args - want *ChannelOpSet - }{ - { - "test assign empty cluster", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{}, - }, - []RWChannel{getChannel("chan1", 1)}, - }, - NewChannelOpSet(NewAddOp(bufferID, getChannel("chan1", 1))), - }, - { - "test watch same channel", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, getChannel("chan1", 1)), - }, - }, - []RWChannel{getChannel("chan1", 1)}, - }, - NewChannelOpSet(), - }, - { - "test normal assign", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, getChannel("chan", 1), getChannel("chan2", 1)), - 2: NewNodeChannelInfo(2, getChannel("chan3", 1)), - }, - }, - []RWChannel{getChannel("chan4", 1)}, - }, - NewChannelOpSet(NewAddOp(2, getChannel("chan4", 1))), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := AverageAssignPolicy(tt.args.store, tt.args.channels) - assert.EqualValues(t, tt.want.Collect(), got.Collect()) + var testNodeID int64 = 9 + + s.Run("no balance after register", func() { + s.mockStore.EXPECT().GetBufferChannelInfo().Return(nil) + s.mockStore.EXPECT().GetNodesChannels().Return([]*NodeChannelInfo{ + {NodeID: testNodeID}, }) - } -} -func TestAvgAssignUnregisteredChannels(t *testing.T) { - type args struct { - store ROChannelStore - nodeID int64 - } - tests := []struct { - name string - args args - want *ChannelOpSet - }{ - { - "test deregister the last node", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, getChannel("chan1", 1)), - }, - }, - 1, - }, - NewChannelOpSet( - NewDeleteOp(1, getChannel("chan1", 1)), - NewAddOp(bufferID, getChannel("chan1", 1)), - ), - }, - { - "test rebalance channels after deregister", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, getChannel("chan1", 1)), - 2: NewNodeChannelInfo(2, getChannel("chan2", 1)), - 3: NewNodeChannelInfo(3), - }, - }, - 2, - }, - NewChannelOpSet( - NewDeleteOp(2, getChannel("chan2", 1)), - NewAddOp(3, getChannel("chan2", 1)), - ), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := AvgAssignUnregisteredChannels(tt.args.store, tt.args.nodeID) - assert.EqualValues(t, tt.want.Collect(), got.Collect()) - }) - } -} + bufferOp, balanceOp := AvgAssignRegisterPolicy(s.mockStore, testNodeID) + s.Nil(bufferOp) + s.Nil(balanceOp) + }) + s.Run("balance bufferID channels after register", func() { + s.mockStore.EXPECT().GetBufferChannelInfo().Return( + &NodeChannelInfo{NodeID: bufferID, Channels: getChannels(ch2Coll)}, + ) -func TestBgCheckForChannelBalance(t *testing.T) { - type args struct { - channels []*NodeChannelInfo - timestamp time.Time - } + bufferOp, balanceOp := AvgAssignRegisterPolicy(s.mockStore, testNodeID) + s.Nil(balanceOp) + s.NotNil(bufferOp) + s.Equal(2, bufferOp.Len()) + s.Equal(2, bufferOp.GetChannelNumber()) - tests := []struct { - name string - args args - // want []*NodeChannelInfo - want int - wantErr error - }{ - { - "test even distribution", - args{ - []*NodeChannelInfo{ - NewNodeChannelInfo(1, getChannel("chan1", 1), getChannel("chan2", 1)), - NewNodeChannelInfo(2, getChannel("chan1", 2), getChannel("chan2", 2)), - NewNodeChannelInfo(3, getChannel("chan1", 3), getChannel("chan2", 3)), - }, - time.Now(), - }, - // there should be no reallocate - 0, - nil, - }, - { - "test uneven with conservative effect", - args{ - []*NodeChannelInfo{ - NewNodeChannelInfo(1, getChannel("chan1", 1), getChannel("chan2", 1)), - NewNodeChannelInfo(2), - }, - time.Now(), - }, - // as we deem that the node having only one channel more than average as even, so there's no reallocation - // for this test case - 0, - nil, - }, - { - "test uneven with zero", - args{ - []*NodeChannelInfo{ - NewNodeChannelInfo(1, getChannel("chan1", 1), getChannel("chan2", 1), getChannel("chan3", 1)), - NewNodeChannelInfo(2), - }, - time.Now(), - }, - 1, - nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - policy := BgBalanceCheck - got, err := policy(tt.args.channels, tt.args.timestamp) - assert.Equal(t, tt.wantErr, err) - assert.EqualValues(t, tt.want, len(got)) - }) - } -} + for _, op := range bufferOp.Collect() { + s.ElementsMatch(lo.Keys(ch2Coll), op.GetChannelNames()) + s.True(op.Type == Delete || op.Type == Watch) + if op.Type == Delete { + s.EqualValues(bufferID, op.NodeID) + } -func TestAvgReassignPolicy(t *testing.T) { - type args struct { - store ROChannelStore - reassigns []*NodeChannelInfo - } - tests := []struct { - name string - args args - want *ChannelOpSet - }{ - { - "test_only_one_node", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, getChannel("chan1", 1)), - }, - }, - []*NodeChannelInfo{NewNodeChannelInfo(1, getChannel("chan1", 1))}, - }, - // as there's no available nodes except the input node, there's no reassign plan generated - NewChannelOpSet(), - }, - { - "test_zero_avg", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, getChannel("chan1", 1)), - 2: NewNodeChannelInfo(2), - 3: NewNodeChannelInfo(3), - 4: NewNodeChannelInfo(4), - }, - }, - []*NodeChannelInfo{NewNodeChannelInfo(1, getChannel("chan1", 1))}, - }, - // as we use ceil to calculate the wanted average number, there should be one reassign - // though the average num less than 1 - NewChannelOpSet( - NewDeleteOp(1, getChannel("chan1", 1)), - NewAddOp(2, getChannel("chan1", 1)), - ), - }, - { - "test_normal_reassigning_for_one_available_nodes", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, getChannel("chan1", 1), getChannel("chan2", 1)), - 2: NewNodeChannelInfo(2), - }, - }, - []*NodeChannelInfo{NewNodeChannelInfo(1, getChannel("chan1", 1), getChannel("chan2", 1))}, - }, - NewChannelOpSet( - NewDeleteOp(1, getChannel("chan1", 1), getChannel("chan2", 1)), - NewAddOp(2, getChannel("chan1", 1), getChannel("chan2", 1)), - ), - }, - { - "test_normal_reassigning_for_multiple_available_nodes", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, getChannel("chan1", 1), - getChannel("chan2", 1), - getChannel("chan3", 1), - getChannel("chan4", 1)), - 2: NewNodeChannelInfo(2), - 3: NewNodeChannelInfo(3), - 4: NewNodeChannelInfo(4), - }, - }, - []*NodeChannelInfo{NewNodeChannelInfo(1, getChannel("chan1", 1), - getChannel("chan2", 1), - getChannel("chan3", 1), - getChannel("chan4", 1))}, - }, - NewChannelOpSet( - NewDeleteOp(1, []RWChannel{ - getChannel("chan1", 1), - getChannel("chan2", 1), - getChannel("chan3", 1), - }...), - NewAddOp(2, getChannel("chan1", 1)), - NewAddOp(3, getChannel("chan2", 1)), - NewAddOp(4, getChannel("chan3", 1)), - ), - }, - { - "test_reassigning_for_extreme_case", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, - getChannel("chan1", 1), - getChannel("chan2", 1), - getChannel("chan3", 1), - getChannel("chan4", 1), - getChannel("chan5", 1), - getChannel("chan6", 1), - getChannel("chan7", 1), - getChannel("chan8", 1), - getChannel("chan9", 1), - getChannel("chan10", 1), - getChannel("chan11", 1), - getChannel("chan12", 1)), - 2: NewNodeChannelInfo(2, - getChannel("chan13", 1), - getChannel("chan14", 1)), - 3: NewNodeChannelInfo(3, getChannel("chan15", 1)), - 4: NewNodeChannelInfo(4), - }, - }, - []*NodeChannelInfo{NewNodeChannelInfo(1, - getChannel("chan1", 1), - getChannel("chan2", 1), - getChannel("chan3", 1), - getChannel("chan4", 1), - getChannel("chan5", 1), - getChannel("chan6", 1), - getChannel("chan7", 1), - getChannel("chan8", 1), - getChannel("chan9", 1), - getChannel("chan10", 1), - getChannel("chan11", 1), - getChannel("chan12", 1))}, - }, - NewChannelOpSet( - NewDeleteOp(1, []RWChannel{ - getChannel("chan1", 1), - getChannel("chan2", 1), - getChannel("chan3", 1), - getChannel("chan4", 1), - getChannel("chan5", 1), - getChannel("chan6", 1), - getChannel("chan7", 1), - getChannel("chan8", 1), - getChannel("chan9", 1), - getChannel("chan10", 1), - getChannel("chan11", 1), - getChannel("chan12", 1), - }...), - NewAddOp(4, []RWChannel{ - getChannel("chan1", 1), - getChannel("chan2", 1), - getChannel("chan3", 1), - getChannel("chan4", 1), - getChannel("chan5", 1), - }...), - NewAddOp(3, []RWChannel{ - getChannel("chan6", 1), - getChannel("chan7", 1), - getChannel("chan8", 1), - getChannel("chan9", 1), - }...), - NewAddOp(2, []RWChannel{ - getChannel("chan10", 1), - getChannel("chan11", 1), - getChannel("chan12", 1), - }...), - ), - }, - } - for _, tt := range tests { - if tt.name == "test_reassigning_for_extreme_case" || - tt.name == "test_normal_reassigning_for_multiple_available_nodes" { - continue + if op.Type == Watch { + s.EqualValues(testNodeID, op.NodeID) + } } - t.Run(tt.name, func(t *testing.T) { - got := AverageReassignPolicy(tt.args.store, tt.args.reassigns) + log.Info("got bufferOp", zap.Any("op", bufferOp)) + }) - wantMap, gotMap := tt.want.SplitByChannel(), got.SplitByChannel() - assert.ElementsMatch(t, lo.Keys(wantMap), lo.Keys(gotMap)) + s.Run("balance after register", func() { + s.mockStore.EXPECT().GetBufferChannelInfo().Return(nil) + s.mockStore.EXPECT().GetNodesChannels().Return([]*NodeChannelInfo{ + {NodeID: 1, Channels: getChannels(ch2Coll)}, + {NodeID: testNodeID}, + }) - for k, opSet := range wantMap { - gotOpSet, ok := gotMap[k] - require.True(t, ok) - assert.ElementsMatch(t, opSet.Collect(), gotOpSet.Collect()) + bufferOp, balanceOp := AvgAssignRegisterPolicy(s.mockStore, testNodeID) + s.Nil(bufferOp) + s.NotNil(balanceOp) + s.Equal(1, balanceOp.Len()) + s.Equal(1, balanceOp.GetChannelNumber()) + + for _, op := range balanceOp.Collect() { + s.Equal(Release, op.Type) + s.EqualValues(1, op.NodeID) + } + log.Info("got balanceOp", zap.Any("op", balanceOp)) + }) +} + +func (s *PolicySuite) TestAverageAssignPolicy() { + ch2Coll := map[string]int64{ + "ch1": 1, + "ch2": 1, + "ch3": 2, + } + channels := getChannels(ch2Coll) + + s.Run("no new channels", func() { + s.mockStore.EXPECT().HasChannel(mock.Anything).Return(true) + + opSet := AverageAssignPolicy(s.mockStore, lo.Values(channels)) + s.Nil(opSet) + }) + + s.Run("no datanodes", func() { + s.mockStore.EXPECT().HasChannel(mock.Anything).Return(false) + s.mockStore.EXPECT().GetNodesChannels().Return(nil) + channels := getChannels(ch2Coll) + + opSet := AverageAssignPolicy(s.mockStore, lo.Values(channels)) + s.NotNil(opSet) + s.Equal(1, opSet.Len()) + + op := opSet.Collect()[0] + s.EqualValues(bufferID, op.NodeID) + s.Equal(Watch, op.Type) + s.ElementsMatch(lo.Keys(ch2Coll), op.GetChannelNames()) + }) + + s.Run("one datanode", func() { + // Test three channels assigned one datanode + s.mockStore.EXPECT().HasChannel(mock.Anything).Return(false) + s.mockStore.EXPECT().GetNodesChannels().Return([]*NodeChannelInfo{ + {NodeID: 1, Channels: getChannels(map[string]int64{"channel": 1})}, + }) + channels := getChannels(ch2Coll) + + opSet := AverageAssignPolicy(s.mockStore, lo.Values(channels)) + s.NotNil(opSet) + s.Equal(1, opSet.Len()) + s.Equal(3, opSet.GetChannelNumber()) + + for _, op := range opSet.Collect() { + s.Equal(Watch, op.Type) + + s.EqualValues(1, op.NodeID) + s.Equal(3, len(op.GetChannelNames())) + s.ElementsMatch(lo.Keys(ch2Coll), op.GetChannelNames()) + } + log.Info("test OpSet", zap.Any("opset", opSet)) + }) + + s.Run("three datanode", func() { + // Test three channels assigned evenly to three datanodes + s.mockStore.EXPECT().HasChannel(mock.Anything).Return(false) + s.mockStore.EXPECT().GetNodesChannels().Return([]*NodeChannelInfo{ + {NodeID: 1}, + {NodeID: 2}, + {NodeID: 3}, + }) + + opSet := AverageAssignPolicy(s.mockStore, lo.Values(channels)) + s.NotNil(opSet) + s.Equal(3, opSet.Len()) + s.Equal(3, opSet.GetChannelNumber()) + + s.ElementsMatch([]int64{1, 2, 3}, lo.Map(opSet.Collect(), func(op *ChannelOp, _ int) int64 { + return op.NodeID + })) + for _, op := range opSet.Collect() { + s.True(lo.Contains([]int64{1, 2, 3}, op.NodeID)) + s.Equal(1, len(op.GetChannelNames())) + s.Equal(Watch, op.Type) + s.True(lo.Contains(lo.Keys(ch2Coll), op.GetChannelNames()[0])) + } + log.Info("test OpSet", zap.Any("opset", opSet)) + }) +} + +func (s *PolicySuite) TestAvgAssignUnregisteredChannels() { + ch2Coll := map[string]int64{ + "ch1": 1, + "ch2": 1, + "ch3": 2, + } + info := &NodeChannelInfo{ + NodeID: 1, + Channels: getChannels(ch2Coll), + } + + s.Run("deregistering last node", func() { + s.mockStore.EXPECT().GetNode(mock.Anything).Return(info) + s.mockStore.EXPECT().GetNodesChannels().Return(nil) + + opSet := AvgAssignUnregisteredChannels(s.mockStore, info.NodeID) + s.NotNil(opSet) + s.Equal(2, opSet.Len()) + + for _, op := range opSet.Collect() { + s.ElementsMatch(lo.Keys(ch2Coll), op.GetChannelNames()) + if op.Type == Delete { + s.EqualValues(info.NodeID, op.NodeID) } + + if op.Type == Watch { + s.EqualValues(bufferID, op.NodeID) + } + } + log.Info("test OpSet", zap.Any("opset", opSet)) + }) + + s.Run("assign channels after deregistering", func() { + s.mockStore.EXPECT().GetNode(mock.Anything).Return(info) + s.mockStore.EXPECT().GetNodesChannels().Return([]*NodeChannelInfo{ + {NodeID: 100}, }) + + opSet := AvgAssignUnregisteredChannels(s.mockStore, info.NodeID) + s.NotNil(opSet) + s.Equal(2, opSet.Len()) + for _, op := range opSet.Collect() { + s.ElementsMatch(lo.Keys(ch2Coll), op.GetChannelNames()) + s.True(op.Type == Delete || op.Type == Watch) + if op.Type == Delete { + s.EqualValues(info.NodeID, op.NodeID) + } + + if op.Type == Watch { + s.EqualValues(100, op.NodeID) + } + } + log.Info("test OpSet", zap.Any("opset", opSet)) + }) + + s.Run("test average", func() { + s.mockStore.EXPECT().GetNode(mock.Anything).Return(info) + s.mockStore.EXPECT().GetNodesChannels().Return([]*NodeChannelInfo{ + {NodeID: 100}, + {NodeID: 101}, + {NodeID: 102}, + }) + + opSet := AvgAssignUnregisteredChannels(s.mockStore, info.NodeID) + s.NotNil(opSet) + s.Equal(4, opSet.Len()) + + nodeCh := make(map[int64]string) + for _, op := range opSet.Collect() { + s.True(op.Type == Delete || op.Type == Watch) + if op.Type == Delete { + s.EqualValues(info.NodeID, op.NodeID) + s.ElementsMatch(lo.Keys(ch2Coll), op.GetChannelNames()) + } + + if op.Type == Watch { + s.Equal(1, len(op.GetChannelNames())) + nodeCh[op.NodeID] = op.GetChannelNames()[0] + } + } + + s.ElementsMatch([]int64{100, 101, 102}, lo.Keys(nodeCh)) + s.ElementsMatch(lo.Keys(ch2Coll), lo.Values(nodeCh)) + log.Info("test OpSet", zap.Any("opset", opSet)) + }) +} + +func (s *PolicySuite) TestAvgBalanceChannelPolicy() { + s.Run("test even distribution", func() { + // even distribution should have not results + evenDist := []*NodeChannelInfo{ + {100, getChannels(map[string]int64{"ch1": 1, "ch2": 1})}, + {101, getChannels(map[string]int64{"ch3": 2, "ch4": 2})}, + {102, getChannels(map[string]int64{"ch5": 3, "ch6": 3})}, + } + + opSet := AvgBalanceChannelPolicy(evenDist) + s.Nil(opSet) + }) + s.Run("test uneven with conservative effect", func() { + // as we deem that the node having only one channel more than average as even, so there's no reallocation + // for this test case + // even distribution should have not results + uneven := []*NodeChannelInfo{ + {100, getChannels(map[string]int64{"ch1": 1, "ch2": 1})}, + {NodeID: 101}, + } + + opSet := AvgBalanceChannelPolicy(uneven) + s.Nil(opSet) + }) + s.Run("test uneven with zero", func() { + uneven := []*NodeChannelInfo{ + {100, getChannels(map[string]int64{"ch1": 1, "ch2": 1, "ch3": 1})}, + {NodeID: 101}, + } + + opSet := AvgBalanceChannelPolicy(uneven) + s.NotNil(opSet) + s.Equal(1, opSet.Len()) + + for _, op := range opSet.Collect() { + s.Equal(Release, op.Type) + s.EqualValues(100, op.NodeID) + s.Equal(1, len(op.GetChannelNames())) + s.True(lo.Contains([]string{"ch1", "ch2", "ch3"}, op.GetChannelNames()[0])) + } + log.Info("test OpSet", zap.Any("opset", opSet)) + }) +} + +func (s *PolicySuite) TestAvgReassignPolicy() { + s.Run("test only one node", func() { + ch2Coll := map[string]int64{ + "ch1": 1, + "ch2": 1, + "ch3": 2, + "ch4": 2, + "ch5": 3, + } + fiveChannels := getChannels(ch2Coll) + storedInfo := []*NodeChannelInfo{{100, fiveChannels}} + s.mockStore.EXPECT().GetNodesChannels().Return(storedInfo) + + opSet := AverageReassignPolicy(s.mockStore, storedInfo) + s.Nil(opSet) + }) + s.Run("test zero average", func() { + // as we use ceil to calculate the wanted average number, there should be one reassign + // though the average num less than 1 + storedInfo := []*NodeChannelInfo{ + {100, getChannels(map[string]int64{"ch1": 1})}, + {NodeID: 102}, + {NodeID: 103}, + {NodeID: 104}, + } + + s.mockStore.EXPECT().GetNodesChannels().Return(storedInfo) + s.mockStore.EXPECT().GetNodeChannelCount(mock.Anything).RunAndReturn(func(nodeID int64) int { + for _, info := range storedInfo { + if info.NodeID == nodeID { + return len(info.Channels) + } + } + return 0 + }) + + opSet := AverageReassignPolicy(s.mockStore, storedInfo[0:1]) + s.NotNil(opSet) + s.Equal(2, opSet.Len()) + + for _, op := range opSet.Collect() { + s.Equal(1, len(op.GetChannelNames())) + s.Equal("ch1", op.GetChannelNames()[0]) + + s.True(op.Type == Delete || op.Type == Watch) + } + log.Info("test OpSet", zap.Any("opset", opSet)) + }) + s.Run("test reassign one to one", func() { + storedInfo := []*NodeChannelInfo{ + {100, getChannels(map[string]int64{"ch1": 1, "ch2": 1, "ch3": 1, "ch4": 1})}, + {NodeID: 101}, + {NodeID: 102}, + } + + s.mockStore.EXPECT().GetNodesChannels().Return(storedInfo) + s.mockStore.EXPECT().GetNodeChannelCount(mock.Anything).RunAndReturn(func(nodeID int64) int { + for _, info := range storedInfo { + if info.NodeID == nodeID { + return len(info.Channels) + } + } + return 0 + }) + + opSet := AverageReassignPolicy(s.mockStore, storedInfo[0:1]) + s.NotNil(opSet) + s.Equal(3, opSet.Len()) + + for _, op := range opSet.Collect() { + s.True(op.Type == Delete || op.Type == Watch) + if op.Type == Delete { + s.ElementsMatch([]string{"ch1", "ch2", "ch3", "ch4"}, op.GetChannelNames()) + s.EqualValues(100, op.NodeID) + } + + if op.Type == Watch { + s.Equal(2, len(op.GetChannelNames())) + s.True(lo.Contains([]int64{102, 101}, op.NodeID)) + } + } + log.Info("test OpSet", zap.Any("opset", opSet)) + }) +} + +type AssignByCountPolicySuite struct { + suite.Suite + + curCluster Assignments +} + +func TestAssignByCountPolicySuite(t *testing.T) { + suite.Run(t, new(AssignByCountPolicySuite)) +} + +func (s *AssignByCountPolicySuite) SetupSubTest() { + s.curCluster = []*NodeChannelInfo{ + {1, getChannels(map[string]int64{"ch-1": 1, "ch2": 1, "ch-2": 2})}, + {2, getChannels(map[string]int64{"ch-3": 1, "ch2": 1, "ch-4": 4})}, + {NodeID: 3, Channels: map[string]RWChannel{}}, } } -func TestAvgBalanceChannelPolicy(t *testing.T) { - type args struct { - store ROChannelStore - } - tests := []struct { - name string - args args - want int - }{ - { - "test_only_one_node", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, - getChannel("chan1", 1), - getChannel("chan2", 1), - getChannel("chan3", 1), - getChannel("chan4", 1), - ), - 2: NewNodeChannelInfo(2), - }, - }, - }, - 1, - }, - } +func (s *AssignByCountPolicySuite) TestWithoutUnassignedChannels() { + s.Run("balance without exclusive", func() { + opSet := AvgAssignByCountPolicy(s.curCluster, nil, nil) + s.NotNil(opSet) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := AvgBalanceChannelPolicy(tt.args.store, time.Now()) - assert.EqualValues(t, tt.want, len(got.Collect())) - }) - } -} + s.Equal(2, opSet.GetChannelNumber()) + for _, op := range opSet.Collect() { + s.True(lo.Contains([]int64{1, 2}, op.NodeID)) + } + }) + s.Run("balance with exclusive", func() { + execlusiveNodes := []int64{1, 2} + opSet := AvgAssignByCountPolicy(s.curCluster, nil, execlusiveNodes) + s.NotNil(opSet) -func TestAvgAssignRegisterPolicy(t *testing.T) { - type args struct { - store ROChannelStore - nodeID int64 - } - tests := []struct { - name string - args args - bufferedUpdates *ChannelOpSet - balanceUpdates *ChannelOpSet - exact bool - bufferedUpdatesNum int - balanceUpdatesNum int - }{ - { - "test empty", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1), - }, - }, - 1, - }, - NewChannelOpSet(), - NewChannelOpSet(), - true, - 0, - 0, - }, - { - "test with buffer channel", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - bufferID: NewNodeChannelInfo(bufferID, getChannel("ch1", 1)), - 1: NewNodeChannelInfo(1), - }, - }, - 1, - }, - NewChannelOpSet( - NewDeleteOp(bufferID, getChannel("ch1", 1)), - NewAddOp(1, getChannel("ch1", 1)), - ), - NewChannelOpSet(), - true, - 0, - 0, - }, - { - "test with avg assign", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, getChannel("ch1", 1), getChannel("ch2", 1)), - 3: NewNodeChannelInfo(3), - }, - }, - 3, - }, - NewChannelOpSet(), - NewChannelOpSet(NewAddOp(1, getChannel("ch1", 1))), - false, - 0, - 1, - }, - { - "test with avg equals to zero", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, getChannel("ch1", 1)), - 2: NewNodeChannelInfo(2, getChannel("ch3", 1)), - 3: NewNodeChannelInfo(3), - }, - }, - 3, - }, - NewChannelOpSet(), - NewChannelOpSet(), - true, - 0, - 0, - }, - { - "test_node_with_empty_channel", - args{ - &ChannelStore{ - memkv.NewMemoryKV(), - map[int64]*NodeChannelInfo{ - 1: NewNodeChannelInfo(1, getChannel("ch1", 1), getChannel("ch2", 1), getChannel("ch3", 1)), - 2: NewNodeChannelInfo(2), - 3: NewNodeChannelInfo(3), - }, - }, - 3, - }, - NewChannelOpSet(), - NewChannelOpSet(NewAddOp(1, getChannel("ch1", 1))), - false, - 0, - 1, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - bufferedUpdates, balanceUpdates := AvgAssignRegisterPolicy(tt.args.store, tt.args.nodeID) - if tt.exact { - assert.EqualValues(t, tt.bufferedUpdates.Collect(), bufferedUpdates.Collect()) - assert.EqualValues(t, tt.balanceUpdates.Collect(), balanceUpdates.Collect()) + s.Equal(2, opSet.GetChannelNumber()) + for _, op := range opSet.Collect() { + if op.NodeID == bufferID { + s.Equal(Watch, op.Type) } else { - assert.Equal(t, tt.bufferedUpdatesNum, len(bufferedUpdates.Collect())) - assert.Equal(t, tt.balanceUpdatesNum, len(balanceUpdates.Collect())) + s.True(lo.Contains([]int64{1, 2}, op.NodeID)) + s.Equal(Delete, op.Type) } - }) - } + } + }) + s.Run("extreme cases", func() { + m := make(map[string]int64) + for i := 0; i < 100; i++ { + m[fmt.Sprintf("ch-%d", i)] = 1 + } + s.curCluster = []*NodeChannelInfo{ + {NodeID: 1, Channels: getChannels(m)}, + {NodeID: 2, Channels: map[string]RWChannel{}}, + {NodeID: 3, Channels: map[string]RWChannel{}}, + {NodeID: 4, Channels: map[string]RWChannel{}}, + {NodeID: 5, Channels: map[string]RWChannel{}}, + } + + execlusiveNodes := []int64{4, 5} + opSet := AvgAssignByCountPolicy(s.curCluster, nil, execlusiveNodes) + s.NotNil(opSet) + }) +} + +func (s *AssignByCountPolicySuite) TestWithUnassignedChannels() { + s.Run("one unassigned channel", func() { + unassigned := []RWChannel{ + getChannel("new-ch-1", 1), + } + + opSet := AvgAssignByCountPolicy(s.curCluster, unassigned, nil) + s.NotNil(opSet) + + s.Equal(1, opSet.GetChannelNumber()) + for _, op := range opSet.Collect() { + if op.NodeID == bufferID { + s.Equal(Delete, op.Type) + } else { + s.EqualValues(3, op.NodeID) + } + } + }) + + s.Run("three unassigned channel", func() { + unassigned := []RWChannel{ + getChannel("new-ch-1", 1), + getChannel("new-ch-2", 1), + getChannel("new-ch-3", 1), + } + + opSet := AvgAssignByCountPolicy(s.curCluster, unassigned, nil) + s.NotNil(opSet) + + s.Equal(3, opSet.GetChannelNumber()) + for _, op := range opSet.Collect() { + if op.NodeID == bufferID { + s.Equal(Delete, op.Type) + } + } + s.Equal(2, opSet.Len()) + + nodeIDs := lo.FilterMap(opSet.Collect(), func(op *ChannelOp, _ int) (int64, bool) { + return op.NodeID, op.NodeID != bufferID + }) + s.ElementsMatch([]int64{3}, nodeIDs) + }) + + s.Run("three unassigned channel with execlusiveNodes", func() { + unassigned := []RWChannel{ + getChannel("new-ch-1", 1), + getChannel("new-ch-2", 1), + getChannel("new-ch-3", 1), + } + + opSet := AvgAssignByCountPolicy(s.curCluster, unassigned, []int64{1, 2}) + s.NotNil(opSet) + + s.Equal(3, opSet.GetChannelNumber()) + for _, op := range opSet.Collect() { + if op.NodeID == bufferID { + s.Equal(Delete, op.Type) + } + } + s.Equal(2, opSet.Len()) + + nodeIDs := lo.FilterMap(opSet.Collect(), func(op *ChannelOp, _ int) (int64, bool) { + return op.NodeID, op.NodeID != bufferID + }) + s.ElementsMatch([]int64{3}, nodeIDs) + }) + s.Run("67 unassigned with 33 in node1, none in node2,3", func() { + var unassigned []RWChannel + m1 := make(map[string]int64) + for i := 0; i < 33; i++ { + m1[fmt.Sprintf("ch-%d", i)] = 1 + } + for i := 33; i < 100; i++ { + unassigned = append(unassigned, getChannel(fmt.Sprintf("ch-%d", i), 1)) + } + s.curCluster = []*NodeChannelInfo{ + {NodeID: 1, Channels: getChannels(m1)}, + {NodeID: 2, Channels: map[string]RWChannel{}}, + {NodeID: 3, Channels: map[string]RWChannel{}}, + } + + opSet := AvgAssignByCountPolicy(s.curCluster, unassigned, nil) + s.NotNil(opSet) + + s.Equal(67, opSet.GetChannelNumber()) + for _, op := range opSet.Collect() { + if op.NodeID == bufferID { + s.Equal(Delete, op.Type) + } + } + s.Equal(4, opSet.Len()) + + nodeIDs := lo.FilterMap(opSet.Collect(), func(op *ChannelOp, _ int) (int64, bool) { + return op.NodeID, op.NodeID != bufferID + }) + s.ElementsMatch([]int64{3, 2}, nodeIDs) + }) } diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index de0a8a453f..abe346ceee 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -343,6 +343,7 @@ func (s *Server) initDataCoord() error { log.Info("init rootcoord client done") s.broker = broker.NewCoordinatorBroker(s.rootCoordClient) + s.allocator = newRootCoordAllocator(s.rootCoordClient) storageCli, err := s.newChunkManagerFactory() if err != nil { @@ -364,8 +365,6 @@ func (s *Server) initDataCoord() error { } log.Info("init datanode cluster done") - s.allocator = newRootCoordAllocator(s.rootCoordClient) - s.initIndexNodeManager() if err = s.initServiceDiscovery(); err != nil { @@ -466,6 +465,13 @@ func (s *Server) startDataCoord() { sessionutil.SaveServerInfo(typeutil.DataCoordRole, s.session.GetServerID()) } +func (s *Server) GetServerID() int64 { + if s.session != nil { + return s.session.GetServerID() + } + return paramtable.GetNodeID() +} + func (s *Server) afterStart() {} func (s *Server) initCluster() error { @@ -473,13 +479,20 @@ func (s *Server) initCluster() error { return nil } - var err error - s.channelManager, err = NewChannelManager(s.watchClient, s.handler, withMsgstreamFactory(s.factory), - withStateChecker(), withBgChecker()) - if err != nil { - return err - } s.sessionManager = NewSessionManagerImpl(withSessionCreator(s.dataNodeCreator)) + + var err error + if paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.GetAsBool() { + s.channelManager, err = NewChannelManagerV2(s.watchClient, s.handler, s.sessionManager, s.allocator, withCheckerV2()) + if err != nil { + return err + } + } else { + s.channelManager, err = NewChannelManager(s.watchClient, s.handler, withMsgstreamFactory(s.factory), withStateChecker(), withBgChecker()) + if err != nil { + return err + } + } s.cluster = NewClusterImpl(s.sessionManager, s.channelManager) return nil } @@ -559,11 +572,21 @@ func (s *Server) initServiceDiscovery() error { log.Info("DataCoord success to get DataNode sessions", zap.Any("sessions", sessions)) datanodes := make([]*NodeInfo, 0, len(sessions)) + legacyVersion, err := semver.Parse(paramtable.Get().DataCoordCfg.LegacyVersionWithoutRPCWatch.GetValue()) + if err != nil { + log.Warn("DataCoord failed to init service discovery", zap.Error(err)) + } + for _, session := range sessions { info := &NodeInfo{ NodeID: session.ServerID, Address: session.Address, } + + if session.Version.LTE(legacyVersion) { + info.IsLegacy = true + } + datanodes = append(datanodes, info) } diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 69e641aac3..349c498617 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -2424,7 +2424,7 @@ func TestOptions(t *testing.T) { defer kv.RemoveWithPrefix("") sessionManager := NewSessionManagerImpl() - channelManager, err := NewChannelManager(kv, newMockHandler()) + channelManager, err := NewChannelManagerV2(kv, newMockHandler(), sessionManager, newMockAllocator()) assert.NoError(t, err) cluster := NewClusterImpl(sessionManager, channelManager) @@ -2479,7 +2479,7 @@ func TestHandleSessionEvent(t *testing.T) { defer cancel() sessionManager := NewSessionManagerImpl() - channelManager, err := NewChannelManager(kv, newMockHandler(), withFactory(&mockPolicyFactory{})) + channelManager, err := NewChannelManagerV2(kv, newMockHandler(), sessionManager, newMockAllocator(), withFactoryV2(&mockPolicyFactory{})) assert.NoError(t, err) cluster := NewClusterImpl(sessionManager, channelManager) diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 6535af3a20..22e4a2d585 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -1249,20 +1249,14 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq }, nil } for _, channelName := range req.GetChannelNames() { - ch := &channelMeta{ - Name: channelName, - CollectionID: req.GetCollectionID(), - StartPositions: req.GetStartPositions(), - Schema: req.GetSchema(), - CreateTimestamp: req.GetCreateTimestamp(), - } + ch := NewRWChannel(channelName, req.GetCollectionID(), req.GetStartPositions(), req.GetSchema(), req.GetCreateTimestamp()) err := s.channelManager.Watch(ctx, ch) if err != nil { log.Warn("fail to watch channelName", zap.Error(err)) resp.Status = merr.Status(err) return resp, nil } - if err := s.meta.catalog.MarkChannelAdded(ctx, ch.Name); err != nil { + if err := s.meta.catalog.MarkChannelAdded(ctx, channelName); err != nil { // TODO: add background task to periodically cleanup the orphaned channel add marks. log.Error("failed to mark channel added", zap.Error(err)) resp.Status = merr.Status(err) diff --git a/internal/datacoord/services_test.go b/internal/datacoord/services_test.go index 6e573239b7..4db96c9a05 100644 --- a/internal/datacoord/services_test.go +++ b/internal/datacoord/services_test.go @@ -50,7 +50,7 @@ func WithChannelManager(cm ChannelManager) Option { func (s *ServerSuite) SetupTest() { s.mockChMgr = NewMockChannelManager(s.T()) - s.mockChMgr.EXPECT().Startup(mock.Anything, mock.Anything).Return(nil).Maybe() + s.mockChMgr.EXPECT().Startup(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() s.mockChMgr.EXPECT().Close().Maybe() s.testServer = newTestServer(s.T(), WithChannelManager(s.mockChMgr)) diff --git a/internal/datacoord/session.go b/internal/datacoord/session.go index 115209bc40..e3393281df 100644 --- a/internal/datacoord/session.go +++ b/internal/datacoord/session.go @@ -30,8 +30,9 @@ var errDisposed = errors.New("client is disposed") // NodeInfo contains node base info type NodeInfo struct { - NodeID int64 - Address string + NodeID int64 + Address string + IsLegacy bool } // Session contains session info of a node diff --git a/internal/datacoord/session_manager_test.go b/internal/datacoord/session_manager_test.go index ab26c7efdc..da6b20dc71 100644 --- a/internal/datacoord/session_manager_test.go +++ b/internal/datacoord/session_manager_test.go @@ -35,7 +35,7 @@ func (s *SessionManagerSuite) SetupTest() { return s.dn, nil })) - s.m.AddSession(&NodeInfo{1000, "addr-1"}) + s.m.AddSession(&NodeInfo{1000, "addr-1", true}) s.MetricsEqual(metrics.DataCoordNumDataNodes, 1) } diff --git a/internal/datanode/channel_manager.go b/internal/datanode/channel_manager.go index 5bb7ad86de..6f6242c726 100644 --- a/internal/datanode/channel_manager.go +++ b/internal/datanode/channel_manager.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -52,24 +53,22 @@ type ChannelManagerImpl struct { releaseFunc releaseFunc - closeCh chan struct{} - closeOnce sync.Once + closeCh lifetime.SafeChan closeWaiter sync.WaitGroup } func NewChannelManager(dn *DataNode) *ChannelManagerImpl { - fm := newFlowgraphManager() cm := ChannelManagerImpl{ dn: dn, - fgManager: fm, + fgManager: dn.flowgraphManager, communicateCh: make(chan *opState, 100), opRunners: typeutil.NewConcurrentMap[string, *opRunner](), abnormals: typeutil.NewConcurrentMap[int64, string](), - releaseFunc: fm.RemoveFlowgraph, + releaseFunc: dn.flowgraphManager.RemoveFlowgraph, - closeCh: make(chan struct{}), + closeCh: lifetime.NewSafeChan(), } return &cm @@ -131,14 +130,14 @@ func (m *ChannelManagerImpl) GetProgress(info *datapb.ChannelWatchInfo) *datapb. } func (m *ChannelManagerImpl) Close() { - m.closeOnce.Do(func() { + if m.opRunners != nil { m.opRunners.Range(func(channel string, runner *opRunner) bool { runner.Close() return true }) - close(m.closeCh) - m.closeWaiter.Wait() - }) + } + m.closeCh.Close() + m.closeWaiter.Wait() } func (m *ChannelManagerImpl) Start() { @@ -150,7 +149,7 @@ func (m *ChannelManagerImpl) Start() { select { case opState := <-m.communicateCh: m.handleOpState(opState) - case <-m.closeCh: + case <-m.closeCh.CloseCh(): log.Info("DataNode ChannelManager exit") return } @@ -170,23 +169,19 @@ func (m *ChannelManagerImpl) handleOpState(opState *opState) { case datapb.ChannelWatchState_WatchSuccess: log.Info("Success to watch") m.fgManager.AddFlowgraph(opState.fg) - m.finishOp(opState.opID, opState.channel) case datapb.ChannelWatchState_WatchFailure: log.Info("Fail to watch") - m.finishOp(opState.opID, opState.channel) case datapb.ChannelWatchState_ReleaseSuccess: log.Info("Success to release") - m.finishOp(opState.opID, opState.channel) - m.destoryRunner(opState.channel) case datapb.ChannelWatchState_ReleaseFailure: log.Info("Fail to release, add channel to abnormal lists") m.abnormals.Insert(opState.opID, opState.channel) - m.finishOp(opState.opID, opState.channel) - m.destoryRunner(opState.channel) } + + m.finishOp(opState.opID, opState.channel) } func (m *ChannelManagerImpl) getOrCreateRunner(channel string) *opRunner { @@ -197,15 +192,10 @@ func (m *ChannelManagerImpl) getOrCreateRunner(channel string) *opRunner { return runner } -func (m *ChannelManagerImpl) destoryRunner(channel string) { - if runner, loaded := m.opRunners.GetAndRemove(channel); loaded { - runner.Close() - } -} - func (m *ChannelManagerImpl) finishOp(opID int64, channel string) { - if runner, loaded := m.opRunners.Get(channel); loaded { + if runner, loaded := m.opRunners.GetAndRemove(channel); loaded { runner.FinishOp(opID) + runner.Close() } } @@ -223,9 +213,8 @@ type opRunner struct { opsInQueue chan *datapb.ChannelWatchInfo resultCh chan *opState - closeWg sync.WaitGroup - closeOnce sync.Once - closeCh chan struct{} + closeCh lifetime.SafeChan + closeWg sync.WaitGroup } func NewOpRunner(channel string, dn *DataNode, f releaseFunc, resultCh chan *opState) *opRunner { @@ -236,7 +225,7 @@ func NewOpRunner(channel string, dn *DataNode, f releaseFunc, resultCh chan *opS opsInQueue: make(chan *datapb.ChannelWatchInfo, 10), allOps: make(map[UniqueID]*opInfo), resultCh: resultCh, - closeCh: make(chan struct{}), + closeCh: lifetime.NewSafeChan(), } } @@ -248,7 +237,7 @@ func (r *opRunner) Start() { select { case info := <-r.opsInQueue: r.NotifyState(r.Execute(info)) - case <-r.closeCh: + case <-r.closeCh.CloseCh(): return } } @@ -301,7 +290,7 @@ func (r *opRunner) Execute(info *datapb.ChannelWatchInfo) *opState { } // ToRelease state - return releaseWithTimer(r.releaseFunc, info.GetVchan().GetChannelName(), info.GetOpID()) + return r.releaseWithTimer(r.releaseFunc, info.GetVchan().GetChannelName(), info.GetOpID()) } // watchWithTimer will return WatchFailure after WatchTimeoutInterval @@ -314,13 +303,13 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState { r.guard.Lock() opInfo, ok := r.allOps[info.GetOpID()] + r.guard.Unlock() if !ok { opState.state = datapb.ChannelWatchState_WatchFailure return opState } tickler := newTickler() opInfo.tickler = tickler - r.guard.Unlock() var ( successSig = make(chan struct{}, 1) @@ -348,6 +337,13 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState { log.Info("Stop timer for ToWatch operation timeout") return + case <-r.closeCh.CloseCh(): + // runner closed from outside + tickler.close() + cancel() + log.Info("Suspend ToWatch operation from outside of opRunner") + return + case <-tickler.progressSig: log.Info("Reset timer for tickler updated") timer.Reset(watchTimeout) @@ -379,7 +375,7 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState { } // releaseWithTimer will return ReleaseFailure after WatchTimeoutInterval -func releaseWithTimer(releaseFunc releaseFunc, channel string, opID UniqueID) *opState { +func (r *opRunner) releaseWithTimer(releaseFunc releaseFunc, channel string, opID UniqueID) *opState { opState := &opState{ channel: channel, opID: opID, @@ -389,23 +385,29 @@ func releaseWithTimer(releaseFunc releaseFunc, channel string, opID UniqueID) *o waiter sync.WaitGroup ) - log := log.With(zap.String("channel", channel)) + log := log.With(zap.Int64("opID", opID), zap.String("channel", channel)) startTimer := func(wg *sync.WaitGroup) { defer wg.Done() releaseTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second) timer := time.NewTimer(releaseTimeout) defer timer.Stop() - log.Info("Start timer for ToRelease operation", zap.Duration("timeout", releaseTimeout)) + log := log.With(zap.Duration("timeout", releaseTimeout)) + log.Info("Start ToRelease timer") for { select { case <-timer.C: - log.Info("Stop timer for ToRelease operation timeout", zap.Duration("timeout", releaseTimeout)) + log.Info("Stop timer for ToRelease operation timeout") opState.state = datapb.ChannelWatchState_ReleaseFailure return + case <-r.closeCh.CloseCh(): + // runner closed from outside + log.Info("Stop timer for opRunner closed") + return + case <-successSig: - log.Info("Stop timer for ToRelease operation succeeded", zap.Duration("timeout", releaseTimeout)) + log.Info("Stop timer for ToRelease operation succeeded") opState.state = datapb.ChannelWatchState_ReleaseSuccess return } @@ -436,18 +438,8 @@ func (r *opRunner) NotifyState(state *opState) { } func (r *opRunner) Close() { - r.guard.Lock() - for _, info := range r.allOps { - if info.tickler != nil { - info.tickler.close() - } - } - r.guard.Unlock() - - r.closeOnce.Do(func() { - close(r.closeCh) - r.closeWg.Wait() - }) + r.closeCh.Close() + r.closeWg.Wait() } type opState struct { diff --git a/internal/datanode/channel_manager_test.go b/internal/datanode/channel_manager_test.go index 71c2830773..8a6b2f441d 100644 --- a/internal/datanode/channel_manager_test.go +++ b/internal/datanode/channel_manager_test.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus/internal/datanode/allocator" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -34,6 +35,44 @@ func TestChannelManagerSuite(t *testing.T) { suite.Run(t, new(ChannelManagerSuite)) } +func TestOpRunnerSuite(t *testing.T) { + suite.Run(t, new(OpRunnerSuite)) +} + +func (s *OpRunnerSuite) SetupTest() { + ctx := context.Background() + s.mockAlloc = allocator.NewMockAllocator(s.T()) + + s.node = newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) + s.node.allocator = s.mockAlloc +} + +func (s *OpRunnerSuite) TestWatchWithTimer() { + var ( + channel string = "ch-1" + commuCh = make(chan *opState) + ) + info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) + mockReleaseFunc := func(channel string) { + log.Info("mock release func") + } + runner := NewOpRunner(channel, s.node, mockReleaseFunc, commuCh) + err := runner.Enqueue(info) + s.Require().NoError(err) + + opState := runner.watchWithTimer(info) + s.NotNil(opState.fg) + s.Equal(channel, opState.channel) + + runner.FinishOp(100) +} + +type OpRunnerSuite struct { + suite.Suite + node *DataNode + mockAlloc *allocator.MockAllocator +} + type ChannelManagerSuite struct { suite.Suite @@ -45,6 +84,8 @@ func (s *ChannelManagerSuite) SetupTest() { ctx := context.Background() s.node = newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) s.node.allocator = allocator.NewMockAllocator(s.T()) + s.node.flowgraphManager = newFlowgraphManager() + s.manager = NewChannelManager(s.node) } @@ -80,7 +121,9 @@ func getWatchInfoByOpID(opID UniqueID, channel string, state datapb.ChannelWatch } func (s *ChannelManagerSuite) TearDownTest() { - s.manager.Close() + if s.manager != nil { + s.manager.Close() + } } func (s *ChannelManagerSuite) TestWatchFail() { @@ -167,11 +210,12 @@ func (s *ChannelManagerSuite) TestSubmitIdempotent() { func (s *ChannelManagerSuite) TestSubmitWatchAndRelease() { channel := "by-dev-rootcoord-dml-0" + // watch info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) - err := s.manager.Submit(info) s.NoError(err) + // wait for result opState := <-s.manager.communicateCh s.NotNil(opState) s.Equal(datapb.ChannelWatchState_WatchSuccess, opState.state) @@ -184,8 +228,8 @@ func (s *ChannelManagerSuite) TestSubmitWatchAndRelease() { s.manager.handleOpState(opState) s.Equal(1, s.manager.fgManager.GetFlowgraphCount()) - s.True(s.manager.opRunners.Contain(info.GetVchan().GetChannelName())) - s.Equal(1, s.manager.opRunners.Len()) + s.False(s.manager.opRunners.Contain(info.GetVchan().GetChannelName())) + s.Equal(0, s.manager.opRunners.Len()) resp = s.manager.GetProgress(info) s.Equal(info.GetOpID(), resp.GetOpID()) @@ -193,10 +237,10 @@ func (s *ChannelManagerSuite) TestSubmitWatchAndRelease() { // release info = getWatchInfoByOpID(101, channel, datapb.ChannelWatchState_ToRelease) - err = s.manager.Submit(info) s.NoError(err) + // wait for result opState = <-s.manager.communicateCh s.NotNil(opState) s.Equal(datapb.ChannelWatchState_ReleaseSuccess, opState.state) diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index 8818c0d497..234c3ec9f9 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -83,7 +83,6 @@ var Params *paramtable.ComponentParam = paramtable.Get() // `segmentCache` stores all flushing and flushed segments. type DataNode struct { ctx context.Context - serverID int64 cancel context.CancelFunc Role string stateCode atomic.Value // commonpb.StateCode_Initializing @@ -129,7 +128,7 @@ type DataNode struct { } // NewDataNode will return a DataNode with abnormal state. -func NewDataNode(ctx context.Context, factory dependency.Factory, serverID int64) *DataNode { +func NewDataNode(ctx context.Context, factory dependency.Factory) *DataNode { rand.Seed(time.Now().UnixNano()) ctx2, cancel2 := context.WithCancel(ctx) node := &DataNode{ @@ -140,13 +139,10 @@ func NewDataNode(ctx context.Context, factory dependency.Factory, serverID int64 rootCoord: nil, dataCoord: nil, factory: factory, - serverID: serverID, segmentCache: newCache(), compactionExecutor: newCompactionExecutor(), - eventManager: NewEventManager(), - flowgraphManager: newFlowgraphManager(), - clearSignal: make(chan string, 100), + clearSignal: make(chan string, 100), reportImportRetryTimes: 10, } @@ -228,10 +224,10 @@ func (node *DataNode) initRateCollector() error { } func (node *DataNode) GetNodeID() int64 { - if node.serverID == 0 && node.session != nil { + if node.session != nil { return node.session.ServerID } - return node.serverID + return paramtable.GetNodeID() } func (node *DataNode) Init() error { @@ -294,6 +290,13 @@ func (node *DataNode) Init() error { node.importTaskMgr = importv2.NewTaskManager() node.importScheduler = importv2.NewScheduler(node.importTaskMgr, node.syncMgr, node.chunkManager) node.channelCheckpointUpdater = newChannelCheckpointUpdater(node) + node.flowgraphManager = newFlowgraphManager() + + if paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.GetAsBool() { + node.channelManager = NewChannelManager(node) + } else { + node.eventManager = NewEventManager() + } log.Info("init datanode done", zap.String("Address", node.address)) }) @@ -322,9 +325,15 @@ func (node *DataNode) handleChannelEvt(evt *clientv3.Event) { // tryToReleaseFlowgraph tries to release a flowgraph func (node *DataNode) tryToReleaseFlowgraph(channel string) { log.Info("try to release flowgraph", zap.String("channel", channel)) - node.compactionExecutor.discardPlan(channel) - node.flowgraphManager.RemoveFlowgraph(channel) - node.writeBufferManager.RemoveChannel(channel) + if node.compactionExecutor != nil { + node.compactionExecutor.discardPlan(channel) + } + if node.flowgraphManager != nil { + node.flowgraphManager.RemoveFlowgraph(channel) + } + if node.writeBufferManager != nil { + node.writeBufferManager.RemoveChannel(channel) + } } // BackGroundGC runs in background to release datanode resources @@ -398,8 +407,12 @@ func (node *DataNode) Start() error { go node.channelCheckpointUpdater.start() - // Start node watch node - node.startWatchChannelsAtBackground(node.ctx) + if paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.GetAsBool() { + node.channelManager.Start() + } else { + // Start node watch node + node.startWatchChannelsAtBackground(node.ctx) + } node.UpdateStateCode(commonpb.StateCode_Healthy) }) @@ -433,9 +446,13 @@ func (node *DataNode) Stop() error { node.stopOnce.Do(func() { // https://github.com/milvus-io/milvus/issues/12282 node.UpdateStateCode(commonpb.StateCode_Abnormal) + if node.channelManager != nil { + node.channelManager.Close() + } - node.flowgraphManager.Close() - node.eventManager.CloseAll() + if node.eventManager != nil { + node.eventManager.CloseAll() + } if node.writeBufferManager != nil { node.writeBufferManager.Stop() @@ -466,6 +483,7 @@ func (node *DataNode) Stop() error { node.importScheduler.Close() } + // Delay the cancellation of ctx to ensure that the session is automatically recycled after closed the flow graph node.cancel() node.stopWaiter.Wait() }) diff --git a/internal/datanode/event_manager.go b/internal/datanode/event_manager.go index 5fda6f0f2c..4d0ceea6c2 100644 --- a/internal/datanode/event_manager.go +++ b/internal/datanode/event_manager.go @@ -92,7 +92,7 @@ func (node *DataNode) StartWatchChannels(ctx context.Context) { // serves the corner case for etcd connection lost and missing some events func (node *DataNode) checkWatchedList() error { // REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name} - prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.serverID)) + prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.GetNodeID())) keys, values, err := node.watchKv.LoadWithPrefix(prefix) if err != nil { return err diff --git a/internal/datanode/event_manager_test.go b/internal/datanode/event_manager_test.go index 177626d7b3..41005cde3e 100644 --- a/internal/datanode/event_manager_test.go +++ b/internal/datanode/event_manager_test.go @@ -42,6 +42,9 @@ import ( func TestWatchChannel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) + paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key) + node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) etcdCli, err := etcd.GetEtcdClient( Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), diff --git a/internal/datanode/mock_test.go b/internal/datanode/mock_test.go index c5afe6cd5a..ab9a99ad8f 100644 --- a/internal/datanode/mock_test.go +++ b/internal/datanode/mock_test.go @@ -83,7 +83,7 @@ var segID2SegInfo = map[int64]*datapb.SegmentInfo{ func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode { factory := dependency.NewDefaultFactory(true) - node := NewDataNode(ctx, factory, 1) + node := NewDataNode(ctx, factory) node.SetSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}) node.dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID()) diff --git a/internal/datanode/services.go b/internal/datanode/services.go index b90fc98722..1b66f157f9 100644 --- a/internal/datanode/services.go +++ b/internal/datanode/services.go @@ -333,6 +333,11 @@ func (node *DataNode) NotifyChannelOperation(ctx context.Context, req *datapb.Ch log.Ctx(ctx).Info("DataNode receives NotifyChannelOperation", zap.Int("operation count", len(req.GetInfos()))) + if node.channelManager == nil { + log.Warn("DataNode NotifyChannelOperation failed due to nil channelManager") + return merr.Status(merr.WrapErrServiceInternal("channelManager is nil! Ignore if you are upgrading datanode/coord to rpc based watch")), nil + } + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { log.Warn("DataNode.NotifyChannelOperation failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err)) return merr.Status(err), nil @@ -356,6 +361,14 @@ func (node *DataNode) CheckChannelOperationProgress(ctx context.Context, req *da ) log.Info("DataNode receives CheckChannelOperationProgress") + + if node.channelManager == nil { + log.Warn("DataNode CheckChannelOperationProgress failed due to nil channelManager") + return &datapb.ChannelOperationProgressResponse{ + Status: merr.Status(merr.WrapErrServiceInternal("channelManager is nil! Ignore if you are upgrading datanode/coord to rpc based watch")), + }, nil + } + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { log.Warn("DataNode.CheckChannelOperationProgress failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err)) return &datapb.ChannelOperationProgressResponse{ diff --git a/internal/datanode/services_test.go b/internal/datanode/services_test.go index 16d4157089..1393f2a57d 100644 --- a/internal/datanode/services_test.go +++ b/internal/datanode/services_test.go @@ -635,40 +635,6 @@ func (s *DataNodeServicesSuite) TestResendSegmentStats() { s.Assert().True(merr.Ok(resp.GetStatus()), "empty call, status shall be OK") } -/* -func (s *DataNodeServicesSuite) TestFlushChannels() { - dmChannelName := "fake-by-dev-rootcoord-dml-channel-TestFlushChannels" - - vChan := &datapb.VchannelInfo{ - CollectionID: 1, - ChannelName: dmChannelName, - UnflushedSegmentIds: []int64{}, - FlushedSegmentIds: []int64{}, - } - - err := s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, vChan, nil, genTestTickler()) - s.Require().NoError(err) - - fgService, ok := s.node.flowgraphManager.getFlowgraphService(dmChannelName) - s.Require().True(ok) - - flushTs := Timestamp(100) - - req := &datapb.FlushChannelsRequest{ - Base: &commonpb.MsgBase{ - TargetID: s.node.GetSession().ServerID, - }, - FlushTs: flushTs, - Channels: []string{dmChannelName}, - } - - status, err := s.node.FlushChannels(s.ctx, req) - s.Assert().NoError(err) - s.Assert().True(merr.Ok(status)) - - s.Assert().True(fgService.channel.getFlushTs() == flushTs) -}*/ - func (s *DataNodeServicesSuite) TestRPCWatch() { s.Run("node not healthy", func() { s.SetupTest() @@ -686,22 +652,16 @@ func (s *DataNodeServicesSuite) TestRPCWatch() { s.ErrorIs(merr.Error(status), merr.ErrServiceNotReady) }) - s.Run("node healthy", func() { + s.Run("submit error", func() { s.SetupTest() - mockChManager := NewMockChannelManager(s.T()) - s.node.channelManager = mockChManager - mockChManager.EXPECT().Submit(mock.Anything).Return(nil).Once() ctx := context.Background() status, err := s.node.NotifyChannelOperation(ctx, &datapb.ChannelOperationsRequest{Infos: []*datapb.ChannelWatchInfo{{OpID: 19530}}}) s.NoError(err) - s.True(merr.Ok(status)) - - mockChManager.EXPECT().GetProgress(mock.Anything).Return( - &datapb.ChannelOperationProgressResponse{Status: merr.Status(nil)}, - ).Once() + s.False(merr.Ok(status)) + s.NotErrorIs(merr.Error(status), merr.ErrServiceNotReady) resp, err := s.node.CheckChannelOperationProgress(ctx, nil) s.NoError(err) - s.True(merr.Ok(resp.GetStatus())) + s.False(merr.Ok(resp.GetStatus())) }) } diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index 387017aaf0..0b8aedce31 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -180,7 +180,7 @@ func (s *Server) startGrpcLoop(grpcPort int) { interceptor.ClusterValidationUnaryServerInterceptor(), interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 { if s.serverID.Load() == 0 { - s.serverID.Store(paramtable.GetNodeID()) + s.serverID.Store(s.dataCoord.(*datacoord.Server).GetServerID()) } return s.serverID.Load() }), @@ -191,7 +191,7 @@ func (s *Server) startGrpcLoop(grpcPort int) { interceptor.ClusterValidationStreamServerInterceptor(), interceptor.ServerIDValidationStreamServerInterceptor(func() int64 { if s.serverID.Load() == 0 { - s.serverID.Store(paramtable.GetNodeID()) + s.serverID.Store(s.dataCoord.(*datacoord.Server).GetServerID()) } return s.serverID.Load() }), diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index dc240d695c..5bbf12224e 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -91,7 +91,7 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) } s.serverID.Store(paramtable.GetNodeID()) - s.datanode = dn.NewDataNode(s.ctx, s.factory, s.serverID.Load()) + s.datanode = dn.NewDataNode(s.ctx, s.factory) return s, nil } diff --git a/pkg/util/conc/pool.go b/pkg/util/conc/pool.go index 8c6c1fb25c..f042dc04b2 100644 --- a/pkg/util/conc/pool.go +++ b/pkg/util/conc/pool.go @@ -81,9 +81,8 @@ func (pool *Pool[T]) Submit(method func() (T, error)) *Future[T] { res, err := method() if err != nil { future.err = err - } else { - future.value = res } + future.value = res }) if err != nil { future.err = err diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 00876612bc..97f8b39378 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -2601,6 +2601,8 @@ user-task-polling: type dataCoordConfig struct { // --- CHANNEL --- WatchTimeoutInterval ParamItem `refreshable:"false"` + EnableBalanceChannelWithRPC ParamItem `refreshable:"false"` + LegacyVersionWithoutRPCWatch ParamItem `refreshable:"false"` ChannelBalanceSilentDuration ParamItem `refreshable:"true"` ChannelBalanceInterval ParamItem `refreshable:"true"` ChannelCheckInterval ParamItem `refreshable:"true"` @@ -2692,6 +2694,24 @@ func (p *dataCoordConfig) init(base *BaseTable) { } p.WatchTimeoutInterval.Init(base.mgr) + p.EnableBalanceChannelWithRPC = ParamItem{ + Key: "dataCoord.channel.balanceWithRpc", + Version: "2.4.0", + DefaultValue: "true", + Doc: "Whether to enable balance with RPC, default to use etcd watch", + Export: true, + } + p.EnableBalanceChannelWithRPC.Init(base.mgr) + + p.LegacyVersionWithoutRPCWatch = ParamItem{ + Key: "dataCoord.channel.legacyVersionWithoutRPCWatch", + Version: "2.4.0", + DefaultValue: "2.4.0", + Doc: "Datanodes <= this version are considered as legacy nodes, which doesn't have rpc based watch(). This is only used during rolling upgrade where legacy nodes won't get new channels", + Export: true, + } + p.LegacyVersionWithoutRPCWatch.Init(base.mgr) + p.ChannelBalanceSilentDuration = ParamItem{ Key: "dataCoord.channel.balanceSilentDuration", Version: "2.2.3", @@ -2713,7 +2733,7 @@ func (p *dataCoordConfig) init(base *BaseTable) { p.ChannelCheckInterval = ParamItem{ Key: "dataCoord.channel.checkInterval", Version: "2.4.0", - DefaultValue: "10", + DefaultValue: "1", Doc: "The interval in seconds with which the channel manager advances channel states", Export: true, } diff --git a/tests/integration/minicluster_v2.go b/tests/integration/minicluster_v2.go index ee674d41aa..626cbbdafa 100644 --- a/tests/integration/minicluster_v2.go +++ b/tests/integration/minicluster_v2.go @@ -414,17 +414,19 @@ func (cluster *MiniClusterV2) StopAllQueryNodes() { for _, node := range cluster.querynodes { node.Stop() } - log.Info(fmt.Sprintf("mini cluster stoped %d extra querynode", numExtraQN)) + cluster.querynodes = nil + log.Info(fmt.Sprintf("mini cluster stopped %d extra querynode", numExtraQN)) } func (cluster *MiniClusterV2) StopAllDataNodes() { cluster.DataNode.Stop() log.Info("mini cluster main dataNode stopped") - numExtraQN := len(cluster.datanodes) + numExtraDN := len(cluster.datanodes) for _, node := range cluster.datanodes { node.Stop() } - log.Info(fmt.Sprintf("mini cluster stoped %d extra datanode", numExtraQN)) + cluster.datanodes = nil + log.Info(fmt.Sprintf("mini cluster stopped %d extra datanode", numExtraDN)) } func (cluster *MiniClusterV2) GetContext() context.Context { diff --git a/tests/integration/watchcompatibility/watch_test.go b/tests/integration/watchcompatibility/watch_test.go new file mode 100644 index 0000000000..0a0fc3674f --- /dev/null +++ b/tests/integration/watchcompatibility/watch_test.go @@ -0,0 +1,365 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package watchcompatibility + +import ( + "context" + "fmt" + "strconv" + "sync" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + grpcdatacoord "github.com/milvus-io/milvus/internal/distributed/datacoord" + grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/tests/integration" +) + +type DataNodeCompatibility struct { + integration.MiniClusterSuite + maxGoRoutineNum int + dim int + numCollections int + rowsPerCollection int + waitTimeInSec time.Duration + prefix string +} + +func (s *DataNodeCompatibility) setupParam() { + s.maxGoRoutineNum = 100 + s.dim = 128 + s.numCollections = 1 + s.rowsPerCollection = 100 + s.waitTimeInSec = time.Second * 1 +} + +func (s *DataNodeCompatibility) flush(collectionName string) { + c := s.Cluster + flushResp, err := c.Proxy.Flush(context.TODO(), &milvuspb.FlushRequest{ + DbName: "", + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + s.Require().True(has) + s.Require().NotEmpty(segmentIDs) + ids := segmentIDs.GetData() + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) + + segments, err := c.MetaWatcher.ShowSegments() + s.NoError(err) + s.NotEmpty(segments) + s.WaitForFlush(context.TODO(), ids, flushTs, "", collectionName) +} + +func (s *DataNodeCompatibility) loadCollection(collectionName string) { + c := s.Cluster + dbName := "" + schema := integration.ConstructSchema(collectionName, s.dim, true) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := c.Proxy.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: common.DefaultShardsNum, + }) + s.NoError(err) + + err = merr.Error(createCollectionStatus) + s.NoError(err) + + showCollectionsResp, err := c.Proxy.ShowCollections(context.TODO(), &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + s.True(merr.Ok(showCollectionsResp.GetStatus())) + + batchSize := 500000 + for start := 0; start < s.rowsPerCollection; start += batchSize { + rowNum := batchSize + if start+batchSize > s.rowsPerCollection { + rowNum = s.rowsPerCollection - start + } + fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, s.dim) + hashKeys := integration.GenerateHashKeys(rowNum) + insertResult, err := c.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.True(merr.Ok(insertResult.GetStatus())) + } + s.flush(collectionName) + + // create index + createIndexStatus, err := c.Proxy.CreateIndex(context.TODO(), &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(s.dim, integration.IndexFaissIvfFlat, metric.IP), + }) + s.NoError(err) + err = merr.Error(createIndexStatus) + s.NoError(err) + s.WaitForIndexBuilt(context.TODO(), collectionName, integration.FloatVecField) + + // load + loadStatus, err := c.Proxy.LoadCollection(context.TODO(), &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + err = merr.Error(loadStatus) + s.NoError(err) + s.WaitForLoad(context.TODO(), collectionName) +} + +func (s *DataNodeCompatibility) checkCollections() bool { + req := &milvuspb.ShowCollectionsRequest{ + DbName: "", + TimeStamp: 0, // means now + } + resp, err := s.Cluster.Proxy.ShowCollections(context.TODO(), req) + s.NoError(err) + s.Equal(len(resp.CollectionIds), s.numCollections) + notLoaded := 0 + loaded := 0 + for _, name := range resp.CollectionNames { + loadProgress, err := s.Cluster.Proxy.GetLoadingProgress(context.TODO(), &milvuspb.GetLoadingProgressRequest{ + DbName: "", + CollectionName: name, + }) + s.NoError(err) + if loadProgress.GetProgress() != int64(100) { + notLoaded++ + } else { + loaded++ + } + } + return notLoaded == 0 +} + +func (s *DataNodeCompatibility) search(collectionName string, currentNumRows int) { + c := s.Cluster + var err error + // Query + queryReq := &milvuspb.QueryRequest{ + Base: nil, + CollectionName: collectionName, + PartitionNames: nil, + Expr: "", + OutputFields: []string{"count(*)"}, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + } + queryResult, err := c.Proxy.Query(context.TODO(), queryReq) + s.NoError(err) + s.Equal(len(queryResult.FieldsData), 1) + numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0] + s.Equal(numEntities, int64(currentNumRows)) + + // Search + expr := fmt.Sprintf("%s > 0", integration.Int64Field) + nq := 10 + topk := 10 + roundDecimal := -1 + radius := 10 + + params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP) + params["radius"] = radius + searchReq := integration.ConstructSearchRequest("", collectionName, expr, + integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, s.dim, topk, roundDecimal) + + searchResult, _ := c.Proxy.Search(context.TODO(), searchReq) + + err = merr.Error(searchResult.GetStatus()) + s.NoError(err) +} + +func (s *DataNodeCompatibility) insertBatchCollections(prefix string, collectionBatchSize, idxStart int, wg *sync.WaitGroup) { + for idx := 0; idx < collectionBatchSize; idx++ { + collectionName := prefix + "_" + strconv.Itoa(idxStart+idx) + s.loadCollection(collectionName) + } + wg.Done() +} + +func (s *DataNodeCompatibility) insert(collectionName string, rowNum int) { + fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, s.dim) + hashKeys := integration.GenerateHashKeys(rowNum) + insertResult, err := s.Cluster.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{ + DbName: "", + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.True(merr.Ok(insertResult.GetStatus())) + s.flush(collectionName) +} + +func (s *DataNodeCompatibility) insertAndCheck(collectionName string, currentNumRows *int, testInsert bool) { + s.search(collectionName, *currentNumRows) + insertRows := 1000 + if testInsert { + s.insert(collectionName, insertRows) + *currentNumRows += insertRows + } + s.search(collectionName, *currentNumRows) +} + +func (s *DataNodeCompatibility) setupData() { + // Add the second data node + s.Cluster.AddDataNode() + goRoutineNum := s.maxGoRoutineNum + if goRoutineNum > s.numCollections { + goRoutineNum = s.numCollections + } + collectionBatchSize := s.numCollections / goRoutineNum + log.Info(fmt.Sprintf("=========================test with dim=%d, s.rowsPerCollection=%d, s.numCollections=%d, goRoutineNum=%d==================", s.dim, s.rowsPerCollection, s.numCollections, goRoutineNum)) + log.Info("=========================Start to inject data=========================") + s.prefix = "TestDataNodeUtil" + funcutil.GenRandomStr() + searchName := s.prefix + "_0" + wg := sync.WaitGroup{} + for idx := 0; idx < goRoutineNum; idx++ { + wg.Add(1) + go s.insertBatchCollections(s.prefix, collectionBatchSize, idx*collectionBatchSize, &wg) + } + wg.Wait() + log.Info("=========================Data injection finished=========================") + s.checkCollections() + log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName)) + s.search(searchName, s.rowsPerCollection) + log.Info("=========================Search finished=========================") + time.Sleep(s.waitTimeInSec) + s.checkCollections() + log.Info(fmt.Sprintf("=========================start to search2 %s=========================", searchName)) + s.search(searchName, s.rowsPerCollection) + log.Info("=========================Search2 finished=========================") + s.checkAllCollectionsReady() +} + +func (s *DataNodeCompatibility) checkAllCollectionsReady() { + goRoutineNum := s.maxGoRoutineNum + if goRoutineNum > s.numCollections { + goRoutineNum = s.numCollections + } + collectionBatchSize := s.numCollections / goRoutineNum + for i := 0; i < goRoutineNum; i++ { + for idx := 0; idx < collectionBatchSize; idx++ { + collectionName := s.prefix + "_" + strconv.Itoa(i*collectionBatchSize+idx) + s.search(collectionName, s.rowsPerCollection) + queryReq := &milvuspb.QueryRequest{ + CollectionName: collectionName, + Expr: "", + OutputFields: []string{"count(*)"}, + } + _, err := s.Cluster.Proxy.Query(context.TODO(), queryReq) + s.NoError(err) + } + } +} + +func (s *DataNodeCompatibility) checkSingleDNRestarts(currentNumRows *int, numNodes, idx int, testInsert bool) { + // Stop all data nodes + s.Cluster.StopAllDataNodes() + // Add new data nodes. + var dn []*grpcdatanode.Server + for i := 0; i < numNodes; i++ { + dn = append(dn, s.Cluster.AddDataNode()) + } + time.Sleep(s.waitTimeInSec) + cn := fmt.Sprintf("%s_0", s.prefix) + s.insertAndCheck(cn, currentNumRows, testInsert) + dn[idx].Stop() + time.Sleep(s.waitTimeInSec) + s.insertAndCheck(cn, currentNumRows, testInsert) +} + +func (s *DataNodeCompatibility) checkDNRestarts(currentNumRows *int, testInsert bool) { + numDatanodes := 2 // configurable + for idx := 0; idx < numDatanodes; idx++ { + s.checkSingleDNRestarts(currentNumRows, numDatanodes, idx, testInsert) + } +} + +func (s *DataNodeCompatibility) restartDC() { + c := s.Cluster + c.DataCoord.Stop() + c.DataCoord = grpcdatacoord.NewServer(context.TODO(), c.GetFactory()) + err := c.DataCoord.Run() + s.NoError(err) +} + +func (s *DataNodeCompatibility) TestCompatibility() { + s.setupParam() + s.setupData() + rows := s.rowsPerCollection + + // new coord + new node + s.checkDNRestarts(&rows, true) + + // new coord + old node + paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key, "false") + s.checkDNRestarts(&rows, false) + + // old coord + old node + s.restartDC() + s.checkDNRestarts(&rows, true) + + // old coord + new node + paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key, "true") + s.checkDNRestarts(&rows, false) + + // new coord + both old & new datanodes. + paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key, "false") + s.restartDC() + s.Cluster.StopAllDataNodes() + d1 := s.Cluster.AddDataNode() + d2 := s.Cluster.AddDataNode() + cn := fmt.Sprintf("%s_0", s.prefix) + s.insertAndCheck(cn, &rows, true) + paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.Key, "true") + s.restartDC() + s.insertAndCheck(cn, &rows, false) + s.Cluster.AddDataNode() + d1.Stop() + s.checkDNRestarts(&rows, true) + s.Cluster.AddDataNode() + d2.Stop() + s.checkDNRestarts(&rows, true) +} + +func TestDataNodeCompatibility(t *testing.T) { + suite.Run(t, new(DataNodeCompatibility)) +}