Support the graceful stop for the query node (#20851)

Signed-off-by: SimFG <bang.fu@zilliz.com>

Signed-off-by: SimFG <bang.fu@zilliz.com>
pull/21031/head
SimFG 2022-12-06 22:59:19 +08:00 committed by GitHub
parent 5e49b095f5
commit f8cff79804
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 730 additions and 150 deletions

View File

@ -199,6 +199,10 @@ func (mr *MilvusRoles) setupLogger() {
func (mr *MilvusRoles) Run(local bool, alias string) {
log.Info("starting running Milvus components")
ctx, cancel := context.WithCancel(context.Background())
defer func() {
// some deferred Stop has race with context cancel
cancel()
}()
mr.printLDPreLoad()
// only standalone enable localMsg
@ -322,7 +326,4 @@ func (mr *MilvusRoles) Run(local bool, alias string) {
syscall.SIGQUIT)
sig := <-sc
log.Error("Get signal to exit\n", zap.String("signal", sig.String()))
// some deferred Stop has race with context cancel
cancel()
}

View File

@ -374,6 +374,7 @@ common:
entityExpiration: -1 # Entity expiration in seconds, CAUTION make sure entityExpiration >= retentionDuration and -1 means never expire
gracefulTime: 5000 # milliseconds. it represents the interval (in ms) by which the request arrival time needs to be subtracted in the case of Bounded Consistency.
gracefulStopTimeout: 30 # seconds. it will force quit the server if the graceful stop process is not completed during this time.
# Default value: auto
# Valid values: [auto, avx512, avx2, avx, sse4_2]

12
go.mod
View File

@ -28,7 +28,7 @@ require (
github.com/klauspost/compress v1.14.2
github.com/lingdor/stackerror v0.0.0-20191119040541-976d8885ed76
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
github.com/milvus-io/milvus-proto/go-api v0.0.0-20221019080323-84e9fa2f9e45
github.com/milvus-io/milvus-proto/go-api v0.0.0-20221126103108-4d988e37ebf2
github.com/minio/minio-go/v7 v7.0.17
github.com/opentracing/opentracing-go v1.2.0
github.com/panjf2000/ants/v2 v2.4.8
@ -88,10 +88,8 @@ require (
github.com/facebookgo/ensure v0.0.0-20200202191622-63f1cf65ac4c // indirect
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect
github.com/facebookgo/subset v0.0.0-20200203212716-c811ad88dec4 // indirect
github.com/fatih/color v1.10.0 // indirect
github.com/form3tech-oss/jwt-go v3.2.3+incompatible // indirect
github.com/fsnotify/fsnotify v1.4.9 // indirect
github.com/ghodss/yaml v1.0.0 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.13.0 // indirect
@ -109,15 +107,12 @@ require (
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jonboulle/clockwork v0.2.2 // indirect
github.com/json-iterator/go v1.1.11 // indirect
github.com/klauspost/asmfmt v1.3.1 // indirect
github.com/klauspost/cpuid v1.3.1 // indirect
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
github.com/kris-nova/logger v0.0.0-20181127235838-fd0d87064b06 // indirect
github.com/kris-nova/lolgopher v0.0.0-20180921204813-313b3abb0d9b // indirect
github.com/leodido/go-urn v1.2.0 // indirect
github.com/linkedin/goavro/v2 v2.11.1 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
@ -147,7 +142,6 @@ require (
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/spf13/afero v1.6.0 // indirect
github.com/spf13/cobra v1.1.3 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.4.0 // indirect
@ -175,13 +169,13 @@ require (
go.opentelemetry.io/otel/sdk/metric v0.20.0 // indirect
go.opentelemetry.io/otel/trace v0.20.0 // indirect
go.opentelemetry.io/proto/otlp v0.7.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
go.uber.org/multierr v1.6.0
golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/text v0.3.7
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
golang.org/x/tools v0.1.9 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect

7
go.sum
View File

@ -197,7 +197,6 @@ github.com/facebookgo/subset v0.0.0-20200203212716-c811ad88dec4 h1:7HZCaLC5+BZpm
github.com/facebookgo/subset v0.0.0-20200203212716-c811ad88dec4/go.mod h1:5tD+neXqOorC30/tWg0LCSkrqj/AR6gu8yY8/fpw1q0=
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg=
github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/form3tech-oss/jwt-go v3.2.3+incompatible h1:7ZaBxOI7TMoYBfyA3cQHErNNyAWIKUMIwqxEtgHOs5c=
@ -457,9 +456,7 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kris-nova/logger v0.0.0-20181127235838-fd0d87064b06 h1:vN4d3jSss3ExzUn2cE0WctxztfOgiKvMKnDrydBsg00=
github.com/kris-nova/logger v0.0.0-20181127235838-fd0d87064b06/go.mod h1:++9BgZujZd4v0ZTZCb5iPsaomXdZWyxotIAh1IiDm44=
github.com/kris-nova/lolgopher v0.0.0-20180921204813-313b3abb0d9b h1:xYEM2oBUhBEhQjrV+KJ9lEWDWYZoNVZUaBF++Wyljq4=
github.com/kris-nova/lolgopher v0.0.0-20180921204813-313b3abb0d9b/go.mod h1:V0HF/ZBlN86HqewcDC/cVxMmYDiRukWjSrgKLUAn9Js=
github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/lingdor/stackerror v0.0.0-20191119040541-976d8885ed76 h1:IVlcvV0CjvfBYYod5ePe89l+3LBAl//6n9kJ9Vr2i0k=
@ -493,8 +490,8 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyex
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4=
github.com/milvus-io/milvus-proto/go-api v0.0.0-20221019080323-84e9fa2f9e45 h1:QxGQqRtJbbdMf/jxodjYaTW8ZcqdyPDsuhIb9Y2jnwk=
github.com/milvus-io/milvus-proto/go-api v0.0.0-20221019080323-84e9fa2f9e45/go.mod h1:148qnlmZ0Fdm1Fq+Mj/OW2uDoEP25g3mjh0vMGtkgmk=
github.com/milvus-io/milvus-proto/go-api v0.0.0-20221126103108-4d988e37ebf2 h1:kf0QwCVxVTV8HLX0rTsOFs4l/vw7MoUjKFVkd4yQRVM=
github.com/milvus-io/milvus-proto/go-api v0.0.0-20221126103108-4d988e37ebf2/go.mod h1:148qnlmZ0Fdm1Fq+Mj/OW2uDoEP25g3mjh0vMGtkgmk=
github.com/milvus-io/pulsar-client-go v0.6.8 h1:fZdZH73aPRszu2fazyeeahQEz34tyn1Pt9EkqJmV100=
github.com/milvus-io/pulsar-client-go v0.6.8/go.mod h1:oFIlYIk23tamkSLttw849qphmMIpHY8ztEBWDWJW+sc=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=

View File

@ -468,13 +468,13 @@ const char descriptor_table_protodef_common_2eproto[] PROTOBUF_SECTION_VARIABLE(
"ateUser\020\024\022\032\n\026PrivilegeDropOwnership\020\025\022\034\n"
"\030PrivilegeSelectOwnership\020\026\022\034\n\030Privilege"
"ManageOwnership\020\027\022\027\n\023PrivilegeSelectUser"
"\020\030*E\n\tStateCode\022\020\n\014Initializing\020\000\022\013\n\007Hea"
"lthy\020\001\022\014\n\010Abnormal\020\002\022\013\n\007StandBy\020\003:^\n\021pri"
"vilege_ext_obj\022\037.google.protobuf.Message"
"Options\030\351\007 \001(\0132!.milvus.proto.common.Pri"
"vilegeExtBU\n\016io.milvus.grpcB\013CommonProto"
"P\001Z1github.com/milvus-io/milvus-proto/go"
"-api/commonpb\240\001\001b\006proto3"
"\020\030*S\n\tStateCode\022\020\n\014Initializing\020\000\022\013\n\007Hea"
"lthy\020\001\022\014\n\010Abnormal\020\002\022\013\n\007StandBy\020\003\022\014\n\010Sto"
"pping\020\004:^\n\021privilege_ext_obj\022\037.google.pr"
"otobuf.MessageOptions\030\351\007 \001(\0132!.milvus.pr"
"oto.common.PrivilegeExtBU\n\016io.milvus.grp"
"cB\013CommonProtoP\001Z1github.com/milvus-io/m"
"ilvus-proto/go-api/commonpb\240\001\001b\006proto3"
;
static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_common_2eproto_deps[1] = {
&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto,
@ -495,7 +495,7 @@ static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_com
static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_common_2eproto_once;
static bool descriptor_table_common_2eproto_initialized = false;
const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_common_2eproto = {
&descriptor_table_common_2eproto_initialized, descriptor_table_protodef_common_2eproto, "common.proto", 5384,
&descriptor_table_common_2eproto_initialized, descriptor_table_protodef_common_2eproto, "common.proto", 5398,
&descriptor_table_common_2eproto_once, descriptor_table_common_2eproto_sccs, descriptor_table_common_2eproto_deps, 11, 1,
schemas, file_default_instances, TableStruct_common_2eproto::offsets,
file_level_metadata_common_2eproto, 11, file_level_enum_descriptors_common_2eproto, file_level_service_descriptors_common_2eproto,
@ -842,6 +842,7 @@ bool StateCode_IsValid(int value) {
case 1:
case 2:
case 3:
case 4:
return true;
default:
return false;

View File

@ -564,12 +564,13 @@ enum StateCode : int {
Healthy = 1,
Abnormal = 2,
StandBy = 3,
Stopping = 4,
StateCode_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(),
StateCode_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max()
};
bool StateCode_IsValid(int value);
constexpr StateCode StateCode_MIN = Initializing;
constexpr StateCode StateCode_MAX = StandBy;
constexpr StateCode StateCode_MAX = Stopping;
constexpr int StateCode_ARRAYSIZE = StateCode_MAX + 1;
const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* StateCode_descriptor();

View File

@ -221,6 +221,11 @@ func (s *Server) Run() error {
// Stop stops QueryNode's grpc service.
func (s *Server) Stop() error {
log.Debug("QueryNode stop", zap.String("Address", Params.GetAddress()))
err := s.querynode.Stop()
if err != nil {
return err
}
if s.closer != nil {
if err := s.closer.Close(); err != nil {
return err
@ -235,11 +240,6 @@ func (s *Server) Stop() error {
log.Debug("Graceful stop grpc server...")
s.grpcServer.GracefulStop()
}
err := s.querynode.Stop()
if err != nil {
return err
}
s.wg.Wait()
return nil
}

View File

@ -473,7 +473,6 @@ func Test_AlterSegmentsAndAddNewSegment(t *testing.T) {
return "", errors.New("key not found")
}
// TODO fubang
catalog := &Catalog{txn, "a"}
err := catalog.AlterSegmentsAndAddNewSegment(context.TODO(), []*datapb.SegmentInfo{droppedSegment}, segment1)
assert.NoError(t, err)

View File

@ -24,11 +24,40 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/task"
)
type Weight = int
const (
weightLow int = iota - 1
weightNormal
weightHigh
)
func GetWeight(w int) Weight {
if w > 0 {
return weightHigh
} else if w < 0 {
return weightLow
}
return weightNormal
}
func GetTaskPriorityFromWeight(w Weight) task.Priority {
switch w {
case weightHigh:
return task.TaskPriorityHigh
case weightLow:
return task.TaskPriorityLow
default:
return task.TaskPriorityNormal
}
}
type SegmentAssignPlan struct {
Segment *meta.Segment
ReplicaID int64
From int64 // -1 if empty
To int64
Weight Weight
}
type ChannelAssignPlan struct {
@ -36,6 +65,7 @@ type ChannelAssignPlan struct {
ReplicaID int64
From int64
To int64
Weight Weight
}
type Balance interface {
@ -99,7 +129,7 @@ func (b *RoundRobinBalancer) getNodes(nodes []int64) []*session.NodeInfo {
ret := make([]*session.NodeInfo, 0, len(nodes))
for _, n := range nodes {
node := b.nodeManager.Get(n)
if node != nil {
if node != nil && !node.IsStoppingState() {
ret = append(ret, node)
}
}

View File

@ -43,15 +43,17 @@ func (suite *BalanceTestSuite) TestAssignBalance() {
name string
nodeIDs []int64
segmentCnts []int
states []session.State
deltaCnts []int
assignments []*meta.Segment
expectPlans []SegmentAssignPlan
}{
{
name: "normal assignment",
nodeIDs: []int64{1, 2},
segmentCnts: []int{100, 200},
deltaCnts: []int{0, -200},
nodeIDs: []int64{1, 2, 3},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateStopping},
segmentCnts: []int{100, 200, 0},
deltaCnts: []int{0, -200, 0},
assignments: []*meta.Segment{
{SegmentInfo: &datapb.SegmentInfo{ID: 1}},
{SegmentInfo: &datapb.SegmentInfo{ID: 2}},
@ -67,6 +69,7 @@ func (suite *BalanceTestSuite) TestAssignBalance() {
name: "empty assignment",
nodeIDs: []int64{},
segmentCnts: []int{},
states: []session.State{},
deltaCnts: []int{},
assignments: []*meta.Segment{
{SegmentInfo: &datapb.SegmentInfo{ID: 1}},
@ -83,9 +86,12 @@ func (suite *BalanceTestSuite) TestAssignBalance() {
for i := range c.nodeIDs {
nodeInfo := session.NewNodeInfo(c.nodeIDs[i], "127.0.0.1:0")
nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i]))
nodeInfo.SetState(c.states[i])
suite.roundRobinBalancer.nodeManager.Add(nodeInfo)
if !nodeInfo.IsStoppingState() {
suite.mockScheduler.EXPECT().GetNodeSegmentDelta(c.nodeIDs[i]).Return(c.deltaCnts[i])
}
}
plans := suite.roundRobinBalancer.AssignSegment(c.assignments, c.nodeIDs)
suite.ElementsMatch(c.expectPlans, plans)
})
@ -97,15 +103,17 @@ func (suite *BalanceTestSuite) TestAssignChannel() {
name string
nodeIDs []int64
channelCnts []int
states []session.State
deltaCnts []int
assignments []*meta.DmChannel
expectPlans []ChannelAssignPlan
}{
{
name: "normal assignment",
nodeIDs: []int64{1, 2},
channelCnts: []int{100, 200},
deltaCnts: []int{0, -200},
nodeIDs: []int64{1, 2, 3},
channelCnts: []int{100, 200, 0},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateStopping},
deltaCnts: []int{0, -200, 0},
assignments: []*meta.DmChannel{
{VchannelInfo: &datapb.VchannelInfo{ChannelName: "channel-1"}},
{VchannelInfo: &datapb.VchannelInfo{ChannelName: "channel-2"}},
@ -121,6 +129,7 @@ func (suite *BalanceTestSuite) TestAssignChannel() {
name: "empty assignment",
nodeIDs: []int64{},
channelCnts: []int{},
states: []session.State{},
deltaCnts: []int{},
assignments: []*meta.DmChannel{
{VchannelInfo: &datapb.VchannelInfo{ChannelName: "channel-1"}},
@ -137,15 +146,32 @@ func (suite *BalanceTestSuite) TestAssignChannel() {
for i := range c.nodeIDs {
nodeInfo := session.NewNodeInfo(c.nodeIDs[i], "127.0.0.1:0")
nodeInfo.UpdateStats(session.WithChannelCnt(c.channelCnts[i]))
nodeInfo.SetState(c.states[i])
suite.roundRobinBalancer.nodeManager.Add(nodeInfo)
if !nodeInfo.IsStoppingState() {
suite.mockScheduler.EXPECT().GetNodeChannelDelta(c.nodeIDs[i]).Return(c.deltaCnts[i])
}
}
plans := suite.roundRobinBalancer.AssignChannel(c.assignments, c.nodeIDs)
suite.ElementsMatch(c.expectPlans, plans)
})
}
}
func (suite *BalanceTestSuite) TestWeight() {
suite.Run("GetWeight", func() {
suite.Equal(weightHigh, GetWeight(10))
suite.Equal(weightNormal, GetWeight(0))
suite.Equal(weightLow, GetWeight(-10))
})
suite.Run("GetTaskPriorityFromWeight", func() {
suite.Equal(task.TaskPriorityHigh, GetTaskPriorityFromWeight(weightHigh))
suite.Equal(task.TaskPriorityNormal, GetTaskPriorityFromWeight(weightNormal))
suite.Equal(task.TaskPriorityLow, GetTaskPriorityFromWeight(weightLow))
})
}
func TestBalanceSuite(t *testing.T) {
suite.Run(t, new(BalanceTestSuite))
}

View File

@ -28,17 +28,16 @@ import (
type RowCountBasedBalancer struct {
*RoundRobinBalancer
nodeManager *session.NodeManager
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
}
func (b *RowCountBasedBalancer) AssignSegment(segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
if len(nodes) == 0 {
nodeItems := b.convertToNodeItems(nodes)
if len(nodeItems) == 0 {
return nil
}
nodeItems := b.convertToNodeItems(nodes)
queue := newPriorityQueue()
for _, item := range nodeItems {
queue.push(item)
@ -68,7 +67,8 @@ func (b *RowCountBasedBalancer) AssignSegment(segments []*meta.Segment, nodes []
func (b *RowCountBasedBalancer) convertToNodeItems(nodeIDs []int64) []*nodeItem {
ret := make([]*nodeItem, 0, len(nodeIDs))
for _, node := range nodeIDs {
for _, nodeInfo := range b.getNodes(nodeIDs) {
node := nodeInfo.ID()
segments := b.dist.SegmentDistManager.GetByNode(node)
rowcnt := 0
for _, s := range segments {
@ -108,6 +108,8 @@ func (b *RowCountBasedBalancer) balanceReplica(replica *meta.Replica) ([]Segment
}
nodesRowCnt := make(map[int64]int)
nodesSegments := make(map[int64][]*meta.Segment)
stoppingNodesSegments := make(map[int64][]*meta.Segment)
totalCnt := 0
for _, nid := range nodes {
segments := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nid)
@ -120,13 +122,23 @@ func (b *RowCountBasedBalancer) balanceReplica(replica *meta.Replica) ([]Segment
cnt += int(s.GetNumOfRows())
}
nodesRowCnt[nid] = cnt
if nodeInfo := b.nodeManager.Get(nid); nodeInfo.IsStoppingState() {
stoppingNodesSegments[nid] = segments
} else {
nodesSegments[nid] = segments
}
totalCnt += cnt
}
average := totalCnt / len(nodes)
if len(nodes) == len(stoppingNodesSegments) {
return b.handleStoppingNodes(replica, stoppingNodesSegments)
}
average := totalCnt / len(nodesSegments)
neededRowCnt := 0
for _, rowcnt := range nodesRowCnt {
for nodeID := range nodesSegments {
rowcnt := nodesRowCnt[nodeID]
if rowcnt < average {
neededRowCnt += average - rowcnt
}
@ -138,13 +150,17 @@ func (b *RowCountBasedBalancer) balanceReplica(replica *meta.Replica) ([]Segment
segmentsToMove := make([]*meta.Segment, 0)
stopSegments, cnt := b.collectionStoppingSegments(stoppingNodesSegments)
segmentsToMove = append(segmentsToMove, stopSegments...)
neededRowCnt -= cnt
// select segments to be moved
outer:
for nodeID, rowcnt := range nodesRowCnt {
for nodeID, segments := range nodesSegments {
rowcnt := nodesRowCnt[nodeID]
if rowcnt <= average {
continue
}
segments := nodesSegments[nodeID]
sort.Slice(segments, func(i, j int) bool {
return segments[i].GetNumOfRows() > segments[j].GetNumOfRows()
})
@ -168,7 +184,8 @@ outer:
// allocate segments to those nodes with row cnt less than average
queue := newPriorityQueue()
for nodeID, rowcnt := range nodesRowCnt {
for nodeID := range nodesSegments {
rowcnt := nodesRowCnt[nodeID]
if rowcnt >= average {
continue
}
@ -177,19 +194,92 @@ outer:
}
plans := make([]SegmentAssignPlan, 0)
getPlanWeight := func(nodeID int64) Weight {
if _, ok := stoppingNodesSegments[nodeID]; ok {
return GetWeight(1)
}
return GetWeight(0)
}
for _, s := range segmentsToMove {
node := queue.pop().(*nodeItem)
plan := SegmentAssignPlan{
ReplicaID: replica.GetID(),
From: s.Node,
To: node.nodeID,
Segment: s,
Weight: getPlanWeight(s.Node),
}
plans = append(plans, plan)
node.setPriority(node.getPriority() + int(s.GetNumOfRows()))
queue.push(node)
}
return plans, nil
return plans, b.getChannelPlan(replica, stoppingNodesSegments)
}
func (b *RowCountBasedBalancer) handleStoppingNodes(replica *meta.Replica, nodeSegments map[int64][]*meta.Segment) ([]SegmentAssignPlan, []ChannelAssignPlan) {
segmentPlans := make([]SegmentAssignPlan, 0, len(nodeSegments))
channelPlans := make([]ChannelAssignPlan, 0, len(nodeSegments))
for nodeID, segments := range nodeSegments {
for _, segment := range segments {
segmentPlan := SegmentAssignPlan{
ReplicaID: replica.ID,
From: nodeID,
To: -1,
Segment: segment,
Weight: GetWeight(1),
}
segmentPlans = append(segmentPlans, segmentPlan)
}
for _, dmChannel := range b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID) {
channelPlan := ChannelAssignPlan{
ReplicaID: replica.ID,
From: nodeID,
To: -1,
Channel: dmChannel,
Weight: GetWeight(1),
}
channelPlans = append(channelPlans, channelPlan)
}
}
return segmentPlans, channelPlans
}
func (b *RowCountBasedBalancer) collectionStoppingSegments(stoppingNodesSegments map[int64][]*meta.Segment) ([]*meta.Segment, int) {
var (
segments []*meta.Segment
removeRowCnt int
)
for _, stoppingSegments := range stoppingNodesSegments {
for _, segment := range stoppingSegments {
segments = append(segments, segment)
removeRowCnt += int(segment.GetNumOfRows())
}
}
return segments, removeRowCnt
}
func (b *RowCountBasedBalancer) getChannelPlan(replica *meta.Replica, stoppingNodesSegments map[int64][]*meta.Segment) []ChannelAssignPlan {
// maybe it will have some strategies to balance the channel in the future
// but now, only balance the channel for the stopping nodes.
return b.getChannelPlanForStoppingNodes(replica, stoppingNodesSegments)
}
func (b *RowCountBasedBalancer) getChannelPlanForStoppingNodes(replica *meta.Replica, stoppingNodesSegments map[int64][]*meta.Segment) []ChannelAssignPlan {
channelPlans := make([]ChannelAssignPlan, 0)
for nodeID := range stoppingNodesSegments {
dmChannels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)
plans := b.AssignChannel(dmChannels, replica.Replica.GetNodes())
for i := range plans {
plans[i].From = nodeID
plans[i].ReplicaID = replica.ID
plans[i].Weight = GetWeight(1)
}
channelPlans = append(channelPlans, plans...)
}
return channelPlans
}
func NewRowCountBasedBalancer(
@ -201,7 +291,6 @@ func NewRowCountBasedBalancer(
) *RowCountBasedBalancer {
return &RowCountBasedBalancer{
RoundRobinBalancer: NewRoundRobinBalancer(scheduler, nodeManager),
nodeManager: nodeManager,
dist: dist,
meta: meta,
targetMgr: targetMgr,

View File

@ -19,6 +19,8 @@ package balance
import (
"testing"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
@ -36,6 +38,7 @@ type RowCountBasedBalancerTestSuite struct {
balancer *RowCountBasedBalancer
kv *etcdkv.EtcdKV
broker *meta.MockBroker
mockScheduler *task.MockScheduler
}
func (suite *RowCountBasedBalancerTestSuite) SetupSuite() {
@ -64,7 +67,8 @@ func (suite *RowCountBasedBalancerTestSuite) SetupTest() {
distManager := meta.NewDistributionManager()
nodeManager := session.NewNodeManager()
suite.balancer = NewRowCountBasedBalancer(nil, nodeManager, distManager, testMeta, testTarget)
suite.mockScheduler = task.NewMockScheduler(suite.T())
suite.balancer = NewRowCountBasedBalancer(suite.mockScheduler, nodeManager, distManager, testMeta, testTarget)
}
func (suite *RowCountBasedBalancerTestSuite) TearDownTest() {
@ -77,6 +81,8 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() {
distributions map[int64][]*meta.Segment
assignments []*meta.Segment
nodes []int64
segmentCnts []int
states []session.State
expectPlans []SegmentAssignPlan
}{
{
@ -90,7 +96,9 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() {
{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10}},
{SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 15}},
},
nodes: []int64{1, 2, 3},
nodes: []int64{1, 2, 3, 4},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal, session.NodeStateStopping},
segmentCnts: []int{0, 1, 1, 0},
expectPlans: []SegmentAssignPlan{
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 5}}, From: -1, To: 2},
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10}}, From: -1, To: 1},
@ -110,6 +118,12 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() {
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
for i := range c.nodes {
nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0")
nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i]))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
}
plans := balancer.AssignSegment(c.assignments, c.nodes)
suite.ElementsMatch(c.expectPlans, plans)
})
@ -120,12 +134,19 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
cases := []struct {
name string
nodes []int64
segmentCnts []int
states []session.State
shouldMock bool
distributions map[int64][]*meta.Segment
distributionChannels map[int64][]*meta.DmChannel
expectPlans []SegmentAssignPlan
expectChannelPlans []ChannelAssignPlan
}{
{
name: "normal balance",
nodes: []int64{1, 2},
segmentCnts: []int{1, 2},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal},
distributions: map[int64][]*meta.Segment{
1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}},
2: {
@ -136,10 +157,65 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
expectPlans: []SegmentAssignPlan{
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2}, From: 2, To: 1, ReplicaID: 1},
},
expectChannelPlans: []ChannelAssignPlan{},
},
{
name: "all stopping balance",
nodes: []int64{1, 2},
segmentCnts: []int{1, 2},
states: []session.State{session.NodeStateStopping, session.NodeStateStopping},
distributions: map[int64][]*meta.Segment{
1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}},
2: {
{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2},
{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2},
},
},
expectPlans: []SegmentAssignPlan{
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2}, From: 2, To: -1, ReplicaID: 1, Weight: weightHigh},
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, From: 2, To: -1, ReplicaID: 1, Weight: weightHigh},
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}, From: 1, To: -1, ReplicaID: 1, Weight: weightHigh},
},
expectChannelPlans: []ChannelAssignPlan{},
},
{
name: "part stopping balance",
nodes: []int64{1, 2, 3},
segmentCnts: []int{1, 2, 2},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateStopping},
shouldMock: true,
distributions: map[int64][]*meta.Segment{
1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}},
2: {
{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2},
{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2},
},
3: {
{SegmentInfo: &datapb.SegmentInfo{ID: 4, CollectionID: 1, NumOfRows: 10}, Node: 3},
{SegmentInfo: &datapb.SegmentInfo{ID: 5, CollectionID: 1, NumOfRows: 10}, Node: 3},
},
},
distributionChannels: map[int64][]*meta.DmChannel{
2: {
{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v2"}, Node: 2},
},
3: {
{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v3"}, Node: 3},
},
},
expectPlans: []SegmentAssignPlan{
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 4, CollectionID: 1, NumOfRows: 10}, Node: 3}, From: 3, To: 1, ReplicaID: 1, Weight: weightHigh},
{Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 5, CollectionID: 1, NumOfRows: 10}, Node: 3}, From: 3, To: 1, ReplicaID: 1, Weight: weightHigh},
},
expectChannelPlans: []ChannelAssignPlan{
{Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v3"}, Node: 3}, From: 3, To: 1, ReplicaID: 1, Weight: weightHigh},
},
},
{
name: "already balanced",
nodes: []int64{1, 2},
segmentCnts: []int{1, 2},
states: []session.State{session.NodeStateNormal, session.NodeStateNormal},
distributions: map[int64][]*meta.Segment{
1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 30}, Node: 1}},
2: {
@ -148,9 +224,11 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
},
},
expectPlans: []SegmentAssignPlan{},
expectChannelPlans: []ChannelAssignPlan{},
},
}
suite.mockScheduler.Mock.On("GetNodeChannelDelta", mock.Anything).Return(0)
for _, c := range cases {
suite.Run(c.name, func() {
suite.SetupSuite()
@ -167,6 +245,12 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
{
SegmentID: 3,
},
{
SegmentID: 4,
},
{
SegmentID: 5,
},
}
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
nil, segments, nil)
@ -179,8 +263,17 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
for node, v := range c.distributionChannels {
balancer.dist.ChannelDistManager.Update(node, v...)
}
for i := range c.nodes {
nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0")
nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i]))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
}
segmentPlans, channelPlans := balancer.Balance()
suite.Empty(channelPlans)
suite.ElementsMatch(c.expectChannelPlans, channelPlans)
suite.ElementsMatch(c.expectPlans, segmentPlans)
})
}

View File

@ -56,6 +56,7 @@ func CreateSegmentTasksFromPlans(ctx context.Context, checkerID int64, timeout t
)
continue
}
task.SetPriority(GetTaskPriorityFromWeight(p.Weight))
ret = append(ret, task)
}
return ret
@ -85,6 +86,7 @@ func CreateChannelTasksFromPlans(ctx context.Context, checkerID int64, timeout t
)
continue
}
task.SetPriority(GetTaskPriorityFromWeight(p.Weight))
ret = append(ret, task)
}
return ret

