fix: Make leader checker generate leader task instead of segment task (#30258)

See also #30150

For leader view distribution with offline nodes, a release task can
never be sent to querynode due to targetNode online check logic. Even
the request is dispatched, normal release task does not have "force"
flag when calling `delegator.ReleaseSegment`.

This PR adds a new type of querycoord task: LeaderTask, the
responsibility of which is to rectify leader view distribtion.

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/30713/head
congqixia 2024-02-21 11:08:51 +08:00 committed by GitHub
parent 81d6cb1a0c
commit 7b91fa3db8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 592 additions and 100 deletions

View File

@ -22,7 +22,6 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
@ -139,23 +138,16 @@ func (c *LeaderChecker) findNeedLoadedSegments(ctx context.Context, replica int6
log.RatedDebug(10, "leader checker append a segment to set",
zap.Int64("segmentID", s.GetID()),
zap.Int64("nodeID", s.Node))
action := task.NewSegmentActionWithScope(s.Node, task.ActionTypeGrow, s.GetInsertChannel(), s.GetID(), querypb.DataScope_Historical)
t, err := task.NewSegmentTask(
action := task.NewLeaderAction(leaderView.ID, s.Node, task.ActionTypeGrow, s.GetInsertChannel(), s.GetID())
t := task.NewLeaderTask(
ctx,
params.Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond),
c.ID(),
s.GetCollectionID(),
replica,
leaderView.ID,
action,
)
if err != nil {
log.Warn("create segment update task failed",
zap.Int64("segmentID", s.GetID()),
zap.Int64("node", s.Node),
zap.Error(err),
)
continue
}
// index task shall have lower or equal priority than balance task
t.SetPriority(task.TaskPriorityHigh)
t.SetReason("add segment to leader view")
@ -189,23 +181,16 @@ func (c *LeaderChecker) findNeedRemovedSegments(ctx context.Context, replica int
log.Debug("leader checker append a segment to remove",
zap.Int64("segmentID", sid),
zap.Int64("nodeID", s.NodeID))
action := task.NewSegmentActionWithScope(s.NodeID, task.ActionTypeReduce, leaderView.Channel, sid, querypb.DataScope_Historical)
t, err := task.NewSegmentTask(
action := task.NewLeaderAction(leaderView.ID, s.NodeID, task.ActionTypeReduce, leaderView.Channel, sid)
t := task.NewLeaderTask(
ctx,
paramtable.Get().QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond),
c.ID(),
leaderView.CollectionID,
replica,
leaderView.ID,
action,
)
if err != nil {
log.Warn("create segment reduce task failed",
zap.Int64("segmentID", sid),
zap.Int64("nodeID", s.NodeID),
zap.Error(err))
continue
}
t.SetPriority(task.TaskPriorityHigh)
t.SetReason("remove segment from leader view")

View File

@ -119,7 +119,7 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegments() {
suite.Len(tasks[0].Actions(), 1)
suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeGrow)
suite.Equal(tasks[0].Actions()[0].Node(), int64(1))
suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(1))
suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(1))
suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh)
}
@ -161,7 +161,7 @@ func (suite *LeaderCheckerTestSuite) TestActivation() {
suite.Len(tasks[0].Actions(), 1)
suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeGrow)
suite.Equal(tasks[0].Actions()[0].Node(), int64(1))
suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(1))
suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(1))
suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh)
}
@ -236,7 +236,7 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncLoadedSegments() {
suite.Len(tasks[0].Actions(), 1)
suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeGrow)
suite.Equal(tasks[0].Actions()[0].Node(), int64(1))
suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(1))
suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(1))
suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh)
}
@ -289,7 +289,7 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreBalancedSegment() {
suite.Len(tasks[0].Actions(), 1)
suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeGrow)
suite.Equal(tasks[0].Actions()[0].Node(), int64(1))
suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(1))
suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(1))
suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh)
}
@ -334,7 +334,7 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegmentsWithReplicas() {
suite.Len(tasks[0].Actions(), 1)
suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeGrow)
suite.Equal(tasks[0].Actions()[0].Node(), int64(1))
suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(1))
suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(1))
suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh)
}
@ -368,7 +368,7 @@ func (suite *LeaderCheckerTestSuite) TestSyncRemovedSegments() {
suite.Len(tasks[0].Actions(), 1)
suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeReduce)
suite.Equal(tasks[0].Actions()[0].Node(), int64(1))
suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(3))
suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(3))
suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh)
}
@ -405,7 +405,7 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncRemovedSegments() {
suite.Len(tasks[0].Actions(), 1)
suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeReduce)
suite.Equal(tasks[0].Actions()[0].Node(), int64(2))
suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(3))
suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(3))
suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh)
}

