diff --git a/internal/dataservice/channel.go b/internal/dataservice/channel.go index e88b9030b0..d67b16c95a 100644 --- a/internal/dataservice/channel.go +++ b/internal/dataservice/channel.go @@ -31,11 +31,11 @@ func newInsertChannelManager() *insertChannelManager { } } -func (cm *insertChannelManager) AllocChannels(collectionID UniqueID, groupNum int) ([]channelGroup, error) { +func (cm *insertChannelManager) GetChannels(collectionID UniqueID, groupNum int) ([]channelGroup, error) { cm.mu.Lock() defer cm.mu.Unlock() if _, ok := cm.channelGroups[collectionID]; ok { - return nil, fmt.Errorf("channel group of collection %d already exist", collectionID) + return cm.channelGroups[collectionID], nil } channels := Params.InsertChannelNumPerCollection m, n := channels/int64(groupNum), channels%int64(groupNum) @@ -74,19 +74,3 @@ func (cm *insertChannelManager) GetChannelGroup(collectionID UniqueID, channelNa } return nil, fmt.Errorf("channel name %s not found", channelName) } - -func (cm *insertChannelManager) ContainsCollection(collectionID UniqueID) (bool, []string) { - cm.mu.RLock() - defer cm.mu.RUnlock() - _, ok := cm.channelGroups[collectionID] - if !ok { - return false, nil - } - ret := make([]string, 0) - for _, cr := range cm.channelGroups[collectionID] { - for _, c := range cr { - ret = append(ret, c) - } - } - return true, ret -} diff --git a/internal/dataservice/channel_test.go b/internal/dataservice/channel_test.go index e05884c368..70a1dab871 100644 --- a/internal/dataservice/channel_test.go +++ b/internal/dataservice/channel_test.go @@ -14,19 +14,14 @@ func TestChannelAllocation(t *testing.T) { collectionID UniqueID groupNum int expectGroupNum int - success bool }{ - {1, 4, 4, true}, - {1, 4, 4, false}, - {2, 1, 1, true}, - {3, 5, 4, true}, + {1, 4, 4}, + {1, 4, 4}, + {2, 1, 1}, + {3, 5, 4}, } for _, c := range cases { - channels, err := manager.AllocChannels(c.collectionID, c.expectGroupNum) - if !c.success { - assert.NotNil(t, err) - continue - } + channels, err := manager.GetChannels(c.collectionID, c.expectGroupNum) assert.Nil(t, err) assert.EqualValues(t, c.expectGroupNum, len(channels)) total := 0 diff --git a/internal/dataservice/server.go b/internal/dataservice/server.go index 3495643ed0..9ae8d3f205 100644 --- a/internal/dataservice/server.go +++ b/internal/dataservice/server.go @@ -682,11 +682,7 @@ func (s *Server) GetInsertChannels(req *datapb.InsertChannelRequest) ([]string, if !s.checkStateIsHealthy() { return nil, errors.New("server is initializing") } - contains, ret := s.insertChannelMgr.ContainsCollection(req.CollectionID) - if contains { - return ret, nil - } - channelGroups, err := s.insertChannelMgr.AllocChannels(req.CollectionID, s.cluster.GetNumOfNodes()) + channelGroups, err := s.insertChannelMgr.GetChannels(req.CollectionID, s.cluster.GetNumOfNodes()) if err != nil { return nil, err } @@ -696,7 +692,6 @@ func (s *Server) GetInsertChannels(req *datapb.InsertChannelRequest) ([]string, channels = append(channels, group...) } s.cluster.WatchInsertChannels(channelGroups) - return channels, nil }