mirror of https://github.com/milvus-io/milvus.git
265 lines
8.2 KiB
Go
265 lines
8.2 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package task
|
|
|
|
import (
|
|
"context"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
mock "github.com/stretchr/testify/mock"
|
|
|
|
"github.com/milvus-io/milvus/internal/datacoord/session"
|
|
taskcommon "github.com/milvus-io/milvus/pkg/v2/taskcommon"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
)
|
|
|
|
func init() {
|
|
paramtable.Init()
|
|
}
|
|
|
|
func TestGlobalScheduler_Enqueue(t *testing.T) {
|
|
cluster := session.NewMockCluster(t)
|
|
scheduler := NewGlobalTaskScheduler(context.TODO(), cluster)
|
|
|
|
task := NewMockTask(t)
|
|
task.EXPECT().GetTaskID().Return(1)
|
|
task.EXPECT().GetTaskState().Return(taskcommon.Init)
|
|
task.EXPECT().GetTaskType().Return(taskcommon.Compaction)
|
|
task.EXPECT().SetTaskTime(mock.Anything, mock.Anything).Return()
|
|
scheduler.Enqueue(task)
|
|
assert.Equal(t, 1, len(scheduler.(*globalTaskScheduler).pendingTasks.TaskIDs()))
|
|
scheduler.Enqueue(task)
|
|
assert.Equal(t, 1, len(scheduler.(*globalTaskScheduler).pendingTasks.TaskIDs()))
|
|
|
|
task = NewMockTask(t)
|
|
task.EXPECT().GetTaskID().Return(2)
|
|
task.EXPECT().GetTaskState().Return(taskcommon.InProgress)
|
|
task.EXPECT().GetTaskType().Return(taskcommon.Compaction)
|
|
task.EXPECT().SetTaskTime(mock.Anything, mock.Anything).Return()
|
|
scheduler.Enqueue(task)
|
|
assert.Equal(t, 1, scheduler.(*globalTaskScheduler).runningTasks.Len())
|
|
scheduler.Enqueue(task)
|
|
assert.Equal(t, 1, scheduler.(*globalTaskScheduler).runningTasks.Len())
|
|
}
|
|
|
|
func TestGlobalScheduler_AbortAndRemoveTask(t *testing.T) {
|
|
cluster := session.NewMockCluster(t)
|
|
scheduler := NewGlobalTaskScheduler(context.TODO(), cluster)
|
|
|
|
task := NewMockTask(t)
|
|
task.EXPECT().GetTaskID().Return(1)
|
|
task.EXPECT().GetTaskState().Return(taskcommon.Init)
|
|
task.EXPECT().GetTaskType().Return(taskcommon.Compaction)
|
|
task.EXPECT().SetTaskTime(mock.Anything, mock.Anything).Return()
|
|
task.EXPECT().DropTaskOnWorker(mock.Anything).Return()
|
|
scheduler.Enqueue(task)
|
|
assert.Equal(t, 1, len(scheduler.(*globalTaskScheduler).pendingTasks.TaskIDs()))
|
|
scheduler.AbortAndRemoveTask(1)
|
|
assert.Equal(t, 0, len(scheduler.(*globalTaskScheduler).pendingTasks.TaskIDs()))
|
|
|
|
task = NewMockTask(t)
|
|
task.EXPECT().GetTaskID().Return(2)
|
|
task.EXPECT().GetTaskState().Return(taskcommon.InProgress)
|
|
task.EXPECT().GetTaskType().Return(taskcommon.Compaction)
|
|
task.EXPECT().SetTaskTime(mock.Anything, mock.Anything).Return()
|
|
task.EXPECT().DropTaskOnWorker(mock.Anything).Return()
|
|
scheduler.Enqueue(task)
|
|
assert.Equal(t, 1, scheduler.(*globalTaskScheduler).runningTasks.Len())
|
|
scheduler.AbortAndRemoveTask(2)
|
|
assert.Equal(t, 0, scheduler.(*globalTaskScheduler).runningTasks.Len())
|
|
}
|
|
|
|
func TestGlobalScheduler_pickNode(t *testing.T) {
|
|
scheduler := NewGlobalTaskScheduler(context.TODO(), nil).(*globalTaskScheduler)
|
|
|
|
nodeID := scheduler.pickNode(map[int64]*session.WorkerSlots{
|
|
1: {
|
|
NodeID: 1,
|
|
AvailableSlots: 30,
|
|
},
|
|
2: {
|
|
NodeID: 2,
|
|
AvailableSlots: 30,
|
|
},
|
|
}, 1)
|
|
assert.True(t, nodeID == int64(1) || nodeID == int64(2)) // random
|
|
|
|
nodeID = scheduler.pickNode(map[int64]*session.WorkerSlots{
|
|
1: {
|
|
NodeID: 1,
|
|
AvailableSlots: 20,
|
|
},
|
|
2: {
|
|
NodeID: 2,
|
|
AvailableSlots: 30,
|
|
},
|
|
}, 100)
|
|
assert.Equal(t, int64(2), nodeID) // taskSlot > available, but node 2 has more available slots
|
|
|
|
nodeID = scheduler.pickNode(map[int64]*session.WorkerSlots{
|
|
1: {
|
|
NodeID: 1,
|
|
AvailableSlots: 0,
|
|
},
|
|
2: {
|
|
NodeID: 2,
|
|
AvailableSlots: 0,
|
|
},
|
|
}, 1)
|
|
assert.Equal(t, int64(NullNodeID), nodeID) // no available slots
|
|
}
|
|
|
|
func TestGlobalScheduler_TestSchedule(t *testing.T) {
|
|
newCluster := func() session.Cluster {
|
|
cluster := session.NewMockCluster(t)
|
|
cluster.EXPECT().QuerySlot().Return(map[int64]*session.WorkerSlots{
|
|
1: {
|
|
NodeID: 1,
|
|
AvailableSlots: 100,
|
|
},
|
|
2: {
|
|
NodeID: 2,
|
|
AvailableSlots: 100,
|
|
},
|
|
}).Maybe()
|
|
return cluster
|
|
}
|
|
|
|
newTask := func() *MockTask {
|
|
task := NewMockTask(t)
|
|
task.EXPECT().GetTaskID().Return(1).Maybe()
|
|
task.EXPECT().GetTaskType().Return(taskcommon.Compaction).Maybe()
|
|
task.EXPECT().SetTaskTime(mock.Anything, mock.Anything).Return().Maybe()
|
|
task.EXPECT().GetTaskSlot().Return(1).Maybe()
|
|
return task
|
|
}
|
|
|
|
t.Run("task retry when CreateTaskOnWorker", func(t *testing.T) {
|
|
scheduler := NewGlobalTaskScheduler(context.TODO(), newCluster())
|
|
scheduler.Start()
|
|
defer scheduler.Stop()
|
|
|
|
task := newTask()
|
|
var stateCounter atomic.Int32
|
|
|
|
// Set initial state
|
|
task.EXPECT().GetTaskState().RunAndReturn(func() taskcommon.State {
|
|
counter := stateCounter.Load()
|
|
if counter == 0 {
|
|
return taskcommon.Init
|
|
}
|
|
return taskcommon.Retry
|
|
}).Maybe()
|
|
|
|
task.EXPECT().CreateTaskOnWorker(mock.Anything, mock.Anything).Run(func(nodeID int64, cluster session.Cluster) {
|
|
stateCounter.Store(1) // Mark that CreateTaskOnWorker was called
|
|
}).Maybe()
|
|
|
|
scheduler.Enqueue(task)
|
|
assert.Eventually(t, func() bool {
|
|
s := scheduler.(*globalTaskScheduler)
|
|
s.mu.RLock(task.GetTaskID())
|
|
defer s.mu.RUnlock(task.GetTaskID())
|
|
return task.GetTaskState() == taskcommon.Retry &&
|
|
s.runningTasks.Len() == 0 && len(s.pendingTasks.TaskIDs()) == 1
|
|
}, 10*time.Second, 10*time.Millisecond)
|
|
})
|
|
|
|
t.Run("task retry when QueryTaskOnWorker", func(t *testing.T) {
|
|
scheduler := NewGlobalTaskScheduler(context.TODO(), newCluster())
|
|
scheduler.Start()
|
|
defer scheduler.Stop()
|
|
|
|
task := newTask()
|
|
var stateCounter atomic.Int32
|
|
|
|
task.EXPECT().GetTaskState().RunAndReturn(func() taskcommon.State {
|
|
counter := stateCounter.Load()
|
|
switch counter {
|
|
case 0:
|
|
return taskcommon.Init
|
|
case 1:
|
|
return taskcommon.InProgress
|
|
default:
|
|
return taskcommon.Retry
|
|
}
|
|
}).Maybe()
|
|
|
|
task.EXPECT().CreateTaskOnWorker(mock.Anything, mock.Anything).Run(func(nodeID int64, cluster session.Cluster) {
|
|
stateCounter.Store(1) // CreateTaskOnWorker called
|
|
}).Maybe()
|
|
|
|
task.EXPECT().QueryTaskOnWorker(mock.Anything).Run(func(cluster session.Cluster) {
|
|
stateCounter.Store(2) // QueryTaskOnWorker called
|
|
}).Maybe()
|
|
|
|
scheduler.Enqueue(task)
|
|
assert.Eventually(t, func() bool {
|
|
s := scheduler.(*globalTaskScheduler)
|
|
s.mu.RLock(1)
|
|
defer s.mu.RUnlock(1)
|
|
return task.GetTaskState() == taskcommon.Retry &&
|
|
s.runningTasks.Len() == 0 && len(s.pendingTasks.TaskIDs()) == 1
|
|
}, 10*time.Second, 10*time.Millisecond)
|
|
})
|
|
|
|
t.Run("normal case", func(t *testing.T) {
|
|
scheduler := NewGlobalTaskScheduler(context.TODO(), newCluster())
|
|
scheduler.Start()
|
|
defer scheduler.Stop()
|
|
|
|
task := newTask()
|
|
var stateCounter atomic.Int32
|
|
|
|
task.EXPECT().GetTaskState().RunAndReturn(func() taskcommon.State {
|
|
counter := stateCounter.Load()
|
|
switch counter {
|
|
case 0:
|
|
return taskcommon.Init
|
|
case 1:
|
|
return taskcommon.InProgress
|
|
default:
|
|
return taskcommon.Finished
|
|
}
|
|
}).Maybe()
|
|
|
|
task.EXPECT().CreateTaskOnWorker(mock.Anything, mock.Anything).Run(func(nodeID int64, cluster session.Cluster) {
|
|
stateCounter.Store(1) // CreateTaskOnWorker called
|
|
}).Maybe()
|
|
|
|
task.EXPECT().QueryTaskOnWorker(mock.Anything).Run(func(cluster session.Cluster) {
|
|
stateCounter.Store(2) // QueryTaskOnWorker called
|
|
}).Maybe()
|
|
|
|
task.EXPECT().DropTaskOnWorker(mock.Anything).Run(func(cluster session.Cluster) {
|
|
stateCounter.Store(3) // DropTaskOnWorker called
|
|
}).Maybe()
|
|
|
|
scheduler.Enqueue(task)
|
|
assert.Eventually(t, func() bool {
|
|
s := scheduler.(*globalTaskScheduler)
|
|
s.mu.RLock(task.GetTaskID())
|
|
defer s.mu.RUnlock(task.GetTaskID())
|
|
return task.GetTaskState() == taskcommon.Finished &&
|
|
s.runningTasks.Len() == 0 && len(s.pendingTasks.TaskIDs()) == 0
|
|
}, 10*time.Second, 10*time.Millisecond)
|
|
})
|
|
}
|