mirror of https://github.com/milvus-io/milvus.git
Return task context error when read task timeout (#19288)
Signed-off-by: Congqi Xia <congqi.xia@zilliz.com> Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/19513/head
parent
b648034cee
commit
16170e2cef
|
@ -0,0 +1,248 @@
|
|||
// Code generated by mockery v2.14.0. DO NOT EDIT.
|
||||
|
||||
package querynode
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockTSafeReplicaInterface is an autogenerated mock type for the TSafeReplicaInterface type
|
||||
type MockTSafeReplicaInterface struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockTSafeReplicaInterface_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockTSafeReplicaInterface) EXPECT() *MockTSafeReplicaInterface_Expecter {
|
||||
return &MockTSafeReplicaInterface_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Watch provides a mock function with given fields:
|
||||
func (_m *MockTSafeReplicaInterface) Watch() Listener {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 Listener
|
||||
if rf, ok := ret.Get(0).(func() Listener); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(Listener)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockTSafeReplicaInterface_Watch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Watch'
|
||||
type MockTSafeReplicaInterface_Watch_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Watch is a helper method to define mock.On call
|
||||
func (_e *MockTSafeReplicaInterface_Expecter) Watch() *MockTSafeReplicaInterface_Watch_Call {
|
||||
return &MockTSafeReplicaInterface_Watch_Call{Call: _e.mock.On("Watch")}
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_Watch_Call) Run(run func()) *MockTSafeReplicaInterface_Watch_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_Watch_Call) Return(_a0 Listener) *MockTSafeReplicaInterface_Watch_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// WatchChannel provides a mock function with given fields: channel
|
||||
func (_m *MockTSafeReplicaInterface) WatchChannel(channel string) Listener {
|
||||
ret := _m.Called(channel)
|
||||
|
||||
var r0 Listener
|
||||
if rf, ok := ret.Get(0).(func(string) Listener); ok {
|
||||
r0 = rf(channel)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(Listener)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockTSafeReplicaInterface_WatchChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchChannel'
|
||||
type MockTSafeReplicaInterface_WatchChannel_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// WatchChannel is a helper method to define mock.On call
|
||||
// - channel string
|
||||
func (_e *MockTSafeReplicaInterface_Expecter) WatchChannel(channel interface{}) *MockTSafeReplicaInterface_WatchChannel_Call {
|
||||
return &MockTSafeReplicaInterface_WatchChannel_Call{Call: _e.mock.On("WatchChannel", channel)}
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_WatchChannel_Call) Run(run func(channel string)) *MockTSafeReplicaInterface_WatchChannel_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_WatchChannel_Call) Return(_a0 Listener) *MockTSafeReplicaInterface_WatchChannel_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// addTSafe provides a mock function with given fields: vChannel
|
||||
func (_m *MockTSafeReplicaInterface) addTSafe(vChannel string) {
|
||||
_m.Called(vChannel)
|
||||
}
|
||||
|
||||
// MockTSafeReplicaInterface_addTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'addTSafe'
|
||||
type MockTSafeReplicaInterface_addTSafe_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// addTSafe is a helper method to define mock.On call
|
||||
// - vChannel string
|
||||
func (_e *MockTSafeReplicaInterface_Expecter) addTSafe(vChannel interface{}) *MockTSafeReplicaInterface_addTSafe_Call {
|
||||
return &MockTSafeReplicaInterface_addTSafe_Call{Call: _e.mock.On("addTSafe", vChannel)}
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_addTSafe_Call) Run(run func(vChannel string)) *MockTSafeReplicaInterface_addTSafe_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_addTSafe_Call) Return() *MockTSafeReplicaInterface_addTSafe_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
// getTSafe provides a mock function with given fields: vChannel
|
||||
func (_m *MockTSafeReplicaInterface) getTSafe(vChannel string) (uint64, error) {
|
||||
ret := _m.Called(vChannel)
|
||||
|
||||
var r0 uint64
|
||||
if rf, ok := ret.Get(0).(func(string) uint64); ok {
|
||||
r0 = rf(vChannel)
|
||||
} else {
|
||||
r0 = ret.Get(0).(uint64)
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = rf(vChannel)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockTSafeReplicaInterface_getTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getTSafe'
|
||||
type MockTSafeReplicaInterface_getTSafe_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// getTSafe is a helper method to define mock.On call
|
||||
// - vChannel string
|
||||
func (_e *MockTSafeReplicaInterface_Expecter) getTSafe(vChannel interface{}) *MockTSafeReplicaInterface_getTSafe_Call {
|
||||
return &MockTSafeReplicaInterface_getTSafe_Call{Call: _e.mock.On("getTSafe", vChannel)}
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_getTSafe_Call) Run(run func(vChannel string)) *MockTSafeReplicaInterface_getTSafe_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_getTSafe_Call) Return(_a0 uint64, _a1 error) *MockTSafeReplicaInterface_getTSafe_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
// removeTSafe provides a mock function with given fields: vChannel
|
||||
func (_m *MockTSafeReplicaInterface) removeTSafe(vChannel string) {
|
||||
_m.Called(vChannel)
|
||||
}
|
||||
|
||||
// MockTSafeReplicaInterface_removeTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'removeTSafe'
|
||||
type MockTSafeReplicaInterface_removeTSafe_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// removeTSafe is a helper method to define mock.On call
|
||||
// - vChannel string
|
||||
func (_e *MockTSafeReplicaInterface_Expecter) removeTSafe(vChannel interface{}) *MockTSafeReplicaInterface_removeTSafe_Call {
|
||||
return &MockTSafeReplicaInterface_removeTSafe_Call{Call: _e.mock.On("removeTSafe", vChannel)}
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_removeTSafe_Call) Run(run func(vChannel string)) *MockTSafeReplicaInterface_removeTSafe_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_removeTSafe_Call) Return() *MockTSafeReplicaInterface_removeTSafe_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
// setTSafe provides a mock function with given fields: vChannel, timestamp
|
||||
func (_m *MockTSafeReplicaInterface) setTSafe(vChannel string, timestamp uint64) error {
|
||||
ret := _m.Called(vChannel, timestamp)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, uint64) error); ok {
|
||||
r0 = rf(vChannel, timestamp)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockTSafeReplicaInterface_setTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'setTSafe'
|
||||
type MockTSafeReplicaInterface_setTSafe_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// setTSafe is a helper method to define mock.On call
|
||||
// - vChannel string
|
||||
// - timestamp uint64
|
||||
func (_e *MockTSafeReplicaInterface_Expecter) setTSafe(vChannel interface{}, timestamp interface{}) *MockTSafeReplicaInterface_setTSafe_Call {
|
||||
return &MockTSafeReplicaInterface_setTSafe_Call{Call: _e.mock.On("setTSafe", vChannel, timestamp)}
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_setTSafe_Call) Run(run func(vChannel string, timestamp uint64)) *MockTSafeReplicaInterface_setTSafe_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string), args[1].(uint64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockTSafeReplicaInterface_setTSafe_Call) Return(_a0 error) *MockTSafeReplicaInterface_setTSafe_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockTSafeReplicaInterface interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// NewMockTSafeReplicaInterface creates a new instance of MockTSafeReplicaInterface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockTSafeReplicaInterface(t mockConstructorTestingTNewMockTSafeReplicaInterface) *MockTSafeReplicaInterface {
|
||||
mock := &MockTSafeReplicaInterface{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -1049,11 +1049,10 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest,
|
|||
resultMut.Lock()
|
||||
defer resultMut.Unlock()
|
||||
if streamErr != nil {
|
||||
cancel()
|
||||
// not set cancel error
|
||||
if !errors.Is(streamErr, context.Canceled) {
|
||||
if err == nil {
|
||||
err = fmt.Errorf("stream operation failed: %w", streamErr)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -1077,11 +1076,10 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest,
|
|||
resultMut.Lock()
|
||||
defer resultMut.Unlock()
|
||||
if nodeErr != nil || partialResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
cancel()
|
||||
// not set cancel error
|
||||
if !errors.Is(nodeErr, context.Canceled) {
|
||||
if err == nil {
|
||||
err = fmt.Errorf("Search %d failed, reason %s err %w", node.nodeID, partialResult.GetStatus().GetReason(), nodeErr)
|
||||
}
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
results = append(results, partialResult)
|
||||
|
@ -1128,11 +1126,10 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, wi
|
|||
|
||||
streamErr := withStreaming(reqCtx)
|
||||
if streamErr != nil {
|
||||
cancel()
|
||||
// not set cancel error
|
||||
if !errors.Is(streamErr, context.Canceled) {
|
||||
if err == nil {
|
||||
err = fmt.Errorf("stream operation failed: %w", streamErr)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -1156,8 +1153,8 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, wi
|
|||
resultMut.Lock()
|
||||
defer resultMut.Unlock()
|
||||
if nodeErr != nil || partialResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
cancel()
|
||||
err = fmt.Errorf("Query %d failed, reason %s err %w", node.nodeID, partialResult.GetStatus().GetReason(), nodeErr)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
results = append(results, partialResult)
|
||||
|
|
|
@ -41,6 +41,7 @@ type readTask interface {
|
|||
CanMergeWith(readTask) bool
|
||||
CPUUsage() int32
|
||||
Timeout() bool
|
||||
TimeoutError() error
|
||||
|
||||
SetMaxCPUUsage(int32)
|
||||
SetStep(step TaskStep)
|
||||
|
@ -133,12 +134,16 @@ func (b *baseReadTask) Timeout() bool {
|
|||
return !funcutil.CheckCtxValid(b.Ctx())
|
||||
}
|
||||
|
||||
func (b *baseReadTask) TimeoutError() error {
|
||||
return b.ctx.Err()
|
||||
}
|
||||
|
||||
func (b *baseReadTask) Ready() (bool, error) {
|
||||
if b.waitTSafeTr == nil {
|
||||
b.waitTSafeTr = timerecord.NewTimeRecorder("waitTSafeTimeRecorder")
|
||||
}
|
||||
if b.Timeout() {
|
||||
return false, fmt.Errorf("deadline exceed")
|
||||
return false, b.TimeoutError()
|
||||
}
|
||||
var channel Channel
|
||||
if b.DataScope == querypb.DataScope_Streaming {
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type baseReadTaskSuite struct {
|
||||
suite.Suite
|
||||
|
||||
qs *queryShard
|
||||
tsafe *MockTSafeReplicaInterface
|
||||
|
||||
task *baseReadTask
|
||||
}
|
||||
|
||||
func (s *baseReadTaskSuite) SetupSuite() {
|
||||
meta := newMockReplicaInterface()
|
||||
meta.getCollectionByIDFunc = func(collectionID UniqueID) (*Collection, error) {
|
||||
return &Collection{
|
||||
id: defaultCollectionID,
|
||||
}, nil
|
||||
}
|
||||
rcm := &mocks.ChunkManager{}
|
||||
lcm := &mocks.ChunkManager{}
|
||||
|
||||
tsafe := &MockTSafeReplicaInterface{}
|
||||
|
||||
qs, err := newQueryShard(context.Background(), defaultCollectionID, defaultDMLChannel, defaultReplicaID, nil, meta, tsafe, lcm, rcm, false)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.qs = qs
|
||||
}
|
||||
|
||||
func (s *baseReadTaskSuite) TearDownSuite() {
|
||||
s.qs.Close()
|
||||
}
|
||||
|
||||
func (s *baseReadTaskSuite) SetupTest() {
|
||||
s.task = &baseReadTask{QS: s.qs, tr: timerecord.NewTimeRecorder("baseReadTaskTest")}
|
||||
}
|
||||
|
||||
func (s *baseReadTaskSuite) TearDownTest() {
|
||||
s.task = nil
|
||||
}
|
||||
|
||||
func (s *baseReadTaskSuite) TestPreExecute() {
|
||||
ctx := context.Background()
|
||||
err := s.task.PreExecute(ctx)
|
||||
s.Assert().NoError(err)
|
||||
s.Assert().Equal(TaskStepPreExecute, s.task.step)
|
||||
}
|
||||
|
||||
func (s *baseReadTaskSuite) TestExecute() {
|
||||
ctx := context.Background()
|
||||
err := s.task.Execute(ctx)
|
||||
s.Assert().NoError(err)
|
||||
s.Assert().Equal(TaskStepExecute, s.task.step)
|
||||
}
|
||||
|
||||
func (s *baseReadTaskSuite) TestTimeout() {
|
||||
s.Run("background ctx", func() {
|
||||
s.task.ctx = context.Background()
|
||||
s.Assert().False(s.task.Timeout())
|
||||
})
|
||||
|
||||
s.Run("context canceled", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
s.task.ctx = ctx
|
||||
|
||||
s.Assert().True(s.task.Timeout())
|
||||
})
|
||||
|
||||
s.Run("deadline exceeded", func() {
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Minute))
|
||||
defer cancel()
|
||||
s.task.ctx = ctx
|
||||
|
||||
s.Assert().True(s.task.Timeout())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *baseReadTaskSuite) TestTimeoutError() {
|
||||
s.Run("background ctx", func() {
|
||||
s.task.ctx = context.Background()
|
||||
s.Assert().Nil(s.task.TimeoutError())
|
||||
})
|
||||
|
||||
s.Run("context canceled", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
s.task.ctx = ctx
|
||||
|
||||
s.Assert().ErrorIs(s.task.TimeoutError(), context.Canceled)
|
||||
})
|
||||
|
||||
s.Run("deadline exceeded", func() {
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Minute))
|
||||
defer cancel()
|
||||
s.task.ctx = ctx
|
||||
|
||||
s.Assert().ErrorIs(s.task.TimeoutError(), context.DeadlineExceeded)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestBaseReadTask(t *testing.T) {
|
||||
suite.Run(t, new(baseReadTaskSuite))
|
||||
}
|
|
@ -146,7 +146,6 @@ func (s *taskScheduler) tryEvictUnsolvedReadTask(headCount int) {
|
|||
if diff <= 0 {
|
||||
return
|
||||
}
|
||||
timeoutErr := fmt.Errorf("deadline exceed")
|
||||
var next *list.Element
|
||||
for e := s.unsolvedReadTasks.Front(); e != nil; e = next {
|
||||
next = e.Next()
|
||||
|
@ -160,7 +159,7 @@ func (s *taskScheduler) tryEvictUnsolvedReadTask(headCount int) {
|
|||
if t.Timeout() {
|
||||
s.unsolvedReadTasks.Remove(e)
|
||||
rateCol.rtCounter.sub(t, unsolvedQueueType)
|
||||
t.Notify(timeoutErr)
|
||||
t.Notify(t.TimeoutError())
|
||||
diff--
|
||||
}
|
||||
}
|
||||
|
@ -188,7 +187,7 @@ func (s *taskScheduler) scheduleReadTasks() {
|
|||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
log.Warn("QueryNode sop schedulerReadTasks")
|
||||
log.Warn("QueryNode stop schedulerReadTasks")
|
||||
return
|
||||
|
||||
case <-s.notifyChan:
|
||||
|
@ -273,12 +272,11 @@ func (s *taskScheduler) executeReadTasks() {
|
|||
defer s.wg.Done()
|
||||
var taskWg sync.WaitGroup
|
||||
defer taskWg.Wait()
|
||||
timeoutErr := fmt.Errorf("deadline exceed")
|
||||
|
||||
executeFunc := func(t readTask) {
|
||||
defer taskWg.Done()
|
||||
if t.Timeout() {
|
||||
t.Notify(timeoutErr)
|
||||
t.Notify(t.TimeoutError())
|
||||
} else {
|
||||
s.processReadTask(t)
|
||||
}
|
||||
|
@ -302,6 +300,7 @@ func (s *taskScheduler) executeReadTasks() {
|
|||
pendingTaskLen := len(s.executeReadTaskChan)
|
||||
taskWg.Add(1)
|
||||
atomic.AddInt32(&s.readConcurrency, int32(pendingTaskLen+1))
|
||||
log.Debug("begin to execute task")
|
||||
go executeFunc(t)
|
||||
|
||||
for i := 0; i < pendingTaskLen; i++ {
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockTask struct {
|
||||
|
@ -65,6 +67,7 @@ type mockReadTask struct {
|
|||
ready bool
|
||||
canMerge bool
|
||||
timeout bool
|
||||
timeoutError error
|
||||
step TaskStep
|
||||
readyError error
|
||||
}
|
||||
|
@ -89,6 +92,10 @@ func (m *mockReadTask) Timeout() bool {
|
|||
return m.timeout
|
||||
}
|
||||
|
||||
func (m *mockReadTask) TimeoutError() error {
|
||||
return m.timeoutError
|
||||
}
|
||||
|
||||
func (m *mockReadTask) SetMaxCPUUsage(cpu int32) {
|
||||
m.maxCPU = cpu
|
||||
}
|
||||
|
@ -125,3 +132,80 @@ func TestTaskScheduler(t *testing.T) {
|
|||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestTaskScheduler_tryEvictUnsolvedReadTask(t *testing.T) {
|
||||
t.Run("evict canceled task", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
tSafe := newTSafeReplica()
|
||||
|
||||
ts := newTaskScheduler(ctx, tSafe)
|
||||
|
||||
taskCanceled := &mockReadTask{
|
||||
mockTask: mockTask{
|
||||
baseTask: baseTask{
|
||||
ctx: ctx,
|
||||
done: make(chan error, 1024),
|
||||
},
|
||||
},
|
||||
timeout: true,
|
||||
timeoutError: context.Canceled,
|
||||
}
|
||||
taskNormal := &mockReadTask{
|
||||
mockTask: mockTask{
|
||||
baseTask: baseTask{
|
||||
ctx: context.Background(),
|
||||
done: make(chan error, 1024),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ts.unsolvedReadTasks.PushBack(taskNormal)
|
||||
ts.unsolvedReadTasks.PushBack(taskCanceled)
|
||||
|
||||
// set max len to 2
|
||||
tmp := Params.QueryNodeCfg.MaxUnsolvedQueueSize
|
||||
Params.QueryNodeCfg.MaxUnsolvedQueueSize = 2
|
||||
ts.tryEvictUnsolvedReadTask(1)
|
||||
Params.QueryNodeCfg.MaxUnsolvedQueueSize = tmp
|
||||
|
||||
err := <-taskCanceled.done
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
|
||||
select {
|
||||
case <-taskNormal.done:
|
||||
t.Fail()
|
||||
default:
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, ts.unsolvedReadTasks.Len())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskScheduler_executeReadTasks(t *testing.T) {
|
||||
t.Run("execute canceled task", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
tSafe := newTSafeReplica()
|
||||
|
||||
ts := newTaskScheduler(ctx, tSafe)
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
taskCanceled := &mockReadTask{
|
||||
mockTask: mockTask{
|
||||
baseTask: baseTask{
|
||||
ctx: ctx,
|
||||
done: make(chan error, 1024),
|
||||
},
|
||||
},
|
||||
timeout: true,
|
||||
timeoutError: context.Canceled,
|
||||
}
|
||||
|
||||
ts.executeReadTaskChan <- taskCanceled
|
||||
|
||||
err := <-taskCanceled.done
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue