milvus/internal/streamingcoord/server/balancer/balancer_test.go

117 lines
3.5 KiB
Go
Raw Normal View History

package balancer_test
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/client/mock_manager"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
_ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestBalancer(t *testing.T) {
paramtable.Init()
streamingNodeManager := mock_manager.NewMockManagerClient(t)
streamingNodeManager.EXPECT().WatchNodeChanged(mock.Anything).Return(make(chan struct{}), nil)
streamingNodeManager.EXPECT().Assign(mock.Anything, mock.Anything).Return(nil)
streamingNodeManager.EXPECT().Remove(mock.Anything, mock.Anything).Return(nil)
streamingNodeManager.EXPECT().CollectAllStatus(mock.Anything).Return(map[int64]*types.StreamingNodeStatus{
1: {
StreamingNodeInfo: types.StreamingNodeInfo{
ServerID: 1,
Address: "localhost:1",
},
},
2: {
StreamingNodeInfo: types.StreamingNodeInfo{
ServerID: 2,
Address: "localhost:2",
},
},
3: {
StreamingNodeInfo: types.StreamingNodeInfo{
ServerID: 3,
Address: "localhost:3",
},
},
4: {
StreamingNodeInfo: types.StreamingNodeInfo{
ServerID: 3,
Address: "localhost:3",
},
Err: types.ErrStopping,
},
}, nil)
catalog := mock_metastore.NewMockStreamingCoordCataLog(t)
resource.InitForTest(resource.OptStreamingCatalog(catalog), resource.OptStreamingManagerClient(streamingNodeManager))
catalog.EXPECT().ListPChannel(mock.Anything).Unset()
catalog.EXPECT().ListPChannel(mock.Anything).RunAndReturn(func(ctx context.Context) ([]*streamingpb.PChannelMeta, error) {
return []*streamingpb.PChannelMeta{
{
Channel: &streamingpb.PChannelInfo{
Name: "test-channel-1",
Term: 1,
},
State: streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED,
Node: &streamingpb.StreamingNodeInfo{ServerId: 1},
},
{
Channel: &streamingpb.PChannelInfo{
Name: "test-channel-2",
Term: 1,
},
State: streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED,
Node: &streamingpb.StreamingNodeInfo{ServerId: 4},
},
{
Channel: &streamingpb.PChannelInfo{
Name: "test-channel-3",
Term: 2,
},
State: streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNING,
Node: &streamingpb.StreamingNodeInfo{ServerId: 2},
},
}, nil
})
catalog.EXPECT().SavePChannels(mock.Anything, mock.Anything).Return(nil).Maybe()
ctx := context.Background()
b, err := balancer.RecoverBalancer(ctx, "pchannel_count_fair")
assert.NoError(t, err)
assert.NotNil(t, b)
defer b.Close()
b.MarkAsUnavailable(ctx, []types.PChannelInfo{{
Name: "test-channel-1",
Term: 1,
}})
b.Trigger(ctx)
doneErr := errors.New("done")
err = b.WatchChannelAssignments(ctx, func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error {
// should one pchannel be assigned to per nodes
nodeIDs := typeutil.NewSet[int64]()
if len(relations) == 3 {
for _, status := range relations {
nodeIDs.Insert(status.Node.ServerID)
}
assert.Equal(t, 3, nodeIDs.Len())
return doneErr
}
return nil
})
assert.ErrorIs(t, err, doneErr)
}