Add retry mechanism for NodeDown LoadBalance (#17306)

Signed-off-by: yah01 <yang.cen@zilliz.com>
pull/17315/head
yah01 2022-06-01 20:00:03 +08:00 committed by GitHub
parent 3dea300efc
commit f5bd519e49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 402 additions and 334 deletions

View File

@ -404,10 +404,14 @@ func (qc *QueryCoord) getUnallocatedNodes() []int64 {
}
func (qc *QueryCoord) handleNodeEvent(ctx context.Context) {
offlineNodeCh := make(chan UniqueID, 100)
go qc.loadBalanceNodeLoop(ctx, offlineNodeCh)
for {
select {
case <-ctx.Done():
return
case event, ok := <-qc.eventChan:
if !ok {
// ErrCompacted is handled inside SessionWatcher
@ -420,6 +424,7 @@ func (qc *QueryCoord) handleNodeEvent(ctx context.Context) {
}
return
}
switch event.EventType {
case sessionutil.SessionAddEvent:
serverID := event.Session.ServerID
@ -446,32 +451,66 @@ func (qc *QueryCoord) handleNodeEvent(ctx context.Context) {
}
qc.cluster.stopNode(serverID)
loadBalanceSegment := &querypb.LoadBalanceRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadBalanceSegments,
SourceID: qc.session.ServerID,
},
SourceNodeIDs: []int64{serverID},
BalanceReason: querypb.TriggerCondition_NodeDown,
}
baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_NodeDown)
loadBalanceTask := &loadBalanceTask{
baseTask: baseTask,
LoadBalanceRequest: loadBalanceSegment,
broker: qc.broker,
cluster: qc.cluster,
meta: qc.meta,
}
qc.metricsCacheManager.InvalidateSystemInfoMetrics()
//TODO:: deal enqueue error
qc.scheduler.Enqueue(loadBalanceTask)
log.Info("start a loadBalance task", zap.Any("task", loadBalanceTask))
offlineNodeCh <- serverID
}
}
}
}
func (qc *QueryCoord) loadBalanceNodeLoop(ctx context.Context, offlineNodeCh chan UniqueID) {
for {
select {
case <-ctx.Done():
return
case node := <-offlineNodeCh:
loadBalanceSegment := &querypb.LoadBalanceRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadBalanceSegments,
SourceID: qc.session.ServerID,
},
SourceNodeIDs: []int64{node},
BalanceReason: querypb.TriggerCondition_NodeDown,
}
baseTask := newBaseTaskWithRetry(qc.loopCtx, querypb.TriggerCondition_NodeDown, 0)
loadBalanceTask := &loadBalanceTask{
baseTask: baseTask,
LoadBalanceRequest: loadBalanceSegment,
broker: qc.broker,
cluster: qc.cluster,
meta: qc.meta,
}
qc.metricsCacheManager.InvalidateSystemInfoMetrics()
//TODO:: deal enqueue error
err := qc.scheduler.Enqueue(loadBalanceTask)
if err != nil {
log.Warn("failed to enqueue LoadBalance task into the scheduler",
zap.Int64("nodeID", node),
zap.Error(err))
offlineNodeCh <- node
continue
}
log.Info("start a loadBalance task",
zap.Int64("nodeID", node),
zap.Int64("taskID", loadBalanceTask.getTaskID()))
err = loadBalanceTask.waitToFinish()
if err != nil {
log.Warn("failed to process LoadBalance task",
zap.Int64("nodeID", node),
zap.Error(err))
offlineNodeCh <- node
continue
}
log.Info("LoadBalance task done, offline node is removed",
zap.Int64("nodeID", node))
}
}
}
func (qc *QueryCoord) watchHandoffSegmentLoop() {
ctx, cancel := context.WithCancel(qc.loopCtx)

View File

@ -141,6 +141,12 @@ func newBaseTask(ctx context.Context, triggerType querypb.TriggerCondition) *bas
return baseTask
}
func newBaseTaskWithRetry(ctx context.Context, triggerType querypb.TriggerCondition, retryCount int) *baseTask {
baseTask := newBaseTask(ctx, triggerType)
baseTask.retryCount = retryCount
return baseTask
}
// getTaskID function returns the unique taskID of the trigger task
func (bt *baseTask) getTaskID() UniqueID {
return bt.taskID
@ -1780,6 +1786,20 @@ func (lbt *loadBalanceTask) preExecute(context.Context) error {
zap.Int64s("sourceNodeIDs", lbt.SourceNodeIDs),
zap.Any("balanceReason", lbt.BalanceReason),
zap.Int64("taskID", lbt.getTaskID()))
if lbt.triggerCondition == querypb.TriggerCondition_LoadBalance {
if err := lbt.checkForManualLoadBalance(); err != nil {
lbt.setResultInfo(err)
return err
}
if len(lbt.SourceNodeIDs) == 0 {
err := errors.New("loadBalanceTask: empty source Node list to balance")
log.Error(err.Error())
lbt.setResultInfo(err)
return err
}
}
return nil
}
@ -1848,118 +1868,99 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error {
defer lbt.reduceRetryCount()
if lbt.triggerCondition == querypb.TriggerCondition_NodeDown {
var internalTasks []task
for _, nodeID := range lbt.SourceNodeIDs {
segmentID2Info := make(map[UniqueID]*querypb.SegmentInfo)
dmChannel2WatchInfo := make(map[string]*querypb.DmChannelWatchInfo)
recoveredCollectionIDs := make(typeutil.UniqueSet)
segmentInfos := lbt.meta.getSegmentInfosByNode(nodeID)
for _, segmentInfo := range segmentInfos {
segmentID2Info[segmentInfo.SegmentID] = segmentInfo
recoveredCollectionIDs.Insert(segmentInfo.CollectionID)
err := lbt.processNodeDownLoadBalance(ctx)
if err != nil {
return err
}
} else if lbt.triggerCondition == querypb.TriggerCondition_LoadBalance {
err := lbt.processManualLoadBalance(ctx)
if err != nil {
return err
}
}
log.Info("loadBalanceTask Execute done",
zap.Int32("trigger type", int32(lbt.triggerCondition)),
zap.Int64s("sourceNodeIDs", lbt.SourceNodeIDs),
zap.Any("balanceReason", lbt.BalanceReason),
zap.Int64("taskID", lbt.getTaskID()))
return nil
}
func (lbt *loadBalanceTask) processNodeDownLoadBalance(ctx context.Context) error {
var internalTasks []task
for _, nodeID := range lbt.SourceNodeIDs {
segments := make(map[UniqueID]*querypb.SegmentInfo)
dmChannels := make(map[string]*querypb.DmChannelWatchInfo)
recoveredCollectionIDs := make(typeutil.UniqueSet)
segmentInfos := lbt.meta.getSegmentInfosByNode(nodeID)
for _, segmentInfo := range segmentInfos {
segments[segmentInfo.SegmentID] = segmentInfo
recoveredCollectionIDs.Insert(segmentInfo.CollectionID)
}
dmChannelWatchInfos := lbt.meta.getDmChannelInfosByNodeID(nodeID)
for _, watchInfo := range dmChannelWatchInfos {
dmChannels[watchInfo.DmChannel] = watchInfo
recoveredCollectionIDs.Insert(watchInfo.CollectionID)
}
if len(segments) == 0 && len(dmChannels) == 0 {
continue
}
for collectionID := range recoveredCollectionIDs {
loadSegmentReqs := make([]*querypb.LoadSegmentsRequest, 0)
watchDmChannelReqs := make([]*querypb.WatchDmChannelsRequest, 0)
collectionInfo, err := lbt.meta.getCollectionInfoByID(collectionID)
if err != nil {
log.Error("loadBalanceTask: get collectionInfo from meta failed", zap.Int64("collectionID", collectionID), zap.Error(err))
lbt.setResultInfo(err)
return err
}
dmChannelWatchInfos := lbt.meta.getDmChannelInfosByNodeID(nodeID)
for _, watchInfo := range dmChannelWatchInfos {
dmChannel2WatchInfo[watchInfo.DmChannel] = watchInfo
recoveredCollectionIDs.Insert(watchInfo.CollectionID)
schema := collectionInfo.Schema
var deltaChannelInfos []*datapb.VchannelInfo
var dmChannelInfos []*datapb.VchannelInfo
var toRecoverPartitionIDs []UniqueID
if collectionInfo.LoadType == querypb.LoadType_LoadCollection {
toRecoverPartitionIDs, err = lbt.broker.showPartitionIDs(ctx, collectionID)
if err != nil {
log.Error("loadBalanceTask: show collection's partitionIDs failed", zap.Int64("collectionID", collectionID), zap.Error(err))
lbt.setResultInfo(err)
panic(err)
}
} else {
toRecoverPartitionIDs = collectionInfo.PartitionIDs
}
log.Info("loadBalanceTask: get collection's all partitionIDs", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", toRecoverPartitionIDs))
replica, err := lbt.getReplica(nodeID, collectionID)
if err != nil {
lbt.setResultInfo(err)
return err
}
for collectionID := range recoveredCollectionIDs {
loadSegmentReqs := make([]*querypb.LoadSegmentsRequest, 0)
watchDmChannelReqs := make([]*querypb.WatchDmChannelsRequest, 0)
collectionInfo, err := lbt.meta.getCollectionInfoByID(collectionID)
for _, partitionID := range toRecoverPartitionIDs {
vChannelInfos, binlogs, err := lbt.broker.getRecoveryInfo(lbt.ctx, collectionID, partitionID)
if err != nil {
log.Error("loadBalanceTask: get collectionInfo from meta failed", zap.Int64("collectionID", collectionID), zap.Error(err))
lbt.setResultInfo(err)
return err
}
schema := collectionInfo.Schema
var deltaChannelInfos []*datapb.VchannelInfo
var dmChannelInfos []*datapb.VchannelInfo
var toRecoverPartitionIDs []UniqueID
if collectionInfo.LoadType == querypb.LoadType_LoadCollection {
toRecoverPartitionIDs, err = lbt.broker.showPartitionIDs(ctx, collectionID)
if err != nil {
log.Error("loadBalanceTask: show collection's partitionIDs failed", zap.Int64("collectionID", collectionID), zap.Error(err))
lbt.setResultInfo(err)
panic(err)
}
} else {
toRecoverPartitionIDs = collectionInfo.PartitionIDs
}
log.Info("loadBalanceTask: get collection's all partitionIDs", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", toRecoverPartitionIDs))
replica, err := lbt.getReplica(nodeID, collectionID)
if err != nil {
lbt.setResultInfo(err)
return err
}
for _, partitionID := range toRecoverPartitionIDs {
vChannelInfos, binlogs, err := lbt.broker.getRecoveryInfo(lbt.ctx, collectionID, partitionID)
if err != nil {
log.Error("loadBalanceTask: getRecoveryInfo failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Error(err))
lbt.setResultInfo(err)
panic(err)
}
for _, segmentBingLog := range binlogs {
segmentID := segmentBingLog.SegmentID
if _, ok := segmentID2Info[segmentID]; ok {
segmentLoadInfo := lbt.broker.generateSegmentLoadInfo(ctx, collectionID, partitionID, segmentBingLog, true, schema)
msgBase := proto.Clone(lbt.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_LoadSegments
loadSegmentReq := &querypb.LoadSegmentsRequest{
Base: msgBase,
Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo},
Schema: schema,
CollectionID: collectionID,
LoadMeta: &querypb.LoadMetaInfo{
LoadType: collectionInfo.LoadType,
CollectionID: collectionID,
PartitionIDs: toRecoverPartitionIDs,
},
ReplicaID: replica.ReplicaID,
}
loadSegmentReqs = append(loadSegmentReqs, loadSegmentReq)
}
}
for _, info := range vChannelInfos {
deltaChannel, err := generateWatchDeltaChannelInfo(info)
if err != nil {
log.Error("loadBalanceTask: generateWatchDeltaChannelInfo failed", zap.Int64("collectionID", collectionID), zap.String("channelName", info.ChannelName), zap.Error(err))
lbt.setResultInfo(err)
panic(err)
}
deltaChannelInfos = append(deltaChannelInfos, deltaChannel)
dmChannelInfos = append(dmChannelInfos, info)
}
}
mergedDeltaChannel := mergeWatchDeltaChannelInfo(deltaChannelInfos)
// If meta is not updated here, deltaChannel meta will not be available when loadSegment reschedule
err = lbt.meta.setDeltaChannel(collectionID, mergedDeltaChannel)
if err != nil {
log.Error("loadBalanceTask: set delta channel info meta failed", zap.Int64("collectionID", collectionID), zap.Error(err))
log.Error("loadBalanceTask: getRecoveryInfo failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Error(err))
lbt.setResultInfo(err)
panic(err)
}
mergedDmChannel := mergeDmChannelInfo(dmChannelInfos)
for channelName, vChannelInfo := range mergedDmChannel {
if _, ok := dmChannel2WatchInfo[channelName]; ok {
for _, segmentBingLog := range binlogs {
segmentID := segmentBingLog.SegmentID
if _, ok := segments[segmentID]; ok {
segmentLoadInfo := lbt.broker.generateSegmentLoadInfo(ctx, collectionID, partitionID, segmentBingLog, true, schema)
msgBase := proto.Clone(lbt.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_WatchDmChannels
watchRequest := &querypb.WatchDmChannelsRequest{
msgBase.MsgType = commonpb.MsgType_LoadSegments
loadSegmentReq := &querypb.LoadSegmentsRequest{
Base: msgBase,
CollectionID: collectionID,
Infos: []*datapb.VchannelInfo{vChannelInfo},
Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo},
Schema: schema,
CollectionID: collectionID,
LoadMeta: &querypb.LoadMetaInfo{
LoadType: collectionInfo.LoadType,
CollectionID: collectionID,
@ -1968,216 +1969,246 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error {
ReplicaID: replica.ReplicaID,
}
if collectionInfo.LoadType == querypb.LoadType_LoadPartition {
watchRequest.PartitionIDs = toRecoverPartitionIDs
}
watchDmChannelReqs = append(watchDmChannelReqs, watchRequest)
loadSegmentReqs = append(loadSegmentReqs, loadSegmentReq)
}
}
tasks, err := assignInternalTask(ctx, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs, true, lbt.SourceNodeIDs, lbt.DstNodeIDs, replica.GetReplicaID(), lbt.broker)
if err != nil {
log.Error("loadBalanceTask: assign child task failed", zap.Int64("sourceNodeID", nodeID))
lbt.setResultInfo(err)
panic(err)
for _, info := range vChannelInfos {
deltaChannel, err := generateWatchDeltaChannelInfo(info)
if err != nil {
log.Error("loadBalanceTask: generateWatchDeltaChannelInfo failed", zap.Int64("collectionID", collectionID), zap.String("channelName", info.ChannelName), zap.Error(err))
lbt.setResultInfo(err)
panic(err)
}
deltaChannelInfos = append(deltaChannelInfos, deltaChannel)
dmChannelInfos = append(dmChannelInfos, info)
}
internalTasks = append(internalTasks, tasks...)
}
mergedDeltaChannel := mergeWatchDeltaChannelInfo(deltaChannelInfos)
// If meta is not updated here, deltaChannel meta will not be available when loadSegment reschedule
err = lbt.meta.setDeltaChannel(collectionID, mergedDeltaChannel)
if err != nil {
log.Error("loadBalanceTask: set delta channel info meta failed", zap.Int64("collectionID", collectionID), zap.Error(err))
lbt.setResultInfo(err)
panic(err)
}
mergedDmChannel := mergeDmChannelInfo(dmChannelInfos)
for channelName, vChannelInfo := range mergedDmChannel {
if _, ok := dmChannels[channelName]; ok {
msgBase := proto.Clone(lbt.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_WatchDmChannels
watchRequest := &querypb.WatchDmChannelsRequest{
Base: msgBase,
CollectionID: collectionID,
Infos: []*datapb.VchannelInfo{vChannelInfo},
Schema: schema,
LoadMeta: &querypb.LoadMetaInfo{
LoadType: collectionInfo.LoadType,
CollectionID: collectionID,
PartitionIDs: toRecoverPartitionIDs,
},
ReplicaID: replica.ReplicaID,
}
if collectionInfo.LoadType == querypb.LoadType_LoadPartition {
watchRequest.PartitionIDs = toRecoverPartitionIDs
}
watchDmChannelReqs = append(watchDmChannelReqs, watchRequest)
}
}
tasks, err := assignInternalTask(ctx, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs, true, lbt.SourceNodeIDs, lbt.DstNodeIDs, replica.GetReplicaID(), lbt.broker)
if err != nil {
log.Error("loadBalanceTask: assign child task failed", zap.Int64("sourceNodeID", nodeID))
lbt.setResultInfo(err)
panic(err)
}
internalTasks = append(internalTasks, tasks...)
}
for _, internalTask := range internalTasks {
lbt.addChildTask(internalTask)
log.Info("loadBalanceTask: add a childTask", zap.String("task type", internalTask.msgType().String()), zap.Any("task", internalTask))
}
log.Info("loadBalanceTask: assign child task done", zap.Int64s("sourceNodeIDs", lbt.SourceNodeIDs))
} else if lbt.triggerCondition == querypb.TriggerCondition_LoadBalance {
if err := lbt.checkForManualLoadBalance(); err != nil {
lbt.setResultInfo(err)
return err
}
if len(lbt.SourceNodeIDs) == 0 {
err := errors.New("loadBalanceTask: empty source Node list to balance")
}
for _, internalTask := range internalTasks {
lbt.addChildTask(internalTask)
log.Info("loadBalanceTask: add a childTask", zap.String("task type", internalTask.msgType().String()), zap.Any("task", internalTask))
}
log.Info("loadBalanceTask: assign child task done", zap.Int64s("sourceNodeIDs", lbt.SourceNodeIDs))
return nil
}
func (lbt *loadBalanceTask) processManualLoadBalance(ctx context.Context) error {
balancedSegmentInfos := make(map[UniqueID]*querypb.SegmentInfo)
balancedSegmentIDs := make([]UniqueID, 0)
for _, nodeID := range lbt.SourceNodeIDs {
nodeExist := lbt.cluster.hasNode(nodeID)
if !nodeExist {
err := fmt.Errorf("loadBalanceTask: query node %d is not exist to balance", nodeID)
log.Error(err.Error())
lbt.setResultInfo(err)
return err
}
balancedSegmentInfos := make(map[UniqueID]*querypb.SegmentInfo)
balancedSegmentIDs := make([]UniqueID, 0)
for _, nodeID := range lbt.SourceNodeIDs {
nodeExist := lbt.cluster.hasNode(nodeID)
if !nodeExist {
err := fmt.Errorf("loadBalanceTask: query node %d is not exist to balance", nodeID)
log.Error(err.Error())
lbt.setResultInfo(err)
return err
}
segmentInfos := lbt.meta.getSegmentInfosByNode(nodeID)
for _, info := range segmentInfos {
balancedSegmentInfos[info.SegmentID] = info
balancedSegmentIDs = append(balancedSegmentIDs, info.SegmentID)
}
segmentInfos := lbt.meta.getSegmentInfosByNode(nodeID)
for _, info := range segmentInfos {
balancedSegmentInfos[info.SegmentID] = info
balancedSegmentIDs = append(balancedSegmentIDs, info.SegmentID)
}
}
// check balanced sealedSegmentIDs in request whether exist in query node
for _, segmentID := range lbt.SealedSegmentIDs {
if _, ok := balancedSegmentInfos[segmentID]; !ok {
err := fmt.Errorf("loadBalanceTask: unloaded segment %d", segmentID)
log.Warn(err.Error())
lbt.setResultInfo(err)
return err
}
}
if len(lbt.SealedSegmentIDs) != 0 {
balancedSegmentIDs = lbt.SealedSegmentIDs
}
// TODO(yah01): release balanced segments in source nodes
// balancedSegmentSet := make(typeutil.UniqueSet)
// balancedSegmentSet.Insert(balancedSegmentIDs...)
// for _, nodeID := range lbt.SourceNodeIDs {
// segments := lbt.meta.getSegmentInfosByNode(nodeID)
// shardSegments := make(map[string][]UniqueID)
// for _, segment := range segments {
// if !balancedSegmentSet.Contain(segment.SegmentID) {
// continue
// }
// shardSegments[segment.DmChannel] = append(shardSegments[segment.DmChannel], segment.SegmentID)
// }
// for dmc, segmentIDs := range shardSegments {
// shardLeader, err := getShardLeaderByNodeID(lbt.meta, lbt.replicaID, dmc)
// if err != nil {
// log.Error("failed to get shardLeader",
// zap.Int64("replicaID", lbt.replicaID),
// zap.Int64("nodeID", nodeID),
// zap.String("dmChannel", dmc),
// zap.Error(err))
// lbt.setResultInfo(err)
// return err
// }
// releaseSegmentReq := &querypb.ReleaseSegmentsRequest{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_ReleaseSegments,
// },
// NodeID: nodeID,
// SegmentIDs: segmentIDs,
// }
// baseTask := newBaseTask(ctx, querypb.TriggerCondition_LoadBalance)
// lbt.addChildTask(&releaseSegmentTask{
// baseTask: baseTask,
// ReleaseSegmentsRequest: releaseSegmentReq,
// cluster: lbt.cluster,
// leaderID: shardLeader,
// })
// }
// }
col2PartitionIDs := make(map[UniqueID][]UniqueID)
par2Segments := make(map[UniqueID][]*querypb.SegmentInfo)
for _, segmentID := range balancedSegmentIDs {
info := balancedSegmentInfos[segmentID]
collectionID := info.CollectionID
partitionID := info.PartitionID
if _, ok := col2PartitionIDs[collectionID]; !ok {
col2PartitionIDs[collectionID] = make([]UniqueID, 0)
}
if _, ok := par2Segments[partitionID]; !ok {
col2PartitionIDs[collectionID] = append(col2PartitionIDs[collectionID], partitionID)
par2Segments[partitionID] = make([]*querypb.SegmentInfo, 0)
}
par2Segments[partitionID] = append(par2Segments[partitionID], info)
}
loadSegmentReqs := make([]*querypb.LoadSegmentsRequest, 0)
for collectionID, partitionIDs := range col2PartitionIDs {
var watchDeltaChannels []*datapb.VchannelInfo
collectionInfo, err := lbt.meta.getCollectionInfoByID(collectionID)
if err != nil {
log.Error("loadBalanceTask: can't find collectionID in meta", zap.Int64("collectionID", collectionID), zap.Error(err))
lbt.setResultInfo(err)
return err
}
for _, partitionID := range partitionIDs {
dmChannelInfos, binlogs, err := lbt.broker.getRecoveryInfo(lbt.ctx, collectionID, partitionID)
if err != nil {
log.Error("loadBalanceTask: getRecoveryInfo failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Error(err))
lbt.setResultInfo(err)
return err
}
segmentID2Binlog := make(map[UniqueID]*datapb.SegmentBinlogs)
for _, binlog := range binlogs {
segmentID2Binlog[binlog.SegmentID] = binlog
}
for _, segmentInfo := range par2Segments[partitionID] {
segmentID := segmentInfo.SegmentID
if _, ok := segmentID2Binlog[segmentID]; !ok {
log.Warn("loadBalanceTask: can't find binlog of segment to balance, may be has been compacted", zap.Int64("segmentID", segmentID))
continue
}
segmentBingLog := segmentID2Binlog[segmentID]
segmentLoadInfo := lbt.broker.generateSegmentLoadInfo(ctx, collectionID, partitionID, segmentBingLog, true, collectionInfo.Schema)
msgBase := proto.Clone(lbt.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_LoadSegments
loadSegmentReq := &querypb.LoadSegmentsRequest{
Base: msgBase,
Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo},
Schema: collectionInfo.Schema,
CollectionID: collectionID,
ReplicaID: lbt.replicaID,
LoadMeta: &querypb.LoadMetaInfo{
CollectionID: collectionID,
PartitionIDs: collectionInfo.PartitionIDs,
},
}
loadSegmentReqs = append(loadSegmentReqs, loadSegmentReq)
}
for _, info := range dmChannelInfos {
deltaChannel, err := generateWatchDeltaChannelInfo(info)
if err != nil {
return err
}
watchDeltaChannels = append(watchDeltaChannels, deltaChannel)
}
}
mergedDeltaChannels := mergeWatchDeltaChannelInfo(watchDeltaChannels)
// If meta is not updated here, deltaChannel meta will not be available when loadSegment reschedule
err = lbt.meta.setDeltaChannel(collectionID, mergedDeltaChannels)
if err != nil {
log.Error("loadBalanceTask: set delta channel info to meta failed", zap.Error(err))
lbt.setResultInfo(err)
return err
}
}
internalTasks, err := assignInternalTask(ctx, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, nil, false, lbt.SourceNodeIDs, lbt.DstNodeIDs, lbt.replicaID, lbt.broker)
if err != nil {
log.Error("loadBalanceTask: assign child task failed", zap.Any("balance request", lbt.LoadBalanceRequest))
// check balanced sealedSegmentIDs in request whether exist in query node
for _, segmentID := range lbt.SealedSegmentIDs {
if _, ok := balancedSegmentInfos[segmentID]; !ok {
err := fmt.Errorf("loadBalanceTask: unloaded segment %d", segmentID)
log.Warn(err.Error())
lbt.setResultInfo(err)
return err
}
for _, internalTask := range internalTasks {
lbt.addChildTask(internalTask)
log.Info("loadBalanceTask: add a childTask", zap.String("task type", internalTask.msgType().String()), zap.Any("balance request", lbt.LoadBalanceRequest))
}
log.Info("loadBalanceTask: assign child task done", zap.Any("balance request", lbt.LoadBalanceRequest))
}
log.Info("loadBalanceTask Execute done",
zap.Int32("trigger type", int32(lbt.triggerCondition)),
zap.Int64s("sourceNodeIDs", lbt.SourceNodeIDs),
zap.Any("balanceReason", lbt.BalanceReason),
zap.Int64("taskID", lbt.getTaskID()))
if len(lbt.SealedSegmentIDs) != 0 {
balancedSegmentIDs = lbt.SealedSegmentIDs
}
// TODO(yah01): release balanced segments in source nodes
// balancedSegmentSet := make(typeutil.UniqueSet)
// balancedSegmentSet.Insert(balancedSegmentIDs...)
// for _, nodeID := range lbt.SourceNodeIDs {
// segments := lbt.meta.getSegmentInfosByNode(nodeID)
// shardSegments := make(map[string][]UniqueID)
// for _, segment := range segments {
// if !balancedSegmentSet.Contain(segment.SegmentID) {
// continue
// }
// shardSegments[segment.DmChannel] = append(shardSegments[segment.DmChannel], segment.SegmentID)
// }
// for dmc, segmentIDs := range shardSegments {
// shardLeader, err := getShardLeaderByNodeID(lbt.meta, lbt.replicaID, dmc)
// if err != nil {
// log.Error("failed to get shardLeader",
// zap.Int64("replicaID", lbt.replicaID),
// zap.Int64("nodeID", nodeID),
// zap.String("dmChannel", dmc),
// zap.Error(err))
// lbt.setResultInfo(err)
// return err
// }
// releaseSegmentReq := &querypb.ReleaseSegmentsRequest{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_ReleaseSegments,
// },
// NodeID: nodeID,
// SegmentIDs: segmentIDs,
// }
// baseTask := newBaseTask(ctx, querypb.TriggerCondition_LoadBalance)
// lbt.addChildTask(&releaseSegmentTask{
// baseTask: baseTask,
// ReleaseSegmentsRequest: releaseSegmentReq,
// cluster: lbt.cluster,
// leaderID: shardLeader,
// })
// }
// }
col2PartitionIDs := make(map[UniqueID][]UniqueID)
par2Segments := make(map[UniqueID][]*querypb.SegmentInfo)
for _, segmentID := range balancedSegmentIDs {
info := balancedSegmentInfos[segmentID]
collectionID := info.CollectionID
partitionID := info.PartitionID
if _, ok := col2PartitionIDs[collectionID]; !ok {
col2PartitionIDs[collectionID] = make([]UniqueID, 0)
}
if _, ok := par2Segments[partitionID]; !ok {
col2PartitionIDs[collectionID] = append(col2PartitionIDs[collectionID], partitionID)
par2Segments[partitionID] = make([]*querypb.SegmentInfo, 0)
}
par2Segments[partitionID] = append(par2Segments[partitionID], info)
}
loadSegmentReqs := make([]*querypb.LoadSegmentsRequest, 0)
for collectionID, partitionIDs := range col2PartitionIDs {
var watchDeltaChannels []*datapb.VchannelInfo
collectionInfo, err := lbt.meta.getCollectionInfoByID(collectionID)
if err != nil {
log.Error("loadBalanceTask: can't find collectionID in meta", zap.Int64("collectionID", collectionID), zap.Error(err))
lbt.setResultInfo(err)
return err
}
for _, partitionID := range partitionIDs {
dmChannelInfos, binlogs, err := lbt.broker.getRecoveryInfo(lbt.ctx, collectionID, partitionID)
if err != nil {
log.Error("loadBalanceTask: getRecoveryInfo failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Error(err))
lbt.setResultInfo(err)
return err
}
segmentID2Binlog := make(map[UniqueID]*datapb.SegmentBinlogs)
for _, binlog := range binlogs {
segmentID2Binlog[binlog.SegmentID] = binlog
}
for _, segmentInfo := range par2Segments[partitionID] {
segmentID := segmentInfo.SegmentID
if _, ok := segmentID2Binlog[segmentID]; !ok {
log.Warn("loadBalanceTask: can't find binlog of segment to balance, may be has been compacted", zap.Int64("segmentID", segmentID))
continue
}
segmentBingLog := segmentID2Binlog[segmentID]
segmentLoadInfo := lbt.broker.generateSegmentLoadInfo(ctx, collectionID, partitionID, segmentBingLog, true, collectionInfo.Schema)
msgBase := proto.Clone(lbt.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_LoadSegments
loadSegmentReq := &querypb.LoadSegmentsRequest{
Base: msgBase,
Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo},
Schema: collectionInfo.Schema,
CollectionID: collectionID,
ReplicaID: lbt.replicaID,
LoadMeta: &querypb.LoadMetaInfo{
CollectionID: collectionID,
PartitionIDs: collectionInfo.PartitionIDs,
},
}
loadSegmentReqs = append(loadSegmentReqs, loadSegmentReq)
}
for _, info := range dmChannelInfos {
deltaChannel, err := generateWatchDeltaChannelInfo(info)
if err != nil {
return err
}
watchDeltaChannels = append(watchDeltaChannels, deltaChannel)
}
}
mergedDeltaChannels := mergeWatchDeltaChannelInfo(watchDeltaChannels)
// If meta is not updated here, deltaChannel meta will not be available when loadSegment reschedule
err = lbt.meta.setDeltaChannel(collectionID, mergedDeltaChannels)
if err != nil {
log.Error("loadBalanceTask: set delta channel info to meta failed", zap.Error(err))
lbt.setResultInfo(err)
return err
}
}
internalTasks, err := assignInternalTask(ctx, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, nil, false, lbt.SourceNodeIDs, lbt.DstNodeIDs, lbt.replicaID, lbt.broker)
if err != nil {
log.Error("loadBalanceTask: assign child task failed", zap.Any("balance request", lbt.LoadBalanceRequest))
lbt.setResultInfo(err)
return err
}
for _, internalTask := range internalTasks {
lbt.addChildTask(internalTask)
log.Info("loadBalanceTask: add a childTask", zap.String("task type", internalTask.msgType().String()), zap.Any("balance request", lbt.LoadBalanceRequest))
}
log.Info("loadBalanceTask: assign child task done", zap.Any("balance request", lbt.LoadBalanceRequest))
return nil
}
@ -2199,22 +2230,6 @@ func (lbt *loadBalanceTask) postExecute(context.Context) error {
lbt.clearChildTasks()
}
// if loadBalanceTask execute failed after query node down, the lbt.getResultInfo().ErrorCode will be set to commonpb.ErrorCode_UnexpectedError
// then the queryCoord will panic, and the nodeInfo should not be removed immediately
// after queryCoord recovery, the balanceTask will redo
if lbt.triggerCondition == querypb.TriggerCondition_NodeDown && lbt.getResultInfo().ErrorCode == commonpb.ErrorCode_Success {
for _, offlineNodeID := range lbt.SourceNodeIDs {
err := lbt.cluster.removeNodeInfo(offlineNodeID)
if err != nil {
//TODO:: clear node info after removeNodeInfo failed
log.Warn("loadBalanceTask: occur error when removing node info from cluster",
zap.Int64("nodeID", offlineNodeID),
zap.Error(err))
continue
}
}
}
log.Info("loadBalanceTask postExecute done",
zap.Int32("trigger type", int32(lbt.triggerCondition)),
zap.Int64s("sourceNodeIDs", lbt.SourceNodeIDs),
@ -2247,9 +2262,9 @@ func (lbt *loadBalanceTask) globalPostExecute(ctx context.Context) error {
}
log.Debug("removing offline nodes from replicas and segments...",
zap.Int("len(replicas)", len(replicas)),
zap.Int("len(segments)", len(segments)),
zap.Int64("trigger task ID", lbt.getTaskID()),
zap.Int("replicaNum", len(replicas)),
zap.Int("segmentNum", len(segments)),
zap.Int64("triggerTaskID", lbt.getTaskID()),
)
// Remove offline nodes from replica
@ -2282,6 +2297,20 @@ func (lbt *loadBalanceTask) globalPostExecute(ctx context.Context) error {
return nil
})
}
// if loadBalanceTask execute failed after query node down, the lbt.getResultInfo().ErrorCode will be set to commonpb.ErrorCode_UnexpectedError
// then the queryCoord will panic, and the nodeInfo should not be removed immediately
// after queryCoord recovery, the balanceTask will redo
for _, offlineNodeID := range lbt.SourceNodeIDs {
err := lbt.cluster.removeNodeInfo(offlineNodeID)
if err != nil {
log.Error("loadBalanceTask: occur error when removing node info from cluster",
zap.Int64("nodeID", offlineNodeID),
zap.Error(err))
lbt.setResultInfo(err)
return err
}
}
}
for _, segment := range segments {