milvus/internal/streamingnode/client/manager/manager_test.go

175 lines
5.9 KiB
Go

package manager
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"google.golang.org/grpc/resolver"
"github.com/milvus-io/milvus/internal/mocks/util/streamingutil/service/mock_lazygrpc"
"github.com/milvus-io/milvus/internal/mocks/util/streamingutil/service/mock_resolver"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
"github.com/milvus-io/milvus/pkg/v2/mocks/proto/mock_streamingpb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestManager(t *testing.T) {
rb := mock_resolver.NewMockBuilder(t)
managerService := mock_lazygrpc.NewMockService[streamingpb.StreamingNodeManagerServiceClient](t)
m := &managerClientImpl{
lifetime: typeutil.NewLifetime(),
stopped: make(chan struct{}),
rb: rb,
service: managerService,
}
r := mock_resolver.NewMockResolver(t)
rb.EXPECT().Resolver().Return(r)
r.EXPECT().Watch(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, f func(discoverer.VersionedState) error) error {
f(discoverer.VersionedState{})
return context.Canceled
})
// Test WatchNodeChanged.
resultCh, err := m.WatchNodeChanged(context.Background())
assert.NoError(t, err)
_, ok := <-resultCh
assert.True(t, ok)
_, ok = <-resultCh
assert.False(t, ok)
// Test CollectAllStatus
r.EXPECT().GetLatestState().RunAndReturn(func() discoverer.VersionedState {
return discoverer.VersionedState{
Version: typeutil.VersionInt64(1),
State: resolver.State{},
}
})
// Not address here.
nodes, err := m.CollectAllStatus(context.Background())
assert.NoError(t, err)
assert.Len(t, nodes, 0)
// Has address.
managerServiceClient := mock_streamingpb.NewMockStreamingNodeManagerServiceClient(t)
managerService.EXPECT().GetService(mock.Anything).RunAndReturn(func(ctx context.Context) (streamingpb.StreamingNodeManagerServiceClient, error) {
return managerServiceClient, nil
})
managerServiceClient.EXPECT().CollectStatus(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, snmcsr *streamingpb.StreamingNodeManagerCollectStatusRequest, co ...grpc.CallOption) (*streamingpb.StreamingNodeManagerCollectStatusResponse, error) {
return &streamingpb.StreamingNodeManagerCollectStatusResponse{}, nil
})
i := 0
states := []map[uint64]bool{
{1: false, 2: false, 3: true},
{1: true, 2: false},
}
r.EXPECT().GetLatestState().Unset()
r.EXPECT().GetLatestState().RunAndReturn(func() discoverer.VersionedState {
s := newVersionedState(int64(i), states[i])
i++
return s
})
nodes, err = m.CollectAllStatus(context.Background())
assert.NoError(t, err)
assert.Len(t, nodes, 3)
assert.ErrorIs(t, nodes[3].Err, types.ErrNotAlive)
assert.ErrorIs(t, nodes[1].Err, types.ErrStopping)
// Test Assign
serverID := int64(2)
managerServiceClient.EXPECT().Assign(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, snmcsr *streamingpb.StreamingNodeManagerAssignRequest, co ...grpc.CallOption) (*streamingpb.StreamingNodeManagerAssignResponse, error) {
pickedServerID, ok := contextutil.GetPickServerID(ctx)
assert.True(t, ok)
assert.Equal(t, serverID, pickedServerID)
return nil, nil
})
err = m.Assign(context.Background(), types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "p", Term: 1},
Node: types.StreamingNodeInfo{ServerID: serverID},
})
assert.NoError(t, err)
// Test Remove
serverID = int64(2)
managerServiceClient.EXPECT().Remove(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, snmrr *streamingpb.StreamingNodeManagerRemoveRequest, co ...grpc.CallOption) (*streamingpb.StreamingNodeManagerRemoveResponse, error) {
pickedServerID, ok := contextutil.GetPickServerID(ctx)
assert.True(t, ok)
assert.Equal(t, serverID, pickedServerID)
return nil, nil
})
err = m.Remove(context.Background(), types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "p", Term: 1},
Node: types.StreamingNodeInfo{ServerID: serverID},
})
assert.NoError(t, err)
// Test Close
managerService.EXPECT().Close().Return()
rb.EXPECT().Close().Return()
m.Close()
nodes, err = m.CollectAllStatus(context.Background())
assert.Nil(t, nodes)
assert.Error(t, err)
err = m.Assign(context.Background(), types.PChannelInfoAssigned{})
assert.Error(t, err)
err = m.Remove(context.Background(), types.PChannelInfoAssigned{})
assert.Error(t, err)
resultCh, err = m.WatchNodeChanged(context.Background())
assert.Nil(t, resultCh)
assert.Error(t, err)
}
func newVersionedState(version int64, serverIDs map[uint64]bool) discoverer.VersionedState {
state := discoverer.VersionedState{
Version: typeutil.VersionInt64(version),
State: resolver.State{
Addresses: make([]resolver.Address, 0, len(serverIDs)),
},
}
for serverID, stopping := range serverIDs {
state.State.Addresses = append(state.State.Addresses, resolver.Address{
Addr: fmt.Sprintf("localhost:%d", serverID),
BalancerAttributes: attributes.WithSession(
new(attributes.Attributes), &sessionutil.SessionRaw{
ServerID: int64(serverID),
Stopping: stopping,
},
),
})
}
return state
}
func TestDial(t *testing.T) {
paramtable.Init()
err := etcd.InitEtcdServer(true, "", t.TempDir(), "stdout", "info")
assert.NoError(t, err)
defer etcd.StopEtcdServer()
c, err := etcd.GetEmbedEtcdClient()
assert.NoError(t, err)
assert.NotNil(t, c)
client := NewManagerClient(c)
assert.NotNil(t, client)
time.Sleep(100 * time.Millisecond)
client.Close()
}