From f8cff79804eb87dc050c4cff4978a2145cc80a21 Mon Sep 17 00:00:00 2001 From: SimFG Date: Tue, 6 Dec 2022 22:59:19 +0800 Subject: [PATCH] Support the graceful stop for the query node (#20851) Signed-off-by: SimFG Signed-off-by: SimFG --- cmd/roles/roles.go | 7 +- configs/milvus.yaml | 1 + go.mod | 12 +- go.sum | 7 +- internal/core/src/pb/common.pb.cc | 17 +-- internal/core/src/pb/common.pb.h | 3 +- internal/distributed/querynode/service.go | 10 +- .../metastore/kv/datacoord/kv_catalog_test.go | 1 - internal/querycoordv2/balance/balance.go | 32 ++++- internal/querycoordv2/balance/balance_test.go | 42 ++++-- .../balance/rowcount_based_balancer.go | 119 ++++++++++++++--- .../balance/rowcount_based_balancer_test.go | 123 +++++++++++++++--- internal/querycoordv2/balance/utils.go | 2 + .../querycoordv2/checkers/balance_checker.go | 7 +- internal/querycoordv2/checkers/controller.go | 11 ++ internal/querycoordv2/dist/dist_controller.go | 7 +- internal/querycoordv2/dist/dist_handler.go | 55 ++++---- internal/querycoordv2/mocks/querynode.go | 4 + internal/querycoordv2/server.go | 10 ++ internal/querycoordv2/server_test.go | 10 ++ internal/querycoordv2/services_test.go | 16 ++- internal/querycoordv2/session/node_manager.go | 29 +++++ internal/querycoordv2/task/utils.go | 6 + internal/querynode/collection_test.go | 1 + internal/querynode/data_sync_service_test.go | 4 + internal/querynode/impl.go | 90 ++++++++----- internal/querynode/impl_test.go | 23 ++++ internal/querynode/meta_replica.go | 22 ++++ internal/querynode/meta_replica_test.go | 38 ++++++ internal/querynode/query_node.go | 31 ++++- internal/querynode/query_node_test.go | 1 + internal/util/paramtable/component_param.go | 15 +++ .../util/paramtable/component_param_test.go | 12 ++ internal/util/sessionutil/session_util.go | 59 ++++++++- .../util/sessionutil/session_util_test.go | 1 + scripts/check_proto_product.sh | 4 +- scripts/stop_graceful.sh | 48 +++++++ 37 files changed, 730 insertions(+), 150 deletions(-) create mode 100755 scripts/stop_graceful.sh diff --git a/cmd/roles/roles.go b/cmd/roles/roles.go index e5b16abbfb..46ba2bed3b 100644 --- a/cmd/roles/roles.go +++ b/cmd/roles/roles.go @@ -199,6 +199,10 @@ func (mr *MilvusRoles) setupLogger() { func (mr *MilvusRoles) Run(local bool, alias string) { log.Info("starting running Milvus components") ctx, cancel := context.WithCancel(context.Background()) + defer func() { + // some deferred Stop has race with context cancel + cancel() + }() mr.printLDPreLoad() // only standalone enable localMsg @@ -322,7 +326,4 @@ func (mr *MilvusRoles) Run(local bool, alias string) { syscall.SIGQUIT) sig := <-sc log.Error("Get signal to exit\n", zap.String("signal", sig.String())) - - // some deferred Stop has race with context cancel - cancel() } diff --git a/configs/milvus.yaml b/configs/milvus.yaml index be770a2ffc..ed81b82b68 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -374,6 +374,7 @@ common: entityExpiration: -1 # Entity expiration in seconds, CAUTION make sure entityExpiration >= retentionDuration and -1 means never expire gracefulTime: 5000 # milliseconds. it represents the interval (in ms) by which the request arrival time needs to be subtracted in the case of Bounded Consistency. + gracefulStopTimeout: 30 # seconds. it will force quit the server if the graceful stop process is not completed during this time. # Default value: auto # Valid values: [auto, avx512, avx2, avx, sse4_2] diff --git a/go.mod b/go.mod index ad332718c4..19f13b443b 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( github.com/klauspost/compress v1.14.2 github.com/lingdor/stackerror v0.0.0-20191119040541-976d8885ed76 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d - github.com/milvus-io/milvus-proto/go-api v0.0.0-20221019080323-84e9fa2f9e45 + github.com/milvus-io/milvus-proto/go-api v0.0.0-20221126103108-4d988e37ebf2 github.com/minio/minio-go/v7 v7.0.17 github.com/opentracing/opentracing-go v1.2.0 github.com/panjf2000/ants/v2 v2.4.8 @@ -88,10 +88,8 @@ require ( github.com/facebookgo/ensure v0.0.0-20200202191622-63f1cf65ac4c // indirect github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect github.com/facebookgo/subset v0.0.0-20200203212716-c811ad88dec4 // indirect - github.com/fatih/color v1.10.0 // indirect github.com/form3tech-oss/jwt-go v3.2.3+incompatible // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect - github.com/ghodss/yaml v1.0.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-playground/locales v0.13.0 // indirect @@ -109,15 +107,12 @@ require ( github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/hashicorp/hcl v1.0.0 // indirect - github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jonboulle/clockwork v0.2.2 // indirect github.com/json-iterator/go v1.1.11 // indirect github.com/klauspost/asmfmt v1.3.1 // indirect github.com/klauspost/cpuid v1.3.1 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect - github.com/kris-nova/logger v0.0.0-20181127235838-fd0d87064b06 // indirect - github.com/kris-nova/lolgopher v0.0.0-20180921204813-313b3abb0d9b // indirect github.com/leodido/go-urn v1.2.0 // indirect github.com/linkedin/goavro/v2 v2.11.1 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect @@ -147,7 +142,6 @@ require ( github.com/sirupsen/logrus v1.8.1 // indirect github.com/soheilhy/cmux v0.1.5 // indirect github.com/spf13/afero v1.6.0 // indirect - github.com/spf13/cobra v1.1.3 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.4.0 // indirect @@ -175,13 +169,13 @@ require ( go.opentelemetry.io/otel/sdk/metric v0.20.0 // indirect go.opentelemetry.io/otel/trace v0.20.0 // indirect go.opentelemetry.io/proto/otlp v0.7.0 // indirect - go.uber.org/multierr v1.6.0 // indirect + go.uber.org/multierr v1.6.0 golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602 golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect - golang.org/x/text v0.3.7 // indirect + golang.org/x/text v0.3.7 golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect golang.org/x/tools v0.1.9 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect diff --git a/go.sum b/go.sum index f5f1d15a9f..dc03c336f9 100644 --- a/go.sum +++ b/go.sum @@ -197,7 +197,6 @@ github.com/facebookgo/subset v0.0.0-20200203212716-c811ad88dec4 h1:7HZCaLC5+BZpm github.com/facebookgo/subset v0.0.0-20200203212716-c811ad88dec4/go.mod h1:5tD+neXqOorC30/tWg0LCSkrqj/AR6gu8yY8/fpw1q0= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= -github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/form3tech-oss/jwt-go v3.2.3+incompatible h1:7ZaBxOI7TMoYBfyA3cQHErNNyAWIKUMIwqxEtgHOs5c= @@ -457,9 +456,7 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kris-nova/logger v0.0.0-20181127235838-fd0d87064b06 h1:vN4d3jSss3ExzUn2cE0WctxztfOgiKvMKnDrydBsg00= -github.com/kris-nova/logger v0.0.0-20181127235838-fd0d87064b06/go.mod h1:++9BgZujZd4v0ZTZCb5iPsaomXdZWyxotIAh1IiDm44= github.com/kris-nova/lolgopher v0.0.0-20180921204813-313b3abb0d9b h1:xYEM2oBUhBEhQjrV+KJ9lEWDWYZoNVZUaBF++Wyljq4= -github.com/kris-nova/lolgopher v0.0.0-20180921204813-313b3abb0d9b/go.mod h1:V0HF/ZBlN86HqewcDC/cVxMmYDiRukWjSrgKLUAn9Js= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/lingdor/stackerror v0.0.0-20191119040541-976d8885ed76 h1:IVlcvV0CjvfBYYod5ePe89l+3LBAl//6n9kJ9Vr2i0k= @@ -493,8 +490,8 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyex github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4= -github.com/milvus-io/milvus-proto/go-api v0.0.0-20221019080323-84e9fa2f9e45 h1:QxGQqRtJbbdMf/jxodjYaTW8ZcqdyPDsuhIb9Y2jnwk= -github.com/milvus-io/milvus-proto/go-api v0.0.0-20221019080323-84e9fa2f9e45/go.mod h1:148qnlmZ0Fdm1Fq+Mj/OW2uDoEP25g3mjh0vMGtkgmk= +github.com/milvus-io/milvus-proto/go-api v0.0.0-20221126103108-4d988e37ebf2 h1:kf0QwCVxVTV8HLX0rTsOFs4l/vw7MoUjKFVkd4yQRVM= +github.com/milvus-io/milvus-proto/go-api v0.0.0-20221126103108-4d988e37ebf2/go.mod h1:148qnlmZ0Fdm1Fq+Mj/OW2uDoEP25g3mjh0vMGtkgmk= github.com/milvus-io/pulsar-client-go v0.6.8 h1:fZdZH73aPRszu2fazyeeahQEz34tyn1Pt9EkqJmV100= github.com/milvus-io/pulsar-client-go v0.6.8/go.mod h1:oFIlYIk23tamkSLttw849qphmMIpHY8ztEBWDWJW+sc= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= diff --git a/internal/core/src/pb/common.pb.cc b/internal/core/src/pb/common.pb.cc index b9a79da471..9c8d3d751b 100644 --- a/internal/core/src/pb/common.pb.cc +++ b/internal/core/src/pb/common.pb.cc @@ -468,13 +468,13 @@ const char descriptor_table_protodef_common_2eproto[] PROTOBUF_SECTION_VARIABLE( "ateUser\020\024\022\032\n\026PrivilegeDropOwnership\020\025\022\034\n" "\030PrivilegeSelectOwnership\020\026\022\034\n\030Privilege" "ManageOwnership\020\027\022\027\n\023PrivilegeSelectUser" - "\020\030*E\n\tStateCode\022\020\n\014Initializing\020\000\022\013\n\007Hea" - "lthy\020\001\022\014\n\010Abnormal\020\002\022\013\n\007StandBy\020\003:^\n\021pri" - "vilege_ext_obj\022\037.google.protobuf.Message" - "Options\030\351\007 \001(\0132!.milvus.proto.common.Pri" - "vilegeExtBU\n\016io.milvus.grpcB\013CommonProto" - "P\001Z1github.com/milvus-io/milvus-proto/go" - "-api/commonpb\240\001\001b\006proto3" + "\020\030*S\n\tStateCode\022\020\n\014Initializing\020\000\022\013\n\007Hea" + "lthy\020\001\022\014\n\010Abnormal\020\002\022\013\n\007StandBy\020\003\022\014\n\010Sto" + "pping\020\004:^\n\021privilege_ext_obj\022\037.google.pr" + "otobuf.MessageOptions\030\351\007 \001(\0132!.milvus.pr" + "oto.common.PrivilegeExtBU\n\016io.milvus.grp" + "cB\013CommonProtoP\001Z1github.com/milvus-io/m" + "ilvus-proto/go-api/commonpb\240\001\001b\006proto3" ; static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_common_2eproto_deps[1] = { &::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto, @@ -495,7 +495,7 @@ static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_com static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_common_2eproto_once; static bool descriptor_table_common_2eproto_initialized = false; const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_common_2eproto = { - &descriptor_table_common_2eproto_initialized, descriptor_table_protodef_common_2eproto, "common.proto", 5384, + &descriptor_table_common_2eproto_initialized, descriptor_table_protodef_common_2eproto, "common.proto", 5398, &descriptor_table_common_2eproto_once, descriptor_table_common_2eproto_sccs, descriptor_table_common_2eproto_deps, 11, 1, schemas, file_default_instances, TableStruct_common_2eproto::offsets, file_level_metadata_common_2eproto, 11, file_level_enum_descriptors_common_2eproto, file_level_service_descriptors_common_2eproto, @@ -842,6 +842,7 @@ bool StateCode_IsValid(int value) { case 1: case 2: case 3: + case 4: return true; default: return false; diff --git a/internal/core/src/pb/common.pb.h b/internal/core/src/pb/common.pb.h index ff1327bc74..04826174b3 100644 --- a/internal/core/src/pb/common.pb.h +++ b/internal/core/src/pb/common.pb.h @@ -564,12 +564,13 @@ enum StateCode : int { Healthy = 1, Abnormal = 2, StandBy = 3, + Stopping = 4, StateCode_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), StateCode_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() }; bool StateCode_IsValid(int value); constexpr StateCode StateCode_MIN = Initializing; -constexpr StateCode StateCode_MAX = StandBy; +constexpr StateCode StateCode_MAX = Stopping; constexpr int StateCode_ARRAYSIZE = StateCode_MAX + 1; const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* StateCode_descriptor(); diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index 5a7531417d..8ddacc8c2a 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -221,6 +221,11 @@ func (s *Server) Run() error { // Stop stops QueryNode's grpc service. func (s *Server) Stop() error { log.Debug("QueryNode stop", zap.String("Address", Params.GetAddress())) + err := s.querynode.Stop() + if err != nil { + return err + } + if s.closer != nil { if err := s.closer.Close(); err != nil { return err @@ -235,11 +240,6 @@ func (s *Server) Stop() error { log.Debug("Graceful stop grpc server...") s.grpcServer.GracefulStop() } - - err := s.querynode.Stop() - if err != nil { - return err - } s.wg.Wait() return nil } diff --git a/internal/metastore/kv/datacoord/kv_catalog_test.go b/internal/metastore/kv/datacoord/kv_catalog_test.go index 02226f106e..6b3aa8d21f 100644 --- a/internal/metastore/kv/datacoord/kv_catalog_test.go +++ b/internal/metastore/kv/datacoord/kv_catalog_test.go @@ -473,7 +473,6 @@ func Test_AlterSegmentsAndAddNewSegment(t *testing.T) { return "", errors.New("key not found") } - // TODO fubang catalog := &Catalog{txn, "a"} err := catalog.AlterSegmentsAndAddNewSegment(context.TODO(), []*datapb.SegmentInfo{droppedSegment}, segment1) assert.NoError(t, err) diff --git a/internal/querycoordv2/balance/balance.go b/internal/querycoordv2/balance/balance.go index 37982dae07..49d2b3e883 100644 --- a/internal/querycoordv2/balance/balance.go +++ b/internal/querycoordv2/balance/balance.go @@ -24,11 +24,40 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/task" ) +type Weight = int + +const ( + weightLow int = iota - 1 + weightNormal + weightHigh +) + +func GetWeight(w int) Weight { + if w > 0 { + return weightHigh + } else if w < 0 { + return weightLow + } + return weightNormal +} + +func GetTaskPriorityFromWeight(w Weight) task.Priority { + switch w { + case weightHigh: + return task.TaskPriorityHigh + case weightLow: + return task.TaskPriorityLow + default: + return task.TaskPriorityNormal + } +} + type SegmentAssignPlan struct { Segment *meta.Segment ReplicaID int64 From int64 // -1 if empty To int64 + Weight Weight } type ChannelAssignPlan struct { @@ -36,6 +65,7 @@ type ChannelAssignPlan struct { ReplicaID int64 From int64 To int64 + Weight Weight } type Balance interface { @@ -99,7 +129,7 @@ func (b *RoundRobinBalancer) getNodes(nodes []int64) []*session.NodeInfo { ret := make([]*session.NodeInfo, 0, len(nodes)) for _, n := range nodes { node := b.nodeManager.Get(n) - if node != nil { + if node != nil && !node.IsStoppingState() { ret = append(ret, node) } } diff --git a/internal/querycoordv2/balance/balance_test.go b/internal/querycoordv2/balance/balance_test.go index 64f3e6cb3f..4ef24b2d82 100644 --- a/internal/querycoordv2/balance/balance_test.go +++ b/internal/querycoordv2/balance/balance_test.go @@ -43,15 +43,17 @@ func (suite *BalanceTestSuite) TestAssignBalance() { name string nodeIDs []int64 segmentCnts []int + states []session.State deltaCnts []int assignments []*meta.Segment expectPlans []SegmentAssignPlan }{ { name: "normal assignment", - nodeIDs: []int64{1, 2}, - segmentCnts: []int{100, 200}, - deltaCnts: []int{0, -200}, + nodeIDs: []int64{1, 2, 3}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateStopping}, + segmentCnts: []int{100, 200, 0}, + deltaCnts: []int{0, -200, 0}, assignments: []*meta.Segment{ {SegmentInfo: &datapb.SegmentInfo{ID: 1}}, {SegmentInfo: &datapb.SegmentInfo{ID: 2}}, @@ -67,6 +69,7 @@ func (suite *BalanceTestSuite) TestAssignBalance() { name: "empty assignment", nodeIDs: []int64{}, segmentCnts: []int{}, + states: []session.State{}, deltaCnts: []int{}, assignments: []*meta.Segment{ {SegmentInfo: &datapb.SegmentInfo{ID: 1}}, @@ -83,8 +86,11 @@ func (suite *BalanceTestSuite) TestAssignBalance() { for i := range c.nodeIDs { nodeInfo := session.NewNodeInfo(c.nodeIDs[i], "127.0.0.1:0") nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i])) + nodeInfo.SetState(c.states[i]) suite.roundRobinBalancer.nodeManager.Add(nodeInfo) - suite.mockScheduler.EXPECT().GetNodeSegmentDelta(c.nodeIDs[i]).Return(c.deltaCnts[i]) + if !nodeInfo.IsStoppingState() { + suite.mockScheduler.EXPECT().GetNodeSegmentDelta(c.nodeIDs[i]).Return(c.deltaCnts[i]) + } } plans := suite.roundRobinBalancer.AssignSegment(c.assignments, c.nodeIDs) suite.ElementsMatch(c.expectPlans, plans) @@ -97,15 +103,17 @@ func (suite *BalanceTestSuite) TestAssignChannel() { name string nodeIDs []int64 channelCnts []int + states []session.State deltaCnts []int assignments []*meta.DmChannel expectPlans []ChannelAssignPlan }{ { name: "normal assignment", - nodeIDs: []int64{1, 2}, - channelCnts: []int{100, 200}, - deltaCnts: []int{0, -200}, + nodeIDs: []int64{1, 2, 3}, + channelCnts: []int{100, 200, 0}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateStopping}, + deltaCnts: []int{0, -200, 0}, assignments: []*meta.DmChannel{ {VchannelInfo: &datapb.VchannelInfo{ChannelName: "channel-1"}}, {VchannelInfo: &datapb.VchannelInfo{ChannelName: "channel-2"}}, @@ -121,6 +129,7 @@ func (suite *BalanceTestSuite) TestAssignChannel() { name: "empty assignment", nodeIDs: []int64{}, channelCnts: []int{}, + states: []session.State{}, deltaCnts: []int{}, assignments: []*meta.DmChannel{ {VchannelInfo: &datapb.VchannelInfo{ChannelName: "channel-1"}}, @@ -137,8 +146,11 @@ func (suite *BalanceTestSuite) TestAssignChannel() { for i := range c.nodeIDs { nodeInfo := session.NewNodeInfo(c.nodeIDs[i], "127.0.0.1:0") nodeInfo.UpdateStats(session.WithChannelCnt(c.channelCnts[i])) + nodeInfo.SetState(c.states[i]) suite.roundRobinBalancer.nodeManager.Add(nodeInfo) - suite.mockScheduler.EXPECT().GetNodeChannelDelta(c.nodeIDs[i]).Return(c.deltaCnts[i]) + if !nodeInfo.IsStoppingState() { + suite.mockScheduler.EXPECT().GetNodeChannelDelta(c.nodeIDs[i]).Return(c.deltaCnts[i]) + } } plans := suite.roundRobinBalancer.AssignChannel(c.assignments, c.nodeIDs) suite.ElementsMatch(c.expectPlans, plans) @@ -146,6 +158,20 @@ func (suite *BalanceTestSuite) TestAssignChannel() { } } +func (suite *BalanceTestSuite) TestWeight() { + suite.Run("GetWeight", func() { + suite.Equal(weightHigh, GetWeight(10)) + suite.Equal(weightNormal, GetWeight(0)) + suite.Equal(weightLow, GetWeight(-10)) + }) + + suite.Run("GetTaskPriorityFromWeight", func() { + suite.Equal(task.TaskPriorityHigh, GetTaskPriorityFromWeight(weightHigh)) + suite.Equal(task.TaskPriorityNormal, GetTaskPriorityFromWeight(weightNormal)) + suite.Equal(task.TaskPriorityLow, GetTaskPriorityFromWeight(weightLow)) + }) +} + func TestBalanceSuite(t *testing.T) { suite.Run(t, new(BalanceTestSuite)) } diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index 56e980e4c9..29878a74eb 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -28,17 +28,16 @@ import ( type RowCountBasedBalancer struct { *RoundRobinBalancer - nodeManager *session.NodeManager - dist *meta.DistributionManager - meta *meta.Meta - targetMgr *meta.TargetManager + dist *meta.DistributionManager + meta *meta.Meta + targetMgr *meta.TargetManager } func (b *RowCountBasedBalancer) AssignSegment(segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { - if len(nodes) == 0 { + nodeItems := b.convertToNodeItems(nodes) + if len(nodeItems) == 0 { return nil } - nodeItems := b.convertToNodeItems(nodes) queue := newPriorityQueue() for _, item := range nodeItems { queue.push(item) @@ -68,7 +67,8 @@ func (b *RowCountBasedBalancer) AssignSegment(segments []*meta.Segment, nodes [] func (b *RowCountBasedBalancer) convertToNodeItems(nodeIDs []int64) []*nodeItem { ret := make([]*nodeItem, 0, len(nodeIDs)) - for _, node := range nodeIDs { + for _, nodeInfo := range b.getNodes(nodeIDs) { + node := nodeInfo.ID() segments := b.dist.SegmentDistManager.GetByNode(node) rowcnt := 0 for _, s := range segments { @@ -108,6 +108,8 @@ func (b *RowCountBasedBalancer) balanceReplica(replica *meta.Replica) ([]Segment } nodesRowCnt := make(map[int64]int) nodesSegments := make(map[int64][]*meta.Segment) + stoppingNodesSegments := make(map[int64][]*meta.Segment) + totalCnt := 0 for _, nid := range nodes { segments := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nid) @@ -120,13 +122,23 @@ func (b *RowCountBasedBalancer) balanceReplica(replica *meta.Replica) ([]Segment cnt += int(s.GetNumOfRows()) } nodesRowCnt[nid] = cnt - nodesSegments[nid] = segments + + if nodeInfo := b.nodeManager.Get(nid); nodeInfo.IsStoppingState() { + stoppingNodesSegments[nid] = segments + } else { + nodesSegments[nid] = segments + } totalCnt += cnt } - average := totalCnt / len(nodes) + if len(nodes) == len(stoppingNodesSegments) { + return b.handleStoppingNodes(replica, stoppingNodesSegments) + } + + average := totalCnt / len(nodesSegments) neededRowCnt := 0 - for _, rowcnt := range nodesRowCnt { + for nodeID := range nodesSegments { + rowcnt := nodesRowCnt[nodeID] if rowcnt < average { neededRowCnt += average - rowcnt } @@ -138,13 +150,17 @@ func (b *RowCountBasedBalancer) balanceReplica(replica *meta.Replica) ([]Segment segmentsToMove := make([]*meta.Segment, 0) + stopSegments, cnt := b.collectionStoppingSegments(stoppingNodesSegments) + segmentsToMove = append(segmentsToMove, stopSegments...) + neededRowCnt -= cnt + // select segments to be moved outer: - for nodeID, rowcnt := range nodesRowCnt { + for nodeID, segments := range nodesSegments { + rowcnt := nodesRowCnt[nodeID] if rowcnt <= average { continue } - segments := nodesSegments[nodeID] sort.Slice(segments, func(i, j int) bool { return segments[i].GetNumOfRows() > segments[j].GetNumOfRows() }) @@ -168,7 +184,8 @@ outer: // allocate segments to those nodes with row cnt less than average queue := newPriorityQueue() - for nodeID, rowcnt := range nodesRowCnt { + for nodeID := range nodesSegments { + rowcnt := nodesRowCnt[nodeID] if rowcnt >= average { continue } @@ -177,19 +194,92 @@ outer: } plans := make([]SegmentAssignPlan, 0) + getPlanWeight := func(nodeID int64) Weight { + if _, ok := stoppingNodesSegments[nodeID]; ok { + return GetWeight(1) + } + return GetWeight(0) + } for _, s := range segmentsToMove { node := queue.pop().(*nodeItem) + plan := SegmentAssignPlan{ ReplicaID: replica.GetID(), From: s.Node, To: node.nodeID, Segment: s, + Weight: getPlanWeight(s.Node), } plans = append(plans, plan) node.setPriority(node.getPriority() + int(s.GetNumOfRows())) queue.push(node) } - return plans, nil + return plans, b.getChannelPlan(replica, stoppingNodesSegments) +} + +func (b *RowCountBasedBalancer) handleStoppingNodes(replica *meta.Replica, nodeSegments map[int64][]*meta.Segment) ([]SegmentAssignPlan, []ChannelAssignPlan) { + segmentPlans := make([]SegmentAssignPlan, 0, len(nodeSegments)) + channelPlans := make([]ChannelAssignPlan, 0, len(nodeSegments)) + for nodeID, segments := range nodeSegments { + for _, segment := range segments { + segmentPlan := SegmentAssignPlan{ + ReplicaID: replica.ID, + From: nodeID, + To: -1, + Segment: segment, + Weight: GetWeight(1), + } + segmentPlans = append(segmentPlans, segmentPlan) + } + for _, dmChannel := range b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID) { + channelPlan := ChannelAssignPlan{ + ReplicaID: replica.ID, + From: nodeID, + To: -1, + Channel: dmChannel, + Weight: GetWeight(1), + } + channelPlans = append(channelPlans, channelPlan) + } + } + + return segmentPlans, channelPlans +} + +func (b *RowCountBasedBalancer) collectionStoppingSegments(stoppingNodesSegments map[int64][]*meta.Segment) ([]*meta.Segment, int) { + var ( + segments []*meta.Segment + removeRowCnt int + ) + + for _, stoppingSegments := range stoppingNodesSegments { + for _, segment := range stoppingSegments { + segments = append(segments, segment) + removeRowCnt += int(segment.GetNumOfRows()) + } + } + return segments, removeRowCnt +} + +func (b *RowCountBasedBalancer) getChannelPlan(replica *meta.Replica, stoppingNodesSegments map[int64][]*meta.Segment) []ChannelAssignPlan { + // maybe it will have some strategies to balance the channel in the future + // but now, only balance the channel for the stopping nodes. + return b.getChannelPlanForStoppingNodes(replica, stoppingNodesSegments) +} + +func (b *RowCountBasedBalancer) getChannelPlanForStoppingNodes(replica *meta.Replica, stoppingNodesSegments map[int64][]*meta.Segment) []ChannelAssignPlan { + channelPlans := make([]ChannelAssignPlan, 0) + for nodeID := range stoppingNodesSegments { + dmChannels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID) + plans := b.AssignChannel(dmChannels, replica.Replica.GetNodes()) + for i := range plans { + plans[i].From = nodeID + plans[i].ReplicaID = replica.ID + plans[i].Weight = GetWeight(1) + } + channelPlans = append(channelPlans, plans...) + } + return channelPlans } func NewRowCountBasedBalancer( @@ -201,7 +291,6 @@ func NewRowCountBasedBalancer( ) *RowCountBasedBalancer { return &RowCountBasedBalancer{ RoundRobinBalancer: NewRoundRobinBalancer(scheduler, nodeManager), - nodeManager: nodeManager, dist: dist, meta: meta, targetMgr: targetMgr, diff --git a/internal/querycoordv2/balance/rowcount_based_balancer_test.go b/internal/querycoordv2/balance/rowcount_based_balancer_test.go index 32fe26b1d1..f525f510d9 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer_test.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer_test.go @@ -19,6 +19,8 @@ package balance import ( "testing" + "github.com/milvus-io/milvus/internal/querycoordv2/task" + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -33,9 +35,10 @@ import ( type RowCountBasedBalancerTestSuite struct { suite.Suite - balancer *RowCountBasedBalancer - kv *etcdkv.EtcdKV - broker *meta.MockBroker + balancer *RowCountBasedBalancer + kv *etcdkv.EtcdKV + broker *meta.MockBroker + mockScheduler *task.MockScheduler } func (suite *RowCountBasedBalancerTestSuite) SetupSuite() { @@ -64,7 +67,8 @@ func (suite *RowCountBasedBalancerTestSuite) SetupTest() { distManager := meta.NewDistributionManager() nodeManager := session.NewNodeManager() - suite.balancer = NewRowCountBasedBalancer(nil, nodeManager, distManager, testMeta, testTarget) + suite.mockScheduler = task.NewMockScheduler(suite.T()) + suite.balancer = NewRowCountBasedBalancer(suite.mockScheduler, nodeManager, distManager, testMeta, testTarget) } func (suite *RowCountBasedBalancerTestSuite) TearDownTest() { @@ -77,6 +81,8 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() { distributions map[int64][]*meta.Segment assignments []*meta.Segment nodes []int64 + segmentCnts []int + states []session.State expectPlans []SegmentAssignPlan }{ { @@ -90,7 +96,9 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() { {SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10}}, {SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 15}}, }, - nodes: []int64{1, 2, 3}, + nodes: []int64{1, 2, 3, 4}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal, session.NodeStateStopping}, + segmentCnts: []int{0, 1, 1, 0}, expectPlans: []SegmentAssignPlan{ {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 5}}, From: -1, To: 2}, {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10}}, From: -1, To: 1}, @@ -110,6 +118,12 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() { for node, s := range c.distributions { balancer.dist.SegmentDistManager.Update(node, s...) } + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0") + nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i])) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + } plans := balancer.AssignSegment(c.assignments, c.nodes) suite.ElementsMatch(c.expectPlans, plans) }) @@ -118,14 +132,21 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() { func (suite *RowCountBasedBalancerTestSuite) TestBalance() { cases := []struct { - name string - nodes []int64 - distributions map[int64][]*meta.Segment - expectPlans []SegmentAssignPlan + name string + nodes []int64 + segmentCnts []int + states []session.State + shouldMock bool + distributions map[int64][]*meta.Segment + distributionChannels map[int64][]*meta.DmChannel + expectPlans []SegmentAssignPlan + expectChannelPlans []ChannelAssignPlan }{ { - name: "normal balance", - nodes: []int64{1, 2}, + name: "normal balance", + nodes: []int64{1, 2}, + segmentCnts: []int{1, 2}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, distributions: map[int64][]*meta.Segment{ 1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}}, 2: { @@ -136,10 +157,65 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { expectPlans: []SegmentAssignPlan{ {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2}, From: 2, To: 1, ReplicaID: 1}, }, + expectChannelPlans: []ChannelAssignPlan{}, }, { - name: "already balanced", - nodes: []int64{1, 2}, + name: "all stopping balance", + nodes: []int64{1, 2}, + segmentCnts: []int{1, 2}, + states: []session.State{session.NodeStateStopping, session.NodeStateStopping}, + distributions: map[int64][]*meta.Segment{ + 1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}}, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2}, + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + }, + expectPlans: []SegmentAssignPlan{ + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2}, From: 2, To: -1, ReplicaID: 1, Weight: weightHigh}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, From: 2, To: -1, ReplicaID: 1, Weight: weightHigh}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}, From: 1, To: -1, ReplicaID: 1, Weight: weightHigh}, + }, + expectChannelPlans: []ChannelAssignPlan{}, + }, + { + name: "part stopping balance", + nodes: []int64{1, 2, 3}, + segmentCnts: []int{1, 2, 2}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateStopping}, + shouldMock: true, + distributions: map[int64][]*meta.Segment{ + 1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}}, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2}, + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + 3: { + {SegmentInfo: &datapb.SegmentInfo{ID: 4, CollectionID: 1, NumOfRows: 10}, Node: 3}, + {SegmentInfo: &datapb.SegmentInfo{ID: 5, CollectionID: 1, NumOfRows: 10}, Node: 3}, + }, + }, + distributionChannels: map[int64][]*meta.DmChannel{ + 2: { + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v2"}, Node: 2}, + }, + 3: { + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v3"}, Node: 3}, + }, + }, + expectPlans: []SegmentAssignPlan{ + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 4, CollectionID: 1, NumOfRows: 10}, Node: 3}, From: 3, To: 1, ReplicaID: 1, Weight: weightHigh}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 5, CollectionID: 1, NumOfRows: 10}, Node: 3}, From: 3, To: 1, ReplicaID: 1, Weight: weightHigh}, + }, + expectChannelPlans: []ChannelAssignPlan{ + {Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v3"}, Node: 3}, From: 3, To: 1, ReplicaID: 1, Weight: weightHigh}, + }, + }, + { + name: "already balanced", + nodes: []int64{1, 2}, + segmentCnts: []int{1, 2}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, distributions: map[int64][]*meta.Segment{ 1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 30}, Node: 1}}, 2: { @@ -147,10 +223,12 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, }, }, - expectPlans: []SegmentAssignPlan{}, + expectPlans: []SegmentAssignPlan{}, + expectChannelPlans: []ChannelAssignPlan{}, }, } + suite.mockScheduler.Mock.On("GetNodeChannelDelta", mock.Anything).Return(0) for _, c := range cases { suite.Run(c.name, func() { suite.SetupSuite() @@ -167,6 +245,12 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { { SegmentID: 3, }, + { + SegmentID: 4, + }, + { + SegmentID: 5, + }, } suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( nil, segments, nil) @@ -179,8 +263,17 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { for node, s := range c.distributions { balancer.dist.SegmentDistManager.Update(node, s...) } + for node, v := range c.distributionChannels { + balancer.dist.ChannelDistManager.Update(node, v...) + } + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0") + nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i])) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + } segmentPlans, channelPlans := balancer.Balance() - suite.Empty(channelPlans) + suite.ElementsMatch(c.expectChannelPlans, channelPlans) suite.ElementsMatch(c.expectPlans, segmentPlans) }) } diff --git a/internal/querycoordv2/balance/utils.go b/internal/querycoordv2/balance/utils.go index 1491f28ea5..1f3c6b1857 100644 --- a/internal/querycoordv2/balance/utils.go +++ b/internal/querycoordv2/balance/utils.go @@ -56,6 +56,7 @@ func CreateSegmentTasksFromPlans(ctx context.Context, checkerID int64, timeout t ) continue } + task.SetPriority(GetTaskPriorityFromWeight(p.Weight)) ret = append(ret, task) } return ret @@ -85,6 +86,7 @@ func CreateChannelTasksFromPlans(ctx context.Context, checkerID int64, timeout t ) continue } + task.SetPriority(GetTaskPriorityFromWeight(p.Weight)) ret = append(ret, task) } return ret diff --git a/internal/querycoordv2/checkers/balance_checker.go b/internal/querycoordv2/checkers/balance_checker.go index 5316b2cb50..bfc68531af 100644 --- a/internal/querycoordv2/checkers/balance_checker.go +++ b/internal/querycoordv2/checkers/balance_checker.go @@ -45,7 +45,12 @@ func (b *BalanceChecker) Check(ctx context.Context) []task.Task { segmentPlans, channelPlans := b.Balance.Balance() tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout, segmentPlans) - task.SetPriority(task.TaskPriorityLow, tasks...) + task.SetPriorityWithFunc(func(t task.Task) task.Priority { + if t.Priority() == task.TaskPriorityHigh { + return task.TaskPriorityHigh + } + return task.TaskPriorityLow + }, tasks...) ret = append(ret, tasks...) tasks = balance.CreateChannelTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.ChannelTaskTimeout, channelPlans) diff --git a/internal/querycoordv2/checkers/controller.go b/internal/querycoordv2/checkers/controller.go index 96cdb9a067..88757fe493 100644 --- a/internal/querycoordv2/checkers/controller.go +++ b/internal/querycoordv2/checkers/controller.go @@ -35,6 +35,7 @@ var ( type CheckerController struct { stopCh chan struct{} + checkCh chan struct{} meta *meta.Meta dist *meta.DistributionManager targetMgr *meta.TargetManager @@ -68,6 +69,7 @@ func NewCheckerController( return &CheckerController{ stopCh: make(chan struct{}), + checkCh: make(chan struct{}), meta: meta, dist: dist, targetMgr: targetMgr, @@ -92,6 +94,11 @@ func (controller *CheckerController) Start(ctx context.Context) { case <-ticker.C: controller.check(ctx) + + case <-controller.checkCh: + ticker.Stop() + controller.check(ctx) + ticker.Reset(Params.QueryCoordCfg.CheckInterval) } } }() @@ -103,6 +110,10 @@ func (controller *CheckerController) Stop() { }) } +func (controller *CheckerController) Check() { + controller.checkCh <- struct{}{} +} + // check is the real implementation of Check func (controller *CheckerController) check(ctx context.Context) { tasks := make([]task.Task, 0) diff --git a/internal/querycoordv2/dist/dist_controller.go b/internal/querycoordv2/dist/dist_controller.go index 6bcfa10f3e..f319de0fde 100644 --- a/internal/querycoordv2/dist/dist_controller.go +++ b/internal/querycoordv2/dist/dist_controller.go @@ -63,12 +63,11 @@ func (dc *Controller) SyncAll(ctx context.Context) { wg := sync.WaitGroup{} for _, h := range dc.handlers { - handler := h wg.Add(1) - go func() { + go func(handler *distHandler) { defer wg.Done() - handler.getDistribution(ctx) - }() + handler.getDistribution(ctx, nil) + }(h) } wg.Wait() } diff --git a/internal/querycoordv2/dist/dist_handler.go b/internal/querycoordv2/dist/dist_handler.go index 7eb4f6c981..a59f08a009 100644 --- a/internal/querycoordv2/dist/dist_handler.go +++ b/internal/querycoordv2/dist/dist_handler.go @@ -58,7 +58,6 @@ func (dh *distHandler) start(ctx context.Context) { logger := log.With(zap.Int64("nodeID", dh.nodeID)) logger.Info("start dist handler") ticker := time.NewTicker(Params.QueryCoordCfg.DistPullInterval) - id := int64(1) failures := 0 for { select { @@ -69,36 +68,25 @@ func (dh *distHandler) start(ctx context.Context) { logger.Info("close dist handelr") return case <-ticker.C: - dh.mu.Lock() - cctx, cancel := context.WithTimeout(ctx, distReqTimeout) - resp, err := dh.client.GetDataDistribution(cctx, dh.nodeID, &querypb.GetDataDistributionRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_GetDistribution), - ), - }) - cancel() - - if err != nil || resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - failures++ - dh.logFailureInfo(resp, err) - node := dh.nodeManager.Get(dh.nodeID) - if node != nil { - log.RatedDebug(30.0, "failed to get node's data distribution", - zap.Int64("nodeID", dh.nodeID), - zap.Time("lastHeartbeat", node.LastHeartbeat()), - ) + dh.getDistribution(ctx, func(isSuccess bool) { + if !isSuccess { + failures++ + node := dh.nodeManager.Get(dh.nodeID) + if node != nil { + log.RatedDebug(30.0, "failed to get node's data distribution", + zap.Int64("nodeID", dh.nodeID), + zap.Time("lastHeartbeat", node.LastHeartbeat()), + ) + } + } else { + failures = 0 } - } else { - failures = 0 - dh.handleDistResp(resp) - } - if failures >= maxFailureTimes { - log.RatedInfo(30.0, fmt.Sprintf("can not get data distribution from node %d for %d times", dh.nodeID, failures)) - // TODO: kill the querynode server and stop the loop? - } - id++ - dh.mu.Unlock() + if failures >= maxFailureTimes { + log.RatedInfo(30.0, fmt.Sprintf("can not get data distribution from node %d for %d times", dh.nodeID, failures)) + // TODO: kill the querynode server and stop the loop? + } + }) } } } @@ -219,7 +207,7 @@ func (dh *distHandler) updateLeaderView(resp *querypb.GetDataDistributionRespons dh.dist.LeaderViewManager.Update(resp.GetNodeID(), updates...) } -func (dh *distHandler) getDistribution(ctx context.Context) { +func (dh *distHandler) getDistribution(ctx context.Context, fn func(isSuccess bool)) { dh.mu.Lock() defer dh.mu.Unlock() cctx, cancel := context.WithTimeout(ctx, distReqTimeout) @@ -230,11 +218,16 @@ func (dh *distHandler) getDistribution(ctx context.Context) { }) cancel() - if err != nil || resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + isSuccess := err != nil || resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success + if isSuccess { dh.logFailureInfo(resp, err) } else { dh.handleDistResp(resp) } + + if fn != nil { + fn(isSuccess) + } } func (dh *distHandler) stop() { diff --git a/internal/querycoordv2/mocks/querynode.go b/internal/querycoordv2/mocks/querynode.go index 3c078faf59..329bf7371b 100644 --- a/internal/querycoordv2/mocks/querynode.go +++ b/internal/querycoordv2/mocks/querynode.go @@ -123,6 +123,10 @@ func (node *MockQueryNode) Start() error { return err } +func (node *MockQueryNode) Stopping() { + node.session.GoingStop() +} + func (node *MockQueryNode) Stop() { node.cancel() node.server.GracefulStop() diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index c2e98674ff..a10bbc66aa 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -559,6 +559,16 @@ func (s *Server) watchNodes(revision int64) { s.handleNodeUp(nodeID) s.metricsCacheManager.InvalidateSystemInfoMetrics() + case sessionutil.SessionUpdateEvent: + nodeID := event.Session.ServerID + addr := event.Session.Address + log.Info("stopping the node", + zap.Int64("nodeID", nodeID), + zap.String("nodeAddr", addr), + ) + s.nodeMgr.Stopping(nodeID) + s.checkerController.Check() + case sessionutil.SessionDelEvent: nodeID := event.Session.ServerID log.Info("a node down, remove it", zap.Int64("nodeID", nodeID)) diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index 268cdf3e2d..48fb9b05f7 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -185,6 +185,16 @@ func (suite *ServerSuite) TestNodeUp() { } +func (suite *ServerSuite) TestNodeUpdate() { + downNode := suite.nodes[0] + downNode.Stopping() + + suite.Eventually(func() bool { + node := suite.server.nodeMgr.Get(downNode.ID) + return node.IsStoppingState() + }, 5*time.Second, time.Second) +} + func (suite *ServerSuite) TestNodeDown() { downNode := suite.nodes[0] downNode.Stop() diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index ea715eb206..0d31d6d4ba 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -608,8 +608,6 @@ func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() { replicas := suite.meta.ReplicaManager.GetByCollection(collection) replicas[0].AddNode(srcNode) replicas[0].AddNode(dstNode) - defer replicas[0].RemoveNode(srcNode) - defer replicas[0].RemoveNode(dstNode) suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) for partition, segments := range suite.segments[collection] { @@ -617,13 +615,21 @@ func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() { metaSegments = append(metaSegments, utils.CreateTestSegment(collection, partition, segment, srcNode, 1, "test-channel")) - if segmentOnCollection[collection] == nil { - segmentOnCollection[collection] = make([]int64, 0) - } segmentOnCollection[collection] = append(segmentOnCollection[collection], segment) } } } + suite.nodeMgr.Add(session.NewNodeInfo(1001, "localhost")) + suite.nodeMgr.Add(session.NewNodeInfo(1002, "localhost")) + defer func() { + for _, collection := range suite.collections { + replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas[0].RemoveNode(srcNode) + replicas[0].RemoveNode(dstNode) + } + suite.nodeMgr.Remove(1001) + suite.nodeMgr.Remove(1002) + }() suite.dist.SegmentDistManager.Update(srcNode, metaSegments...) // expect each collection can only trigger its own segment's balance diff --git a/internal/querycoordv2/session/node_manager.go b/internal/querycoordv2/session/node_manager.go index 749c55e7e5..f748cd0ebb 100644 --- a/internal/querycoordv2/session/node_manager.go +++ b/internal/querycoordv2/session/node_manager.go @@ -26,6 +26,7 @@ import ( type Manager interface { Add(node *NodeInfo) + Stopping(nodeID int64) Remove(nodeID int64) Get(nodeID int64) *NodeInfo GetAll() []*NodeInfo @@ -50,6 +51,14 @@ func (m *NodeManager) Remove(nodeID int64) { metrics.QueryCoordNumQueryNodes.WithLabelValues().Set(float64(len(m.nodes))) } +func (m *NodeManager) Stopping(nodeID int64) { + m.mu.Lock() + defer m.mu.Unlock() + if nodeInfo, ok := m.nodes[nodeID]; ok { + nodeInfo.SetState(NodeStateStopping) + } +} + func (m *NodeManager) Get(nodeID int64) *NodeInfo { m.mu.RLock() defer m.mu.RUnlock() @@ -72,11 +81,19 @@ func NewNodeManager() *NodeManager { } } +type State int + +const ( + NodeStateNormal = iota + NodeStateStopping +) + type NodeInfo struct { stats mu sync.RWMutex id int64 addr string + state State lastHeartbeat *atomic.Int64 } @@ -108,6 +125,18 @@ func (n *NodeInfo) LastHeartbeat() time.Time { return time.Unix(0, n.lastHeartbeat.Load()) } +func (n *NodeInfo) IsStoppingState() bool { + n.mu.RLock() + defer n.mu.RUnlock() + return n.state == NodeStateStopping +} + +func (n *NodeInfo) SetState(s State) { + n.mu.Lock() + defer n.mu.Unlock() + n.state = s +} + func (n *NodeInfo) UpdateStats(opts ...StatsOption) { n.mu.Lock() for _, opt := range opts { diff --git a/internal/querycoordv2/task/utils.go b/internal/querycoordv2/task/utils.go index 8aaa5c3776..4986c4529e 100644 --- a/internal/querycoordv2/task/utils.go +++ b/internal/querycoordv2/task/utils.go @@ -56,6 +56,12 @@ func SetPriority(priority Priority, tasks ...Task) { } } +func SetPriorityWithFunc(f func(t Task) Priority, tasks ...Task) { + for i := range tasks { + tasks[i].SetPriority(f(tasks[i])) + } +} + // GetTaskType returns the task's type, // for now, only 3 types; // - only 1 grow action -> Grow diff --git a/internal/querynode/collection_test.go b/internal/querynode/collection_test.go index ce29a16525..e94f865b4e 100644 --- a/internal/querynode/collection_test.go +++ b/internal/querynode/collection_test.go @@ -53,6 +53,7 @@ func TestCollection_schema(t *testing.T) { } func TestCollection_vChannel(t *testing.T) { + Params.Init() collectionID := UniqueID(0) schema := genTestCollectionSchema() diff --git a/internal/querynode/data_sync_service_test.go b/internal/querynode/data_sync_service_test.go index a518a09ea1..a8b2a3f540 100644 --- a/internal/querynode/data_sync_service_test.go +++ b/internal/querynode/data_sync_service_test.go @@ -27,6 +27,10 @@ import ( "github.com/stretchr/testify/suite" ) +func init() { + rateCol, _ = newRateCollector() +} + func TestDataSyncService_DMLFlowGraphs(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index b3c1944436..be287668f0 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -45,6 +45,17 @@ import ( "github.com/milvus-io/milvus/internal/util/typeutil" ) +// isHealthy checks if QueryNode is healthy +func (node *QueryNode) isHealthy() bool { + code := node.stateCode.Load().(commonpb.StateCode) + return code == commonpb.StateCode_Healthy +} + +func (node *QueryNode) isHealthyOrStopping() bool { + code := node.stateCode.Load().(commonpb.StateCode) + return code == commonpb.StateCode_Healthy || code == commonpb.StateCode_Stopping +} + // GetComponentStates returns information about whether the node is healthy func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { stats := &milvuspb.ComponentStates{ @@ -159,10 +170,12 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que }, } - if !node.isHealthy() { + if !node.isHealthyOrStopping() { failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID()) return failRet, nil } + node.wg.Add(1) + defer node.wg.Done() traceID, _, _ := trace.InfoFromContext(ctx) log.Ctx(ctx).Debug("received GetStatisticRequest", @@ -286,8 +299,7 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que // WatchDmChannels create consumers on dmChannels to receive Incremental data,which is the important part of real-time query func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { // check node healthy - code := node.stateCode.Load().(commonpb.StateCode) - if code != commonpb.StateCode_Healthy { + if !node.isHealthy() { err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -295,6 +307,8 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC } return status, nil } + node.wg.Add(1) + defer node.wg.Done() // check target matches if in.GetBase().GetTargetID() != paramtable.GetNodeID() { @@ -369,8 +383,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { // check node healthy - code := node.stateCode.Load().(commonpb.StateCode) - if code != commonpb.StateCode_Healthy { + if !node.isHealthyOrStopping() { err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -378,6 +391,8 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC } return status, nil } + node.wg.Add(1) + defer node.wg.Done() // check target matches if req.GetBase().GetTargetID() != paramtable.GetNodeID() { @@ -430,8 +445,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC // LoadSegments load historical data into query node, historical data can be vector data or index func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { // check node healthy - code := node.stateCode.Load().(commonpb.StateCode) - if code != commonpb.StateCode_Healthy { + if !node.isHealthy() { err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -439,6 +453,9 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment } return status, nil } + node.wg.Add(1) + defer node.wg.Done() + // check target matches if in.GetBase().GetTargetID() != paramtable.GetNodeID() { status := &commonpb.Status{ @@ -511,8 +528,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment // ReleaseCollection clears all data related to this collection on the querynode func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - code := node.stateCode.Load().(commonpb.StateCode) - if code != commonpb.StateCode_Healthy { + if !node.isHealthyOrStopping() { err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -520,6 +536,9 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas } return status, nil } + node.wg.Add(1) + defer node.wg.Done() + dct := &releaseCollectionTask{ baseTask: baseTask{ ctx: ctx, @@ -557,8 +576,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas // ReleasePartitions clears all data related to this partition on the querynode func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { - code := node.stateCode.Load().(commonpb.StateCode) - if code != commonpb.StateCode_Healthy { + if !node.isHealthyOrStopping() { err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -566,6 +584,9 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas } return status, nil } + node.wg.Add(1) + defer node.wg.Done() + dct := &releasePartitionsTask{ baseTask: baseTask{ ctx: ctx, @@ -603,9 +624,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas // ReleaseSegments remove the specified segments from query node according segmentIDs, partitionIDs, and collectionID func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { - // check node healthy - code := node.stateCode.Load().(commonpb.StateCode) - if code != commonpb.StateCode_Healthy { + if !node.isHealthyOrStopping() { err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -613,6 +632,9 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS } return status, nil } + node.wg.Add(1) + defer node.wg.Done() + // check target matches if in.GetBase().GetTargetID() != paramtable.GetNodeID() { status := &commonpb.Status{ @@ -651,8 +673,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS // GetSegmentInfo returns segment information of the collection on the queryNode, and the information includes memSize, numRow, indexName, indexID ... func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - code := node.stateCode.Load().(commonpb.StateCode) - if code != commonpb.StateCode_Healthy { + if !node.isHealthyOrStopping() { err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) res := &querypb.GetSegmentInfoResponse{ Status: &commonpb.Status{ @@ -662,6 +683,9 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmen } return res, nil } + node.wg.Add(1) + defer node.wg.Done() + var segmentInfos []*querypb.SegmentInfo segmentIDs := make(map[int64]struct{}) @@ -696,12 +720,6 @@ func filterSegmentInfo(segmentInfos []*querypb.SegmentInfo, segmentIDs map[int64 return filtered } -// isHealthy checks if QueryNode is healthy -func (node *QueryNode) isHealthy() bool { - code := node.stateCode.Load().(commonpb.StateCode) - return code == commonpb.StateCode_Healthy -} - // Search performs replica search tasks. func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { log.Ctx(ctx).Debug("Received SearchRequest", @@ -785,10 +803,12 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc() } }() - if !node.isHealthy() { + if !node.isHealthyOrStopping() { failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID()) return failRet, nil } + node.wg.Add(1) + defer node.wg.Done() msgID := req.GetReq().GetBase().GetMsgID() log.Ctx(ctx).Debug("Received SearchRequest", @@ -935,10 +955,12 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc() } }() - if !node.isHealthy() { + if !node.isHealthyOrStopping() { failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID()) return failRet, nil } + node.wg.Add(1) + defer node.wg.Done() traceID, _, _ := trace.InfoFromContext(ctx) log.Ctx(ctx).Debug("queryWithDmlChannel receives query request", @@ -1145,6 +1167,8 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn Reason: msgQueryNodeIsUnhealthy(paramtable.GetNodeID()), }, nil } + node.wg.Add(1) + defer node.wg.Done() log.Info("Received SyncReplicaSegments request", zap.String("vchannelName", req.GetVchannelName())) @@ -1164,7 +1188,7 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn // ShowConfigurations returns the configurations of queryNode matching req.Pattern func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - if !node.isHealthy() { + if !node.isHealthyOrStopping() { log.Warn("QueryNode.ShowConfigurations failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.String("req", req.Pattern), @@ -1178,13 +1202,15 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S Configuations: nil, }, nil } + node.wg.Add(1) + defer node.wg.Done() return getComponentConfigurations(ctx, req), nil } // GetMetrics return system infos of the query node, such as total memory, memory usage, cpu usage ... func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - if !node.isHealthy() { + if !node.isHealthyOrStopping() { log.Warn("QueryNode.GetMetrics failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.String("req", req.Request), @@ -1198,6 +1224,8 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR Response: "", }, nil } + node.wg.Add(1) + defer node.wg.Done() metricType, err := metricsinfo.ParseMetricType(req.Request) if err != nil { @@ -1257,7 +1285,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get zap.Int64("msg-id", req.GetBase().GetMsgID()), zap.Int64("node-id", paramtable.GetNodeID()), ) - if !node.isHealthy() { + if !node.isHealthyOrStopping() { log.Warn("QueryNode.GetMetrics failed", zap.Error(errQueryNodeIsUnhealthy(paramtable.GetNodeID()))) @@ -1268,6 +1296,8 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get }, }, nil } + node.wg.Add(1) + defer node.wg.Done() // check target matches if req.GetBase().GetTargetID() != paramtable.GetNodeID() { @@ -1345,8 +1375,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) { log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannel())) // check node healthy - code := node.stateCode.Load().(commonpb.StateCode) - if code != commonpb.StateCode_Healthy { + if !node.isHealthyOrStopping() { err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -1354,6 +1383,9 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi } return status, nil } + node.wg.Add(1) + defer node.wg.Done() + // check target matches if req.GetBase().GetTargetID() != paramtable.GetNodeID() { log.Warn("failed to do match target id when sync ", zap.Int64("expect", req.GetBase().GetTargetID()), zap.Int64("actual", node.session.ServerID)) diff --git a/internal/querynode/impl_test.go b/internal/querynode/impl_test.go index 13fc2089a9..72c7458449 100644 --- a/internal/querynode/impl_test.go +++ b/internal/querynode/impl_test.go @@ -144,6 +144,20 @@ func TestImpl_WatchDmChannels(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) }) + t.Run("server stopping", func(t *testing.T) { + req := &queryPb.WatchDmChannelsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchDmChannels, + MsgID: rand.Int63(), + }, + } + node.UpdateStateCode(commonpb.StateCode_Stopping) + defer node.UpdateStateCode(commonpb.StateCode_Healthy) + status, err := node.WatchDmChannels(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) + }) + t.Run("mock release after loaded", func(t *testing.T) { mockTSReplica := &MockTSafeReplicaInterface{} @@ -253,6 +267,15 @@ func TestImpl_LoadSegments(t *testing.T) { t.Run("server unhealthy", func(t *testing.T) { node.UpdateStateCode(commonpb.StateCode_Abnormal) + defer node.UpdateStateCode(commonpb.StateCode_Healthy) + status, err := node.LoadSegments(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) + }) + + t.Run("server stopping", func(t *testing.T) { + node.UpdateStateCode(commonpb.StateCode_Stopping) + defer node.UpdateStateCode(commonpb.StateCode_Healthy) status, err := node.LoadSegments(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) diff --git a/internal/querynode/meta_replica.go b/internal/querynode/meta_replica.go index 42193acf67..ec2ab6a470 100644 --- a/internal/querynode/meta_replica.go +++ b/internal/querynode/meta_replica.go @@ -139,6 +139,7 @@ type ReplicaInterface interface { getGrowingSegments() []*Segment getSealedSegments() []*Segment + getNoSegmentChan() <-chan struct{} } // collectionReplica is the data replication of memory data in query node. @@ -149,6 +150,7 @@ type metaReplica struct { partitions map[UniqueID]*Partition growingSegments map[UniqueID]*Segment sealedSegments map[UniqueID]*Segment + noSegmentChan chan struct{} excludedSegments map[UniqueID][]*datapb.SegmentInfo // map[collectionID]segmentIDs @@ -743,6 +745,26 @@ func (replica *metaReplica) removeSegmentPrivate(segmentID UniqueID, segType seg ).Sub(float64(rowCount)) } } + replica.sendNoSegmentSignal() +} + +func (replica *metaReplica) sendNoSegmentSignal() { + if replica.noSegmentChan == nil { + return + } + select { + case <-replica.noSegmentChan: + default: + if len(replica.growingSegments) == 0 && len(replica.sealedSegments) == 0 { + close(replica.noSegmentChan) + } + } +} + +func (replica *metaReplica) getNoSegmentChan() <-chan struct{} { + replica.noSegmentChan = make(chan struct{}) + replica.sendNoSegmentSignal() + return replica.noSegmentChan } // getSegmentByID returns the segment which id is segmentID diff --git a/internal/querynode/meta_replica_test.go b/internal/querynode/meta_replica_test.go index e7a06b0e5e..bef74aea92 100644 --- a/internal/querynode/meta_replica_test.go +++ b/internal/querynode/meta_replica_test.go @@ -17,7 +17,9 @@ package querynode import ( + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -169,6 +171,42 @@ func TestMetaReplica_segment(t *testing.T) { } }) + t.Run("test getNoSegmentChan", func(t *testing.T) { + replica, err := genSimpleReplica() + assert.NoError(t, err) + defer replica.freeAll() + + select { + case <-replica.getNoSegmentChan(): + default: + assert.FailNow(t, "fail to assert getNoSegmentChan") + } + + const segmentNum = 3 + c := replica.getNoSegmentChan() + w := sync.WaitGroup{} + w.Add(1) + go func() { + defer w.Done() + select { + case <-c: + case <-time.After(5 * time.Second): + assert.FailNow(t, "timeout getNoSegmentChan") + } + }() + for i := 0; i < segmentNum; i++ { + err := replica.addSegment(UniqueID(i), defaultPartitionID, defaultCollectionID, "", defaultSegmentVersion, defaultSegmentStartPosition, segmentTypeGrowing) + assert.NoError(t, err) + targetSeg, err := replica.getSegmentByID(UniqueID(i), segmentTypeGrowing) + assert.NoError(t, err) + assert.Equal(t, UniqueID(i), targetSeg.segmentID) + } + for i := 0; i < segmentNum; i++ { + replica.removeSegment(UniqueID(i), segmentTypeGrowing) + } + w.Wait() + }) + t.Run("test hasSegment", func(t *testing.T) { replica, err := genSimpleReplica() assert.NoError(t, err) diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 233d0b63d5..7508e3292f 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -30,6 +30,8 @@ import "C" import ( "context" "fmt" + "github.com/samber/lo" + uberatomic "go.uber.org/atomic" "os" "path" "runtime" @@ -87,6 +89,7 @@ type QueryNode struct { wg sync.WaitGroup stateCode atomic.Value + stopFlag uberatomic.Bool //call once initOnce sync.Once @@ -137,6 +140,7 @@ func NewQueryNode(ctx context.Context, factory dependency.Factory) *QueryNode { node.tSafeReplica = newTSafeReplica() node.scheduler = newTaskScheduler(ctx1, node.tSafeReplica) node.UpdateStateCode(commonpb.StateCode_Abnormal) + node.stopFlag.Store(false) return node } @@ -324,8 +328,34 @@ func (node *QueryNode) Start() error { // Stop mainly stop QueryNode's query service, historical loop and streaming loop. func (node *QueryNode) Stop() error { + if node.stopFlag.Load() { + return nil + } + node.stopFlag.Store(true) + log.Warn("Query node stop..") + err := node.session.GoingStop() + if err != nil { + log.Warn("session fail to go stopping state", zap.Error(err)) + } else { + node.UpdateStateCode(commonpb.StateCode_Stopping) + noSegmentChan := node.metaReplica.getNoSegmentChan() + select { + case <-noSegmentChan: + case <-time.After(time.Duration(Params.QueryNodeCfg.GracefulStopTimeout) * time.Second): + log.Warn("migrate data timed out", zap.Int64("server_id", paramtable.GetNodeID()), + zap.Int64s("sealed_segment", lo.Map[*Segment, int64](node.metaReplica.getSealedSegments(), func(t *Segment, i int) int64 { + return t.ID() + })), + zap.Int64s("growing_segment", lo.Map[*Segment, int64](node.metaReplica.getGrowingSegments(), func(t *Segment, i int) int64 { + return t.ID() + })), + ) + } + } + node.UpdateStateCode(commonpb.StateCode_Abnormal) + node.wg.Wait() node.queryNodeLoopCancel() // close services @@ -346,7 +376,6 @@ func (node *QueryNode) Stop() error { } node.session.Revoke(time.Second) - node.wg.Wait() return nil } diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index 95360f8d3d..6b0cf3f0e0 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -108,6 +108,7 @@ func newQueryNodeMock() *QueryNode { } svr.loader = newSegmentLoader(svr.metaReplica, etcdKV, svr.vectorStorage, factory) svr.etcdKV = etcdKV + svr.etcdCli = etcdCli return svr } diff --git a/internal/util/paramtable/component_param.go b/internal/util/paramtable/component_param.go index 96b282bfd6..fb34e5dbf5 100644 --- a/internal/util/paramtable/component_param.go +++ b/internal/util/paramtable/component_param.go @@ -34,6 +34,7 @@ const ( // DefaultIndexSliceSize defines the default slice size of index file when serializing. DefaultIndexSliceSize = 16 DefaultGracefulTime = 5000 //ms + DefaultGracefulStopTimeout = 30 // s DefaultThreadCoreCoefficient = 10 DefaultSessionTTL = 60 //s @@ -147,6 +148,7 @@ type commonConfig struct { LoadNumThreadRatio float64 BeamWidthRatio float64 GracefulTime int64 + GracefulStopTimeout int64 // unit: s StorageType string SimdType string @@ -198,6 +200,7 @@ func (p *commonConfig) init(base *BaseTable) { p.initLoadNumThreadRatio() p.initBeamWidthRatio() p.initGracefulTime() + p.initGracefulStopTimeout() p.initStorageType() p.initThreadCoreCoefficient() @@ -434,6 +437,10 @@ func (p *commonConfig) initGracefulTime() { p.GracefulTime = p.Base.ParseInt64WithDefault("common.gracefulTime", DefaultGracefulTime) } +func (p *commonConfig) initGracefulStopTimeout() { + p.GracefulStopTimeout = p.Base.ParseInt64WithDefault("common.gracefulStopTimeout", DefaultGracefulStopTimeout) +} + func (p *commonConfig) initStorageType() { p.StorageType = p.Base.LoadWithDefault("common.storageType", "minio") } @@ -941,6 +948,8 @@ type queryNodeConfig struct { GCHelperEnabled bool MinimumGOGCConfig int MaximumGOGCConfig int + + GracefulStopTimeout int64 } func (p *queryNodeConfig) init(base *BaseTable) { @@ -973,6 +982,8 @@ func (p *queryNodeConfig) init(base *BaseTable) { p.initGCTunerEnbaled() p.initMaximumGOGC() p.initMinimumGOGC() + + p.initGracefulStopTimeout() } // InitAlias initializes an alias for the QueryNode role. @@ -1145,6 +1156,10 @@ func (p *queryNodeConfig) initMaximumGOGC() { p.MaximumGOGCConfig = p.Base.ParseIntWithDefault("queryNode.gchelper.maximumGoGC", 200) } +func (p *queryNodeConfig) initGracefulStopTimeout() { + p.GracefulStopTimeout = p.Base.ParseInt64WithDefault("queryNode.gracefulStopTimeout", params.CommonCfg.GracefulStopTimeout) +} + // ///////////////////////////////////////////////////////////////////////////// // --- datacoord --- type dataCoordConfig struct { diff --git a/internal/util/paramtable/component_param_test.go b/internal/util/paramtable/component_param_test.go index e4e8d02b8e..bdfc774130 100644 --- a/internal/util/paramtable/component_param_test.go +++ b/internal/util/paramtable/component_param_test.go @@ -68,6 +68,13 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, Params.GracefulTime, int64(DefaultGracefulTime)) t.Logf("default grafeful time = %d", Params.GracefulTime) + assert.Equal(t, Params.GracefulStopTimeout, int64(DefaultGracefulStopTimeout)) + assert.Equal(t, params.QueryNodeCfg.GracefulStopTimeout, Params.GracefulStopTimeout) + t.Logf("default grafeful stop timeout = %d", Params.GracefulStopTimeout) + Params.Base.Save("common.gracefulStopTimeout", "50") + Params.initGracefulStopTimeout() + assert.Equal(t, Params.GracefulStopTimeout, int64(50)) + // -- proxy -- assert.Equal(t, Params.ProxySubName, "by-dev-proxy") t.Logf("ProxySubName: %s", Params.ProxySubName) @@ -291,6 +298,11 @@ func TestComponentParam(t *testing.T) { nprobe = Params.SmallIndexNProbe assert.Equal(t, int64(4), nprobe) + + Params.Base.Save("queryNode.gracefulStopTimeout", "100") + Params.initGracefulStopTimeout() + gracefulStopTimeout := Params.GracefulStopTimeout + assert.Equal(t, int64(100), gracefulStopTimeout) }) t.Run("test dataCoordConfig", func(t *testing.T) { diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index 02c13b406c..e82e44a82c 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -54,6 +54,8 @@ func (t SessionEventType) String() string { return "SessionAddEvent" case SessionDelEvent: return "SessionDelEvent" + case SessionUpdateEvent: + return "SessionUpdateEvent" default: return "" } @@ -71,6 +73,8 @@ const ( SessionAddEvent // SessionDelEvent event type for a Session deleted SessionDelEvent + // SessionUpdateEvent event type for a Session stopping + SessionUpdateEvent ) // Session is a struct to store service's session, including ServerID, ServerName, @@ -86,6 +90,7 @@ type Session struct { ServerName string `json:"ServerName,omitempty"` Address string `json:"Address,omitempty"` Exclusive bool `json:"Exclusive,omitempty"` + Stopping bool `json:"Stopping,omitempty"` TriggerKill bool Version semver.Version `json:"Version,omitempty"` @@ -138,6 +143,7 @@ func (s *Session) UnmarshalJSON(data []byte) error { ServerName string `json:"ServerName,omitempty"` Address string `json:"Address,omitempty"` Exclusive bool `json:"Exclusive,omitempty"` + Stopping bool `json:"Stopping,omitempty"` TriggerKill bool Version string `json:"Version"` } @@ -157,6 +163,7 @@ func (s *Session) UnmarshalJSON(data []byte) error { s.ServerName = raw.ServerName s.Address = raw.Address s.Exclusive = raw.Exclusive + s.Stopping = raw.Stopping s.TriggerKill = raw.TriggerKill return nil } @@ -170,6 +177,7 @@ func (s *Session) MarshalJSON() ([]byte, error) { ServerName string `json:"ServerName,omitempty"` Address string `json:"Address,omitempty"` Exclusive bool `json:"Exclusive,omitempty"` + Stopping bool `json:"Stopping,omitempty"` TriggerKill bool Version string `json:"Version"` }{ @@ -177,6 +185,7 @@ func (s *Session) MarshalJSON() ([]byte, error) { ServerName: s.ServerName, Address: s.Address, Exclusive: s.Exclusive, + Stopping: s.Stopping, TriggerKill: s.TriggerKill, Version: verStr, }) @@ -325,6 +334,14 @@ func (s *Session) getServerIDWithKey(key string) (int64, error) { } } +func (s *Session) getCompleteKey() string { + key := s.ServerName + if !s.Exclusive || (s.enableActiveStandBy && s.isStandby.Load().(bool)) { + key = fmt.Sprintf("%s-%d", key, s.ServerID) + } + return path.Join(s.metaRoot, DefaultServiceRoot, key) +} + // registerService registers the service to etcd so that other services // can find that the service is online and issue subsequent operations // RegisterService will save a key-value in etcd @@ -342,11 +359,7 @@ func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, er if s.enableActiveStandBy { s.updateStandby(true) } - key := s.ServerName - if !s.Exclusive || s.enableActiveStandBy { - key = fmt.Sprintf("%s-%d", key, s.ServerID) - } - completeKey := path.Join(s.metaRoot, DefaultServiceRoot, key) + completeKey := s.getCompleteKey() var ch <-chan *clientv3.LeaseKeepAliveResponse log.Debug("service begin to register to etcd", zap.String("serverName", s.ServerName), zap.Int64("ServerID", s.ServerID)) @@ -383,7 +396,7 @@ func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, er } if !txnResp.Succeeded { - return fmt.Errorf("function CompareAndSwap error for compare is false for key: %s", key) + return fmt.Errorf("function CompareAndSwap error for compare is false for key: %s", s.ServerName) } log.Debug("put session key into etcd", zap.String("key", completeKey), zap.String("value", string(sessionJSON))) @@ -497,6 +510,34 @@ func (s *Session) GetSessionsWithVersionRange(prefix string, r semver.Range) (ma return res, resp.Header.Revision, nil } +func (s *Session) GoingStop() error { + if s == nil || s.etcdCli == nil || s.leaseID == nil { + return errors.New("the session hasn't been init") + } + + completeKey := s.getCompleteKey() + resp, err := s.etcdCli.Get(s.ctx, completeKey, clientv3.WithCountOnly()) + if err != nil { + log.Error("fail to get the session", zap.String("key", completeKey), zap.Error(err)) + return err + } + if resp.Count == 0 { + return nil + } + s.Stopping = true + sessionJSON, err := json.Marshal(s) + if err != nil { + log.Error("fail to marshal the session", zap.String("key", completeKey)) + return err + } + _, err = s.etcdCli.Put(s.ctx, completeKey, string(sessionJSON), clientv3.WithLease(*s.leaseID)) + if err != nil { + log.Error("fail to update the session to stopping state", zap.String("key", completeKey)) + return err + } + return nil +} + // SessionEvent indicates the changes of other servers. // if a server is up, EventType is SessAddEvent. // if a server is down, EventType is SessDelEvent. @@ -596,7 +637,11 @@ func (w *sessionWatcher) handleWatchResponse(wresp clientv3.WatchResponse) { if !w.validate(session) { continue } - eventType = SessionAddEvent + if session.Stopping { + eventType = SessionUpdateEvent + } else { + eventType = SessionAddEvent + } case mvccpb.DELETE: log.Debug("watch services", zap.Any("delete kv", ev.PrevKv)) diff --git a/internal/util/sessionutil/session_util_test.go b/internal/util/sessionutil/session_util_test.go index b1336e3bfc..4ee357cf56 100644 --- a/internal/util/sessionutil/session_util_test.go +++ b/internal/util/sessionutil/session_util_test.go @@ -692,6 +692,7 @@ func TestSessionEventType_String(t *testing.T) { {t: SessionNoneEvent, want: ""}, {t: SessionAddEvent, want: "SessionAddEvent"}, {t: SessionDelEvent, want: "SessionDelEvent"}, + {t: SessionUpdateEvent, want: "SessionUpdateEvent"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/scripts/check_proto_product.sh b/scripts/check_proto_product.sh index 2a909e96ce..c1cf354bb0 100755 --- a/scripts/check_proto_product.sh +++ b/scripts/check_proto_product.sh @@ -29,7 +29,9 @@ if [[ $(uname -s) == "Darwin" ]]; then export PATH="/usr/local/opt/grep/libexec/gnubin:$PATH" fi -if test -z "$(git status | grep -E "*pb.go|*pb.cc|*pb.h")"; then +check_result=$(git status | grep -E "*pb.go|*pb.cc|*pb.h") +echo "check_result: $check_result" +if test -z "$check_result"; then exit 0 else echo "The go file or cpp file generated by proto are not latest!" diff --git a/scripts/stop_graceful.sh b/scripts/stop_graceful.sh new file mode 100755 index 0000000000..0d8f753978 --- /dev/null +++ b/scripts/stop_graceful.sh @@ -0,0 +1,48 @@ +# Licensed to the LF AI & Data foundation under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +function get_milvus_process() { + return $(ps -e | grep milvus | grep -v grep | awk '{print $1}') +} + +echo "Stopping milvus..." +BASEDIR=$(dirname "$0") +timeout=120 +if [ ! -z "$1" ]; then + timeout=$1 +fi + +if [ -z $(get_milvus_process) ]; then + echo "No milvus process" + exit 0 +fi +kill -15 $(get_milvus_process) + +start=$(date +%s) +while : +do + sleep 1 + if [ -z $(get_milvus_process) ]; then + echo "Milvus stopped" + break + fi + if [ $(( $(date +%s) - $start )) -gt $timeout ]; then + echo "Milvus timeout stopped" + kill -15 $(get_milvus_process) + break + fi +done +