enhance segment balance by considering global rowCount(##22914) (#23056)

Signed-off-by: MrPresent-Han <jamesharden11122@gmail.com>
Co-authored-by: xiaofan-luan <xiaofan.luan@zilliz.com>
pull/23192/head
MrPresent-Han 2023-04-03 14:16:25 +08:00 committed by GitHub
parent a0ca4d6108
commit afd874b736
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1317 additions and 46 deletions

View File

@ -17,6 +17,7 @@
package balance
import (
"fmt"
"sort"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
@ -60,6 +61,11 @@ type SegmentAssignPlan struct {
Weight Weight
}
func (segPlan SegmentAssignPlan) ToString() string {
return fmt.Sprintf("SegmentPlan:[collectionID: %d, replicaID: %d, segmentID: %d, from: %d, to: %d, weight: %d]\n",
segPlan.Segment.CollectionID, segPlan.ReplicaID, segPlan.Segment.ID, segPlan.From, segPlan.To, segPlan.Weight)
}
type ChannelAssignPlan struct {
Channel *meta.DmChannel
ReplicaID int64
@ -68,8 +74,19 @@ type ChannelAssignPlan struct {
Weight Weight
}
func (chanPlan ChannelAssignPlan) ToString() string {
return fmt.Sprintf("ChannelPlan:[collectionID: %d, channel: %s, replicaID: %d, from: %d, to: %d, weight: %d]\n",
chanPlan.Channel.CollectionID, chanPlan.Channel.ChannelName, chanPlan.ReplicaID, chanPlan.From, chanPlan.To, chanPlan.Weight)
}
var (
RoundRobinBalancerName = "RoundRobinBalancer"
RowCountBasedBalancerName = "RowCountBasedBalancer"
ScoreBasedBalancerName = "ScoreBasedBalancer"
)
type Balance interface {
AssignSegment(segments []*meta.Segment, nodes []int64) []SegmentAssignPlan
AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan
AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan
Balance() ([]SegmentAssignPlan, []ChannelAssignPlan)
}
@ -79,7 +96,7 @@ type RoundRobinBalancer struct {
nodeManager *session.NodeManager
}
func (b *RoundRobinBalancer) AssignSegment(segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
nodesInfo := b.getNodes(nodes)
if len(nodesInfo) == 0 {
return nil

View File

@ -92,7 +92,7 @@ func (suite *BalanceTestSuite) TestAssignBalance() {
suite.mockScheduler.EXPECT().GetNodeSegmentDelta(c.nodeIDs[i]).Return(c.deltaCnts[i])
}
}
plans := suite.roundRobinBalancer.AssignSegment(c.assignments, c.nodeIDs)
plans := suite.roundRobinBalancer.AssignSegment(0, c.assignments, c.nodeIDs)
suite.ElementsMatch(c.expectPlans, plans)
})
}

View File

@ -60,13 +60,13 @@ func (_c *MockBalancer_AssignChannel_Call) Return(_a0 []ChannelAssignPlan) *Mock
return _c
}
// AssignSegment provides a mock function with given fields: segments, nodes
func (_m *MockBalancer) AssignSegment(segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
ret := _m.Called(segments, nodes)
// AssignSegment provides a mock function with given fields: collectionID, segments, nodes
func (_m *MockBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
ret := _m.Called(collectionID, segments, nodes)
var r0 []SegmentAssignPlan
if rf, ok := ret.Get(0).(func([]*meta.Segment, []int64) []SegmentAssignPlan); ok {
r0 = rf(segments, nodes)
if rf, ok := ret.Get(0).(func(int64, []*meta.Segment, []int64) []SegmentAssignPlan); ok {
r0 = rf(collectionID, segments, nodes)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]SegmentAssignPlan)
@ -82,15 +82,16 @@ type MockBalancer_AssignSegment_Call struct {
}
// AssignSegment is a helper method to define mock.On call
// - segments []*meta.Segment
// - nodes []int64
func (_e *MockBalancer_Expecter) AssignSegment(segments interface{}, nodes interface{}) *MockBalancer_AssignSegment_Call {
return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", segments, nodes)}
// - collectionID int64
// - segments []*meta.Segment
// - nodes []int64
func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}) *MockBalancer_AssignSegment_Call {
return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes)}
}
func (_c *MockBalancer_AssignSegment_Call) Run(run func(segments []*meta.Segment, nodes []int64)) *MockBalancer_AssignSegment_Call {
func (_c *MockBalancer_AssignSegment_Call) Run(run func(collectionID int64, segments []*meta.Segment, nodes []int64)) *MockBalancer_AssignSegment_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]*meta.Segment), args[1].([]int64))
run(args[0].(int64), args[1].([]*meta.Segment), args[2].([]int64))
})
return _c
}

View File

@ -0,0 +1,96 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package balance
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMinPriorityQueue(t *testing.T) {
pq := newPriorityQueue()
for i := 0; i < 5; i++ {
priority := i % 3
nodeItem := newNodeItem(priority, int64(i))
pq.push(&nodeItem)
}
item := pq.pop()
assert.Equal(t, item.getPriority(), 0)
assert.Equal(t, item.(*nodeItem).nodeID, int64(0))
item = pq.pop()
assert.Equal(t, item.getPriority(), 0)
assert.Equal(t, item.(*nodeItem).nodeID, int64(3))
item = pq.pop()
assert.Equal(t, item.getPriority(), 1)
assert.Equal(t, item.(*nodeItem).nodeID, int64(1))
item = pq.pop()
assert.Equal(t, item.getPriority(), 1)
assert.Equal(t, item.(*nodeItem).nodeID, int64(4))
item = pq.pop()
assert.Equal(t, item.getPriority(), 2)
println(item.getPriority())
assert.Equal(t, item.(*nodeItem).nodeID, int64(2))
}
func TestPopPriorityQueue(t *testing.T) {
pq := newPriorityQueue()
for i := 0; i < 1; i++ {
priority := 1
nodeItem := newNodeItem(priority, int64(i))
pq.push(&nodeItem)
}
item := pq.pop()
assert.Equal(t, item.getPriority(), 1)
assert.Equal(t, item.(*nodeItem).nodeID, int64(0))
pq.push(item)
// if it's round robin, but not working
item = pq.pop()
assert.Equal(t, item.getPriority(), 1)
assert.Equal(t, item.(*nodeItem).nodeID, int64(0))
}
func TestMaxPriorityQueue(t *testing.T) {
pq := newPriorityQueue()
for i := 0; i < 5; i++ {
priority := i % 3
nodeItem := newNodeItem(-priority, int64(i))
pq.push(&nodeItem)
}
item := pq.pop()
assert.Equal(t, item.getPriority(), -2)
assert.Equal(t, item.(*nodeItem).nodeID, int64(2))
item = pq.pop()
assert.Equal(t, item.getPriority(), -1)
assert.Equal(t, item.(*nodeItem).nodeID, int64(4))
item = pq.pop()
assert.Equal(t, item.getPriority(), -1)
assert.Equal(t, item.(*nodeItem).nodeID, int64(1))
item = pq.pop()
assert.Equal(t, item.getPriority(), 0)
assert.Equal(t, item.(*nodeItem).nodeID, int64(3))
item = pq.pop()
assert.Equal(t, item.getPriority(), 0)
assert.Equal(t, item.(*nodeItem).nodeID, int64(0))
}

