milvus/internal/streamingnode/client/handler/assignment/watcher_test.go

92 lines
2.4 KiB
Go

package assignment
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/resolver"
"github.com/milvus-io/milvus/internal/mocks/util/streamingutil/service/mock_resolver"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestWatcher(t *testing.T) {
r := mock_resolver.NewMockResolver(t)
ch := make(chan discoverer.VersionedState)
r.EXPECT().Watch(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, f func(s discoverer.VersionedState) error) error {
for {
select {
case v, ok := <-ch:
if !ok {
return nil
}
f(v)
case <-ctx.Done():
return ctx.Err()
}
}
})
w := NewWatcher(r)
defer w.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
a := w.Get(ctx, "test_pchannel")
assert.Nil(t, a)
err := w.Watch(ctx, "test_pchannel", nil)
assert.ErrorIs(t, err, context.DeadlineExceeded)
ch <- discoverer.VersionedState{
Version: typeutil.VersionInt64(1),
State: resolver.State{
Addresses: []resolver.Address{
{
Addr: "test_addr",
BalancerAttributes: attributes.WithChannelAssignmentInfo(
new(attributes.Attributes),
&types.StreamingNodeAssignment{
NodeInfo: types.StreamingNodeInfo{
ServerID: 1,
Address: "test_addr",
},
Channels: map[string]types.PChannelInfo{
"test_pchannel": {
Name: "test_pchannel",
Term: 1,
},
"test_pchannel_2": {
Name: "test_pchannel_2",
Term: 2,
},
},
},
),
},
},
},
}
err = w.Watch(context.Background(), "test_pchannel", nil)
assert.NoError(t, err)
a = w.Get(ctx, "test_pchannel")
assert.NotNil(t, a)
assert.Equal(t, int64(1), a.Channel.Term)
err = w.Watch(context.Background(), "test_pchannel_2", nil)
assert.NoError(t, err)
a = w.Get(ctx, "test_pchannel_2")
assert.NotNil(t, a)
assert.Equal(t, int64(2), a.Channel.Term)
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err = w.Watch(ctx, "test_pchannel", a)
assert.ErrorIs(t, err, context.DeadlineExceeded)
}