milvus/internal/indexnode/task_scheduler_test.go

209 lines
5.0 KiB
Go

package indexnode
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/stretchr/testify/assert"
)
type fakeTaskState int
const (
fakeTaskInited = iota
fakeTaskEnqueued
fakeTaskPrepared
fakeTaskLoadedData
fakeTaskBuiltIndex
fakeTaskSavedIndexes
)
type stagectx struct {
mu sync.Mutex
curstate fakeTaskState
state2cancel fakeTaskState
ch chan struct{}
}
var _ context.Context = &stagectx{}
func (s *stagectx) Deadline() (time.Time, bool) {
return time.Now(), false
}
func (s *stagectx) Done() <-chan struct{} {
s.mu.Lock()
defer s.mu.Unlock()
if s.curstate == s.state2cancel {
close(s.ch)
}
return s.ch
}
func (s *stagectx) Err() error {
select {
case <-s.ch:
return fmt.Errorf("canceled")
default:
return nil
}
}
func (s *stagectx) Value(k interface{}) interface{} {
return nil
}
func (s *stagectx) setState(state fakeTaskState) {
s.mu.Lock()
defer s.mu.Unlock()
s.curstate = state
}
var _taskwg sync.WaitGroup
type fakeTask struct {
id int
ctx context.Context
state fakeTaskState
reterr map[fakeTaskState]error
retstate commonpb.IndexState
expectedState commonpb.IndexState
failReason string
}
var _ task = &fakeTask{}
func (t *fakeTask) Name() string {
return fmt.Sprintf("fake-task-%d", t.id)
}
func (t *fakeTask) Ctx() context.Context {
return t.ctx
}
func (t *fakeTask) OnEnqueue(ctx context.Context) error {
_taskwg.Add(1)
t.state = fakeTaskEnqueued
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
}
func (t *fakeTask) Prepare(ctx context.Context) error {
t.state = fakeTaskPrepared
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
}
func (t *fakeTask) LoadData(ctx context.Context) error {
t.state = fakeTaskLoadedData
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
}
func (t *fakeTask) BuildIndex(ctx context.Context) error {
t.state = fakeTaskBuiltIndex
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
}
func (t *fakeTask) SaveIndexFiles(ctx context.Context) error {
t.state = fakeTaskSavedIndexes
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
}
func (t *fakeTask) Reset() {
_taskwg.Done()
}
func (t *fakeTask) SetState(state commonpb.IndexState, failReason string) {
t.retstate = state
t.failReason = failReason
}
func (t *fakeTask) GetState() commonpb.IndexState {
return t.retstate
}
var (
idLock sync.Mutex
id = 0
)
func newTask(cancelStage fakeTaskState, reterror map[fakeTaskState]error, expectedState commonpb.IndexState) task {
idLock.Lock()
newID := id
id++
idLock.Unlock()
return &fakeTask{
reterr: reterror,
id: newID,
ctx: &stagectx{
curstate: fakeTaskInited,
state2cancel: cancelStage,
ch: make(chan struct{}),
},
state: fakeTaskInited,
retstate: commonpb.IndexState_IndexStateNone,
expectedState: expectedState,
}
}
func TestIndexTaskScheduler(t *testing.T) {
Params.Init()
scheduler := NewTaskScheduler(context.TODO())
scheduler.Start()
tasks := make([]task, 0)
tasks = append(tasks,
newTask(fakeTaskEnqueued, nil, commonpb.IndexState_Failed),
newTask(fakeTaskLoadedData, nil, commonpb.IndexState_Failed),
newTask(fakeTaskPrepared, nil, commonpb.IndexState_Failed),
newTask(fakeTaskBuiltIndex, nil, commonpb.IndexState_Failed),
newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished),
newTask(fakeTaskSavedIndexes, map[fakeTaskState]error{fakeTaskLoadedData: ErrNoSuchKey}, commonpb.IndexState_Failed),
newTask(fakeTaskSavedIndexes, map[fakeTaskState]error{fakeTaskSavedIndexes: fmt.Errorf("auth failed")}, commonpb.IndexState_Retry))
for _, task := range tasks {
assert.Nil(t, scheduler.IndexBuildQueue.Enqueue(task))
}
_taskwg.Wait()
scheduler.Close()
scheduler.wg.Wait()
for _, task := range tasks[:len(tasks)-2] {
assert.Equal(t, task.GetState(), task.(*fakeTask).expectedState)
assert.Equal(t, task.Ctx().(*stagectx).curstate, task.Ctx().(*stagectx).state2cancel)
}
assert.Equal(t, tasks[len(tasks)-2].GetState(), tasks[len(tasks)-2].(*fakeTask).expectedState)
assert.Equal(t, tasks[len(tasks)-2].Ctx().(*stagectx).curstate, fakeTaskState(fakeTaskLoadedData))
assert.Equal(t, tasks[len(tasks)-1].GetState(), tasks[len(tasks)-1].(*fakeTask).expectedState)
assert.Equal(t, tasks[len(tasks)-1].Ctx().(*stagectx).curstate, fakeTaskState(fakeTaskSavedIndexes))
scheduler = NewTaskScheduler(context.TODO())
tasks = make([]task, 0, 1024)
for i := 0; i < 1024; i++ {
tasks = append(tasks, newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished))
assert.Nil(t, scheduler.IndexBuildQueue.Enqueue(tasks[len(tasks)-1]))
}
failTask := newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished)
err := scheduler.IndexBuildQueue.Enqueue(failTask)
assert.Error(t, err)
failTask.Reset()
scheduler.Start()
_taskwg.Wait()
scheduler.Close()
scheduler.wg.Wait()
for _, task := range tasks {
assert.Equal(t, task.GetState(), commonpb.IndexState_Finished)
}
}