View File

@ -37,7 +37,7 @@ type RowCountBasedBalancer struct {
targetMgr *meta.TargetManager
}
func (b *RowCountBasedBalancer) AssignSegment(segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
nodeItems := b.convertToNodeItems(nodes)
if len(nodeItems) == 0 {
return nil

View File

@ -126,7 +126,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() {
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
}
plans := balancer.AssignSegment(c.assignments, c.nodes)
plans := balancer.AssignSegment(0, c.assignments, c.nodes)
suite.ElementsMatch(c.expectPlans, plans)
})
}

View File

@ -0,0 +1,380 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package balance
import (
"sort"
"github.com/samber/lo"
"go.uber.org/zap"
"golang.org/x/exp/maps"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type ScoreBasedBalancer struct {
*RowCountBasedBalancer
balancedCollectionsCurrentRound typeutil.UniqueSet
}
func NewScoreBasedBalancer(scheduler task.Scheduler,
nodeManager *session.NodeManager,
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager) *ScoreBasedBalancer {
return &ScoreBasedBalancer{
RowCountBasedBalancer: NewRowCountBasedBalancer(scheduler, nodeManager, dist, meta, targetMgr),
balancedCollectionsCurrentRound: typeutil.NewUniqueSet(),
}
}
// TODO assign channel need to think of global channels
func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
nodeItems := b.convertToNodeItems(collectionID, nodes)
if len(nodeItems) == 0 {
return nil
}
queue := newPriorityQueue()
for _, item := range nodeItems {
queue.push(item)
}
sort.Slice(segments, func(i, j int) bool {
return segments[i].GetNumOfRows() > segments[j].GetNumOfRows()
})
plans := make([]SegmentAssignPlan, 0, len(segments))
for _, s := range segments {
// pick the node with the least row count and allocate to it.
ni := queue.pop().(*nodeItem)
plan := SegmentAssignPlan{
From: -1,
To: ni.nodeID,
Weight: GetWeight(1),
Segment: s,
}
plans = append(plans, plan)
// change node's priority and push back, should count for both collection factor and local factor
p := ni.getPriority()
ni.setPriority(p + int(s.GetNumOfRows()) +
int(float64(s.GetNumOfRows())*params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat()))
queue.push(ni)
}
return plans
}
func (b *ScoreBasedBalancer) convertToNodeItems(collectionID int64, nodeIDs []int64) []*nodeItem {
ret := make([]*nodeItem, 0, len(nodeIDs))
for _, nodeInfo := range b.getNodes(nodeIDs) {
node := nodeInfo.ID()
priority := b.calculatePriority(collectionID, node)
nodeItem := newNodeItem(priority, node)
ret = append(ret, &nodeItem)
}
return ret
}
func (b *ScoreBasedBalancer) calculatePriority(collectionID, nodeID int64) int {
globalSegments := b.dist.SegmentDistManager.GetByNode(nodeID)
rowCount := 0
for _, s := range globalSegments {
rowCount += int(s.GetNumOfRows())
}
collectionSegments := b.dist.SegmentDistManager.GetByCollectionAndNode(collectionID, nodeID)
collectionRowCount := 0
for _, s := range collectionSegments {
collectionRowCount += int(s.GetNumOfRows())
}
return collectionRowCount + int(float64(rowCount)*
params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat())
}
func (b *ScoreBasedBalancer) Balance() ([]SegmentAssignPlan, []ChannelAssignPlan) {
ids := b.meta.CollectionManager.GetAll()
// loading collection should skip balance
loadedCollections := lo.Filter(ids, func(cid int64, _ int) bool {
return b.meta.GetCollection(cid).Status == querypb.LoadStatus_Loaded
})
sort.Slice(loadedCollections, func(i, j int) bool {
return loadedCollections[i] < loadedCollections[j]
})
segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0)
hasUnBalancedCollections := false
for _, cid := range loadedCollections {
if b.balancedCollectionsCurrentRound.Contain(cid) {
log.Debug("ScoreBasedBalancer has balanced collection, skip balancing in this round",
zap.Int64("collectionID", cid))
continue
}
hasUnBalancedCollections = true
replicas := b.meta.ReplicaManager.GetByCollection(cid)
for _, replica := range replicas {
sPlans, cPlans := b.balanceReplica(replica)
PrintNewBalancePlans(cid, replica.GetID(), sPlans, cPlans)
segmentPlans = append(segmentPlans, sPlans...)
channelPlans = append(channelPlans, cPlans...)
}
b.balancedCollectionsCurrentRound.Insert(cid)
if len(segmentPlans) != 0 || len(channelPlans) != 0 {
log.Debug("ScoreBasedBalancer has generated balance plans for", zap.Int64("collectionID", cid))
break
}
}
if !hasUnBalancedCollections {
b.balancedCollectionsCurrentRound.Clear()
log.Debug("ScoreBasedBalancer has balanced all " +
"collections in one round, clear collectionIDs for this round")
}
return segmentPlans, channelPlans
}
func (b *ScoreBasedBalancer) balanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) {
nodes := replica.GetNodes()
if len(nodes) == 0 {
return nil, nil
}
nodesSegments := make(map[int64][]*meta.Segment)
stoppingNodesSegments := make(map[int64][]*meta.Segment)
outboundNodes := b.meta.ResourceManager.CheckOutboundNodes(replica)
// calculate stopping nodes and available nodes.
for _, nid := range nodes {
segments := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nid)
// Only balance segments in targets
segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.GetHistoricalSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
})
if isStopping, err := b.nodeManager.IsStoppingNode(nid); err != nil {
log.Info("not existed node", zap.Int64("nid", nid), zap.Any("segments", segments), zap.Error(err))
continue
} else if isStopping {
stoppingNodesSegments[nid] = segments
} else if outboundNodes.Contain(nid) {
// if node is stop or transfer to other rg
log.RatedInfo(10, "meet outbound node, try to move out all segment/channel",
zap.Int64("collectionID", replica.GetCollectionID()),
zap.Int64("replicaID", replica.GetCollectionID()),
zap.Int64("node", nid),
)
stoppingNodesSegments[nid] = segments
} else {
nodesSegments[nid] = segments
}
}
if len(nodes) == len(stoppingNodesSegments) {
// no available nodes to balance
log.Warn("All nodes is under stopping mode or outbound, skip balance replica",
zap.Int64("collection", replica.CollectionID),
zap.Int64("replica id", replica.Replica.GetID()),
zap.String("replica group", replica.Replica.GetResourceGroup()),
zap.Int64s("nodes", replica.Replica.GetNodes()),
)
return nil, nil
}
if len(nodesSegments) <= 0 {
log.Warn("No nodes is available in resource group, skip balance replica",
zap.Int64("collection", replica.CollectionID),
zap.Int64("replica id", replica.Replica.GetID()),
zap.String("replica group", replica.Replica.GetResourceGroup()),
zap.Int64s("nodes", replica.Replica.GetNodes()),
)
return nil, nil
}
//print current distribution before generating plans
PrintCurrentReplicaDist(replica, stoppingNodesSegments, nodesSegments, b.dist.ChannelDistManager)
if len(stoppingNodesSegments) != 0 {
log.Info("Handle stopping nodes",
zap.Int64("collection", replica.CollectionID),
zap.Int64("replica id", replica.Replica.GetID()),
zap.String("replica group", replica.Replica.GetResourceGroup()),
zap.Any("stopping nodes", maps.Keys(stoppingNodesSegments)),
zap.Any("available nodes", maps.Keys(nodesSegments)),
)
// handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score
return b.getStoppedSegmentPlan(replica, nodesSegments, stoppingNodesSegments), b.getStoppedChannelPlan(replica, lo.Keys(nodesSegments), lo.Keys(stoppingNodesSegments))
}
// normal balance, find segments from largest score nodes and transfer to smallest score nodes.
return b.getNormalSegmentPlan(replica, nodesSegments), b.getNormalChannelPlan(replica, lo.Keys(nodesSegments))
}
func (b *ScoreBasedBalancer) getStoppedSegmentPlan(replica *meta.Replica, nodesSegments map[int64][]*meta.Segment, stoppingNodesSegments map[int64][]*meta.Segment) []SegmentAssignPlan {
segmentPlans := make([]SegmentAssignPlan, 0)
// generate candidates
nodeItems := b.convertToNodeItems(replica.GetCollectionID(), lo.Keys(nodesSegments))
queue := newPriorityQueue()
for _, item := range nodeItems {
queue.push(item)
}
// collect segment segments to assign
var segments []*meta.Segment
nodeIndex := make(map[int64]int64)
for nodeID, stoppingSegments := range stoppingNodesSegments {
for _, segment := range stoppingSegments {
segments = append(segments, segment)
nodeIndex[segment.GetID()] = nodeID
}
}
sort.Slice(segments, func(i, j int) bool {
return segments[i].GetNumOfRows() > segments[j].GetNumOfRows()
})
for _, s := range segments {
// pick the node with the least row count and allocate to it.
ni := queue.pop().(*nodeItem)
plan := SegmentAssignPlan{
ReplicaID: replica.GetID(),
From: nodeIndex[s.GetID()],
To: ni.nodeID,
Weight: GetWeight(1),
Segment: s,
}
segmentPlans = append(segmentPlans, plan)
// change node's priority and push back, should count for both collection factor and local factor
p := ni.getPriority()
ni.setPriority(p + int(s.GetNumOfRows()) + int(float64(s.GetNumOfRows())*
params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat()))
queue.push(ni)
}
return segmentPlans
}
func (b *ScoreBasedBalancer) getStoppedChannelPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan {
channelPlans := make([]ChannelAssignPlan, 0)
for _, nodeID := range offlineNodes {
dmChannels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)
plans := b.AssignChannel(dmChannels, onlineNodes)
for i := range plans {
plans[i].From = nodeID
plans[i].ReplicaID = replica.ID
plans[i].Weight = GetWeight(1)
}
channelPlans = append(channelPlans, plans...)
}
return channelPlans
}
func (b *ScoreBasedBalancer) getNormalSegmentPlan(replica *meta.Replica, nodesSegments map[int64][]*meta.Segment) []SegmentAssignPlan {
if b.scheduler.GetSegmentTaskNum() != 0 {
// scheduler is handling segment task, skip
return nil
}
segmentPlans := make([]SegmentAssignPlan, 0)
// generate candidates
nodeItems := b.convertToNodeItems(replica.GetCollectionID(), lo.Keys(nodesSegments))
lastIdx := len(nodeItems) - 1
havingMovedSegments := typeutil.NewUniqueSet()
for {
sort.Slice(nodeItems, func(i, j int) bool {
return nodeItems[i].priority <= nodeItems[j].priority
})
toNode := nodeItems[0]
fromNode := nodeItems[lastIdx]
// sort the segments in asc order, try to mitigate to-from-unbalance
// TODO: segment infos inside dist manager may change in the process of making balance plan
fromSegments := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.CollectionID, fromNode.nodeID)
sort.Slice(fromSegments, func(i, j int) bool {
return fromSegments[i].GetNumOfRows() < fromSegments[j].GetNumOfRows()
})
var targetSegmentToMove *meta.Segment
for _, segment := range fromSegments {
targetSegmentToMove = segment
if havingMovedSegments.Contain(targetSegmentToMove.GetID()) {
targetSegmentToMove = nil
continue
}
break
}
if targetSegmentToMove == nil {
//the node with the highest score doesn't have any segments suitable for balancing, stop balancing this round
break
}
fromPriority := fromNode.priority
toPriority := toNode.priority
unbalance := fromPriority - toPriority
nextFromPriority := fromPriority - int(targetSegmentToMove.GetNumOfRows()) - int(float64(targetSegmentToMove.GetNumOfRows())*
params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat())
nextToPriority := toPriority + int(targetSegmentToMove.GetNumOfRows()) + int(float64(targetSegmentToMove.GetNumOfRows())*
params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat())
//still unbalanced after this balance plan is executed
if nextToPriority <= nextFromPriority {
plan := SegmentAssignPlan{
ReplicaID: replica.GetID(),
From: fromNode.nodeID,
To: toNode.nodeID,
Segment: targetSegmentToMove,
Weight: GetWeight(0),
}
segmentPlans = append(segmentPlans, plan)
} else {
//if unbalance reverted after balance action, we will consider the benefit
//only trigger following balance when the generated reverted balance
//is far smaller than the original unbalance
nextUnbalance := nextToPriority - nextFromPriority
if int(float64(nextUnbalance)*params.Params.QueryCoordCfg.ScoreUnbalanceTolerationFactor.GetAsFloat()) < unbalance {
plan := SegmentAssignPlan{
ReplicaID: replica.GetID(),
From: fromNode.nodeID,
To: toNode.nodeID,
Segment: targetSegmentToMove,
Weight: GetWeight(0),
}
segmentPlans = append(segmentPlans, plan)
} else {
//if the tiniest segment movement between the highest scored node and lowest scored node will
//not provide sufficient balance benefit, we will seize balancing in this round
break
}
}
havingMovedSegments.Insert(targetSegmentToMove.GetID())
//update node priority
toNode.setPriority(nextToPriority)
fromNode.setPriority(nextFromPriority)
// if toNode and fromNode can not find segment to balance, break, else try to balance the next round
// TODO swap segment between toNode and fromNode, see if the cluster becomes more balance
}
return segmentPlans
}
func (b *ScoreBasedBalancer) getNormalChannelPlan(replica *meta.Replica, onlineNodes []int64) []ChannelAssignPlan {
// TODO
return make([]ChannelAssignPlan, 0)
}

