From e6d2849c9a1e5586ed1747596b7f06ad1114fb76 Mon Sep 17 00:00:00 2001 From: SimFG Date: Thu, 29 Dec 2022 15:47:31 +0800 Subject: [PATCH] Check whether the node is stopping when using `load balance` api (#21438) Signed-off-by: SimFG --- internal/querycoordv2/services.go | 22 ++++++++++++++++++++++ internal/querycoordv2/services_test.go | 20 ++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index 92caad0c13..9866b3b576 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -444,6 +444,20 @@ func (s *Server) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo }, nil } +func (s *Server) isStoppingNode(nodeID int64) error { + isStopping, err := s.nodeMgr.IsStoppingNode(nodeID) + if err != nil { + log.Warn("fail to check whether the node is stopping", zap.Int64("node_id", nodeID), zap.Error(err)) + return err + } + if isStopping { + msg := fmt.Sprintf("failed to balance due to the source/destination node[%d] is stopping", nodeID) + log.Warn(msg) + return errors.New(msg) + } + return nil +} + func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error) { log := log.Ctx(ctx).With( zap.Int64("collectionID", req.GetCollectionID()), @@ -478,12 +492,20 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques log.Warn(msg) return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil } + if err := s.isStoppingNode(srcNode); err != nil { + return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, + fmt.Sprintf("can't balance, because the source node[%d] is invalid", srcNode), err), nil + } for _, dstNode := range req.GetDstNodeIDs() { if !replica.Nodes.Contain(dstNode) { msg := "destination nodes have to be in the same replica of source node" log.Warn(msg) return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil } + if err := s.isStoppingNode(dstNode); err != nil { + return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, + fmt.Sprintf("can't balance, because the destination node[%d] is invalid", dstNode), err), nil + } } err := s.balanceSegments(ctx, req, replica) diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index 3b5b420bed..f465f9da31 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -745,6 +745,26 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() { suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode) suite.Contains(resp.Reason, "failed to balance segments") suite.Contains(resp.Reason, task.ErrTaskCanceled.Error()) + + suite.meta.ReplicaManager.AddNode(replicas[0].ID, 10) + req.SourceNodeIDs = []int64{10} + resp, err = server.LoadBalance(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode) + + req.SourceNodeIDs = []int64{srcNode} + req.DstNodeIDs = []int64{10} + resp, err = server.LoadBalance(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode) + + suite.nodeMgr.Add(session.NewNodeInfo(10, "localhost")) + suite.nodeMgr.Stopping(10) + resp, err = server.LoadBalance(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode) + suite.nodeMgr.Remove(10) + suite.meta.ReplicaManager.RemoveNode(replicas[0].ID, 10) } }