fix: Make syncmgr lock key before returning future (#32865)

See also #32860

SyncMgr did not ensure task key is locked before `SyncData` returning
which may cause concurrent problem during sync wich multiple policies.

This PR change sync mgr implementation to make sure the key is locked
before returning task result `*conc.Future`

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/32822/head
congqixia 2024-05-09 10:09:30 +08:00 committed by GitHub
parent 3d78b90fe7
commit a06f601c6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 153 additions and 33 deletions

View File

@ -13,32 +13,33 @@ type Task interface {
StartPosition() *msgpb.MsgPosition
ChannelName() string
Run() error
HandleError(error)
}
type keyLockDispatcher[K comparable] struct {
keyLock *lock.KeyLock[K]
workerPool *conc.Pool[error]
workerPool *conc.Pool[struct{}]
}
func newKeyLockDispatcher[K comparable](maxParallel int) *keyLockDispatcher[K] {
dispatcher := &keyLockDispatcher[K]{
workerPool: conc.NewPool[error](maxParallel, conc.WithPreAlloc(false)),
workerPool: conc.NewPool[struct{}](maxParallel, conc.WithPreAlloc(false)),
keyLock: lock.NewKeyLock[K](),
}
return dispatcher
}
func (d *keyLockDispatcher[K]) Submit(key K, t Task, callbacks ...func(error)) *conc.Future[error] {
func (d *keyLockDispatcher[K]) Submit(key K, t Task, callbacks ...func(error) error) *conc.Future[struct{}] {
d.keyLock.Lock(key)
return d.workerPool.Submit(func() (error, error) {
return d.workerPool.Submit(func() (struct{}, error) {
defer d.keyLock.Unlock(key)
err := t.Run()
for _, callback := range callbacks {
callback(err)
err = callback(err)
}
return err, nil
return struct{}{}, err
})
}

View File

@ -155,6 +155,39 @@ func (_c *MockTask_Checkpoint_Call) RunAndReturn(run func() *msgpb.MsgPosition)
return _c
}
// HandleError provides a mock function with given fields: _a0
func (_m *MockTask) HandleError(_a0 error) {
_m.Called(_a0)
}
// MockTask_HandleError_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HandleError'
type MockTask_HandleError_Call struct {
*mock.Call
}
// HandleError is a helper method to define mock.On call
// - _a0 error
func (_e *MockTask_Expecter) HandleError(_a0 interface{}) *MockTask_HandleError_Call {
return &MockTask_HandleError_Call{Call: _e.mock.On("HandleError", _a0)}
}
func (_c *MockTask_HandleError_Call) Run(run func(_a0 error)) *MockTask_HandleError_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(error))
})
return _c
}
func (_c *MockTask_HandleError_Call) Return() *MockTask_HandleError_Call {
_c.Call.Return()
return _c
}
func (_c *MockTask_HandleError_Call) RunAndReturn(run func(error)) *MockTask_HandleError_Call {
_c.Call.Return(run)
return _c
}
// Run provides a mock function with given fields:
func (_m *MockTask) Run() error {
ret := _m.Called()

View File

@ -121,30 +121,44 @@ func (mgr *syncManager) SyncData(ctx context.Context, task Task) *conc.Future[st
func (mgr *syncManager) safeSubmitTask(task Task) *conc.Future[struct{}] {
taskKey := fmt.Sprintf("%d-%d", task.SegmentID(), task.Checkpoint().GetTimestamp())
mgr.tasks.Insert(taskKey, task)
defer mgr.tasks.Remove(taskKey)
return conc.Go(func() (struct{}, error) {
defer mgr.tasks.Remove(taskKey)
for {
targetID, err := task.CalcTargetSegment()
if err != nil {
return struct{}{}, err
}
log.Info("task calculated target segment id",
zap.Int64("targetID", targetID),
zap.Int64("segmentID", task.SegmentID()),
)
key, err := task.CalcTargetSegment()
if err != nil {
task.HandleError(err)
return conc.Go(func() (struct{}, error) { return struct{}{}, err })
}
// make sync for same segment execute in sequence
// if previous sync task is not finished, block here
f := mgr.Submit(targetID, task)
err, _ = f.Await()
if errors.Is(err, errTargetSegmentNotMatch) {
log.Info("target updated during submitting", zap.Error(err))
continue
}
return struct{}{}, err
return mgr.submit(key, task)
}
func (mgr *syncManager) submit(key int64, task Task) *conc.Future[struct{}] {
handler := func(err error) error {
// unexpected error
if !errors.Is(err, errTargetSegmentNotMatch) {
task.HandleError(err)
return err
}
})
targetID, err := task.CalcTargetSegment()
// shall not reach, segment meta lost during sync
if err != nil {
task.HandleError(err)
return err
}
if targetID == key {
err = merr.WrapErrServiceInternal("recaluated with same key", fmt.Sprint(targetID))
task.HandleError(err)
return err
}
log.Info("task calculated target segment id",
zap.Int64("targetID", targetID),
zap.Int64("segmentID", task.SegmentID()),
)
return mgr.submit(targetID, task).Err()
}
log.Info("sync mgr sumbit task with key", zap.Int64("key", key))
return mgr.Submit(key, task, handler)
}
func (mgr *syncManager) GetEarliestPosition(channel string) (int64, *msgpb.MsgPosition) {

View File

@ -23,6 +23,7 @@ import (
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
)
@ -241,7 +242,7 @@ func (s *SyncManagerSuite) TestBlock() {
MsgID: []byte{1, 2, 3, 4},
Timestamp: 100,
})
manager.SyncData(context.Background(), task)
go manager.SyncData(context.Background(), task)
select {
case <-sig:
@ -317,12 +318,86 @@ func (s *SyncManagerSuite) TestTargetUpdated() {
task.EXPECT().CalcTargetSegment().Return(1001, nil).Once()
task.EXPECT().Run().Return(errTargetSegmentNotMatch).Once()
task.EXPECT().Run().Return(nil).Once()
task.EXPECT().HandleError(mock.Anything)
f := manager.SyncData(context.Background(), task)
_, err = f.Await()
s.NoError(err)
}
func (s *SyncManagerSuite) TestUnexpectedError() {
manager, err := NewSyncManager(s.chunkManager, s.allocator)
s.NoError(err)
task := NewMockTask(s.T())
task.EXPECT().SegmentID().Return(1000)
task.EXPECT().Checkpoint().Return(&msgpb.MsgPosition{})
task.EXPECT().CalcTargetSegment().Return(1000, nil).Once()
task.EXPECT().Run().Return(merr.WrapErrServiceInternal("mocked")).Once()
task.EXPECT().HandleError(mock.Anything)
f := manager.SyncData(context.Background(), task)
_, err = f.Await()
s.Error(err)
}
func (s *SyncManagerSuite) TestCalcTargetError() {
s.Run("fail_before_submit", func() {
manager, err := NewSyncManager(s.chunkManager, s.allocator)
s.NoError(err)
mockErr := merr.WrapErrServiceInternal("mocked")
task := NewMockTask(s.T())
task.EXPECT().SegmentID().Return(1000)
task.EXPECT().Checkpoint().Return(&msgpb.MsgPosition{})
task.EXPECT().CalcTargetSegment().Return(0, mockErr).Once()
task.EXPECT().HandleError(mock.Anything)
f := manager.SyncData(context.Background(), task)
_, err = f.Await()
s.Error(err)
s.ErrorIs(err, mockErr)
})
s.Run("fail_during_rerun", func() {
manager, err := NewSyncManager(s.chunkManager, s.allocator)
s.NoError(err)
mockErr := merr.WrapErrServiceInternal("mocked")
task := NewMockTask(s.T())
task.EXPECT().SegmentID().Return(1000)
task.EXPECT().Checkpoint().Return(&msgpb.MsgPosition{})
task.EXPECT().CalcTargetSegment().Return(1000, nil).Once()
task.EXPECT().CalcTargetSegment().Return(0, mockErr).Once()
task.EXPECT().Run().Return(errTargetSegmentNotMatch).Once()
task.EXPECT().HandleError(mock.Anything)
f := manager.SyncData(context.Background(), task)
_, err = f.Await()
s.Error(err)
s.ErrorIs(err, mockErr)
})
}
func (s *SyncManagerSuite) TestTargetUpdateSameID() {
manager, err := NewSyncManager(s.chunkManager, s.allocator)
s.NoError(err)
task := NewMockTask(s.T())
task.EXPECT().SegmentID().Return(1000)
task.EXPECT().Checkpoint().Return(&msgpb.MsgPosition{})
task.EXPECT().CalcTargetSegment().Return(1000, nil).Once()
task.EXPECT().CalcTargetSegment().Return(1000, nil).Once()
task.EXPECT().Run().Return(errTargetSegmentNotMatch).Once()
task.EXPECT().HandleError(mock.Anything)
f := manager.SyncData(context.Background(), task)
_, err = f.Await()
s.Error(err)
}
func TestSyncManager(t *testing.T) {
suite.Run(t, new(SyncManagerSuite))
}

View File

@ -109,7 +109,7 @@ func (t *SyncTask) getLogger() *log.MLogger {
)
}
func (t *SyncTask) handleError(err error) {
func (t *SyncTask) HandleError(err error) {
if errors.Is(err, errTargetSegmentNotMatch) {
return
}
@ -129,7 +129,7 @@ func (t *SyncTask) Run() (err error) {
log := t.getLogger()
defer func() {
if err != nil {
t.handleError(err)
t.HandleError(err)
}
}()
@ -138,7 +138,6 @@ func (t *SyncTask) Run() (err error) {
if !has {
log.Warn("failed to sync data, segment not found in metacache")
err := merr.WrapErrSegmentNotFound(t.segmentID)
t.handleError(err)
return err
}
@ -175,7 +174,6 @@ func (t *SyncTask) Run() (err error) {
err = t.writeLogs()
if err != nil {
log.Warn("failed to save serialized data into storage", zap.Error(err))
t.handleError(err)
return err
}
@ -195,7 +193,6 @@ func (t *SyncTask) Run() (err error) {
err = t.writeMeta()
if err != nil {
log.Warn("failed to save serialized data into storage", zap.Error(err))
t.handleError(err)
return err
}
}