Return task context error when read task timeout (#19288)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/19513/head
congqixia 2022-09-28 12:08:59 +08:00 committed by GitHub
parent b648034cee
commit 16170e2cef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 464 additions and 16 deletions

View File

@ -0,0 +1,248 @@
// Code generated by mockery v2.14.0. DO NOT EDIT.
package querynode
import mock "github.com/stretchr/testify/mock"
// MockTSafeReplicaInterface is an autogenerated mock type for the TSafeReplicaInterface type
type MockTSafeReplicaInterface struct {
mock.Mock
}
type MockTSafeReplicaInterface_Expecter struct {
mock *mock.Mock
}
func (_m *MockTSafeReplicaInterface) EXPECT() *MockTSafeReplicaInterface_Expecter {
return &MockTSafeReplicaInterface_Expecter{mock: &_m.Mock}
}
// Watch provides a mock function with given fields:
func (_m *MockTSafeReplicaInterface) Watch() Listener {
ret := _m.Called()
var r0 Listener
if rf, ok := ret.Get(0).(func() Listener); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(Listener)
}
}
return r0
}
// MockTSafeReplicaInterface_Watch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Watch'
type MockTSafeReplicaInterface_Watch_Call struct {
*mock.Call
}
// Watch is a helper method to define mock.On call
func (_e *MockTSafeReplicaInterface_Expecter) Watch() *MockTSafeReplicaInterface_Watch_Call {
return &MockTSafeReplicaInterface_Watch_Call{Call: _e.mock.On("Watch")}
}
func (_c *MockTSafeReplicaInterface_Watch_Call) Run(run func()) *MockTSafeReplicaInterface_Watch_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockTSafeReplicaInterface_Watch_Call) Return(_a0 Listener) *MockTSafeReplicaInterface_Watch_Call {
_c.Call.Return(_a0)
return _c
}
// WatchChannel provides a mock function with given fields: channel
func (_m *MockTSafeReplicaInterface) WatchChannel(channel string) Listener {
ret := _m.Called(channel)
var r0 Listener
if rf, ok := ret.Get(0).(func(string) Listener); ok {
r0 = rf(channel)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(Listener)
}
}
return r0
}
// MockTSafeReplicaInterface_WatchChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchChannel'
type MockTSafeReplicaInterface_WatchChannel_Call struct {
*mock.Call
}
// WatchChannel is a helper method to define mock.On call
// - channel string
func (_e *MockTSafeReplicaInterface_Expecter) WatchChannel(channel interface{}) *MockTSafeReplicaInterface_WatchChannel_Call {
return &MockTSafeReplicaInterface_WatchChannel_Call{Call: _e.mock.On("WatchChannel", channel)}
}
func (_c *MockTSafeReplicaInterface_WatchChannel_Call) Run(run func(channel string)) *MockTSafeReplicaInterface_WatchChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockTSafeReplicaInterface_WatchChannel_Call) Return(_a0 Listener) *MockTSafeReplicaInterface_WatchChannel_Call {
_c.Call.Return(_a0)
return _c
}
// addTSafe provides a mock function with given fields: vChannel
func (_m *MockTSafeReplicaInterface) addTSafe(vChannel string) {
_m.Called(vChannel)
}
// MockTSafeReplicaInterface_addTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'addTSafe'
type MockTSafeReplicaInterface_addTSafe_Call struct {
*mock.Call
}
// addTSafe is a helper method to define mock.On call
// - vChannel string
func (_e *MockTSafeReplicaInterface_Expecter) addTSafe(vChannel interface{}) *MockTSafeReplicaInterface_addTSafe_Call {
return &MockTSafeReplicaInterface_addTSafe_Call{Call: _e.mock.On("addTSafe", vChannel)}
}
func (_c *MockTSafeReplicaInterface_addTSafe_Call) Run(run func(vChannel string)) *MockTSafeReplicaInterface_addTSafe_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockTSafeReplicaInterface_addTSafe_Call) Return() *MockTSafeReplicaInterface_addTSafe_Call {
_c.Call.Return()
return _c
}
// getTSafe provides a mock function with given fields: vChannel
func (_m *MockTSafeReplicaInterface) getTSafe(vChannel string) (uint64, error) {
ret := _m.Called(vChannel)
var r0 uint64
if rf, ok := ret.Get(0).(func(string) uint64); ok {
r0 = rf(vChannel)
} else {
r0 = ret.Get(0).(uint64)
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(vChannel)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockTSafeReplicaInterface_getTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getTSafe'
type MockTSafeReplicaInterface_getTSafe_Call struct {
*mock.Call
}
// getTSafe is a helper method to define mock.On call
// - vChannel string
func (_e *MockTSafeReplicaInterface_Expecter) getTSafe(vChannel interface{}) *MockTSafeReplicaInterface_getTSafe_Call {
return &MockTSafeReplicaInterface_getTSafe_Call{Call: _e.mock.On("getTSafe", vChannel)}
}
func (_c *MockTSafeReplicaInterface_getTSafe_Call) Run(run func(vChannel string)) *MockTSafeReplicaInterface_getTSafe_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockTSafeReplicaInterface_getTSafe_Call) Return(_a0 uint64, _a1 error) *MockTSafeReplicaInterface_getTSafe_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// removeTSafe provides a mock function with given fields: vChannel
func (_m *MockTSafeReplicaInterface) removeTSafe(vChannel string) {
_m.Called(vChannel)
}
// MockTSafeReplicaInterface_removeTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'removeTSafe'
type MockTSafeReplicaInterface_removeTSafe_Call struct {
*mock.Call
}
// removeTSafe is a helper method to define mock.On call
// - vChannel string
func (_e *MockTSafeReplicaInterface_Expecter) removeTSafe(vChannel interface{}) *MockTSafeReplicaInterface_removeTSafe_Call {
return &MockTSafeReplicaInterface_removeTSafe_Call{Call: _e.mock.On("removeTSafe", vChannel)}
}
func (_c *MockTSafeReplicaInterface_removeTSafe_Call) Run(run func(vChannel string)) *MockTSafeReplicaInterface_removeTSafe_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockTSafeReplicaInterface_removeTSafe_Call) Return() *MockTSafeReplicaInterface_removeTSafe_Call {
_c.Call.Return()
return _c
}
// setTSafe provides a mock function with given fields: vChannel, timestamp
func (_m *MockTSafeReplicaInterface) setTSafe(vChannel string, timestamp uint64) error {
ret := _m.Called(vChannel, timestamp)
var r0 error
if rf, ok := ret.Get(0).(func(string, uint64) error); ok {
r0 = rf(vChannel, timestamp)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockTSafeReplicaInterface_setTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'setTSafe'
type MockTSafeReplicaInterface_setTSafe_Call struct {
*mock.Call
}
// setTSafe is a helper method to define mock.On call
// - vChannel string
// - timestamp uint64
func (_e *MockTSafeReplicaInterface_Expecter) setTSafe(vChannel interface{}, timestamp interface{}) *MockTSafeReplicaInterface_setTSafe_Call {
return &MockTSafeReplicaInterface_setTSafe_Call{Call: _e.mock.On("setTSafe", vChannel, timestamp)}
}
func (_c *MockTSafeReplicaInterface_setTSafe_Call) Run(run func(vChannel string, timestamp uint64)) *MockTSafeReplicaInterface_setTSafe_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(uint64))
})
return _c
}
func (_c *MockTSafeReplicaInterface_setTSafe_Call) Return(_a0 error) *MockTSafeReplicaInterface_setTSafe_Call {
_c.Call.Return(_a0)
return _c
}
type mockConstructorTestingTNewMockTSafeReplicaInterface interface {
mock.TestingT
Cleanup(func())
}
// NewMockTSafeReplicaInterface creates a new instance of MockTSafeReplicaInterface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockTSafeReplicaInterface(t mockConstructorTestingTNewMockTSafeReplicaInterface) *MockTSafeReplicaInterface {
mock := &MockTSafeReplicaInterface{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -1049,11 +1049,10 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest,
resultMut.Lock()
defer resultMut.Unlock()
if streamErr != nil {
cancel()
// not set cancel error
if !errors.Is(streamErr, context.Canceled) {
if err == nil {
err = fmt.Errorf("stream operation failed: %w", streamErr)
}
cancel()
}
}()
@ -1077,11 +1076,10 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest,
resultMut.Lock()
defer resultMut.Unlock()
if nodeErr != nil || partialResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
cancel()
// not set cancel error
if !errors.Is(nodeErr, context.Canceled) {
if err == nil {
err = fmt.Errorf("Search %d failed, reason %s err %w", node.nodeID, partialResult.GetStatus().GetReason(), nodeErr)
}
cancel()
return
}
results = append(results, partialResult)
@ -1128,11 +1126,10 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, wi
streamErr := withStreaming(reqCtx)
if streamErr != nil {
cancel()
// not set cancel error
if !errors.Is(streamErr, context.Canceled) {
if err == nil {
err = fmt.Errorf("stream operation failed: %w", streamErr)
}
cancel()
}
}()
@ -1156,8 +1153,8 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, wi
resultMut.Lock()
defer resultMut.Unlock()
if nodeErr != nil || partialResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
cancel()
err = fmt.Errorf("Query %d failed, reason %s err %w", node.nodeID, partialResult.GetStatus().GetReason(), nodeErr)
cancel()
return
}
results = append(results, partialResult)

