milvus/internal/streamingnode/server/walmanager/wal_state_test.go

184 lines
4.9 KiB
Go

package walmanager
import (
"context"
"sync"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
)
func TestInitialWALState(t *testing.T) {
currentState := initialCurrentWALState
assert.Equal(t, types.InitialTerm, currentState.Term())
assert.False(t, currentState.Available())
assert.Nil(t, currentState.GetWAL())
assert.NoError(t, currentState.GetLastError())
assert.Equal(t, toStateString(currentState), "(-1,false)")
expectedState := initialExpectedWALState
assert.Equal(t, types.InitialTerm, expectedState.Term())
assert.False(t, expectedState.Available())
assert.Zero(t, expectedState.GetPChannelInfo())
assert.Equal(t, context.Background(), expectedState.Context())
assert.Equal(t, toStateString(expectedState), "(-1,false)")
}
func TestAvailableCurrentWALState(t *testing.T) {
l := mock_wal.NewMockWAL(t)
l.EXPECT().Channel().Return(types.PChannelInfo{
Term: 1,
})
state := newAvailableCurrentState(l)
assert.Equal(t, int64(1), state.Term())
assert.True(t, state.Available())
assert.Equal(t, l, state.GetWAL())
assert.Nil(t, state.GetLastError())
assert.Equal(t, toStateString(state), "(1,true)")
}
func TestUnavailableCurrentWALState(t *testing.T) {
err := errors.New("test")
state := newUnavailableCurrentState(1, err)
assert.Equal(t, int64(1), state.Term())
assert.False(t, state.Available())
assert.Nil(t, state.GetWAL())
assert.ErrorIs(t, state.GetLastError(), err)
assert.Equal(t, toStateString(state), "(1,false)")
}
func TestAvailableExpectedWALState(t *testing.T) {
channel := types.PChannelInfo{}
state := newAvailableExpectedState(context.Background(), channel)
assert.Equal(t, int64(0), state.Term())
assert.True(t, state.Available())
assert.Equal(t, context.Background(), state.Context())
assert.Equal(t, channel, state.GetPChannelInfo())
assert.Equal(t, toStateString(state), "(0,true)")
}
func TestUnavailableExpectedWALState(t *testing.T) {
state := newUnavailableExpectedState(1)
assert.Equal(t, int64(1), state.Term())
assert.False(t, state.Available())
assert.Zero(t, state.GetPChannelInfo())
assert.Equal(t, context.Background(), state.Context())
assert.Equal(t, toStateString(state), "(1,false)")
}
func TestIsStateBefore(t *testing.T) {
// initial state comparison.
assert.False(t, isStateBefore(initialCurrentWALState, initialExpectedWALState))
assert.False(t, isStateBefore(initialExpectedWALState, initialCurrentWALState))
l := mock_wal.NewMockWAL(t)
l.EXPECT().Channel().Return(types.PChannelInfo{
Term: 1,
})
cases := []walState{
newAvailableCurrentState(l),
newUnavailableCurrentState(1, nil),
newAvailableExpectedState(context.Background(), types.PChannelInfo{
Term: 3,
}),
newUnavailableExpectedState(5),
}
for _, s := range cases {
assert.True(t, isStateBefore(initialCurrentWALState, s))
assert.True(t, isStateBefore(initialExpectedWALState, s))
assert.False(t, isStateBefore(s, initialCurrentWALState))
assert.False(t, isStateBefore(s, initialExpectedWALState))
}
for i, s1 := range cases {
for _, s2 := range cases[:i] {
assert.True(t, isStateBefore(s2, s1))
assert.False(t, isStateBefore(s1, s2))
}
}
}
func TestStateWithCond(t *testing.T) {
stateCond := newWALStateWithCond(initialCurrentWALState)
assert.Equal(t, initialCurrentWALState, stateCond.GetState())
// test notification.
wg := sync.WaitGroup{}
targetState := newUnavailableCurrentState(10, nil)
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
oldState := stateCond.GetState()
for {
if !isStateBefore(oldState, targetState) {
break
}
err := stateCond.WatchChanged(context.Background(), oldState)
assert.NoError(t, err)
newState := stateCond.GetState()
assert.True(t, isStateBefore(oldState, newState))
oldState = newState
}
}()
wg.Add(1)
go func() {
defer wg.Done()
oldState := stateCond.GetState()
for i := int64(0); i < 10; i++ {
var newState currentWALState
if i%2 == 0 {
l := mock_wal.NewMockWAL(t)
l.EXPECT().Channel().Return(types.PChannelInfo{
Term: i % 2,
}).Maybe()
newState = newAvailableCurrentState(l)
} else {
newState = newUnavailableCurrentState(i%3, nil)
}
stateCond.SetStateAndNotify(newState)
// updated state should never before old state.
stateNow := stateCond.GetState()
assert.False(t, isStateBefore(stateNow, oldState))
oldState = stateNow
}
stateCond.SetStateAndNotify(targetState)
}()
}
ch := make(chan struct{})
go func() {
wg.Wait()
close(ch)
}()
select {
case <-time.After(time.Second * 3):
t.Errorf("test should never block")
case <-ch:
}
// test cancel.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err := stateCond.WatchChanged(ctx, targetState)
assert.ErrorIs(t, err, context.DeadlineExceeded)
}