View File

@ -23,7 +23,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/pkg/util/funcutil"
. "github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type ActionType int32
@ -51,12 +51,12 @@ type Action interface {
}
type BaseAction struct {
nodeID UniqueID
nodeID typeutil.UniqueID
typ ActionType
shard string
}
func NewBaseAction(nodeID UniqueID, typ ActionType, shard string) *BaseAction {
func NewBaseAction(nodeID typeutil.UniqueID, typ ActionType, shard string) *BaseAction {
return &BaseAction{
nodeID: nodeID,
typ: typ,
@ -79,17 +79,17 @@ func (action *BaseAction) Shard() string {
type SegmentAction struct {
*BaseAction
segmentID UniqueID
segmentID typeutil.UniqueID
scope querypb.DataScope
rpcReturned atomic.Bool
}
func NewSegmentAction(nodeID UniqueID, typ ActionType, shard string, segmentID UniqueID) *SegmentAction {
func NewSegmentAction(nodeID typeutil.UniqueID, typ ActionType, shard string, segmentID typeutil.UniqueID) *SegmentAction {
return NewSegmentActionWithScope(nodeID, typ, shard, segmentID, querypb.DataScope_All)
}
func NewSegmentActionWithScope(nodeID UniqueID, typ ActionType, shard string, segmentID UniqueID, scope querypb.DataScope) *SegmentAction {
func NewSegmentActionWithScope(nodeID typeutil.UniqueID, typ ActionType, shard string, segmentID typeutil.UniqueID, scope querypb.DataScope) *SegmentAction {
base := NewBaseAction(nodeID, typ, shard)
return &SegmentAction{
BaseAction: base,
@ -99,7 +99,7 @@ func NewSegmentActionWithScope(nodeID UniqueID, typ ActionType, shard string, se
}
}
func (action *SegmentAction) SegmentID() UniqueID {
func (action *SegmentAction) SegmentID() typeutil.UniqueID {
return action.segmentID
}
@ -143,7 +143,7 @@ type ChannelAction struct {
*BaseAction
}
func NewChannelAction(nodeID UniqueID, typ ActionType, channelName string) *ChannelAction {
func NewChannelAction(nodeID typeutil.UniqueID, typ ActionType, channelName string) *ChannelAction {
return &ChannelAction{
BaseAction: NewBaseAction(nodeID, typ, channelName),
}
@ -160,3 +160,43 @@ func (action *ChannelAction) IsFinished(distMgr *meta.DistributionManager) bool
return hasNode == isGrow
}
type LeaderAction struct {
*BaseAction
leaderID typeutil.UniqueID
segmentID typeutil.UniqueID
rpcReturned atomic.Bool
}
func NewLeaderAction(leaderID, workerID typeutil.UniqueID, typ ActionType, shard string, segmentID typeutil.UniqueID) *LeaderAction {
action := &LeaderAction{
BaseAction: NewBaseAction(workerID, typ, shard),
leaderID: leaderID,
segmentID: segmentID,
}
action.rpcReturned.Store(false)
return action
}
func (action *LeaderAction) SegmentID() typeutil.UniqueID {
return action.segmentID
}
func (action *LeaderAction) IsFinished(distMgr *meta.DistributionManager) bool {
views := distMgr.LeaderViewManager.GetLeaderView(action.leaderID)
view := views[action.Shard()]
if view == nil {
return false
}
dist := view.Segments[action.SegmentID()]
switch action.Type() {
case ActionTypeGrow:
return action.rpcReturned.Load() && dist != nil && dist.NodeID == action.Node()
case ActionTypeReduce:
return action.rpcReturned.Load() && (dist == nil || dist.NodeID != action.Node())
}
return false
}

View File

@ -26,6 +26,8 @@ import (
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
@ -33,6 +35,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/merr"
@ -111,6 +114,9 @@ func (ex *Executor) Execute(task Task, step int) bool {
case *ChannelAction:
ex.executeDmChannelAction(task.(*ChannelTask), step)
case *LeaderAction:
ex.executeLeaderAction(task.(*LeaderTask), step)
}
}()
@ -162,70 +168,15 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
ex.removeTask(task, step)
}()
collectionInfo, err := ex.broker.DescribeCollection(ctx, task.CollectionID())
collectionInfo, loadMeta, channel, err := ex.getMetaInfo(ctx, task)
if err != nil {
log.Warn("failed to get collection info", zap.Error(err))
return err
}
partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID())
if err != nil {
log.Warn("failed to get partitions of collection", zap.Error(err))
return err
}
loadMeta := packLoadMeta(
ex.meta.GetLoadType(task.CollectionID()),
task.CollectionID(),
partitions...,
)
// get channel first, in case of target updated after segment info fetched
channel := ex.targetMgr.GetDmChannel(task.CollectionID(), task.shard, meta.NextTargetFirst)
if channel == nil {
return merr.WrapErrChannelNotAvailable(task.shard)
}
resp, err := ex.broker.GetSegmentInfo(ctx, task.SegmentID())
if err != nil || len(resp.GetInfos()) == 0 {
log.Warn("failed to get segment info from DataCoord", zap.Error(err))
loadInfo, indexInfos, err := ex.getLoadInfo(ctx, task.CollectionID(), action.SegmentID(), channel)
if err != nil {
return err
}
segment := resp.GetInfos()[0]
log = log.With(zap.String("level", segment.GetLevel().String()))
indexes, err := ex.broker.GetIndexInfo(ctx, task.CollectionID(), segment.GetID())
if err != nil {
if !errors.Is(err, merr.ErrIndexNotFound) {
log.Warn("failed to get index of segment", zap.Error(err))
return err
}
indexes = nil
}
// Get collection index info
indexInfos, err := ex.broker.DescribeIndex(ctx, task.CollectionID())
if err != nil {
log.Warn("fail to get index meta of collection")
return err
}
// update the field index params
for _, segmentIndex := range indexes {
index, found := lo.Find(indexInfos, func(indexInfo *indexpb.IndexInfo) bool {
return indexInfo.IndexID == segmentIndex.IndexID
})
if !found {
log.Warn("no collection index info for the given segment index", zap.String("indexName", segmentIndex.GetIndexName()))
}
params := funcutil.KeyValuePair2Map(segmentIndex.GetIndexParams())
for _, kv := range index.GetUserIndexParams() {
if indexparams.IsConfigableIndexParam(kv.GetKey()) {
params[kv.GetKey()] = kv.GetValue()
}
}
segmentIndex.IndexParams = funcutil.Map2KeyValuePair(params)
}
loadInfo := utils.PackSegmentLoadInfo(segment, channel.GetSeekPosition(), indexes)
req := packLoadSegmentRequest(
task,
@ -238,10 +189,10 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
)
// Get shard leader for the given replica and segment
leaderID, ok := getShardLeader(ex.meta.ReplicaManager, ex.dist, task.CollectionID(), action.Node(), segment.GetInsertChannel())
leaderID, ok := getShardLeader(ex.meta.ReplicaManager, ex.dist, task.CollectionID(), action.Node(), task.Shard())
if !ok {
msg := "no shard leader for the segment to execute loading"
err = merr.WrapErrChannelNotFound(segment.GetInsertChannel(), "shard delegator not found")
err = merr.WrapErrChannelNotFound(task.Shard(), "shard delegator not found")
log.Warn(msg, zap.Error(err))
return err
}
@ -444,3 +395,211 @@ func (ex *Executor) unsubscribeChannel(task *ChannelTask, step int) error {
log.Info("unsubscribe channel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed))
return nil
}
func (ex *Executor) executeLeaderAction(task *LeaderTask, step int) {
switch task.Actions()[step].Type() {
case ActionTypeGrow, ActionTypeUpdate:
ex.setDistribution(task, step)
case ActionTypeReduce:
ex.removeDistribution(task, step)
}
}
func (ex *Executor) setDistribution(task *LeaderTask, step int) error {
action := task.Actions()[step].(*LeaderAction)
defer action.rpcReturned.Store(true)
ctx := task.Context()
log := log.Ctx(ctx).With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int64("segmentID", task.segmentID),
zap.Int64("leader", action.leaderID),
zap.Int64("node", action.Node()),
zap.String("source", task.Source().String()),
)
var err error
defer func() {
if err != nil {
task.Fail(err)
}
ex.removeTask(task, step)
}()
collectionInfo, loadMeta, channel, err := ex.getMetaInfo(ctx, task)
if err != nil {
return err
}
loadInfo, _, err := ex.getLoadInfo(ctx, task.CollectionID(), action.SegmentID(), channel)
if err != nil {
return err
}
req := &querypb.SyncDistributionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments),
commonpbutil.WithMsgID(task.ID()),
),
CollectionID: task.collectionID,
Channel: task.Shard(),
Schema: collectionInfo.GetSchema(),
LoadMeta: loadMeta,
ReplicaID: task.ReplicaID(),
Actions: []*querypb.SyncAction{
{
Type: querypb.SyncType_Set,
PartitionID: loadInfo.GetPartitionID(),
SegmentID: action.SegmentID(),
NodeID: action.Node(),
Info: loadInfo,
},
},
}
startTs := time.Now()
log.Info("Sync Distribution...")
status, err := ex.cluster.SyncDistribution(task.Context(), task.leaderID, req)
err = merr.CheckRPCCall(status, err)
if err != nil {
log.Warn("failed to sync distribution", zap.Error(err))
return err
}
elapsed := time.Since(startTs)
log.Info("sync distribution done", zap.Duration("elapsed", elapsed))
return nil
}
func (ex *Executor) removeDistribution(task *LeaderTask, step int) error {
action := task.Actions()[step].(*LeaderAction)
defer action.rpcReturned.Store(true)
ctx := task.Context()
log := log.Ctx(ctx).With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int64("segmentID", task.segmentID),
zap.Int64("leader", action.leaderID),
zap.Int64("node", action.Node()),
zap.String("source", task.Source().String()),
)
var err error
defer func() {
if err != nil {
task.Fail(err)
}
ex.removeTask(task, step)
}()
req := &querypb.SyncDistributionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments),
commonpbutil.WithMsgID(task.ID()),
),
CollectionID: task.collectionID,
Channel: task.Shard(),
ReplicaID: task.ReplicaID(),
Actions: []*querypb.SyncAction{
{
Type: querypb.SyncType_Set,
SegmentID: action.SegmentID(),
},
},
}
startTs := time.Now()
log.Info("Sync Distribution...")
status, err := ex.cluster.SyncDistribution(task.Context(), task.leaderID, req)
// status, err := ex.cluster.LoadSegments(task.Context(), leaderID, req)
err = merr.CheckRPCCall(status, err)
if err != nil {
log.Warn("failed to sync distribution", zap.Error(err))
return err
}
elapsed := time.Since(startTs)
log.Info("sync distribution done", zap.Duration("elapsed", elapsed))
return nil
}
func (ex *Executor) getMetaInfo(ctx context.Context, task Task) (*milvuspb.DescribeCollectionResponse, *querypb.LoadMetaInfo, *meta.DmChannel, error) {
collectionID := task.CollectionID()
shard := task.Shard()
log := log.Ctx(ctx)
collectionInfo, err := ex.broker.DescribeCollection(ctx, collectionID)
if err != nil {
log.Warn("failed to get collection info", zap.Error(err))
return nil, nil, nil, err
}
partitions, err := utils.GetPartitions(ex.meta.CollectionManager, collectionID)
if err != nil {
log.Warn("failed to get partitions of collection", zap.Error(err))
return nil, nil, nil, err
}
loadMeta := packLoadMeta(
ex.meta.GetLoadType(collectionID),
collectionID,
partitions...,
)
// get channel first, in case of target updated after segment info fetched
channel := ex.targetMgr.GetDmChannel(collectionID, shard, meta.NextTargetFirst)
if channel == nil {
return nil, nil, nil, merr.WrapErrChannelNotAvailable(shard)
}
return collectionInfo, loadMeta, channel, nil
}
func (ex *Executor) getLoadInfo(ctx context.Context, collectionID, segmentID int64, channel *meta.DmChannel) (*querypb.SegmentLoadInfo, []*indexpb.IndexInfo, error) {
log := log.Ctx(ctx)
resp, err := ex.broker.GetSegmentInfo(ctx, segmentID)
if err != nil || len(resp.GetInfos()) == 0 {
log.Warn("failed to get segment info from DataCoord", zap.Error(err))
return nil, nil, err
}
segment := resp.GetInfos()[0]
log = log.With(zap.String("level", segment.GetLevel().String()))
indexes, err := ex.broker.GetIndexInfo(ctx, collectionID, segment.GetID())
if err != nil {
if !errors.Is(err, merr.ErrIndexNotFound) {
log.Warn("failed to get index of segment", zap.Error(err))
return nil, nil, err
}
indexes = nil
}
// Get collection index info
indexInfos, err := ex.broker.DescribeIndex(ctx, collectionID)
if err != nil {
log.Warn("fail to get index meta of collection", zap.Error(err))
return nil, nil, err
}
// update the field index params
for _, segmentIndex := range indexes {
index, found := lo.Find(indexInfos, func(indexInfo *indexpb.IndexInfo) bool {
return indexInfo.IndexID == segmentIndex.IndexID
})
if !found {
log.Warn("no collection index info for the given segment index", zap.String("indexName", segmentIndex.GetIndexName()))
}
params := funcutil.KeyValuePair2Map(segmentIndex.GetIndexParams())
for _, kv := range index.GetUserIndexParams() {
if indexparams.IsConfigableIndexParam(kv.GetKey()) {
params[kv.GetKey()] = kv.GetValue()
}
}
segmentIndex.IndexParams = funcutil.Map2KeyValuePair(params)
}
loadInfo := utils.PackSegmentLoadInfo(segment, channel.GetSeekPosition(), indexes)
return loadInfo, indexInfos, nil
}

View File

@ -75,6 +75,14 @@ func NewReplicaSegmentIndex(task *SegmentTask) replicaSegmentIndex {
}
}
func NewReplicaLeaderIndex(task *LeaderTask) replicaSegmentIndex {
return replicaSegmentIndex{
ReplicaID: task.ReplicaID(),
SegmentID: task.SegmentID(),
IsGrowing: false,
}
}
type replicaChannelIndex struct {
ReplicaID int64
Channel string
@ -263,6 +271,10 @@ func (scheduler *taskScheduler) Add(task Task) error {
case *ChannelTask:
index := replicaChannelIndex{task.ReplicaID(), task.Channel()}
scheduler.channelTasks[index] = task
case *LeaderTask:
index := NewReplicaLeaderIndex(task)
scheduler.segmentTasks[index] = task
}
scheduler.updateTaskMetrics()
@ -369,6 +381,23 @@ func (scheduler *taskScheduler) preAdd(task Task) error {
return merr.WrapErrServiceInternal("source channel unsubscribed, stop balancing")
}
}
case *LeaderTask:
index := NewReplicaLeaderIndex(task)
if old, ok := scheduler.segmentTasks[index]; ok {
if task.Priority() > old.Priority() {
log.Info("replace old task, the new one with higher priority",
zap.Int64("oldID", old.ID()),
zap.String("oldPriority", old.Priority().String()),
zap.Int64("newID", task.ID()),
zap.String("newPriority", task.Priority().String()),
)
old.Cancel(merr.WrapErrServiceInternal("replaced with the other one with higher priority"))
scheduler.remove(old)
return nil
}
return merr.WrapErrServiceInternal("task with the same segment exists")
}
default:
panic(fmt.Sprintf("preAdd: forget to process task type: %+v", task))
}
@ -755,6 +784,11 @@ func (scheduler *taskScheduler) remove(task Task) {
index := replicaChannelIndex{task.ReplicaID(), task.Channel()}
delete(scheduler.channelTasks, index)
log = log.With(zap.String("channel", task.Channel()))
case *LeaderTask:
index := NewReplicaLeaderIndex(task)
delete(scheduler.segmentTasks, index)
log = log.With(zap.Int64("segmentID", task.SegmentID()))
}
scheduler.updateTaskMetrics()
@ -780,6 +814,11 @@ func (scheduler *taskScheduler) checkStale(task Task) error {
return err
}
case *LeaderTask:
if err := scheduler.checkLeaderTaskStale(task); err != nil {
return err
}
default:
panic(fmt.Sprintf("checkStale: forget to check task type: %+v", task))
}
@ -865,3 +904,53 @@ func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) error {
}
return nil
}
func (scheduler *taskScheduler) checkLeaderTaskStale(task *LeaderTask) error {
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.String("source", task.Source().String()),
zap.Int64("leaderID", task.leaderID),
)
for _, action := range task.Actions() {
switch action.Type() {
case ActionTypeGrow:
taskType := GetTaskType(task)
var segment *datapb.SegmentInfo
if taskType == TaskTypeMove || taskType == TaskTypeUpdate {
segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTarget)
} else {
segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.NextTarget)
}
if segment == nil {
log.Warn("task stale due to the segment to load not exists in targets",
zap.Int64("segment", task.segmentID),
zap.String("taskType", taskType.String()),
)
return merr.WrapErrSegmentReduplicate(task.SegmentID(), "target doesn't contain this segment")
}
replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node())
if replica == nil {
log.Warn("task stale due to replica not found")
return merr.WrapErrReplicaNotFound(task.CollectionID(), "by collectionID")
}
view := scheduler.distMgr.GetLeaderShardView(task.leaderID, task.Shard())
if view == nil {
log.Warn("task stale due to leader not found")
return merr.WrapErrChannelNotFound(task.Shard(), "failed to get shard delegator")
}
case ActionTypeReduce:
view := scheduler.distMgr.GetLeaderShardView(task.leaderID, task.Shard())
if view == nil {
log.Warn("task stale due to leader not found")
return merr.WrapErrChannelNotFound(task.Shard(), "failed to get shard delegator")
}
}
}
return nil
}

