enhance: Add restful api for devops to execute rolling upgrade (#29998)

issue: #29261
This PR Add restful api for devops to execute rolling upgrade, including
suspend/resume balance and manual transfer segments/channels.

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
pull/31610/head
wei liu 2024-03-27 16:15:19 +08:00 committed by GitHub
parent bd44bd5ae2
commit 92971707de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 4697 additions and 127 deletions

View File

@ -402,3 +402,102 @@ func (c *Client) DeactivateChecker(ctx context.Context, req *querypb.DeactivateC
return client.DeactivateChecker(ctx, req)
})
}
func (c *Client) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest, opts ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
)
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*querypb.ListQueryNodeResponse, error) {
return client.ListQueryNode(ctx, req)
})
}
func (c *Client) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
)
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*querypb.GetQueryNodeDistributionResponse, error) {
return client.GetQueryNodeDistribution(ctx, req)
})
}
func (c *Client) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
)
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
return client.SuspendBalance(ctx, req)
})
}
func (c *Client) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
)
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
return client.ResumeBalance(ctx, req)
})
}
func (c *Client) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
)
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
return client.SuspendNode(ctx, req)
})
}
func (c *Client) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
)
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
return client.ResumeNode(ctx, req)
})
}
func (c *Client) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
)
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
return client.TransferSegment(ctx, req)
})
}
func (c *Client) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
)
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
return client.TransferChannel(ctx, req)
})
}
func (c *Client) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
)
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
return client.CheckQueryNodeDistribution(ctx, req)
})
}

View File

@ -158,6 +158,33 @@ func Test_NewClient(t *testing.T) {
r30, err := client.DeactivateChecker(ctx, nil)
retCheck(retNotNil, r30, err)
r31, err := client.ListQueryNode(ctx, nil)
retCheck(retNotNil, r31, err)
r32, err := client.GetQueryNodeDistribution(ctx, nil)
retCheck(retNotNil, r32, err)
r33, err := client.SuspendBalance(ctx, nil)
retCheck(retNotNil, r33, err)
r34, err := client.ResumeBalance(ctx, nil)
retCheck(retNotNil, r34, err)
r35, err := client.SuspendNode(ctx, nil)
retCheck(retNotNil, r35, err)
r36, err := client.ResumeNode(ctx, nil)
retCheck(retNotNil, r36, err)
r37, err := client.TransferSegment(ctx, nil)
retCheck(retNotNil, r37, err)
r38, err := client.TransferChannel(ctx, nil)
retCheck(retNotNil, r38, err)
r39, err := client.CheckQueryNodeDistribution(ctx, nil)
retCheck(retNotNil, r39, err)
}
client.(*Client).grpcClient = &mock.GRPCClientBase[querypb.QueryCoordClient]{

View File

@ -446,3 +446,39 @@ func (s *Server) DeactivateChecker(ctx context.Context, req *querypb.DeactivateC
func (s *Server) ListCheckers(ctx context.Context, req *querypb.ListCheckersRequest) (*querypb.ListCheckersResponse, error) {
return s.queryCoord.ListCheckers(ctx, req)
}
func (s *Server) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error) {
return s.queryCoord.ListQueryNode(ctx, req)
}
func (s *Server) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error) {
return s.queryCoord.GetQueryNodeDistribution(ctx, req)
}
func (s *Server) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest) (*commonpb.Status, error) {
return s.queryCoord.SuspendBalance(ctx, req)
}
func (s *Server) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest) (*commonpb.Status, error) {
return s.queryCoord.ResumeBalance(ctx, req)
}
func (s *Server) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest) (*commonpb.Status, error) {
return s.queryCoord.SuspendNode(ctx, req)
}
func (s *Server) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest) (*commonpb.Status, error) {
return s.queryCoord.ResumeNode(ctx, req)
}
func (s *Server) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest) (*commonpb.Status, error) {
return s.queryCoord.TransferSegment(ctx, req)
}
func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest) (*commonpb.Status, error) {
return s.queryCoord.TransferChannel(ctx, req)
}
func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error) {
return s.queryCoord.CheckQueryNodeDistribution(ctx, req)
}

View File

