From 8388478ef3a918f5a1bc8f29a36680aaa352f03e Mon Sep 17 00:00:00 2001 From: yah01 Date: Fri, 24 Jun 2022 23:24:15 +0800 Subject: [PATCH] SyncReplicaSegments syncs all segments (#17774) Signed-off-by: yah01 --- internal/querycoord/task.go | 331 +++++++++++++------------- internal/querycoord/task_scheduler.go | 45 ++-- internal/querycoord/util.go | 159 +++++++------ 3 files changed, 278 insertions(+), 257 deletions(-) diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index b6ee027396..7f2851a31e 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -356,17 +356,7 @@ func (lct *loadCollectionTask) updateTaskProcess() { } if allDone { - err := syncReplicaSegments(lct.ctx, lct.cluster, childTasks) - if err != nil { - log.Error("loadCollectionTask: failed to sync replica segments to shard leader", - zap.Int64("taskID", lct.getTaskID()), - zap.Int64("collectionID", collectionID), - zap.Error(err)) - lct.setResultInfo(err) - return - } - - err = lct.meta.setLoadPercentage(collectionID, 0, 100, querypb.LoadType_LoadCollection) + err := lct.meta.setLoadPercentage(collectionID, 0, 100, querypb.LoadType_LoadCollection) if err != nil { log.Error("loadCollectionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID)) lct.setResultInfo(err) @@ -609,6 +599,32 @@ func (lct *loadCollectionTask) postExecute(ctx context.Context) error { return nil } +func (lct *loadCollectionTask) globalPostExecute(ctx context.Context) error { + collection, err := lct.meta.getCollectionInfoByID(lct.CollectionID) + if err != nil { + log.Error("loadCollectionTask: failed to get collection info from meta", + zap.Int64("taskID", lct.getTaskID()), + zap.Int64("collectionID", lct.CollectionID), + zap.Error(err)) + + return err + } + + for _, replica := range collection.ReplicaIds { + err := syncReplicaSegments(lct.ctx, lct.meta, lct.cluster, replica) + if err != nil { + log.Error("loadCollectionTask: failed to sync replica segments to shard leader", + zap.Int64("taskID", lct.getTaskID()), + zap.Int64("collectionID", lct.CollectionID), + zap.Error(err)) + + return err + } + } + + return nil +} + func (lct *loadCollectionTask) rollBack(ctx context.Context) []task { onlineNodeIDs := lct.cluster.OnlineNodeIDs() resultTasks := make([]task, 0) @@ -804,16 +820,6 @@ func (lpt *loadPartitionTask) updateTaskProcess() { } } if allDone { - err := syncReplicaSegments(lpt.ctx, lpt.cluster, childTasks) - if err != nil { - log.Error("loadPartitionTask: failed to sync replica segments to shard leader", - zap.Int64("taskID", lpt.getTaskID()), - zap.Int64("collectionID", collectionID), - zap.Error(err)) - lpt.setResultInfo(err) - return - } - for _, id := range partitionIDs { err := lpt.meta.setLoadPercentage(collectionID, id, 100, querypb.LoadType_LoadPartition) if err != nil { @@ -1049,6 +1055,34 @@ func (lpt *loadPartitionTask) postExecute(ctx context.Context) error { return nil } +func (lpt *loadPartitionTask) globalPostExecute(ctx context.Context) error { + collectionID := lpt.CollectionID + + collection, err := lpt.meta.getCollectionInfoByID(collectionID) + if err != nil { + log.Error("loadPartitionTask: failed to get collection info from meta", + zap.Int64("taskID", lpt.getTaskID()), + zap.Int64("collectionID", collectionID), + zap.Error(err)) + + return err + } + + for _, replica := range collection.ReplicaIds { + err := syncReplicaSegments(lpt.ctx, lpt.meta, lpt.cluster, replica) + if err != nil { + log.Error("loadPartitionTask: failed to sync replica segments to shard leader", + zap.Int64("taskID", lpt.getTaskID()), + zap.Int64("collectionID", collectionID), + zap.Error(err)) + + return err + } + } + + return nil +} + func (lpt *loadPartitionTask) rollBack(ctx context.Context) []task { collectionID := lpt.CollectionID resultTasks := make([]task, 0) @@ -2264,167 +2298,144 @@ func (lbt *loadBalanceTask) postExecute(context.Context) error { } func (lbt *loadBalanceTask) globalPostExecute(ctx context.Context) error { - if len(lbt.getChildTask()) > 0 { - replicas := make(map[UniqueID]*milvuspb.ReplicaInfo) - segments := make(map[UniqueID]*querypb.SegmentInfo) - dmChannels := make(map[string]*querypb.DmChannelWatchInfo) + if lbt.BalanceReason != querypb.TriggerCondition_NodeDown { + return nil + } - for _, id := range lbt.SourceNodeIDs { - for _, segment := range lbt.meta.getSegmentInfosByNode(id) { - segments[segment.SegmentID] = segment - } - for _, dmChannel := range lbt.meta.getDmChannelInfosByNodeID(id) { - dmChannels[dmChannel.DmChannel] = dmChannel - } + replicas := make(map[UniqueID]*milvuspb.ReplicaInfo) + segments := make(map[UniqueID]*querypb.SegmentInfo) + dmChannels := make(map[string]*querypb.DmChannelWatchInfo) - nodeReplicas, err := lbt.meta.getReplicasByNodeID(id) + for _, id := range lbt.SourceNodeIDs { + for _, segment := range lbt.meta.getSegmentInfosByNode(id) { + segments[segment.SegmentID] = segment + } + for _, dmChannel := range lbt.meta.getDmChannelInfosByNodeID(id) { + dmChannels[dmChannel.DmChannel] = dmChannel + } + + nodeReplicas, err := lbt.meta.getReplicasByNodeID(id) + if err != nil { + log.Warn("failed to get replicas for removing offline querynode from it", + zap.Int64("querynodeID", id), + zap.Error(err)) + + continue + } + for _, replica := range nodeReplicas { + replicas[replica.ReplicaID] = replica + } + } + + log.Debug("removing offline nodes from replicas and segments...", + zap.Int("replicaNum", len(replicas)), + zap.Int("segmentNum", len(segments)), + zap.Int64("triggerTaskID", lbt.getTaskID()), + ) + + wg := errgroup.Group{} + // Remove offline nodes from replica + for replicaID := range replicas { + replicaID := replicaID + wg.Go(func() error { + return lbt.meta.applyReplicaBalancePlan( + NewRemoveBalancePlan(replicaID, lbt.SourceNodeIDs...)) + }) + } + + // Remove offline nodes from dmChannels + for _, dmChannel := range dmChannels { + dmChannel := dmChannel + wg.Go(func() error { + dmChannel.NodeIds = removeFromSlice(dmChannel.NodeIds, lbt.SourceNodeIDs...) + + err := lbt.meta.setDmChannelInfos(dmChannel) if err != nil { - log.Warn("failed to get replicas for removing offline querynode from it", - zap.Int64("querynodeID", id), + log.Error("failed to remove offline nodes from dmChannel info", + zap.String("dmChannel", dmChannel.DmChannel), zap.Error(err)) - continue - } - for _, replica := range nodeReplicas { - replicas[replica.ReplicaID] = replica - } - } - - log.Debug("removing offline nodes from replicas and segments...", - zap.Int("replicaNum", len(replicas)), - zap.Int("segmentNum", len(segments)), - zap.Int64("triggerTaskID", lbt.getTaskID()), - ) - - wg := errgroup.Group{} - if lbt.triggerCondition == querypb.TriggerCondition_NodeDown { - // Remove offline nodes from replica - for replicaID := range replicas { - replicaID := replicaID - wg.Go(func() error { - return lbt.meta.applyReplicaBalancePlan( - NewRemoveBalancePlan(replicaID, lbt.SourceNodeIDs...)) - }) + return err } - // Remove offline nodes from dmChannels - for _, dmChannel := range dmChannels { - dmChannel := dmChannel - wg.Go(func() error { - dmChannel.NodeIds = removeFromSlice(dmChannel.NodeIds, lbt.SourceNodeIDs...) + log.Info("remove offline nodes from dmChannel", + zap.Int64("taskID", lbt.getTaskID()), + zap.String("dmChannel", dmChannel.DmChannel), + zap.Int64s("nodeIds", dmChannel.NodeIds)) - err := lbt.meta.setDmChannelInfos(dmChannel) - if err != nil { - log.Error("failed to remove offline nodes from dmChannel info", - zap.String("dmChannel", dmChannel.DmChannel), - zap.Error(err)) + return nil + }) + } - return err - } + // Update shard leaders for replicas + for _, childTask := range lbt.getChildTask() { + if task, ok := childTask.(*watchDmChannelTask); ok { + wg.Go(func() error { + leaderID := task.NodeID + dmChannel := task.Infos[0].ChannelName - log.Info("remove offline nodes from dmChannel", - zap.Int64("taskID", lbt.getTaskID()), - zap.String("dmChannel", dmChannel.DmChannel), - zap.Int64s("nodeIds", dmChannel.NodeIds)) - - return nil - }) - } - } - - // Remove offline nodes from segment - // for _, segment := range segments { - // segment := segment - // wg.Go(func() error { - // segment.NodeID = -1 - // segment.NodeIds = removeFromSlice(segment.NodeIds, lbt.SourceNodeIDs...) - - // err := lbt.meta.saveSegmentInfo(segment) - // if err != nil { - // log.Error("failed to remove offline nodes from segment info", - // zap.Int64("segmentID", segment.SegmentID), - // zap.Error(err)) - - // return err - // } - - // log.Info("remove offline nodes from segment", - // zap.Int64("taskID", lbt.getTaskID()), - // zap.Int64("segmentID", segment.GetSegmentID()), - // zap.Int64s("nodeIds", segment.GetNodeIds())) - - // return nil - // }) - // } - - // Wait for the previous goroutines, - // which conflicts with the code below due to modifing replicas - err := wg.Wait() - if err != nil { - return err - } - for _, childTask := range lbt.getChildTask() { - if task, ok := childTask.(*watchDmChannelTask); ok { - wg.Go(func() error { - leaderID := task.NodeID - dmChannel := task.Infos[0].ChannelName - - nodeInfo, err := lbt.cluster.GetNodeInfoByID(leaderID) - if err != nil { - log.Error("failed to get node info to update shard leader info", - zap.Int64("triggerTaskID", lbt.getTaskID()), - zap.Int64("taskID", task.getTaskID()), - zap.Int64("nodeID", leaderID), - zap.String("dmChannel", dmChannel), - zap.Error(err)) - return err - } - - err = lbt.meta.updateShardLeader(task.ReplicaID, dmChannel, leaderID, nodeInfo.(*queryNode).address) - if err != nil { - log.Error("failed to update shard leader info of replica", - zap.Int64("triggerTaskID", lbt.getTaskID()), - zap.Int64("taskID", task.getTaskID()), - zap.Int64("replicaID", task.ReplicaID), - zap.String("dmChannel", dmChannel), - zap.Error(err)) - return err - } - - log.Debug("LoadBalance: update shard leader", + nodeInfo, err := lbt.cluster.GetNodeInfoByID(leaderID) + if err != nil { + log.Error("failed to get node info to update shard leader info", zap.Int64("triggerTaskID", lbt.getTaskID()), zap.Int64("taskID", task.getTaskID()), + zap.Int64("nodeID", leaderID), zap.String("dmChannel", dmChannel), - zap.Int64("leader", leaderID)) + zap.Error(err)) + return err + } - return nil - }) - } + err = lbt.meta.updateShardLeader(task.ReplicaID, dmChannel, leaderID, nodeInfo.(*queryNode).address) + if err != nil { + log.Error("failed to update shard leader info of replica", + zap.Int64("triggerTaskID", lbt.getTaskID()), + zap.Int64("taskID", task.getTaskID()), + zap.Int64("replicaID", task.ReplicaID), + zap.String("dmChannel", dmChannel), + zap.Error(err)) + return err + } + + log.Debug("LoadBalance: update shard leader", + zap.Int64("triggerTaskID", lbt.getTaskID()), + zap.Int64("taskID", task.getTaskID()), + zap.String("dmChannel", dmChannel), + zap.Int64("leader", leaderID)) + + return nil + }) } - err = wg.Wait() - if err != nil { - return err + } + + err := wg.Wait() + if err != nil { + return err + } + + for replicaID := range replicas { + shards := make([]string, 0, len(dmChannels)) + for _, dmc := range dmChannels { + shards = append(shards, dmc.DmChannel) } - err = syncReplicaSegments(ctx, lbt.cluster, lbt.getChildTask()) + err := syncReplicaSegments(lbt.ctx, lbt.meta, lbt.cluster, replicaID, shards...) if err != nil { + log.Error("loadBalanceTask: failed to sync segments distribution", + zap.Int64("collectionID", lbt.CollectionID), + zap.Int64("replicaID", lbt.replicaID), + zap.Error(err)) return err } } - // if loadBalanceTask execute failed after query node down, the lbt.getResultInfo().ErrorCode will be set to commonpb.ErrorCode_UnexpectedError - // then the queryCoord will panic, and the nodeInfo should not be removed immediately - // after queryCoord recovery, the balanceTask will redo - if lbt.BalanceReason == querypb.TriggerCondition_NodeDown { - for _, offlineNodeID := range lbt.SourceNodeIDs { - err := lbt.cluster.RemoveNodeInfo(offlineNodeID) - if err != nil { - log.Error("loadBalanceTask: occur error when removing node info from cluster", - zap.Int64("nodeID", offlineNodeID), - zap.Error(err)) - lbt.setResultInfo(err) - return err - } + for _, offlineNodeID := range lbt.SourceNodeIDs { + err := lbt.cluster.RemoveNodeInfo(offlineNodeID) + if err != nil { + log.Error("loadBalanceTask: occur error when removing node info from cluster", + zap.Int64("nodeID", offlineNodeID), + zap.Error(err)) + lbt.setResultInfo(err) + return err } } diff --git a/internal/querycoord/task_scheduler.go b/internal/querycoord/task_scheduler.go index 1663fbe054..9127896ba1 100644 --- a/internal/querycoord/task_scheduler.go +++ b/internal/querycoord/task_scheduler.go @@ -660,27 +660,13 @@ func (scheduler *TaskScheduler) scheduleLoop() { processInternalTaskFn(derivedInternalTasks, triggerTask) } } + } - //TODO::xige-16, judging the triggerCondition is ugly, the taskScheduler will be refactored soon - // if query node down, the loaded segment and watched dmChannel by the node should be balance to new querynode - // if triggerCondition == NodeDown, loadSegment and watchDmchannel request will keep reschedule until the success - // the node info has been deleted after assgining child task to triggerTask - // so it is necessary to update the meta of segment and dmchannel, or some data may be lost in meta - resultInfo := triggerTask.getResultInfo() - if resultInfo.ErrorCode != commonpb.ErrorCode_Success { - if !alreadyNotify { - triggerTask.notify(errors.New(resultInfo.Reason)) - alreadyNotify = true - } - rollBackTasks := triggerTask.rollBack(scheduler.ctx) - log.Info("scheduleLoop: start rollBack after triggerTask failed", - zap.Int64("triggerTaskID", triggerTask.getTaskID()), - zap.Any("rollBackTasks", rollBackTasks), - zap.String("error", resultInfo.Reason)) - // there is no need to save rollBacked internal task to etcd - // After queryCoord recover, it will retry failed childTask - // if childTask still execute failed, then reProduce rollBacked tasks - processInternalTaskFn(rollBackTasks, triggerTask) + // triggerTask may be LoadCollection, LoadPartitions, LoadBalance, Handoff + if triggerTask.getResultInfo().ErrorCode == commonpb.ErrorCode_Success || triggerTask.getTriggerCondition() == querypb.TriggerCondition_NodeDown { + err = updateSegmentInfoFromTask(scheduler.ctx, triggerTask, scheduler.meta) + if err != nil { + triggerTask.setResultInfo(err) } } @@ -694,12 +680,21 @@ func (scheduler *TaskScheduler) scheduleLoop() { } } - // triggerTask may be LoadCollection, LoadPartitions, LoadBalance, Handoff - if triggerTask.getResultInfo().ErrorCode == commonpb.ErrorCode_Success || triggerTask.getTriggerCondition() == querypb.TriggerCondition_NodeDown { - err = updateSegmentInfoFromTask(scheduler.ctx, triggerTask, scheduler.meta) - if err != nil { - triggerTask.setResultInfo(err) + resultInfo := triggerTask.getResultInfo() + if resultInfo.ErrorCode != commonpb.ErrorCode_Success { + if !alreadyNotify { + triggerTask.notify(errors.New(resultInfo.Reason)) + alreadyNotify = true } + rollBackTasks := triggerTask.rollBack(scheduler.ctx) + log.Info("scheduleLoop: start rollBack after triggerTask failed", + zap.Int64("triggerTaskID", triggerTask.getTaskID()), + zap.Any("rollBackTasks", rollBackTasks), + zap.String("error", resultInfo.Reason)) + // there is no need to save rollBacked internal task to etcd + // After queryCoord recover, it will retry failed childTask + // if childTask still execute failed, then reProduce rollBacked tasks + processInternalTaskFn(rollBackTasks, triggerTask) } err = removeTaskFromKVFn(triggerTask) diff --git a/internal/querycoord/util.go b/internal/querycoord/util.go index 6ce01aa8c5..9f2b0b1e73 100644 --- a/internal/querycoord/util.go +++ b/internal/querycoord/util.go @@ -18,12 +18,15 @@ package querycoord import ( "context" + "sort" + "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/typeutil" + "go.uber.org/zap" ) func getCompareMapFromSlice(sliceData []int64) map[int64]struct{} { @@ -106,95 +109,107 @@ func getDstNodeIDByTask(t task) int64 { return nodeID } -func syncReplicaSegments(ctx context.Context, cluster Cluster, childTasks []task) error { - type SegmentIndex struct { - NodeID UniqueID - PartitionID UniqueID - ReplicaID UniqueID +// syncReplicaSegments syncs the segments distribution of replica to shard leaders +// only syncs the segments in shards if not nil +func syncReplicaSegments(ctx context.Context, meta Meta, cluster Cluster, replicaID UniqueID, shards ...string) error { + replica, err := meta.getReplicaByID(replicaID) + if err != nil { + return err } - type ShardLeader struct { - ReplicaID UniqueID - LeaderID UniqueID + collectionSegments := make(map[UniqueID]*querypb.SegmentInfo) + for _, segment := range meta.showSegmentInfos(replica.CollectionID, nil) { + collectionSegments[segment.SegmentID] = segment } - shardSegments := make(map[string]map[SegmentIndex]typeutil.UniqueSet) // DMC -> set[Segment] - shardLeaders := make(map[string][]*ShardLeader) // DMC -> leader - for _, childTask := range childTasks { - switch task := childTask.(type) { - case *loadSegmentTask: - nodeID := getDstNodeIDByTask(task) - for _, segment := range task.Infos { - segments, ok := shardSegments[segment.InsertChannel] - if !ok { - segments = make(map[SegmentIndex]typeutil.UniqueSet) - } - - index := SegmentIndex{ - NodeID: nodeID, - PartitionID: segment.PartitionID, - ReplicaID: task.ReplicaID, - } - - _, ok = segments[index] - if !ok { - segments[index] = make(typeutil.UniqueSet) - } - segments[index].Insert(segment.SegmentID) - - shardSegments[segment.InsertChannel] = segments - } - - case *watchDmChannelTask: - leaderID := getDstNodeIDByTask(task) - leader := &ShardLeader{ - ReplicaID: task.ReplicaID, - LeaderID: leaderID, - } - - for _, dmc := range task.Infos { - leaders, ok := shardLeaders[dmc.ChannelName] - if !ok { - leaders = make([]*ShardLeader, 0) - } - - leaders = append(leaders, leader) - - shardLeaders[dmc.ChannelName] = leaders - } + shardSegments := make(map[string][]*querypb.SegmentInfo) // DMC -> []SegmentInfo + for _, segment := range collectionSegments { + // Group segments by shard + segments, ok := shardSegments[segment.DmChannel] + if !ok { + segments = make([]*querypb.SegmentInfo, 0) } + + segments = append(segments, segment) + shardSegments[segment.DmChannel] = segments } - for dmc, leaders := range shardLeaders { - // invoke sync segments even no segment - segments := shardSegments[dmc] + for _, shard := range replica.ShardReplicas { + if len(shards) > 0 && !isInShards(shard.DmChannelName, shards) { + continue + } - for _, leader := range leaders { - req := querypb.SyncReplicaSegmentsRequest{ - VchannelName: dmc, - ReplicaSegments: make([]*querypb.ReplicaSegmentsInfo, 0, len(segments)), + segments := shardSegments[shard.DmChannelName] + req := querypb.SyncReplicaSegmentsRequest{ + VchannelName: shard.DmChannelName, + ReplicaSegments: make([]*querypb.ReplicaSegmentsInfo, 0, len(segments)), + } + + sort.Slice(segments, func(i, j int) bool { + inode := getNodeInReplica(replica, segments[i].NodeIds) + jnode := getNodeInReplica(replica, segments[j].NodeIds) + + return inode < jnode || + inode == jnode && segments[i].PartitionID < segments[j].PartitionID + }) + + for i, j := 0, 0; i < len(segments); i = j { + node := getNodeInReplica(replica, segments[i].NodeIds) + partition := segments[i].PartitionID + + j++ + for j < len(segments) && + getNodeInReplica(replica, segments[j].NodeIds) == node && + segments[j].PartitionID == partition { + j++ } - for index, segmentSet := range segments { - if index.ReplicaID == leader.ReplicaID { - req.ReplicaSegments = append(req.ReplicaSegments, - &querypb.ReplicaSegmentsInfo{ - NodeId: index.NodeID, - PartitionId: index.PartitionID, - SegmentIds: segmentSet.Collect(), - }) - } - } - err := cluster.SyncReplicaSegments(ctx, leader.LeaderID, &req) - if err != nil { - return err + segmentIds := make([]UniqueID, 0, len(segments[i:j])) + for _, segment := range segments[i:j] { + segmentIds = append(segmentIds, segment.SegmentID) } + + req.ReplicaSegments = append(req.ReplicaSegments, &querypb.ReplicaSegmentsInfo{ + NodeId: node, + PartitionId: partition, + SegmentIds: segmentIds, + }) + } + + log.Debug("sync replica segments", + zap.Int64("replicaID", replicaID), + zap.Int64("leaderID", shard.LeaderID), + zap.Any("req", req)) + err := cluster.SyncReplicaSegments(ctx, shard.LeaderID, &req) + if err != nil { + return err } } return nil } +func isInShards(shard string, shards []string) bool { + for _, item := range shards { + if shard == item { + return true + } + } + + return false +} + +// getNodeInReplica gets the node which is in the replica +func getNodeInReplica(replica *milvuspb.ReplicaInfo, nodes []UniqueID) UniqueID { + for _, node := range nodes { + if nodeIncluded(node, replica.NodeIds) { + return node + } + } + + return 0 +} + func removeFromSlice(origin []UniqueID, del ...UniqueID) []UniqueID { set := make(typeutil.UniqueSet, len(origin)) set.Insert(origin...)