View File

@ -72,6 +72,7 @@ type Task interface {
ID() typeutil.UniqueID
CollectionID() typeutil.UniqueID
ReplicaID() typeutil.UniqueID
Shard() string
SetID(id typeutil.UniqueID)
Status() Status
SetStatus(status Status)
@ -162,6 +163,10 @@ func (task *baseTask) ReplicaID() typeutil.UniqueID {
return task.replicaID
}
func (task *baseTask) Shard() string {
return task.shard
}
func (task *baseTask) LoadType() querypb.LoadType {
return task.loadType
}
@ -318,10 +323,6 @@ func NewSegmentTask(ctx context.Context,
}, nil
}
func (task *SegmentTask) Shard() string {
return task.shard
}
func (task *SegmentTask) SegmentID() typeutil.UniqueID {
return task.segmentID
}
@ -383,3 +384,40 @@ func (task *ChannelTask) Index() string {
func (task *ChannelTask) String() string {
return fmt.Sprintf("%s [channel=%s]", task.baseTask.String(), task.Channel())
}
type LeaderTask struct {
*baseTask
segmentID typeutil.UniqueID
leaderID int64
}
func NewLeaderTask(ctx context.Context,
timeout time.Duration,
source Source,
collectionID,
replicaID typeutil.UniqueID,
leaderID int64,
action *LeaderAction,
) *LeaderTask {
segmentID := action.SegmentID()
base := newBaseTask(ctx, source, collectionID, replicaID, action.Shard(), fmt.Sprintf("LeaderTask-%s-%d", action.Type().String(), segmentID))
base.actions = []Action{action}
return &LeaderTask{
baseTask: base,
segmentID: segmentID,
leaderID: leaderID,
}
}
func (task *LeaderTask) SegmentID() typeutil.UniqueID {
return task.segmentID
}
func (task *LeaderTask) Index() string {
return fmt.Sprintf("%s[segment=%d][growing=false]", task.baseTask.Index(), task.segmentID)
}
func (task *LeaderTask) String() string {
return fmt.Sprintf("%s [segmentID=%d][leader=%d]", task.baseTask.String(), task.segmentID, task.leaderID)
}

View File

@ -171,6 +171,8 @@ func (suite *TaskSuite) BeforeTest(suiteName, testName string) {
"TestMoveSegmentTaskStale",
"TestSubmitDuplicateLoadSegmentTask",
"TestSubmitDuplicateSubscribeChannelTask",
"TestLeaderTaskSet",
"TestLeaderTaskRemove",
"TestNoExecutor":
suite.meta.PutCollection(&meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
@ -1213,6 +1215,113 @@ func (suite *TaskSuite) TestChannelTaskReplace() {
suite.AssertTaskNum(0, channelNum, channelNum, 0)
}
func (suite *TaskSuite) TestLeaderTaskSet() {
ctx := context.Background()
timeout := 10 * time.Second
targetNode := int64(3)
partition := int64(100)
channel := &datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: Params.CommonCfg.RootCoordDml.GetValue() + "-test",
}
// Expect
suite.broker.EXPECT().DescribeCollection(mock.Anything, suite.collection).Return(&milvuspb.DescribeCollectionResponse{
Schema: &schemapb.CollectionSchema{
Name: "TestLoadSegmentTask",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector},
},
},
}, nil)
suite.broker.EXPECT().DescribeIndex(mock.Anything, suite.collection).Return([]*indexpb.IndexInfo{
{
CollectionID: suite.collection,
},
}, nil)
for _, segment := range suite.loadSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{
Infos: []*datapb.SegmentInfo{
{
ID: segment,
CollectionID: suite.collection,
PartitionID: partition,
InsertChannel: channel.ChannelName,
},
},
}, nil)
suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil)
}
suite.cluster.EXPECT().SyncDistribution(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil)
// Test load segment task
suite.dist.ChannelDistManager.Update(targetNode, meta.DmChannelFromVChannel(&datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: channel.ChannelName,
}))
tasks := []Task{}
segments := make([]*datapb.SegmentInfo, 0)
for _, segment := range suite.loadSegments {
segments = append(segments, &datapb.SegmentInfo{
ID: segment,
InsertChannel: channel.ChannelName,
PartitionID: 1,
})
task := NewLeaderTask(
ctx,
timeout,
WrapIDSource(0),
suite.collection,
suite.replica,
targetNode,
NewLeaderAction(targetNode, targetNode, ActionTypeGrow, channel.GetChannelName(), segment),
)
tasks = append(tasks, task)
err := suite.scheduler.Add(task)
suite.NoError(err)
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil)
suite.target.UpdateCollectionNextTarget(suite.collection)
segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
view := &meta.LeaderView{
ID: targetNode,
CollectionID: suite.collection,
Channel: channel.GetChannelName(),
Segments: map[int64]*querypb.SegmentDist{},
}
suite.dist.LeaderViewManager.Update(targetNode, view)
// Process tasks
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Process tasks done
// Dist contains channels
view = &meta.LeaderView{
ID: targetNode,
CollectionID: suite.collection,
Channel: channel.GetChannelName(),
Segments: map[int64]*querypb.SegmentDist{},
}
for _, segment := range suite.loadSegments {
view.Segments[segment] = &querypb.SegmentDist{NodeID: targetNode, Version: 0}
}
distSegments := lo.Map(segments, func(info *datapb.SegmentInfo, _ int) *meta.Segment {
return meta.SegmentFromInfo(info)
})
suite.dist.LeaderViewManager.Update(targetNode, view)
suite.dist.SegmentDistManager.Update(targetNode, distSegments...)
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(0, 0, 0, 0)
for _, task := range tasks {
suite.Equal(TaskStatusSucceeded, task.Status())
suite.NoError(task.Err())
}
}
func (suite *TaskSuite) TestCreateTaskBehavior() {
chanelTask, err := NewChannelTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0)
suite.ErrorIs(err, merr.ErrParameterInvalid)
@ -1244,6 +1353,10 @@ func (suite *TaskSuite) TestCreateTaskBehavior() {
segmentTask, err = NewSegmentTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0, segmentAction1, segmentAction2)
suite.ErrorIs(err, merr.ErrParameterInvalid)
suite.Nil(segmentTask)
leaderAction := NewLeaderAction(1, 2, ActionTypeGrow, "fake-channel1", 100)
leaderTask := NewLeaderTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0, 1, leaderAction)
suite.NotNil(leaderTask)
}
func (suite *TaskSuite) TestSegmentTaskReplace() {
@ -1387,6 +1500,74 @@ func (suite *TaskSuite) dispatchAndWait(node int64) {
suite.FailNow("executor hangs in executing tasks", "count=%d keys=%+v", count, keys)
}
func (suite *TaskSuite) TestLeaderTaskRemove() {
ctx := context.Background()
timeout := 10 * time.Second
targetNode := int64(3)
partition := int64(100)
channel := &datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: Params.CommonCfg.RootCoordDml.GetValue() + "-test",
}
// Expect
suite.cluster.EXPECT().SyncDistribution(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil)
// Test remove segment task
view := &meta.LeaderView{
ID: targetNode,
CollectionID: suite.collection,
Channel: channel.ChannelName,
Segments: make(map[int64]*querypb.SegmentDist),
}
segments := make([]*meta.Segment, 0)
tasks := []Task{}
for _, segment := range suite.releaseSegments {
segments = append(segments, &meta.Segment{
SegmentInfo: &datapb.SegmentInfo{
ID: segment,
CollectionID: suite.collection,
PartitionID: partition,
InsertChannel: channel.ChannelName,
},
})
view.Segments[segment] = &querypb.SegmentDist{NodeID: targetNode, Version: 0}
task := NewLeaderTask(
ctx,
timeout,
WrapIDSource(0),
suite.collection,
suite.replica,
targetNode,
NewLeaderAction(targetNode, targetNode, ActionTypeReduce, channel.GetChannelName(), segment),
)
tasks = append(tasks, task)
err := suite.scheduler.Add(task)
suite.NoError(err)
}
suite.dist.SegmentDistManager.Update(targetNode, segments...)
suite.dist.LeaderViewManager.Update(targetNode, view)
segmentsNum := len(suite.releaseSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
// Process tasks
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
view.Segments = make(map[int64]*querypb.SegmentDist)
suite.dist.LeaderViewManager.Update(targetNode, view)
// Process tasks done
// suite.dist.LeaderViewManager.Update(targetNode)
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(0, 0, 0, 0)
for _, task := range tasks {
suite.Equal(TaskStatusSucceeded, task.Status())
suite.NoError(task.Err())
}
}
func (suite *TaskSuite) newScheduler() *taskScheduler {
return NewScheduler(
context.Background(),