@ -283,6 +283,78 @@ func Test_NewServer(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
})
t.Run("ListQueryNode", func(t *testing.T) {
req := &querypb.ListQueryNodeRequest{}
mqc.EXPECT().ListQueryNode(mock.Anything, req).Return(&querypb.ListQueryNodeResponse{Status: merr.Success()}, nil)
resp, err := server.ListQueryNode(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
t.Run("GetQueryNodeDistribution", func(t *testing.T) {
req := &querypb.GetQueryNodeDistributionRequest{}
mqc.EXPECT().GetQueryNodeDistribution(mock.Anything, req).Return(&querypb.GetQueryNodeDistributionResponse{Status: merr.Success()}, nil)
resp, err := server.GetQueryNodeDistribution(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
t.Run("SuspendBalance", func(t *testing.T) {
req := &querypb.SuspendBalanceRequest{}
mqc.EXPECT().SuspendBalance(mock.Anything, req).Return(merr.Success(), nil)
resp, err := server.SuspendBalance(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("ResumeBalance", func(t *testing.T) {
req := &querypb.ResumeBalanceRequest{}
mqc.EXPECT().ResumeBalance(mock.Anything, req).Return(merr.Success(), nil)
resp, err := server.ResumeBalance(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("SuspendNode", func(t *testing.T) {
req := &querypb.SuspendNodeRequest{}
mqc.EXPECT().SuspendNode(mock.Anything, req).Return(merr.Success(), nil)
resp, err := server.SuspendNode(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("ResumeNode", func(t *testing.T) {
req := &querypb.ResumeNodeRequest{}
mqc.EXPECT().ResumeNode(mock.Anything, req).Return(merr.Success(), nil)
resp, err := server.ResumeNode(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("TransferSegment", func(t *testing.T) {
req := &querypb.TransferSegmentRequest{}
mqc.EXPECT().TransferSegment(mock.Anything, req).Return(merr.Success(), nil)
resp, err := server.TransferSegment(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("TransferChannel", func(t *testing.T) {
req := &querypb.TransferChannelRequest{}
mqc.EXPECT().TransferChannel(mock.Anything, req).Return(merr.Success(), nil)
resp, err := server.TransferChannel(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("CheckQueryNodeDistribution", func(t *testing.T) {
req := &querypb.CheckQueryNodeDistributionRequest{}
mqc.EXPECT().CheckQueryNodeDistribution(mock.Anything, req).Return(merr.Success(), nil)
resp, err := server.CheckQueryNodeDistribution(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
err = server.Stop()
assert.NoError(t, err)
}

View File

@ -74,6 +74,10 @@ func (s *Server) GetStatistics(ctx context.Context, request *querypb.GetStatisti
return s.querynode.GetStatistics(ctx, request)
}
func (s *Server) GetQueryNode() types.QueryNodeComponent {
return s.querynode
}
// NewServer create a new QueryNode grpc server.
func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) {
ctx1, cancel := context.WithCancel(ctx)

View File

@ -144,6 +144,61 @@ func (_c *MockQueryCoord_CheckHealth_Call) RunAndReturn(run func(context.Context
return _c
}
// CheckQueryNodeDistribution provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) CheckQueryNodeDistribution(_a0 context.Context, _a1 *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoord_CheckQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckQueryNodeDistribution'
type MockQueryCoord_CheckQueryNodeDistribution_Call struct {
*mock.Call
}
// CheckQueryNodeDistribution is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.CheckQueryNodeDistributionRequest
func (_e *MockQueryCoord_Expecter) CheckQueryNodeDistribution(_a0 interface{}, _a1 interface{}) *MockQueryCoord_CheckQueryNodeDistribution_Call {
return &MockQueryCoord_CheckQueryNodeDistribution_Call{Call: _e.mock.On("CheckQueryNodeDistribution", _a0, _a1)}
}
func (_c *MockQueryCoord_CheckQueryNodeDistribution_Call) Run(run func(_a0 context.Context, _a1 *querypb.CheckQueryNodeDistributionRequest)) *MockQueryCoord_CheckQueryNodeDistribution_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.CheckQueryNodeDistributionRequest))
})
return _c
}
func (_c *MockQueryCoord_CheckQueryNodeDistribution_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_CheckQueryNodeDistribution_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoord_CheckQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error)) *MockQueryCoord_CheckQueryNodeDistribution_Call {
_c.Call.Return(run)
return _c
}
// CreateResourceGroup provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) CreateResourceGroup(_a0 context.Context, _a1 *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
@ -529,6 +584,61 @@ func (_c *MockQueryCoord_GetPartitionStates_Call) RunAndReturn(run func(context.
return _c
}
// GetQueryNodeDistribution provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) GetQueryNodeDistribution(_a0 context.Context, _a1 *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error) {
ret := _m.Called(_a0, _a1)
var r0 *querypb.GetQueryNodeDistributionResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest) *querypb.GetQueryNodeDistributionResponse); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.GetQueryNodeDistributionResponse)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetQueryNodeDistributionRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoord_GetQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetQueryNodeDistribution'
type MockQueryCoord_GetQueryNodeDistribution_Call struct {
*mock.Call
}
// GetQueryNodeDistribution is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.GetQueryNodeDistributionRequest
func (_e *MockQueryCoord_Expecter) GetQueryNodeDistribution(_a0 interface{}, _a1 interface{}) *MockQueryCoord_GetQueryNodeDistribution_Call {
return &MockQueryCoord_GetQueryNodeDistribution_Call{Call: _e.mock.On("GetQueryNodeDistribution", _a0, _a1)}
}
func (_c *MockQueryCoord_GetQueryNodeDistribution_Call) Run(run func(_a0 context.Context, _a1 *querypb.GetQueryNodeDistributionRequest)) *MockQueryCoord_GetQueryNodeDistribution_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.GetQueryNodeDistributionRequest))
})
return _c
}
func (_c *MockQueryCoord_GetQueryNodeDistribution_Call) Return(_a0 *querypb.GetQueryNodeDistributionResponse, _a1 error) *MockQueryCoord_GetQueryNodeDistribution_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoord_GetQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error)) *MockQueryCoord_GetQueryNodeDistribution_Call {
_c.Call.Return(run)
return _c
}
// GetReplicas provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) GetReplicas(_a0 context.Context, _a1 *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) {
ret := _m.Called(_a0, _a1)
@ -900,6 +1010,61 @@ func (_c *MockQueryCoord_ListCheckers_Call) RunAndReturn(run func(context.Contex
return _c
}
// ListQueryNode provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) ListQueryNode(_a0 context.Context, _a1 *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error) {
ret := _m.Called(_a0, _a1)
var r0 *querypb.ListQueryNodeResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest) *querypb.ListQueryNodeResponse); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.ListQueryNodeResponse)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ListQueryNodeRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoord_ListQueryNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListQueryNode'
type MockQueryCoord_ListQueryNode_Call struct {
*mock.Call
}
// ListQueryNode is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.ListQueryNodeRequest
func (_e *MockQueryCoord_Expecter) ListQueryNode(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ListQueryNode_Call {
return &MockQueryCoord_ListQueryNode_Call{Call: _e.mock.On("ListQueryNode", _a0, _a1)}
}
func (_c *MockQueryCoord_ListQueryNode_Call) Run(run func(_a0 context.Context, _a1 *querypb.ListQueryNodeRequest)) *MockQueryCoord_ListQueryNode_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.ListQueryNodeRequest))
})
return _c
}
func (_c *MockQueryCoord_ListQueryNode_Call) Return(_a0 *querypb.ListQueryNodeResponse, _a1 error) *MockQueryCoord_ListQueryNode_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoord_ListQueryNode_Call) RunAndReturn(run func(context.Context, *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error)) *MockQueryCoord_ListQueryNode_Call {
_c.Call.Return(run)
return _c
}
// ListResourceGroups provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) ListResourceGroups(_a0 context.Context, _a1 *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) {
ret := _m.Called(_a0, _a1)
@ -1271,6 +1436,116 @@ func (_c *MockQueryCoord_ReleasePartitions_Call) RunAndReturn(run func(context.C
return _c
}
// ResumeBalance provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) ResumeBalance(_a0 context.Context, _a1 *querypb.ResumeBalanceRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeBalanceRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoord_ResumeBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeBalance'
type MockQueryCoord_ResumeBalance_Call struct {
*mock.Call
}
// ResumeBalance is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.ResumeBalanceRequest
func (_e *MockQueryCoord_Expecter) ResumeBalance(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ResumeBalance_Call {
return &MockQueryCoord_ResumeBalance_Call{Call: _e.mock.On("ResumeBalance", _a0, _a1)}
}
func (_c *MockQueryCoord_ResumeBalance_Call) Run(run func(_a0 context.Context, _a1 *querypb.ResumeBalanceRequest)) *MockQueryCoord_ResumeBalance_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.ResumeBalanceRequest))
})
return _c
}
func (_c *MockQueryCoord_ResumeBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_ResumeBalance_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoord_ResumeBalance_Call) RunAndReturn(run func(context.Context, *querypb.ResumeBalanceRequest) (*commonpb.Status, error)) *MockQueryCoord_ResumeBalance_Call {
_c.Call.Return(run)
return _c
}
// ResumeNode provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) ResumeNode(_a0 context.Context, _a1 *querypb.ResumeNodeRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeNodeRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoord_ResumeNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeNode'
type MockQueryCoord_ResumeNode_Call struct {
*mock.Call
}
// ResumeNode is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.ResumeNodeRequest
func (_e *MockQueryCoord_Expecter) ResumeNode(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ResumeNode_Call {
return &MockQueryCoord_ResumeNode_Call{Call: _e.mock.On("ResumeNode", _a0, _a1)}
}
func (_c *MockQueryCoord_ResumeNode_Call) Run(run func(_a0 context.Context, _a1 *querypb.ResumeNodeRequest)) *MockQueryCoord_ResumeNode_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.ResumeNodeRequest))
})
return _c
}
func (_c *MockQueryCoord_ResumeNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_ResumeNode_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoord_ResumeNode_Call) RunAndReturn(run func(context.Context, *querypb.ResumeNodeRequest) (*commonpb.Status, error)) *MockQueryCoord_ResumeNode_Call {
_c.Call.Return(run)
return _c
}
// SetAddress provides a mock function with given fields: address
func (_m *MockQueryCoord) SetAddress(address string) {
_m.Called(address)
@ -1734,6 +2009,116 @@ func (_c *MockQueryCoord_Stop_Call) RunAndReturn(run func() error) *MockQueryCoo
return _c
}
// SuspendBalance provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) SuspendBalance(_a0 context.Context, _a1 *querypb.SuspendBalanceRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendBalanceRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoord_SuspendBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendBalance'
type MockQueryCoord_SuspendBalance_Call struct {
*mock.Call
}
// SuspendBalance is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.SuspendBalanceRequest
func (_e *MockQueryCoord_Expecter) SuspendBalance(_a0 interface{}, _a1 interface{}) *MockQueryCoord_SuspendBalance_Call {
return &MockQueryCoord_SuspendBalance_Call{Call: _e.mock.On("SuspendBalance", _a0, _a1)}
}
func (_c *MockQueryCoord_SuspendBalance_Call) Run(run func(_a0 context.Context, _a1 *querypb.SuspendBalanceRequest)) *MockQueryCoord_SuspendBalance_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.SuspendBalanceRequest))
})
return _c
}
func (_c *MockQueryCoord_SuspendBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_SuspendBalance_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoord_SuspendBalance_Call) RunAndReturn(run func(context.Context, *querypb.SuspendBalanceRequest) (*commonpb.Status, error)) *MockQueryCoord_SuspendBalance_Call {
_c.Call.Return(run)
return _c
}
// SuspendNode provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) SuspendNode(_a0 context.Context, _a1 *querypb.SuspendNodeRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendNodeRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoord_SuspendNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendNode'
type MockQueryCoord_SuspendNode_Call struct {
*mock.Call
}
// SuspendNode is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.SuspendNodeRequest
func (_e *MockQueryCoord_Expecter) SuspendNode(_a0 interface{}, _a1 interface{}) *MockQueryCoord_SuspendNode_Call {
return &MockQueryCoord_SuspendNode_Call{Call: _e.mock.On("SuspendNode", _a0, _a1)}
}
func (_c *MockQueryCoord_SuspendNode_Call) Run(run func(_a0 context.Context, _a1 *querypb.SuspendNodeRequest)) *MockQueryCoord_SuspendNode_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.SuspendNodeRequest))
})
return _c
}
func (_c *MockQueryCoord_SuspendNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_SuspendNode_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoord_SuspendNode_Call) RunAndReturn(run func(context.Context, *querypb.SuspendNodeRequest) (*commonpb.Status, error)) *MockQueryCoord_SuspendNode_Call {
_c.Call.Return(run)
return _c
}
// SyncNewCreatedPartition provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) SyncNewCreatedPartition(_a0 context.Context, _a1 *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
@ -1789,6 +2174,61 @@ func (_c *MockQueryCoord_SyncNewCreatedPartition_Call) RunAndReturn(run func(con
return _c
}
// TransferChannel provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) TransferChannel(_a0 context.Context, _a1 *querypb.TransferChannelRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferChannelRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoord_TransferChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferChannel'
type MockQueryCoord_TransferChannel_Call struct {
*mock.Call
}
// TransferChannel is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.TransferChannelRequest
func (_e *MockQueryCoord_Expecter) TransferChannel(_a0 interface{}, _a1 interface{}) *MockQueryCoord_TransferChannel_Call {
return &MockQueryCoord_TransferChannel_Call{Call: _e.mock.On("TransferChannel", _a0, _a1)}
}
func (_c *MockQueryCoord_TransferChannel_Call) Run(run func(_a0 context.Context, _a1 *querypb.TransferChannelRequest)) *MockQueryCoord_TransferChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.TransferChannelRequest))
})
return _c
}
func (_c *MockQueryCoord_TransferChannel_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_TransferChannel_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoord_TransferChannel_Call) RunAndReturn(run func(context.Context, *querypb.TransferChannelRequest) (*commonpb.Status, error)) *MockQueryCoord_TransferChannel_Call {
_c.Call.Return(run)
return _c
}
// TransferNode provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) TransferNode(_a0 context.Context, _a1 *milvuspb.TransferNodeRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
@ -1899,6 +2339,61 @@ func (_c *MockQueryCoord_TransferReplica_Call) RunAndReturn(run func(context.Con
return _c
}
// TransferSegment provides a mock function with given fields: _a0, _a1
func (_m *MockQueryCoord) TransferSegment(_a0 context.Context, _a1 *querypb.TransferSegmentRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferSegmentRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoord_TransferSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferSegment'
type MockQueryCoord_TransferSegment_Call struct {
*mock.Call
}
// TransferSegment is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.TransferSegmentRequest
func (_e *MockQueryCoord_Expecter) TransferSegment(_a0 interface{}, _a1 interface{}) *MockQueryCoord_TransferSegment_Call {
return &MockQueryCoord_TransferSegment_Call{Call: _e.mock.On("TransferSegment", _a0, _a1)}
}
func (_c *MockQueryCoord_TransferSegment_Call) Run(run func(_a0 context.Context, _a1 *querypb.TransferSegmentRequest)) *MockQueryCoord_TransferSegment_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.TransferSegmentRequest))
})
return _c
}
func (_c *MockQueryCoord_TransferSegment_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_TransferSegment_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoord_TransferSegment_Call) RunAndReturn(run func(context.Context, *querypb.TransferSegmentRequest) (*commonpb.Status, error)) *MockQueryCoord_TransferSegment_Call {
_c.Call.Return(run)
return _c
}
// UpdateStateCode provides a mock function with given fields: stateCode
func (_m *MockQueryCoord) UpdateStateCode(stateCode commonpb.StateCode) {
_m.Called(stateCode)

View File

@ -171,6 +171,76 @@ func (_c *MockQueryCoordClient_CheckHealth_Call) RunAndReturn(run func(context.C
return _c
}
// CheckQueryNodeDistribution provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) CheckQueryNodeDistribution(ctx context.Context, in *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) *commonpb.Status); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoordClient_CheckQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckQueryNodeDistribution'
type MockQueryCoordClient_CheckQueryNodeDistribution_Call struct {
*mock.Call
}
// CheckQueryNodeDistribution is a helper method to define mock.On call
// - ctx context.Context
// - in *querypb.CheckQueryNodeDistributionRequest
// - opts ...grpc.CallOption
func (_e *MockQueryCoordClient_Expecter) CheckQueryNodeDistribution(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_CheckQueryNodeDistribution_Call {
return &MockQueryCoordClient_CheckQueryNodeDistribution_Call{Call: _e.mock.On("CheckQueryNodeDistribution",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockQueryCoordClient_CheckQueryNodeDistribution_Call) Run(run func(ctx context.Context, in *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_CheckQueryNodeDistribution_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*querypb.CheckQueryNodeDistributionRequest), variadicArgs...)
})
return _c
}
func (_c *MockQueryCoordClient_CheckQueryNodeDistribution_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_CheckQueryNodeDistribution_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoordClient_CheckQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_CheckQueryNodeDistribution_Call {
_c.Call.Return(run)
return _c
}
// Close provides a mock function with given fields:
func (_m *MockQueryCoordClient) Close() error {
ret := _m.Called()
@ -702,6 +772,76 @@ func (_c *MockQueryCoordClient_GetPartitionStates_Call) RunAndReturn(run func(co
return _c
}
// GetQueryNodeDistribution provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) GetQueryNodeDistribution(ctx context.Context, in *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *querypb.GetQueryNodeDistributionResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) *querypb.GetQueryNodeDistributionResponse); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.GetQueryNodeDistributionResponse)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoordClient_GetQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetQueryNodeDistribution'
type MockQueryCoordClient_GetQueryNodeDistribution_Call struct {
*mock.Call
}
// GetQueryNodeDistribution is a helper method to define mock.On call
// - ctx context.Context
// - in *querypb.GetQueryNodeDistributionRequest
// - opts ...grpc.CallOption
func (_e *MockQueryCoordClient_Expecter) GetQueryNodeDistribution(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_GetQueryNodeDistribution_Call {
return &MockQueryCoordClient_GetQueryNodeDistribution_Call{Call: _e.mock.On("GetQueryNodeDistribution",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockQueryCoordClient_GetQueryNodeDistribution_Call) Run(run func(ctx context.Context, in *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_GetQueryNodeDistribution_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*querypb.GetQueryNodeDistributionRequest), variadicArgs...)
})
return _c
}
func (_c *MockQueryCoordClient_GetQueryNodeDistribution_Call) Return(_a0 *querypb.GetQueryNodeDistributionResponse, _a1 error) *MockQueryCoordClient_GetQueryNodeDistribution_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoordClient_GetQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error)) *MockQueryCoordClient_GetQueryNodeDistribution_Call {
_c.Call.Return(run)
return _c
}
// GetReplicas provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) GetReplicas(ctx context.Context, in *milvuspb.GetReplicasRequest, opts ...grpc.CallOption) (*milvuspb.GetReplicasResponse, error) {
_va := make([]interface{}, len(opts))
@ -1122,6 +1262,76 @@ func (_c *MockQueryCoordClient_ListCheckers_Call) RunAndReturn(run func(context.
return _c
}
// ListQueryNode provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) ListQueryNode(ctx context.Context, in *querypb.ListQueryNodeRequest, opts ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *querypb.ListQueryNodeResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) *querypb.ListQueryNodeResponse); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.ListQueryNodeResponse)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoordClient_ListQueryNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListQueryNode'
type MockQueryCoordClient_ListQueryNode_Call struct {
*mock.Call
}
// ListQueryNode is a helper method to define mock.On call
// - ctx context.Context
// - in *querypb.ListQueryNodeRequest
// - opts ...grpc.CallOption
func (_e *MockQueryCoordClient_Expecter) ListQueryNode(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ListQueryNode_Call {
return &MockQueryCoordClient_ListQueryNode_Call{Call: _e.mock.On("ListQueryNode",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockQueryCoordClient_ListQueryNode_Call) Run(run func(ctx context.Context, in *querypb.ListQueryNodeRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ListQueryNode_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*querypb.ListQueryNodeRequest), variadicArgs...)
})
return _c
}
func (_c *MockQueryCoordClient_ListQueryNode_Call) Return(_a0 *querypb.ListQueryNodeResponse, _a1 error) *MockQueryCoordClient_ListQueryNode_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoordClient_ListQueryNode_Call) RunAndReturn(run func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error)) *MockQueryCoordClient_ListQueryNode_Call {
_c.Call.Return(run)
return _c
}
// ListResourceGroups provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) ListResourceGroups(ctx context.Context, in *milvuspb.ListResourceGroupsRequest, opts ...grpc.CallOption) (*milvuspb.ListResourceGroupsResponse, error) {
_va := make([]interface{}, len(opts))
@ -1542,6 +1752,146 @@ func (_c *MockQueryCoordClient_ReleasePartitions_Call) RunAndReturn(run func(con
return _c
}
// ResumeBalance provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) ResumeBalance(ctx context.Context, in *querypb.ResumeBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) *commonpb.Status); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoordClient_ResumeBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeBalance'
type MockQueryCoordClient_ResumeBalance_Call struct {
*mock.Call
}
// ResumeBalance is a helper method to define mock.On call
// - ctx context.Context
// - in *querypb.ResumeBalanceRequest
// - opts ...grpc.CallOption
func (_e *MockQueryCoordClient_Expecter) ResumeBalance(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ResumeBalance_Call {
return &MockQueryCoordClient_ResumeBalance_Call{Call: _e.mock.On("ResumeBalance",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockQueryCoordClient_ResumeBalance_Call) Run(run func(ctx context.Context, in *querypb.ResumeBalanceRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ResumeBalance_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*querypb.ResumeBalanceRequest), variadicArgs...)
})
return _c
}
func (_c *MockQueryCoordClient_ResumeBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_ResumeBalance_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoordClient_ResumeBalance_Call) RunAndReturn(run func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_ResumeBalance_Call {
_c.Call.Return(run)
return _c
}
// ResumeNode provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) ResumeNode(ctx context.Context, in *querypb.ResumeNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) *commonpb.Status); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoordClient_ResumeNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeNode'
type MockQueryCoordClient_ResumeNode_Call struct {
*mock.Call
}
// ResumeNode is a helper method to define mock.On call
// - ctx context.Context
// - in *querypb.ResumeNodeRequest
// - opts ...grpc.CallOption
func (_e *MockQueryCoordClient_Expecter) ResumeNode(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ResumeNode_Call {
return &MockQueryCoordClient_ResumeNode_Call{Call: _e.mock.On("ResumeNode",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockQueryCoordClient_ResumeNode_Call) Run(run func(ctx context.Context, in *querypb.ResumeNodeRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ResumeNode_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*querypb.ResumeNodeRequest), variadicArgs...)
})
return _c
}
func (_c *MockQueryCoordClient_ResumeNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_ResumeNode_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoordClient_ResumeNode_Call) RunAndReturn(run func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_ResumeNode_Call {
_c.Call.Return(run)
return _c
}
// ShowCollections provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) ShowCollections(ctx context.Context, in *querypb.ShowCollectionsRequest, opts ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error) {
_va := make([]interface{}, len(opts))
@ -1752,6 +2102,146 @@ func (_c *MockQueryCoordClient_ShowPartitions_Call) RunAndReturn(run func(contex
return _c
}
// SuspendBalance provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) SuspendBalance(ctx context.Context, in *querypb.SuspendBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) *commonpb.Status); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoordClient_SuspendBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendBalance'
type MockQueryCoordClient_SuspendBalance_Call struct {
*mock.Call
}
// SuspendBalance is a helper method to define mock.On call
// - ctx context.Context
// - in *querypb.SuspendBalanceRequest
// - opts ...grpc.CallOption
func (_e *MockQueryCoordClient_Expecter) SuspendBalance(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_SuspendBalance_Call {
return &MockQueryCoordClient_SuspendBalance_Call{Call: _e.mock.On("SuspendBalance",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockQueryCoordClient_SuspendBalance_Call) Run(run func(ctx context.Context, in *querypb.SuspendBalanceRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_SuspendBalance_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*querypb.SuspendBalanceRequest), variadicArgs...)
})
return _c
}
func (_c *MockQueryCoordClient_SuspendBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_SuspendBalance_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoordClient_SuspendBalance_Call) RunAndReturn(run func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_SuspendBalance_Call {
_c.Call.Return(run)
return _c
}
// SuspendNode provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) SuspendNode(ctx context.Context, in *querypb.SuspendNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) *commonpb.Status); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoordClient_SuspendNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendNode'
type MockQueryCoordClient_SuspendNode_Call struct {
*mock.Call
}
// SuspendNode is a helper method to define mock.On call
// - ctx context.Context
// - in *querypb.SuspendNodeRequest
// - opts ...grpc.CallOption
func (_e *MockQueryCoordClient_Expecter) SuspendNode(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_SuspendNode_Call {
return &MockQueryCoordClient_SuspendNode_Call{Call: _e.mock.On("SuspendNode",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockQueryCoordClient_SuspendNode_Call) Run(run func(ctx context.Context, in *querypb.SuspendNodeRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_SuspendNode_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*querypb.SuspendNodeRequest), variadicArgs...)
})
return _c
}
func (_c *MockQueryCoordClient_SuspendNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_SuspendNode_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoordClient_SuspendNode_Call) RunAndReturn(run func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_SuspendNode_Call {
_c.Call.Return(run)
return _c
}
// SyncNewCreatedPartition provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) SyncNewCreatedPartition(ctx context.Context, in *querypb.SyncNewCreatedPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
@ -1822,6 +2312,76 @@ func (_c *MockQueryCoordClient_SyncNewCreatedPartition_Call) RunAndReturn(run fu
return _c
}
// TransferChannel provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) TransferChannel(ctx context.Context, in *querypb.TransferChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) *commonpb.Status); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoordClient_TransferChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferChannel'
type MockQueryCoordClient_TransferChannel_Call struct {
*mock.Call
}
// TransferChannel is a helper method to define mock.On call
// - ctx context.Context
// - in *querypb.TransferChannelRequest
// - opts ...grpc.CallOption
func (_e *MockQueryCoordClient_Expecter) TransferChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_TransferChannel_Call {
return &MockQueryCoordClient_TransferChannel_Call{Call: _e.mock.On("TransferChannel",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockQueryCoordClient_TransferChannel_Call) Run(run func(ctx context.Context, in *querypb.TransferChannelRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_TransferChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*querypb.TransferChannelRequest), variadicArgs...)
})
return _c
}
func (_c *MockQueryCoordClient_TransferChannel_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_TransferChannel_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoordClient_TransferChannel_Call) RunAndReturn(run func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_TransferChannel_Call {
_c.Call.Return(run)
return _c
}
// TransferNode provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) TransferNode(ctx context.Context, in *milvuspb.TransferNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
@ -1962,6 +2522,76 @@ func (_c *MockQueryCoordClient_TransferReplica_Call) RunAndReturn(run func(conte
return _c
}
// TransferSegment provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryCoordClient) TransferSegment(ctx context.Context, in *querypb.TransferSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) *commonpb.Status); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoordClient_TransferSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferSegment'
type MockQueryCoordClient_TransferSegment_Call struct {
*mock.Call
}
// TransferSegment is a helper method to define mock.On call
// - ctx context.Context
// - in *querypb.TransferSegmentRequest
// - opts ...grpc.CallOption
func (_e *MockQueryCoordClient_Expecter) TransferSegment(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_TransferSegment_Call {
return &MockQueryCoordClient_TransferSegment_Call{Call: _e.mock.On("TransferSegment",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockQueryCoordClient_TransferSegment_Call) Run(run func(ctx context.Context, in *querypb.TransferSegmentRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_TransferSegment_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*querypb.TransferSegmentRequest), variadicArgs...)
})
return _c
}
func (_c *MockQueryCoordClient_TransferSegment_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_TransferSegment_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryCoordClient_TransferSegment_Call) RunAndReturn(run func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_TransferSegment_Call {
_c.Call.Return(run)
return _c
}
// NewMockQueryCoordClient creates a new instance of MockQueryCoordClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockQueryCoordClient(t interface {

View File

@ -86,13 +86,21 @@ service QueryCoord {
returns (DescribeResourceGroupResponse) {
}
// ops interfaces
rpc ListCheckers(ListCheckersRequest) returns (ListCheckersResponse) {
}
rpc ActivateChecker(ActivateCheckerRequest) returns (common.Status) {
}
rpc DeactivateChecker(DeactivateCheckerRequest) returns (common.Status) {
}
// ops interfaces
rpc ListCheckers(ListCheckersRequest) returns (ListCheckersResponse) {}
rpc ActivateChecker(ActivateCheckerRequest) returns (common.Status) {}
rpc DeactivateChecker(DeactivateCheckerRequest) returns (common.Status) {}
rpc ListQueryNode(ListQueryNodeRequest) returns (ListQueryNodeResponse) {}
rpc GetQueryNodeDistribution(GetQueryNodeDistributionRequest) returns (GetQueryNodeDistributionResponse) {}
rpc SuspendBalance(SuspendBalanceRequest) returns (common.Status) {}
rpc ResumeBalance(ResumeBalanceRequest) returns (common.Status) {}
rpc SuspendNode(SuspendNodeRequest) returns (common.Status) {}
rpc ResumeNode(ResumeNodeRequest) returns (common.Status) {}
rpc TransferSegment(TransferSegmentRequest) returns (common.Status) {}
rpc TransferChannel(TransferChannelRequest) returns (common.Status) {}
rpc CheckQueryNodeDistribution(CheckQueryNodeDistributionRequest) returns (common.Status) {}
}
service QueryNode {
@ -793,3 +801,75 @@ message CollectionTarget {
repeated ChannelTarget Channel_targets = 2;
int64 version = 3;
}
message NodeInfo {
int64 ID = 2;
string address = 3;
string state = 4;
}
message ListQueryNodeRequest {
common.MsgBase base = 1;
}
message ListQueryNodeResponse {
common.Status status = 1;
repeated NodeInfo nodeInfos = 2;
}
message GetQueryNodeDistributionRequest {
common.MsgBase base = 1;
int64 nodeID = 2;
}
message GetQueryNodeDistributionResponse {
common.Status status = 1;
int64 ID = 2;
repeated string channel_names = 3;
repeated int64 sealed_segmentIDs = 4;
}
message SuspendBalanceRequest {
common.MsgBase base = 1;
}
message ResumeBalanceRequest {
common.MsgBase base = 1;
}
message SuspendNodeRequest {
common.MsgBase base = 1;
int64 nodeID = 2;
}
message ResumeNodeRequest {
common.MsgBase base = 1;
int64 nodeID = 2;
}
message TransferSegmentRequest {
common.MsgBase base = 1;
int64 segmentID = 2;
int64 source_nodeID = 3;
int64 target_nodeID = 4;
bool transfer_all = 5;
bool to_all_nodes = 6;
bool copy_mode = 7;
}
message TransferChannelRequest {
common.MsgBase base = 1;
string channel_name = 2;
int64 source_nodeID = 3;
int64 target_nodeID = 4;
bool transfer_all = 5;
bool to_all_nodes = 6;
bool copy_mode = 7;
}
message CheckQueryNodeDistributionRequest {
common.MsgBase base = 1;
int64 source_nodeID = 3;
int64 target_nodeID = 4;
}

View File

@ -17,14 +17,18 @@
package proxy
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"sync"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
management "github.com/milvus-io/milvus/internal/http"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/merr"
)
// this file contains proxy management restful API handler
@ -32,6 +36,17 @@ import (
const (
mgrRouteGcPause = `/management/datacoord/garbage_collection/pause`
mgrRouteGcResume = `/management/datacoord/garbage_collection/resume`
mgrSuspendQueryCoordBalance = `/management/querycoord/balance/suspend`
mgrResumeQueryCoordBalance = `/management/querycoord/balance/resume`
mgrTransferSegment = `/management/querycoord/transfer/segment`
mgrTransferChannel = `/management/querycoord/transfer/channel`
mgrSuspendQueryNode = `/management/querycoord/node/suspend`
mgrResumeQueryNode = `/management/querycoord/node/resume`
mgrListQueryNode = `/management/querycoord/node/list`
mgrGetQueryNodeDistribution = `/management/querycoord/distribution/get`
mgrCheckQueryNodeDistribution = `/management/querycoord/distribution/check`
)
var mgrRouteRegisterOnce sync.Once
@ -46,6 +61,42 @@ func RegisterMgrRoute(proxy *Proxy) {
Path: mgrRouteGcResume,
HandlerFunc: proxy.ResumeDatacoordGC,
})
management.Register(&management.Handler{
Path: mgrListQueryNode,
HandlerFunc: proxy.ListQueryNode,
})
management.Register(&management.Handler{
Path: mgrGetQueryNodeDistribution,
HandlerFunc: proxy.GetQueryNodeDistribution,
})
management.Register(&management.Handler{
Path: mgrSuspendQueryCoordBalance,
HandlerFunc: proxy.SuspendQueryCoordBalance,
})
management.Register(&management.Handler{
Path: mgrResumeQueryCoordBalance,
HandlerFunc: proxy.ResumeQueryCoordBalance,
})
management.Register(&management.Handler{
Path: mgrSuspendQueryNode,
HandlerFunc: proxy.SuspendQueryNode,
})
management.Register(&management.Handler{
Path: mgrResumeQueryNode,
HandlerFunc: proxy.ResumeQueryNode,
})
management.Register(&management.Handler{
Path: mgrTransferSegment,
HandlerFunc: proxy.TransferSegment,
})
management.Register(&management.Handler{
Path: mgrTransferChannel,
HandlerFunc: proxy.TransferChannel,
})
management.Register(&management.Handler{
Path: mgrCheckQueryNodeDistribution,
HandlerFunc: proxy.CheckQueryNodeDistribution,
})
})
}
@ -91,3 +142,362 @@ func (node *Proxy) ResumeDatacoordGC(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"msg": "OK"}`))
}
func (node *Proxy) ListQueryNode(w http.ResponseWriter, req *http.Request) {
resp, err := node.queryCoord.ListQueryNode(req.Context(), &querypb.ListQueryNodeRequest{
Base: commonpbutil.NewMsgBase(),
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to list query node, %s"}`, err.Error())))
return
}
if !merr.Ok(resp.GetStatus()) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to list query node, %s"}`, resp.GetStatus().GetReason())))
return
}
w.WriteHeader(http.StatusOK)
// skip marshal status to output
resp.Status = nil
bytes, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to list query node, %s"}`, err.Error())))
return
}
w.Write(bytes)
}
func (node *Proxy) GetQueryNodeDistribution(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
return
}
nodeID, err := strconv.ParseInt(req.FormValue("node_id"), 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, err.Error())))
return
}
resp, err := node.queryCoord.GetQueryNodeDistribution(req.Context(), &querypb.GetQueryNodeDistributionRequest{
Base: commonpbutil.NewMsgBase(),
NodeID: nodeID,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, err.Error())))
return
}
if !merr.Ok(resp.GetStatus()) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, resp.GetStatus().GetReason())))
return
}
w.WriteHeader(http.StatusOK)
// skip marshal status to output
resp.Status = nil
bytes, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, err.Error())))
return
}
w.Write(bytes)
}
func (node *Proxy) SuspendQueryCoordBalance(w http.ResponseWriter, req *http.Request) {
resp, err := node.queryCoord.SuspendBalance(req.Context(), &querypb.SuspendBalanceRequest{
Base: commonpbutil.NewMsgBase(),
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend balance, %s"}`, err.Error())))
return
}
if !merr.Ok(resp) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend balance, %s"}`, resp.GetReason())))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"msg": "OK"}`))
}
func (node *Proxy) ResumeQueryCoordBalance(w http.ResponseWriter, req *http.Request) {
resp, err := node.queryCoord.ResumeBalance(req.Context(), &querypb.ResumeBalanceRequest{
Base: commonpbutil.NewMsgBase(),
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume balance, %s"}`, err.Error())))
return
}
if !merr.Ok(resp) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume balance, %s"}`, resp.GetReason())))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"msg": "OK"}`))
}
func (node *Proxy) SuspendQueryNode(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
return
}
nodeID, err := strconv.ParseInt(req.FormValue("node_id"), 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend node, %s"}`, err.Error())))
return
}
resp, err := node.queryCoord.SuspendNode(req.Context(), &querypb.SuspendNodeRequest{
Base: commonpbutil.NewMsgBase(),
NodeID: nodeID,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend node, %s"}`, err.Error())))
return
}
if !merr.Ok(resp) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend node, %s"}`, resp.GetReason())))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"msg": "OK"}`))
}
func (node *Proxy) ResumeQueryNode(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
return
}
nodeID, err := strconv.ParseInt(req.FormValue("node_id"), 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume node, %s"}`, err.Error())))
return
}
resp, err := node.queryCoord.ResumeNode(req.Context(), &querypb.ResumeNodeRequest{
Base: commonpbutil.NewMsgBase(),
NodeID: nodeID,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume node, %s"}`, err.Error())))
return
}
if !merr.Ok(resp) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume node, %s"}`, resp.GetReason())))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"msg": "OK"}`))
}
func (node *Proxy) TransferSegment(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
return
}
request := &querypb.TransferSegmentRequest{
Base: commonpbutil.NewMsgBase(),
}
source, err := strconv.ParseInt(req.FormValue("source_node_id"), 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": failed to transfer segment", %s"}`, err.Error())))
return
}
request.SourceNodeID = source
target := req.FormValue("target_node_id")
if len(target) == 0 {
request.ToAllNodes = true
} else {
value, err := strconv.ParseInt(target, 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
return
}
request.TargetNodeID = value
}
segmentID := req.FormValue("segment_id")
if len(segmentID) == 0 {
request.TransferAll = true
} else {
value, err := strconv.ParseInt(segmentID, 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
return
}
request.TargetNodeID = value
}
copyMode := req.FormValue("copy_mode")
if len(copyMode) == 0 {
request.CopyMode = true
} else {
value, err := strconv.ParseBool(copyMode)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
return
}
request.CopyMode = value
}
resp, err := node.queryCoord.TransferSegment(req.Context(), request)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
return
}
if !merr.Ok(resp) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, resp.GetReason())))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"msg": "OK"}`))
}
func (node *Proxy) TransferChannel(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error())))
return
}
request := &querypb.TransferChannelRequest{
Base: commonpbutil.NewMsgBase(),
}
source, err := strconv.ParseInt(req.FormValue("source_node_id"), 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": failed to transfer channel", %s"}`, err.Error())))
return
}
request.SourceNodeID = source
target := req.FormValue("target_node_id")
if len(target) == 0 {
request.ToAllNodes = true
} else {
value, err := strconv.ParseInt(target, 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error())))
return
}
request.TargetNodeID = value
}
channel := req.FormValue("channel_name")
if len(channel) == 0 {
request.TransferAll = true
} else {
request.ChannelName = channel
}
copyMode := req.FormValue("copy_mode")
if len(copyMode) == 0 {
request.CopyMode = false
} else {
value, err := strconv.ParseBool(copyMode)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error())))
return
}
request.CopyMode = value
}
resp, err := node.queryCoord.TransferChannel(req.Context(), request)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error())))
return
}
if !merr.Ok(resp) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, resp.GetReason())))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"msg": "OK"}`))
}
func (node *Proxy) CheckQueryNodeDistribution(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, err.Error())))
return
}
source, err := strconv.ParseInt(req.FormValue("source_node_id"), 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": failed to check whether query node has same distribution", %s"}`, err.Error())))
return
}
target, err := strconv.ParseInt(req.FormValue("target_node_id"), 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, err.Error())))
return
}
resp, err := node.queryCoord.CheckQueryNodeDistribution(req.Context(), &querypb.CheckQueryNodeDistributionRequest{
Base: commonpbutil.NewMsgBase(),
SourceNodeID: source,
TargetNodeID: target,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, err.Error())))
return
}
if !merr.Ok(resp) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, resp.GetReason())))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"msg": "OK"}`))
}

View File

@ -20,6 +20,7 @@ import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/cockroachdb/errors"
@ -30,19 +31,25 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type ProxyManagementSuite struct {
suite.Suite
datacoord *mocks.MockDataCoordClient
proxy *Proxy
querycoord *mocks.MockQueryCoordClient
datacoord *mocks.MockDataCoordClient
proxy *Proxy
}
func (s *ProxyManagementSuite) SetupTest() {
s.datacoord = mocks.NewMockDataCoordClient(s.T())
s.querycoord = mocks.NewMockQueryCoordClient(s.T())
s.proxy = &Proxy{
dataCoord: s.datacoord,
dataCoord: s.datacoord,
queryCoord: s.querycoord,
}
}
@ -158,6 +165,527 @@ func (s *ProxyManagementSuite) TestResumeDatacoordGC() {
})
}
func (s *ProxyManagementSuite) TestListQueryNode() {
s.Run("normal", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().ListQueryNode(mock.Anything, mock.Anything).Return(&querypb.ListQueryNodeResponse{
Status: merr.Success(),
NodeInfos: []*querypb.NodeInfo{
{
ID: 1,
Address: "localhost",
State: "Healthy",
},
},
}, nil)
req, err := http.NewRequest(http.MethodPost, mgrListQueryNode, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.ListQueryNode(recorder, req)
s.Equal(http.StatusOK, recorder.Code)
s.Equal(`{"nodeInfos":[{"ID":1,"address":"localhost","state":"Healthy"}]}`, recorder.Body.String())
})
s.Run("return_error", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().ListQueryNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
req, err := http.NewRequest(http.MethodPost, mgrListQueryNode, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.ListQueryNode(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
s.Run("return_failure", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().ListQueryNode(mock.Anything, mock.Anything).Return(&querypb.ListQueryNodeResponse{
Status: merr.Status(merr.ErrServiceNotReady),
}, nil)
req, err := http.NewRequest(http.MethodPost, mgrListQueryNode, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.ListQueryNode(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
}
func (s *ProxyManagementSuite) TestGetQueryNodeDistribution() {
s.Run("normal", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().GetQueryNodeDistribution(mock.Anything, mock.Anything).Return(&querypb.GetQueryNodeDistributionResponse{
Status: merr.Success(),
ID: 1,
ChannelNames: []string{"channel-1"},
SealedSegmentIDs: []int64{1, 2, 3},
}, nil)
req, err := http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader("node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.GetQueryNodeDistribution(recorder, req)
s.Equal(http.StatusOK, recorder.Code)
s.Equal(`{"ID":1,"channel_names":["channel-1"],"sealed_segmentIDs":[1,2,3]}`, recorder.Body.String())
})
s.Run("return_error", func() {
s.SetupTest()
defer s.TearDownTest()
// test invalid request body
req, err := http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.GetQueryNodeDistribution(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test miss requested param
req, err = http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader(""))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.GetQueryNodeDistribution(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test rpc return error
s.querycoord.EXPECT().GetQueryNodeDistribution(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
req, err = http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader("node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.GetQueryNodeDistribution(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
s.Run("return_failure", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().GetQueryNodeDistribution(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
req, err := http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader("node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.GetQueryNodeDistribution(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
}
func (s *ProxyManagementSuite) TestSuspendQueryCoordBalance() {
s.Run("normal", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().SuspendBalance(mock.Anything, mock.Anything).Return(merr.Success(), nil)
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryCoordBalance, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.SuspendQueryCoordBalance(recorder, req)
s.Equal(http.StatusOK, recorder.Code)
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
})
s.Run("return_error", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().SuspendBalance(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryCoordBalance, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.SuspendQueryCoordBalance(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
s.Run("return_failure", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().SuspendBalance(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryCoordBalance, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.SuspendQueryCoordBalance(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
}
func (s *ProxyManagementSuite) TestResumeQueryCoordBalance() {
s.Run("normal", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().ResumeBalance(mock.Anything, mock.Anything).Return(merr.Success(), nil)
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryCoordBalance, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.ResumeQueryCoordBalance(recorder, req)
s.Equal(http.StatusOK, recorder.Code)
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
})
s.Run("return_error", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().ResumeBalance(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryCoordBalance, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.ResumeQueryCoordBalance(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
s.Run("return_failure", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().ResumeBalance(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryCoordBalance, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.ResumeQueryCoordBalance(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
}
func (s *ProxyManagementSuite) TestSuspendQueryNode() {
s.Run("normal", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().SuspendNode(mock.Anything, mock.Anything).Return(merr.Success(), nil)
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader("node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.SuspendQueryNode(recorder, req)
s.Equal(http.StatusOK, recorder.Code)
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
})
s.Run("return_error", func() {
s.SetupTest()
defer s.TearDownTest()
// test invalid request body
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryNode, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.SuspendQueryNode(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test miss requested param
req, err = http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader(""))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.SuspendQueryNode(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test rpc return error
s.querycoord.EXPECT().SuspendNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
req, err = http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader("node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.SuspendQueryNode(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
s.Run("return_failure", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().SuspendNode(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader("node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.SuspendQueryNode(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
}
func (s *ProxyManagementSuite) TestResumeQueryNode() {
s.Run("normal", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().ResumeNode(mock.Anything, mock.Anything).Return(merr.Success(), nil)
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader("node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.ResumeQueryNode(recorder, req)
s.Equal(http.StatusOK, recorder.Code)
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
})
s.Run("return_error", func() {
s.SetupTest()
defer s.TearDownTest()
// test invalid request body
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryNode, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.ResumeQueryNode(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test miss requested param
req, err = http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader(""))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.ResumeQueryNode(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test rpc return error
s.querycoord.EXPECT().ResumeNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
req, err = http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader("node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.ResumeQueryNode(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
s.Run("return_failure", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().ResumeNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader("node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.ResumeQueryNode(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
}
func (s *ProxyManagementSuite) TestTransferSegment() {
s.Run("normal", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().TransferSegment(mock.Anything, mock.Anything).Return(merr.Success(), nil)
req, err := http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1&copy_mode=false"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.TransferSegment(recorder, req)
s.Equal(http.StatusOK, recorder.Code)
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
// test use default param
req, err = http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.TransferSegment(recorder, req)
s.Equal(http.StatusOK, recorder.Code)
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
})
s.Run("return_error", func() {
s.SetupTest()
defer s.TearDownTest()
// test invalid request body
req, err := http.NewRequest(http.MethodPost, mgrTransferSegment, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.TransferSegment(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test miss requested param
req, err = http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader(""))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.TransferSegment(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test rpc return error
s.querycoord.EXPECT().TransferSegment(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
req, err = http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.TransferSegment(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
s.Run("return_failure", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().TransferSegment(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
req, err := http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.TransferSegment(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
}
func (s *ProxyManagementSuite) TestTransferChannel() {
s.Run("normal", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().TransferChannel(mock.Anything, mock.Anything).Return(merr.Success(), nil)
req, err := http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1&copy_mode=false"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.TransferChannel(recorder, req)
s.Equal(http.StatusOK, recorder.Code)
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
// test use default param
req, err = http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.TransferChannel(recorder, req)
s.Equal(http.StatusOK, recorder.Code)
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
})
s.Run("return_error", func() {
s.SetupTest()
defer s.TearDownTest()
// test invalid request body
req, err := http.NewRequest(http.MethodPost, mgrTransferChannel, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.TransferChannel(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test miss requested param
req, err = http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader(""))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.TransferChannel(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test rpc return error
s.querycoord.EXPECT().TransferChannel(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
req, err = http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.TransferChannel(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
s.Run("return_failure", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().TransferChannel(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
req, err := http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.TransferChannel(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
}
func (s *ProxyManagementSuite) TestCheckQueryNodeDistribution() {
s.Run("normal", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().CheckQueryNodeDistribution(mock.Anything, mock.Anything).Return(merr.Success(), nil)
req, err := http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader("source_node_id=1&target_node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.CheckQueryNodeDistribution(recorder, req)
s.Equal(http.StatusOK, recorder.Code)
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
})
s.Run("return_error", func() {
s.SetupTest()
defer s.TearDownTest()
// test invalid request body
req, err := http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, nil)
s.Require().NoError(err)
recorder := httptest.NewRecorder()
s.proxy.CheckQueryNodeDistribution(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test miss requested param
req, err = http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader(""))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.CheckQueryNodeDistribution(recorder, req)
s.Equal(http.StatusBadRequest, recorder.Code)
// test rpc return error
s.querycoord.EXPECT().CheckQueryNodeDistribution(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
req, err = http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader("source_node_id=1&target_node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder = httptest.NewRecorder()
s.proxy.CheckQueryNodeDistribution(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
s.Run("return_failure", func() {
s.SetupTest()
defer s.TearDownTest()
s.querycoord.EXPECT().CheckQueryNodeDistribution(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
req, err := http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader("source_node_id=1&target_node_id=1"))
s.Require().NoError(err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
s.proxy.CheckQueryNodeDistribution(recorder, req)
s.Equal(http.StatusInternalServerError, recorder.Code)
})
}
func TestProxyManagement(t *testing.T) {
suite.Run(t, new(ProxyManagementSuite))
}

View File

@ -20,6 +20,8 @@ import (
"fmt"
"sort"
"github.com/samber/lo"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
@ -57,8 +59,8 @@ var (
)
type Balance interface {
AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan
AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan
AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan
AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan
BalanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)
}
@ -67,7 +69,15 @@ type RoundRobinBalancer struct {
nodeManager *session.NodeManager
}
func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
if !manualBalance {
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
info := b.nodeManager.Get(node)
return info != nil && info.GetState() == session.NodeStateNormal
})
}
nodesInfo := b.getNodes(nodes)
if len(nodesInfo) == 0 {
return nil
@ -90,7 +100,14 @@ func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.
return ret
}
func (b *RoundRobinBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan {
func (b *RoundRobinBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
if !manualBalance {
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
info := b.nodeManager.Get(node)
return info != nil && info.GetState() == session.NodeStateNormal
})
}
nodesInfo := b.getNodes(nodes)
if len(nodesInfo) == 0 {
return nil

View File

@ -97,7 +97,7 @@ func (suite *BalanceTestSuite) TestAssignBalance() {
suite.mockScheduler.EXPECT().GetNodeSegmentDelta(c.nodeIDs[i]).Return(c.deltaCnts[i])
}
}
plans := suite.roundRobinBalancer.AssignSegment(0, c.assignments, c.nodeIDs)
plans := suite.roundRobinBalancer.AssignSegment(0, c.assignments, c.nodeIDs, false)
suite.ElementsMatch(c.expectPlans, plans)
})
}
@ -161,7 +161,7 @@ func (suite *BalanceTestSuite) TestAssignChannel() {
suite.mockScheduler.EXPECT().GetNodeChannelDelta(c.nodeIDs[i]).Return(c.deltaCnts[i])
}
}
plans := suite.roundRobinBalancer.AssignChannel(c.assignments, c.nodeIDs)
plans := suite.roundRobinBalancer.AssignChannel(c.assignments, c.nodeIDs, false)
suite.ElementsMatch(c.expectPlans, plans)
})
}

View File

@ -20,13 +20,13 @@ func (_m *MockBalancer) EXPECT() *MockBalancer_Expecter {
return &MockBalancer_Expecter{mock: &_m.Mock}
}
// AssignChannel provides a mock function with given fields: channels, nodes
func (_m *MockBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan {
ret := _m.Called(channels, nodes)
// AssignChannel provides a mock function with given fields: channels, nodes, manualBalance
func (_m *MockBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
ret := _m.Called(channels, nodes, manualBalance)
var r0 []ChannelAssignPlan
if rf, ok := ret.Get(0).(func([]*meta.DmChannel, []int64) []ChannelAssignPlan); ok {
r0 = rf(channels, nodes)
if rf, ok := ret.Get(0).(func([]*meta.DmChannel, []int64, bool) []ChannelAssignPlan); ok {
r0 = rf(channels, nodes, manualBalance)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]ChannelAssignPlan)
@ -44,13 +44,14 @@ type MockBalancer_AssignChannel_Call struct {
// AssignChannel is a helper method to define mock.On call
// - channels []*meta.DmChannel
// - nodes []int64
func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}) *MockBalancer_AssignChannel_Call {
return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes)}
// - manualBalance bool
func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignChannel_Call {
return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes, manualBalance)}
}
func (_c *MockBalancer_AssignChannel_Call) Run(run func(channels []*meta.DmChannel, nodes []int64)) *MockBalancer_AssignChannel_Call {
func (_c *MockBalancer_AssignChannel_Call) Run(run func(channels []*meta.DmChannel, nodes []int64, manualBalance bool)) *MockBalancer_AssignChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]*meta.DmChannel), args[1].([]int64))
run(args[0].([]*meta.DmChannel), args[1].([]int64), args[2].(bool))
})
return _c
}
@ -60,18 +61,18 @@ func (_c *MockBalancer_AssignChannel_Call) Return(_a0 []ChannelAssignPlan) *Mock
return _c
}
func (_c *MockBalancer_AssignChannel_Call) RunAndReturn(run func([]*meta.DmChannel, []int64) []ChannelAssignPlan) *MockBalancer_AssignChannel_Call {
func (_c *MockBalancer_AssignChannel_Call) RunAndReturn(run func([]*meta.DmChannel, []int64, bool) []ChannelAssignPlan) *MockBalancer_AssignChannel_Call {
_c.Call.Return(run)
return _c
}
// AssignSegment provides a mock function with given fields: collectionID, segments, nodes
func (_m *MockBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
ret := _m.Called(collectionID, segments, nodes)
// AssignSegment provides a mock function with given fields: collectionID, segments, nodes, manualBalance
func (_m *MockBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
ret := _m.Called(collectionID, segments, nodes, manualBalance)
var r0 []SegmentAssignPlan
if rf, ok := ret.Get(0).(func(int64, []*meta.Segment, []int64) []SegmentAssignPlan); ok {
r0 = rf(collectionID, segments, nodes)
if rf, ok := ret.Get(0).(func(int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan); ok {
r0 = rf(collectionID, segments, nodes, manualBalance)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]SegmentAssignPlan)
@ -90,13 +91,14 @@ type MockBalancer_AssignSegment_Call struct {
// - collectionID int64
// - segments []*meta.Segment
// - nodes []int64
func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}) *MockBalancer_AssignSegment_Call {
return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes)}
// - manualBalance bool
func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignSegment_Call {
return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes, manualBalance)}
}
func (_c *MockBalancer_AssignSegment_Call) Run(run func(collectionID int64, segments []*meta.Segment, nodes []int64)) *MockBalancer_AssignSegment_Call {
func (_c *MockBalancer_AssignSegment_Call) Run(run func(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool)) *MockBalancer_AssignSegment_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].([]*meta.Segment), args[2].([]int64))
run(args[0].(int64), args[1].([]*meta.Segment), args[2].([]int64), args[3].(bool))
})
return _c
}
@ -106,7 +108,7 @@ func (_c *MockBalancer_AssignSegment_Call) Return(_a0 []SegmentAssignPlan) *Mock
return _c
}
func (_c *MockBalancer_AssignSegment_Call) RunAndReturn(run func(int64, []*meta.Segment, []int64) []SegmentAssignPlan) *MockBalancer_AssignSegment_Call {
func (_c *MockBalancer_AssignSegment_Call) RunAndReturn(run func(int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan) *MockBalancer_AssignSegment_Call {
_c.Call.Return(run)
return _c
}

View File

@ -41,7 +41,15 @@ type RowCountBasedBalancer struct {
// AssignSegment, when row count based balancer assign segments, it will assign segment to node with least global row count.
// try to make every query node has same row count.
func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
if !manualBalance {
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
info := b.nodeManager.Get(node)
return info != nil && info.GetState() == session.NodeStateNormal
})
}
nodeItems := b.convertToNodeItemsBySegment(nodes)
if len(nodeItems) == 0 {
return nil
@ -75,7 +83,15 @@ func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*me
// AssignSegment, when row count based balancer assign segments, it will assign channel to node with least global channel count.
// try to make every query node has channel count
func (b *RowCountBasedBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan {
func (b *RowCountBasedBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
if !manualBalance {
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
info := b.nodeManager.Get(node)
return info != nil && info.GetState() == session.NodeStateNormal
})
}
nodeItems := b.convertToNodeItemsByChannel(nodes)
nodeItems = lo.Shuffle(nodeItems)
if len(nodeItems) == 0 {
@ -215,7 +231,7 @@ func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, on
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
segment.GetLevel() != datapb.SegmentLevel_L0
})
plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes)
plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes, false)
for i := range plans {
plans[i].From = nodeID
plans[i].Replica = replica
@ -283,7 +299,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode
return nil
}
segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, nodesWithLessRow)
segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, nodesWithLessRow, false)
for i := range segmentPlans {
segmentPlans[i].From = segmentPlans[i].Segment.Node
segmentPlans[i].Replica = replica
@ -296,7 +312,7 @@ func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, on
channelPlans := make([]ChannelAssignPlan, 0)
for _, nodeID := range offlineNodes {
dmChannels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)
plans := b.AssignChannel(dmChannels, onlineNodes)
plans := b.AssignChannel(dmChannels, onlineNodes, false)
for i := range plans {
plans[i].From = nodeID
plans[i].Replica = replica
@ -334,7 +350,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode
return nil
}
channelPlans := b.AssignChannel(channelsToMove, nodeWithLessChannel)
channelPlans := b.AssignChannel(channelsToMove, nodeWithLessChannel, false)
for i := range channelPlans {
channelPlans[i].From = channelPlans[i].Channel.Node
channelPlans[i].Replica = replica

View File

@ -136,12 +136,67 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() {
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
}
plans := balancer.AssignSegment(0, c.assignments, c.nodes)
plans := balancer.AssignSegment(0, c.assignments, c.nodes, false)
assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans, plans)
})
}
}
func (suite *RowCountBasedBalancerTestSuite) TestSuspendNode() {
cases := []struct {
name string
distributions map[int64][]*meta.Segment
assignments []*meta.Segment
nodes []int64
segmentCnts []int
states []session.State
expectPlans []SegmentAssignPlan
}{
{
name: "test suspend node",
distributions: map[int64][]*meta.Segment{
2: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 20}, Node: 2}},
3: {{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 30}, Node: 3}},
},
assignments: []*meta.Segment{
{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 5}},
{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10}},
{SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 15}},
},
nodes: []int64{1, 2, 3, 4},
states: []session.State{session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend},
segmentCnts: []int{0, 1, 1, 0},
expectPlans: []SegmentAssignPlan{},
},
}
for _, c := range cases {
suite.Run(c.name, func() {
// I do not find a better way to do the setup and teardown work for subtests yet.
// If you do, please replace with it.
suite.SetupSuite()
defer suite.TearDownTest()
balancer := suite.balancer
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
for i := range c.nodes {
nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: c.nodes[i],
Address: "localhost",
Hostname: "localhost",
})
nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i]))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
}
plans := balancer.AssignSegment(0, c.assignments, c.nodes, false)
// all node has been suspend, so no node to assign segment
suite.ElementsMatch(c.expectPlans, plans)
})
}
}
func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
cases := []struct {
name string
@ -888,7 +943,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegmentWithGrowing() {
NumOfGrowingRows: 50,
}
suite.balancer.dist.LeaderViewManager.Update(1, leaderView)
plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions))
plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false)
for _, p := range plans {
suite.Equal(int64(2), p.To)
}

View File

@ -50,7 +50,15 @@ func NewScoreBasedBalancer(scheduler task.Scheduler,
}
// AssignSegment got a segment list, and try to assign each segment to node's with lowest score
func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
if !manualBalance {
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
info := b.nodeManager.Get(node)
return info != nil && info.GetState() == session.NodeStateNormal
})
}
// calculate each node's score
nodeItems := b.convertToNodeItems(collectionID, nodes)
if len(nodeItems) == 0 {
@ -87,7 +95,8 @@ func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.
sourceNode := nodeItemsMap[s.Node]
// if segment's node exist, which means this segment comes from balancer. we should consider the benefit
// if the segment reassignment doesn't got enough benefit, we should skip this reassignment
if sourceNode != nil && !b.hasEnoughBenefit(sourceNode, targetNode, priorityChange) {
// notice: we should skip benefit check for manual balance
if !manualBalance && sourceNode != nil && !b.hasEnoughBenefit(sourceNode, targetNode, priorityChange) {
return
}
@ -249,7 +258,7 @@ func (b *ScoreBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlin
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
segment.GetLevel() != datapb.SegmentLevel_L0
})
plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes)
plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes, false)
for i := range plans {
plans[i].From = nodeID
plans[i].Replica = replica
@ -313,7 +322,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNodes [
return nil
}
segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, onlineNodes)
segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, onlineNodes, false)
for i := range segmentPlans {
segmentPlans[i].From = segmentPlans[i].Segment.Node
segmentPlans[i].Replica = replica

View File

@ -232,13 +232,68 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() {
suite.balancer.nodeManager.Add(nodeInfo)
}
for i := range c.collectionIDs {
plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes)
plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes, false)
assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans[i], plans)
}
})
}
}
func (suite *ScoreBasedBalancerTestSuite) TestSuspendNode() {
cases := []struct {
name string
distributions map[int64][]*meta.Segment
assignments []*meta.Segment
nodes []int64
segmentCnts []int
states []session.State
expectPlans []SegmentAssignPlan
}{
{
name: "test suspend node",
distributions: map[int64][]*meta.Segment{
2: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 20}, Node: 2}},
3: {{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 30}, Node: 3}},
},
assignments: []*meta.Segment{
{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 5}},
{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10}},
{SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 15}},
},
nodes: []int64{1, 2, 3, 4},
states: []session.State{session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend},
segmentCnts: []int{0, 1, 1, 0},
expectPlans: []SegmentAssignPlan{},
},
}
for _, c := range cases {
suite.Run(c.name, func() {
// I do not find a better way to do the setup and teardown work for subtests yet.
// If you do, please replace with it.
suite.SetupSuite()
defer suite.TearDownTest()
balancer := suite.balancer
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
for i := range c.nodes {
nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: c.nodes[i],
Address: "localhost",
Hostname: "localhost",
})
nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i]))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
}
plans := balancer.AssignSegment(0, c.assignments, c.nodes, false)
// all node has been suspend, so no node to assign segment
suite.ElementsMatch(c.expectPlans, plans)
})
}
}
func (suite *ScoreBasedBalancerTestSuite) TestAssignSegmentWithGrowing() {
suite.SetupSuite()
defer suite.TearDownTest()
@ -279,7 +334,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegmentWithGrowing() {
NumOfGrowingRows: 50,
}
suite.balancer.dist.LeaderViewManager.Update(1, leaderView)
plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions))
plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false)
for _, p := range plans {
suite.Equal(int64(2), p.To)
}

View File

@ -222,7 +222,7 @@ func (c *ChannelChecker) createChannelLoadTask(ctx context.Context, channels []*
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool {
return !outboundNodes.Contain(node)
})
plans := c.balancer.AssignChannel(channels, availableNodes)
plans := c.balancer.AssignChannel(channels, availableNodes, false)
for i := range plans {
plans[i].Replica = replica
}

View File

@ -100,7 +100,7 @@ func (suite *ChannelCheckerTestSuite) setNodeAvailable(nodes ...int64) {
func (suite *ChannelCheckerTestSuite) createMockBalancer() balance.Balance {
balancer := balance.NewMockBalancer(suite.T())
balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything).Maybe().Return(func(channels []*meta.DmChannel, nodes []int64) []balance.ChannelAssignPlan {
balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(channels []*meta.DmChannel, nodes []int64, _ bool) []balance.ChannelAssignPlan {
plans := make([]balance.ChannelAssignPlan, 0, len(channels))
for i, c := range channels {
plan := balance.ChannelAssignPlan{

View File

@ -134,11 +134,11 @@ func (suite *CheckerControllerSuite) TestBasic() {
assignSegCounter := atomic.NewInt32(0)
assingChanCounter := atomic.NewInt32(0)
suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(i1 int64, s []*meta.Segment, i2 []int64) []balance.SegmentAssignPlan {
suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(i1 int64, s []*meta.Segment, i2 []int64, i4 bool) []balance.SegmentAssignPlan {
assignSegCounter.Inc()
return nil
})
suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything).RunAndReturn(func(dc []*meta.DmChannel, i []int64) []balance.ChannelAssignPlan {
suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(dc []*meta.DmChannel, i []int64, _ bool) []balance.ChannelAssignPlan {
assingChanCounter.Inc()
return nil
})

View File

@ -400,7 +400,7 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments []
SegmentInfo: s,
}
})
shardPlans := c.balancer.AssignSegment(replica.CollectionID, segmentInfos, availableNodes)
shardPlans := c.balancer.AssignSegment(replica.CollectionID, segmentInfos, availableNodes, false)
for i := range shardPlans {
shardPlans[i].Replica = replica
}

View File

@ -87,7 +87,7 @@ func (suite *SegmentCheckerTestSuite) TearDownTest() {
func (suite *SegmentCheckerTestSuite) createMockBalancer() balance.Balance {
balancer := balance.NewMockBalancer(suite.T())
balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(collectionID int64, segments []*meta.Segment, nodes []int64) []balance.SegmentAssignPlan {
balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(collectionID int64, segments []*meta.Segment, nodes []int64, _ bool) []balance.SegmentAssignPlan {
plans := make([]balance.SegmentAssignPlan, 0, len(segments))
for i, s := range segments {
plan := balance.SegmentAssignPlan{

View File

@ -22,6 +22,7 @@ import (
"sync"
"time"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
@ -87,78 +88,61 @@ func (s *Server) getCollectionSegmentInfo(collection int64) []*querypb.SegmentIn
return lo.Values(infos)
}
// parseBalanceRequest parses the load balance request,
// returns the collection, replica, and segments
func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRequest, replica *meta.Replica) error {
srcNode := req.GetSourceNodeIDs()[0]
dstNodeSet := typeutil.NewUniqueSet(req.GetDstNodeIDs()...)
if dstNodeSet.Len() == 0 {
outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica)
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool {
stop, err := s.nodeMgr.IsStoppingNode(node)
if err != nil {
return false
}
return !outboundNodes.Contain(node) && !stop
})
dstNodeSet.Insert(availableNodes...)
// generate balance segment task and submit to scheduler
// if sync is true, this func call will wait task to finish, until reach the segment task timeout
// if copyMode is true, this func call will generate a load segment task, instead a balance segment task
func (s *Server) balanceSegments(ctx context.Context,
collectionID int64,
replica *meta.Replica,
srcNode int64,
dstNodes []int64,
segments []*meta.Segment,
sync bool,
copyMode bool,
) error {
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID), zap.Int64("srcNode", srcNode))
plans := s.balancer.AssignSegment(collectionID, segments, dstNodes, true)
for i := range plans {
plans[i].From = srcNode
plans[i].Replica = replica
}
dstNodeSet.Remove(srcNode)
toBalance := typeutil.NewSet[*meta.Segment]()
// Only balance segments in targets
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(srcNode))
segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool {
return s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
})
allSegments := make(map[int64]*meta.Segment)
for _, segment := range segments {
allSegments[segment.GetID()] = segment
}
if len(req.GetSealedSegmentIDs()) == 0 {
toBalance.Insert(segments...)
} else {
for _, segmentID := range req.GetSealedSegmentIDs() {
segment, ok := allSegments[segmentID]
if !ok {
return fmt.Errorf("segment %d not found in source node %d", segmentID, srcNode)
}
toBalance.Insert(segment)
}
}
log := log.With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64("srcNodeID", srcNode),
zap.Int64s("destNodeIDs", dstNodeSet.Collect()),
)
plans := s.balancer.AssignSegment(req.GetCollectionID(), toBalance.Collect(), dstNodeSet.Collect())
tasks := make([]task.Task, 0, len(plans))
for _, plan := range plans {
log.Info("manually balance segment...",
zap.Int64("destNodeID", plan.To),
zap.Int64("replica", plan.Replica.ID),
zap.String("channel", plan.Segment.InsertChannel),
zap.Int64("from", plan.From),
zap.Int64("to", plan.To),
zap.Int64("segmentID", plan.Segment.GetID()),
)
task, err := task.NewSegmentTask(ctx,
actions := make([]task.Action, 0)
loadAction := task.NewSegmentActionWithScope(plan.To, task.ActionTypeGrow, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical)
actions = append(actions, loadAction)
if !copyMode {
// if in copy mode, the release action will be skip
releaseAction := task.NewSegmentActionWithScope(plan.From, task.ActionTypeReduce, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical)
actions = append(actions, releaseAction)
}
task, err := task.NewSegmentTask(s.ctx,
Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond),
task.WrapIDSource(req.GetBase().GetMsgID()),
req.GetCollectionID(),
replica,
task.NewSegmentActionWithScope(plan.To, task.ActionTypeGrow, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical),
task.NewSegmentActionWithScope(srcNode, task.ActionTypeReduce, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical),
utils.ManualBalance,
collectionID,
plan.Replica,
actions...,
)
if err != nil {
log.Warn("create segment task for balance failed",
zap.Int64("collection", req.GetCollectionID()),
zap.Int64("replica", replica.GetID()),
zap.Int64("replica", plan.Replica.ID),
zap.String("channel", plan.Segment.InsertChannel),
zap.Int64("from", srcNode),
zap.Int64("from", plan.From),
zap.Int64("to", plan.To),
zap.Int64("segmentID", plan.Segment.GetID()),
zap.Error(err),
)
continue
}
task.SetReason("manual balance")
err = s.taskScheduler.Add(task)
if err != nil {
task.Cancel(err)
@ -166,7 +150,92 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe
}
tasks = append(tasks, task)
}
return task.Wait(ctx, Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), tasks...)
if sync {
err := task.Wait(ctx, Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), tasks...)
if err != nil {
msg := "failed to wait all balance task finished"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
}
return nil
}
// generate balance channel task and submit to scheduler
// if sync is true, this func call will wait task to finish, until reach the channel task timeout
// if copyMode is true, this func call will generate a load channel task, instead a balance channel task
func (s *Server) balanceChannels(ctx context.Context,
collectionID int64,
replica *meta.Replica,
srcNode int64,
dstNodes []int64,
channels []*meta.DmChannel,
sync bool,
copyMode bool,
) error {
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID))
plans := s.balancer.AssignChannel(channels, dstNodes, true)
for i := range plans {
plans[i].From = srcNode
plans[i].Replica = replica
}
tasks := make([]task.Task, 0, len(plans))
for _, plan := range plans {
log.Info("manually balance channel...",
zap.Int64("replica", plan.Replica.ID),
zap.String("channel", plan.Channel.GetChannelName()),
zap.Int64("from", plan.From),
zap.Int64("to", plan.To),
)
actions := make([]task.Action, 0)
loadAction := task.NewChannelAction(plan.To, task.ActionTypeGrow, plan.Channel.GetChannelName())
actions = append(actions, loadAction)
if !copyMode {
// if in copy mode, the release action will be skip
releaseAction := task.NewChannelAction(plan.From, task.ActionTypeReduce, plan.Channel.GetChannelName())
actions = append(actions, releaseAction)
}
task, err := task.NewChannelTask(s.ctx,
Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond),
utils.ManualBalance,
collectionID,
plan.Replica,
actions...,
)
if err != nil {
log.Warn("create channel task for balance failed",
zap.Int64("replica", plan.Replica.ID),
zap.String("channel", plan.Channel.GetChannelName()),
zap.Int64("from", plan.From),
zap.Int64("to", plan.To),
zap.Error(err),
)
continue
}
task.SetReason("manual balance")
err = s.taskScheduler.Add(task)
if err != nil {
task.Cancel(err)
return err
}
tasks = append(tasks, task)
}
if sync {
err := task.Wait(ctx, Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), tasks...)
if err != nil {
msg := "failed to wait all balance task finished"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
}
return nil
}
// TODO(dragondriver): add more detail metrics

View File

@ -232,7 +232,7 @@ func (m *ReplicaManager) GetByCollection(collectionID typeutil.UniqueID) []*Repl
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
replicas := make([]*Replica, 0, 3)
replicas := make([]*Replica, 0)
for _, replica := range m.replicas {
if replica.CollectionID == collectionID {
replicas = append(replicas, replica)
@ -255,6 +255,20 @@ func (m *ReplicaManager) GetByCollectionAndNode(collectionID, nodeID typeutil.Un
return nil
}
func (m *ReplicaManager) GetByNode(nodeID typeutil.UniqueID) []*Replica {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
replicas := make([]*Replica, 0)
for _, replica := range m.replicas {
if replica.nodes.Contain(nodeID) {
replicas = append(replicas, replica)
}
}
return replicas
}
func (m *ReplicaManager) GetByCollectionAndRG(collectionID int64, rgName string) []*Replica {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()

View File

@ -102,6 +102,7 @@ func (suite *ReplicaManagerSuite) TestGet() {
for _, replica := range replicas {
suite.Equal(collection, replica.GetCollectionID())
suite.Equal(replica, mgr.Get(replica.GetID()))
suite.Equal(len(replica.Replica.GetNodes()), replica.Len())
suite.Equal(replica.Replica.GetNodes(), replica.GetNodes())
replicaNodes[replica.GetID()] = replica.Replica.GetNodes()
nodes = append(nodes, replica.Replica.Nodes...)
@ -117,6 +118,24 @@ func (suite *ReplicaManagerSuite) TestGet() {
}
}
func (suite *ReplicaManagerSuite) TestGetByNode() {
mgr := suite.mgr
randomNodeID := int64(11111)
testReplica1, err := mgr.spawn(3002, DefaultResourceGroupName)
suite.NoError(err)
testReplica1.AddNode(randomNodeID)
testReplica2, err := mgr.spawn(3002, DefaultResourceGroupName)
suite.NoError(err)
testReplica2.AddNode(randomNodeID)
mgr.Put(testReplica1, testReplica2)
replicas := mgr.GetByNode(randomNodeID)
suite.Len(replicas, 2)
}
func (suite *ReplicaManagerSuite) TestRecover() {
mgr := suite.mgr

View File

@ -0,0 +1,888 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querycoordv2
import (
"context"
"testing"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore"
"github.com/milvus-io/milvus/internal/metastore/kv/querycoord"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/balance"
"github.com/milvus-io/milvus/internal/querycoordv2/checkers"
"github.com/milvus-io/milvus/internal/querycoordv2/dist"
"github.com/milvus-io/milvus/internal/querycoordv2/job"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type OpsServiceSuite struct {
suite.Suite
// Dependencies
kv kv.MetaKv
store metastore.QueryCoordCatalog
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
broker *meta.MockBroker
targetObserver *observers.TargetObserver
cluster *session.MockCluster
nodeMgr *session.NodeManager
jobScheduler *job.Scheduler
taskScheduler *task.MockScheduler
balancer balance.Balance
distMgr *meta.DistributionManager
distController *dist.MockController
checkerController *checkers.CheckerController
// Test object
server *Server
}
func (suite *OpsServiceSuite) SetupSuite() {
paramtable.Init()
}
func (suite *OpsServiceSuite) SetupTest() {
config := params.GenerateEtcdConfig()
cli, err := etcd.GetEtcdClient(
config.UseEmbedEtcd.GetAsBool(),
config.EtcdUseSSL.GetAsBool(),
config.Endpoints.GetAsStrings(),
config.EtcdTLSCert.GetValue(),
config.EtcdTLSKey.GetValue(),
config.EtcdTLSCACert.GetValue(),
config.EtcdTLSMinVersion.GetValue())
suite.Require().NoError(err)
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())
suite.store = querycoord.NewCatalog(suite.kv)
suite.dist = meta.NewDistributionManager()
suite.nodeMgr = session.NewNodeManager()
suite.meta = meta.NewMeta(params.RandomIncrementIDAllocator(), suite.store, suite.nodeMgr)
suite.broker = meta.NewMockBroker(suite.T())
suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta)
suite.targetObserver = observers.NewTargetObserver(
suite.meta,
suite.targetMgr,
suite.dist,
suite.broker,
suite.cluster,
)
suite.cluster = session.NewMockCluster(suite.T())
suite.jobScheduler = job.NewScheduler()
suite.taskScheduler = task.NewMockScheduler(suite.T())
suite.jobScheduler.Start()
suite.balancer = balance.NewScoreBasedBalancer(
suite.taskScheduler,
suite.nodeMgr,
suite.dist,
suite.meta,
suite.targetMgr,
)
meta.GlobalFailedLoadCache = meta.NewFailedLoadCache()
suite.distMgr = meta.NewDistributionManager()
suite.distController = dist.NewMockController(suite.T())
suite.checkerController = checkers.NewCheckerController(suite.meta, suite.distMgr,
suite.targetMgr, suite.balancer, suite.nodeMgr, suite.taskScheduler, suite.broker)
suite.server = &Server{
kv: suite.kv,
store: suite.store,
session: sessionutil.NewSessionWithEtcd(context.Background(), Params.EtcdCfg.MetaRootPath.GetValue(), cli),
metricsCacheManager: metricsinfo.NewMetricsCacheManager(),
dist: suite.dist,
meta: suite.meta,
targetMgr: suite.targetMgr,
broker: suite.broker,
targetObserver: suite.targetObserver,
nodeMgr: suite.nodeMgr,
cluster: suite.cluster,
jobScheduler: suite.jobScheduler,
taskScheduler: suite.taskScheduler,
balancer: suite.balancer,
distController: suite.distController,
ctx: context.Background(),
checkerController: suite.checkerController,
}
suite.server.collectionObserver = observers.NewCollectionObserver(
suite.server.dist,
suite.server.meta,
suite.server.targetMgr,
suite.targetObserver,
&checkers.CheckerController{},
)
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
}
func (suite *OpsServiceSuite) TestActiveCheckers() {
// test server unhealthy
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
ctx := context.Background()
resp, err := suite.server.ListCheckers(ctx, &querypb.ListCheckersRequest{})
suite.NoError(err)
suite.False(merr.Ok(resp.Status))
resp1, err := suite.server.DeactivateChecker(ctx, &querypb.DeactivateCheckerRequest{})
suite.NoError(err)
suite.False(merr.Ok(resp1))
resp2, err := suite.server.ActivateChecker(ctx, &querypb.ActivateCheckerRequest{})
suite.NoError(err)
suite.False(merr.Ok(resp2))
// test active success
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err = suite.server.ListCheckers(ctx, &querypb.ListCheckersRequest{})
suite.NoError(err)
suite.True(merr.Ok(resp.Status))
suite.Len(resp.GetCheckerInfos(), 5)
resp4, err := suite.server.DeactivateChecker(ctx, &querypb.DeactivateCheckerRequest{
CheckerID: int32(utils.ChannelChecker),
})
suite.NoError(err)
suite.True(merr.Ok(resp4))
suite.False(suite.checkerController.IsActive(utils.ChannelChecker))
resp5, err := suite.server.ActivateChecker(ctx, &querypb.ActivateCheckerRequest{
CheckerID: int32(utils.ChannelChecker),
})
suite.NoError(err)
suite.True(merr.Ok(resp5))
suite.True(suite.checkerController.IsActive(utils.ChannelChecker))
}
func (suite *OpsServiceSuite) TestListQueryNode() {
// test server unhealthy
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
ctx := context.Background()
resp, err := suite.server.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{})
suite.NoError(err)
suite.Equal(0, len(resp.GetNodeInfos()))
suite.False(merr.Ok(resp.Status))
// test server healthy
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 111,
Address: "localhost",
Hostname: "localhost",
}))
resp, err = suite.server.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{})
suite.NoError(err)
suite.Equal(1, len(resp.GetNodeInfos()))
}
func (suite *OpsServiceSuite) TestGetQueryNodeDistribution() {
// test server unhealthy
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
ctx := context.Background()
resp, err := suite.server.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{})
suite.NoError(err)
suite.False(merr.Ok(resp.Status))
// test node not found
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err = suite.server.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
NodeID: 1,
})
suite.NoError(err)
suite.False(merr.Ok(resp.Status))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Hostname: "localhost",
}))
// test success
channels := []*meta.DmChannel{
{
VchannelInfo: &datapb.VchannelInfo{
CollectionID: 1,
ChannelName: "channel1",
},
Node: 1,
},
{
VchannelInfo: &datapb.VchannelInfo{
CollectionID: 1,
ChannelName: "channel2",
},
Node: 1,
},
}
segments := []*meta.Segment{
{
SegmentInfo: &datapb.SegmentInfo{
ID: 1,
CollectionID: 1,
PartitionID: 1,
InsertChannel: "channel1",
},
Node: 1,
},
{
SegmentInfo: &datapb.SegmentInfo{
ID: 2,
CollectionID: 1,
PartitionID: 1,
InsertChannel: "channel2",
},
Node: 1,
},
}
suite.dist.ChannelDistManager.Update(1, channels...)
suite.dist.SegmentDistManager.Update(1, segments...)
resp, err = suite.server.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
NodeID: 1,
})
suite.NoError(err)
suite.True(merr.Ok(resp.Status))
suite.Equal(2, len(resp.GetChannelNames()))
suite.Equal(2, len(resp.GetSealedSegmentIDs()))
}
func (suite *OpsServiceSuite) TestCheckQueryNodeDistribution() {
// test server unhealthy
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
ctx := context.Background()
resp, err := suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{})
suite.NoError(err)
suite.False(merr.Ok(resp))
// test node not found
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{
TargetNodeID: 2,
})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Hostname: "localhost",
}))
resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{
SourceNodeID: 1,
})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Hostname: "localhost",
}))
// test success
channels := []*meta.DmChannel{
{
VchannelInfo: &datapb.VchannelInfo{
CollectionID: 1,
ChannelName: "channel1",
},
Node: 1,
},
{
VchannelInfo: &datapb.VchannelInfo{
CollectionID: 1,
ChannelName: "channel2",
},
Node: 1,
},
}
segments := []*meta.Segment{
{
SegmentInfo: &datapb.SegmentInfo{
ID: 1,
CollectionID: 1,
PartitionID: 1,
InsertChannel: "channel1",
},
Node: 1,
},
{
SegmentInfo: &datapb.SegmentInfo{
ID: 2,
CollectionID: 1,
PartitionID: 1,
InsertChannel: "channel2",
},
Node: 1,
},
}
suite.dist.ChannelDistManager.Update(1, channels...)
suite.dist.SegmentDistManager.Update(1, segments...)
resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{
SourceNodeID: 1,
TargetNodeID: 2,
})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.dist.ChannelDistManager.Update(2, channels...)
suite.dist.SegmentDistManager.Update(2, segments...)
resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{
SourceNodeID: 1,
TargetNodeID: 1,
})
suite.NoError(err)
suite.True(merr.Ok(resp))
}
func (suite *OpsServiceSuite) TestSuspendAndResumeBalance() {
// test server unhealthy
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
ctx := context.Background()
resp, err := suite.server.SuspendBalance(ctx, &querypb.SuspendBalanceRequest{})
suite.NoError(err)
suite.False(merr.Ok(resp))
resp, err = suite.server.ResumeBalance(ctx, &querypb.ResumeBalanceRequest{})
suite.NoError(err)
suite.False(merr.Ok(resp))
// test suspend success
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err = suite.server.SuspendBalance(ctx, &querypb.SuspendBalanceRequest{})
suite.NoError(err)
suite.True(merr.Ok(resp))
suite.False(suite.checkerController.IsActive(utils.BalanceChecker))
resp, err = suite.server.ResumeBalance(ctx, &querypb.ResumeBalanceRequest{})
suite.NoError(err)
suite.True(merr.Ok(resp))
suite.True(suite.checkerController.IsActive(utils.BalanceChecker))
}
func (suite *OpsServiceSuite) TestSuspendAndResumeNode() {
// test server unhealthy
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
ctx := context.Background()
resp, err := suite.server.SuspendNode(ctx, &querypb.SuspendNodeRequest{})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
resp, err = suite.server.ResumeNode(ctx, &querypb.ResumeNodeRequest{})
suite.NoError(err)
suite.False(merr.Ok(resp))
// test node not found
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err = suite.server.SuspendNode(ctx, &querypb.SuspendNodeRequest{
NodeID: 1,
})
suite.NoError(err)
suite.False(merr.Ok(resp))
resp, err = suite.server.ResumeNode(ctx, &querypb.ResumeNodeRequest{
NodeID: 1,
})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Hostname: "localhost",
}))
// test success
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err = suite.server.SuspendNode(ctx, &querypb.SuspendNodeRequest{
NodeID: 1,
})
suite.NoError(err)
suite.True(merr.Ok(resp))
node := suite.nodeMgr.Get(1)
suite.Equal(session.NodeStateSuspend, node.GetState())
resp, err = suite.server.ResumeNode(ctx, &querypb.ResumeNodeRequest{
NodeID: 1,
})
suite.NoError(err)
suite.True(merr.Ok(resp))
node = suite.nodeMgr.Get(1)
suite.Equal(session.NodeStateNormal, node.GetState())
}
func (suite *OpsServiceSuite) TestTransferSegment() {
ctx := context.Background()
// test server unhealthy
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
resp, err := suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
// test source node not healthy
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
SourceNodeID: 1,
})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Hostname: "localhost",
}))
collectionID := int64(1)
partitionID := int64(1)
replicaID := int64(1)
nodes := []int64{1, 2, 3, 4}
replica := utils.CreateTestReplica(replicaID, collectionID, nodes)
suite.meta.ReplicaManager.Put(replica)
collection := utils.CreateTestCollection(collectionID, 1)
partition := utils.CreateTestPartition(partitionID, collectionID)
suite.meta.PutCollection(collection, partition)
segmentIDs := []int64{1, 2, 3, 4}
channelNames := []string{"channel-1", "channel-2", "channel-3", "channel-4"}
// test target node not healthy
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 2,
Address: "localhost",
Hostname: "localhost",
}))
// test segment not exist in node
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
SegmentID: segmentIDs[0],
})
suite.NoError(err)
suite.False(merr.Ok(resp))
segments := []*datapb.SegmentInfo{
{
ID: segmentIDs[0],
CollectionID: collectionID,
PartitionID: partitionID,
InsertChannel: channelNames[0],
NumOfRows: 1,
},
{
ID: segmentIDs[1],
CollectionID: collectionID,
PartitionID: partitionID,
InsertChannel: channelNames[1],
NumOfRows: 1,
},
{
ID: segmentIDs[2],
CollectionID: collectionID,
PartitionID: partitionID,
InsertChannel: channelNames[2],
NumOfRows: 1,
},
{
ID: segmentIDs[3],
CollectionID: collectionID,
PartitionID: partitionID,
InsertChannel: channelNames[3],
NumOfRows: 1,
},
}
channels := []*datapb.VchannelInfo{
{
CollectionID: collectionID,
ChannelName: channelNames[0],
},
{
CollectionID: collectionID,
ChannelName: channelNames[1],
},
{
CollectionID: collectionID,
ChannelName: channelNames[2],
},
{
CollectionID: collectionID,
ChannelName: channelNames[3],
},
}
segmentInfos := lo.Map(segments, func(segment *datapb.SegmentInfo, _ int) *meta.Segment {
return &meta.Segment{
SegmentInfo: segment,
Node: nodes[0],
}
})
chanenlInfos := lo.Map(channels, func(channel *datapb.VchannelInfo, _ int) *meta.DmChannel {
return &meta.DmChannel{
VchannelInfo: channel,
Node: nodes[0],
}
})
suite.dist.SegmentDistManager.Update(1, segmentInfos[0])
// test segment not exist in current target, expect no task assign and success
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
SegmentID: segmentIDs[0],
})
suite.NoError(err)
suite.True(merr.Ok(resp))
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(channels, segments, nil)
suite.targetMgr.UpdateCollectionNextTarget(1)
suite.targetMgr.UpdateCollectionCurrentTarget(1)
suite.dist.SegmentDistManager.Update(1, segmentInfos...)
suite.dist.ChannelDistManager.Update(1, chanenlInfos...)
for _, node := range nodes {
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: node,
Address: "localhost",
Hostname: "localhost",
}))
suite.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, node)
}
// test transfer segment success, expect generate 1 balance segment task
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
actions := t.Actions()
suite.Equal(len(actions), 2)
suite.Equal(actions[0].Node(), int64(2))
return nil
})
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
SegmentID: segmentIDs[0],
})
suite.NoError(err)
suite.True(merr.Ok(resp))
// test copy mode, expect generate 1 load segment task
suite.taskScheduler.ExpectedCalls = nil
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
actions := t.Actions()
suite.Equal(len(actions), 1)
suite.Equal(actions[0].Node(), int64(2))
return nil
})
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
SegmentID: segmentIDs[0],
CopyMode: true,
})
suite.NoError(err)
suite.True(merr.Ok(resp))
// test transfer all segments, expect generate 4 load segment task
suite.taskScheduler.ExpectedCalls = nil
counter := atomic.NewInt64(0)
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
actions := t.Actions()
suite.Equal(len(actions), 2)
suite.Equal(actions[0].Node(), int64(2))
counter.Inc()
return nil
})
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
TransferAll: true,
})
suite.NoError(err)
suite.True(merr.Ok(resp))
suite.Equal(counter.Load(), int64(4))
// test transfer all segment to all nodes, expect generate 4 load segment task
suite.taskScheduler.ExpectedCalls = nil
counter = atomic.NewInt64(0)
nodeSet := typeutil.NewUniqueSet()
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
actions := t.Actions()
suite.Equal(len(actions), 2)
nodeSet.Insert(actions[0].Node())
counter.Inc()
return nil
})
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
SourceNodeID: nodes[0],
TransferAll: true,
ToAllNodes: true,
})
suite.NoError(err)
suite.True(merr.Ok(resp))
suite.Equal(counter.Load(), int64(4))
suite.Len(nodeSet.Collect(), 3)
}
func (suite *OpsServiceSuite) TestTransferChannel() {
ctx := context.Background()
// test server unhealthy
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
resp, err := suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
// test source node not healthy
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
SourceNodeID: 1,
})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Hostname: "localhost",
}))
collectionID := int64(1)
partitionID := int64(1)
replicaID := int64(1)
nodes := []int64{1, 2, 3, 4}
replica := utils.CreateTestReplica(replicaID, collectionID, nodes)
suite.meta.ReplicaManager.Put(replica)
collection := utils.CreateTestCollection(collectionID, 1)
partition := utils.CreateTestPartition(partitionID, collectionID)
suite.meta.PutCollection(collection, partition)
segmentIDs := []int64{1, 2, 3, 4}
channelNames := []string{"channel-1", "channel-2", "channel-3", "channel-4"}
// test target node not healthy
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 2,
Address: "localhost",
Hostname: "localhost",
}))
segments := []*datapb.SegmentInfo{
{
ID: segmentIDs[0],
CollectionID: collectionID,
PartitionID: partitionID,
InsertChannel: channelNames[0],
NumOfRows: 1,
},
{
ID: segmentIDs[1],
CollectionID: collectionID,
PartitionID: partitionID,
InsertChannel: channelNames[1],
NumOfRows: 1,
},
{
ID: segmentIDs[2],
CollectionID: collectionID,
PartitionID: partitionID,
InsertChannel: channelNames[2],
NumOfRows: 1,
},
{
ID: segmentIDs[3],
CollectionID: collectionID,
PartitionID: partitionID,
InsertChannel: channelNames[3],
NumOfRows: 1,
},
}
channels := []*datapb.VchannelInfo{
{
CollectionID: collectionID,
ChannelName: channelNames[0],
},
{
CollectionID: collectionID,
ChannelName: channelNames[1],
},
{
CollectionID: collectionID,
ChannelName: channelNames[2],
},
{
CollectionID: collectionID,
ChannelName: channelNames[3],
},
}
segmentInfos := lo.Map(segments, func(segment *datapb.SegmentInfo, _ int) *meta.Segment {
return &meta.Segment{
SegmentInfo: segment,
Node: nodes[0],
}
})
suite.dist.SegmentDistManager.Update(1, segmentInfos...)
chanenlInfos := lo.Map(channels, func(channel *datapb.VchannelInfo, _ int) *meta.DmChannel {
return &meta.DmChannel{
VchannelInfo: channel,
Node: nodes[0],
}
})
// test channel not exist in node
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
ChannelName: channelNames[0],
})
suite.NoError(err)
suite.False(merr.Ok(resp))
suite.dist.ChannelDistManager.Update(1, chanenlInfos[0])
// test channel not exist in current target, expect no task assign and success
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
ChannelName: channelNames[0],
})
suite.NoError(err)
suite.True(merr.Ok(resp))
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(channels, segments, nil)
suite.targetMgr.UpdateCollectionNextTarget(1)
suite.targetMgr.UpdateCollectionCurrentTarget(1)
suite.dist.SegmentDistManager.Update(1, segmentInfos...)
suite.dist.ChannelDistManager.Update(1, chanenlInfos...)
for _, node := range nodes {
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: node,
Address: "localhost",
Hostname: "localhost",
}))
suite.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, node)
}
// test transfer channel success, expect generate 1 balance channel task
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
actions := t.Actions()
suite.Equal(len(actions), 2)
suite.Equal(actions[0].Node(), int64(2))
return nil
})
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
ChannelName: channelNames[0],
})
suite.NoError(err)
suite.True(merr.Ok(resp))
// test copy mode, expect generate 1 load segment task
suite.taskScheduler.ExpectedCalls = nil
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
actions := t.Actions()
suite.Equal(len(actions), 1)
suite.Equal(actions[0].Node(), int64(2))
return nil
})
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
ChannelName: channelNames[0],
CopyMode: true,
})
suite.NoError(err)
suite.True(merr.Ok(resp))
// test transfer all channels, expect generate 4 load segment task
suite.taskScheduler.ExpectedCalls = nil
counter := atomic.NewInt64(0)
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
actions := t.Actions()
suite.Equal(len(actions), 2)
suite.Equal(actions[0].Node(), int64(2))
counter.Inc()
return nil
})
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
SourceNodeID: nodes[0],
TargetNodeID: nodes[1],
TransferAll: true,
})
suite.NoError(err)
suite.True(merr.Ok(resp))
suite.Equal(counter.Load(), int64(4))
// test transfer all channels to all nodes, expect generate 4 load segment task
suite.taskScheduler.ExpectedCalls = nil
counter = atomic.NewInt64(0)
nodeSet := typeutil.NewUniqueSet()
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
actions := t.Actions()
suite.Equal(len(actions), 2)
nodeSet.Insert(actions[0].Node())
counter.Inc()
return nil
})
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
SourceNodeID: nodes[0],
TransferAll: true,
ToAllNodes: true,
})
suite.NoError(err)
suite.True(merr.Ok(resp))
suite.Equal(counter.Load(), int64(4))
suite.Len(nodeSet.Collect(), 3)
}
func TestOpsService(t *testing.T) {
suite.Run(t, new(OpsServiceSuite))
}

View File

@ -19,10 +19,14 @@ package querycoordv2
import (
"context"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
@ -93,3 +97,368 @@ func (s *Server) DeactivateChecker(ctx context.Context, req *querypb.DeactivateC
}
return merr.Success(), nil
}
// return all available node list, for each node, return it's (nodeID, ip_address)
func (s *Server) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error) {
log := log.Ctx(ctx)
log.Info("ListQueryNode request received")
errMsg := "failed to list querynode state"
if err := merr.CheckHealthy(s.State()); err != nil {
log.Warn(errMsg, zap.Error(err))
return &querypb.ListQueryNodeResponse{
Status: merr.Status(errors.Wrap(err, errMsg)),
}, nil
}
nodes := lo.Map(s.nodeMgr.GetAll(), func(nodeInfo *session.NodeInfo, _ int) *querypb.NodeInfo {
return &querypb.NodeInfo{
ID: nodeInfo.ID(),
Address: nodeInfo.Addr(),
State: nodeInfo.GetState().String(),
}
})
return &querypb.ListQueryNodeResponse{
Status: merr.Success(),
NodeInfos: nodes,
}, nil
}
// return query node's data distribution, for given nodeID, return it's (channel_name_list, sealed_segment_list)
func (s *Server) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error) {
log := log.Ctx(ctx).With(zap.Int64("nodeID", req.GetNodeID()))
log.Info("GetQueryNodeDistribution request received")
errMsg := "failed to get query node distribution"
if err := merr.CheckHealthy(s.State()); err != nil {
log.Warn(errMsg, zap.Error(err))
return &querypb.GetQueryNodeDistributionResponse{
Status: merr.Status(errors.Wrap(err, errMsg)),
}, nil
}
if s.nodeMgr.Get(req.GetNodeID()) == nil {
err := merr.WrapErrNodeNotFound(req.GetNodeID(), errMsg)
log.Warn(errMsg, zap.Error(err))
return &querypb.GetQueryNodeDistributionResponse{
Status: merr.Status(err),
}, nil
}
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(req.GetNodeID()))
channels := s.dist.ChannelDistManager.GetByNode(req.NodeID)
return &querypb.GetQueryNodeDistributionResponse{
Status: merr.Success(),
ChannelNames: lo.Map(channels, func(c *meta.DmChannel, _ int) string { return c.GetChannelName() }),
SealedSegmentIDs: lo.Map(segments, func(s *meta.Segment, _ int) int64 { return s.GetID() }),
}, nil
}
// suspend background balance for all query node, include stopping balance and auto balance
func (s *Server) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx)
log.Info("SuspendBalance request received")
errMsg := "failed to suspend balance for all querynode"
if err := merr.CheckHealthy(s.State()); err != nil {
return merr.Status(err), nil
}
err := s.checkerController.Deactivate(utils.BalanceChecker)
if err != nil {
log.Warn(errMsg, zap.Error(err))
return merr.Status(err), nil
}
return merr.Success(), nil
}
// resume background balance for all query node, include stopping balance and auto balance
func (s *Server) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx)
log.Info("ResumeBalance request received")
errMsg := "failed to resume balance for all querynode"
if err := merr.CheckHealthy(s.State()); err != nil {
return merr.Status(err), nil
}
err := s.checkerController.Activate(utils.BalanceChecker)
if err != nil {
log.Warn(errMsg, zap.Error(err))
return merr.Status(err), nil
}
return merr.Success(), nil
}
// suspend node from resource operation, for given node, suspend load_segment/sub_channel operations
func (s *Server) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx)
log.Info("SuspendNode request received", zap.Int64("nodeID", req.GetNodeID()))
errMsg := "failed to suspend query node"
if err := merr.CheckHealthy(s.State()); err != nil {
log.Warn(errMsg, zap.Error(err))
return merr.Status(err), nil
}
if s.nodeMgr.Get(req.GetNodeID()) == nil {
err := merr.WrapErrNodeNotFound(req.GetNodeID(), errMsg)
log.Warn(errMsg, zap.Error(err))
return merr.Status(err), nil
}
err := s.nodeMgr.Suspend(req.GetNodeID())
if err != nil {
log.Warn(errMsg, zap.Error(err))
return merr.Status(err), nil
}
return merr.Success(), nil
}
// resume node from resource operation, for given node, resume load_segment/sub_channel operations
func (s *Server) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx)
log.Info("ResumeNode request received", zap.Int64("nodeID", req.GetNodeID()))
errMsg := "failed to resume query node"
if err := merr.CheckHealthy(s.State()); err != nil {
log.Warn(errMsg, zap.Error(err))
return merr.Status(errors.Wrap(err, errMsg)), nil
}
if s.nodeMgr.Get(req.GetNodeID()) == nil {
err := merr.WrapErrNodeNotFound(req.GetNodeID(), errMsg)
log.Warn(errMsg, zap.Error(err))
return merr.Status(err), nil
}
err := s.nodeMgr.Resume(req.GetNodeID())
if err != nil {
log.Warn(errMsg, zap.Error(err))
return merr.Status(errors.Wrap(err, errMsg)), nil
}
return merr.Success(), nil
}
// transfer segment from source to target,
// if no segment_id specified, default to transfer all segment on the source node.
// if no target_nodeId specified, default to move segment to all other nodes
func (s *Server) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx)
log.Info("TransferSegment request received",
zap.Int64("source", req.GetSourceNodeID()),
zap.Int64("dest", req.GetTargetNodeID()),
zap.Int64("segment", req.GetSegmentID()))
if err := merr.CheckHealthy(s.State()); err != nil {
msg := "failed to load balance"
log.Warn(msg, zap.Error(err))
return merr.Status(errors.Wrap(err, msg)), nil
}
// check whether srcNode is healthy
srcNode := req.GetSourceNodeID()
if err := s.isStoppingNode(srcNode); err != nil {
err := merr.WrapErrNodeNotAvailable(srcNode, "the source node is invalid")
return merr.Status(err), nil
}
replicas := s.meta.ReplicaManager.GetByNode(req.GetSourceNodeID())
for _, replica := range replicas {
// when no dst node specified, default to use all other nodes in same
dstNodeSet := typeutil.NewUniqueSet()
if req.GetToAllNodes() {
outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica)
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { return !outboundNodes.Contain(node) })
dstNodeSet.Insert(availableNodes...)
} else {
// check whether dstNode is healthy
if err := s.isStoppingNode(req.GetTargetNodeID()); err != nil {
err := merr.WrapErrNodeNotAvailable(srcNode, "the target node is invalid")
return merr.Status(err), nil
}
dstNodeSet.Insert(req.GetTargetNodeID())
}
dstNodeSet.Remove(srcNode)
// check sealed segment list
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(srcNode))
toBalance := typeutil.NewSet[*meta.Segment]()
if req.GetTransferAll() {
toBalance.Insert(segments...)
} else {
// check whether sealed segment exist
segment, ok := lo.Find(segments, func(s *meta.Segment) bool { return s.GetID() == req.GetSegmentID() })
if !ok {
err := merr.WrapErrSegmentNotFound(req.GetSegmentID(), "segment not found in source node")
return merr.Status(err), nil
}
existInTarget := s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
if !existInTarget {
log.Info("segment doesn't exist in current target, skip it", zap.Int64("segmentID", req.GetSegmentID()))
} else {
toBalance.Insert(segment)
}
}
err := s.balanceSegments(ctx, replica.GetCollectionID(), replica, srcNode, dstNodeSet.Collect(), toBalance.Collect(), false, req.GetCopyMode())
if err != nil {
msg := "failed to balance segments"
log.Warn(msg, zap.Error(err))
return merr.Status(errors.Wrap(err, msg)), nil
}
}
return merr.Success(), nil
}
// transfer channel from source to target,
// if no channel_name specified, default to transfer all channel on the source node.
// if no target_nodeId specified, default to move channel to all other nodes
func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx)
log.Info("TransferChannel request received",
zap.Int64("source", req.GetSourceNodeID()),
zap.Int64("dest", req.GetTargetNodeID()),
zap.String("channel", req.GetChannelName()))
if err := merr.CheckHealthy(s.State()); err != nil {
msg := "failed to load balance"
log.Warn(msg, zap.Error(err))
return merr.Status(errors.Wrap(err, msg)), nil
}
// check whether srcNode is healthy
srcNode := req.GetSourceNodeID()
if err := s.isStoppingNode(srcNode); err != nil {
err := merr.WrapErrNodeNotAvailable(srcNode, "the source node is invalid")
return merr.Status(err), nil
}
replicas := s.meta.ReplicaManager.GetByNode(req.GetSourceNodeID())
for _, replica := range replicas {
// when no dst node specified, default to use all other nodes in same
dstNodeSet := typeutil.NewUniqueSet()
if req.GetToAllNodes() {
outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica)
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { return !outboundNodes.Contain(node) })
dstNodeSet.Insert(availableNodes...)
} else {
// check whether dstNode is healthy
if err := s.isStoppingNode(req.GetTargetNodeID()); err != nil {
err := merr.WrapErrNodeNotAvailable(srcNode, "the target node is invalid")
return merr.Status(err), nil
}
dstNodeSet.Insert(req.GetTargetNodeID())
}
dstNodeSet.Remove(srcNode)
// check sealed segment list
channels := s.dist.ChannelDistManager.GetByCollectionAndNode(replica.CollectionID, srcNode)
toBalance := typeutil.NewSet[*meta.DmChannel]()
if req.GetTransferAll() {
toBalance.Insert(channels...)
} else {
// check whether sealed segment exist
channel, ok := lo.Find(channels, func(ch *meta.DmChannel) bool { return ch.GetChannelName() == req.GetChannelName() })
if !ok {
err := merr.WrapErrChannelNotFound(req.GetChannelName(), "channel not found in source node")
return merr.Status(err), nil
}
existInTarget := s.targetMgr.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget) != nil
if !existInTarget {
log.Info("channel doesn't exist in current target, skip it", zap.String("channelName", channel.GetChannelName()))
} else {
toBalance.Insert(channel)
}
}
err := s.balanceChannels(ctx, replica.GetCollectionID(), replica, srcNode, dstNodeSet.Collect(), toBalance.Collect(), false, req.GetCopyMode())
if err != nil {
msg := "failed to balance channels"
log.Warn(msg, zap.Error(err))
return merr.Status(errors.Wrap(err, msg)), nil
}
}
return merr.Success(), nil
}
func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx)
log.Info("CheckQueryNodeDistribution request received",
zap.Int64("source", req.GetSourceNodeID()),
zap.Int64("dest", req.GetTargetNodeID()))
errMsg := "failed to check query node distribution"
if err := merr.CheckHealthy(s.State()); err != nil {
log.Warn(errMsg, zap.Error(err))
return merr.Status(err), nil
}
sourceNode := s.nodeMgr.Get(req.GetSourceNodeID())
if sourceNode == nil {
err := merr.WrapErrNodeNotFound(req.GetSourceNodeID(), "source node not found")
log.Warn(errMsg, zap.Error(err))
return merr.Status(err), nil
}
targetNode := s.nodeMgr.Get(req.GetTargetNodeID())
if targetNode == nil {
err := merr.WrapErrNodeNotFound(req.GetTargetNodeID(), "target node not found")
log.Warn(errMsg, zap.Error(err))
return merr.Status(err), nil
}
// check channel list
channelOnSrc := s.dist.ChannelDistManager.GetByNode(req.GetSourceNodeID())
channelOnDst := s.dist.ChannelDistManager.GetByNode(req.GetTargetNodeID())
channelDstMap := lo.SliceToMap(channelOnDst, func(ch *meta.DmChannel) (string, *meta.DmChannel) {
return ch.GetChannelName(), ch
})
for _, ch := range channelOnSrc {
if _, ok := channelDstMap[ch.GetChannelName()]; !ok {
return merr.Status(merr.WrapErrChannelLack(ch.GetChannelName())), nil
}
}
channelSrcMap := lo.SliceToMap(channelOnSrc, func(ch *meta.DmChannel) (string, *meta.DmChannel) {
return ch.GetChannelName(), ch
})
for _, ch := range channelOnDst {
if _, ok := channelSrcMap[ch.GetChannelName()]; !ok {
return merr.Status(merr.WrapErrChannelLack(ch.GetChannelName())), nil
}
}
// check segment list
segmentOnSrc := s.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(req.GetSourceNodeID()))
segmentOnDst := s.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(req.GetTargetNodeID()))
segmentDstMap := lo.SliceToMap(segmentOnDst, func(s *meta.Segment) (int64, *meta.Segment) {
return s.GetID(), s
})
for _, s := range segmentOnSrc {
if _, ok := segmentDstMap[s.GetID()]; !ok {
return merr.Status(merr.WrapErrSegmentLack(s.GetID())), nil
}
}
segmentSrcMap := lo.SliceToMap(segmentOnSrc, func(s *meta.Segment) (int64, *meta.Segment) {
return s.GetID(), s
})
for _, s := range segmentOnDst {
if _, ok := segmentSrcMap[s.GetID()]; !ok {
return merr.Status(merr.WrapErrSegmentLack(s.GetID())), nil
}
}
return merr.Success(), nil
}

View File

@ -682,24 +682,67 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques
return merr.Status(errors.Wrap(err,
fmt.Sprintf("can't balance, because the source node[%d] is invalid", srcNode))), nil
}
for _, dstNode := range req.GetDstNodeIDs() {
if !replica.Contains(dstNode) {
err := merr.WrapErrNodeNotFound(dstNode, "destination node not found in the same replica")
log.Warn("failed to balance to the destination node", zap.Error(err))
return merr.Status(err), nil
// when no dst node specified, default to use all other nodes in same
dstNodeSet := typeutil.NewUniqueSet()
if len(req.GetDstNodeIDs()) == 0 {
outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica)
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { return !outboundNodes.Contain(node) })
dstNodeSet.Insert(availableNodes...)
} else {
for _, dstNode := range req.GetDstNodeIDs() {
if !replica.Contains(dstNode) {
err := merr.WrapErrNodeNotFound(dstNode, "destination node not found in the same replica")
log.Warn("failed to balance to the destination node", zap.Error(err))
return merr.Status(err), nil
}
dstNodeSet.Insert(dstNode)
}
}
// check whether dstNode is healthy
for dstNode := range dstNodeSet {
if err := s.isStoppingNode(dstNode); err != nil {
return merr.Status(errors.Wrap(err,
fmt.Sprintf("can't balance, because the destination node[%d] is invalid", dstNode))), nil
}
}
err := s.balanceSegments(ctx, req, replica)
// check sealed segment list
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(req.GetCollectionID()), meta.WithNodeID(srcNode))
segmentsMap := lo.SliceToMap(segments, func(s *meta.Segment) (int64, *meta.Segment) {
return s.GetID(), s
})
toBalance := typeutil.NewSet[*meta.Segment]()
if len(req.GetSealedSegmentIDs()) == 0 {
toBalance.Insert(segments...)
} else {
// check whether sealed segment exist
for _, segmentID := range req.GetSealedSegmentIDs() {
segment, ok := segmentsMap[segmentID]
if !ok {
err := merr.WrapErrSegmentNotFound(segmentID, "segment not found in source node")
return merr.Status(err), nil
}
// Only balance segments in targets
existInTarget := s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
if !existInTarget {
log.Info("segment doesn't exist in current target, skip it", zap.Int64("segmentID", segmentID))
continue
}
toBalance.Insert(segment)
}
}
err := s.balanceSegments(ctx, replica.GetCollectionID(), replica, srcNode, dstNodeSet.Collect(), toBalance.Collect(), true, false)
if err != nil {
msg := "failed to balance segments"
log.Warn(msg, zap.Error(err))
return merr.Status(errors.Wrap(err, msg)), nil
}
return merr.Success(), nil
}

View File

@ -1174,6 +1174,51 @@ func (suite *ServiceSuite) TestLoadBalance() {
suite.Equal(resp.GetCode(), merr.Code(merr.ErrServiceNotReady))
}
func (suite *ServiceSuite) TestLoadBalanceWithNoDstNode() {
suite.loadAll()
ctx := context.Background()
server := suite.server
// Test get balance first segment
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
nodes := replicas[0].GetNodes()
srcNode := nodes[0]
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateSegmentDist(collection, srcNode)
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
CollectionID: collection,
SourceNodeIDs: []int64{srcNode},
SealedSegmentIDs: segments,
}
suite.taskScheduler.ExpectedCalls = make([]*mock.Call, 0)
suite.taskScheduler.EXPECT().Add(mock.Anything).Run(func(task task.Task) {
actions := task.Actions()
suite.Len(actions, 2)
growAction, reduceAction := actions[0], actions[1]
suite.Contains(nodes, growAction.Node())
suite.Equal(srcNode, reduceAction.Node())
task.Cancel(nil)
}).Return(nil)
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.taskScheduler.AssertExpectations(suite.T())
}
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
req := &querypb.LoadBalanceRequest{
CollectionID: suite.collections[0],
SourceNodeIDs: []int64{1},
DstNodeIDs: []int64{100 + 1},
}
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
suite.Equal(resp.GetCode(), merr.Code(merr.ErrServiceNotReady))
}
func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() {
suite.loadAll()
ctx := context.Background()

View File

@ -23,8 +23,11 @@ import (
"github.com/blang/semver/v4"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type Manager interface {
@ -33,6 +36,9 @@ type Manager interface {
Remove(nodeID int64)
Get(nodeID int64) *NodeInfo
GetAll() []*NodeInfo
Suspend(nodeID int64) error
Resume(nodeID int64) error
}
type NodeManager struct {
@ -62,6 +68,42 @@ func (m *NodeManager) Stopping(nodeID int64) {
}
}
func (m *NodeManager) Suspend(nodeID int64) error {
m.mu.Lock()
defer m.mu.Unlock()
nodeInfo, ok := m.nodes[nodeID]
if !ok {
return merr.WrapErrNodeNotFound(nodeID)
}
switch nodeInfo.GetState() {
case NodeStateNormal:
nodeInfo.SetState(NodeStateSuspend)
return nil
default:
log.Warn("failed to suspend query node", zap.Int64("nodeID", nodeID), zap.String("state", nodeInfo.GetState().String()))
return merr.WrapErrNodeStateUnexpected(nodeID, nodeInfo.GetState().String(), "failed to suspend a query node")
}
}
func (m *NodeManager) Resume(nodeID int64) error {
m.mu.Lock()
defer m.mu.Unlock()
nodeInfo, ok := m.nodes[nodeID]
if !ok {
return merr.WrapErrNodeNotFound(nodeID)
}
switch nodeInfo.GetState() {
case NodeStateSuspend:
nodeInfo.SetState(NodeStateNormal)
return nil
default:
log.Warn("failed to resume query node", zap.Int64("nodeID", nodeID), zap.String("state", nodeInfo.GetState().String()))
return merr.WrapErrNodeStateUnexpected(nodeID, nodeInfo.GetState().String(), "failed to resume query node")
}
}
func (m *NodeManager) IsStoppingNode(nodeID int64) (bool, error) {
m.mu.RLock()
defer m.mu.RUnlock()
@ -98,8 +140,9 @@ func NewNodeManager() *NodeManager {
type State int
const (
NodeStateNormal = iota
NodeStateStopping
NormalStateName = "Normal"
StoppingStateName = "Stopping"
SuspendStateName = "Suspend"
)
type ImmutableNodeInfo struct {
@ -109,6 +152,22 @@ type ImmutableNodeInfo struct {
Version semver.Version
}
const (
NodeStateNormal State = iota
NodeStateStopping
NodeStateSuspend
)
var stateNameMap = map[State]string{
NodeStateNormal: NormalStateName,
NodeStateStopping: StoppingStateName,
NodeStateSuspend: SuspendStateName,
}
func (s State) String() string {
return stateNameMap[s]
}
type NodeInfo struct {
stats
mu sync.RWMutex
@ -161,6 +220,12 @@ func (n *NodeInfo) SetState(s State) {
n.state = s
}
func (n *NodeInfo) GetState() State {
n.mu.RLock()
defer n.mu.RUnlock()
return n.state
}
func (n *NodeInfo) UpdateStats(opts ...StatsOption) {
n.mu.Lock()
for _, opt := range opts {

View File

@ -0,0 +1,110 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type NodeManagerSuite struct {
suite.Suite
nodeManager *NodeManager
}
func (s *NodeManagerSuite) SetupTest() {
s.nodeManager = NewNodeManager()
}
func (s *NodeManagerSuite) TearDownTest() {
}
func (s *NodeManagerSuite) TestNodeOperation() {
s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Hostname: "localhost",
}))
s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{
NodeID: 2,
Address: "localhost",
Hostname: "localhost",
}))
s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{
NodeID: 3,
Address: "localhost",
Hostname: "localhost",
}))
s.NotNil(s.nodeManager.Get(1))
s.Len(s.nodeManager.GetAll(), 3)
s.nodeManager.Remove(1)
s.Nil(s.nodeManager.Get(1))
s.Len(s.nodeManager.GetAll(), 2)
s.nodeManager.Stopping(2)
s.True(s.nodeManager.IsStoppingNode(2))
err := s.nodeManager.Resume(2)
s.ErrorIs(err, merr.ErrNodeStateUnexpected)
s.True(s.nodeManager.IsStoppingNode(2))
node := s.nodeManager.Get(2)
node.SetState(NodeStateNormal)
s.False(s.nodeManager.IsStoppingNode(2))
err = s.nodeManager.Resume(3)
s.ErrorIs(err, merr.ErrNodeStateUnexpected)
s.nodeManager.Suspend(3)
node = s.nodeManager.Get(3)
s.NotNil(node)
s.Equal(NodeStateSuspend, node.GetState())
s.nodeManager.Resume(3)
node = s.nodeManager.Get(3)
s.NotNil(node)
s.Equal(NodeStateNormal, node.GetState())
}
func (s *NodeManagerSuite) TestNodeInfo() {
node := NewNodeInfo(ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Hostname: "localhost",
})
s.Equal(int64(1), node.ID())
s.Equal("localhost", node.Addr())
node.setChannelCnt(1)
node.setSegmentCnt(1)
s.Equal(1, node.ChannelCnt())
s.Equal(1, node.SegmentCnt())
node.UpdateStats(WithSegmentCnt(5))
node.UpdateStats(WithChannelCnt(5))
s.Equal(5, node.ChannelCnt())
s.Equal(5, node.SegmentCnt())
node.SetLastHeartbeat(time.Now())
s.NotNil(node.LastHeartbeat())
}
func TestNodeManagerSuite(t *testing.T) {
suite.Run(t, new(NodeManagerSuite))
}

View File

@ -27,6 +27,7 @@ const (
BalanceCheckerName = "balance_checker"
IndexCheckerName = "index_checker"
LeaderCheckerName = "leader_checker"
ManualBalanceName = "manual_balance"
)
type CheckerType int32
@ -37,6 +38,7 @@ const (
BalanceChecker
IndexChecker
LeaderChecker
ManualBalance
)
var checkerNames = map[CheckerType]string{
@ -45,6 +47,7 @@ var checkerNames = map[CheckerType]string{
BalanceChecker: BalanceCheckerName,
IndexChecker: IndexCheckerName,
LeaderChecker: LeaderCheckerName,
ManualBalance: ManualBalanceName,
}
func (s CheckerType) String() string {

View File

@ -141,3 +141,39 @@ func (m *GrpcQueryCoordClient) ActivateChecker(ctx context.Context, in *querypb.
func (m *GrpcQueryCoordClient) DeactivateChecker(ctx context.Context, in *querypb.DeactivateCheckerRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcQueryCoordClient) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest, opts ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error) {
return &querypb.ListQueryNodeResponse{}, m.Err
}
func (m *GrpcQueryCoordClient) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error) {
return &querypb.GetQueryNodeDistributionResponse{}, m.Err
}
func (m *GrpcQueryCoordClient) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcQueryCoordClient) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcQueryCoordClient) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcQueryCoordClient) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcQueryCoordClient) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcQueryCoordClient) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcQueryCoordClient) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}

View File

@ -90,11 +90,12 @@ var (
ErrDatabaseInvalidName = newMilvusError("invalid database name", 802, false)
// Node related
ErrNodeNotFound = newMilvusError("node not found", 901, false)
ErrNodeOffline = newMilvusError("node offline", 902, false)
ErrNodeLack = newMilvusError("node lacks", 903, false)
ErrNodeNotMatch = newMilvusError("node not match", 904, false)
ErrNodeNotAvailable = newMilvusError("node not available", 905, false)
ErrNodeNotFound = newMilvusError("node not found", 901, false)
ErrNodeOffline = newMilvusError("node offline", 902, false)
ErrNodeLack = newMilvusError("node lacks", 903, false)
ErrNodeNotMatch = newMilvusError("node not match", 904, false)
ErrNodeNotAvailable = newMilvusError("node not available", 905, false)
ErrNodeStateUnexpected = newMilvusError("node state unexpected", 906, false)
// IO related
ErrIoKeyNotFound = newMilvusError("key not found", 1000, false)

View File

@ -120,6 +120,7 @@ func (s *ErrSuite) TestWrap() {
s.ErrorIs(WrapErrNodeNotFound(1, "failed to get node"), ErrNodeNotFound)
s.ErrorIs(WrapErrNodeOffline(1, "failed to access node"), ErrNodeOffline)
s.ErrorIs(WrapErrNodeLack(3, 1, "need more nodes"), ErrNodeLack)
s.ErrorIs(WrapErrNodeStateUnexpected(1, "Stopping", "failed to suspend node"), ErrNodeStateUnexpected)
// IO related
s.ErrorIs(WrapErrIoKeyNotFound("test_key", "failed to read"), ErrIoKeyNotFound)

View File

@ -731,6 +731,14 @@ func WrapErrNodeNotAvailable(id int64, msg ...string) error {
return err
}
func WrapErrNodeStateUnexpected(id int64, state string, msg ...string) error {
err := wrapFields(ErrNodeStateUnexpected, value("node", id), value("state", state))
if len(msg) > 0 {
err = errors.Wrap(err, strings.Join(msg, "->"))
}
return err
}
func WrapErrNodeNotMatch(expectedNodeID, actualNodeID int64, msg ...string) error {
err := wrapFields(ErrNodeNotMatch,
value("expectedNodeID", expectedNodeID),

View File

@ -0,0 +1,364 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rollingupgrade
import (
"context"
"math/rand"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/tests/integration"
)
type ManualRollingUpgradeSuite struct {
integration.MiniClusterSuite
}
func (s *ManualRollingUpgradeSuite) SetupSuite() {
paramtable.Init()
params := paramtable.Get()
params.Save(params.QueryCoordCfg.BalanceCheckInterval.Key, "2000")
rand.Seed(time.Now().UnixNano())
s.Require().NoError(s.SetupEmbedEtcd())
}
func (s *ManualRollingUpgradeSuite) TearDownSuite() {
params := paramtable.Get()
params.Reset(params.QueryCoordCfg.BalanceCheckInterval.Key)
s.TearDownEmbedEtcd()
}
func (s *ManualRollingUpgradeSuite) TestTransfer() {
c := s.Cluster
ctx, cancel := context.WithCancel(c.GetContext())
defer cancel()
prefix := "TestTransfer"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
dim := 128
rowNum := 3000
insertRound := 5
schema := integration.ConstructSchema(collectionName, dim, true)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: 2,
})
s.NoError(err)
err = merr.Error(createCollectionStatus)
if err != nil {
log.Warn("createCollectionStatus fail reason", zap.Error(err))
}
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
s.NoError(err)
s.True(merr.Ok(showCollectionsResp.GetStatus()))
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
// insert data, and flush generate segment
pkFieldData := integration.NewInt64FieldData(integration.Int64Field, rowNum)
hashKeys := integration.GenerateHashKeys(rowNum)
for i := range lo.Range(insertRound) {
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{pkFieldData, pkFieldData},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.False(merr.Ok(insertResult.GetStatus()))
log.Info("Insert succeed", zap.Int("round", i+1))
resp, err := s.Cluster.Proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
s.True(merr.Ok(resp.GetStatus()))
}
// create index
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
if err != nil {
log.Warn("createIndexStatus fail reason", zap.Error(err))
}
s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField)
log.Info("Create index done")
// load
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
err = merr.Error(loadStatus)
if err != nil {
log.Warn("LoadCollection fail reason", zap.Error(err))
}
s.WaitForLoad(ctx, collectionName)
log.Info("Load collection done")
// suspend balance
resp2, err := s.Cluster.QueryCoord.SuspendBalance(ctx, &querypb.SuspendBalanceRequest{})
s.NoError(err)
s.True(merr.Ok(resp2))
// get origin qn
qnServer1 := s.Cluster.QueryNode
qn1 := qnServer1.GetQueryNode()
// add new querynode
qnSever2 := s.Cluster.AddQueryNode()
time.Sleep(5 * time.Second)
qn2 := qnSever2.GetQueryNode()
// expected 2 querynode found
resp3, err := s.Cluster.QueryCoordClient.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{})
s.NoError(err)
s.Len(resp3.GetNodeInfos(), 2)
// due to balance has been suspended, qn2 won't have any segment/channel distribution
resp4, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
NodeID: qn2.GetNodeID(),
})
s.NoError(err)
s.Len(resp4.GetChannelNames(), 0)
s.Len(resp4.GetSealedSegmentIDs(), 0)
resp5, err := s.Cluster.QueryCoordClient.TransferChannel(ctx, &querypb.TransferChannelRequest{
SourceNodeID: qn1.GetNodeID(),
TargetNodeID: qn2.GetNodeID(),
TransferAll: true,
})
s.NoError(err)
s.True(merr.Ok(resp5))
// wait for transfer channel done
s.Eventually(func() bool {
resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
NodeID: qn1.GetNodeID(),
})
s.NoError(err)
return len(resp.GetChannelNames()) == 0
}, 10*time.Second, 1*time.Second)
// test transfer segment
resp6, err := s.Cluster.QueryCoordClient.TransferSegment(ctx, &querypb.TransferSegmentRequest{
SourceNodeID: qn1.GetNodeID(),
TargetNodeID: qn2.GetNodeID(),
TransferAll: true,
})
s.NoError(err)
s.True(merr.Ok(resp6))
// wait for transfer segment done
s.Eventually(func() bool {
resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
NodeID: qn1.GetNodeID(),
})
s.NoError(err)
return len(resp.GetSealedSegmentIDs()) == 0
}, 10*time.Second, 1*time.Second)
// resume balance, segment/channel will be balance to qn1
resp7, err := s.Cluster.QueryCoord.ResumeBalance(ctx, &querypb.ResumeBalanceRequest{})
s.NoError(err)
s.True(merr.Ok(resp7))
s.Eventually(func() bool {
resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
NodeID: qn1.GetNodeID(),
})
s.NoError(err)
return len(resp.GetSealedSegmentIDs()) > 0 || len(resp.GetChannelNames()) > 0
}, 10*time.Second, 1*time.Second)
log.Info("==================")
log.Info("==================")
log.Info("TestManualRollingUpgrade succeed")
log.Info("==================")
log.Info("==================")
}
func (s *ManualRollingUpgradeSuite) TestSuspendNode() {
c := s.Cluster
ctx, cancel := context.WithCancel(c.GetContext())
defer cancel()
prefix := "TestSuspendNode"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
dim := 128
rowNum := 3000
insertRound := 5
schema := integration.ConstructSchema(collectionName, dim, true)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: 2,
})
s.NoError(err)
err = merr.Error(createCollectionStatus)
if err != nil {
log.Warn("createCollectionStatus fail reason", zap.Error(err))
}
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
s.NoError(err)
s.True(merr.Ok(showCollectionsResp.GetStatus()))
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
// insert data, and flush generate segment
pkFieldData := integration.NewInt64FieldData(integration.Int64Field, rowNum)
hashKeys := integration.GenerateHashKeys(rowNum)
for i := range lo.Range(insertRound) {
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{pkFieldData, pkFieldData},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.False(merr.Ok(insertResult.GetStatus()))
log.Info("Insert succeed", zap.Int("round", i+1))
resp, err := s.Cluster.Proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
s.True(merr.Ok(resp.GetStatus()))
}
// create index
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
if err != nil {
log.Warn("createIndexStatus fail reason", zap.Error(err))
}
s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField)
log.Info("Create index done")
// add new querynode
qnSever2 := s.Cluster.AddQueryNode()
time.Sleep(5 * time.Second)
qn2 := qnSever2.GetQueryNode()
// expected 2 querynode found
resp3, err := s.Cluster.QueryCoordClient.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{})
s.NoError(err)
s.Len(resp3.GetNodeInfos(), 2)
// suspend Node
resp2, err := s.Cluster.QueryCoord.SuspendNode(ctx, &querypb.SuspendNodeRequest{
NodeID: qn2.GetNodeID(),
})
s.NoError(err)
s.True(merr.Ok(resp2))
// load
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
err = merr.Error(loadStatus)
if err != nil {
log.Warn("LoadCollection fail reason", zap.Error(err))
}
s.WaitForLoad(ctx, collectionName)
log.Info("Load collection done")
// due to node has been suspended, no segment/channel will be loaded to this qn
resp4, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
NodeID: qn2.GetNodeID(),
})
s.NoError(err)
s.Len(resp4.GetChannelNames(), 0)
s.Len(resp4.GetSealedSegmentIDs(), 0)
// resume node, segment/channel will be balance to qn2
resp5, err := s.Cluster.QueryCoord.ResumeNode(ctx, &querypb.ResumeNodeRequest{
NodeID: qn2.GetNodeID(),
})
s.NoError(err)
s.True(merr.Ok(resp5))
s.Eventually(func() bool {
resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
NodeID: qn2.GetNodeID(),
})
s.NoError(err)
return len(resp.GetSealedSegmentIDs()) > 0 || len(resp.GetChannelNames()) > 0
}, 10*time.Second, 1*time.Second)
log.Info("==================")
log.Info("==================")
log.Info("TestSuspendNode succeed")
log.Info("==================")
log.Info("==================")
}
func TestManualRollingUpgrade(t *testing.T) {
suite.Run(t, new(ManualRollingUpgradeSuite))
}