Add ctx parameter for organizeTask and GetWorker method (#26835)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/26843/head
congqixia 2023-09-05 10:05:48 +08:00 committed by GitHub
parent c132c53b1a
commit 4b58c71908
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 64 additions and 55 deletions

View File

@ -17,6 +17,7 @@
package cluster
import (
context "context"
"fmt"
"strconv"
@ -29,11 +30,11 @@ import (
// Manager is the interface for worker manager.
type Manager interface {
GetWorker(nodeID int64) (Worker, error)
GetWorker(ctx context.Context, nodeID int64) (Worker, error)
}
// WorkerBuilder is function alias to build a worker from NodeID
type WorkerBuilder func(nodeID int64) (Worker, error)
type WorkerBuilder func(ctx context.Context, nodeID int64) (Worker, error)
type grpcWorkerManager struct {
workers *typeutil.ConcurrentMap[int64, Worker]
@ -42,12 +43,12 @@ type grpcWorkerManager struct {
}
// GetWorker returns worker with specified nodeID.
func (m *grpcWorkerManager) GetWorker(nodeID int64) (Worker, error) {
func (m *grpcWorkerManager) GetWorker(ctx context.Context, nodeID int64) (Worker, error) {
worker, ok := m.workers.Get(nodeID)
var err error
if !ok {
worker, err, _ = m.sf.Do(strconv.FormatInt(nodeID, 10), func() (Worker, error) {
worker, err = m.builder(nodeID)
worker, err = m.builder(ctx, nodeID)
if err != nil {
log.Warn("failed to build worker",
zap.Int64("nodeID", nodeID),

View File

@ -17,6 +17,7 @@
package cluster
import (
context "context"
"testing"
"github.com/cockroachdb/errors"
@ -25,23 +26,24 @@ import (
)
func TestManager(t *testing.T) {
ctx := context.Background()
t.Run("normal_get", func(t *testing.T) {
worker := &MockWorker{}
worker.EXPECT().IsHealthy().Return(true)
var buildErr error
var called int
builder := func(nodeID int64) (Worker, error) {
builder := func(_ context.Context, nodeID int64) (Worker, error) {
called++
return worker, buildErr
}
manager := NewWorkerManager(builder)
w, err := manager.GetWorker(0)
w, err := manager.GetWorker(ctx, 0)
assert.Equal(t, worker, w)
assert.NoError(t, err)
assert.Equal(t, 1, called)
w, err = manager.GetWorker(0)
w, err = manager.GetWorker(ctx, 0)
assert.Equal(t, worker, w)
assert.NoError(t, err)
assert.Equal(t, 1, called)
@ -53,13 +55,13 @@ func TestManager(t *testing.T) {
var buildErr error
var called int
buildErr = errors.New("mocked error")
builder := func(nodeID int64) (Worker, error) {
builder := func(_ context.Context, nodeID int64) (Worker, error) {
called++
return worker, buildErr
}
manager := NewWorkerManager(builder)
_, err := manager.GetWorker(0)
_, err := manager.GetWorker(ctx, 0)
assert.Error(t, err)
assert.Equal(t, 1, called)
})
@ -69,13 +71,13 @@ func TestManager(t *testing.T) {
worker.EXPECT().IsHealthy().Return(false)
var buildErr error
var called int
builder := func(nodeID int64) (Worker, error) {
builder := func(_ context.Context, nodeID int64) (Worker, error) {
called++
return worker, buildErr
}
manager := NewWorkerManager(builder)
_, err := manager.GetWorker(0)
_, err := manager.GetWorker(ctx, 0)
assert.Error(t, err)
assert.Equal(t, 1, called)
})

View File

@ -2,7 +2,11 @@
package cluster
import mock "github.com/stretchr/testify/mock"
import (
context "context"
mock "github.com/stretchr/testify/mock"
)
// MockManager is an autogenerated mock type for the Manager type
type MockManager struct {
@ -17,25 +21,25 @@ func (_m *MockManager) EXPECT() *MockManager_Expecter {
return &MockManager_Expecter{mock: &_m.Mock}
}
// GetWorker provides a mock function with given fields: nodeID
func (_m *MockManager) GetWorker(nodeID int64) (Worker, error) {
ret := _m.Called(nodeID)
// GetWorker provides a mock function with given fields: ctx, nodeID
func (_m *MockManager) GetWorker(ctx context.Context, nodeID int64) (Worker, error) {
ret := _m.Called(ctx, nodeID)
var r0 Worker
var r1 error
if rf, ok := ret.Get(0).(func(int64) (Worker, error)); ok {
return rf(nodeID)
if rf, ok := ret.Get(0).(func(context.Context, int64) (Worker, error)); ok {
return rf(ctx, nodeID)
}
if rf, ok := ret.Get(0).(func(int64) Worker); ok {
r0 = rf(nodeID)
if rf, ok := ret.Get(0).(func(context.Context, int64) Worker); ok {
r0 = rf(ctx, nodeID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(Worker)
}
}
if rf, ok := ret.Get(1).(func(int64) error); ok {
r1 = rf(nodeID)
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
r1 = rf(ctx, nodeID)
} else {
r1 = ret.Error(1)
}
@ -49,14 +53,15 @@ type MockManager_GetWorker_Call struct {
}
// GetWorker is a helper method to define mock.On call
// - ctx context.Context
// - nodeID int64
func (_e *MockManager_Expecter) GetWorker(nodeID interface{}) *MockManager_GetWorker_Call {
return &MockManager_GetWorker_Call{Call: _e.mock.On("GetWorker", nodeID)}
func (_e *MockManager_Expecter) GetWorker(ctx interface{}, nodeID interface{}) *MockManager_GetWorker_Call {
return &MockManager_GetWorker_Call{Call: _e.mock.On("GetWorker", ctx, nodeID)}
}
func (_c *MockManager_GetWorker_Call) Run(run func(nodeID int64)) *MockManager_GetWorker_Call {
func (_c *MockManager_GetWorker_Call) Run(run func(ctx context.Context, nodeID int64)) *MockManager_GetWorker_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
@ -66,7 +71,7 @@ func (_c *MockManager_GetWorker_Call) Return(_a0 Worker, _a1 error) *MockManager
return _c
}
func (_c *MockManager_GetWorker_Call) RunAndReturn(run func(int64) (Worker, error)) *MockManager_GetWorker_Call {
func (_c *MockManager_GetWorker_Call) RunAndReturn(run func(context.Context, int64) (Worker, error)) *MockManager_GetWorker_Call {
_c.Call.Return(run)
return _c
}

View File

@ -249,7 +249,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
zap.Int("sealedNum", sealedNum),
zap.Int("growingNum", len(growing)),
)
tasks, err := organizeSubTask(req, sealed, growing, sd.workerManager, sd.modifySearchRequest)
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
if err != nil {
log.Warn("Search organizeSubTask failed", zap.Error(err))
return nil, err
@ -313,7 +313,7 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
zap.Int("sealedNum", sealedNum),
zap.Int("growingNum", len(growing)),
)
tasks, err := organizeSubTask(req, sealed, growing, sd.workerManager, sd.modifyQueryRequest)
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifyQueryRequest)
if err != nil {
log.Warn("query organizeSubTask failed", zap.Error(err))
return nil, err
@ -356,7 +356,7 @@ func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetSta
sealed, growing, version := sd.distribution.GetSegments(true, req.Req.GetPartitionIDs()...)
defer sd.distribution.FinishUsage(version)
tasks, err := organizeSubTask(req, sealed, growing, sd.workerManager, func(req *querypb.GetStatisticsRequest, scope querypb.DataScope, segmentIDs []int64, targetID int64) *querypb.GetStatisticsRequest {
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, func(req *querypb.GetStatisticsRequest, scope querypb.DataScope, segmentIDs []int64, targetID int64) *querypb.GetStatisticsRequest {
nodeReq := proto.Clone(req).(*querypb.GetStatisticsRequest)
nodeReq.GetReq().GetBase().TargetID = targetID
nodeReq.Scope = scope
@ -386,7 +386,8 @@ type subTask[T any] struct {
worker cluster.Worker
}
func organizeSubTask[T any](req T, sealed []SnapshotItem, growing []SegmentEntry, workerManager cluster.Manager, modify func(T, querypb.DataScope, []int64, int64) T) ([]subTask[T], error) {
func organizeSubTask[T any](ctx context.Context, req T, sealed []SnapshotItem, growing []SegmentEntry, sd *shardDelegator, modify func(T, querypb.DataScope, []int64, int64) T) ([]subTask[T], error) {
log := sd.getLogger(ctx)
result := make([]subTask[T], 0, len(sealed)+1)
packSubTask := func(segments []SegmentEntry, workerID int64, scope querypb.DataScope) error {
@ -399,7 +400,7 @@ func organizeSubTask[T any](req T, sealed []SnapshotItem, growing []SegmentEntry
// update request
req := modify(req, scope, segmentIDs, workerID)
worker, err := workerManager.GetWorker(workerID)
worker, err := sd.workerManager.GetWorker(ctx, workerID)
if err != nil {
log.Warn("failed to get worker",
zap.Int64("nodeID", workerID),

View File

@ -193,7 +193,7 @@ func (sd *shardDelegator) ProcessDelete(deleteData []*DeleteData, ts uint64) {
for _, entry := range sealed {
entry := entry
eg.Go(func() error {
worker, err := sd.workerManager.GetWorker(entry.NodeID)
worker, err := sd.workerManager.GetWorker(ctx, entry.NodeID)
if err != nil {
log.Warn("failed to get worker",
zap.Int64("nodeID", paramtable.GetNodeID()),
@ -209,7 +209,7 @@ func (sd *shardDelegator) ProcessDelete(deleteData []*DeleteData, ts uint64) {
}
if len(growing) > 0 {
eg.Go(func() error {
worker, err := sd.workerManager.GetWorker(paramtable.GetNodeID())
worker, err := sd.workerManager.GetWorker(ctx, paramtable.GetNodeID())
if err != nil {
log.Error("failed to get worker(local)",
zap.Int64("nodeID", paramtable.GetNodeID()),
@ -338,7 +338,7 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
zap.Int64s("segments", lo.Map(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) int64 { return info.GetSegmentID() })),
)
worker, err := sd.workerManager.GetWorker(targetNodeID)
worker, err := sd.workerManager.GetWorker(ctx, targetNodeID)
if err != nil {
log.Warn("delegator failed to find worker", zap.Error(err))
return err
@ -603,7 +603,7 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
}
if !force {
worker, err := sd.workerManager.GetWorker(targetNodeID)
worker, err := sd.workerManager.GetWorker(ctx, targetNodeID)
if err != nil {
log.Warn("delegator failed to find worker",
zap.Error(err),

View File

@ -262,7 +262,7 @@ func (s *DelegatorDataSuite) TestProcessDelete() {
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil)
worker1.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).Return(nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
// load growing
@ -351,7 +351,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -416,7 +416,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil)
worker1.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).Return(nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -474,7 +474,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
s.loader.ExpectedCalls = nil
}()
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Return(nil, errors.New("mock error"))
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Return(nil, errors.New("mock error"))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -510,7 +510,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -554,7 +554,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(errors.New("mocked error"))
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -624,7 +624,7 @@ func (s *DelegatorDataSuite) TestReleaseSegment() {
Return(nil)
worker1.EXPECT().ReleaseSegments(mock.Anything, mock.AnythingOfType("*querypb.ReleaseSegmentsRequest")).Return(nil)
worker2.EXPECT().ReleaseSegments(mock.Anything, mock.AnythingOfType("*querypb.ReleaseSegmentsRequest")).Return(nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
// load growing

View File

@ -277,7 +277,7 @@ func (s *DelegatorSuite) TestSearch() {
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -308,7 +308,7 @@ func (s *DelegatorSuite) TestSearch() {
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -346,7 +346,7 @@ func (s *DelegatorSuite) TestSearch() {
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -386,7 +386,7 @@ func (s *DelegatorSuite) TestSearch() {
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -530,7 +530,7 @@ func (s *DelegatorSuite) TestQuery() {
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.RetrieveResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -561,7 +561,7 @@ func (s *DelegatorSuite) TestQuery() {
worker2.EXPECT().QuerySegments(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
Return(&internalpb.RetrieveResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -599,7 +599,7 @@ func (s *DelegatorSuite) TestQuery() {
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.RetrieveResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -636,7 +636,7 @@ func (s *DelegatorSuite) TestQuery() {
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.RetrieveResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -751,7 +751,7 @@ func (s *DelegatorSuite) TestGetStats() {
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.GetStatisticsResponse{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -787,7 +787,7 @@ func (s *DelegatorSuite) TestGetStats() {
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.GetStatisticsResponse{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
@ -824,7 +824,7 @@ func (s *DelegatorSuite) TestGetStats() {
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.GetStatisticsResponse{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)

View File

@ -298,7 +298,7 @@ func (node *QueryNode) Init() error {
)
log.Info("queryNode init scheduler", zap.String("policy", schedulePolicy))
node.clusterManager = cluster.NewWorkerManager(func(nodeID int64) (cluster.Worker, error) {
node.clusterManager = cluster.NewWorkerManager(func(ctx context.Context, nodeID int64) (cluster.Worker, error) {
if nodeID == paramtable.GetNodeID() {
return NewLocalWorker(node), nil
}
@ -316,7 +316,7 @@ func (node *QueryNode) Init() error {
}
}
client, err := grpcquerynodeclient.NewClient(node.ctx, addr, nodeID)
client, err := grpcquerynodeclient.NewClient(ctx, addr, nodeID)
if err != nil {
return nil, err
}