View File

@ -0,0 +1,604 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package balance
import (
"testing"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
. "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type ScoreBasedBalancerTestSuite struct {
suite.Suite
balancer *ScoreBasedBalancer
kv *etcdkv.EtcdKV
broker *meta.MockBroker
mockScheduler *task.MockScheduler
}
func (suite *ScoreBasedBalancerTestSuite) SetupSuite() {
Params.Init()
}
func (suite *ScoreBasedBalancerTestSuite) SetupTest() {
var err error
config := GenerateEtcdConfig()
cli, err := etcd.GetEtcdClient(
config.UseEmbedEtcd.GetAsBool(),
config.EtcdUseSSL.GetAsBool(),
config.Endpoints.GetAsStrings(),
config.EtcdTLSCert.GetValue(),
config.EtcdTLSKey.GetValue(),
config.EtcdTLSCACert.GetValue(),
config.EtcdTLSMinVersion.GetValue())
suite.Require().NoError(err)
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())
suite.broker = meta.NewMockBroker(suite.T())
store := meta.NewMetaStore(suite.kv)
idAllocator := RandomIncrementIDAllocator()
nodeManager := session.NewNodeManager()
testMeta := meta.NewMeta(idAllocator, store, nodeManager)
testTarget := meta.NewTargetManager(suite.broker, testMeta)
distManager := meta.NewDistributionManager()
suite.mockScheduler = task.NewMockScheduler(suite.T())
suite.balancer = NewScoreBasedBalancer(suite.mockScheduler, nodeManager, distManager, testMeta, testTarget)
}
func (suite *ScoreBasedBalancerTestSuite) TearDownTest() {
suite.kv.Close()
}
func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() {
cases := []struct {
name string
comment string
distributions map[int64][]*meta.Segment
assignments [][]*meta.Segment
nodes []int64
collectionIDs []int64
segmentCnts []int
states []session.State
expectPlans [][]SegmentAssignPlan
}{
{
name: "test empty cluster assigning one collection",
comment: "this is most simple case in which global row count is zero for all nodes",
distributions: map[int64][]*meta.Segment{},
assignments: [][]*meta.Segment{
{
{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 5, CollectionID: 1}},
{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 10, CollectionID: 1}},
{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 15, CollectionID: 1}},
},
},
nodes: []int64{1, 2, 3},
collectionIDs: []int64{0},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal},
segmentCnts: []int{0, 0, 0},
expectPlans: [][]SegmentAssignPlan{
{
//as assign segments is used while loading collection,
//all assignPlan should have weight equal to 1(HIGH PRIORITY)
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 15,
CollectionID: 1}}, From: -1, To: 1, Weight: 1},
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 10,
CollectionID: 1}}, From: -1, To: 3, Weight: 1},
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 5,
CollectionID: 1}}, From: -1, To: 2, Weight: 1},
},
},
},
{
name: "test non-empty cluster assigning one collection",
comment: "this case will verify the effect of global row for loading segments process, although node1" +
"has only 10 rows at the beginning, but it has so many rows on global view, resulting in a lower priority",
distributions: map[int64][]*meta.Segment{
1: {
{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 10, CollectionID: 1}, Node: 1},
{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 300, CollectionID: 2}, Node: 1},
//base: collection1-node1-priority is 10 + 0.1 * 310 = 41
//assign3: collection1-node1-priority is 15 + 0.1 * 315 = 46.5
},
2: {
{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 20, CollectionID: 1}, Node: 2},
{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 180, CollectionID: 2}, Node: 2},
//base: collection1-node2-priority is 20 + 0.1 * 200 = 40
//assign2: collection1-node2-priority is 30 + 0.1 * 210 = 51
},
3: {
{SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 30, CollectionID: 1}, Node: 3},
{SegmentInfo: &datapb.SegmentInfo{ID: 6, NumOfRows: 20, CollectionID: 2}, Node: 3},
//base: collection1-node2-priority is 30 + 0.1 * 50 = 35
//assign1: collection1-node2-priority is 45 + 0.1 * 65 = 51.5
},
},
assignments: [][]*meta.Segment{
{
{SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 5, CollectionID: 1}},
{SegmentInfo: &datapb.SegmentInfo{ID: 8, NumOfRows: 10, CollectionID: 1}},
{SegmentInfo: &datapb.SegmentInfo{ID: 9, NumOfRows: 15, CollectionID: 1}},
},
},
nodes: []int64{1, 2, 3},
collectionIDs: []int64{1},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal},
segmentCnts: []int{0, 0, 0},
expectPlans: [][]SegmentAssignPlan{
{
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 9, NumOfRows: 15, CollectionID: 1}}, From: -1, To: 3, Weight: 1},
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 8, NumOfRows: 10, CollectionID: 1}}, From: -1, To: 2, Weight: 1},
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 5, CollectionID: 1}}, From: -1, To: 1, Weight: 1},
},
},
},
{
name: "test non-empty cluster assigning two collections at one round segment checking",
comment: "this case is used to demonstrate the existing assign mechanism having flaws when assigning " +
"multi collections at one round by using the only segment distribution",
distributions: map[int64][]*meta.Segment{
1: {
{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 10, CollectionID: 1}, Node: 1},
},
2: {
{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 20, CollectionID: 1}, Node: 2},
},
3: {
{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 40, CollectionID: 1}, Node: 3},
},
},
assignments: [][]*meta.Segment{
{
{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 60, CollectionID: 1}},
{SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 50, CollectionID: 1}},
},
{
{SegmentInfo: &datapb.SegmentInfo{ID: 6, NumOfRows: 15, CollectionID: 2}},
{SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 10, CollectionID: 2}},
},
},
nodes: []int64{1, 2, 3},
collectionIDs: []int64{1, 2},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal},
segmentCnts: []int{0, 0, 0},
expectPlans: [][]SegmentAssignPlan{
//note that these two segments plans are absolutely unbalanced globally,
//as if the assignment for collection1 could succeed, node1 and node2 will both have 70 rows
//much more than node3, but following assignment will still assign segment based on [10,20,40]
//rather than [70,70,40], this flaw will be mitigated by balance process and maybe fixed in the later versions
{
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 60, CollectionID: 1}}, From: -1, To: 1, Weight: 1},
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 50, CollectionID: 1}}, From: -1, To: 2, Weight: 1},
},
{
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 6, NumOfRows: 15, CollectionID: 2}}, From: -1, To: 1, Weight: 1},
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 10, CollectionID: 2}}, From: -1, To: 2, Weight: 1},
},
},
},
}
for _, c := range cases {
suite.Run(c.name, func() {
suite.SetupSuite()
defer suite.TearDownTest()
balancer := suite.balancer
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
for i := range c.nodes {
nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0")
nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i]))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
}
for i := range c.collectionIDs {
plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes)
suite.ElementsMatch(c.expectPlans[i], plans)
}
})
}
}
func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() {
cases := []struct {
name string
nodes []int64
notExistedNodes []int64
collectionIDs []int64
replicaIDs []int64
collectionsSegments [][]*datapb.SegmentBinlogs
states []session.State
shouldMock bool
distributions map[int64][]*meta.Segment
distributionChannels map[int64][]*meta.DmChannel
expectPlans []SegmentAssignPlan
expectChannelPlans []ChannelAssignPlan
}{
{
name: "normal balance for one collection only",
nodes: []int64{1, 2},
collectionIDs: []int64{1},
replicaIDs: []int64{1},
collectionsSegments: [][]*datapb.SegmentBinlogs{
{
{SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3},
},
},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal},
distributions: map[int64][]*meta.Segment{
1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}},
2: {
{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2},
{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2},
},
},
expectPlans: []SegmentAssignPlan{
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2}, From: 2, To: 1, ReplicaID: 1},
},
expectChannelPlans: []ChannelAssignPlan{},
},
{
name: "already balanced for one collection only",
nodes: []int64{1, 2},
collectionIDs: []int64{1},
replicaIDs: []int64{1},
collectionsSegments: [][]*datapb.SegmentBinlogs{
{
{SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3},
},
},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal},
distributions: map[int64][]*meta.Segment{
1: {
{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1},
{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1},
},
2: {
{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2},
},
},
expectPlans: []SegmentAssignPlan{},
expectChannelPlans: []ChannelAssignPlan{},
},
}
for _, c := range cases {
suite.Run(c.name, func() {
suite.SetupSuite()
defer suite.TearDownTest()
balancer := suite.balancer
//1. set up target for multi collections
collections := make([]*meta.Collection, 0, len(c.collectionIDs))
for i := range c.collectionIDs {
collection := utils.CreateTestCollection(c.collectionIDs[i], int32(c.replicaIDs[i]))
collections = append(collections, collection)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, c.collectionIDs[i], c.replicaIDs[i]).Return(
nil, c.collectionsSegments[i], nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionIDs[i]).Return([]int64{c.collectionIDs[i]}, nil).Maybe()
balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(c.collectionIDs[i], c.collectionIDs[i])
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionIDs[i], c.collectionIDs[i])
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaIDs[i], c.collectionIDs[i],
append(c.nodes, c.notExistedNodes...)))
}
//2. set up target for distribution for multi collections
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
for node, v := range c.distributionChannels {
balancer.dist.ChannelDistManager.Update(node, v...)
}
//3. set up nodes info and resourceManager for balancer
for i := range c.nodes {
nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0")
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, c.nodes[i])
}
//4. balance and verify result
segmentPlans, channelPlans := balancer.Balance()
suite.ElementsMatch(c.expectChannelPlans, channelPlans)
suite.ElementsMatch(c.expectPlans, segmentPlans)
})
}
}
func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() {
balanceCase := struct {
name string
nodes []int64
notExistedNodes []int64
collectionIDs []int64
replicaIDs []int64
collectionsSegments [][]*datapb.SegmentBinlogs
states []session.State
shouldMock bool
distributions []map[int64][]*meta.Segment
expectPlans [][]SegmentAssignPlan
}{
name: "balance considering both global rowCounts and collection rowCounts",
nodes: []int64{1, 2, 3},
collectionIDs: []int64{1, 2},
replicaIDs: []int64{1, 2},
collectionsSegments: [][]*datapb.SegmentBinlogs{
{
{SegmentID: 1}, {SegmentID: 3},
},
{
{SegmentID: 2}, {SegmentID: 4},
},
},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal},
distributions: []map[int64][]*meta.Segment{
{
1: {
{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 20}, Node: 1},
{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 2, NumOfRows: 20}, Node: 1},
},
2: {
{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 20}, Node: 2},
{SegmentInfo: &datapb.SegmentInfo{ID: 4, CollectionID: 2, NumOfRows: 30}, Node: 2},
},
},
{
1: {
{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 20}, Node: 1},
{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 2, NumOfRows: 20}, Node: 1},
},
2: {
{SegmentInfo: &datapb.SegmentInfo{ID: 4, CollectionID: 2, NumOfRows: 30}, Node: 2},
},
3: {
{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 20}, Node: 3},
},
},
},
expectPlans: [][]SegmentAssignPlan{
{
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 20},
Node: 2}, From: 2, To: 3, ReplicaID: 1,
},
},
{},
},
}
suite.SetupSuite()
defer suite.TearDownTest()
balancer := suite.balancer
//1. set up target for multi collections
collections := make([]*meta.Collection, 0, len(balanceCase.collectionIDs))
for i := range balanceCase.collectionIDs {
collection := utils.CreateTestCollection(balanceCase.collectionIDs[i], int32(balanceCase.replicaIDs[i]))
collections = append(collections, collection)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, balanceCase.collectionIDs[i], balanceCase.replicaIDs[i]).Return(
nil, balanceCase.collectionsSegments[i], nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, balanceCase.collectionIDs[i]).Return([]int64{balanceCase.collectionIDs[i]}, nil).Maybe()
balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i])
balancer.targetMgr.UpdateCollectionCurrentTarget(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i])
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(balanceCase.replicaIDs[i], balanceCase.collectionIDs[i],
append(balanceCase.nodes, balanceCase.notExistedNodes...)))
}
//2. set up target for distribution for multi collections
for node, s := range balanceCase.distributions[0] {
balancer.dist.SegmentDistManager.Update(node, s...)
}
//3. set up nodes info and resourceManager for balancer
for i := range balanceCase.nodes {
nodeInfo := session.NewNodeInfo(balanceCase.nodes[i], "127.0.0.1:0")
nodeInfo.SetState(balanceCase.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, balanceCase.nodes[i])
}
//4. first round balance
segmentPlans, _ := balancer.Balance()
suite.ElementsMatch(balanceCase.expectPlans[0], segmentPlans)
//5. update segment distribution to simulate balance effect
for node, s := range balanceCase.distributions[1] {
balancer.dist.SegmentDistManager.Update(node, s...)
}
//6. balance again
segmentPlans, _ = balancer.Balance()
suite.ElementsMatch(balanceCase.expectPlans[1], segmentPlans)
//6. balance one more and finish this round
segmentPlans, _ = balancer.Balance()
suite.ElementsMatch(balanceCase.expectPlans[1], segmentPlans)
}
func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() {
cases := []struct {
name string
nodes []int64
outBoundNodes []int64
notExistedNodes []int64
collectionIDs []int64
replicaIDs []int64
collectionsSegments [][]*datapb.SegmentBinlogs
states []session.State
shouldMock bool
distributions map[int64][]*meta.Segment
distributionChannels map[int64][]*meta.DmChannel
expectPlans []SegmentAssignPlan
expectChannelPlans []ChannelAssignPlan
}{
{
name: "stopped balance for one collection",
nodes: []int64{1, 2, 3},
outBoundNodes: []int64{},
collectionIDs: []int64{1},
replicaIDs: []int64{1},
collectionsSegments: [][]*datapb.SegmentBinlogs{
{
{SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3},
},
},
states: []session.State{session.NodeStateStopping, session.NodeStateNormal, session.NodeStateNormal},
distributions: map[int64][]*meta.Segment{
1: {
{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1},
{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1},
},
2: {
{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2},
},
},
expectPlans: []SegmentAssignPlan{
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20},
Node: 1}, From: 1, To: 3, ReplicaID: 1, Weight: 1},
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10},
Node: 1}, From: 1, To: 3, ReplicaID: 1, Weight: 1},
},
expectChannelPlans: []ChannelAssignPlan{},
},
{
name: "all nodes stopping",
nodes: []int64{1, 2, 3},
outBoundNodes: []int64{},
collectionIDs: []int64{1},
replicaIDs: []int64{1},
collectionsSegments: [][]*datapb.SegmentBinlogs{
{
{SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3},
},
},
states: []session.State{session.NodeStateStopping, session.NodeStateStopping, session.NodeStateStopping},
distributions: map[int64][]*meta.Segment{
1: {
{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1},
{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1},
},
2: {
{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2},
},
},
expectPlans: []SegmentAssignPlan{},
expectChannelPlans: []ChannelAssignPlan{},
},
{
name: "all nodes outbound",
nodes: []int64{1, 2, 3},
outBoundNodes: []int64{1, 2, 3},
collectionIDs: []int64{1},
replicaIDs: []int64{1},
collectionsSegments: [][]*datapb.SegmentBinlogs{
{
{SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3},
},
},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal},
distributions: map[int64][]*meta.Segment{
1: {
{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1},
{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1},
},
2: {
{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2},
},
},
expectPlans: []SegmentAssignPlan{},
expectChannelPlans: []ChannelAssignPlan{},
},
}
for i, c := range cases {
suite.Run(c.name, func() {
if i == 0 {
suite.mockScheduler.Mock.On("GetNodeChannelDelta", mock.Anything).Return(0)
}
suite.SetupSuite()
defer suite.TearDownTest()
balancer := suite.balancer
//1. set up target for multi collections
collections := make([]*meta.Collection, 0, len(c.collectionIDs))
for i := range c.collectionIDs {
collection := utils.CreateTestCollection(c.collectionIDs[i], int32(c.replicaIDs[i]))
collections = append(collections, collection)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, c.collectionIDs[i], c.replicaIDs[i]).Return(
nil, c.collectionsSegments[i], nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionIDs[i]).Return([]int64{c.collectionIDs[i]}, nil).Maybe()
balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(c.collectionIDs[i], c.collectionIDs[i])
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionIDs[i], c.collectionIDs[i])
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaIDs[i], c.collectionIDs[i],
append(c.nodes, c.notExistedNodes...)))
}
//2. set up target for distribution for multi collections
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
for node, v := range c.distributionChannels {
balancer.dist.ChannelDistManager.Update(node, v...)
}
//3. set up nodes info and resourceManager for balancer
for i := range c.nodes {
nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0")
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, c.nodes[i])
}
for i := range c.outBoundNodes {
suite.balancer.meta.ResourceManager.UnassignNode(meta.DefaultResourceGroupName, c.outBoundNodes[i])
}
//4. balance and verify result
segmentPlans, channelPlans := balancer.Balance()
suite.ElementsMatch(c.expectChannelPlans, channelPlans)
suite.ElementsMatch(c.expectPlans, segmentPlans)
})
}
}
func TestScoreBasedBalancerSuite(t *testing.T) {
suite.Run(t, new(ScoreBasedBalancerTestSuite))
}

