mirror of https://github.com/milvus-io/milvus.git
enhance: Add restful api for devops to execute rolling upgrade (#29998)
issue: #29261 This PR Add restful api for devops to execute rolling upgrade, including suspend/resume balance and manual transfer segments/channels. --------- Signed-off-by: Wei Liu <wei.liu@zilliz.com>pull/31610/head
parent
bd44bd5ae2
commit
92971707de
|
@ -402,3 +402,102 @@ func (c *Client) DeactivateChecker(ctx context.Context, req *querypb.DeactivateC
|
|||
return client.DeactivateChecker(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest, opts ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error) {
|
||||
req = typeutil.Clone(req)
|
||||
commonpbutil.UpdateMsgBase(
|
||||
req.GetBase(),
|
||||
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
|
||||
)
|
||||
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*querypb.ListQueryNodeResponse, error) {
|
||||
return client.ListQueryNode(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error) {
|
||||
req = typeutil.Clone(req)
|
||||
commonpbutil.UpdateMsgBase(
|
||||
req.GetBase(),
|
||||
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
|
||||
)
|
||||
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*querypb.GetQueryNodeDistributionResponse, error) {
|
||||
return client.GetQueryNodeDistribution(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
req = typeutil.Clone(req)
|
||||
commonpbutil.UpdateMsgBase(
|
||||
req.GetBase(),
|
||||
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
|
||||
)
|
||||
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
|
||||
return client.SuspendBalance(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
req = typeutil.Clone(req)
|
||||
commonpbutil.UpdateMsgBase(
|
||||
req.GetBase(),
|
||||
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
|
||||
)
|
||||
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
|
||||
return client.ResumeBalance(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
req = typeutil.Clone(req)
|
||||
commonpbutil.UpdateMsgBase(
|
||||
req.GetBase(),
|
||||
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
|
||||
)
|
||||
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
|
||||
return client.SuspendNode(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
req = typeutil.Clone(req)
|
||||
commonpbutil.UpdateMsgBase(
|
||||
req.GetBase(),
|
||||
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
|
||||
)
|
||||
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
|
||||
return client.ResumeNode(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
req = typeutil.Clone(req)
|
||||
commonpbutil.UpdateMsgBase(
|
||||
req.GetBase(),
|
||||
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
|
||||
)
|
||||
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
|
||||
return client.TransferSegment(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
req = typeutil.Clone(req)
|
||||
commonpbutil.UpdateMsgBase(
|
||||
req.GetBase(),
|
||||
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
|
||||
)
|
||||
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
|
||||
return client.TransferChannel(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
req = typeutil.Clone(req)
|
||||
commonpbutil.UpdateMsgBase(
|
||||
req.GetBase(),
|
||||
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
|
||||
)
|
||||
return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) {
|
||||
return client.CheckQueryNodeDistribution(ctx, req)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -158,6 +158,33 @@ func Test_NewClient(t *testing.T) {
|
|||
|
||||
r30, err := client.DeactivateChecker(ctx, nil)
|
||||
retCheck(retNotNil, r30, err)
|
||||
|
||||
r31, err := client.ListQueryNode(ctx, nil)
|
||||
retCheck(retNotNil, r31, err)
|
||||
|
||||
r32, err := client.GetQueryNodeDistribution(ctx, nil)
|
||||
retCheck(retNotNil, r32, err)
|
||||
|
||||
r33, err := client.SuspendBalance(ctx, nil)
|
||||
retCheck(retNotNil, r33, err)
|
||||
|
||||
r34, err := client.ResumeBalance(ctx, nil)
|
||||
retCheck(retNotNil, r34, err)
|
||||
|
||||
r35, err := client.SuspendNode(ctx, nil)
|
||||
retCheck(retNotNil, r35, err)
|
||||
|
||||
r36, err := client.ResumeNode(ctx, nil)
|
||||
retCheck(retNotNil, r36, err)
|
||||
|
||||
r37, err := client.TransferSegment(ctx, nil)
|
||||
retCheck(retNotNil, r37, err)
|
||||
|
||||
r38, err := client.TransferChannel(ctx, nil)
|
||||
retCheck(retNotNil, r38, err)
|
||||
|
||||
r39, err := client.CheckQueryNodeDistribution(ctx, nil)
|
||||
retCheck(retNotNil, r39, err)
|
||||
}
|
||||
|
||||
client.(*Client).grpcClient = &mock.GRPCClientBase[querypb.QueryCoordClient]{
|
||||
|
|
|
@ -446,3 +446,39 @@ func (s *Server) DeactivateChecker(ctx context.Context, req *querypb.DeactivateC
|
|||
func (s *Server) ListCheckers(ctx context.Context, req *querypb.ListCheckersRequest) (*querypb.ListCheckersResponse, error) {
|
||||
return s.queryCoord.ListCheckers(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Server) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error) {
|
||||
return s.queryCoord.ListQueryNode(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Server) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error) {
|
||||
return s.queryCoord.GetQueryNodeDistribution(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Server) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest) (*commonpb.Status, error) {
|
||||
return s.queryCoord.SuspendBalance(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Server) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest) (*commonpb.Status, error) {
|
||||
return s.queryCoord.ResumeBalance(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Server) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest) (*commonpb.Status, error) {
|
||||
return s.queryCoord.SuspendNode(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Server) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest) (*commonpb.Status, error) {
|
||||
return s.queryCoord.ResumeNode(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Server) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest) (*commonpb.Status, error) {
|
||||
return s.queryCoord.TransferSegment(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest) (*commonpb.Status, error) {
|
||||
return s.queryCoord.TransferChannel(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error) {
|
||||
return s.queryCoord.CheckQueryNodeDistribution(ctx, req)
|
||||
}
|
||||
|
|
|
@ -283,6 +283,78 @@ func Test_NewServer(t *testing.T) {
|
|||
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
})
|
||||
|
||||
t.Run("ListQueryNode", func(t *testing.T) {
|
||||
req := &querypb.ListQueryNodeRequest{}
|
||||
mqc.EXPECT().ListQueryNode(mock.Anything, req).Return(&querypb.ListQueryNodeResponse{Status: merr.Success()}, nil)
|
||||
resp, err := server.ListQueryNode(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
})
|
||||
|
||||
t.Run("GetQueryNodeDistribution", func(t *testing.T) {
|
||||
req := &querypb.GetQueryNodeDistributionRequest{}
|
||||
mqc.EXPECT().GetQueryNodeDistribution(mock.Anything, req).Return(&querypb.GetQueryNodeDistributionResponse{Status: merr.Success()}, nil)
|
||||
resp, err := server.GetQueryNodeDistribution(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
})
|
||||
|
||||
t.Run("SuspendBalance", func(t *testing.T) {
|
||||
req := &querypb.SuspendBalanceRequest{}
|
||||
mqc.EXPECT().SuspendBalance(mock.Anything, req).Return(merr.Success(), nil)
|
||||
resp, err := server.SuspendBalance(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
|
||||
})
|
||||
|
||||
t.Run("ResumeBalance", func(t *testing.T) {
|
||||
req := &querypb.ResumeBalanceRequest{}
|
||||
mqc.EXPECT().ResumeBalance(mock.Anything, req).Return(merr.Success(), nil)
|
||||
resp, err := server.ResumeBalance(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
|
||||
})
|
||||
|
||||
t.Run("SuspendNode", func(t *testing.T) {
|
||||
req := &querypb.SuspendNodeRequest{}
|
||||
mqc.EXPECT().SuspendNode(mock.Anything, req).Return(merr.Success(), nil)
|
||||
resp, err := server.SuspendNode(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
|
||||
})
|
||||
|
||||
t.Run("ResumeNode", func(t *testing.T) {
|
||||
req := &querypb.ResumeNodeRequest{}
|
||||
mqc.EXPECT().ResumeNode(mock.Anything, req).Return(merr.Success(), nil)
|
||||
resp, err := server.ResumeNode(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
|
||||
})
|
||||
|
||||
t.Run("TransferSegment", func(t *testing.T) {
|
||||
req := &querypb.TransferSegmentRequest{}
|
||||
mqc.EXPECT().TransferSegment(mock.Anything, req).Return(merr.Success(), nil)
|
||||
resp, err := server.TransferSegment(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
|
||||
})
|
||||
|
||||
t.Run("TransferChannel", func(t *testing.T) {
|
||||
req := &querypb.TransferChannelRequest{}
|
||||
mqc.EXPECT().TransferChannel(mock.Anything, req).Return(merr.Success(), nil)
|
||||
resp, err := server.TransferChannel(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
|
||||
})
|
||||
|
||||
t.Run("CheckQueryNodeDistribution", func(t *testing.T) {
|
||||
req := &querypb.CheckQueryNodeDistributionRequest{}
|
||||
mqc.EXPECT().CheckQueryNodeDistribution(mock.Anything, req).Return(merr.Success(), nil)
|
||||
resp, err := server.CheckQueryNodeDistribution(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
|
||||
})
|
||||
|
||||
err = server.Stop()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -74,6 +74,10 @@ func (s *Server) GetStatistics(ctx context.Context, request *querypb.GetStatisti
|
|||
return s.querynode.GetStatistics(ctx, request)
|
||||
}
|
||||
|
||||
func (s *Server) GetQueryNode() types.QueryNodeComponent {
|
||||
return s.querynode
|
||||
}
|
||||
|
||||
// NewServer create a new QueryNode grpc server.
|
||||
func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) {
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
|
|
|
@ -144,6 +144,61 @@ func (_c *MockQueryCoord_CheckHealth_Call) RunAndReturn(run func(context.Context
|
|||
return _c
|
||||
}
|
||||
|
||||
// CheckQueryNodeDistribution provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) CheckQueryNodeDistribution(_a0 context.Context, _a1 *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest) *commonpb.Status); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoord_CheckQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckQueryNodeDistribution'
|
||||
type MockQueryCoord_CheckQueryNodeDistribution_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CheckQueryNodeDistribution is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.CheckQueryNodeDistributionRequest
|
||||
func (_e *MockQueryCoord_Expecter) CheckQueryNodeDistribution(_a0 interface{}, _a1 interface{}) *MockQueryCoord_CheckQueryNodeDistribution_Call {
|
||||
return &MockQueryCoord_CheckQueryNodeDistribution_Call{Call: _e.mock.On("CheckQueryNodeDistribution", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_CheckQueryNodeDistribution_Call) Run(run func(_a0 context.Context, _a1 *querypb.CheckQueryNodeDistributionRequest)) *MockQueryCoord_CheckQueryNodeDistribution_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.CheckQueryNodeDistributionRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_CheckQueryNodeDistribution_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_CheckQueryNodeDistribution_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_CheckQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error)) *MockQueryCoord_CheckQueryNodeDistribution_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// CreateResourceGroup provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) CreateResourceGroup(_a0 context.Context, _a1 *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
@ -529,6 +584,61 @@ func (_c *MockQueryCoord_GetPartitionStates_Call) RunAndReturn(run func(context.
|
|||
return _c
|
||||
}
|
||||
|
||||
// GetQueryNodeDistribution provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) GetQueryNodeDistribution(_a0 context.Context, _a1 *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *querypb.GetQueryNodeDistributionResponse
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest) *querypb.GetQueryNodeDistributionResponse); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.GetQueryNodeDistributionResponse)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetQueryNodeDistributionRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoord_GetQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetQueryNodeDistribution'
|
||||
type MockQueryCoord_GetQueryNodeDistribution_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetQueryNodeDistribution is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.GetQueryNodeDistributionRequest
|
||||
func (_e *MockQueryCoord_Expecter) GetQueryNodeDistribution(_a0 interface{}, _a1 interface{}) *MockQueryCoord_GetQueryNodeDistribution_Call {
|
||||
return &MockQueryCoord_GetQueryNodeDistribution_Call{Call: _e.mock.On("GetQueryNodeDistribution", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_GetQueryNodeDistribution_Call) Run(run func(_a0 context.Context, _a1 *querypb.GetQueryNodeDistributionRequest)) *MockQueryCoord_GetQueryNodeDistribution_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.GetQueryNodeDistributionRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_GetQueryNodeDistribution_Call) Return(_a0 *querypb.GetQueryNodeDistributionResponse, _a1 error) *MockQueryCoord_GetQueryNodeDistribution_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_GetQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error)) *MockQueryCoord_GetQueryNodeDistribution_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetReplicas provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) GetReplicas(_a0 context.Context, _a1 *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
@ -900,6 +1010,61 @@ func (_c *MockQueryCoord_ListCheckers_Call) RunAndReturn(run func(context.Contex
|
|||
return _c
|
||||
}
|
||||
|
||||
// ListQueryNode provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) ListQueryNode(_a0 context.Context, _a1 *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *querypb.ListQueryNodeResponse
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest) *querypb.ListQueryNodeResponse); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.ListQueryNodeResponse)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ListQueryNodeRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoord_ListQueryNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListQueryNode'
|
||||
type MockQueryCoord_ListQueryNode_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListQueryNode is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.ListQueryNodeRequest
|
||||
func (_e *MockQueryCoord_Expecter) ListQueryNode(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ListQueryNode_Call {
|
||||
return &MockQueryCoord_ListQueryNode_Call{Call: _e.mock.On("ListQueryNode", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_ListQueryNode_Call) Run(run func(_a0 context.Context, _a1 *querypb.ListQueryNodeRequest)) *MockQueryCoord_ListQueryNode_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.ListQueryNodeRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_ListQueryNode_Call) Return(_a0 *querypb.ListQueryNodeResponse, _a1 error) *MockQueryCoord_ListQueryNode_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_ListQueryNode_Call) RunAndReturn(run func(context.Context, *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error)) *MockQueryCoord_ListQueryNode_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListResourceGroups provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) ListResourceGroups(_a0 context.Context, _a1 *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
@ -1271,6 +1436,116 @@ func (_c *MockQueryCoord_ReleasePartitions_Call) RunAndReturn(run func(context.C
|
|||
return _c
|
||||
}
|
||||
|
||||
// ResumeBalance provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) ResumeBalance(_a0 context.Context, _a1 *querypb.ResumeBalanceRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest) (*commonpb.Status, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest) *commonpb.Status); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeBalanceRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoord_ResumeBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeBalance'
|
||||
type MockQueryCoord_ResumeBalance_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ResumeBalance is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.ResumeBalanceRequest
|
||||
func (_e *MockQueryCoord_Expecter) ResumeBalance(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ResumeBalance_Call {
|
||||
return &MockQueryCoord_ResumeBalance_Call{Call: _e.mock.On("ResumeBalance", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_ResumeBalance_Call) Run(run func(_a0 context.Context, _a1 *querypb.ResumeBalanceRequest)) *MockQueryCoord_ResumeBalance_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.ResumeBalanceRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_ResumeBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_ResumeBalance_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_ResumeBalance_Call) RunAndReturn(run func(context.Context, *querypb.ResumeBalanceRequest) (*commonpb.Status, error)) *MockQueryCoord_ResumeBalance_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ResumeNode provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) ResumeNode(_a0 context.Context, _a1 *querypb.ResumeNodeRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest) (*commonpb.Status, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest) *commonpb.Status); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeNodeRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoord_ResumeNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeNode'
|
||||
type MockQueryCoord_ResumeNode_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ResumeNode is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.ResumeNodeRequest
|
||||
func (_e *MockQueryCoord_Expecter) ResumeNode(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ResumeNode_Call {
|
||||
return &MockQueryCoord_ResumeNode_Call{Call: _e.mock.On("ResumeNode", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_ResumeNode_Call) Run(run func(_a0 context.Context, _a1 *querypb.ResumeNodeRequest)) *MockQueryCoord_ResumeNode_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.ResumeNodeRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_ResumeNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_ResumeNode_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_ResumeNode_Call) RunAndReturn(run func(context.Context, *querypb.ResumeNodeRequest) (*commonpb.Status, error)) *MockQueryCoord_ResumeNode_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAddress provides a mock function with given fields: address
|
||||
func (_m *MockQueryCoord) SetAddress(address string) {
|
||||
_m.Called(address)
|
||||
|
@ -1734,6 +2009,116 @@ func (_c *MockQueryCoord_Stop_Call) RunAndReturn(run func() error) *MockQueryCoo
|
|||
return _c
|
||||
}
|
||||
|
||||
// SuspendBalance provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) SuspendBalance(_a0 context.Context, _a1 *querypb.SuspendBalanceRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest) (*commonpb.Status, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest) *commonpb.Status); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendBalanceRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoord_SuspendBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendBalance'
|
||||
type MockQueryCoord_SuspendBalance_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SuspendBalance is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.SuspendBalanceRequest
|
||||
func (_e *MockQueryCoord_Expecter) SuspendBalance(_a0 interface{}, _a1 interface{}) *MockQueryCoord_SuspendBalance_Call {
|
||||
return &MockQueryCoord_SuspendBalance_Call{Call: _e.mock.On("SuspendBalance", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_SuspendBalance_Call) Run(run func(_a0 context.Context, _a1 *querypb.SuspendBalanceRequest)) *MockQueryCoord_SuspendBalance_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.SuspendBalanceRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_SuspendBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_SuspendBalance_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_SuspendBalance_Call) RunAndReturn(run func(context.Context, *querypb.SuspendBalanceRequest) (*commonpb.Status, error)) *MockQueryCoord_SuspendBalance_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SuspendNode provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) SuspendNode(_a0 context.Context, _a1 *querypb.SuspendNodeRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest) (*commonpb.Status, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest) *commonpb.Status); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendNodeRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoord_SuspendNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendNode'
|
||||
type MockQueryCoord_SuspendNode_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SuspendNode is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.SuspendNodeRequest
|
||||
func (_e *MockQueryCoord_Expecter) SuspendNode(_a0 interface{}, _a1 interface{}) *MockQueryCoord_SuspendNode_Call {
|
||||
return &MockQueryCoord_SuspendNode_Call{Call: _e.mock.On("SuspendNode", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_SuspendNode_Call) Run(run func(_a0 context.Context, _a1 *querypb.SuspendNodeRequest)) *MockQueryCoord_SuspendNode_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.SuspendNodeRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_SuspendNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_SuspendNode_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_SuspendNode_Call) RunAndReturn(run func(context.Context, *querypb.SuspendNodeRequest) (*commonpb.Status, error)) *MockQueryCoord_SuspendNode_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SyncNewCreatedPartition provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) SyncNewCreatedPartition(_a0 context.Context, _a1 *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
@ -1789,6 +2174,61 @@ func (_c *MockQueryCoord_SyncNewCreatedPartition_Call) RunAndReturn(run func(con
|
|||
return _c
|
||||
}
|
||||
|
||||
// TransferChannel provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) TransferChannel(_a0 context.Context, _a1 *querypb.TransferChannelRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest) (*commonpb.Status, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest) *commonpb.Status); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferChannelRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoord_TransferChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferChannel'
|
||||
type MockQueryCoord_TransferChannel_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// TransferChannel is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.TransferChannelRequest
|
||||
func (_e *MockQueryCoord_Expecter) TransferChannel(_a0 interface{}, _a1 interface{}) *MockQueryCoord_TransferChannel_Call {
|
||||
return &MockQueryCoord_TransferChannel_Call{Call: _e.mock.On("TransferChannel", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_TransferChannel_Call) Run(run func(_a0 context.Context, _a1 *querypb.TransferChannelRequest)) *MockQueryCoord_TransferChannel_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.TransferChannelRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_TransferChannel_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_TransferChannel_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_TransferChannel_Call) RunAndReturn(run func(context.Context, *querypb.TransferChannelRequest) (*commonpb.Status, error)) *MockQueryCoord_TransferChannel_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// TransferNode provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) TransferNode(_a0 context.Context, _a1 *milvuspb.TransferNodeRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
@ -1899,6 +2339,61 @@ func (_c *MockQueryCoord_TransferReplica_Call) RunAndReturn(run func(context.Con
|
|||
return _c
|
||||
}
|
||||
|
||||
// TransferSegment provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryCoord) TransferSegment(_a0 context.Context, _a1 *querypb.TransferSegmentRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest) (*commonpb.Status, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest) *commonpb.Status); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferSegmentRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoord_TransferSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferSegment'
|
||||
type MockQueryCoord_TransferSegment_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// TransferSegment is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.TransferSegmentRequest
|
||||
func (_e *MockQueryCoord_Expecter) TransferSegment(_a0 interface{}, _a1 interface{}) *MockQueryCoord_TransferSegment_Call {
|
||||
return &MockQueryCoord_TransferSegment_Call{Call: _e.mock.On("TransferSegment", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_TransferSegment_Call) Run(run func(_a0 context.Context, _a1 *querypb.TransferSegmentRequest)) *MockQueryCoord_TransferSegment_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.TransferSegmentRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_TransferSegment_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_TransferSegment_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoord_TransferSegment_Call) RunAndReturn(run func(context.Context, *querypb.TransferSegmentRequest) (*commonpb.Status, error)) *MockQueryCoord_TransferSegment_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateStateCode provides a mock function with given fields: stateCode
|
||||
func (_m *MockQueryCoord) UpdateStateCode(stateCode commonpb.StateCode) {
|
||||
_m.Called(stateCode)
|
||||
|
|
|
@ -171,6 +171,76 @@ func (_c *MockQueryCoordClient_CheckHealth_Call) RunAndReturn(run func(context.C
|
|||
return _c
|
||||
}
|
||||
|
||||
// CheckQueryNodeDistribution provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) CheckQueryNodeDistribution(ctx context.Context, in *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) *commonpb.Status); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoordClient_CheckQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckQueryNodeDistribution'
|
||||
type MockQueryCoordClient_CheckQueryNodeDistribution_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CheckQueryNodeDistribution is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *querypb.CheckQueryNodeDistributionRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockQueryCoordClient_Expecter) CheckQueryNodeDistribution(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_CheckQueryNodeDistribution_Call {
|
||||
return &MockQueryCoordClient_CheckQueryNodeDistribution_Call{Call: _e.mock.On("CheckQueryNodeDistribution",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_CheckQueryNodeDistribution_Call) Run(run func(ctx context.Context, in *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_CheckQueryNodeDistribution_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*querypb.CheckQueryNodeDistributionRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_CheckQueryNodeDistribution_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_CheckQueryNodeDistribution_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_CheckQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_CheckQueryNodeDistribution_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *MockQueryCoordClient) Close() error {
|
||||
ret := _m.Called()
|
||||
|
@ -702,6 +772,76 @@ func (_c *MockQueryCoordClient_GetPartitionStates_Call) RunAndReturn(run func(co
|
|||
return _c
|
||||
}
|
||||
|
||||
// GetQueryNodeDistribution provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) GetQueryNodeDistribution(ctx context.Context, in *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *querypb.GetQueryNodeDistributionResponse
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) *querypb.GetQueryNodeDistributionResponse); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.GetQueryNodeDistributionResponse)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoordClient_GetQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetQueryNodeDistribution'
|
||||
type MockQueryCoordClient_GetQueryNodeDistribution_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetQueryNodeDistribution is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *querypb.GetQueryNodeDistributionRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockQueryCoordClient_Expecter) GetQueryNodeDistribution(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_GetQueryNodeDistribution_Call {
|
||||
return &MockQueryCoordClient_GetQueryNodeDistribution_Call{Call: _e.mock.On("GetQueryNodeDistribution",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_GetQueryNodeDistribution_Call) Run(run func(ctx context.Context, in *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_GetQueryNodeDistribution_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*querypb.GetQueryNodeDistributionRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_GetQueryNodeDistribution_Call) Return(_a0 *querypb.GetQueryNodeDistributionResponse, _a1 error) *MockQueryCoordClient_GetQueryNodeDistribution_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_GetQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error)) *MockQueryCoordClient_GetQueryNodeDistribution_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetReplicas provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) GetReplicas(ctx context.Context, in *milvuspb.GetReplicasRequest, opts ...grpc.CallOption) (*milvuspb.GetReplicasResponse, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
|
@ -1122,6 +1262,76 @@ func (_c *MockQueryCoordClient_ListCheckers_Call) RunAndReturn(run func(context.
|
|||
return _c
|
||||
}
|
||||
|
||||
// ListQueryNode provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) ListQueryNode(ctx context.Context, in *querypb.ListQueryNodeRequest, opts ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *querypb.ListQueryNodeResponse
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) *querypb.ListQueryNodeResponse); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.ListQueryNodeResponse)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoordClient_ListQueryNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListQueryNode'
|
||||
type MockQueryCoordClient_ListQueryNode_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListQueryNode is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *querypb.ListQueryNodeRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockQueryCoordClient_Expecter) ListQueryNode(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ListQueryNode_Call {
|
||||
return &MockQueryCoordClient_ListQueryNode_Call{Call: _e.mock.On("ListQueryNode",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_ListQueryNode_Call) Run(run func(ctx context.Context, in *querypb.ListQueryNodeRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ListQueryNode_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*querypb.ListQueryNodeRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_ListQueryNode_Call) Return(_a0 *querypb.ListQueryNodeResponse, _a1 error) *MockQueryCoordClient_ListQueryNode_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_ListQueryNode_Call) RunAndReturn(run func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error)) *MockQueryCoordClient_ListQueryNode_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListResourceGroups provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) ListResourceGroups(ctx context.Context, in *milvuspb.ListResourceGroupsRequest, opts ...grpc.CallOption) (*milvuspb.ListResourceGroupsResponse, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
|
@ -1542,6 +1752,146 @@ func (_c *MockQueryCoordClient_ReleasePartitions_Call) RunAndReturn(run func(con
|
|||
return _c
|
||||
}
|
||||
|
||||
// ResumeBalance provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) ResumeBalance(ctx context.Context, in *querypb.ResumeBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) *commonpb.Status); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoordClient_ResumeBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeBalance'
|
||||
type MockQueryCoordClient_ResumeBalance_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ResumeBalance is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *querypb.ResumeBalanceRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockQueryCoordClient_Expecter) ResumeBalance(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ResumeBalance_Call {
|
||||
return &MockQueryCoordClient_ResumeBalance_Call{Call: _e.mock.On("ResumeBalance",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_ResumeBalance_Call) Run(run func(ctx context.Context, in *querypb.ResumeBalanceRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ResumeBalance_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*querypb.ResumeBalanceRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_ResumeBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_ResumeBalance_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_ResumeBalance_Call) RunAndReturn(run func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_ResumeBalance_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ResumeNode provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) ResumeNode(ctx context.Context, in *querypb.ResumeNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) *commonpb.Status); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoordClient_ResumeNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeNode'
|
||||
type MockQueryCoordClient_ResumeNode_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ResumeNode is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *querypb.ResumeNodeRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockQueryCoordClient_Expecter) ResumeNode(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ResumeNode_Call {
|
||||
return &MockQueryCoordClient_ResumeNode_Call{Call: _e.mock.On("ResumeNode",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_ResumeNode_Call) Run(run func(ctx context.Context, in *querypb.ResumeNodeRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ResumeNode_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*querypb.ResumeNodeRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_ResumeNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_ResumeNode_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_ResumeNode_Call) RunAndReturn(run func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_ResumeNode_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ShowCollections provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) ShowCollections(ctx context.Context, in *querypb.ShowCollectionsRequest, opts ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
|
@ -1752,6 +2102,146 @@ func (_c *MockQueryCoordClient_ShowPartitions_Call) RunAndReturn(run func(contex
|
|||
return _c
|
||||
}
|
||||
|
||||
// SuspendBalance provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) SuspendBalance(ctx context.Context, in *querypb.SuspendBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) *commonpb.Status); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoordClient_SuspendBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendBalance'
|
||||
type MockQueryCoordClient_SuspendBalance_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SuspendBalance is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *querypb.SuspendBalanceRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockQueryCoordClient_Expecter) SuspendBalance(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_SuspendBalance_Call {
|
||||
return &MockQueryCoordClient_SuspendBalance_Call{Call: _e.mock.On("SuspendBalance",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_SuspendBalance_Call) Run(run func(ctx context.Context, in *querypb.SuspendBalanceRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_SuspendBalance_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*querypb.SuspendBalanceRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_SuspendBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_SuspendBalance_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_SuspendBalance_Call) RunAndReturn(run func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_SuspendBalance_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SuspendNode provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) SuspendNode(ctx context.Context, in *querypb.SuspendNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) *commonpb.Status); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoordClient_SuspendNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendNode'
|
||||
type MockQueryCoordClient_SuspendNode_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SuspendNode is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *querypb.SuspendNodeRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockQueryCoordClient_Expecter) SuspendNode(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_SuspendNode_Call {
|
||||
return &MockQueryCoordClient_SuspendNode_Call{Call: _e.mock.On("SuspendNode",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_SuspendNode_Call) Run(run func(ctx context.Context, in *querypb.SuspendNodeRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_SuspendNode_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*querypb.SuspendNodeRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_SuspendNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_SuspendNode_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_SuspendNode_Call) RunAndReturn(run func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_SuspendNode_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SyncNewCreatedPartition provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) SyncNewCreatedPartition(ctx context.Context, in *querypb.SyncNewCreatedPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
|
@ -1822,6 +2312,76 @@ func (_c *MockQueryCoordClient_SyncNewCreatedPartition_Call) RunAndReturn(run fu
|
|||
return _c
|
||||
}
|
||||
|
||||
// TransferChannel provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) TransferChannel(ctx context.Context, in *querypb.TransferChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) *commonpb.Status); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoordClient_TransferChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferChannel'
|
||||
type MockQueryCoordClient_TransferChannel_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// TransferChannel is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *querypb.TransferChannelRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockQueryCoordClient_Expecter) TransferChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_TransferChannel_Call {
|
||||
return &MockQueryCoordClient_TransferChannel_Call{Call: _e.mock.On("TransferChannel",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_TransferChannel_Call) Run(run func(ctx context.Context, in *querypb.TransferChannelRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_TransferChannel_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*querypb.TransferChannelRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_TransferChannel_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_TransferChannel_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_TransferChannel_Call) RunAndReturn(run func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_TransferChannel_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// TransferNode provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) TransferNode(ctx context.Context, in *milvuspb.TransferNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
|
@ -1962,6 +2522,76 @@ func (_c *MockQueryCoordClient_TransferReplica_Call) RunAndReturn(run func(conte
|
|||
return _c
|
||||
}
|
||||
|
||||
// TransferSegment provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryCoordClient) TransferSegment(ctx context.Context, in *querypb.TransferSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) *commonpb.Status); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryCoordClient_TransferSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferSegment'
|
||||
type MockQueryCoordClient_TransferSegment_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// TransferSegment is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *querypb.TransferSegmentRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockQueryCoordClient_Expecter) TransferSegment(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_TransferSegment_Call {
|
||||
return &MockQueryCoordClient_TransferSegment_Call{Call: _e.mock.On("TransferSegment",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_TransferSegment_Call) Run(run func(ctx context.Context, in *querypb.TransferSegmentRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_TransferSegment_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*querypb.TransferSegmentRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_TransferSegment_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_TransferSegment_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryCoordClient_TransferSegment_Call) RunAndReturn(run func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_TransferSegment_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMockQueryCoordClient creates a new instance of MockQueryCoordClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockQueryCoordClient(t interface {
|
||||
|
|
|
@ -86,13 +86,21 @@ service QueryCoord {
|
|||
returns (DescribeResourceGroupResponse) {
|
||||
}
|
||||
|
||||
// ops interfaces
|
||||
rpc ListCheckers(ListCheckersRequest) returns (ListCheckersResponse) {
|
||||
}
|
||||
rpc ActivateChecker(ActivateCheckerRequest) returns (common.Status) {
|
||||
}
|
||||
rpc DeactivateChecker(DeactivateCheckerRequest) returns (common.Status) {
|
||||
}
|
||||
|
||||
// ops interfaces
|
||||
rpc ListCheckers(ListCheckersRequest) returns (ListCheckersResponse) {}
|
||||
rpc ActivateChecker(ActivateCheckerRequest) returns (common.Status) {}
|
||||
rpc DeactivateChecker(DeactivateCheckerRequest) returns (common.Status) {}
|
||||
|
||||
rpc ListQueryNode(ListQueryNodeRequest) returns (ListQueryNodeResponse) {}
|
||||
rpc GetQueryNodeDistribution(GetQueryNodeDistributionRequest) returns (GetQueryNodeDistributionResponse) {}
|
||||
rpc SuspendBalance(SuspendBalanceRequest) returns (common.Status) {}
|
||||
rpc ResumeBalance(ResumeBalanceRequest) returns (common.Status) {}
|
||||
rpc SuspendNode(SuspendNodeRequest) returns (common.Status) {}
|
||||
rpc ResumeNode(ResumeNodeRequest) returns (common.Status) {}
|
||||
rpc TransferSegment(TransferSegmentRequest) returns (common.Status) {}
|
||||
rpc TransferChannel(TransferChannelRequest) returns (common.Status) {}
|
||||
rpc CheckQueryNodeDistribution(CheckQueryNodeDistributionRequest) returns (common.Status) {}
|
||||
}
|
||||
|
||||
service QueryNode {
|
||||
|
@ -793,3 +801,75 @@ message CollectionTarget {
|
|||
repeated ChannelTarget Channel_targets = 2;
|
||||
int64 version = 3;
|
||||
}
|
||||
message NodeInfo {
|
||||
int64 ID = 2;
|
||||
string address = 3;
|
||||
string state = 4;
|
||||
}
|
||||
|
||||
message ListQueryNodeRequest {
|
||||
common.MsgBase base = 1;
|
||||
}
|
||||
|
||||
message ListQueryNodeResponse {
|
||||
common.Status status = 1;
|
||||
repeated NodeInfo nodeInfos = 2;
|
||||
}
|
||||
|
||||
message GetQueryNodeDistributionRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 nodeID = 2;
|
||||
}
|
||||
|
||||
message GetQueryNodeDistributionResponse {
|
||||
common.Status status = 1;
|
||||
int64 ID = 2;
|
||||
repeated string channel_names = 3;
|
||||
repeated int64 sealed_segmentIDs = 4;
|
||||
}
|
||||
|
||||
message SuspendBalanceRequest {
|
||||
common.MsgBase base = 1;
|
||||
}
|
||||
|
||||
message ResumeBalanceRequest {
|
||||
common.MsgBase base = 1;
|
||||
}
|
||||
|
||||
message SuspendNodeRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 nodeID = 2;
|
||||
}
|
||||
|
||||
message ResumeNodeRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 nodeID = 2;
|
||||
}
|
||||
|
||||
message TransferSegmentRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 segmentID = 2;
|
||||
int64 source_nodeID = 3;
|
||||
int64 target_nodeID = 4;
|
||||
bool transfer_all = 5;
|
||||
bool to_all_nodes = 6;
|
||||
bool copy_mode = 7;
|
||||
}
|
||||
|
||||
message TransferChannelRequest {
|
||||
common.MsgBase base = 1;
|
||||
string channel_name = 2;
|
||||
int64 source_nodeID = 3;
|
||||
int64 target_nodeID = 4;
|
||||
bool transfer_all = 5;
|
||||
bool to_all_nodes = 6;
|
||||
bool copy_mode = 7;
|
||||
}
|
||||
|
||||
message CheckQueryNodeDistributionRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 source_nodeID = 3;
|
||||
int64 target_nodeID = 4;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -17,14 +17,18 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
management "github.com/milvus-io/milvus/internal/http"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
// this file contains proxy management restful API handler
|
||||
|
@ -32,6 +36,17 @@ import (
|
|||
const (
|
||||
mgrRouteGcPause = `/management/datacoord/garbage_collection/pause`
|
||||
mgrRouteGcResume = `/management/datacoord/garbage_collection/resume`
|
||||
|
||||
mgrSuspendQueryCoordBalance = `/management/querycoord/balance/suspend`
|
||||
mgrResumeQueryCoordBalance = `/management/querycoord/balance/resume`
|
||||
mgrTransferSegment = `/management/querycoord/transfer/segment`
|
||||
mgrTransferChannel = `/management/querycoord/transfer/channel`
|
||||
|
||||
mgrSuspendQueryNode = `/management/querycoord/node/suspend`
|
||||
mgrResumeQueryNode = `/management/querycoord/node/resume`
|
||||
mgrListQueryNode = `/management/querycoord/node/list`
|
||||
mgrGetQueryNodeDistribution = `/management/querycoord/distribution/get`
|
||||
mgrCheckQueryNodeDistribution = `/management/querycoord/distribution/check`
|
||||
)
|
||||
|
||||
var mgrRouteRegisterOnce sync.Once
|
||||
|
@ -46,6 +61,42 @@ func RegisterMgrRoute(proxy *Proxy) {
|
|||
Path: mgrRouteGcResume,
|
||||
HandlerFunc: proxy.ResumeDatacoordGC,
|
||||
})
|
||||
management.Register(&management.Handler{
|
||||
Path: mgrListQueryNode,
|
||||
HandlerFunc: proxy.ListQueryNode,
|
||||
})
|
||||
management.Register(&management.Handler{
|
||||
Path: mgrGetQueryNodeDistribution,
|
||||
HandlerFunc: proxy.GetQueryNodeDistribution,
|
||||
})
|
||||
management.Register(&management.Handler{
|
||||
Path: mgrSuspendQueryCoordBalance,
|
||||
HandlerFunc: proxy.SuspendQueryCoordBalance,
|
||||
})
|
||||
management.Register(&management.Handler{
|
||||
Path: mgrResumeQueryCoordBalance,
|
||||
HandlerFunc: proxy.ResumeQueryCoordBalance,
|
||||
})
|
||||
management.Register(&management.Handler{
|
||||
Path: mgrSuspendQueryNode,
|
||||
HandlerFunc: proxy.SuspendQueryNode,
|
||||
})
|
||||
management.Register(&management.Handler{
|
||||
Path: mgrResumeQueryNode,
|
||||
HandlerFunc: proxy.ResumeQueryNode,
|
||||
})
|
||||
management.Register(&management.Handler{
|
||||
Path: mgrTransferSegment,
|
||||
HandlerFunc: proxy.TransferSegment,
|
||||
})
|
||||
management.Register(&management.Handler{
|
||||
Path: mgrTransferChannel,
|
||||
HandlerFunc: proxy.TransferChannel,
|
||||
})
|
||||
management.Register(&management.Handler{
|
||||
Path: mgrCheckQueryNodeDistribution,
|
||||
HandlerFunc: proxy.CheckQueryNodeDistribution,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -91,3 +142,362 @@ func (node *Proxy) ResumeDatacoordGC(w http.ResponseWriter, req *http.Request) {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"msg": "OK"}`))
|
||||
}
|
||||
|
||||
func (node *Proxy) ListQueryNode(w http.ResponseWriter, req *http.Request) {
|
||||
resp, err := node.queryCoord.ListQueryNode(req.Context(), &querypb.ListQueryNodeRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
})
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to list query node, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
if !merr.Ok(resp.GetStatus()) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to list query node, %s"}`, resp.GetStatus().GetReason())))
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
// skip marshal status to output
|
||||
resp.Status = nil
|
||||
bytes, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to list query node, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
w.Write(bytes)
|
||||
}
|
||||
|
||||
func (node *Proxy) GetQueryNodeDistribution(w http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
nodeID, err := strconv.ParseInt(req.FormValue("node_id"), 10, 64)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := node.queryCoord.GetQueryNodeDistribution(req.Context(), &querypb.GetQueryNodeDistributionRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
NodeID: nodeID,
|
||||
})
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
if !merr.Ok(resp.GetStatus()) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, resp.GetStatus().GetReason())))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
// skip marshal status to output
|
||||
resp.Status = nil
|
||||
bytes, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
w.Write(bytes)
|
||||
}
|
||||
|
||||
func (node *Proxy) SuspendQueryCoordBalance(w http.ResponseWriter, req *http.Request) {
|
||||
resp, err := node.queryCoord.SuspendBalance(req.Context(), &querypb.SuspendBalanceRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
})
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend balance, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
if !merr.Ok(resp) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend balance, %s"}`, resp.GetReason())))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"msg": "OK"}`))
|
||||
}
|
||||
|
||||
func (node *Proxy) ResumeQueryCoordBalance(w http.ResponseWriter, req *http.Request) {
|
||||
resp, err := node.queryCoord.ResumeBalance(req.Context(), &querypb.ResumeBalanceRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
})
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume balance, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
if !merr.Ok(resp) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume balance, %s"}`, resp.GetReason())))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"msg": "OK"}`))
|
||||
}
|
||||
|
||||
func (node *Proxy) SuspendQueryNode(w http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
nodeID, err := strconv.ParseInt(req.FormValue("node_id"), 10, 64)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend node, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
resp, err := node.queryCoord.SuspendNode(req.Context(), &querypb.SuspendNodeRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
NodeID: nodeID,
|
||||
})
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend node, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
if !merr.Ok(resp) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend node, %s"}`, resp.GetReason())))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"msg": "OK"}`))
|
||||
}
|
||||
|
||||
func (node *Proxy) ResumeQueryNode(w http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
nodeID, err := strconv.ParseInt(req.FormValue("node_id"), 10, 64)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume node, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
resp, err := node.queryCoord.ResumeNode(req.Context(), &querypb.ResumeNodeRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
NodeID: nodeID,
|
||||
})
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume node, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
if !merr.Ok(resp) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume node, %s"}`, resp.GetReason())))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"msg": "OK"}`))
|
||||
}
|
||||
|
||||
func (node *Proxy) TransferSegment(w http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
request := &querypb.TransferSegmentRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
}
|
||||
|
||||
source, err := strconv.ParseInt(req.FormValue("source_node_id"), 10, 64)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": failed to transfer segment", %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
request.SourceNodeID = source
|
||||
|
||||
target := req.FormValue("target_node_id")
|
||||
if len(target) == 0 {
|
||||
request.ToAllNodes = true
|
||||
} else {
|
||||
value, err := strconv.ParseInt(target, 10, 64)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
request.TargetNodeID = value
|
||||
}
|
||||
|
||||
segmentID := req.FormValue("segment_id")
|
||||
if len(segmentID) == 0 {
|
||||
request.TransferAll = true
|
||||
} else {
|
||||
value, err := strconv.ParseInt(segmentID, 10, 64)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
request.TargetNodeID = value
|
||||
}
|
||||
|
||||
copyMode := req.FormValue("copy_mode")
|
||||
if len(copyMode) == 0 {
|
||||
request.CopyMode = true
|
||||
} else {
|
||||
value, err := strconv.ParseBool(copyMode)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
request.CopyMode = value
|
||||
}
|
||||
|
||||
resp, err := node.queryCoord.TransferSegment(req.Context(), request)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
if !merr.Ok(resp) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, resp.GetReason())))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"msg": "OK"}`))
|
||||
}
|
||||
|
||||
func (node *Proxy) TransferChannel(w http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
request := &querypb.TransferChannelRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
}
|
||||
|
||||
source, err := strconv.ParseInt(req.FormValue("source_node_id"), 10, 64)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": failed to transfer channel", %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
request.SourceNodeID = source
|
||||
|
||||
target := req.FormValue("target_node_id")
|
||||
if len(target) == 0 {
|
||||
request.ToAllNodes = true
|
||||
} else {
|
||||
value, err := strconv.ParseInt(target, 10, 64)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
request.TargetNodeID = value
|
||||
}
|
||||
|
||||
channel := req.FormValue("channel_name")
|
||||
if len(channel) == 0 {
|
||||
request.TransferAll = true
|
||||
} else {
|
||||
request.ChannelName = channel
|
||||
}
|
||||
|
||||
copyMode := req.FormValue("copy_mode")
|
||||
if len(copyMode) == 0 {
|
||||
request.CopyMode = false
|
||||
} else {
|
||||
value, err := strconv.ParseBool(copyMode)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
request.CopyMode = value
|
||||
}
|
||||
|
||||
resp, err := node.queryCoord.TransferChannel(req.Context(), request)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
if !merr.Ok(resp) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, resp.GetReason())))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"msg": "OK"}`))
|
||||
}
|
||||
|
||||
func (node *Proxy) CheckQueryNodeDistribution(w http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
source, err := strconv.ParseInt(req.FormValue("source_node_id"), 10, 64)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": failed to check whether query node has same distribution", %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
target, err := strconv.ParseInt(req.FormValue("target_node_id"), 10, 64)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
resp, err := node.queryCoord.CheckQueryNodeDistribution(req.Context(), &querypb.CheckQueryNodeDistributionRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
SourceNodeID: source,
|
||||
TargetNodeID: target,
|
||||
})
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
if !merr.Ok(resp) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, resp.GetReason())))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"msg": "OK"}`))
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
@ -30,19 +31,25 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type ProxyManagementSuite struct {
|
||||
suite.Suite
|
||||
|
||||
datacoord *mocks.MockDataCoordClient
|
||||
proxy *Proxy
|
||||
querycoord *mocks.MockQueryCoordClient
|
||||
datacoord *mocks.MockDataCoordClient
|
||||
proxy *Proxy
|
||||
}
|
||||
|
||||
func (s *ProxyManagementSuite) SetupTest() {
|
||||
s.datacoord = mocks.NewMockDataCoordClient(s.T())
|
||||
s.querycoord = mocks.NewMockQueryCoordClient(s.T())
|
||||
|
||||
s.proxy = &Proxy{
|
||||
dataCoord: s.datacoord,
|
||||
dataCoord: s.datacoord,
|
||||
queryCoord: s.querycoord,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -158,6 +165,527 @@ func (s *ProxyManagementSuite) TestResumeDatacoordGC() {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *ProxyManagementSuite) TestListQueryNode() {
|
||||
s.Run("normal", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().ListQueryNode(mock.Anything, mock.Anything).Return(&querypb.ListQueryNodeResponse{
|
||||
Status: merr.Success(),
|
||||
NodeInfos: []*querypb.NodeInfo{
|
||||
{
|
||||
ID: 1,
|
||||
Address: "localhost",
|
||||
State: "Healthy",
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, mgrListQueryNode, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.ListQueryNode(recorder, req)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(`{"nodeInfos":[{"ID":1,"address":"localhost","state":"Healthy"}]}`, recorder.Body.String())
|
||||
})
|
||||
|
||||
s.Run("return_error", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().ListQueryNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
|
||||
req, err := http.NewRequest(http.MethodPost, mgrListQueryNode, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.ListQueryNode(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
|
||||
s.Run("return_failure", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().ListQueryNode(mock.Anything, mock.Anything).Return(&querypb.ListQueryNodeResponse{
|
||||
Status: merr.Status(merr.ErrServiceNotReady),
|
||||
}, nil)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, mgrListQueryNode, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.ListQueryNode(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ProxyManagementSuite) TestGetQueryNodeDistribution() {
|
||||
s.Run("normal", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().GetQueryNodeDistribution(mock.Anything, mock.Anything).Return(&querypb.GetQueryNodeDistributionResponse{
|
||||
Status: merr.Success(),
|
||||
ID: 1,
|
||||
ChannelNames: []string{"channel-1"},
|
||||
SealedSegmentIDs: []int64{1, 2, 3},
|
||||
}, nil)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader("node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.GetQueryNodeDistribution(recorder, req)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(`{"ID":1,"channel_names":["channel-1"],"sealed_segmentIDs":[1,2,3]}`, recorder.Body.String())
|
||||
})
|
||||
|
||||
s.Run("return_error", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
// test invalid request body
|
||||
req, err := http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, nil)
|
||||
s.Require().NoError(err)
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.GetQueryNodeDistribution(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test miss requested param
|
||||
req, err = http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader(""))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.GetQueryNodeDistribution(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test rpc return error
|
||||
s.querycoord.EXPECT().GetQueryNodeDistribution(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
|
||||
req, err = http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader("node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.GetQueryNodeDistribution(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
|
||||
s.Run("return_failure", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().GetQueryNodeDistribution(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
|
||||
req, err := http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader("node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.GetQueryNodeDistribution(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ProxyManagementSuite) TestSuspendQueryCoordBalance() {
|
||||
s.Run("normal", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().SuspendBalance(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryCoordBalance, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.SuspendQueryCoordBalance(recorder, req)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
|
||||
})
|
||||
|
||||
s.Run("return_error", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().SuspendBalance(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
|
||||
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryCoordBalance, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.SuspendQueryCoordBalance(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
|
||||
s.Run("return_failure", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().SuspendBalance(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
|
||||
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryCoordBalance, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.SuspendQueryCoordBalance(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ProxyManagementSuite) TestResumeQueryCoordBalance() {
|
||||
s.Run("normal", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().ResumeBalance(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryCoordBalance, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.ResumeQueryCoordBalance(recorder, req)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
|
||||
})
|
||||
|
||||
s.Run("return_error", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().ResumeBalance(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
|
||||
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryCoordBalance, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.ResumeQueryCoordBalance(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
|
||||
s.Run("return_failure", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().ResumeBalance(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
|
||||
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryCoordBalance, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.ResumeQueryCoordBalance(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ProxyManagementSuite) TestSuspendQueryNode() {
|
||||
s.Run("normal", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().SuspendNode(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader("node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.SuspendQueryNode(recorder, req)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
|
||||
})
|
||||
|
||||
s.Run("return_error", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
// test invalid request body
|
||||
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryNode, nil)
|
||||
s.Require().NoError(err)
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.SuspendQueryNode(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test miss requested param
|
||||
req, err = http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader(""))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.SuspendQueryNode(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test rpc return error
|
||||
s.querycoord.EXPECT().SuspendNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
|
||||
req, err = http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader("node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.SuspendQueryNode(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
|
||||
s.Run("return_failure", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().SuspendNode(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
|
||||
req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader("node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.SuspendQueryNode(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ProxyManagementSuite) TestResumeQueryNode() {
|
||||
s.Run("normal", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().ResumeNode(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader("node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.ResumeQueryNode(recorder, req)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
|
||||
})
|
||||
|
||||
s.Run("return_error", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
// test invalid request body
|
||||
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryNode, nil)
|
||||
s.Require().NoError(err)
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.ResumeQueryNode(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test miss requested param
|
||||
req, err = http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader(""))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.ResumeQueryNode(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test rpc return error
|
||||
s.querycoord.EXPECT().ResumeNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
|
||||
req, err = http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader("node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.ResumeQueryNode(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
|
||||
s.Run("return_failure", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().ResumeNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
|
||||
req, err := http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader("node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.ResumeQueryNode(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ProxyManagementSuite) TestTransferSegment() {
|
||||
s.Run("normal", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().TransferSegment(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1©_mode=false"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.TransferSegment(recorder, req)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
|
||||
|
||||
// test use default param
|
||||
req, err = http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.TransferSegment(recorder, req)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
|
||||
})
|
||||
|
||||
s.Run("return_error", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
// test invalid request body
|
||||
req, err := http.NewRequest(http.MethodPost, mgrTransferSegment, nil)
|
||||
s.Require().NoError(err)
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.TransferSegment(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test miss requested param
|
||||
req, err = http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader(""))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.TransferSegment(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test rpc return error
|
||||
s.querycoord.EXPECT().TransferSegment(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
|
||||
req, err = http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.TransferSegment(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
|
||||
s.Run("return_failure", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().TransferSegment(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
|
||||
req, err := http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.TransferSegment(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ProxyManagementSuite) TestTransferChannel() {
|
||||
s.Run("normal", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().TransferChannel(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1©_mode=false"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.TransferChannel(recorder, req)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
|
||||
|
||||
// test use default param
|
||||
req, err = http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.TransferChannel(recorder, req)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
|
||||
})
|
||||
|
||||
s.Run("return_error", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
// test invalid request body
|
||||
req, err := http.NewRequest(http.MethodPost, mgrTransferChannel, nil)
|
||||
s.Require().NoError(err)
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.TransferChannel(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test miss requested param
|
||||
req, err = http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader(""))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.TransferChannel(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test rpc return error
|
||||
s.querycoord.EXPECT().TransferChannel(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
|
||||
req, err = http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.TransferChannel(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
|
||||
s.Run("return_failure", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().TransferChannel(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
|
||||
req, err := http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.TransferChannel(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ProxyManagementSuite) TestCheckQueryNodeDistribution() {
|
||||
s.Run("normal", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().CheckQueryNodeDistribution(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader("source_node_id=1&target_node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.CheckQueryNodeDistribution(recorder, req)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(`{"msg": "OK"}`, recorder.Body.String())
|
||||
})
|
||||
|
||||
s.Run("return_error", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
// test invalid request body
|
||||
req, err := http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, nil)
|
||||
s.Require().NoError(err)
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.CheckQueryNodeDistribution(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test miss requested param
|
||||
req, err = http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader(""))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.CheckQueryNodeDistribution(recorder, req)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// test rpc return error
|
||||
s.querycoord.EXPECT().CheckQueryNodeDistribution(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error"))
|
||||
req, err = http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader("source_node_id=1&target_node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder = httptest.NewRecorder()
|
||||
s.proxy.CheckQueryNodeDistribution(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
|
||||
s.Run("return_failure", func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.querycoord.EXPECT().CheckQueryNodeDistribution(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
|
||||
req, err := http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader("source_node_id=1&target_node_id=1"))
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
recorder := httptest.NewRecorder()
|
||||
s.proxy.CheckQueryNodeDistribution(recorder, req)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManagement(t *testing.T) {
|
||||
suite.Run(t, new(ProxyManagementSuite))
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/session"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/task"
|
||||
|
@ -57,8 +59,8 @@ var (
|
|||
)
|
||||
|
||||
type Balance interface {
|
||||
AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan
|
||||
AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan
|
||||
AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan
|
||||
AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan
|
||||
BalanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)
|
||||
}
|
||||
|
||||
|
@ -67,7 +69,15 @@ type RoundRobinBalancer struct {
|
|||
nodeManager *session.NodeManager
|
||||
}
|
||||
|
||||
func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
|
||||
func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
|
||||
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
|
||||
if !manualBalance {
|
||||
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
|
||||
info := b.nodeManager.Get(node)
|
||||
return info != nil && info.GetState() == session.NodeStateNormal
|
||||
})
|
||||
}
|
||||
|
||||
nodesInfo := b.getNodes(nodes)
|
||||
if len(nodesInfo) == 0 {
|
||||
return nil
|
||||
|
@ -90,7 +100,14 @@ func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.
|
|||
return ret
|
||||
}
|
||||
|
||||
func (b *RoundRobinBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan {
|
||||
func (b *RoundRobinBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
|
||||
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
|
||||
if !manualBalance {
|
||||
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
|
||||
info := b.nodeManager.Get(node)
|
||||
return info != nil && info.GetState() == session.NodeStateNormal
|
||||
})
|
||||
}
|
||||
nodesInfo := b.getNodes(nodes)
|
||||
if len(nodesInfo) == 0 {
|
||||
return nil
|
||||
|
|
|
@ -97,7 +97,7 @@ func (suite *BalanceTestSuite) TestAssignBalance() {
|
|||
suite.mockScheduler.EXPECT().GetNodeSegmentDelta(c.nodeIDs[i]).Return(c.deltaCnts[i])
|
||||
}
|
||||
}
|
||||
plans := suite.roundRobinBalancer.AssignSegment(0, c.assignments, c.nodeIDs)
|
||||
plans := suite.roundRobinBalancer.AssignSegment(0, c.assignments, c.nodeIDs, false)
|
||||
suite.ElementsMatch(c.expectPlans, plans)
|
||||
})
|
||||
}
|
||||
|
@ -161,7 +161,7 @@ func (suite *BalanceTestSuite) TestAssignChannel() {
|
|||
suite.mockScheduler.EXPECT().GetNodeChannelDelta(c.nodeIDs[i]).Return(c.deltaCnts[i])
|
||||
}
|
||||
}
|
||||
plans := suite.roundRobinBalancer.AssignChannel(c.assignments, c.nodeIDs)
|
||||
plans := suite.roundRobinBalancer.AssignChannel(c.assignments, c.nodeIDs, false)
|
||||
suite.ElementsMatch(c.expectPlans, plans)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -20,13 +20,13 @@ func (_m *MockBalancer) EXPECT() *MockBalancer_Expecter {
|
|||
return &MockBalancer_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// AssignChannel provides a mock function with given fields: channels, nodes
|
||||
func (_m *MockBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan {
|
||||
ret := _m.Called(channels, nodes)
|
||||
// AssignChannel provides a mock function with given fields: channels, nodes, manualBalance
|
||||
func (_m *MockBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
|
||||
ret := _m.Called(channels, nodes, manualBalance)
|
||||
|
||||
var r0 []ChannelAssignPlan
|
||||
if rf, ok := ret.Get(0).(func([]*meta.DmChannel, []int64) []ChannelAssignPlan); ok {
|
||||
r0 = rf(channels, nodes)
|
||||
if rf, ok := ret.Get(0).(func([]*meta.DmChannel, []int64, bool) []ChannelAssignPlan); ok {
|
||||
r0 = rf(channels, nodes, manualBalance)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]ChannelAssignPlan)
|
||||
|
@ -44,13 +44,14 @@ type MockBalancer_AssignChannel_Call struct {
|
|||
// AssignChannel is a helper method to define mock.On call
|
||||
// - channels []*meta.DmChannel
|
||||
// - nodes []int64
|
||||
func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}) *MockBalancer_AssignChannel_Call {
|
||||
return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes)}
|
||||
// - manualBalance bool
|
||||
func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignChannel_Call {
|
||||
return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes, manualBalance)}
|
||||
}
|
||||
|
||||
func (_c *MockBalancer_AssignChannel_Call) Run(run func(channels []*meta.DmChannel, nodes []int64)) *MockBalancer_AssignChannel_Call {
|
||||
func (_c *MockBalancer_AssignChannel_Call) Run(run func(channels []*meta.DmChannel, nodes []int64, manualBalance bool)) *MockBalancer_AssignChannel_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]*meta.DmChannel), args[1].([]int64))
|
||||
run(args[0].([]*meta.DmChannel), args[1].([]int64), args[2].(bool))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -60,18 +61,18 @@ func (_c *MockBalancer_AssignChannel_Call) Return(_a0 []ChannelAssignPlan) *Mock
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockBalancer_AssignChannel_Call) RunAndReturn(run func([]*meta.DmChannel, []int64) []ChannelAssignPlan) *MockBalancer_AssignChannel_Call {
|
||||
func (_c *MockBalancer_AssignChannel_Call) RunAndReturn(run func([]*meta.DmChannel, []int64, bool) []ChannelAssignPlan) *MockBalancer_AssignChannel_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// AssignSegment provides a mock function with given fields: collectionID, segments, nodes
|
||||
func (_m *MockBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
|
||||
ret := _m.Called(collectionID, segments, nodes)
|
||||
// AssignSegment provides a mock function with given fields: collectionID, segments, nodes, manualBalance
|
||||
func (_m *MockBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
|
||||
ret := _m.Called(collectionID, segments, nodes, manualBalance)
|
||||
|
||||
var r0 []SegmentAssignPlan
|
||||
if rf, ok := ret.Get(0).(func(int64, []*meta.Segment, []int64) []SegmentAssignPlan); ok {
|
||||
r0 = rf(collectionID, segments, nodes)
|
||||
if rf, ok := ret.Get(0).(func(int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan); ok {
|
||||
r0 = rf(collectionID, segments, nodes, manualBalance)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]SegmentAssignPlan)
|
||||
|
@ -90,13 +91,14 @@ type MockBalancer_AssignSegment_Call struct {
|
|||
// - collectionID int64
|
||||
// - segments []*meta.Segment
|
||||
// - nodes []int64
|
||||
func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}) *MockBalancer_AssignSegment_Call {
|
||||
return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes)}
|
||||
// - manualBalance bool
|
||||
func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignSegment_Call {
|
||||
return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes, manualBalance)}
|
||||
}
|
||||
|
||||
func (_c *MockBalancer_AssignSegment_Call) Run(run func(collectionID int64, segments []*meta.Segment, nodes []int64)) *MockBalancer_AssignSegment_Call {
|
||||
func (_c *MockBalancer_AssignSegment_Call) Run(run func(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool)) *MockBalancer_AssignSegment_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(int64), args[1].([]*meta.Segment), args[2].([]int64))
|
||||
run(args[0].(int64), args[1].([]*meta.Segment), args[2].([]int64), args[3].(bool))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -106,7 +108,7 @@ func (_c *MockBalancer_AssignSegment_Call) Return(_a0 []SegmentAssignPlan) *Mock
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockBalancer_AssignSegment_Call) RunAndReturn(run func(int64, []*meta.Segment, []int64) []SegmentAssignPlan) *MockBalancer_AssignSegment_Call {
|
||||
func (_c *MockBalancer_AssignSegment_Call) RunAndReturn(run func(int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan) *MockBalancer_AssignSegment_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
|
|
@ -41,7 +41,15 @@ type RowCountBasedBalancer struct {
|
|||
|
||||
// AssignSegment, when row count based balancer assign segments, it will assign segment to node with least global row count.
|
||||
// try to make every query node has same row count.
|
||||
func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
|
||||
func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
|
||||
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
|
||||
if !manualBalance {
|
||||
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
|
||||
info := b.nodeManager.Get(node)
|
||||
return info != nil && info.GetState() == session.NodeStateNormal
|
||||
})
|
||||
}
|
||||
|
||||
nodeItems := b.convertToNodeItemsBySegment(nodes)
|
||||
if len(nodeItems) == 0 {
|
||||
return nil
|
||||
|
@ -75,7 +83,15 @@ func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*me
|
|||
|
||||
// AssignSegment, when row count based balancer assign segments, it will assign channel to node with least global channel count.
|
||||
// try to make every query node has channel count
|
||||
func (b *RowCountBasedBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan {
|
||||
func (b *RowCountBasedBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
|
||||
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
|
||||
if !manualBalance {
|
||||
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
|
||||
info := b.nodeManager.Get(node)
|
||||
return info != nil && info.GetState() == session.NodeStateNormal
|
||||
})
|
||||
}
|
||||
|
||||
nodeItems := b.convertToNodeItemsByChannel(nodes)
|
||||
nodeItems = lo.Shuffle(nodeItems)
|
||||
if len(nodeItems) == 0 {
|
||||
|
@ -215,7 +231,7 @@ func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, on
|
|||
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
|
||||
segment.GetLevel() != datapb.SegmentLevel_L0
|
||||
})
|
||||
plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes)
|
||||
plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes, false)
|
||||
for i := range plans {
|
||||
plans[i].From = nodeID
|
||||
plans[i].Replica = replica
|
||||
|
@ -283,7 +299,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode
|
|||
return nil
|
||||
}
|
||||
|
||||
segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, nodesWithLessRow)
|
||||
segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, nodesWithLessRow, false)
|
||||
for i := range segmentPlans {
|
||||
segmentPlans[i].From = segmentPlans[i].Segment.Node
|
||||
segmentPlans[i].Replica = replica
|
||||
|
@ -296,7 +312,7 @@ func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, on
|
|||
channelPlans := make([]ChannelAssignPlan, 0)
|
||||
for _, nodeID := range offlineNodes {
|
||||
dmChannels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)
|
||||
plans := b.AssignChannel(dmChannels, onlineNodes)
|
||||
plans := b.AssignChannel(dmChannels, onlineNodes, false)
|
||||
for i := range plans {
|
||||
plans[i].From = nodeID
|
||||
plans[i].Replica = replica
|
||||
|
@ -334,7 +350,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode
|
|||
return nil
|
||||
}
|
||||
|
||||
channelPlans := b.AssignChannel(channelsToMove, nodeWithLessChannel)
|
||||
channelPlans := b.AssignChannel(channelsToMove, nodeWithLessChannel, false)
|
||||
for i := range channelPlans {
|
||||
channelPlans[i].From = channelPlans[i].Channel.Node
|
||||
channelPlans[i].Replica = replica
|
||||
|
|
|
@ -136,12 +136,67 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() {
|
|||
nodeInfo.SetState(c.states[i])
|
||||
suite.balancer.nodeManager.Add(nodeInfo)
|
||||
}
|
||||
plans := balancer.AssignSegment(0, c.assignments, c.nodes)
|
||||
plans := balancer.AssignSegment(0, c.assignments, c.nodes, false)
|
||||
assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans, plans)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *RowCountBasedBalancerTestSuite) TestSuspendNode() {
|
||||
cases := []struct {
|
||||
name string
|
||||
distributions map[int64][]*meta.Segment
|
||||
assignments []*meta.Segment
|
||||
nodes []int64
|
||||
segmentCnts []int
|
||||
states []session.State
|
||||
expectPlans []SegmentAssignPlan
|
||||
}{
|
||||
{
|
||||
name: "test suspend node",
|
||||
distributions: map[int64][]*meta.Segment{
|
||||
2: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 20}, Node: 2}},
|
||||
3: {{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 30}, Node: 3}},
|
||||
},
|
||||
assignments: []*meta.Segment{
|
||||
{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 5}},
|
||||
{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10}},
|
||||
{SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 15}},
|
||||
},
|
||||
nodes: []int64{1, 2, 3, 4},
|
||||
states: []session.State{session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend},
|
||||
segmentCnts: []int{0, 1, 1, 0},
|
||||
expectPlans: []SegmentAssignPlan{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
suite.Run(c.name, func() {
|
||||
// I do not find a better way to do the setup and teardown work for subtests yet.
|
||||
// If you do, please replace with it.
|
||||
suite.SetupSuite()
|
||||
defer suite.TearDownTest()
|
||||
balancer := suite.balancer
|
||||
for node, s := range c.distributions {
|
||||
balancer.dist.SegmentDistManager.Update(node, s...)
|
||||
}
|
||||
for i := range c.nodes {
|
||||
nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: c.nodes[i],
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
})
|
||||
nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i]))
|
||||
nodeInfo.SetState(c.states[i])
|
||||
suite.balancer.nodeManager.Add(nodeInfo)
|
||||
}
|
||||
plans := balancer.AssignSegment(0, c.assignments, c.nodes, false)
|
||||
// all node has been suspend, so no node to assign segment
|
||||
suite.ElementsMatch(c.expectPlans, plans)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
|
||||
cases := []struct {
|
||||
name string
|
||||
|
@ -888,7 +943,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegmentWithGrowing() {
|
|||
NumOfGrowingRows: 50,
|
||||
}
|
||||
suite.balancer.dist.LeaderViewManager.Update(1, leaderView)
|
||||
plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions))
|
||||
plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false)
|
||||
for _, p := range plans {
|
||||
suite.Equal(int64(2), p.To)
|
||||
}
|
||||
|
|
|
@ -50,7 +50,15 @@ func NewScoreBasedBalancer(scheduler task.Scheduler,
|
|||
}
|
||||
|
||||
// AssignSegment got a segment list, and try to assign each segment to node's with lowest score
|
||||
func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan {
|
||||
func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
|
||||
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
|
||||
if !manualBalance {
|
||||
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
|
||||
info := b.nodeManager.Get(node)
|
||||
return info != nil && info.GetState() == session.NodeStateNormal
|
||||
})
|
||||
}
|
||||
|
||||
// calculate each node's score
|
||||
nodeItems := b.convertToNodeItems(collectionID, nodes)
|
||||
if len(nodeItems) == 0 {
|
||||
|
@ -87,7 +95,8 @@ func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.
|
|||
sourceNode := nodeItemsMap[s.Node]
|
||||
// if segment's node exist, which means this segment comes from balancer. we should consider the benefit
|
||||
// if the segment reassignment doesn't got enough benefit, we should skip this reassignment
|
||||
if sourceNode != nil && !b.hasEnoughBenefit(sourceNode, targetNode, priorityChange) {
|
||||
// notice: we should skip benefit check for manual balance
|
||||
if !manualBalance && sourceNode != nil && !b.hasEnoughBenefit(sourceNode, targetNode, priorityChange) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -249,7 +258,7 @@ func (b *ScoreBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlin
|
|||
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
|
||||
segment.GetLevel() != datapb.SegmentLevel_L0
|
||||
})
|
||||
plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes)
|
||||
plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes, false)
|
||||
for i := range plans {
|
||||
plans[i].From = nodeID
|
||||
plans[i].Replica = replica
|
||||
|
@ -313,7 +322,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNodes [
|
|||
return nil
|
||||
}
|
||||
|
||||
segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, onlineNodes)
|
||||
segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, onlineNodes, false)
|
||||
for i := range segmentPlans {
|
||||
segmentPlans[i].From = segmentPlans[i].Segment.Node
|
||||
segmentPlans[i].Replica = replica
|
||||
|
|
|
@ -232,13 +232,68 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() {
|
|||
suite.balancer.nodeManager.Add(nodeInfo)
|
||||
}
|
||||
for i := range c.collectionIDs {
|
||||
plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes)
|
||||
plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes, false)
|
||||
assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans[i], plans)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *ScoreBasedBalancerTestSuite) TestSuspendNode() {
|
||||
cases := []struct {
|
||||
name string
|
||||
distributions map[int64][]*meta.Segment
|
||||
assignments []*meta.Segment
|
||||
nodes []int64
|
||||
segmentCnts []int
|
||||
states []session.State
|
||||
expectPlans []SegmentAssignPlan
|
||||
}{
|
||||
{
|
||||
name: "test suspend node",
|
||||
distributions: map[int64][]*meta.Segment{
|
||||
2: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 20}, Node: 2}},
|
||||
3: {{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 30}, Node: 3}},
|
||||
},
|
||||
assignments: []*meta.Segment{
|
||||
{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 5}},
|
||||
{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10}},
|
||||
{SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 15}},
|
||||
},
|
||||
nodes: []int64{1, 2, 3, 4},
|
||||
states: []session.State{session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend},
|
||||
segmentCnts: []int{0, 1, 1, 0},
|
||||
expectPlans: []SegmentAssignPlan{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
suite.Run(c.name, func() {
|
||||
// I do not find a better way to do the setup and teardown work for subtests yet.
|
||||
// If you do, please replace with it.
|
||||
suite.SetupSuite()
|
||||
defer suite.TearDownTest()
|
||||
balancer := suite.balancer
|
||||
for node, s := range c.distributions {
|
||||
balancer.dist.SegmentDistManager.Update(node, s...)
|
||||
}
|
||||
for i := range c.nodes {
|
||||
nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: c.nodes[i],
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
})
|
||||
nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i]))
|
||||
nodeInfo.SetState(c.states[i])
|
||||
suite.balancer.nodeManager.Add(nodeInfo)
|
||||
}
|
||||
plans := balancer.AssignSegment(0, c.assignments, c.nodes, false)
|
||||
// all node has been suspend, so no node to assign segment
|
||||
suite.ElementsMatch(c.expectPlans, plans)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *ScoreBasedBalancerTestSuite) TestAssignSegmentWithGrowing() {
|
||||
suite.SetupSuite()
|
||||
defer suite.TearDownTest()
|
||||
|
@ -279,7 +334,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegmentWithGrowing() {
|
|||
NumOfGrowingRows: 50,
|
||||
}
|
||||
suite.balancer.dist.LeaderViewManager.Update(1, leaderView)
|
||||
plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions))
|
||||
plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false)
|
||||
for _, p := range plans {
|
||||
suite.Equal(int64(2), p.To)
|
||||
}
|
||||
|
|
|
@ -222,7 +222,7 @@ func (c *ChannelChecker) createChannelLoadTask(ctx context.Context, channels []*
|
|||
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool {
|
||||
return !outboundNodes.Contain(node)
|
||||
})
|
||||
plans := c.balancer.AssignChannel(channels, availableNodes)
|
||||
plans := c.balancer.AssignChannel(channels, availableNodes, false)
|
||||
for i := range plans {
|
||||
plans[i].Replica = replica
|
||||
}
|
||||
|
|
|
@ -100,7 +100,7 @@ func (suite *ChannelCheckerTestSuite) setNodeAvailable(nodes ...int64) {
|
|||
|
||||
func (suite *ChannelCheckerTestSuite) createMockBalancer() balance.Balance {
|
||||
balancer := balance.NewMockBalancer(suite.T())
|
||||
balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything).Maybe().Return(func(channels []*meta.DmChannel, nodes []int64) []balance.ChannelAssignPlan {
|
||||
balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(channels []*meta.DmChannel, nodes []int64, _ bool) []balance.ChannelAssignPlan {
|
||||
plans := make([]balance.ChannelAssignPlan, 0, len(channels))
|
||||
for i, c := range channels {
|
||||
plan := balance.ChannelAssignPlan{
|
||||
|
|
|
@ -134,11 +134,11 @@ func (suite *CheckerControllerSuite) TestBasic() {
|
|||
|
||||
assignSegCounter := atomic.NewInt32(0)
|
||||
assingChanCounter := atomic.NewInt32(0)
|
||||
suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(i1 int64, s []*meta.Segment, i2 []int64) []balance.SegmentAssignPlan {
|
||||
suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(i1 int64, s []*meta.Segment, i2 []int64, i4 bool) []balance.SegmentAssignPlan {
|
||||
assignSegCounter.Inc()
|
||||
return nil
|
||||
})
|
||||
suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything).RunAndReturn(func(dc []*meta.DmChannel, i []int64) []balance.ChannelAssignPlan {
|
||||
suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(dc []*meta.DmChannel, i []int64, _ bool) []balance.ChannelAssignPlan {
|
||||
assingChanCounter.Inc()
|
||||
return nil
|
||||
})
|
||||
|
|
|
@ -400,7 +400,7 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments []
|
|||
SegmentInfo: s,
|
||||
}
|
||||
})
|
||||
shardPlans := c.balancer.AssignSegment(replica.CollectionID, segmentInfos, availableNodes)
|
||||
shardPlans := c.balancer.AssignSegment(replica.CollectionID, segmentInfos, availableNodes, false)
|
||||
for i := range shardPlans {
|
||||
shardPlans[i].Replica = replica
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ func (suite *SegmentCheckerTestSuite) TearDownTest() {
|
|||
|
||||
func (suite *SegmentCheckerTestSuite) createMockBalancer() balance.Balance {
|
||||
balancer := balance.NewMockBalancer(suite.T())
|
||||
balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(collectionID int64, segments []*meta.Segment, nodes []int64) []balance.SegmentAssignPlan {
|
||||
balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(collectionID int64, segments []*meta.Segment, nodes []int64, _ bool) []balance.SegmentAssignPlan {
|
||||
plans := make([]balance.SegmentAssignPlan, 0, len(segments))
|
||||
for i, s := range segments {
|
||||
plan := balance.SegmentAssignPlan{
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -87,78 +88,61 @@ func (s *Server) getCollectionSegmentInfo(collection int64) []*querypb.SegmentIn
|
|||
return lo.Values(infos)
|
||||
}
|
||||
|
||||
// parseBalanceRequest parses the load balance request,
|
||||
// returns the collection, replica, and segments
|
||||
func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRequest, replica *meta.Replica) error {
|
||||
srcNode := req.GetSourceNodeIDs()[0]
|
||||
dstNodeSet := typeutil.NewUniqueSet(req.GetDstNodeIDs()...)
|
||||
if dstNodeSet.Len() == 0 {
|
||||
outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica)
|
||||
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool {
|
||||
stop, err := s.nodeMgr.IsStoppingNode(node)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return !outboundNodes.Contain(node) && !stop
|
||||
})
|
||||
dstNodeSet.Insert(availableNodes...)
|
||||
// generate balance segment task and submit to scheduler
|
||||
// if sync is true, this func call will wait task to finish, until reach the segment task timeout
|
||||
// if copyMode is true, this func call will generate a load segment task, instead a balance segment task
|
||||
func (s *Server) balanceSegments(ctx context.Context,
|
||||
collectionID int64,
|
||||
replica *meta.Replica,
|
||||
srcNode int64,
|
||||
dstNodes []int64,
|
||||
segments []*meta.Segment,
|
||||
sync bool,
|
||||
copyMode bool,
|
||||
) error {
|
||||
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID), zap.Int64("srcNode", srcNode))
|
||||
plans := s.balancer.AssignSegment(collectionID, segments, dstNodes, true)
|
||||
for i := range plans {
|
||||
plans[i].From = srcNode
|
||||
plans[i].Replica = replica
|
||||
}
|
||||
dstNodeSet.Remove(srcNode)
|
||||
|
||||
toBalance := typeutil.NewSet[*meta.Segment]()
|
||||
// Only balance segments in targets
|
||||
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(srcNode))
|
||||
segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool {
|
||||
return s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
|
||||
})
|
||||
allSegments := make(map[int64]*meta.Segment)
|
||||
for _, segment := range segments {
|
||||
allSegments[segment.GetID()] = segment
|
||||
}
|
||||
|
||||
if len(req.GetSealedSegmentIDs()) == 0 {
|
||||
toBalance.Insert(segments...)
|
||||
} else {
|
||||
for _, segmentID := range req.GetSealedSegmentIDs() {
|
||||
segment, ok := allSegments[segmentID]
|
||||
if !ok {
|
||||
return fmt.Errorf("segment %d not found in source node %d", segmentID, srcNode)
|
||||
}
|
||||
toBalance.Insert(segment)
|
||||
}
|
||||
}
|
||||
|
||||
log := log.With(
|
||||
zap.Int64("collectionID", req.GetCollectionID()),
|
||||
zap.Int64("srcNodeID", srcNode),
|
||||
zap.Int64s("destNodeIDs", dstNodeSet.Collect()),
|
||||
)
|
||||
plans := s.balancer.AssignSegment(req.GetCollectionID(), toBalance.Collect(), dstNodeSet.Collect())
|
||||
tasks := make([]task.Task, 0, len(plans))
|
||||
for _, plan := range plans {
|
||||
log.Info("manually balance segment...",
|
||||
zap.Int64("destNodeID", plan.To),
|
||||
zap.Int64("replica", plan.Replica.ID),
|
||||
zap.String("channel", plan.Segment.InsertChannel),
|
||||
zap.Int64("from", plan.From),
|
||||
zap.Int64("to", plan.To),
|
||||
zap.Int64("segmentID", plan.Segment.GetID()),
|
||||
)
|
||||
task, err := task.NewSegmentTask(ctx,
|
||||
actions := make([]task.Action, 0)
|
||||
loadAction := task.NewSegmentActionWithScope(plan.To, task.ActionTypeGrow, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical)
|
||||
actions = append(actions, loadAction)
|
||||
if !copyMode {
|
||||
// if in copy mode, the release action will be skip
|
||||
releaseAction := task.NewSegmentActionWithScope(plan.From, task.ActionTypeReduce, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical)
|
||||
actions = append(actions, releaseAction)
|
||||
}
|
||||
|
||||
task, err := task.NewSegmentTask(s.ctx,
|
||||
Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond),
|
||||
task.WrapIDSource(req.GetBase().GetMsgID()),
|
||||
req.GetCollectionID(),
|
||||
replica,
|
||||
task.NewSegmentActionWithScope(plan.To, task.ActionTypeGrow, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical),
|
||||
task.NewSegmentActionWithScope(srcNode, task.ActionTypeReduce, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical),
|
||||
utils.ManualBalance,
|
||||
collectionID,
|
||||
plan.Replica,
|
||||
actions...,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn("create segment task for balance failed",
|
||||
zap.Int64("collection", req.GetCollectionID()),
|
||||
zap.Int64("replica", replica.GetID()),
|
||||
zap.Int64("replica", plan.Replica.ID),
|
||||
zap.String("channel", plan.Segment.InsertChannel),
|
||||
zap.Int64("from", srcNode),
|
||||
zap.Int64("from", plan.From),
|
||||
zap.Int64("to", plan.To),
|
||||
zap.Int64("segmentID", plan.Segment.GetID()),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
task.SetReason("manual balance")
|
||||
err = s.taskScheduler.Add(task)
|
||||
if err != nil {
|
||||
task.Cancel(err)
|
||||
|
@ -166,7 +150,92 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe
|
|||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
return task.Wait(ctx, Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), tasks...)
|
||||
|
||||
if sync {
|
||||
err := task.Wait(ctx, Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), tasks...)
|
||||
if err != nil {
|
||||
msg := "failed to wait all balance task finished"
|
||||
log.Warn(msg, zap.Error(err))
|
||||
return errors.Wrap(err, msg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generate balance channel task and submit to scheduler
|
||||
// if sync is true, this func call will wait task to finish, until reach the channel task timeout
|
||||
// if copyMode is true, this func call will generate a load channel task, instead a balance channel task
|
||||
func (s *Server) balanceChannels(ctx context.Context,
|
||||
collectionID int64,
|
||||
replica *meta.Replica,
|
||||
srcNode int64,
|
||||
dstNodes []int64,
|
||||
channels []*meta.DmChannel,
|
||||
sync bool,
|
||||
copyMode bool,
|
||||
) error {
|
||||
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID))
|
||||
|
||||
plans := s.balancer.AssignChannel(channels, dstNodes, true)
|
||||
for i := range plans {
|
||||
plans[i].From = srcNode
|
||||
plans[i].Replica = replica
|
||||
}
|
||||
|
||||
tasks := make([]task.Task, 0, len(plans))
|
||||
for _, plan := range plans {
|
||||
log.Info("manually balance channel...",
|
||||
zap.Int64("replica", plan.Replica.ID),
|
||||
zap.String("channel", plan.Channel.GetChannelName()),
|
||||
zap.Int64("from", plan.From),
|
||||
zap.Int64("to", plan.To),
|
||||
)
|
||||
|
||||
actions := make([]task.Action, 0)
|
||||
loadAction := task.NewChannelAction(plan.To, task.ActionTypeGrow, plan.Channel.GetChannelName())
|
||||
actions = append(actions, loadAction)
|
||||
if !copyMode {
|
||||
// if in copy mode, the release action will be skip
|
||||
releaseAction := task.NewChannelAction(plan.From, task.ActionTypeReduce, plan.Channel.GetChannelName())
|
||||
actions = append(actions, releaseAction)
|
||||
}
|
||||
task, err := task.NewChannelTask(s.ctx,
|
||||
Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond),
|
||||
utils.ManualBalance,
|
||||
collectionID,
|
||||
plan.Replica,
|
||||
actions...,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn("create channel task for balance failed",
|
||||
zap.Int64("replica", plan.Replica.ID),
|
||||
zap.String("channel", plan.Channel.GetChannelName()),
|
||||
zap.Int64("from", plan.From),
|
||||
zap.Int64("to", plan.To),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
task.SetReason("manual balance")
|
||||
err = s.taskScheduler.Add(task)
|
||||
if err != nil {
|
||||
task.Cancel(err)
|
||||
return err
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
if sync {
|
||||
err := task.Wait(ctx, Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), tasks...)
|
||||
if err != nil {
|
||||
msg := "failed to wait all balance task finished"
|
||||
log.Warn(msg, zap.Error(err))
|
||||
return errors.Wrap(err, msg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(dragondriver): add more detail metrics
|
||||
|
|
|
@ -232,7 +232,7 @@ func (m *ReplicaManager) GetByCollection(collectionID typeutil.UniqueID) []*Repl
|
|||
m.rwmutex.RLock()
|
||||
defer m.rwmutex.RUnlock()
|
||||
|
||||
replicas := make([]*Replica, 0, 3)
|
||||
replicas := make([]*Replica, 0)
|
||||
for _, replica := range m.replicas {
|
||||
if replica.CollectionID == collectionID {
|
||||
replicas = append(replicas, replica)
|
||||
|
@ -255,6 +255,20 @@ func (m *ReplicaManager) GetByCollectionAndNode(collectionID, nodeID typeutil.Un
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *ReplicaManager) GetByNode(nodeID typeutil.UniqueID) []*Replica {
|
||||
m.rwmutex.RLock()
|
||||
defer m.rwmutex.RUnlock()
|
||||
|
||||
replicas := make([]*Replica, 0)
|
||||
for _, replica := range m.replicas {
|
||||
if replica.nodes.Contain(nodeID) {
|
||||
replicas = append(replicas, replica)
|
||||
}
|
||||
}
|
||||
|
||||
return replicas
|
||||
}
|
||||
|
||||
func (m *ReplicaManager) GetByCollectionAndRG(collectionID int64, rgName string) []*Replica {
|
||||
m.rwmutex.RLock()
|
||||
defer m.rwmutex.RUnlock()
|
||||
|
|
|
@ -102,6 +102,7 @@ func (suite *ReplicaManagerSuite) TestGet() {
|
|||
for _, replica := range replicas {
|
||||
suite.Equal(collection, replica.GetCollectionID())
|
||||
suite.Equal(replica, mgr.Get(replica.GetID()))
|
||||
suite.Equal(len(replica.Replica.GetNodes()), replica.Len())
|
||||
suite.Equal(replica.Replica.GetNodes(), replica.GetNodes())
|
||||
replicaNodes[replica.GetID()] = replica.Replica.GetNodes()
|
||||
nodes = append(nodes, replica.Replica.Nodes...)
|
||||
|
@ -117,6 +118,24 @@ func (suite *ReplicaManagerSuite) TestGet() {
|
|||
}
|
||||
}
|
||||
|
||||
func (suite *ReplicaManagerSuite) TestGetByNode() {
|
||||
mgr := suite.mgr
|
||||
|
||||
randomNodeID := int64(11111)
|
||||
testReplica1, err := mgr.spawn(3002, DefaultResourceGroupName)
|
||||
suite.NoError(err)
|
||||
testReplica1.AddNode(randomNodeID)
|
||||
|
||||
testReplica2, err := mgr.spawn(3002, DefaultResourceGroupName)
|
||||
suite.NoError(err)
|
||||
testReplica2.AddNode(randomNodeID)
|
||||
|
||||
mgr.Put(testReplica1, testReplica2)
|
||||
|
||||
replicas := mgr.GetByNode(randomNodeID)
|
||||
suite.Len(replicas, 2)
|
||||
}
|
||||
|
||||
func (suite *ReplicaManagerSuite) TestRecover() {
|
||||
mgr := suite.mgr
|
||||
|
||||
|
|
|
@ -0,0 +1,888 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querycoordv2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/kv"
|
||||
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
|
||||
"github.com/milvus-io/milvus/internal/metastore"
|
||||
"github.com/milvus-io/milvus/internal/metastore/kv/querycoord"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/balance"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/checkers"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/dist"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/job"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/params"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/session"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/task"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type OpsServiceSuite struct {
|
||||
suite.Suite
|
||||
|
||||
// Dependencies
|
||||
kv kv.MetaKv
|
||||
store metastore.QueryCoordCatalog
|
||||
dist *meta.DistributionManager
|
||||
meta *meta.Meta
|
||||
targetMgr *meta.TargetManager
|
||||
broker *meta.MockBroker
|
||||
targetObserver *observers.TargetObserver
|
||||
cluster *session.MockCluster
|
||||
nodeMgr *session.NodeManager
|
||||
jobScheduler *job.Scheduler
|
||||
taskScheduler *task.MockScheduler
|
||||
balancer balance.Balance
|
||||
|
||||
distMgr *meta.DistributionManager
|
||||
distController *dist.MockController
|
||||
checkerController *checkers.CheckerController
|
||||
|
||||
// Test object
|
||||
server *Server
|
||||
}
|
||||
|
||||
func (suite *OpsServiceSuite) SetupSuite() {
|
||||
paramtable.Init()
|
||||
}
|
||||
|
||||
func (suite *OpsServiceSuite) SetupTest() {
|
||||
config := params.GenerateEtcdConfig()
|
||||
cli, err := etcd.GetEtcdClient(
|
||||
config.UseEmbedEtcd.GetAsBool(),
|
||||
config.EtcdUseSSL.GetAsBool(),
|
||||
config.Endpoints.GetAsStrings(),
|
||||
config.EtcdTLSCert.GetValue(),
|
||||
config.EtcdTLSKey.GetValue(),
|
||||
config.EtcdTLSCACert.GetValue(),
|
||||
config.EtcdTLSMinVersion.GetValue())
|
||||
suite.Require().NoError(err)
|
||||
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())
|
||||
|
||||
suite.store = querycoord.NewCatalog(suite.kv)
|
||||
suite.dist = meta.NewDistributionManager()
|
||||
suite.nodeMgr = session.NewNodeManager()
|
||||
suite.meta = meta.NewMeta(params.RandomIncrementIDAllocator(), suite.store, suite.nodeMgr)
|
||||
suite.broker = meta.NewMockBroker(suite.T())
|
||||
suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta)
|
||||
suite.targetObserver = observers.NewTargetObserver(
|
||||
suite.meta,
|
||||
suite.targetMgr,
|
||||
suite.dist,
|
||||
suite.broker,
|
||||
suite.cluster,
|
||||
)
|
||||
suite.cluster = session.NewMockCluster(suite.T())
|
||||
suite.jobScheduler = job.NewScheduler()
|
||||
suite.taskScheduler = task.NewMockScheduler(suite.T())
|
||||
suite.jobScheduler.Start()
|
||||
suite.balancer = balance.NewScoreBasedBalancer(
|
||||
suite.taskScheduler,
|
||||
suite.nodeMgr,
|
||||
suite.dist,
|
||||
suite.meta,
|
||||
suite.targetMgr,
|
||||
)
|
||||
meta.GlobalFailedLoadCache = meta.NewFailedLoadCache()
|
||||
suite.distMgr = meta.NewDistributionManager()
|
||||
suite.distController = dist.NewMockController(suite.T())
|
||||
|
||||
suite.checkerController = checkers.NewCheckerController(suite.meta, suite.distMgr,
|
||||
suite.targetMgr, suite.balancer, suite.nodeMgr, suite.taskScheduler, suite.broker)
|
||||
|
||||
suite.server = &Server{
|
||||
kv: suite.kv,
|
||||
store: suite.store,
|
||||
session: sessionutil.NewSessionWithEtcd(context.Background(), Params.EtcdCfg.MetaRootPath.GetValue(), cli),
|
||||
metricsCacheManager: metricsinfo.NewMetricsCacheManager(),
|
||||
dist: suite.dist,
|
||||
meta: suite.meta,
|
||||
targetMgr: suite.targetMgr,
|
||||
broker: suite.broker,
|
||||
targetObserver: suite.targetObserver,
|
||||
nodeMgr: suite.nodeMgr,
|
||||
cluster: suite.cluster,
|
||||
jobScheduler: suite.jobScheduler,
|
||||
taskScheduler: suite.taskScheduler,
|
||||
balancer: suite.balancer,
|
||||
distController: suite.distController,
|
||||
ctx: context.Background(),
|
||||
checkerController: suite.checkerController,
|
||||
}
|
||||
suite.server.collectionObserver = observers.NewCollectionObserver(
|
||||
suite.server.dist,
|
||||
suite.server.meta,
|
||||
suite.server.targetMgr,
|
||||
suite.targetObserver,
|
||||
&checkers.CheckerController{},
|
||||
)
|
||||
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
}
|
||||
|
||||
func (suite *OpsServiceSuite) TestActiveCheckers() {
|
||||
// test server unhealthy
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
ctx := context.Background()
|
||||
resp, err := suite.server.ListCheckers(ctx, &querypb.ListCheckersRequest{})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp.Status))
|
||||
|
||||
resp1, err := suite.server.DeactivateChecker(ctx, &querypb.DeactivateCheckerRequest{})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp1))
|
||||
|
||||
resp2, err := suite.server.ActivateChecker(ctx, &querypb.ActivateCheckerRequest{})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp2))
|
||||
|
||||
// test active success
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
resp, err = suite.server.ListCheckers(ctx, &querypb.ListCheckersRequest{})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp.Status))
|
||||
suite.Len(resp.GetCheckerInfos(), 5)
|
||||
|
||||
resp4, err := suite.server.DeactivateChecker(ctx, &querypb.DeactivateCheckerRequest{
|
||||
CheckerID: int32(utils.ChannelChecker),
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp4))
|
||||
suite.False(suite.checkerController.IsActive(utils.ChannelChecker))
|
||||
|
||||
resp5, err := suite.server.ActivateChecker(ctx, &querypb.ActivateCheckerRequest{
|
||||
CheckerID: int32(utils.ChannelChecker),
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp5))
|
||||
suite.True(suite.checkerController.IsActive(utils.ChannelChecker))
|
||||
}
|
||||
|
||||
func (suite *OpsServiceSuite) TestListQueryNode() {
|
||||
// test server unhealthy
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
ctx := context.Background()
|
||||
resp, err := suite.server.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{})
|
||||
suite.NoError(err)
|
||||
suite.Equal(0, len(resp.GetNodeInfos()))
|
||||
suite.False(merr.Ok(resp.Status))
|
||||
// test server healthy
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: 111,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
resp, err = suite.server.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{})
|
||||
suite.NoError(err)
|
||||
suite.Equal(1, len(resp.GetNodeInfos()))
|
||||
}
|
||||
|
||||
func (suite *OpsServiceSuite) TestGetQueryNodeDistribution() {
|
||||
// test server unhealthy
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
ctx := context.Background()
|
||||
resp, err := suite.server.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp.Status))
|
||||
|
||||
// test node not found
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
resp, err = suite.server.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
|
||||
NodeID: 1,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp.Status))
|
||||
|
||||
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
// test success
|
||||
channels := []*meta.DmChannel{
|
||||
{
|
||||
VchannelInfo: &datapb.VchannelInfo{
|
||||
CollectionID: 1,
|
||||
ChannelName: "channel1",
|
||||
},
|
||||
Node: 1,
|
||||
},
|
||||
{
|
||||
VchannelInfo: &datapb.VchannelInfo{
|
||||
CollectionID: 1,
|
||||
ChannelName: "channel2",
|
||||
},
|
||||
Node: 1,
|
||||
},
|
||||
}
|
||||
|
||||
segments := []*meta.Segment{
|
||||
{
|
||||
SegmentInfo: &datapb.SegmentInfo{
|
||||
ID: 1,
|
||||
CollectionID: 1,
|
||||
PartitionID: 1,
|
||||
InsertChannel: "channel1",
|
||||
},
|
||||
Node: 1,
|
||||
},
|
||||
{
|
||||
SegmentInfo: &datapb.SegmentInfo{
|
||||
ID: 2,
|
||||
CollectionID: 1,
|
||||
PartitionID: 1,
|
||||
InsertChannel: "channel2",
|
||||
},
|
||||
Node: 1,
|
||||
},
|
||||
}
|
||||
suite.dist.ChannelDistManager.Update(1, channels...)
|
||||
suite.dist.SegmentDistManager.Update(1, segments...)
|
||||
|
||||
resp, err = suite.server.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
|
||||
NodeID: 1,
|
||||
})
|
||||
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp.Status))
|
||||
suite.Equal(2, len(resp.GetChannelNames()))
|
||||
suite.Equal(2, len(resp.GetSealedSegmentIDs()))
|
||||
}
|
||||
|
||||
func (suite *OpsServiceSuite) TestCheckQueryNodeDistribution() {
|
||||
// test server unhealthy
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
ctx := context.Background()
|
||||
resp, err := suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
// test node not found
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{
|
||||
TargetNodeID: 2,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
|
||||
resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{
|
||||
SourceNodeID: 1,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
// test success
|
||||
channels := []*meta.DmChannel{
|
||||
{
|
||||
VchannelInfo: &datapb.VchannelInfo{
|
||||
CollectionID: 1,
|
||||
ChannelName: "channel1",
|
||||
},
|
||||
Node: 1,
|
||||
},
|
||||
{
|
||||
VchannelInfo: &datapb.VchannelInfo{
|
||||
CollectionID: 1,
|
||||
ChannelName: "channel2",
|
||||
},
|
||||
Node: 1,
|
||||
},
|
||||
}
|
||||
|
||||
segments := []*meta.Segment{
|
||||
{
|
||||
SegmentInfo: &datapb.SegmentInfo{
|
||||
ID: 1,
|
||||
CollectionID: 1,
|
||||
PartitionID: 1,
|
||||
InsertChannel: "channel1",
|
||||
},
|
||||
Node: 1,
|
||||
},
|
||||
{
|
||||
SegmentInfo: &datapb.SegmentInfo{
|
||||
ID: 2,
|
||||
CollectionID: 1,
|
||||
PartitionID: 1,
|
||||
InsertChannel: "channel2",
|
||||
},
|
||||
Node: 1,
|
||||
},
|
||||
}
|
||||
suite.dist.ChannelDistManager.Update(1, channels...)
|
||||
suite.dist.SegmentDistManager.Update(1, segments...)
|
||||
|
||||
resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{
|
||||
SourceNodeID: 1,
|
||||
TargetNodeID: 2,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
suite.dist.ChannelDistManager.Update(2, channels...)
|
||||
suite.dist.SegmentDistManager.Update(2, segments...)
|
||||
resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{
|
||||
SourceNodeID: 1,
|
||||
TargetNodeID: 1,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
}
|
||||
|
||||
func (suite *OpsServiceSuite) TestSuspendAndResumeBalance() {
|
||||
// test server unhealthy
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
ctx := context.Background()
|
||||
resp, err := suite.server.SuspendBalance(ctx, &querypb.SuspendBalanceRequest{})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
resp, err = suite.server.ResumeBalance(ctx, &querypb.ResumeBalanceRequest{})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
// test suspend success
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
resp, err = suite.server.SuspendBalance(ctx, &querypb.SuspendBalanceRequest{})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
suite.False(suite.checkerController.IsActive(utils.BalanceChecker))
|
||||
|
||||
resp, err = suite.server.ResumeBalance(ctx, &querypb.ResumeBalanceRequest{})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
suite.True(suite.checkerController.IsActive(utils.BalanceChecker))
|
||||
}
|
||||
|
||||
func (suite *OpsServiceSuite) TestSuspendAndResumeNode() {
|
||||
// test server unhealthy
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
ctx := context.Background()
|
||||
resp, err := suite.server.SuspendNode(ctx, &querypb.SuspendNodeRequest{})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
resp, err = suite.server.ResumeNode(ctx, &querypb.ResumeNodeRequest{})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
// test node not found
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
resp, err = suite.server.SuspendNode(ctx, &querypb.SuspendNodeRequest{
|
||||
NodeID: 1,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
resp, err = suite.server.ResumeNode(ctx, &querypb.ResumeNodeRequest{
|
||||
NodeID: 1,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
// test success
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
resp, err = suite.server.SuspendNode(ctx, &querypb.SuspendNodeRequest{
|
||||
NodeID: 1,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
node := suite.nodeMgr.Get(1)
|
||||
suite.Equal(session.NodeStateSuspend, node.GetState())
|
||||
|
||||
resp, err = suite.server.ResumeNode(ctx, &querypb.ResumeNodeRequest{
|
||||
NodeID: 1,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
node = suite.nodeMgr.Get(1)
|
||||
suite.Equal(session.NodeStateNormal, node.GetState())
|
||||
}
|
||||
|
||||
func (suite *OpsServiceSuite) TestTransferSegment() {
|
||||
ctx := context.Background()
|
||||
|
||||
// test server unhealthy
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
resp, err := suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
// test source node not healthy
|
||||
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
|
||||
SourceNodeID: 1,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
|
||||
collectionID := int64(1)
|
||||
partitionID := int64(1)
|
||||
replicaID := int64(1)
|
||||
nodes := []int64{1, 2, 3, 4}
|
||||
replica := utils.CreateTestReplica(replicaID, collectionID, nodes)
|
||||
suite.meta.ReplicaManager.Put(replica)
|
||||
collection := utils.CreateTestCollection(collectionID, 1)
|
||||
partition := utils.CreateTestPartition(partitionID, collectionID)
|
||||
suite.meta.PutCollection(collection, partition)
|
||||
segmentIDs := []int64{1, 2, 3, 4}
|
||||
channelNames := []string{"channel-1", "channel-2", "channel-3", "channel-4"}
|
||||
|
||||
// test target node not healthy
|
||||
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: 2,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
|
||||
// test segment not exist in node
|
||||
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
SegmentID: segmentIDs[0],
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
segments := []*datapb.SegmentInfo{
|
||||
{
|
||||
ID: segmentIDs[0],
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
InsertChannel: channelNames[0],
|
||||
NumOfRows: 1,
|
||||
},
|
||||
{
|
||||
ID: segmentIDs[1],
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
InsertChannel: channelNames[1],
|
||||
NumOfRows: 1,
|
||||
},
|
||||
{
|
||||
ID: segmentIDs[2],
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
InsertChannel: channelNames[2],
|
||||
NumOfRows: 1,
|
||||
},
|
||||
{
|
||||
ID: segmentIDs[3],
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
InsertChannel: channelNames[3],
|
||||
NumOfRows: 1,
|
||||
},
|
||||
}
|
||||
|
||||
channels := []*datapb.VchannelInfo{
|
||||
{
|
||||
CollectionID: collectionID,
|
||||
ChannelName: channelNames[0],
|
||||
},
|
||||
{
|
||||
CollectionID: collectionID,
|
||||
ChannelName: channelNames[1],
|
||||
},
|
||||
{
|
||||
CollectionID: collectionID,
|
||||
ChannelName: channelNames[2],
|
||||
},
|
||||
{
|
||||
CollectionID: collectionID,
|
||||
ChannelName: channelNames[3],
|
||||
},
|
||||
}
|
||||
segmentInfos := lo.Map(segments, func(segment *datapb.SegmentInfo, _ int) *meta.Segment {
|
||||
return &meta.Segment{
|
||||
SegmentInfo: segment,
|
||||
Node: nodes[0],
|
||||
}
|
||||
})
|
||||
chanenlInfos := lo.Map(channels, func(channel *datapb.VchannelInfo, _ int) *meta.DmChannel {
|
||||
return &meta.DmChannel{
|
||||
VchannelInfo: channel,
|
||||
Node: nodes[0],
|
||||
}
|
||||
})
|
||||
suite.dist.SegmentDistManager.Update(1, segmentInfos[0])
|
||||
|
||||
// test segment not exist in current target, expect no task assign and success
|
||||
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
SegmentID: segmentIDs[0],
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
|
||||
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(channels, segments, nil)
|
||||
suite.targetMgr.UpdateCollectionNextTarget(1)
|
||||
suite.targetMgr.UpdateCollectionCurrentTarget(1)
|
||||
suite.dist.SegmentDistManager.Update(1, segmentInfos...)
|
||||
suite.dist.ChannelDistManager.Update(1, chanenlInfos...)
|
||||
|
||||
for _, node := range nodes {
|
||||
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: node,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
suite.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, node)
|
||||
}
|
||||
|
||||
// test transfer segment success, expect generate 1 balance segment task
|
||||
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
|
||||
actions := t.Actions()
|
||||
suite.Equal(len(actions), 2)
|
||||
suite.Equal(actions[0].Node(), int64(2))
|
||||
return nil
|
||||
})
|
||||
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
SegmentID: segmentIDs[0],
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
|
||||
// test copy mode, expect generate 1 load segment task
|
||||
suite.taskScheduler.ExpectedCalls = nil
|
||||
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
|
||||
actions := t.Actions()
|
||||
suite.Equal(len(actions), 1)
|
||||
suite.Equal(actions[0].Node(), int64(2))
|
||||
return nil
|
||||
})
|
||||
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
SegmentID: segmentIDs[0],
|
||||
CopyMode: true,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
|
||||
// test transfer all segments, expect generate 4 load segment task
|
||||
suite.taskScheduler.ExpectedCalls = nil
|
||||
counter := atomic.NewInt64(0)
|
||||
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
|
||||
actions := t.Actions()
|
||||
suite.Equal(len(actions), 2)
|
||||
suite.Equal(actions[0].Node(), int64(2))
|
||||
counter.Inc()
|
||||
return nil
|
||||
})
|
||||
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
TransferAll: true,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
suite.Equal(counter.Load(), int64(4))
|
||||
|
||||
// test transfer all segment to all nodes, expect generate 4 load segment task
|
||||
suite.taskScheduler.ExpectedCalls = nil
|
||||
counter = atomic.NewInt64(0)
|
||||
nodeSet := typeutil.NewUniqueSet()
|
||||
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
|
||||
actions := t.Actions()
|
||||
suite.Equal(len(actions), 2)
|
||||
nodeSet.Insert(actions[0].Node())
|
||||
counter.Inc()
|
||||
return nil
|
||||
})
|
||||
resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TransferAll: true,
|
||||
ToAllNodes: true,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
suite.Equal(counter.Load(), int64(4))
|
||||
suite.Len(nodeSet.Collect(), 3)
|
||||
}
|
||||
|
||||
func (suite *OpsServiceSuite) TestTransferChannel() {
|
||||
ctx := context.Background()
|
||||
|
||||
// test server unhealthy
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
resp, err := suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
// test source node not healthy
|
||||
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
|
||||
SourceNodeID: 1,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
|
||||
collectionID := int64(1)
|
||||
partitionID := int64(1)
|
||||
replicaID := int64(1)
|
||||
nodes := []int64{1, 2, 3, 4}
|
||||
replica := utils.CreateTestReplica(replicaID, collectionID, nodes)
|
||||
suite.meta.ReplicaManager.Put(replica)
|
||||
collection := utils.CreateTestCollection(collectionID, 1)
|
||||
partition := utils.CreateTestPartition(partitionID, collectionID)
|
||||
suite.meta.PutCollection(collection, partition)
|
||||
segmentIDs := []int64{1, 2, 3, 4}
|
||||
channelNames := []string{"channel-1", "channel-2", "channel-3", "channel-4"}
|
||||
|
||||
// test target node not healthy
|
||||
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: 2,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
|
||||
segments := []*datapb.SegmentInfo{
|
||||
{
|
||||
ID: segmentIDs[0],
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
InsertChannel: channelNames[0],
|
||||
NumOfRows: 1,
|
||||
},
|
||||
{
|
||||
ID: segmentIDs[1],
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
InsertChannel: channelNames[1],
|
||||
NumOfRows: 1,
|
||||
},
|
||||
{
|
||||
ID: segmentIDs[2],
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
InsertChannel: channelNames[2],
|
||||
NumOfRows: 1,
|
||||
},
|
||||
{
|
||||
ID: segmentIDs[3],
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
InsertChannel: channelNames[3],
|
||||
NumOfRows: 1,
|
||||
},
|
||||
}
|
||||
|
||||
channels := []*datapb.VchannelInfo{
|
||||
{
|
||||
CollectionID: collectionID,
|
||||
ChannelName: channelNames[0],
|
||||
},
|
||||
{
|
||||
CollectionID: collectionID,
|
||||
ChannelName: channelNames[1],
|
||||
},
|
||||
{
|
||||
CollectionID: collectionID,
|
||||
ChannelName: channelNames[2],
|
||||
},
|
||||
{
|
||||
CollectionID: collectionID,
|
||||
ChannelName: channelNames[3],
|
||||
},
|
||||
}
|
||||
segmentInfos := lo.Map(segments, func(segment *datapb.SegmentInfo, _ int) *meta.Segment {
|
||||
return &meta.Segment{
|
||||
SegmentInfo: segment,
|
||||
Node: nodes[0],
|
||||
}
|
||||
})
|
||||
suite.dist.SegmentDistManager.Update(1, segmentInfos...)
|
||||
chanenlInfos := lo.Map(channels, func(channel *datapb.VchannelInfo, _ int) *meta.DmChannel {
|
||||
return &meta.DmChannel{
|
||||
VchannelInfo: channel,
|
||||
Node: nodes[0],
|
||||
}
|
||||
})
|
||||
|
||||
// test channel not exist in node
|
||||
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
ChannelName: channelNames[0],
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.False(merr.Ok(resp))
|
||||
|
||||
suite.dist.ChannelDistManager.Update(1, chanenlInfos[0])
|
||||
|
||||
// test channel not exist in current target, expect no task assign and success
|
||||
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
ChannelName: channelNames[0],
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
|
||||
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(channels, segments, nil)
|
||||
suite.targetMgr.UpdateCollectionNextTarget(1)
|
||||
suite.targetMgr.UpdateCollectionCurrentTarget(1)
|
||||
suite.dist.SegmentDistManager.Update(1, segmentInfos...)
|
||||
suite.dist.ChannelDistManager.Update(1, chanenlInfos...)
|
||||
|
||||
for _, node := range nodes {
|
||||
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
|
||||
NodeID: node,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
suite.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, node)
|
||||
}
|
||||
|
||||
// test transfer channel success, expect generate 1 balance channel task
|
||||
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
|
||||
actions := t.Actions()
|
||||
suite.Equal(len(actions), 2)
|
||||
suite.Equal(actions[0].Node(), int64(2))
|
||||
return nil
|
||||
})
|
||||
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
ChannelName: channelNames[0],
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
|
||||
// test copy mode, expect generate 1 load segment task
|
||||
suite.taskScheduler.ExpectedCalls = nil
|
||||
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
|
||||
actions := t.Actions()
|
||||
suite.Equal(len(actions), 1)
|
||||
suite.Equal(actions[0].Node(), int64(2))
|
||||
return nil
|
||||
})
|
||||
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
ChannelName: channelNames[0],
|
||||
CopyMode: true,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
|
||||
// test transfer all channels, expect generate 4 load segment task
|
||||
suite.taskScheduler.ExpectedCalls = nil
|
||||
counter := atomic.NewInt64(0)
|
||||
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
|
||||
actions := t.Actions()
|
||||
suite.Equal(len(actions), 2)
|
||||
suite.Equal(actions[0].Node(), int64(2))
|
||||
counter.Inc()
|
||||
return nil
|
||||
})
|
||||
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TargetNodeID: nodes[1],
|
||||
TransferAll: true,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
suite.Equal(counter.Load(), int64(4))
|
||||
|
||||
// test transfer all channels to all nodes, expect generate 4 load segment task
|
||||
suite.taskScheduler.ExpectedCalls = nil
|
||||
counter = atomic.NewInt64(0)
|
||||
nodeSet := typeutil.NewUniqueSet()
|
||||
suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
|
||||
actions := t.Actions()
|
||||
suite.Equal(len(actions), 2)
|
||||
nodeSet.Insert(actions[0].Node())
|
||||
counter.Inc()
|
||||
return nil
|
||||
})
|
||||
resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{
|
||||
SourceNodeID: nodes[0],
|
||||
TransferAll: true,
|
||||
ToAllNodes: true,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.True(merr.Ok(resp))
|
||||
suite.Equal(counter.Load(), int64(4))
|
||||
suite.Len(nodeSet.Collect(), 3)
|
||||
}
|
||||
|
||||
func TestOpsService(t *testing.T) {
|
||||
suite.Run(t, new(OpsServiceSuite))
|
||||
}
|
|
@ -19,10 +19,14 @@ package querycoordv2
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/session"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
@ -93,3 +97,368 @@ func (s *Server) DeactivateChecker(ctx context.Context, req *querypb.DeactivateC
|
|||
}
|
||||
return merr.Success(), nil
|
||||
}
|
||||
|
||||
// return all available node list, for each node, return it's (nodeID, ip_address)
|
||||
func (s *Server) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error) {
|
||||
log := log.Ctx(ctx)
|
||||
log.Info("ListQueryNode request received")
|
||||
|
||||
errMsg := "failed to list querynode state"
|
||||
if err := merr.CheckHealthy(s.State()); err != nil {
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return &querypb.ListQueryNodeResponse{
|
||||
Status: merr.Status(errors.Wrap(err, errMsg)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
nodes := lo.Map(s.nodeMgr.GetAll(), func(nodeInfo *session.NodeInfo, _ int) *querypb.NodeInfo {
|
||||
return &querypb.NodeInfo{
|
||||
ID: nodeInfo.ID(),
|
||||
Address: nodeInfo.Addr(),
|
||||
State: nodeInfo.GetState().String(),
|
||||
}
|
||||
})
|
||||
|
||||
return &querypb.ListQueryNodeResponse{
|
||||
Status: merr.Success(),
|
||||
NodeInfos: nodes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// return query node's data distribution, for given nodeID, return it's (channel_name_list, sealed_segment_list)
|
||||
func (s *Server) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error) {
|
||||
log := log.Ctx(ctx).With(zap.Int64("nodeID", req.GetNodeID()))
|
||||
log.Info("GetQueryNodeDistribution request received")
|
||||
|
||||
errMsg := "failed to get query node distribution"
|
||||
if err := merr.CheckHealthy(s.State()); err != nil {
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return &querypb.GetQueryNodeDistributionResponse{
|
||||
Status: merr.Status(errors.Wrap(err, errMsg)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if s.nodeMgr.Get(req.GetNodeID()) == nil {
|
||||
err := merr.WrapErrNodeNotFound(req.GetNodeID(), errMsg)
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return &querypb.GetQueryNodeDistributionResponse{
|
||||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(req.GetNodeID()))
|
||||
channels := s.dist.ChannelDistManager.GetByNode(req.NodeID)
|
||||
return &querypb.GetQueryNodeDistributionResponse{
|
||||
Status: merr.Success(),
|
||||
ChannelNames: lo.Map(channels, func(c *meta.DmChannel, _ int) string { return c.GetChannelName() }),
|
||||
SealedSegmentIDs: lo.Map(segments, func(s *meta.Segment, _ int) int64 { return s.GetID() }),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// suspend background balance for all query node, include stopping balance and auto balance
|
||||
func (s *Server) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest) (*commonpb.Status, error) {
|
||||
log := log.Ctx(ctx)
|
||||
log.Info("SuspendBalance request received")
|
||||
|
||||
errMsg := "failed to suspend balance for all querynode"
|
||||
if err := merr.CheckHealthy(s.State()); err != nil {
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
err := s.checkerController.Deactivate(utils.BalanceChecker)
|
||||
if err != nil {
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
return merr.Success(), nil
|
||||
}
|
||||
|
||||
// resume background balance for all query node, include stopping balance and auto balance
|
||||
func (s *Server) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest) (*commonpb.Status, error) {
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
log.Info("ResumeBalance request received")
|
||||
|
||||
errMsg := "failed to resume balance for all querynode"
|
||||
if err := merr.CheckHealthy(s.State()); err != nil {
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
err := s.checkerController.Activate(utils.BalanceChecker)
|
||||
if err != nil {
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
return merr.Success(), nil
|
||||
}
|
||||
|
||||
// suspend node from resource operation, for given node, suspend load_segment/sub_channel operations
|
||||
func (s *Server) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest) (*commonpb.Status, error) {
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
log.Info("SuspendNode request received", zap.Int64("nodeID", req.GetNodeID()))
|
||||
|
||||
errMsg := "failed to suspend query node"
|
||||
if err := merr.CheckHealthy(s.State()); err != nil {
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
if s.nodeMgr.Get(req.GetNodeID()) == nil {
|
||||
err := merr.WrapErrNodeNotFound(req.GetNodeID(), errMsg)
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
err := s.nodeMgr.Suspend(req.GetNodeID())
|
||||
if err != nil {
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
return merr.Success(), nil
|
||||
}
|
||||
|
||||
// resume node from resource operation, for given node, resume load_segment/sub_channel operations
|
||||
func (s *Server) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest) (*commonpb.Status, error) {
|
||||
log := log.Ctx(ctx)
|
||||
log.Info("ResumeNode request received", zap.Int64("nodeID", req.GetNodeID()))
|
||||
|
||||
errMsg := "failed to resume query node"
|
||||
if err := merr.CheckHealthy(s.State()); err != nil {
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return merr.Status(errors.Wrap(err, errMsg)), nil
|
||||
}
|
||||
|
||||
if s.nodeMgr.Get(req.GetNodeID()) == nil {
|
||||
err := merr.WrapErrNodeNotFound(req.GetNodeID(), errMsg)
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
err := s.nodeMgr.Resume(req.GetNodeID())
|
||||
if err != nil {
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return merr.Status(errors.Wrap(err, errMsg)), nil
|
||||
}
|
||||
|
||||
return merr.Success(), nil
|
||||
}
|
||||
|
||||
// transfer segment from source to target,
|
||||
// if no segment_id specified, default to transfer all segment on the source node.
|
||||
// if no target_nodeId specified, default to move segment to all other nodes
|
||||
func (s *Server) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest) (*commonpb.Status, error) {
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
log.Info("TransferSegment request received",
|
||||
zap.Int64("source", req.GetSourceNodeID()),
|
||||
zap.Int64("dest", req.GetTargetNodeID()),
|
||||
zap.Int64("segment", req.GetSegmentID()))
|
||||
|
||||
if err := merr.CheckHealthy(s.State()); err != nil {
|
||||
msg := "failed to load balance"
|
||||
log.Warn(msg, zap.Error(err))
|
||||
return merr.Status(errors.Wrap(err, msg)), nil
|
||||
}
|
||||
|
||||
// check whether srcNode is healthy
|
||||
srcNode := req.GetSourceNodeID()
|
||||
if err := s.isStoppingNode(srcNode); err != nil {
|
||||
err := merr.WrapErrNodeNotAvailable(srcNode, "the source node is invalid")
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
replicas := s.meta.ReplicaManager.GetByNode(req.GetSourceNodeID())
|
||||
for _, replica := range replicas {
|
||||
// when no dst node specified, default to use all other nodes in same
|
||||
dstNodeSet := typeutil.NewUniqueSet()
|
||||
if req.GetToAllNodes() {
|
||||
outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica)
|
||||
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { return !outboundNodes.Contain(node) })
|
||||
dstNodeSet.Insert(availableNodes...)
|
||||
} else {
|
||||
// check whether dstNode is healthy
|
||||
if err := s.isStoppingNode(req.GetTargetNodeID()); err != nil {
|
||||
err := merr.WrapErrNodeNotAvailable(srcNode, "the target node is invalid")
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
dstNodeSet.Insert(req.GetTargetNodeID())
|
||||
}
|
||||
dstNodeSet.Remove(srcNode)
|
||||
|
||||
// check sealed segment list
|
||||
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(srcNode))
|
||||
|
||||
toBalance := typeutil.NewSet[*meta.Segment]()
|
||||
if req.GetTransferAll() {
|
||||
toBalance.Insert(segments...)
|
||||
} else {
|
||||
// check whether sealed segment exist
|
||||
segment, ok := lo.Find(segments, func(s *meta.Segment) bool { return s.GetID() == req.GetSegmentID() })
|
||||
if !ok {
|
||||
err := merr.WrapErrSegmentNotFound(req.GetSegmentID(), "segment not found in source node")
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
existInTarget := s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
|
||||
if !existInTarget {
|
||||
log.Info("segment doesn't exist in current target, skip it", zap.Int64("segmentID", req.GetSegmentID()))
|
||||
} else {
|
||||
toBalance.Insert(segment)
|
||||
}
|
||||
}
|
||||
|
||||
err := s.balanceSegments(ctx, replica.GetCollectionID(), replica, srcNode, dstNodeSet.Collect(), toBalance.Collect(), false, req.GetCopyMode())
|
||||
if err != nil {
|
||||
msg := "failed to balance segments"
|
||||
log.Warn(msg, zap.Error(err))
|
||||
return merr.Status(errors.Wrap(err, msg)), nil
|
||||
}
|
||||
}
|
||||
return merr.Success(), nil
|
||||
}
|
||||
|
||||
// transfer channel from source to target,
|
||||
// if no channel_name specified, default to transfer all channel on the source node.
|
||||
// if no target_nodeId specified, default to move channel to all other nodes
|
||||
func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest) (*commonpb.Status, error) {
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
log.Info("TransferChannel request received",
|
||||
zap.Int64("source", req.GetSourceNodeID()),
|
||||
zap.Int64("dest", req.GetTargetNodeID()),
|
||||
zap.String("channel", req.GetChannelName()))
|
||||
|
||||
if err := merr.CheckHealthy(s.State()); err != nil {
|
||||
msg := "failed to load balance"
|
||||
log.Warn(msg, zap.Error(err))
|
||||
return merr.Status(errors.Wrap(err, msg)), nil
|
||||
}
|
||||
|
||||
// check whether srcNode is healthy
|
||||
srcNode := req.GetSourceNodeID()
|
||||
if err := s.isStoppingNode(srcNode); err != nil {
|
||||
err := merr.WrapErrNodeNotAvailable(srcNode, "the source node is invalid")
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
replicas := s.meta.ReplicaManager.GetByNode(req.GetSourceNodeID())
|
||||
for _, replica := range replicas {
|
||||
// when no dst node specified, default to use all other nodes in same
|
||||
dstNodeSet := typeutil.NewUniqueSet()
|
||||
if req.GetToAllNodes() {
|
||||
outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica)
|
||||
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { return !outboundNodes.Contain(node) })
|
||||
dstNodeSet.Insert(availableNodes...)
|
||||
} else {
|
||||
// check whether dstNode is healthy
|
||||
if err := s.isStoppingNode(req.GetTargetNodeID()); err != nil {
|
||||
err := merr.WrapErrNodeNotAvailable(srcNode, "the target node is invalid")
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
dstNodeSet.Insert(req.GetTargetNodeID())
|
||||
}
|
||||
dstNodeSet.Remove(srcNode)
|
||||
|
||||
// check sealed segment list
|
||||
channels := s.dist.ChannelDistManager.GetByCollectionAndNode(replica.CollectionID, srcNode)
|
||||
toBalance := typeutil.NewSet[*meta.DmChannel]()
|
||||
if req.GetTransferAll() {
|
||||
toBalance.Insert(channels...)
|
||||
} else {
|
||||
// check whether sealed segment exist
|
||||
channel, ok := lo.Find(channels, func(ch *meta.DmChannel) bool { return ch.GetChannelName() == req.GetChannelName() })
|
||||
if !ok {
|
||||
err := merr.WrapErrChannelNotFound(req.GetChannelName(), "channel not found in source node")
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
existInTarget := s.targetMgr.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget) != nil
|
||||
if !existInTarget {
|
||||
log.Info("channel doesn't exist in current target, skip it", zap.String("channelName", channel.GetChannelName()))
|
||||
} else {
|
||||
toBalance.Insert(channel)
|
||||
}
|
||||
}
|
||||
|
||||
err := s.balanceChannels(ctx, replica.GetCollectionID(), replica, srcNode, dstNodeSet.Collect(), toBalance.Collect(), false, req.GetCopyMode())
|
||||
if err != nil {
|
||||
msg := "failed to balance channels"
|
||||
log.Warn(msg, zap.Error(err))
|
||||
return merr.Status(errors.Wrap(err, msg)), nil
|
||||
}
|
||||
}
|
||||
return merr.Success(), nil
|
||||
}
|
||||
|
||||
func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error) {
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
log.Info("CheckQueryNodeDistribution request received",
|
||||
zap.Int64("source", req.GetSourceNodeID()),
|
||||
zap.Int64("dest", req.GetTargetNodeID()))
|
||||
|
||||
errMsg := "failed to check query node distribution"
|
||||
if err := merr.CheckHealthy(s.State()); err != nil {
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
sourceNode := s.nodeMgr.Get(req.GetSourceNodeID())
|
||||
if sourceNode == nil {
|
||||
err := merr.WrapErrNodeNotFound(req.GetSourceNodeID(), "source node not found")
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
targetNode := s.nodeMgr.Get(req.GetTargetNodeID())
|
||||
if targetNode == nil {
|
||||
err := merr.WrapErrNodeNotFound(req.GetTargetNodeID(), "target node not found")
|
||||
log.Warn(errMsg, zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
// check channel list
|
||||
channelOnSrc := s.dist.ChannelDistManager.GetByNode(req.GetSourceNodeID())
|
||||
channelOnDst := s.dist.ChannelDistManager.GetByNode(req.GetTargetNodeID())
|
||||
channelDstMap := lo.SliceToMap(channelOnDst, func(ch *meta.DmChannel) (string, *meta.DmChannel) {
|
||||
return ch.GetChannelName(), ch
|
||||
})
|
||||
for _, ch := range channelOnSrc {
|
||||
if _, ok := channelDstMap[ch.GetChannelName()]; !ok {
|
||||
return merr.Status(merr.WrapErrChannelLack(ch.GetChannelName())), nil
|
||||
}
|
||||
}
|
||||
channelSrcMap := lo.SliceToMap(channelOnSrc, func(ch *meta.DmChannel) (string, *meta.DmChannel) {
|
||||
return ch.GetChannelName(), ch
|
||||
})
|
||||
for _, ch := range channelOnDst {
|
||||
if _, ok := channelSrcMap[ch.GetChannelName()]; !ok {
|
||||
return merr.Status(merr.WrapErrChannelLack(ch.GetChannelName())), nil
|
||||
}
|
||||
}
|
||||
|
||||
// check segment list
|
||||
segmentOnSrc := s.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(req.GetSourceNodeID()))
|
||||
segmentOnDst := s.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(req.GetTargetNodeID()))
|
||||
segmentDstMap := lo.SliceToMap(segmentOnDst, func(s *meta.Segment) (int64, *meta.Segment) {
|
||||
return s.GetID(), s
|
||||
})
|
||||
for _, s := range segmentOnSrc {
|
||||
if _, ok := segmentDstMap[s.GetID()]; !ok {
|
||||
return merr.Status(merr.WrapErrSegmentLack(s.GetID())), nil
|
||||
}
|
||||
}
|
||||
segmentSrcMap := lo.SliceToMap(segmentOnSrc, func(s *meta.Segment) (int64, *meta.Segment) {
|
||||
return s.GetID(), s
|
||||
})
|
||||
for _, s := range segmentOnDst {
|
||||
if _, ok := segmentSrcMap[s.GetID()]; !ok {
|
||||
return merr.Status(merr.WrapErrSegmentLack(s.GetID())), nil
|
||||
}
|
||||
}
|
||||
|
||||
return merr.Success(), nil
|
||||
}
|
||||
|
|
|
@ -682,24 +682,67 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques
|
|||
return merr.Status(errors.Wrap(err,
|
||||
fmt.Sprintf("can't balance, because the source node[%d] is invalid", srcNode))), nil
|
||||
}
|
||||
for _, dstNode := range req.GetDstNodeIDs() {
|
||||
if !replica.Contains(dstNode) {
|
||||
err := merr.WrapErrNodeNotFound(dstNode, "destination node not found in the same replica")
|
||||
log.Warn("failed to balance to the destination node", zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
|
||||
// when no dst node specified, default to use all other nodes in same
|
||||
dstNodeSet := typeutil.NewUniqueSet()
|
||||
if len(req.GetDstNodeIDs()) == 0 {
|
||||
outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica)
|
||||
availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { return !outboundNodes.Contain(node) })
|
||||
dstNodeSet.Insert(availableNodes...)
|
||||
} else {
|
||||
for _, dstNode := range req.GetDstNodeIDs() {
|
||||
if !replica.Contains(dstNode) {
|
||||
err := merr.WrapErrNodeNotFound(dstNode, "destination node not found in the same replica")
|
||||
log.Warn("failed to balance to the destination node", zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
dstNodeSet.Insert(dstNode)
|
||||
}
|
||||
}
|
||||
|
||||
// check whether dstNode is healthy
|
||||
for dstNode := range dstNodeSet {
|
||||
if err := s.isStoppingNode(dstNode); err != nil {
|
||||
return merr.Status(errors.Wrap(err,
|
||||
fmt.Sprintf("can't balance, because the destination node[%d] is invalid", dstNode))), nil
|
||||
}
|
||||
}
|
||||
|
||||
err := s.balanceSegments(ctx, req, replica)
|
||||
// check sealed segment list
|
||||
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(req.GetCollectionID()), meta.WithNodeID(srcNode))
|
||||
segmentsMap := lo.SliceToMap(segments, func(s *meta.Segment) (int64, *meta.Segment) {
|
||||
return s.GetID(), s
|
||||
})
|
||||
|
||||
toBalance := typeutil.NewSet[*meta.Segment]()
|
||||
if len(req.GetSealedSegmentIDs()) == 0 {
|
||||
toBalance.Insert(segments...)
|
||||
} else {
|
||||
// check whether sealed segment exist
|
||||
for _, segmentID := range req.GetSealedSegmentIDs() {
|
||||
segment, ok := segmentsMap[segmentID]
|
||||
if !ok {
|
||||
err := merr.WrapErrSegmentNotFound(segmentID, "segment not found in source node")
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
// Only balance segments in targets
|
||||
existInTarget := s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
|
||||
if !existInTarget {
|
||||
log.Info("segment doesn't exist in current target, skip it", zap.Int64("segmentID", segmentID))
|
||||
continue
|
||||
}
|
||||
toBalance.Insert(segment)
|
||||
}
|
||||
}
|
||||
|
||||
err := s.balanceSegments(ctx, replica.GetCollectionID(), replica, srcNode, dstNodeSet.Collect(), toBalance.Collect(), true, false)
|
||||
if err != nil {
|
||||
msg := "failed to balance segments"
|
||||
log.Warn(msg, zap.Error(err))
|
||||
return merr.Status(errors.Wrap(err, msg)), nil
|
||||
}
|
||||
|
||||
return merr.Success(), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1174,6 +1174,51 @@ func (suite *ServiceSuite) TestLoadBalance() {
|
|||
suite.Equal(resp.GetCode(), merr.Code(merr.ErrServiceNotReady))
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) TestLoadBalanceWithNoDstNode() {
|
||||
suite.loadAll()
|
||||
ctx := context.Background()
|
||||
server := suite.server
|
||||
|
||||
// Test get balance first segment
|
||||
for _, collection := range suite.collections {
|
||||
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
|
||||
nodes := replicas[0].GetNodes()
|
||||
srcNode := nodes[0]
|
||||
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
|
||||
suite.updateSegmentDist(collection, srcNode)
|
||||
segments := suite.getAllSegments(collection)
|
||||
req := &querypb.LoadBalanceRequest{
|
||||
CollectionID: collection,
|
||||
SourceNodeIDs: []int64{srcNode},
|
||||
SealedSegmentIDs: segments,
|
||||
}
|
||||
suite.taskScheduler.ExpectedCalls = make([]*mock.Call, 0)
|
||||
suite.taskScheduler.EXPECT().Add(mock.Anything).Run(func(task task.Task) {
|
||||
actions := task.Actions()
|
||||
suite.Len(actions, 2)
|
||||
growAction, reduceAction := actions[0], actions[1]
|
||||
suite.Contains(nodes, growAction.Node())
|
||||
suite.Equal(srcNode, reduceAction.Node())
|
||||
task.Cancel(nil)
|
||||
}).Return(nil)
|
||||
resp, err := server.LoadBalance(ctx, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
suite.taskScheduler.AssertExpectations(suite.T())
|
||||
}
|
||||
|
||||
// Test when server is not healthy
|
||||
server.UpdateStateCode(commonpb.StateCode_Initializing)
|
||||
req := &querypb.LoadBalanceRequest{
|
||||
CollectionID: suite.collections[0],
|
||||
SourceNodeIDs: []int64{1},
|
||||
DstNodeIDs: []int64{100 + 1},
|
||||
}
|
||||
resp, err := server.LoadBalance(ctx, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(resp.GetCode(), merr.Code(merr.ErrServiceNotReady))
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() {
|
||||
suite.loadAll()
|
||||
ctx := context.Background()
|
||||
|
|
|
@ -23,8 +23,11 @@ import (
|
|||
|
||||
"github.com/blang/semver/v4"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
|
@ -33,6 +36,9 @@ type Manager interface {
|
|||
Remove(nodeID int64)
|
||||
Get(nodeID int64) *NodeInfo
|
||||
GetAll() []*NodeInfo
|
||||
|
||||
Suspend(nodeID int64) error
|
||||
Resume(nodeID int64) error
|
||||
}
|
||||
|
||||
type NodeManager struct {
|
||||
|
@ -62,6 +68,42 @@ func (m *NodeManager) Stopping(nodeID int64) {
|
|||
}
|
||||
}
|
||||
|
||||
func (m *NodeManager) Suspend(nodeID int64) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
nodeInfo, ok := m.nodes[nodeID]
|
||||
if !ok {
|
||||
return merr.WrapErrNodeNotFound(nodeID)
|
||||
}
|
||||
switch nodeInfo.GetState() {
|
||||
case NodeStateNormal:
|
||||
nodeInfo.SetState(NodeStateSuspend)
|
||||
return nil
|
||||
default:
|
||||
log.Warn("failed to suspend query node", zap.Int64("nodeID", nodeID), zap.String("state", nodeInfo.GetState().String()))
|
||||
return merr.WrapErrNodeStateUnexpected(nodeID, nodeInfo.GetState().String(), "failed to suspend a query node")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *NodeManager) Resume(nodeID int64) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
nodeInfo, ok := m.nodes[nodeID]
|
||||
if !ok {
|
||||
return merr.WrapErrNodeNotFound(nodeID)
|
||||
}
|
||||
|
||||
switch nodeInfo.GetState() {
|
||||
case NodeStateSuspend:
|
||||
nodeInfo.SetState(NodeStateNormal)
|
||||
return nil
|
||||
|
||||
default:
|
||||
log.Warn("failed to resume query node", zap.Int64("nodeID", nodeID), zap.String("state", nodeInfo.GetState().String()))
|
||||
return merr.WrapErrNodeStateUnexpected(nodeID, nodeInfo.GetState().String(), "failed to resume query node")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *NodeManager) IsStoppingNode(nodeID int64) (bool, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
@ -98,8 +140,9 @@ func NewNodeManager() *NodeManager {
|
|||
type State int
|
||||
|
||||
const (
|
||||
NodeStateNormal = iota
|
||||
NodeStateStopping
|
||||
NormalStateName = "Normal"
|
||||
StoppingStateName = "Stopping"
|
||||
SuspendStateName = "Suspend"
|
||||
)
|
||||
|
||||
type ImmutableNodeInfo struct {
|
||||
|
@ -109,6 +152,22 @@ type ImmutableNodeInfo struct {
|
|||
Version semver.Version
|
||||
}
|
||||
|
||||
const (
|
||||
NodeStateNormal State = iota
|
||||
NodeStateStopping
|
||||
NodeStateSuspend
|
||||
)
|
||||
|
||||
var stateNameMap = map[State]string{
|
||||
NodeStateNormal: NormalStateName,
|
||||
NodeStateStopping: StoppingStateName,
|
||||
NodeStateSuspend: SuspendStateName,
|
||||
}
|
||||
|
||||
func (s State) String() string {
|
||||
return stateNameMap[s]
|
||||
}
|
||||
|
||||
type NodeInfo struct {
|
||||
stats
|
||||
mu sync.RWMutex
|
||||
|
@ -161,6 +220,12 @@ func (n *NodeInfo) SetState(s State) {
|
|||
n.state = s
|
||||
}
|
||||
|
||||
func (n *NodeInfo) GetState() State {
|
||||
n.mu.RLock()
|
||||
defer n.mu.RUnlock()
|
||||
return n.state
|
||||
}
|
||||
|
||||
func (n *NodeInfo) UpdateStats(opts ...StatsOption) {
|
||||
n.mu.Lock()
|
||||
for _, opt := range opts {
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package session
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type NodeManagerSuite struct {
|
||||
suite.Suite
|
||||
|
||||
nodeManager *NodeManager
|
||||
}
|
||||
|
||||
func (s *NodeManagerSuite) SetupTest() {
|
||||
s.nodeManager = NewNodeManager()
|
||||
}
|
||||
|
||||
func (s *NodeManagerSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (s *NodeManagerSuite) TestNodeOperation() {
|
||||
s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{
|
||||
NodeID: 2,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{
|
||||
NodeID: 3,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
}))
|
||||
|
||||
s.NotNil(s.nodeManager.Get(1))
|
||||
s.Len(s.nodeManager.GetAll(), 3)
|
||||
s.nodeManager.Remove(1)
|
||||
s.Nil(s.nodeManager.Get(1))
|
||||
s.Len(s.nodeManager.GetAll(), 2)
|
||||
|
||||
s.nodeManager.Stopping(2)
|
||||
s.True(s.nodeManager.IsStoppingNode(2))
|
||||
err := s.nodeManager.Resume(2)
|
||||
s.ErrorIs(err, merr.ErrNodeStateUnexpected)
|
||||
s.True(s.nodeManager.IsStoppingNode(2))
|
||||
node := s.nodeManager.Get(2)
|
||||
node.SetState(NodeStateNormal)
|
||||
s.False(s.nodeManager.IsStoppingNode(2))
|
||||
|
||||
err = s.nodeManager.Resume(3)
|
||||
s.ErrorIs(err, merr.ErrNodeStateUnexpected)
|
||||
|
||||
s.nodeManager.Suspend(3)
|
||||
node = s.nodeManager.Get(3)
|
||||
s.NotNil(node)
|
||||
s.Equal(NodeStateSuspend, node.GetState())
|
||||
s.nodeManager.Resume(3)
|
||||
node = s.nodeManager.Get(3)
|
||||
s.NotNil(node)
|
||||
s.Equal(NodeStateNormal, node.GetState())
|
||||
}
|
||||
|
||||
func (s *NodeManagerSuite) TestNodeInfo() {
|
||||
node := NewNodeInfo(ImmutableNodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "localhost",
|
||||
Hostname: "localhost",
|
||||
})
|
||||
s.Equal(int64(1), node.ID())
|
||||
s.Equal("localhost", node.Addr())
|
||||
node.setChannelCnt(1)
|
||||
node.setSegmentCnt(1)
|
||||
s.Equal(1, node.ChannelCnt())
|
||||
s.Equal(1, node.SegmentCnt())
|
||||
|
||||
node.UpdateStats(WithSegmentCnt(5))
|
||||
node.UpdateStats(WithChannelCnt(5))
|
||||
s.Equal(5, node.ChannelCnt())
|
||||
s.Equal(5, node.SegmentCnt())
|
||||
|
||||
node.SetLastHeartbeat(time.Now())
|
||||
s.NotNil(node.LastHeartbeat())
|
||||
}
|
||||
|
||||
func TestNodeManagerSuite(t *testing.T) {
|
||||
suite.Run(t, new(NodeManagerSuite))
|
||||
}
|
|
@ -27,6 +27,7 @@ const (
|
|||
BalanceCheckerName = "balance_checker"
|
||||
IndexCheckerName = "index_checker"
|
||||
LeaderCheckerName = "leader_checker"
|
||||
ManualBalanceName = "manual_balance"
|
||||
)
|
||||
|
||||
type CheckerType int32
|
||||
|
@ -37,6 +38,7 @@ const (
|
|||
BalanceChecker
|
||||
IndexChecker
|
||||
LeaderChecker
|
||||
ManualBalance
|
||||
)
|
||||
|
||||
var checkerNames = map[CheckerType]string{
|
||||
|
@ -45,6 +47,7 @@ var checkerNames = map[CheckerType]string{
|
|||
BalanceChecker: BalanceCheckerName,
|
||||
IndexChecker: IndexCheckerName,
|
||||
LeaderChecker: LeaderCheckerName,
|
||||
ManualBalance: ManualBalanceName,
|
||||
}
|
||||
|
||||
func (s CheckerType) String() string {
|
||||
|
|
|
@ -141,3 +141,39 @@ func (m *GrpcQueryCoordClient) ActivateChecker(ctx context.Context, in *querypb.
|
|||
func (m *GrpcQueryCoordClient) DeactivateChecker(ctx context.Context, in *querypb.DeactivateCheckerRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryCoordClient) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest, opts ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error) {
|
||||
return &querypb.ListQueryNodeResponse{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryCoordClient) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error) {
|
||||
return &querypb.GetQueryNodeDistributionResponse{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryCoordClient) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryCoordClient) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryCoordClient) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryCoordClient) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryCoordClient) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryCoordClient) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryCoordClient) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
|
|
@ -90,11 +90,12 @@ var (
|
|||
ErrDatabaseInvalidName = newMilvusError("invalid database name", 802, false)
|
||||
|
||||
// Node related
|
||||
ErrNodeNotFound = newMilvusError("node not found", 901, false)
|
||||
ErrNodeOffline = newMilvusError("node offline", 902, false)
|
||||
ErrNodeLack = newMilvusError("node lacks", 903, false)
|
||||
ErrNodeNotMatch = newMilvusError("node not match", 904, false)
|
||||
ErrNodeNotAvailable = newMilvusError("node not available", 905, false)
|
||||
ErrNodeNotFound = newMilvusError("node not found", 901, false)
|
||||
ErrNodeOffline = newMilvusError("node offline", 902, false)
|
||||
ErrNodeLack = newMilvusError("node lacks", 903, false)
|
||||
ErrNodeNotMatch = newMilvusError("node not match", 904, false)
|
||||
ErrNodeNotAvailable = newMilvusError("node not available", 905, false)
|
||||
ErrNodeStateUnexpected = newMilvusError("node state unexpected", 906, false)
|
||||
|
||||
// IO related
|
||||
ErrIoKeyNotFound = newMilvusError("key not found", 1000, false)
|
||||
|
|
|
@ -120,6 +120,7 @@ func (s *ErrSuite) TestWrap() {
|
|||
s.ErrorIs(WrapErrNodeNotFound(1, "failed to get node"), ErrNodeNotFound)
|
||||
s.ErrorIs(WrapErrNodeOffline(1, "failed to access node"), ErrNodeOffline)
|
||||
s.ErrorIs(WrapErrNodeLack(3, 1, "need more nodes"), ErrNodeLack)
|
||||
s.ErrorIs(WrapErrNodeStateUnexpected(1, "Stopping", "failed to suspend node"), ErrNodeStateUnexpected)
|
||||
|
||||
// IO related
|
||||
s.ErrorIs(WrapErrIoKeyNotFound("test_key", "failed to read"), ErrIoKeyNotFound)
|
||||
|
|
|
@ -731,6 +731,14 @@ func WrapErrNodeNotAvailable(id int64, msg ...string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func WrapErrNodeStateUnexpected(id int64, state string, msg ...string) error {
|
||||
err := wrapFields(ErrNodeStateUnexpected, value("node", id), value("state", state))
|
||||
if len(msg) > 0 {
|
||||
err = errors.Wrap(err, strings.Join(msg, "->"))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func WrapErrNodeNotMatch(expectedNodeID, actualNodeID int64, msg ...string) error {
|
||||
err := wrapFields(ErrNodeNotMatch,
|
||||
value("expectedNodeID", expectedNodeID),
|
||||
|
|
|
@ -0,0 +1,364 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package rollingupgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
||||
type ManualRollingUpgradeSuite struct {
|
||||
integration.MiniClusterSuite
|
||||
}
|
||||
|
||||
func (s *ManualRollingUpgradeSuite) SetupSuite() {
|
||||
paramtable.Init()
|
||||
params := paramtable.Get()
|
||||
params.Save(params.QueryCoordCfg.BalanceCheckInterval.Key, "2000")
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
s.Require().NoError(s.SetupEmbedEtcd())
|
||||
}
|
||||
|
||||
func (s *ManualRollingUpgradeSuite) TearDownSuite() {
|
||||
params := paramtable.Get()
|
||||
params.Reset(params.QueryCoordCfg.BalanceCheckInterval.Key)
|
||||
|
||||
s.TearDownEmbedEtcd()
|
||||
}
|
||||
|
||||
func (s *ManualRollingUpgradeSuite) TestTransfer() {
|
||||
c := s.Cluster
|
||||
ctx, cancel := context.WithCancel(c.GetContext())
|
||||
defer cancel()
|
||||
|
||||
prefix := "TestTransfer"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
dim := 128
|
||||
rowNum := 3000
|
||||
insertRound := 5
|
||||
|
||||
schema := integration.ConstructSchema(collectionName, dim, true)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
s.NoError(err)
|
||||
|
||||
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: 2,
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
err = merr.Error(createCollectionStatus)
|
||||
if err != nil {
|
||||
log.Warn("createCollectionStatus fail reason", zap.Error(err))
|
||||
}
|
||||
|
||||
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
|
||||
showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(showCollectionsResp.GetStatus()))
|
||||
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
|
||||
|
||||
// insert data, and flush generate segment
|
||||
pkFieldData := integration.NewInt64FieldData(integration.Int64Field, rowNum)
|
||||
hashKeys := integration.GenerateHashKeys(rowNum)
|
||||
for i := range lo.Range(insertRound) {
|
||||
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
FieldsData: []*schemapb.FieldData{pkFieldData, pkFieldData},
|
||||
HashKeys: hashKeys,
|
||||
NumRows: uint32(rowNum),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.False(merr.Ok(insertResult.GetStatus()))
|
||||
log.Info("Insert succeed", zap.Int("round", i+1))
|
||||
resp, err := s.Cluster.Proxy.Flush(ctx, &milvuspb.FlushRequest{
|
||||
DbName: dbName,
|
||||
CollectionNames: []string{collectionName},
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(resp.GetStatus()))
|
||||
}
|
||||
|
||||
// create index
|
||||
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
if err != nil {
|
||||
log.Warn("createIndexStatus fail reason", zap.Error(err))
|
||||
}
|
||||
|
||||
s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField)
|
||||
log.Info("Create index done")
|
||||
|
||||
// load
|
||||
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(loadStatus)
|
||||
if err != nil {
|
||||
log.Warn("LoadCollection fail reason", zap.Error(err))
|
||||
}
|
||||
s.WaitForLoad(ctx, collectionName)
|
||||
log.Info("Load collection done")
|
||||
|
||||
// suspend balance
|
||||
resp2, err := s.Cluster.QueryCoord.SuspendBalance(ctx, &querypb.SuspendBalanceRequest{})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(resp2))
|
||||
|
||||
// get origin qn
|
||||
qnServer1 := s.Cluster.QueryNode
|
||||
qn1 := qnServer1.GetQueryNode()
|
||||
|
||||
// add new querynode
|
||||
qnSever2 := s.Cluster.AddQueryNode()
|
||||
time.Sleep(5 * time.Second)
|
||||
qn2 := qnSever2.GetQueryNode()
|
||||
|
||||
// expected 2 querynode found
|
||||
resp3, err := s.Cluster.QueryCoordClient.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{})
|
||||
s.NoError(err)
|
||||
s.Len(resp3.GetNodeInfos(), 2)
|
||||
|
||||
// due to balance has been suspended, qn2 won't have any segment/channel distribution
|
||||
resp4, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
|
||||
NodeID: qn2.GetNodeID(),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Len(resp4.GetChannelNames(), 0)
|
||||
s.Len(resp4.GetSealedSegmentIDs(), 0)
|
||||
|
||||
resp5, err := s.Cluster.QueryCoordClient.TransferChannel(ctx, &querypb.TransferChannelRequest{
|
||||
SourceNodeID: qn1.GetNodeID(),
|
||||
TargetNodeID: qn2.GetNodeID(),
|
||||
TransferAll: true,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(resp5))
|
||||
|
||||
// wait for transfer channel done
|
||||
s.Eventually(func() bool {
|
||||
resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
|
||||
NodeID: qn1.GetNodeID(),
|
||||
})
|
||||
s.NoError(err)
|
||||
return len(resp.GetChannelNames()) == 0
|
||||
}, 10*time.Second, 1*time.Second)
|
||||
|
||||
// test transfer segment
|
||||
resp6, err := s.Cluster.QueryCoordClient.TransferSegment(ctx, &querypb.TransferSegmentRequest{
|
||||
SourceNodeID: qn1.GetNodeID(),
|
||||
TargetNodeID: qn2.GetNodeID(),
|
||||
TransferAll: true,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(resp6))
|
||||
|
||||
// wait for transfer segment done
|
||||
s.Eventually(func() bool {
|
||||
resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
|
||||
NodeID: qn1.GetNodeID(),
|
||||
})
|
||||
s.NoError(err)
|
||||
return len(resp.GetSealedSegmentIDs()) == 0
|
||||
}, 10*time.Second, 1*time.Second)
|
||||
|
||||
// resume balance, segment/channel will be balance to qn1
|
||||
resp7, err := s.Cluster.QueryCoord.ResumeBalance(ctx, &querypb.ResumeBalanceRequest{})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(resp7))
|
||||
|
||||
s.Eventually(func() bool {
|
||||
resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
|
||||
NodeID: qn1.GetNodeID(),
|
||||
})
|
||||
s.NoError(err)
|
||||
return len(resp.GetSealedSegmentIDs()) > 0 || len(resp.GetChannelNames()) > 0
|
||||
}, 10*time.Second, 1*time.Second)
|
||||
|
||||
log.Info("==================")
|
||||
log.Info("==================")
|
||||
log.Info("TestManualRollingUpgrade succeed")
|
||||
log.Info("==================")
|
||||
log.Info("==================")
|
||||
}
|
||||
|
||||
func (s *ManualRollingUpgradeSuite) TestSuspendNode() {
|
||||
c := s.Cluster
|
||||
ctx, cancel := context.WithCancel(c.GetContext())
|
||||
defer cancel()
|
||||
|
||||
prefix := "TestSuspendNode"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
dim := 128
|
||||
rowNum := 3000
|
||||
insertRound := 5
|
||||
|
||||
schema := integration.ConstructSchema(collectionName, dim, true)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
s.NoError(err)
|
||||
|
||||
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: 2,
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
err = merr.Error(createCollectionStatus)
|
||||
if err != nil {
|
||||
log.Warn("createCollectionStatus fail reason", zap.Error(err))
|
||||
}
|
||||
|
||||
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
|
||||
showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(showCollectionsResp.GetStatus()))
|
||||
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
|
||||
|
||||
// insert data, and flush generate segment
|
||||
pkFieldData := integration.NewInt64FieldData(integration.Int64Field, rowNum)
|
||||
hashKeys := integration.GenerateHashKeys(rowNum)
|
||||
for i := range lo.Range(insertRound) {
|
||||
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
FieldsData: []*schemapb.FieldData{pkFieldData, pkFieldData},
|
||||
HashKeys: hashKeys,
|
||||
NumRows: uint32(rowNum),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.False(merr.Ok(insertResult.GetStatus()))
|
||||
log.Info("Insert succeed", zap.Int("round", i+1))
|
||||
resp, err := s.Cluster.Proxy.Flush(ctx, &milvuspb.FlushRequest{
|
||||
DbName: dbName,
|
||||
CollectionNames: []string{collectionName},
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(resp.GetStatus()))
|
||||
}
|
||||
|
||||
// create index
|
||||
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
if err != nil {
|
||||
log.Warn("createIndexStatus fail reason", zap.Error(err))
|
||||
}
|
||||
|
||||
s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField)
|
||||
log.Info("Create index done")
|
||||
|
||||
// add new querynode
|
||||
qnSever2 := s.Cluster.AddQueryNode()
|
||||
time.Sleep(5 * time.Second)
|
||||
qn2 := qnSever2.GetQueryNode()
|
||||
|
||||
// expected 2 querynode found
|
||||
resp3, err := s.Cluster.QueryCoordClient.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{})
|
||||
s.NoError(err)
|
||||
s.Len(resp3.GetNodeInfos(), 2)
|
||||
|
||||
// suspend Node
|
||||
resp2, err := s.Cluster.QueryCoord.SuspendNode(ctx, &querypb.SuspendNodeRequest{
|
||||
NodeID: qn2.GetNodeID(),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(resp2))
|
||||
|
||||
// load
|
||||
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(loadStatus)
|
||||
if err != nil {
|
||||
log.Warn("LoadCollection fail reason", zap.Error(err))
|
||||
}
|
||||
s.WaitForLoad(ctx, collectionName)
|
||||
log.Info("Load collection done")
|
||||
|
||||
// due to node has been suspended, no segment/channel will be loaded to this qn
|
||||
resp4, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
|
||||
NodeID: qn2.GetNodeID(),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Len(resp4.GetChannelNames(), 0)
|
||||
s.Len(resp4.GetSealedSegmentIDs(), 0)
|
||||
|
||||
// resume node, segment/channel will be balance to qn2
|
||||
resp5, err := s.Cluster.QueryCoord.ResumeNode(ctx, &querypb.ResumeNodeRequest{
|
||||
NodeID: qn2.GetNodeID(),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(resp5))
|
||||
|
||||
s.Eventually(func() bool {
|
||||
resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{
|
||||
NodeID: qn2.GetNodeID(),
|
||||
})
|
||||
s.NoError(err)
|
||||
return len(resp.GetSealedSegmentIDs()) > 0 || len(resp.GetChannelNames()) > 0
|
||||
}, 10*time.Second, 1*time.Second)
|
||||
|
||||
log.Info("==================")
|
||||
log.Info("==================")
|
||||
log.Info("TestSuspendNode succeed")
|
||||
log.Info("==================")
|
||||
log.Info("==================")
|
||||
}
|
||||
|
||||
func TestManualRollingUpgrade(t *testing.T) {
|
||||
suite.Run(t, new(ManualRollingUpgradeSuite))
|
||||
}
|
Loading…
Reference in New Issue