Postpone the execution of handoff until index creation is complete (#11648)

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/11648/merge
xige-16 2021-11-12 18:49:10 +08:00 committed by GitHub
parent 93149c5ad9
commit d857577a7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 434 additions and 41 deletions

View File

@ -74,6 +74,10 @@ queryCoord:
address: localhost
port: 19531
autoHandoff: true
autoBalance: false
overloadedMemoryThresholdPercentage: 90
balanceIntervalSeconds: 60
memoryUsageMaxDifferencePercentage: 30
grpc:
serverMaxRecvSize: 2147483647 # math.MaxInt32

View File

@ -25,6 +25,7 @@ import (
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
@ -59,6 +60,7 @@ type Cluster interface {
releasePartitions(ctx context.Context, nodeID int64, in *querypb.ReleasePartitionsRequest) error
getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error)
getSegmentInfoByNode(ctx context.Context, nodeID int64, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error)
getSegmentInfoByID(ctx context.Context, segmentID UniqueID) (*querypb.SegmentInfo, error)
registerNode(ctx context.Context, session *sessionutil.Session, id UniqueID, state nodeState) error
getNodeInfoByID(nodeID int64) (Node, error)
@ -69,7 +71,7 @@ type Cluster interface {
offlineNodes() (map[int64]Node, error)
hasNode(nodeID int64) bool
allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64) error
allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error
allocateChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, wait bool, excludeNodeIDs []int64) error
getSessionVersion() int64
@ -381,6 +383,37 @@ func (c *queryNodeCluster) releasePartitions(ctx context.Context, nodeID int64,
return fmt.Errorf("ReleasePartitions: can't find query node by nodeID, nodeID = %d", nodeID)
}
func (c *queryNodeCluster) getSegmentInfoByID(ctx context.Context, segmentID UniqueID) (*querypb.SegmentInfo, error) {
c.RLock()
defer c.RUnlock()
segmentInfo, err := c.clusterMeta.getSegmentInfoByID(segmentID)
if err != nil {
return nil, err
}
if node, ok := c.nodes[segmentInfo.NodeID]; ok {
res, err := node.getSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SegmentInfo,
},
CollectionID: segmentInfo.CollectionID,
})
if err != nil {
return nil, err
}
if res != nil {
for _, info := range res.Infos {
if info.SegmentID == segmentID {
return info, nil
}
}
}
return nil, fmt.Errorf("updateSegmentInfo: can't find segment %d on query node %d", segmentID, segmentInfo.NodeID)
}
return nil, fmt.Errorf("updateSegmentInfo: can't find query node by nodeID, nodeID = %d", segmentInfo.NodeID)
}
func (c *queryNodeCluster) getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) {
c.RLock()
defer c.RUnlock()
@ -650,8 +683,8 @@ func (c *queryNodeCluster) getCollectionInfosByID(ctx context.Context, nodeID in
return nil
}
func (c *queryNodeCluster) allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64) error {
return c.segmentAllocator(ctx, reqs, c, wait, excludeNodeIDs)
func (c *queryNodeCluster) allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error {
return c.segmentAllocator(ctx, reqs, c, wait, excludeNodeIDs, includeNodeIDs)
}
func (c *queryNodeCluster) allocateChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, wait bool, excludeNodeIDs []int64) error {

View File

@ -32,9 +32,11 @@ import (
)
const (
defaultTotalmemPerNode = 6000000000
defaultTotalmemPerNode = 6000000
)
var GlobalSegmentInfos = make(map[UniqueID]*querypb.SegmentInfo)
type queryNodeServerMock struct {
querypb.QueryNodeServer
ctx context.Context
@ -58,9 +60,9 @@ type queryNodeServerMock struct {
getSegmentInfos func() (*querypb.GetSegmentInfoResponse, error)
getMetrics func() (*milvuspb.GetMetricsResponse, error)
totalMem uint64
memUsage uint64
memUsageRate float64
segmentInfos map[UniqueID]*querypb.SegmentInfo
totalMem uint64
}
func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock {
@ -81,9 +83,9 @@ func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock {
getSegmentInfos: returnSuccessGetSegmentInfoResult,
getMetrics: returnSuccessGetMetricsResult,
totalMem: defaultTotalmemPerNode,
memUsage: uint64(0),
memUsageRate: float64(0),
segmentInfos: GlobalSegmentInfos,
totalMem: defaultTotalmemPerNode,
}
}
@ -194,12 +196,19 @@ func (qs *queryNodeServerMock) LoadSegments(ctx context.Context, req *querypb.Lo
if err != nil {
return returnFailedResult()
}
totalNumRow := int64(0)
for _, info := range req.Infos {
totalNumRow += info.NumOfRows
segmentInfo := &querypb.SegmentInfo{
SegmentID: info.SegmentID,
PartitionID: info.PartitionID,
CollectionID: info.CollectionID,
NodeID: qs.queryNodeID,
SegmentState: querypb.SegmentState_sealed,
MemSize: info.NumOfRows * int64(sizePerRecord),
NumRows: info.NumOfRows,
}
qs.segmentInfos[info.SegmentID] = segmentInfo
}
qs.memUsage += uint64(totalNumRow) * uint64(sizePerRecord)
qs.memUsageRate = float64(qs.memUsage) / float64(qs.totalMem)
return qs.loadSegment()
}
@ -215,8 +224,19 @@ func (qs *queryNodeServerMock) ReleaseSegments(ctx context.Context, req *querypb
return qs.releaseSegments()
}
func (qs *queryNodeServerMock) GetSegmentInfo(context.Context, *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return qs.getSegmentInfos()
func (qs *queryNodeServerMock) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
segmentInfos := make([]*querypb.SegmentInfo, 0)
for _, info := range qs.segmentInfos {
if info.CollectionID == req.CollectionID && info.NodeID == qs.queryNodeID {
segmentInfos = append(segmentInfos, info)
}
}
res, err := qs.getSegmentInfos()
if err == nil {
res.Infos = segmentInfos
}
return res, err
}
func (qs *queryNodeServerMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
@ -227,13 +247,20 @@ func (qs *queryNodeServerMock) GetMetrics(ctx context.Context, req *milvuspb.Get
if response.Status.ErrorCode != commonpb.ErrorCode_Success {
return nil, errors.New("query node do task failed")
}
totalMemUsage := uint64(0)
for _, info := range qs.segmentInfos {
if info.NodeID == qs.queryNodeID {
totalMemUsage += uint64(info.MemSize)
}
}
nodeInfos := metricsinfo.QueryNodeInfos{
BaseComponentInfos: metricsinfo.BaseComponentInfos{
Name: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, qs.queryNodeID),
HardwareInfos: metricsinfo.HardwareMetrics{
IP: qs.queryNodeIP,
Memory: qs.totalMem,
MemoryUsage: qs.memUsage,
MemoryUsage: totalMemUsage,
},
Type: typeutil.QueryNodeRole,
ID: qs.queryNodeID,

View File

@ -70,6 +70,12 @@ type ParamTable struct {
//---- Handoff ---
AutoHandoff bool
//---- Balance ---
AutoBalance bool
OverloadedMemoryThresholdPercentage float64
BalanceIntervalSeconds int64
MemoryUsageMaxDifferencePercentage float64
}
// Params are variables of the ParamTable type
@ -117,6 +123,12 @@ func (p *ParamTable) Init() {
p.initDmlChannelName()
p.initDeltaChannelName()
//---- Balance ---
p.initAutoBalance()
p.initOverloadedMemoryThresholdPercentage()
p.initBalanceIntervalSeconds()
p.initMemoryUsageMaxDifferencePercentage()
}
func (p *ParamTable) initQueryCoordAddress() {
@ -271,6 +283,42 @@ func (p *ParamTable) initAutoHandoff() {
}
}
func (p *ParamTable) initAutoBalance() {
balanceStr := p.LoadWithDefault("queryCoord.autoBalance", "false")
autoBalance, err := strconv.ParseBool(balanceStr)
if err != nil {
panic(err)
}
p.AutoBalance = autoBalance
}
func (p *ParamTable) initOverloadedMemoryThresholdPercentage() {
overloadedMemoryThresholdPercentage := p.LoadWithDefault("queryCoord.overloadedMemoryThresholdPercentage", "90")
thresholdPercentage, err := strconv.ParseInt(overloadedMemoryThresholdPercentage, 10, 64)
if err != nil {
panic(err)
}
p.OverloadedMemoryThresholdPercentage = float64(thresholdPercentage) / 100
}
func (p *ParamTable) initBalanceIntervalSeconds() {
balanceInterval := p.LoadWithDefault("queryCoord.balanceIntervalSeconds", "60")
interval, err := strconv.ParseInt(balanceInterval, 10, 64)
if err != nil {
panic(err)
}
p.BalanceIntervalSeconds = interval
}
func (p *ParamTable) initMemoryUsageMaxDifferencePercentage() {
maxDiff := p.LoadWithDefault("queryCoord.memoryUsageMaxDifferencePercentage", "30")
diffPercentage, err := strconv.ParseInt(maxDiff, 10, 64)
if err != nil {
panic(err)
}
p.MemoryUsageMaxDifferencePercentage = float64(diffPercentage) / 100
}
func (p *ParamTable) initDmlChannelName() {
config, err := p.Load("msgChannel.chanNamePrefix.rootCoordDml")
if err != nil {

View File

@ -14,6 +14,9 @@ package querycoord
import (
"context"
"errors"
"math"
"sort"
"fmt"
"math/rand"
"strconv"
@ -184,6 +187,9 @@ func (qc *QueryCoord) Start() error {
qc.loopWg.Add(1)
go qc.watchHandoffSegmentLoop()
qc.loopWg.Add(1)
go qc.loadBalanceSegmentLoop()
go qc.session.LivenessCheck(qc.loopCtx, func() {
log.Error("Query Coord disconnected from etcd, process will exit", zap.Int64("Server Id", qc.session.ServerID))
if err := qc.Stop(); err != nil {
@ -563,3 +569,179 @@ func (qc *QueryCoord) processHandoffAfterIndexDone(ctx context.Context, indexedC
}
}
}
func (qc *QueryCoord) loadBalanceSegmentLoop() {
ctx, cancel := context.WithCancel(qc.loopCtx)
defer cancel()
defer qc.loopWg.Done()
log.Debug("query coordinator start load balance segment loop")
timer := time.NewTicker(time.Duration(Params.BalanceIntervalSeconds) * time.Second)
for {
select {
case <-ctx.Done():
return
case <-timer.C:
onlineNodes, err := qc.cluster.onlineNodes()
if err != nil {
log.Warn("loadBalanceSegmentLoop: there are no online query node to balance")
continue
}
// get mem info of online nodes from cluster
nodeID2MemUsageRate := make(map[int64]float64)
nodeID2MemUsage := make(map[int64]uint64)
nodeID2TotalMem := make(map[int64]uint64)
nodeID2SegmentInfos := make(map[int64]map[UniqueID]*querypb.SegmentInfo)
onlineNodeIDs := make([]int64, 0)
for nodeID := range onlineNodes {
nodeInfo, err := qc.cluster.getNodeInfoByID(nodeID)
if err != nil {
log.Warn("loadBalanceSegmentLoop: get node info from query node failed", zap.Int64("nodeID", nodeID), zap.Error(err))
delete(onlineNodes, nodeID)
continue
}
updateSegmentInfoDone := true
leastSegmentInfos := make(map[UniqueID]*querypb.SegmentInfo)
segmentInfos := qc.meta.getSegmentInfosByNode(nodeID)
for _, segmentInfo := range segmentInfos {
leastInfo, err := qc.cluster.getSegmentInfoByID(ctx, segmentInfo.SegmentID)
if err != nil {
log.Warn("loadBalanceSegmentLoop: get segment info from query node failed", zap.Int64("nodeID", nodeID), zap.Error(err))
delete(onlineNodes, nodeID)
updateSegmentInfoDone = false
break
}
leastSegmentInfos[segmentInfo.SegmentID] = leastInfo
}
if updateSegmentInfoDone {
nodeID2MemUsageRate[nodeID] = nodeInfo.(*queryNode).memUsageRate
nodeID2MemUsage[nodeID] = nodeInfo.(*queryNode).memUsage
nodeID2TotalMem[nodeID] = nodeInfo.(*queryNode).totalMem
onlineNodeIDs = append(onlineNodeIDs, nodeID)
nodeID2SegmentInfos[nodeID] = leastSegmentInfos
}
}
log.Debug("loadBalanceSegmentLoop: memory usage rage of all online query node", zap.Any("mem rate", nodeID2MemUsageRate))
if len(onlineNodeIDs) <= 1 {
log.Warn("loadBalanceSegmentLoop: there are too few online query nodes to balance", zap.Int64s("onlineNodeIDs", onlineNodeIDs))
continue
}
// check which nodes need balance and determine which segments on these nodes need to be migrated to other nodes
memoryInsufficient := false
loadBalanceTasks := make([]*loadBalanceTask, 0)
for {
var selectedSegmentInfo *querypb.SegmentInfo = nil
sort.Slice(onlineNodeIDs, func(i, j int) bool {
return nodeID2MemUsageRate[onlineNodeIDs[i]] > nodeID2MemUsageRate[onlineNodeIDs[j]]
})
// the memoryUsageRate of the sourceNode is higher than other query node
sourceNodeID := onlineNodeIDs[0]
dstNodeID := onlineNodeIDs[len(onlineNodeIDs)-1]
memUsageRateDiff := nodeID2MemUsageRate[sourceNodeID] - nodeID2MemUsageRate[dstNodeID]
// if memoryUsageRate of source node is greater then 90%, and the max memUsageDiff is greater than 30%
// then migrate the segments on source node to other query nodes
if nodeID2MemUsageRate[sourceNodeID] > Params.OverloadedMemoryThresholdPercentage ||
memUsageRateDiff > Params.MemoryUsageMaxDifferencePercentage {
segmentInfos := nodeID2SegmentInfos[sourceNodeID]
// select the segment that needs balance on the source node
selectedSegmentInfo, err = chooseSegmentToBalance(sourceNodeID, dstNodeID, segmentInfos, nodeID2MemUsage, nodeID2TotalMem, nodeID2MemUsageRate)
if err == nil && selectedSegmentInfo != nil {
req := &querypb.LoadBalanceRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadBalanceSegments,
},
BalanceReason: querypb.TriggerCondition_loadBalance,
SourceNodeIDs: []UniqueID{sourceNodeID},
DstNodeIDs: []UniqueID{dstNodeID},
SealedSegmentIDs: []UniqueID{selectedSegmentInfo.SegmentID},
}
baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_loadBalance)
balanceTask := &loadBalanceTask{
baseTask: baseTask,
LoadBalanceRequest: req,
rootCoord: qc.rootCoordClient,
dataCoord: qc.dataCoordClient,
cluster: qc.cluster,
meta: qc.meta,
}
loadBalanceTasks = append(loadBalanceTasks, balanceTask)
nodeID2MemUsage[sourceNodeID] -= uint64(selectedSegmentInfo.MemSize)
nodeID2MemUsage[dstNodeID] += uint64(selectedSegmentInfo.MemSize)
nodeID2MemUsageRate[sourceNodeID] = float64(nodeID2MemUsage[sourceNodeID]) / float64(nodeID2TotalMem[sourceNodeID])
nodeID2MemUsageRate[dstNodeID] = float64(nodeID2MemUsage[dstNodeID]) / float64(nodeID2TotalMem[dstNodeID])
delete(nodeID2SegmentInfos[sourceNodeID], selectedSegmentInfo.SegmentID)
nodeID2SegmentInfos[dstNodeID][selectedSegmentInfo.SegmentID] = selectedSegmentInfo
continue
}
}
if err != nil {
// no enough memory on query nodes to balance, then notify proxy to stop insert
memoryInsufficient = true
}
// if memoryInsufficient == false
// all query node's memoryUsageRate is less than 90%, and the max memUsageDiff is less than 30%
// this balance loop is done
break
}
if !memoryInsufficient {
for _, t := range loadBalanceTasks {
qc.scheduler.Enqueue(t)
log.Debug("loadBalanceSegmentLoop: enqueue a loadBalance task", zap.Any("task", t))
err = t.waitToFinish()
if err != nil {
// if failed, wait for next balance loop
// it may be that the collection/partition of the balanced segment has been released
// it also may be other abnormal errors
log.Error("loadBalanceSegmentLoop: balance task execute failed", zap.Any("task", t))
} else {
log.Debug("loadBalanceSegmentLoop: balance task execute success", zap.Any("task", t))
}
}
log.Debug("loadBalanceSegmentLoop: load balance Done in this loop", zap.Any("tasks", loadBalanceTasks))
} else {
// no enough memory on query nodes to balance, then notify proxy to stop insert
//TODO:: xige-16
log.Error("loadBalanceSegmentLoop: query node has insufficient memory, stop inserting data")
}
}
}
}
func chooseSegmentToBalance(sourceNodeID int64, dstNodeID int64,
segmentInfos map[UniqueID]*querypb.SegmentInfo,
nodeID2MemUsage map[int64]uint64,
nodeID2TotalMem map[int64]uint64,
nodeID2MemUsageRate map[int64]float64) (*querypb.SegmentInfo, error) {
memoryInsufficient := true
minMemDiffPercentage := 1.0
var selectedSegmentInfo *querypb.SegmentInfo = nil
for _, info := range segmentInfos {
dstNodeMemUsageAfterBalance := nodeID2MemUsage[dstNodeID] + uint64(info.MemSize)
dstNodeMemUsageRateAfterBalance := float64(dstNodeMemUsageAfterBalance) / float64(nodeID2TotalMem[dstNodeID])
// if memUsageRate of dstNode is greater than OverloadedMemoryThresholdPercentage after balance, than can't balance
if dstNodeMemUsageRateAfterBalance < Params.OverloadedMemoryThresholdPercentage {
memoryInsufficient = false
sourceNodeMemUsageAfterBalance := nodeID2MemUsage[sourceNodeID] - uint64(info.MemSize)
sourceNodeMemUsageRateAfterBalance := float64(sourceNodeMemUsageAfterBalance) / float64(nodeID2TotalMem[sourceNodeID])
// assume all query node has same memory capacity
// if the memUsageRateDiff between the two nodes does not become smaller after balance, there is no need for balance
diffBeforBalance := nodeID2MemUsageRate[sourceNodeID] - nodeID2MemUsageRate[dstNodeID]
diffAfterBalance := dstNodeMemUsageRateAfterBalance - sourceNodeMemUsageRateAfterBalance
if diffAfterBalance < diffBeforBalance {
if math.Abs(diffAfterBalance) < minMemDiffPercentage {
selectedSegmentInfo = info
}
}
}
}
if memoryInsufficient {
return nil, errors.New("all query nodes has insufficient memory")
}
return selectedSegmentInfo, nil
}

View File

@ -43,6 +43,7 @@ func refreshParams() {
Params.MetaRootPath = Params.MetaRootPath + suffix
Params.DmlChannelPrefix = "Dml"
Params.DeltaChannelPrefix = "delta"
GlobalSegmentInfos = make(map[UniqueID]*querypb.SegmentInfo)
}
func TestMain(m *testing.M) {
@ -490,3 +491,73 @@ func TestHandoffSegmentLoop(t *testing.T) {
err = removeAllSession()
assert.Nil(t, err)
}
func TestLoadBalanceSegmentLoop(t *testing.T) {
refreshParams()
Params.BalanceIntervalSeconds = 10
baseCtx := context.Background()
queryCoord, err := startQueryCoord(baseCtx)
assert.Nil(t, err)
queryCoord.cluster.(*queryNodeCluster).segmentAllocator = shuffleSegmentsToQueryNode
queryNode1, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, queryNode1.queryNodeID)
loadCollectionTask := genLoadCollectionTask(baseCtx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadCollectionTask)
assert.Nil(t, err)
waitTaskFinalState(loadCollectionTask, taskExpired)
partitionID := defaultPartitionID
for {
req := &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{partitionID},
Schema: genCollectionSchema(defaultCollectionID, false),
}
baseTask := newBaseTask(baseCtx, querypb.TriggerCondition_grpcRequest)
loadPartitionTask := &loadPartitionTask{
baseTask: baseTask,
LoadPartitionsRequest: req,
dataCoord: queryCoord.dataCoordClient,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
err = queryCoord.scheduler.Enqueue(loadPartitionTask)
assert.Nil(t, err)
waitTaskFinalState(loadPartitionTask, taskExpired)
nodeInfo, err := queryCoord.cluster.getNodeInfoByID(queryNode1.queryNodeID)
assert.Nil(t, err)
if nodeInfo.(*queryNode).memUsageRate >= Params.OverloadedMemoryThresholdPercentage {
break
}
partitionID++
}
queryNode2, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, queryNode2.queryNodeID)
// if sealed has been balance to query node2, than balance work
for {
segmentInfos, err := queryCoord.cluster.getSegmentInfoByNode(baseCtx, queryNode2.queryNodeID, &querypb.GetSegmentInfoRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadBalanceSegments,
},
CollectionID: defaultCollectionID,
})
assert.Nil(t, err)
if len(segmentInfos) > 0 {
break
}
}
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}

View File

@ -24,19 +24,17 @@ import (
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const MaxMemUsagePerNode = 0.9
func defaultSegAllocatePolicy() SegmentAllocatePolicy {
return shuffleSegmentsToQueryNodeV2
}
// SegmentAllocatePolicy helper function definition to allocate Segment to queryNode
type SegmentAllocatePolicy func(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error
type SegmentAllocatePolicy func(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error
// shuffleSegmentsToQueryNode shuffle segments to online nodes
// returned are noded id for each segment, which satisfies:
// len(returnedNodeIds) == len(segmentIDs) && segmentIDs[i] is assigned to returnedNodeIds[i]
func shuffleSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error {
func shuffleSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error {
if len(reqs) == 0 {
return nil
}
@ -57,6 +55,10 @@ func shuffleSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegment
nodeID2NumSegemnt := make(map[int64]int)
for nodeID := range availableNodes {
if len(includeNodeIDs) > 0 && !nodeIncluded(nodeID, includeNodeIDs) {
delete(availableNodes, nodeID)
continue
}
numSegments, err := cluster.getNumSegments(nodeID)
if err != nil {
delete(availableNodes, nodeID)
@ -87,7 +89,7 @@ func shuffleSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegment
}
}
func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error {
func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error {
// key = offset, value = segmentSize
if len(reqs) == 0 {
return nil
@ -118,6 +120,10 @@ func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegme
delete(availableNodes, id)
}
for nodeID := range availableNodes {
if len(includeNodeIDs) > 0 && !nodeIncluded(nodeID, includeNodeIDs) {
delete(availableNodes, nodeID)
continue
}
// statistic nodeInfo, used memory, memory usage of every query node
nodeInfo, err := cluster.getNodeInfoByID(nodeID)
if err != nil {
@ -127,7 +133,7 @@ func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegme
}
queryNodeInfo := nodeInfo.(*queryNode)
// avoid allocate segment to node which memUsageRate is high
if queryNodeInfo.memUsageRate >= MaxMemUsagePerNode {
if queryNodeInfo.memUsageRate >= Params.OverloadedMemoryThresholdPercentage {
log.Debug("shuffleSegmentsToQueryNodeV2: queryNode memUsageRate large than MaxMemUsagePerNode", zap.Int64("nodeID", nodeID), zap.Float64("current rate", queryNodeInfo.memUsageRate))
delete(availableNodes, nodeID)
continue
@ -152,7 +158,7 @@ func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegme
for _, nodeID := range nodeIDSlice {
memUsageAfterLoad := memUsage[nodeID] + uint64(sizeOfReq)
memUsageRateAfterLoad := float64(memUsageAfterLoad) / float64(totalMem[nodeID])
if memUsageRateAfterLoad > MaxMemUsagePerNode {
if memUsageRateAfterLoad > Params.OverloadedMemoryThresholdPercentage {
continue
}
reqs[offset].DstNodeID = nodeID
@ -181,3 +187,13 @@ func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegme
}
}
}
func nodeIncluded(nodeID int64, includeNodeIDs []int64) bool {
for _, id := range includeNodeIDs {
if id == nodeID {
return true
}
}
return false
}

View File

@ -70,7 +70,7 @@ func TestShuffleSegmentsToQueryNode(t *testing.T) {
reqs := []*querypb.LoadSegmentsRequest{firstReq, secondReq}
t.Run("Test shuffleSegmentsWithoutQueryNode", func(t *testing.T) {
err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil)
err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil, nil)
assert.NotNil(t, err)
})
@ -82,7 +82,7 @@ func TestShuffleSegmentsToQueryNode(t *testing.T) {
waitQueryNodeOnline(cluster, node1ID)
t.Run("Test shuffleSegmentsToQueryNode", func(t *testing.T) {
err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil)
err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil, nil)
assert.Nil(t, err)
assert.Equal(t, node1ID, firstReq.DstNodeID)
@ -98,7 +98,7 @@ func TestShuffleSegmentsToQueryNode(t *testing.T) {
cluster.stopNode(node1ID)
t.Run("Test shuffleSegmentsToQueryNodeV2", func(t *testing.T) {
err = shuffleSegmentsToQueryNodeV2(baseCtx, reqs, cluster, false, nil)
err = shuffleSegmentsToQueryNodeV2(baseCtx, reqs, cluster, false, nil, nil)
assert.Nil(t, err)
assert.Equal(t, node2ID, firstReq.DstNodeID)

View File

@ -453,7 +453,7 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error {
}
internalTasks, err := assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, false, nil)
internalTasks, err := assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, false, nil, nil)
if err != nil {
log.Warn("loadCollectionTask: assign child task failed", zap.Int64("collectionID", collectionID))
lct.setResultInfo(err)
@ -783,7 +783,7 @@ func (lpt *loadPartitionTask) execute(ctx context.Context) error {
}
}
internalTasks, err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs, watchDeltaReqs, false, nil)
internalTasks, err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs, watchDeltaReqs, false, nil, nil)
if err != nil {
log.Warn("loadPartitionTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
lpt.setResultInfo(err)
@ -1082,7 +1082,7 @@ func (lst *loadSegmentTask) reschedule(ctx context.Context) ([]task, error) {
}
lst.excludeNodeIDs = append(lst.excludeNodeIDs, lst.DstNodeID)
//TODO:: wait or not according msgType
reScheduledTasks, err := assignInternalTask(ctx, collectionID, lst.getParentTask(), lst.meta, lst.cluster, loadSegmentReqs, nil, nil, false, lst.excludeNodeIDs)
reScheduledTasks, err := assignInternalTask(ctx, collectionID, lst.getParentTask(), lst.meta, lst.cluster, loadSegmentReqs, nil, nil, false, lst.excludeNodeIDs, nil)
if err != nil {
log.Error("loadSegment reschedule failed", zap.Int64s("excludeNodes", lst.excludeNodeIDs), zap.Error(err))
return nil, err
@ -1257,7 +1257,7 @@ func (wdt *watchDmChannelTask) reschedule(ctx context.Context) ([]task, error) {
}
wdt.excludeNodeIDs = append(wdt.excludeNodeIDs, wdt.NodeID)
//TODO:: wait or not according msgType
reScheduledTasks, err := assignInternalTask(ctx, collectionID, wdt.parentTask, wdt.meta, wdt.cluster, nil, watchDmChannelReqs, nil, false, wdt.excludeNodeIDs)
reScheduledTasks, err := assignInternalTask(ctx, collectionID, wdt.parentTask, wdt.meta, wdt.cluster, nil, watchDmChannelReqs, nil, false, wdt.excludeNodeIDs, nil)
if err != nil {
log.Error("watchDmChannel reschedule failed", zap.Int64s("excludeNodes", wdt.excludeNodeIDs), zap.Error(err))
return nil, err
@ -1557,7 +1557,7 @@ func (ht *handoffTask) execute(ctx context.Context) error {
ht.setResultInfo(err)
return err
}
internalTasks, err := assignInternalTask(ctx, collectionID, ht, ht.meta, ht.cluster, []*querypb.LoadSegmentsRequest{loadSegmentReq}, nil, watchDeltaChannelReqs, true, nil)
internalTasks, err := assignInternalTask(ctx, collectionID, ht, ht.meta, ht.cluster, []*querypb.LoadSegmentsRequest{loadSegmentReq}, nil, watchDeltaChannelReqs, true, nil, nil)
if err != nil {
log.Error("handoffTask: assign child task failed", zap.Any("segmentInfo", segmentInfo))
ht.setResultInfo(err)
@ -1774,7 +1774,7 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error {
}
}
internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, true, lbt.SourceNodeIDs)
internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, true, lbt.SourceNodeIDs, lbt.DstNodeIDs)
if err != nil {
log.Warn("loadBalanceTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
lbt.setResultInfo(err)
@ -1925,7 +1925,7 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error {
}
// TODO:: assignInternalTask with multi collection
internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, nil, watchDeltaChannelReqs, false, lbt.SourceNodeIDs)
internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, nil, watchDeltaChannelReqs, false, lbt.SourceNodeIDs, lbt.DstNodeIDs)
if err != nil {
log.Warn("loadBalanceTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
lbt.setResultInfo(err)
@ -2006,11 +2006,11 @@ func assignInternalTask(ctx context.Context,
loadSegmentRequests []*querypb.LoadSegmentsRequest,
watchDmChannelRequests []*querypb.WatchDmChannelsRequest,
watchDeltaChannelRequests []*querypb.WatchDeltaChannelsRequest,
wait bool, excludeNodeIDs []int64) ([]task, error) {
wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) ([]task, error) {
sp, _ := trace.StartSpanFromContext(ctx)
defer sp.Finish()
internalTasks := make([]task, 0)
err := cluster.allocateSegmentsToQueryNode(ctx, loadSegmentRequests, wait, excludeNodeIDs)
err := cluster.allocateSegmentsToQueryNode(ctx, loadSegmentRequests, wait, excludeNodeIDs, includeNodeIDs)
if err != nil {
log.Error("assignInternalTask: assign segment to node failed", zap.Any("load segments requests", loadSegmentRequests))
return nil, err

View File

@ -694,7 +694,7 @@ func Test_AssignInternalTask(t *testing.T) {
loadSegmentRequests = append(loadSegmentRequests, req)
}
internalTasks, err := assignInternalTask(queryCoord.loopCtx, defaultCollectionID, loadCollectionTask, queryCoord.meta, queryCoord.cluster, loadSegmentRequests, nil, nil, false, nil)
internalTasks, err := assignInternalTask(queryCoord.loopCtx, defaultCollectionID, loadCollectionTask, queryCoord.meta, queryCoord.cluster, loadSegmentRequests, nil, nil, false, nil, nil)
assert.Nil(t, err)
assert.NotEqual(t, 1, len(internalTasks))

View File

@ -87,6 +87,9 @@ type ParamTable struct {
// recovery
skipQueryChannelRecovery bool
// memory limit
OverloadedMemoryThresholdPercentage float64
}
// Params is a package scoped variable of type ParamTable.
@ -146,6 +149,7 @@ func (p *ParamTable) Init() {
p.initRoleName()
p.initSkipQueryChannelRecovery()
p.initOverloadedMemoryThresholdPercentage()
}
func (p *ParamTable) initCacheSize() {
@ -346,3 +350,12 @@ func (p *ParamTable) initRoleName() {
func (p *ParamTable) initSkipQueryChannelRecovery() {
p.skipQueryChannelRecovery = p.ParseBool("msgChannel.skipQueryChannelRecovery", false)
}
func (p *ParamTable) initOverloadedMemoryThresholdPercentage() {
overloadedMemoryThresholdPercentage := p.LoadWithDefault("queryCoord.overloadedMemoryThresholdPercentage", "90")
thresholdPercentage, err := strconv.ParseInt(overloadedMemoryThresholdPercentage, 10, 64)
if err != nil {
panic(err)
}
p.OverloadedMemoryThresholdPercentage = float64(thresholdPercentage) / 100
}

View File

@ -525,7 +525,6 @@ func (loader *segmentLoader) estimateSegmentSize(segment *Segment,
}
func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentSizes map[UniqueID]int64) error {
const thresholdFactor = 0.9
usedMem, err := getUsedMemory()
if err != nil {
return err
@ -548,16 +547,16 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentSize
zap.Any("usedMem", usedMem),
zap.Any("segmentTotalSize", segmentTotalSize),
zap.Any("currentSegmentSize", size),
zap.Any("thresholdFactor", thresholdFactor),
zap.Any("thresholdFactor", Params.OverloadedMemoryThresholdPercentage),
)
if int64(usedMem)+segmentTotalSize+size > int64(float64(totalMem)*thresholdFactor) {
if int64(usedMem)+segmentTotalSize+size > int64(float64(totalMem)*Params.OverloadedMemoryThresholdPercentage) {
return errors.New(fmt.Sprintln("load segment failed, OOM if load, "+
"collectionID = ", collectionID, ", ",
"usedMem = ", usedMem, ", ",
"segmentTotalSize = ", segmentTotalSize, ", ",
"currentSegmentSize = ", size, ", ",
"totalMem = ", totalMem, ", ",
"thresholdFactor = ", thresholdFactor,
"thresholdFactor = ", Params.OverloadedMemoryThresholdPercentage,
))
}
}