View File

@ -41,6 +41,7 @@ type readTask interface {
CanMergeWith(readTask) bool
CPUUsage() int32
Timeout() bool
TimeoutError() error
SetMaxCPUUsage(int32)
SetStep(step TaskStep)
@ -133,12 +134,16 @@ func (b *baseReadTask) Timeout() bool {
return !funcutil.CheckCtxValid(b.Ctx())
}
func (b *baseReadTask) TimeoutError() error {
return b.ctx.Err()
}
func (b *baseReadTask) Ready() (bool, error) {
if b.waitTSafeTr == nil {
b.waitTSafeTr = timerecord.NewTimeRecorder("waitTSafeTimeRecorder")
}
if b.Timeout() {
return false, fmt.Errorf("deadline exceed")
return false, b.TimeoutError()
}
var channel Channel
if b.DataScope == querypb.DataScope_Streaming {

View File

@ -0,0 +1,115 @@
package querynode
import (
"context"
"testing"
"time"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/stretchr/testify/suite"
)
type baseReadTaskSuite struct {
suite.Suite
qs *queryShard
tsafe *MockTSafeReplicaInterface
task *baseReadTask
}
func (s *baseReadTaskSuite) SetupSuite() {
meta := newMockReplicaInterface()
meta.getCollectionByIDFunc = func(collectionID UniqueID) (*Collection, error) {
return &Collection{
id: defaultCollectionID,
}, nil
}
rcm := &mocks.ChunkManager{}
lcm := &mocks.ChunkManager{}
tsafe := &MockTSafeReplicaInterface{}
qs, err := newQueryShard(context.Background(), defaultCollectionID, defaultDMLChannel, defaultReplicaID, nil, meta, tsafe, lcm, rcm, false)
s.Require().NoError(err)
s.qs = qs
}
func (s *baseReadTaskSuite) TearDownSuite() {
s.qs.Close()
}
func (s *baseReadTaskSuite) SetupTest() {
s.task = &baseReadTask{QS: s.qs, tr: timerecord.NewTimeRecorder("baseReadTaskTest")}
}
func (s *baseReadTaskSuite) TearDownTest() {
s.task = nil
}
func (s *baseReadTaskSuite) TestPreExecute() {
ctx := context.Background()
err := s.task.PreExecute(ctx)
s.Assert().NoError(err)
s.Assert().Equal(TaskStepPreExecute, s.task.step)
}
func (s *baseReadTaskSuite) TestExecute() {
ctx := context.Background()
err := s.task.Execute(ctx)
s.Assert().NoError(err)
s.Assert().Equal(TaskStepExecute, s.task.step)
}
func (s *baseReadTaskSuite) TestTimeout() {
s.Run("background ctx", func() {
s.task.ctx = context.Background()
s.Assert().False(s.task.Timeout())
})
s.Run("context canceled", func() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
s.task.ctx = ctx
s.Assert().True(s.task.Timeout())
})
s.Run("deadline exceeded", func() {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Minute))
defer cancel()
s.task.ctx = ctx
s.Assert().True(s.task.Timeout())
})
}
func (s *baseReadTaskSuite) TestTimeoutError() {
s.Run("background ctx", func() {
s.task.ctx = context.Background()
s.Assert().Nil(s.task.TimeoutError())
})
s.Run("context canceled", func() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
s.task.ctx = ctx
s.Assert().ErrorIs(s.task.TimeoutError(), context.Canceled)
})
s.Run("deadline exceeded", func() {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Minute))
defer cancel()
s.task.ctx = ctx
s.Assert().ErrorIs(s.task.TimeoutError(), context.DeadlineExceeded)
})
}
func TestBaseReadTask(t *testing.T) {
suite.Run(t, new(baseReadTaskSuite))
}

