milvus/internal/datanode/index/scheduler_test.go

229 lines
5.5 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package index
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
type fakeTaskState int
const (
fakeTaskInited = iota
fakeTaskEnqueued
fakeTaskPrepared
fakeTaskLoadedData
fakeTaskBuiltIndex
fakeTaskSavedIndexes
)
type stagectx struct {
mu sync.Mutex
curstate fakeTaskState
state2cancel fakeTaskState
ch chan struct{}
}
var _ context.Context = &stagectx{}
func (s *stagectx) Deadline() (time.Time, bool) {
return time.Now(), false
}
func (s *stagectx) Done() <-chan struct{} {
s.mu.Lock()
defer s.mu.Unlock()
if s.curstate == s.state2cancel {
close(s.ch)
}
return s.ch
}
func (s *stagectx) Err() error {
select {
case <-s.ch:
return errors.New("canceled")
default:
return nil
}
}
func (s *stagectx) Value(k interface{}) interface{} {
return nil
}
func (s *stagectx) setState(state fakeTaskState) {
s.mu.Lock()
defer s.mu.Unlock()
s.curstate = state
}
var _taskwg sync.WaitGroup
type fakeTask struct {
id int
ctx context.Context
state fakeTaskState
reterr map[fakeTaskState]error
retstate indexpb.JobState
expectedState indexpb.JobState
failReason string
}
var _ Task = &fakeTask{}
func (t *fakeTask) Name() string {
return fmt.Sprintf("fake-task-%d", t.id)
}
func (t *fakeTask) Ctx() context.Context {
return t.ctx
}
func (t *fakeTask) GetSlot() int64 {
return 1
}
func (t *fakeTask) OnEnqueue(ctx context.Context) error {
_taskwg.Add(1)
t.state = fakeTaskEnqueued
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
}
func (t *fakeTask) PreExecute(ctx context.Context) error {
t.state = fakeTaskPrepared
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
}
func (t *fakeTask) LoadData(ctx context.Context) error {
t.state = fakeTaskLoadedData
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
}
func (t *fakeTask) Execute(ctx context.Context) error {
t.state = fakeTaskBuiltIndex
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
}
func (t *fakeTask) PostExecute(ctx context.Context) error {
t.state = fakeTaskSavedIndexes
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
}
func (t *fakeTask) Reset() {
_taskwg.Done()
}
func (t *fakeTask) SetState(state indexpb.JobState, failReason string) {
t.retstate = state
t.failReason = failReason
}
func (t *fakeTask) GetState() indexpb.JobState {
return t.retstate
}
var (
idLock sync.Mutex
id = 0
)
func newTask(cancelStage fakeTaskState, reterror map[fakeTaskState]error, expectedState indexpb.JobState) Task {
idLock.Lock()
newID := id
id++
idLock.Unlock()
return &fakeTask{
reterr: reterror,
id: newID,
ctx: &stagectx{
curstate: fakeTaskInited,
state2cancel: cancelStage,
ch: make(chan struct{}),
},
state: fakeTaskInited,
retstate: indexpb.JobState_JobStateNone,
expectedState: expectedState,
}
}
func TestIndexTaskScheduler(t *testing.T) {
paramtable.Init()
scheduler := NewTaskScheduler(context.TODO())
scheduler.Start()
tasks := make([]Task, 0)
tasks = append(tasks,
newTask(fakeTaskEnqueued, nil, indexpb.JobState_JobStateRetry),
newTask(fakeTaskPrepared, nil, indexpb.JobState_JobStateRetry),
newTask(fakeTaskBuiltIndex, nil, indexpb.JobState_JobStateRetry),
newTask(fakeTaskSavedIndexes, nil, indexpb.JobState_JobStateFinished),
newTask(fakeTaskSavedIndexes, map[fakeTaskState]error{fakeTaskSavedIndexes: errors.New("auth failed")}, indexpb.JobState_JobStateRetry))
for _, task := range tasks {
assert.Nil(t, scheduler.TaskQueue.Enqueue(task))
}
_taskwg.Wait()
scheduler.Close()
scheduler.wg.Wait()
for _, task := range tasks[:len(tasks)-1] {
assert.Equal(t, task.GetState(), task.(*fakeTask).expectedState)
assert.Equal(t, task.Ctx().(*stagectx).curstate, task.Ctx().(*stagectx).state2cancel)
}
assert.Equal(t, tasks[len(tasks)-1].GetState(), tasks[len(tasks)-1].(*fakeTask).expectedState)
assert.Equal(t, tasks[len(tasks)-1].Ctx().(*stagectx).curstate, fakeTaskState(fakeTaskSavedIndexes))
scheduler = NewTaskScheduler(context.TODO())
tasks = make([]Task, 0, 1024)
for i := 0; i < 1024; i++ {
tasks = append(tasks, newTask(fakeTaskSavedIndexes, nil, indexpb.JobState_JobStateFinished))
assert.Nil(t, scheduler.TaskQueue.Enqueue(tasks[len(tasks)-1]))
}
failTask := newTask(fakeTaskSavedIndexes, nil, indexpb.JobState_JobStateFinished)
err := scheduler.TaskQueue.Enqueue(failTask)
assert.Error(t, err)
failTask.Reset()
scheduler.Start()
_taskwg.Wait()
scheduler.Close()
scheduler.wg.Wait()
for _, task := range tasks {
assert.Equal(t, task.GetState(), indexpb.JobState_JobStateFinished)
}
}