milvus/internal/querycoordv2/session/cluster_test.go

380 lines
12 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"context"
"net"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/mocks"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
const bufSize = 1024 * 1024
type ClusterTestSuite struct {
suite.Suite
svrs []*grpc.Server
listeners []net.Listener
cluster *QueryCluster
nodeManager *NodeManager
}
func (suite *ClusterTestSuite) SetupSuite() {
paramtable.Init()
paramtable.Get().Save("grpc.client.maxMaxAttempts", "1")
suite.setupServers()
}
func (suite *ClusterTestSuite) TearDownSuite() {
paramtable.Get().Save("grpc.client.maxMaxAttempts", strconv.FormatInt(paramtable.DefaultMaxAttempts, 10))
for _, svr := range suite.svrs {
svr.GracefulStop()
}
}
func (suite *ClusterTestSuite) SetupTest() {
suite.setupCluster()
}
func (suite *ClusterTestSuite) TearDownTest() {
suite.cluster.Stop()
}
func (suite *ClusterTestSuite) setupServers() {
svrs := suite.createTestServers()
for _, svr := range svrs {
lis, err := net.Listen("tcp", ":0")
suite.NoError(err)
suite.listeners = append(suite.listeners, lis)
s := grpc.NewServer()
querypb.RegisterQueryNodeServer(s, svr)
go func() {
suite.Eventually(func() bool {
return s.Serve(lis) == nil
}, 10*time.Second, 100*time.Millisecond)
}()
suite.svrs = append(suite.svrs, s)
}
// check server ready to serve
for _, lis := range suite.listeners {
conn, err := grpc.Dial(lis.Addr().String(), grpc.WithBlock(), grpc.WithTransportCredentials(insecure.NewCredentials()))
suite.NoError(err)
suite.NoError(conn.Close())
}
}
func (suite *ClusterTestSuite) setupCluster() {
suite.nodeManager = NewNodeManager()
for i, lis := range suite.listeners {
node := NewNodeInfo(ImmutableNodeInfo{
NodeID: int64(i),
Address: lis.Addr().String(),
Hostname: "localhost",
})
suite.nodeManager.Add(node)
}
suite.cluster = NewCluster(suite.nodeManager, DefaultQueryNodeCreator)
}
func (suite *ClusterTestSuite) createTestServers() []querypb.QueryNodeServer {
// create 2 mock servers with 1 always return error
ret := make([]querypb.QueryNodeServer, 0, 2)
ret = append(ret, suite.createDefaultMockServer())
ret = append(ret, suite.createFailedMockServer())
return ret
}
func (suite *ClusterTestSuite) createDefaultMockServer() querypb.QueryNodeServer {
succStatus := merr.Success()
svr := mocks.NewMockQueryNodeServer(suite.T())
// TODO: register more mock methods
svr.EXPECT().LoadSegments(
mock.Anything,
mock.AnythingOfType("*querypb.LoadSegmentsRequest"),
).Maybe().Return(succStatus, nil)
svr.EXPECT().WatchDmChannels(
mock.Anything,
mock.AnythingOfType("*querypb.WatchDmChannelsRequest"),
).Maybe().Return(succStatus, nil)
svr.EXPECT().UnsubDmChannel(
mock.Anything,
mock.AnythingOfType("*querypb.UnsubDmChannelRequest"),
).Maybe().Return(succStatus, nil)
svr.EXPECT().ReleaseSegments(
mock.Anything,
mock.AnythingOfType("*querypb.ReleaseSegmentsRequest"),
).Maybe().Return(succStatus, nil)
svr.EXPECT().LoadPartitions(
mock.Anything,
mock.AnythingOfType("*querypb.LoadPartitionsRequest"),
).Maybe().Return(succStatus, nil)
svr.EXPECT().ReleasePartitions(
mock.Anything,
mock.AnythingOfType("*querypb.ReleasePartitionsRequest"),
).Maybe().Return(succStatus, nil)
svr.EXPECT().GetDataDistribution(
mock.Anything,
mock.AnythingOfType("*querypb.GetDataDistributionRequest"),
).Maybe().Return(&querypb.GetDataDistributionResponse{Status: succStatus}, nil)
svr.EXPECT().GetMetrics(
mock.Anything,
mock.AnythingOfType("*milvuspb.GetMetricsRequest"),
).Maybe().Return(&milvuspb.GetMetricsResponse{Status: succStatus}, nil)
svr.EXPECT().SyncDistribution(
mock.Anything,
mock.AnythingOfType("*querypb.SyncDistributionRequest"),
).Maybe().Return(succStatus, nil)
svr.EXPECT().GetComponentStates(
mock.Anything,
mock.AnythingOfType("*milvuspb.GetComponentStatesRequest"),
).Maybe().Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
return svr
}
func (suite *ClusterTestSuite) createFailedMockServer() querypb.QueryNodeServer {
failStatus := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unexpected error",
}
svr := mocks.NewMockQueryNodeServer(suite.T())
// TODO: register more mock methods
svr.EXPECT().LoadSegments(
mock.Anything,
mock.AnythingOfType("*querypb.LoadSegmentsRequest"),
).Maybe().Return(failStatus, nil)
svr.EXPECT().WatchDmChannels(
mock.Anything,
mock.AnythingOfType("*querypb.WatchDmChannelsRequest"),
).Maybe().Return(failStatus, nil)
svr.EXPECT().UnsubDmChannel(
mock.Anything,
mock.AnythingOfType("*querypb.UnsubDmChannelRequest"),
).Maybe().Return(failStatus, nil)
svr.EXPECT().ReleaseSegments(
mock.Anything,
mock.AnythingOfType("*querypb.ReleaseSegmentsRequest"),
).Maybe().Return(failStatus, nil)
svr.EXPECT().LoadPartitions(
mock.Anything,
mock.AnythingOfType("*querypb.LoadPartitionsRequest"),
).Maybe().Return(failStatus, nil)
svr.EXPECT().ReleasePartitions(
mock.Anything,
mock.AnythingOfType("*querypb.ReleasePartitionsRequest"),
).Maybe().Return(failStatus, nil)
svr.EXPECT().GetDataDistribution(
mock.Anything,
mock.AnythingOfType("*querypb.GetDataDistributionRequest"),
).Maybe().Return(&querypb.GetDataDistributionResponse{Status: failStatus}, nil)
svr.EXPECT().GetMetrics(
mock.Anything,
mock.AnythingOfType("*milvuspb.GetMetricsRequest"),
).Maybe().Return(&milvuspb.GetMetricsResponse{Status: failStatus}, nil)
svr.EXPECT().SyncDistribution(
mock.Anything,
mock.AnythingOfType("*querypb.SyncDistributionRequest"),
).Maybe().Return(failStatus, nil)
svr.EXPECT().GetComponentStates(
mock.Anything,
mock.AnythingOfType("*milvuspb.GetComponentStatesRequest"),
).Maybe().Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Abnormal},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
return svr
}
func (suite *ClusterTestSuite) TestLoadSegments() {
ctx := context.TODO()
status, err := suite.cluster.LoadSegments(ctx, 0, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{},
Infos: []*querypb.SegmentLoadInfo{{}},
})
suite.NoError(err)
merr.Ok(status)
status, err = suite.cluster.LoadSegments(ctx, 1, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{},
Infos: []*querypb.SegmentLoadInfo{{}},
})
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
suite.Equal("unexpected error", status.GetReason())
_, err = suite.cluster.LoadSegments(ctx, 3, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{},
Infos: []*querypb.SegmentLoadInfo{{}},
})
suite.Error(err)
suite.IsType(WrapErrNodeNotFound(3), err)
}
func (suite *ClusterTestSuite) TestWatchDmChannels() {
ctx := context.TODO()
status, err := suite.cluster.WatchDmChannels(ctx, 0, &querypb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
merr.Ok(status)
status, err = suite.cluster.WatchDmChannels(ctx, 1, &querypb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
suite.Equal("unexpected error", status.GetReason())
}
func (suite *ClusterTestSuite) TestUnsubDmChannel() {
ctx := context.TODO()
status, err := suite.cluster.UnsubDmChannel(ctx, 0, &querypb.UnsubDmChannelRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
merr.Ok(status)
status, err = suite.cluster.UnsubDmChannel(ctx, 1, &querypb.UnsubDmChannelRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
suite.Equal("unexpected error", status.GetReason())
}
func (suite *ClusterTestSuite) TestReleaseSegments() {
ctx := context.TODO()
status, err := suite.cluster.ReleaseSegments(ctx, 0, &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
merr.Ok(status)
status, err = suite.cluster.ReleaseSegments(ctx, 1, &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
suite.Equal("unexpected error", status.GetReason())
}
func (suite *ClusterTestSuite) TestLoadAndReleasePartitions() {
ctx := context.TODO()
status, err := suite.cluster.LoadPartitions(ctx, 0, &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
merr.Ok(status)
status, err = suite.cluster.LoadPartitions(ctx, 1, &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
suite.Equal("unexpected error", status.GetReason())
status, err = suite.cluster.ReleasePartitions(ctx, 0, &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
merr.Ok(status)
status, err = suite.cluster.ReleasePartitions(ctx, 1, &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
suite.Equal("unexpected error", status.GetReason())
}
func (suite *ClusterTestSuite) TestGetDataDistribution() {
ctx := context.TODO()
resp, err := suite.cluster.GetDataDistribution(ctx, 0, &querypb.GetDataDistributionRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(merr.Success(), resp.GetStatus())
resp, err = suite.cluster.GetDataDistribution(ctx, 1, &querypb.GetDataDistributionRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode())
suite.Equal("unexpected error", resp.GetStatus().GetReason())
}
func (suite *ClusterTestSuite) TestGetMetrics() {
ctx := context.TODO()
resp, err := suite.cluster.GetMetrics(ctx, 0, &milvuspb.GetMetricsRequest{})
suite.NoError(err)
suite.Equal(merr.Success(), resp.GetStatus())
resp, err = suite.cluster.GetMetrics(ctx, 1, &milvuspb.GetMetricsRequest{})
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode())
suite.Equal("unexpected error", resp.GetStatus().GetReason())
}
func (suite *ClusterTestSuite) TestSyncDistribution() {
ctx := context.TODO()
status, err := suite.cluster.SyncDistribution(ctx, 0, &querypb.SyncDistributionRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
merr.Ok(status)
status, err = suite.cluster.SyncDistribution(ctx, 1, &querypb.SyncDistributionRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
suite.Equal("unexpected error", status.GetReason())
}
func (suite *ClusterTestSuite) TestGetComponentStates() {
ctx := context.TODO()
status, err := suite.cluster.GetComponentStates(ctx, 0)
suite.NoError(err)
suite.Equal(status.State.GetStateCode(), commonpb.StateCode_Healthy)
status, err = suite.cluster.GetComponentStates(ctx, 1)
suite.NoError(err)
suite.Equal(status.State.GetStateCode(), commonpb.StateCode_Abnormal)
}
func TestClusterSuite(t *testing.T) {
suite.Run(t, new(ClusterTestSuite))
}