View File

@ -146,7 +146,6 @@ func (s *taskScheduler) tryEvictUnsolvedReadTask(headCount int) {
if diff <= 0 {
return
}
timeoutErr := fmt.Errorf("deadline exceed")
var next *list.Element
for e := s.unsolvedReadTasks.Front(); e != nil; e = next {
next = e.Next()
@ -160,7 +159,7 @@ func (s *taskScheduler) tryEvictUnsolvedReadTask(headCount int) {
if t.Timeout() {
s.unsolvedReadTasks.Remove(e)
rateCol.rtCounter.sub(t, unsolvedQueueType)
t.Notify(timeoutErr)
t.Notify(t.TimeoutError())
diff--
}
}
@ -188,7 +187,7 @@ func (s *taskScheduler) scheduleReadTasks() {
for {
select {
case <-s.ctx.Done():
log.Warn("QueryNode sop schedulerReadTasks")
log.Warn("QueryNode stop schedulerReadTasks")
return
case <-s.notifyChan:
@ -273,12 +272,11 @@ func (s *taskScheduler) executeReadTasks() {
defer s.wg.Done()
var taskWg sync.WaitGroup
defer taskWg.Wait()
timeoutErr := fmt.Errorf("deadline exceed")
executeFunc := func(t readTask) {
defer taskWg.Done()
if t.Timeout() {
t.Notify(timeoutErr)
t.Notify(t.TimeoutError())
} else {
s.processReadTask(t)
}
@ -302,6 +300,7 @@ func (s *taskScheduler) executeReadTasks() {
pendingTaskLen := len(s.executeReadTaskChan)
taskWg.Add(1)
atomic.AddInt32(&s.readConcurrency, int32(pendingTaskLen+1))
log.Debug("begin to execute task")
go executeFunc(t)
for i := 0; i < pendingTaskLen; i++ {

View File

@ -20,6 +20,8 @@ import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
type mockTask struct {
@ -65,6 +67,7 @@ type mockReadTask struct {
ready bool
canMerge bool
timeout bool
timeoutError error
step TaskStep
readyError error
}
@ -89,6 +92,10 @@ func (m *mockReadTask) Timeout() bool {
return m.timeout
}
func (m *mockReadTask) TimeoutError() error {
return m.timeoutError
}
func (m *mockReadTask) SetMaxCPUUsage(cpu int32) {
m.maxCPU = cpu
}
@ -125,3 +132,80 @@ func TestTaskScheduler(t *testing.T) {
ts.Close()
}
func TestTaskScheduler_tryEvictUnsolvedReadTask(t *testing.T) {
t.Run("evict canceled task", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
tSafe := newTSafeReplica()
ts := newTaskScheduler(ctx, tSafe)
taskCanceled := &mockReadTask{
mockTask: mockTask{
baseTask: baseTask{
ctx: ctx,
done: make(chan error, 1024),
},
},
timeout: true,
timeoutError: context.Canceled,
}
taskNormal := &mockReadTask{
mockTask: mockTask{
baseTask: baseTask{
ctx: context.Background(),
done: make(chan error, 1024),
},
},
}
ts.unsolvedReadTasks.PushBack(taskNormal)
ts.unsolvedReadTasks.PushBack(taskCanceled)
// set max len to 2
tmp := Params.QueryNodeCfg.MaxUnsolvedQueueSize
Params.QueryNodeCfg.MaxUnsolvedQueueSize = 2
ts.tryEvictUnsolvedReadTask(1)
Params.QueryNodeCfg.MaxUnsolvedQueueSize = tmp
err := <-taskCanceled.done
assert.ErrorIs(t, err, context.Canceled)
select {
case <-taskNormal.done:
t.Fail()
default:
}
assert.Equal(t, 1, ts.unsolvedReadTasks.Len())
})
}
func TestTaskScheduler_executeReadTasks(t *testing.T) {
t.Run("execute canceled task", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tSafe := newTSafeReplica()
ts := newTaskScheduler(ctx, tSafe)
ts.Start()
defer ts.Close()
taskCanceled := &mockReadTask{
mockTask: mockTask{
baseTask: baseTask{
ctx: ctx,
done: make(chan error, 1024),
},
},
timeout: true,
timeoutError: context.Canceled,
}
ts.executeReadTaskChan <- taskCanceled
err := <-taskCanceled.done
assert.ErrorIs(t, err, context.Canceled)
})
}