View File

@ -45,7 +45,12 @@ func (b *BalanceChecker) Check(ctx context.Context) []task.Task {
segmentPlans, channelPlans := b.Balance.Balance()
tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout, segmentPlans)
task.SetPriority(task.TaskPriorityLow, tasks...)
task.SetPriorityWithFunc(func(t task.Task) task.Priority {
if t.Priority() == task.TaskPriorityHigh {
return task.TaskPriorityHigh
}
return task.TaskPriorityLow
}, tasks...)
ret = append(ret, tasks...)
tasks = balance.CreateChannelTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.ChannelTaskTimeout, channelPlans)

View File

@ -35,6 +35,7 @@ var (
type CheckerController struct {
stopCh chan struct{}
checkCh chan struct{}
meta *meta.Meta
dist *meta.DistributionManager
targetMgr *meta.TargetManager
@ -68,6 +69,7 @@ func NewCheckerController(
return &CheckerController{
stopCh: make(chan struct{}),
checkCh: make(chan struct{}),
meta: meta,
dist: dist,
targetMgr: targetMgr,
@ -92,6 +94,11 @@ func (controller *CheckerController) Start(ctx context.Context) {
case <-ticker.C:
controller.check(ctx)
case <-controller.checkCh:
ticker.Stop()
controller.check(ctx)
ticker.Reset(Params.QueryCoordCfg.CheckInterval)
}
}
}()
@ -103,6 +110,10 @@ func (controller *CheckerController) Stop() {
})
}
func (controller *CheckerController) Check() {
controller.checkCh <- struct{}{}
}
// check is the real implementation of Check
func (controller *CheckerController) check(ctx context.Context) {
tasks := make([]task.Task, 0)

View File

@ -63,12 +63,11 @@ func (dc *Controller) SyncAll(ctx context.Context) {
wg := sync.WaitGroup{}
for _, h := range dc.handlers {
handler := h
wg.Add(1)
go func() {
go func(handler *distHandler) {
defer wg.Done()
handler.getDistribution(ctx)
}()
handler.getDistribution(ctx, nil)
}(h)
}
wg.Wait()
}

View File

@ -58,7 +58,6 @@ func (dh *distHandler) start(ctx context.Context) {
logger := log.With(zap.Int64("nodeID", dh.nodeID))
logger.Info("start dist handler")
ticker := time.NewTicker(Params.QueryCoordCfg.DistPullInterval)
id := int64(1)
failures := 0
for {
select {
@ -69,18 +68,9 @@ func (dh *distHandler) start(ctx context.Context) {
logger.Info("close dist handelr")
return
case <-ticker.C:
dh.mu.Lock()
cctx, cancel := context.WithTimeout(ctx, distReqTimeout)
resp, err := dh.client.GetDataDistribution(cctx, dh.nodeID, &querypb.GetDataDistributionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetDistribution),
),
})
cancel()
if err != nil || resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
dh.getDistribution(ctx, func(isSuccess bool) {
if !isSuccess {
failures++
dh.logFailureInfo(resp, err)
node := dh.nodeManager.Get(dh.nodeID)
if node != nil {
log.RatedDebug(30.0, "failed to get node's data distribution",
@ -90,15 +80,13 @@ func (dh *distHandler) start(ctx context.Context) {
}
} else {
failures = 0
dh.handleDistResp(resp)
}
if failures >= maxFailureTimes {
log.RatedInfo(30.0, fmt.Sprintf("can not get data distribution from node %d for %d times", dh.nodeID, failures))
// TODO: kill the querynode server and stop the loop?
}
id++
dh.mu.Unlock()
})
}
}
}
@ -219,7 +207,7 @@ func (dh *distHandler) updateLeaderView(resp *querypb.GetDataDistributionRespons
dh.dist.LeaderViewManager.Update(resp.GetNodeID(), updates...)
}
func (dh *distHandler) getDistribution(ctx context.Context) {
func (dh *distHandler) getDistribution(ctx context.Context, fn func(isSuccess bool)) {
dh.mu.Lock()
defer dh.mu.Unlock()
cctx, cancel := context.WithTimeout(ctx, distReqTimeout)
@ -230,11 +218,16 @@ func (dh *distHandler) getDistribution(ctx context.Context) {
})
cancel()
if err != nil || resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
isSuccess := err != nil || resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success
if isSuccess {
dh.logFailureInfo(resp, err)
} else {
dh.handleDistResp(resp)
}
if fn != nil {
fn(isSuccess)
}
}
func (dh *distHandler) stop() {

View File

@ -123,6 +123,10 @@ func (node *MockQueryNode) Start() error {
return err
}
func (node *MockQueryNode) Stopping() {
node.session.GoingStop()
}
func (node *MockQueryNode) Stop() {
node.cancel()
node.server.GracefulStop()

View File

@ -559,6 +559,16 @@ func (s *Server) watchNodes(revision int64) {
s.handleNodeUp(nodeID)
s.metricsCacheManager.InvalidateSystemInfoMetrics()
case sessionutil.SessionUpdateEvent:
nodeID := event.Session.ServerID
addr := event.Session.Address
log.Info("stopping the node",
zap.Int64("nodeID", nodeID),
zap.String("nodeAddr", addr),
)
s.nodeMgr.Stopping(nodeID)
s.checkerController.Check()
case sessionutil.SessionDelEvent:
nodeID := event.Session.ServerID
log.Info("a node down, remove it", zap.Int64("nodeID", nodeID))

View File

@ -185,6 +185,16 @@ func (suite *ServerSuite) TestNodeUp() {
}
func (suite *ServerSuite) TestNodeUpdate() {
downNode := suite.nodes[0]
downNode.Stopping()
suite.Eventually(func() bool {
node := suite.server.nodeMgr.Get(downNode.ID)
return node.IsStoppingState()
}, 5*time.Second, time.Second)
}
func (suite *ServerSuite) TestNodeDown() {
downNode := suite.nodes[0]
downNode.Stop()

View File

@ -608,8 +608,6 @@ func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas[0].AddNode(srcNode)
replicas[0].AddNode(dstNode)
defer replicas[0].RemoveNode(srcNode)
defer replicas[0].RemoveNode(dstNode)
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
for partition, segments := range suite.segments[collection] {
@ -617,13 +615,21 @@ func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() {
metaSegments = append(metaSegments,
utils.CreateTestSegment(collection, partition, segment, srcNode, 1, "test-channel"))
if segmentOnCollection[collection] == nil {
segmentOnCollection[collection] = make([]int64, 0)
}
segmentOnCollection[collection] = append(segmentOnCollection[collection], segment)
}
}
}
suite.nodeMgr.Add(session.NewNodeInfo(1001, "localhost"))
suite.nodeMgr.Add(session.NewNodeInfo(1002, "localhost"))
defer func() {
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas[0].RemoveNode(srcNode)
replicas[0].RemoveNode(dstNode)
}
suite.nodeMgr.Remove(1001)
suite.nodeMgr.Remove(1002)
}()
suite.dist.SegmentDistManager.Update(srcNode, metaSegments...)
// expect each collection can only trigger its own segment's balance

View File

@ -26,6 +26,7 @@ import (
type Manager interface {
Add(node *NodeInfo)
Stopping(nodeID int64)
Remove(nodeID int64)
Get(nodeID int64) *NodeInfo
GetAll() []*NodeInfo
@ -50,6 +51,14 @@ func (m *NodeManager) Remove(nodeID int64) {
metrics.QueryCoordNumQueryNodes.WithLabelValues().Set(float64(len(m.nodes)))
}
func (m *NodeManager) Stopping(nodeID int64) {
m.mu.Lock()
defer m.mu.Unlock()
if nodeInfo, ok := m.nodes[nodeID]; ok {
nodeInfo.SetState(NodeStateStopping)
}
}
func (m *NodeManager) Get(nodeID int64) *NodeInfo {
m.mu.RLock()
defer m.mu.RUnlock()
@ -72,11 +81,19 @@ func NewNodeManager() *NodeManager {
}
}
type State int
const (
NodeStateNormal = iota
NodeStateStopping
)
type NodeInfo struct {
stats
mu sync.RWMutex
id int64
addr string
state State
lastHeartbeat *atomic.Int64
}
@ -108,6 +125,18 @@ func (n *NodeInfo) LastHeartbeat() time.Time {
return time.Unix(0, n.lastHeartbeat.Load())
}
func (n *NodeInfo) IsStoppingState() bool {
n.mu.RLock()
defer n.mu.RUnlock()
return n.state == NodeStateStopping
}
func (n *NodeInfo) SetState(s State) {
n.mu.Lock()
defer n.mu.Unlock()
n.state = s
}
func (n *NodeInfo) UpdateStats(opts ...StatsOption) {
n.mu.Lock()
for _, opt := range opts {

View File

@ -56,6 +56,12 @@ func SetPriority(priority Priority, tasks ...Task) {
}
}
func SetPriorityWithFunc(f func(t Task) Priority, tasks ...Task) {
for i := range tasks {
tasks[i].SetPriority(f(tasks[i]))
}
}
// GetTaskType returns the task's type,
// for now, only 3 types;
// - only 1 grow action -> Grow

View File

@ -53,6 +53,7 @@ func TestCollection_schema(t *testing.T) {
}
func TestCollection_vChannel(t *testing.T) {
Params.Init()
collectionID := UniqueID(0)
schema := genTestCollectionSchema()

View File

@ -27,6 +27,10 @@ import (
"github.com/stretchr/testify/suite"
)
func init() {
rateCol, _ = newRateCollector()
}
func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

View File

@ -45,6 +45,17 @@ import (
"github.com/milvus-io/milvus/internal/util/typeutil"
)
// isHealthy checks if QueryNode is healthy
func (node *QueryNode) isHealthy() bool {
code := node.stateCode.Load().(commonpb.StateCode)
return code == commonpb.StateCode_Healthy
}
func (node *QueryNode) isHealthyOrStopping() bool {
code := node.stateCode.Load().(commonpb.StateCode)
return code == commonpb.StateCode_Healthy || code == commonpb.StateCode_Stopping
}
// GetComponentStates returns information about whether the node is healthy
func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
stats := &milvuspb.ComponentStates{
@ -159,10 +170,12 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que
},
}
if !node.isHealthy() {
if !node.isHealthyOrStopping() {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID())
return failRet, nil
}
node.wg.Add(1)
defer node.wg.Done()
traceID, _, _ := trace.InfoFromContext(ctx)
log.Ctx(ctx).Debug("received GetStatisticRequest",
@ -286,8 +299,7 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que
// WatchDmChannels create consumers on dmChannels to receive Incremental datawhich is the important part of real-time query
func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
// check node healthy
code := node.stateCode.Load().(commonpb.StateCode)
if code != commonpb.StateCode_Healthy {
if !node.isHealthy() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -295,6 +307,8 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
// check target matches
if in.GetBase().GetTargetID() != paramtable.GetNodeID() {
@ -369,8 +383,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) {
// check node healthy
code := node.stateCode.Load().(commonpb.StateCode)
if code != commonpb.StateCode_Healthy {
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -378,6 +391,8 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
// check target matches
if req.GetBase().GetTargetID() != paramtable.GetNodeID() {
@ -430,8 +445,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
// LoadSegments load historical data into query node, historical data can be vector data or index
func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
// check node healthy
code := node.stateCode.Load().(commonpb.StateCode)
if code != commonpb.StateCode_Healthy {
if !node.isHealthy() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -439,6 +453,9 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
// check target matches
if in.GetBase().GetTargetID() != paramtable.GetNodeID() {
status := &commonpb.Status{
@ -511,8 +528,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
// ReleaseCollection clears all data related to this collection on the querynode
func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
code := node.stateCode.Load().(commonpb.StateCode)
if code != commonpb.StateCode_Healthy {
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -520,6 +536,9 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
dct := &releaseCollectionTask{
baseTask: baseTask{
ctx: ctx,
@ -557,8 +576,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas
// ReleasePartitions clears all data related to this partition on the querynode
func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
code := node.stateCode.Load().(commonpb.StateCode)
if code != commonpb.StateCode_Healthy {
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -566,6 +584,9 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
dct := &releasePartitionsTask{
baseTask: baseTask{
ctx: ctx,
@ -603,9 +624,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas
// ReleaseSegments remove the specified segments from query node according segmentIDs, partitionIDs, and collectionID
func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
// check node healthy
code := node.stateCode.Load().(commonpb.StateCode)
if code != commonpb.StateCode_Healthy {
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -613,6 +632,9 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
// check target matches
if in.GetBase().GetTargetID() != paramtable.GetNodeID() {
status := &commonpb.Status{
@ -651,8 +673,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS
// GetSegmentInfo returns segment information of the collection on the queryNode, and the information includes memSize, numRow, indexName, indexID ...
func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
code := node.stateCode.Load().(commonpb.StateCode)
if code != commonpb.StateCode_Healthy {
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
res := &querypb.GetSegmentInfoResponse{
Status: &commonpb.Status{
@ -662,6 +683,9 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmen
}
return res, nil
}
node.wg.Add(1)
defer node.wg.Done()
var segmentInfos []*querypb.SegmentInfo
segmentIDs := make(map[int64]struct{})
@ -696,12 +720,6 @@ func filterSegmentInfo(segmentInfos []*querypb.SegmentInfo, segmentIDs map[int64
return filtered
}
// isHealthy checks if QueryNode is healthy
func (node *QueryNode) isHealthy() bool {
code := node.stateCode.Load().(commonpb.StateCode)
return code == commonpb.StateCode_Healthy
}
// Search performs replica search tasks.
func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
log.Ctx(ctx).Debug("Received SearchRequest",
@ -785,10 +803,12 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc()
}
}()
if !node.isHealthy() {
if !node.isHealthyOrStopping() {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID())
return failRet, nil
}
node.wg.Add(1)
defer node.wg.Done()
msgID := req.GetReq().GetBase().GetMsgID()
log.Ctx(ctx).Debug("Received SearchRequest",
@ -935,10 +955,12 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc()
}
}()
if !node.isHealthy() {
if !node.isHealthyOrStopping() {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID())
return failRet, nil
}
node.wg.Add(1)
defer node.wg.Done()
traceID, _, _ := trace.InfoFromContext(ctx)
log.Ctx(ctx).Debug("queryWithDmlChannel receives query request",
@ -1145,6 +1167,8 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn
Reason: msgQueryNodeIsUnhealthy(paramtable.GetNodeID()),
}, nil
}
node.wg.Add(1)
defer node.wg.Done()
log.Info("Received SyncReplicaSegments request", zap.String("vchannelName", req.GetVchannelName()))
@ -1164,7 +1188,7 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn
// ShowConfigurations returns the configurations of queryNode matching req.Pattern
func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
if !node.isHealthy() {
if !node.isHealthyOrStopping() {
log.Warn("QueryNode.ShowConfigurations failed",
zap.Int64("nodeId", paramtable.GetNodeID()),
zap.String("req", req.Pattern),
@ -1178,13 +1202,15 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S
Configuations: nil,
}, nil
}
node.wg.Add(1)
defer node.wg.Done()
return getComponentConfigurations(ctx, req), nil
}
// GetMetrics return system infos of the query node, such as total memory, memory usage, cpu usage ...
func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
if !node.isHealthy() {
if !node.isHealthyOrStopping() {
log.Warn("QueryNode.GetMetrics failed",
zap.Int64("nodeId", paramtable.GetNodeID()),
zap.String("req", req.Request),
@ -1198,6 +1224,8 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
Response: "",
}, nil
}
node.wg.Add(1)
defer node.wg.Done()
metricType, err := metricsinfo.ParseMetricType(req.Request)
if err != nil {
@ -1257,7 +1285,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
zap.Int64("msg-id", req.GetBase().GetMsgID()),
zap.Int64("node-id", paramtable.GetNodeID()),
)
if !node.isHealthy() {
if !node.isHealthyOrStopping() {
log.Warn("QueryNode.GetMetrics failed",
zap.Error(errQueryNodeIsUnhealthy(paramtable.GetNodeID())))
@ -1268,6 +1296,8 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
},
}, nil
}
node.wg.Add(1)
defer node.wg.Done()
// check target matches
if req.GetBase().GetTargetID() != paramtable.GetNodeID() {
@ -1345,8 +1375,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannel()))
// check node healthy
code := node.stateCode.Load().(commonpb.StateCode)
if code != commonpb.StateCode_Healthy {
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -1354,6 +1383,9 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
// check target matches
if req.GetBase().GetTargetID() != paramtable.GetNodeID() {
log.Warn("failed to do match target id when sync ", zap.Int64("expect", req.GetBase().GetTargetID()), zap.Int64("actual", node.session.ServerID))

View File

@ -144,6 +144,20 @@ func TestImpl_WatchDmChannels(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
t.Run("server stopping", func(t *testing.T) {
req := &queryPb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
MsgID: rand.Int63(),
},
}
node.UpdateStateCode(commonpb.StateCode_Stopping)
defer node.UpdateStateCode(commonpb.StateCode_Healthy)
status, err := node.WatchDmChannels(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
t.Run("mock release after loaded", func(t *testing.T) {
mockTSReplica := &MockTSafeReplicaInterface{}
@ -253,6 +267,15 @@ func TestImpl_LoadSegments(t *testing.T) {
t.Run("server unhealthy", func(t *testing.T) {
node.UpdateStateCode(commonpb.StateCode_Abnormal)
defer node.UpdateStateCode(commonpb.StateCode_Healthy)
status, err := node.LoadSegments(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
t.Run("server stopping", func(t *testing.T) {
node.UpdateStateCode(commonpb.StateCode_Stopping)
defer node.UpdateStateCode(commonpb.StateCode_Healthy)
status, err := node.LoadSegments(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)

View File

@ -139,6 +139,7 @@ type ReplicaInterface interface {
getGrowingSegments() []*Segment
getSealedSegments() []*Segment
getNoSegmentChan() <-chan struct{}
}
// collectionReplica is the data replication of memory data in query node.
@ -149,6 +150,7 @@ type metaReplica struct {
partitions map[UniqueID]*Partition
growingSegments map[UniqueID]*Segment
sealedSegments map[UniqueID]*Segment
noSegmentChan chan struct{}
excludedSegments map[UniqueID][]*datapb.SegmentInfo // map[collectionID]segmentIDs
@ -743,6 +745,26 @@ func (replica *metaReplica) removeSegmentPrivate(segmentID UniqueID, segType seg
).Sub(float64(rowCount))
}
}
replica.sendNoSegmentSignal()
}
func (replica *metaReplica) sendNoSegmentSignal() {
if replica.noSegmentChan == nil {
return
}
select {
case <-replica.noSegmentChan:
default:
if len(replica.growingSegments) == 0 && len(replica.sealedSegments) == 0 {
close(replica.noSegmentChan)
}
}
}
func (replica *metaReplica) getNoSegmentChan() <-chan struct{} {
replica.noSegmentChan = make(chan struct{})
replica.sendNoSegmentSignal()
return replica.noSegmentChan
}
// getSegmentByID returns the segment which id is segmentID

View File

@ -17,7 +17,9 @@
package querynode
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -169,6 +171,42 @@ func TestMetaReplica_segment(t *testing.T) {
}
})
t.Run("test getNoSegmentChan", func(t *testing.T) {
replica, err := genSimpleReplica()
assert.NoError(t, err)
defer replica.freeAll()
select {
case <-replica.getNoSegmentChan():
default:
assert.FailNow(t, "fail to assert getNoSegmentChan")
}
const segmentNum = 3
c := replica.getNoSegmentChan()
w := sync.WaitGroup{}
w.Add(1)
go func() {
defer w.Done()
select {
case <-c:
case <-time.After(5 * time.Second):
assert.FailNow(t, "timeout getNoSegmentChan")
}
}()
for i := 0; i < segmentNum; i++ {
err := replica.addSegment(UniqueID(i), defaultPartitionID, defaultCollectionID, "", defaultSegmentVersion, defaultSegmentStartPosition, segmentTypeGrowing)
assert.NoError(t, err)
targetSeg, err := replica.getSegmentByID(UniqueID(i), segmentTypeGrowing)
assert.NoError(t, err)
assert.Equal(t, UniqueID(i), targetSeg.segmentID)
}
for i := 0; i < segmentNum; i++ {
replica.removeSegment(UniqueID(i), segmentTypeGrowing)
}
w.Wait()
})
t.Run("test hasSegment", func(t *testing.T) {
replica, err := genSimpleReplica()
assert.NoError(t, err)

View File

@ -30,6 +30,8 @@ import "C"
import (
"context"
"fmt"
"github.com/samber/lo"
uberatomic "go.uber.org/atomic"
"os"
"path"
"runtime"
@ -87,6 +89,7 @@ type QueryNode struct {
wg sync.WaitGroup
stateCode atomic.Value
stopFlag uberatomic.Bool
//call once
initOnce sync.Once
@ -137,6 +140,7 @@ func NewQueryNode(ctx context.Context, factory dependency.Factory) *QueryNode {
node.tSafeReplica = newTSafeReplica()
node.scheduler = newTaskScheduler(ctx1, node.tSafeReplica)
node.UpdateStateCode(commonpb.StateCode_Abnormal)
node.stopFlag.Store(false)
return node
}
@ -324,8 +328,34 @@ func (node *QueryNode) Start() error {
// Stop mainly stop QueryNode's query service, historical loop and streaming loop.
func (node *QueryNode) Stop() error {
if node.stopFlag.Load() {
return nil
}
node.stopFlag.Store(true)
log.Warn("Query node stop..")
err := node.session.GoingStop()
if err != nil {
log.Warn("session fail to go stopping state", zap.Error(err))
} else {
node.UpdateStateCode(commonpb.StateCode_Stopping)
noSegmentChan := node.metaReplica.getNoSegmentChan()
select {
case <-noSegmentChan:
case <-time.After(time.Duration(Params.QueryNodeCfg.GracefulStopTimeout) * time.Second):
log.Warn("migrate data timed out", zap.Int64("server_id", paramtable.GetNodeID()),
zap.Int64s("sealed_segment", lo.Map[*Segment, int64](node.metaReplica.getSealedSegments(), func(t *Segment, i int) int64 {
return t.ID()
})),
zap.Int64s("growing_segment", lo.Map[*Segment, int64](node.metaReplica.getGrowingSegments(), func(t *Segment, i int) int64 {
return t.ID()
})),
)
}
}
node.UpdateStateCode(commonpb.StateCode_Abnormal)
node.wg.Wait()
node.queryNodeLoopCancel()
// close services
@ -346,7 +376,6 @@ func (node *QueryNode) Stop() error {
}
node.session.Revoke(time.Second)
node.wg.Wait()
return nil
}

View File

@ -108,6 +108,7 @@ func newQueryNodeMock() *QueryNode {
}
svr.loader = newSegmentLoader(svr.metaReplica, etcdKV, svr.vectorStorage, factory)
svr.etcdKV = etcdKV
svr.etcdCli = etcdCli
return svr
}

View File

@ -34,6 +34,7 @@ const (
// DefaultIndexSliceSize defines the default slice size of index file when serializing.
DefaultIndexSliceSize = 16
DefaultGracefulTime = 5000 //ms
DefaultGracefulStopTimeout = 30 // s
DefaultThreadCoreCoefficient = 10
DefaultSessionTTL = 60 //s
@ -147,6 +148,7 @@ type commonConfig struct {
LoadNumThreadRatio float64
BeamWidthRatio float64
GracefulTime int64
GracefulStopTimeout int64 // unit: s
StorageType string
SimdType string
@ -198,6 +200,7 @@ func (p *commonConfig) init(base *BaseTable) {
p.initLoadNumThreadRatio()
p.initBeamWidthRatio()
p.initGracefulTime()
p.initGracefulStopTimeout()
p.initStorageType()
p.initThreadCoreCoefficient()
@ -434,6 +437,10 @@ func (p *commonConfig) initGracefulTime() {
p.GracefulTime = p.Base.ParseInt64WithDefault("common.gracefulTime", DefaultGracefulTime)
}
func (p *commonConfig) initGracefulStopTimeout() {
p.GracefulStopTimeout = p.Base.ParseInt64WithDefault("common.gracefulStopTimeout", DefaultGracefulStopTimeout)
}
func (p *commonConfig) initStorageType() {
p.StorageType = p.Base.LoadWithDefault("common.storageType", "minio")
}
@ -941,6 +948,8 @@ type queryNodeConfig struct {
GCHelperEnabled bool
MinimumGOGCConfig int
MaximumGOGCConfig int
GracefulStopTimeout int64
}
func (p *queryNodeConfig) init(base *BaseTable) {
@ -973,6 +982,8 @@ func (p *queryNodeConfig) init(base *BaseTable) {
p.initGCTunerEnbaled()
p.initMaximumGOGC()
p.initMinimumGOGC()
p.initGracefulStopTimeout()
}
// InitAlias initializes an alias for the QueryNode role.
@ -1145,6 +1156,10 @@ func (p *queryNodeConfig) initMaximumGOGC() {
p.MaximumGOGCConfig = p.Base.ParseIntWithDefault("queryNode.gchelper.maximumGoGC", 200)
}
func (p *queryNodeConfig) initGracefulStopTimeout() {
p.GracefulStopTimeout = p.Base.ParseInt64WithDefault("queryNode.gracefulStopTimeout", params.CommonCfg.GracefulStopTimeout)
}
// /////////////////////////////////////////////////////////////////////////////
// --- datacoord ---
type dataCoordConfig struct {

View File

@ -68,6 +68,13 @@ func TestComponentParam(t *testing.T) {
assert.Equal(t, Params.GracefulTime, int64(DefaultGracefulTime))
t.Logf("default grafeful time = %d", Params.GracefulTime)
assert.Equal(t, Params.GracefulStopTimeout, int64(DefaultGracefulStopTimeout))
assert.Equal(t, params.QueryNodeCfg.GracefulStopTimeout, Params.GracefulStopTimeout)
t.Logf("default grafeful stop timeout = %d", Params.GracefulStopTimeout)
Params.Base.Save("common.gracefulStopTimeout", "50")
Params.initGracefulStopTimeout()
assert.Equal(t, Params.GracefulStopTimeout, int64(50))
// -- proxy --
assert.Equal(t, Params.ProxySubName, "by-dev-proxy")
t.Logf("ProxySubName: %s", Params.ProxySubName)
@ -291,6 +298,11 @@ func TestComponentParam(t *testing.T) {
nprobe = Params.SmallIndexNProbe
assert.Equal(t, int64(4), nprobe)
Params.Base.Save("queryNode.gracefulStopTimeout", "100")
Params.initGracefulStopTimeout()
gracefulStopTimeout := Params.GracefulStopTimeout
assert.Equal(t, int64(100), gracefulStopTimeout)
})
t.Run("test dataCoordConfig", func(t *testing.T) {

View File

@ -54,6 +54,8 @@ func (t SessionEventType) String() string {
return "SessionAddEvent"
case SessionDelEvent:
return "SessionDelEvent"
case SessionUpdateEvent:
return "SessionUpdateEvent"
default:
return ""
}
@ -71,6 +73,8 @@ const (
SessionAddEvent
// SessionDelEvent event type for a Session deleted
SessionDelEvent
// SessionUpdateEvent event type for a Session stopping
SessionUpdateEvent
)
// Session is a struct to store service's session, including ServerID, ServerName,
@ -86,6 +90,7 @@ type Session struct {
ServerName string `json:"ServerName,omitempty"`
Address string `json:"Address,omitempty"`
Exclusive bool `json:"Exclusive,omitempty"`
Stopping bool `json:"Stopping,omitempty"`
TriggerKill bool
Version semver.Version `json:"Version,omitempty"`
@ -138,6 +143,7 @@ func (s *Session) UnmarshalJSON(data []byte) error {
ServerName string `json:"ServerName,omitempty"`
Address string `json:"Address,omitempty"`
Exclusive bool `json:"Exclusive,omitempty"`
Stopping bool `json:"Stopping,omitempty"`
TriggerKill bool
Version string `json:"Version"`
}
@ -157,6 +163,7 @@ func (s *Session) UnmarshalJSON(data []byte) error {
s.ServerName = raw.ServerName
s.Address = raw.Address
s.Exclusive = raw.Exclusive
s.Stopping = raw.Stopping
s.TriggerKill = raw.TriggerKill
return nil
}
@ -170,6 +177,7 @@ func (s *Session) MarshalJSON() ([]byte, error) {
ServerName string `json:"ServerName,omitempty"`
Address string `json:"Address,omitempty"`
Exclusive bool `json:"Exclusive,omitempty"`
Stopping bool `json:"Stopping,omitempty"`
TriggerKill bool
Version string `json:"Version"`
}{
@ -177,6 +185,7 @@ func (s *Session) MarshalJSON() ([]byte, error) {
ServerName: s.ServerName,
Address: s.Address,
Exclusive: s.Exclusive,
Stopping: s.Stopping,
TriggerKill: s.TriggerKill,
Version: verStr,
})
@ -325,6 +334,14 @@ func (s *Session) getServerIDWithKey(key string) (int64, error) {
}
}
func (s *Session) getCompleteKey() string {
key := s.ServerName
if !s.Exclusive || (s.enableActiveStandBy && s.isStandby.Load().(bool)) {
key = fmt.Sprintf("%s-%d", key, s.ServerID)
}
return path.Join(s.metaRoot, DefaultServiceRoot, key)
}
// registerService registers the service to etcd so that other services
// can find that the service is online and issue subsequent operations
// RegisterService will save a key-value in etcd
@ -342,11 +359,7 @@ func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, er
if s.enableActiveStandBy {
s.updateStandby(true)
}
key := s.ServerName
if !s.Exclusive || s.enableActiveStandBy {
key = fmt.Sprintf("%s-%d", key, s.ServerID)
}
completeKey := path.Join(s.metaRoot, DefaultServiceRoot, key)
completeKey := s.getCompleteKey()
var ch <-chan *clientv3.LeaseKeepAliveResponse
log.Debug("service begin to register to etcd", zap.String("serverName", s.ServerName), zap.Int64("ServerID", s.ServerID))
@ -383,7 +396,7 @@ func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, er
}
if !txnResp.Succeeded {
return fmt.Errorf("function CompareAndSwap error for compare is false for key: %s", key)
return fmt.Errorf("function CompareAndSwap error for compare is false for key: %s", s.ServerName)
}
log.Debug("put session key into etcd", zap.String("key", completeKey), zap.String("value", string(sessionJSON)))
@ -497,6 +510,34 @@ func (s *Session) GetSessionsWithVersionRange(prefix string, r semver.Range) (ma
return res, resp.Header.Revision, nil
}
func (s *Session) GoingStop() error {
if s == nil || s.etcdCli == nil || s.leaseID == nil {
return errors.New("the session hasn't been init")
}
completeKey := s.getCompleteKey()
resp, err := s.etcdCli.Get(s.ctx, completeKey, clientv3.WithCountOnly())
if err != nil {
log.Error("fail to get the session", zap.String("key", completeKey), zap.Error(err))
return err
}
if resp.Count == 0 {
return nil
}
s.Stopping = true
sessionJSON, err := json.Marshal(s)
if err != nil {
log.Error("fail to marshal the session", zap.String("key", completeKey))
return err
}
_, err = s.etcdCli.Put(s.ctx, completeKey, string(sessionJSON), clientv3.WithLease(*s.leaseID))
if err != nil {
log.Error("fail to update the session to stopping state", zap.String("key", completeKey))
return err
}
return nil
}
// SessionEvent indicates the changes of other servers.
// if a server is up, EventType is SessAddEvent.
// if a server is down, EventType is SessDelEvent.
@ -596,7 +637,11 @@ func (w *sessionWatcher) handleWatchResponse(wresp clientv3.WatchResponse) {
if !w.validate(session) {
continue
}
if session.Stopping {
eventType = SessionUpdateEvent
} else {
eventType = SessionAddEvent
}
case mvccpb.DELETE:
log.Debug("watch services",
zap.Any("delete kv", ev.PrevKv))

View File

@ -692,6 +692,7 @@ func TestSessionEventType_String(t *testing.T) {
{t: SessionNoneEvent, want: ""},
{t: SessionAddEvent, want: "SessionAddEvent"},
{t: SessionDelEvent, want: "SessionDelEvent"},
{t: SessionUpdateEvent, want: "SessionUpdateEvent"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@ -29,7 +29,9 @@ if [[ $(uname -s) == "Darwin" ]]; then
export PATH="/usr/local/opt/grep/libexec/gnubin:$PATH"
fi
if test -z "$(git status | grep -E "*pb.go|*pb.cc|*pb.h")"; then
check_result=$(git status | grep -E "*pb.go|*pb.cc|*pb.h")
echo "check_result: $check_result"
if test -z "$check_result"; then
exit 0
else
echo "The go file or cpp file generated by proto are not latest!"

48
scripts/stop_graceful.sh Executable file
View File

@ -0,0 +1,48 @@
# Licensed to the LF AI & Data foundation under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function get_milvus_process() {
return $(ps -e | grep milvus | grep -v grep | awk '{print $1}')
}
echo "Stopping milvus..."
BASEDIR=$(dirname "$0")
timeout=120
if [ ! -z "$1" ]; then
timeout=$1
fi
if [ -z $(get_milvus_process) ]; then
echo "No milvus process"
exit 0
fi
kill -15 $(get_milvus_process)
start=$(date +%s)
while :
do
sleep 1
if [ -z $(get_milvus_process) ]; then
echo "Milvus stopped"
break
fi
if [ $(( $(date +%s) - $start )) -gt $timeout ]; then
echo "Milvus timeout stopped"
kill -15 $(get_milvus_process)
break
fi
done