View File

@ -18,13 +18,19 @@ package balance
import (
"context"
"fmt"
"time"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"go.uber.org/zap"
)
const (
InfoPrefix = "Balance-Info:"
)
func CreateSegmentTasksFromPlans(ctx context.Context, checkerID int64, timeout time.Duration, plans []SegmentAssignPlan) []task.Task {
ret := make([]task.Task, 0)
for _, p := range plans {
@ -105,3 +111,75 @@ func CreateChannelTasksFromPlans(ctx context.Context, checkerID int64, timeout t
}
return ret
}
func PrintNewBalancePlans(collectionID int64, replicaID int64, segmentPlans []SegmentAssignPlan,
channelPlans []ChannelAssignPlan) {
balanceInfo := fmt.Sprintf("%s{collectionID:%d, replicaID:%d, ", InfoPrefix, collectionID, replicaID)
for _, segmentPlan := range segmentPlans {
balanceInfo += segmentPlan.ToString()
}
for _, channelPlan := range channelPlans {
balanceInfo += channelPlan.ToString()
}
balanceInfo += "}"
log.Info(balanceInfo)
}
func PrintCurrentReplicaDist(replica *meta.Replica,
stoppingNodesSegments map[int64][]*meta.Segment, nodeSegments map[int64][]*meta.Segment,
channelManager *meta.ChannelDistManager) {
distInfo := fmt.Sprintf("%s {collectionID:%d, replicaID:%d, ", InfoPrefix, replica.CollectionID, replica.GetID())
//1. print stopping nodes segment distribution
distInfo += "[stoppingNodesSegmentDist:"
for stoppingNodeID, stoppedSegments := range stoppingNodesSegments {
distInfo += fmt.Sprintf("[nodeID:%d, ", stoppingNodeID)
distInfo += "stopped-segments:["
for _, stoppedSegment := range stoppedSegments {
distInfo += fmt.Sprintf("%d,", stoppedSegment.GetID())
}
distInfo += "]]"
}
distInfo += "]\n"
//2. print normal nodes segment distribution
distInfo += "[normalNodesSegmentDist:"
for normalNodeID, normalNodeSegments := range nodeSegments {
distInfo += fmt.Sprintf("[nodeID:%d, ", normalNodeID)
distInfo += "loaded-segments:["
nodeRowSum := int64(0)
for _, normalSegment := range normalNodeSegments {
distInfo += fmt.Sprintf("[segmentID: %d, rowCount: %d] ",
normalSegment.GetID(), normalSegment.GetNumOfRows())
nodeRowSum += normalSegment.GetNumOfRows()
}
distInfo += fmt.Sprintf("] nodeRowSum:%d]", nodeRowSum)
}
distInfo += "]\n"
//3. print stopping nodes channel distribution
distInfo += "[stoppingNodesChannelDist:"
for stoppingNodeID := range stoppingNodesSegments {
stoppingNodeChannels := channelManager.GetByNode(stoppingNodeID)
distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", stoppingNodeID, len(stoppingNodeChannels))
distInfo += "channels:["
for _, stoppingChan := range stoppingNodeChannels {
distInfo += fmt.Sprintf("%s,", stoppingChan.GetChannelName())
}
distInfo += "]]"
}
distInfo += "]\n"
//4. print normal nodes channel distribution
distInfo += "[normalNodesChannelDist:"
for normalNodeID := range nodeSegments {
normalNodeChannels := channelManager.GetByNode(normalNodeID)
distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", normalNodeID, len(normalNodeChannels))
distInfo += "channels:["
for _, normalNodeChan := range normalNodeChannels {
distInfo += fmt.Sprintf("%s,", normalNodeChan.GetChannelName())
}
distInfo += "]]"
}
distInfo += "]\n"
log.Info(distInfo)
}

View File

@ -54,13 +54,14 @@ func NewCheckerController(
dist *meta.DistributionManager,
targetMgr *meta.TargetManager,
balancer balance.Balance,
nodeMgr *session.NodeManager,
scheduler task.Scheduler) *CheckerController {
// CheckerController runs checkers with the order,
// the former checker has higher priority
checkers := []Checker{
NewChannelChecker(meta, dist, targetMgr, balancer),
NewSegmentChecker(meta, dist, targetMgr, balancer),
NewSegmentChecker(meta, dist, targetMgr, balancer, nodeMgr),
NewBalanceChecker(balancer),
}
for i, checker := range checkers {

View File

@ -26,9 +26,11 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/balance"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
. "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/samber/lo"
"go.uber.org/zap"
)
@ -39,6 +41,7 @@ type SegmentChecker struct {
dist *meta.DistributionManager
targetMgr *meta.TargetManager
balancer balance.Balance
nodeMgr *session.NodeManager
}
func NewSegmentChecker(
@ -46,12 +49,14 @@ func NewSegmentChecker(
dist *meta.DistributionManager,
targetMgr *meta.TargetManager,
balancer balance.Balance,
nodeMgr *session.NodeManager,
) *SegmentChecker {
return &SegmentChecker{
meta: meta,
dist: dist,
targetMgr: targetMgr,
balancer: balancer,
nodeMgr: nodeMgr,
}
}
@ -274,9 +279,13 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments []
}
outboundNodes := c.meta.ResourceManager.CheckOutboundNodes(replica)
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool {
return !outboundNodes.Contain(node)
stop, err := c.nodeMgr.IsStoppingNode(node)
if err != nil {
return false
}
return !outboundNodes.Contain(node) && !stop
})
plans := c.balancer.AssignSegment(packedSegments, availableNodes)
plans := c.balancer.AssignSegment(replica.CollectionID, packedSegments, availableNodes)
for i := range plans {
plans[i].ReplicaID = replica.GetID()
}

View File

@ -73,7 +73,7 @@ func (suite *SegmentCheckerTestSuite) SetupTest() {
targetManager := meta.NewTargetManager(suite.broker, suite.meta)
balancer := suite.createMockBalancer()
suite.checker = NewSegmentChecker(suite.meta, distManager, targetManager, balancer)
suite.checker = NewSegmentChecker(suite.meta, distManager, targetManager, balancer, suite.nodeMgr)
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil).Maybe()
}
@ -84,7 +84,7 @@ func (suite *SegmentCheckerTestSuite) TearDownTest() {
func (suite *SegmentCheckerTestSuite) createMockBalancer() balance.Balance {
balancer := balance.NewMockBalancer(suite.T())
balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything).Maybe().Return(func(segments []*meta.Segment, nodes []int64) []balance.SegmentAssignPlan {
balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(collectionID int64, segments []*meta.Segment, nodes []int64) []balance.SegmentAssignPlan {
plans := make([]balance.SegmentAssignPlan, 0, len(segments))
for i, s := range segments {
plan := balance.SegmentAssignPlan{

View File

@ -98,7 +98,11 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe
if dstNodeSet.Len() == 0 {
outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica)
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool {
return !outboundNodes.Contain(node)
stop, err := s.nodeMgr.IsStoppingNode(node)
if err != nil {
return false
}
return !outboundNodes.Contain(node) && !stop
})
dstNodeSet.Insert(availableNodes...)
}
@ -132,7 +136,7 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe
zap.Int64("srcNodeID", srcNode),
zap.Int64s("destNodeIDs", dstNodeSet.Collect()),
)
plans := s.balancer.AssignSegment(toBalance.Collect(), dstNodeSet.Collect())
plans := s.balancer.AssignSegment(req.GetCollectionID(), toBalance.Collect(), dstNodeSet.Collect())
tasks := make([]task.Task, 0, len(plans))
for _, plan := range plans {
log.Info("manually balance segment...",

View File

@ -109,7 +109,8 @@ type Server struct {
replicaObserver *observers.ReplicaObserver
resourceObserver *observers.ResourceObserver
balancer balance.Balance
balancer balance.Balance
balancerMap map[string]balance.Balance
// Active-standby
enableActiveStandBy bool
@ -249,15 +250,21 @@ func (s *Server) initQueryCoord() error {
s.taskScheduler,
)
// Init balancer
log.Info("init balancer")
s.balancer = balance.NewRowCountBasedBalancer(
s.taskScheduler,
s.nodeMgr,
s.dist,
s.meta,
s.targetMgr,
)
// Init balancer map and balancer
log.Info("init all available balancer")
s.balancerMap = make(map[string]balance.Balance)
s.balancerMap[balance.RoundRobinBalancerName] = balance.NewRoundRobinBalancer(s.taskScheduler, s.nodeMgr)
s.balancerMap[balance.RowCountBasedBalancerName] = balance.NewRowCountBasedBalancer(s.taskScheduler,
s.nodeMgr, s.dist, s.meta, s.targetMgr)
s.balancerMap[balance.ScoreBasedBalancerName] = balance.NewScoreBasedBalancer(s.taskScheduler,
s.nodeMgr, s.dist, s.meta, s.targetMgr)
if balancer, ok := s.balancerMap[params.Params.QueryCoordCfg.Balancer.GetValue()]; ok {
s.balancer = balancer
log.Info("use config balancer", zap.String("balancer", params.Params.QueryCoordCfg.Balancer.GetValue()))
} else {
s.balancer = s.balancerMap[balance.RowCountBasedBalancerName]
log.Info("use rowCountBased auto balancer")
}
// Init checker controller
log.Info("init checker controller")
@ -266,6 +273,7 @@ func (s *Server) initQueryCoord() error {
s.dist,
s.targetMgr,
s.balancer,
s.nodeMgr,
s.taskScheduler,
)

View File

@ -475,6 +475,7 @@ func (suite *ServerSuite) hackServer() {
suite.server.dist,
suite.server.targetMgr,
suite.server.balancer,
suite.server.nodeMgr,
suite.server.taskScheduler,
)
suite.server.targetObserver = observers.NewTargetObserver(

View File

@ -13,6 +13,14 @@ type MockScheduler struct {
mock.Mock
}
func (_m *MockScheduler) GetChannelTaskNum() int {
return 0
}
func (_m *MockScheduler) GetSegmentTaskNum() int {
return 0
}
type MockScheduler_Expecter struct {
mock *mock.Mock
}
@ -41,7 +49,7 @@ type MockScheduler_Add_Call struct {
}
// Add is a helper method to define mock.On call
// - task Task
// - task Task
func (_e *MockScheduler_Expecter) Add(task interface{}) *MockScheduler_Add_Call {
return &MockScheduler_Add_Call{Call: _e.mock.On("Add", task)}
}
@ -69,7 +77,7 @@ type MockScheduler_AddExecutor_Call struct {
}
// AddExecutor is a helper method to define mock.On call
// - nodeID int64
// - nodeID int64
func (_e *MockScheduler_Expecter) AddExecutor(nodeID interface{}) *MockScheduler_AddExecutor_Call {
return &MockScheduler_AddExecutor_Call{Call: _e.mock.On("AddExecutor", nodeID)}
}
@ -97,7 +105,7 @@ type MockScheduler_Dispatch_Call struct {
}
// Dispatch is a helper method to define mock.On call
// - node int64
// - node int64
func (_e *MockScheduler_Expecter) Dispatch(node interface{}) *MockScheduler_Dispatch_Call {
return &MockScheduler_Dispatch_Call{Call: _e.mock.On("Dispatch", node)}
}
@ -134,7 +142,7 @@ type MockScheduler_GetNodeChannelDelta_Call struct {
}
// GetNodeChannelDelta is a helper method to define mock.On call
// - nodeID int64
// - nodeID int64
func (_e *MockScheduler_Expecter) GetNodeChannelDelta(nodeID interface{}) *MockScheduler_GetNodeChannelDelta_Call {
return &MockScheduler_GetNodeChannelDelta_Call{Call: _e.mock.On("GetNodeChannelDelta", nodeID)}
}
@ -171,7 +179,7 @@ type MockScheduler_GetNodeSegmentDelta_Call struct {
}
// GetNodeSegmentDelta is a helper method to define mock.On call
// - nodeID int64
// - nodeID int64
func (_e *MockScheduler_Expecter) GetNodeSegmentDelta(nodeID interface{}) *MockScheduler_GetNodeSegmentDelta_Call {
return &MockScheduler_GetNodeSegmentDelta_Call{Call: _e.mock.On("GetNodeSegmentDelta", nodeID)}
}
@ -199,7 +207,7 @@ type MockScheduler_RemoveByNode_Call struct {
}
// RemoveByNode is a helper method to define mock.On call
// - node int64
// - node int64
func (_e *MockScheduler_Expecter) RemoveByNode(node interface{}) *MockScheduler_RemoveByNode_Call {
return &MockScheduler_RemoveByNode_Call{Call: _e.mock.On("RemoveByNode", node)}
}
@ -227,7 +235,7 @@ type MockScheduler_RemoveExecutor_Call struct {
}
// RemoveExecutor is a helper method to define mock.On call
// - nodeID int64
// - nodeID int64
func (_e *MockScheduler_Expecter) RemoveExecutor(nodeID interface{}) *MockScheduler_RemoveExecutor_Call {
return &MockScheduler_RemoveExecutor_Call{Call: _e.mock.On("RemoveExecutor", nodeID)}
}
@ -255,7 +263,7 @@ type MockScheduler_Start_Call struct {
}
// Start is a helper method to define mock.On call
// - ctx context.Context
// - ctx context.Context
func (_e *MockScheduler_Expecter) Start(ctx interface{}) *MockScheduler_Start_Call {
return &MockScheduler_Start_Call{Call: _e.mock.On("Start", ctx)}
}

View File

@ -120,6 +120,8 @@ type Scheduler interface {
RemoveByNode(node int64)
GetNodeSegmentDelta(nodeID int64) int
GetNodeChannelDelta(nodeID int64) int
GetChannelTaskNum() int
GetSegmentTaskNum() int
}
type taskScheduler struct {
@ -292,7 +294,6 @@ func (scheduler *taskScheduler) preAdd(task Task) error {
return merr.WrapErrServiceInternal("task with the same channel exists")
}
if GetTaskType(task) == TaskTypeGrow {
nodesWithChannel := scheduler.distMgr.LeaderViewManager.GetChannelDist(task.Channel())
replicaNodeMap := utils.GroupNodesByReplica(scheduler.meta.ReplicaManager, task.CollectionID(), nodesWithChannel)
@ -300,11 +301,9 @@ func (scheduler *taskScheduler) preAdd(task Task) error {
return merr.WrapErrServiceInternal("channel subscribed, it can be only balanced")
}
}
default:
panic(fmt.Sprintf("preAdd: forget to process task type: %+v", task))
}
return nil
}
@ -386,6 +385,20 @@ func (scheduler *taskScheduler) GetNodeChannelDelta(nodeID int64) int {
return calculateNodeDelta(nodeID, scheduler.channelTasks)
}
func (scheduler *taskScheduler) GetChannelTaskNum() int {
scheduler.rwmutex.RLock()
defer scheduler.rwmutex.RUnlock()
return len(scheduler.channelTasks)
}
func (scheduler *taskScheduler) GetSegmentTaskNum() int {
scheduler.rwmutex.RLock()
defer scheduler.rwmutex.RUnlock()
return len(scheduler.segmentTasks)
}
func calculateNodeDelta[K comparable, T ~map[K]Task](nodeID int64, tasks T) int {
delta := 0
for _, task := range tasks {

View File

@ -435,7 +435,6 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
// LoadSegments load historical data into query node, historical data can be vector data or index
func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
nodeID := node.session.ServerID
log.Info("wayblink", zap.Int64("nodeID", nodeID))
// check node healthy
if !node.lifetime.Add(commonpbutil.IsHealthy) {
err := fmt.Errorf("query node %d is not ready", nodeID)

View File

@ -1084,6 +1084,9 @@ type queryCoordConfig struct {
//---- Balance ---
AutoBalance ParamItem `refreshable:"true"`
Balancer ParamItem `refreshable:"true"`
GlobalRowCountFactor ParamItem `refreshable:"true"`
ScoreUnbalanceTolerationFactor ParamItem `refreshable:"true"`
OverloadedMemoryThresholdPercentage ParamItem `refreshable:"true"`
BalanceIntervalSeconds ParamItem `refreshable:"true"`
MemoryUsageMaxDifferencePercentage ParamItem `refreshable:"true"`
@ -1149,13 +1152,43 @@ func (p *queryCoordConfig) init(base *BaseTable) {
p.AutoBalance = ParamItem{
Key: "queryCoord.autoBalance",
Version: "2.0.0",
DefaultValue: "tru",
DefaultValue: "true",
PanicIfEmpty: true,
Doc: "Enable auto balance",
Export: true,
}
p.AutoBalance.Init(base.mgr)
p.Balancer = ParamItem{
Key: "queryCoord.balancer",
Version: "2.0.0",
DefaultValue: "RowCountBasedBalancer",
PanicIfEmpty: true,
Doc: "auto balancer used for segments on queryNodes",
Export: true,
}
p.Balancer.Init(base.mgr)
p.GlobalRowCountFactor = ParamItem{
Key: "queryCoord.globalRowCountFactor",
Version: "2.0.0",
DefaultValue: "0.1",
PanicIfEmpty: true,
Doc: "the weight used when balancing segments among queryNodes",
Export: true,
}
p.GlobalRowCountFactor.Init(base.mgr)
p.ScoreUnbalanceTolerationFactor = ParamItem{
Key: "queryCoord.scoreUnbalanceTolerationFactor",
Version: "2.0.0",
DefaultValue: "1.3",
PanicIfEmpty: true,
Doc: "the largest value for unbalanced extent between from and to nodes when doing balance",
Export: true,
}
p.ScoreUnbalanceTolerationFactor.Init(base.mgr)
p.OverloadedMemoryThresholdPercentage = ParamItem{
Key: "queryCoord.overloadedMemoryThresholdPercentage",
Version: "2.0.0",

View File

@ -91,6 +91,10 @@ func (set Set[T]) Remove(elements ...T) {
}
}
func (set Set[T]) Clear() {
set.Remove(set.Collect()...)
}
// Get all elements in the set
func (set Set[T]) Collect() []T {
elements := make([]T, 0, len(set))

View File

@ -36,3 +36,18 @@ func TestUniqueSet(t *testing.T) {
assert.True(t, set.Contain(9))
assert.False(t, set.Contain(5, 7, 9))
}
func TestUniqueSetClear(t *testing.T) {
set := make(UniqueSet)
set.Insert(5, 7, 9)
assert.True(t, set.Contain(5))
assert.True(t, set.Contain(7))
assert.True(t, set.Contain(9))
assert.Equal(t, 3, set.Len())
set.Clear()
assert.False(t, set.Contain(5))
assert.False(t, set.Contain(7))
assert.False(t, set.Contain(9))
assert.Equal(t, 0, set.Len())
}