From c8f89907b678237fab8e2557f988e157eec08859 Mon Sep 17 00:00:00 2001 From: yah01 <yang.cen@zilliz.com> Date: Tue, 17 Jan 2023 11:41:51 +0800 Subject: [PATCH] Fix current target may be updated to an invalid target (#21742) Signed-off-by: yah01 <yang.cen@zilliz.com> --- .../observers/collection_observer.go | 9 +-- .../observers/collection_observer_test.go | 19 +++++- .../querycoordv2/observers/target_observer.go | 59 +++++++++++++------ internal/querycoordv2/server.go | 11 ++-- internal/querycoordv2/server_test.go | 13 ++++ 5 files changed, 80 insertions(+), 31 deletions(-) diff --git a/internal/querycoordv2/observers/collection_observer.go b/internal/querycoordv2/observers/collection_observer.go index cfb0764e7e..20a17d7ef4 100644 --- a/internal/querycoordv2/observers/collection_observer.go +++ b/internal/querycoordv2/observers/collection_observer.go @@ -37,6 +37,7 @@ type CollectionObserver struct { dist *meta.DistributionManager meta *meta.Meta targetMgr *meta.TargetManager + targetObserver *TargetObserver collectionLoadedCount map[int64]int partitionLoadedCount map[int64]int @@ -47,12 +48,14 @@ func NewCollectionObserver( dist *meta.DistributionManager, meta *meta.Meta, targetMgr *meta.TargetManager, + targetObserver *TargetObserver, ) *CollectionObserver { return &CollectionObserver{ stopCh: make(chan struct{}), dist: dist, meta: meta, targetMgr: targetMgr, + targetObserver: targetObserver, collectionLoadedCount: make(map[int64]int), partitionLoadedCount: make(map[int64]int), } @@ -201,9 +204,8 @@ func (ob *CollectionObserver) observeCollectionLoadStatus(collection *meta.Colle return } ob.collectionLoadedCount[collection.GetCollectionID()] = loadedCount - if updated.LoadPercentage == 100 { + if updated.LoadPercentage == 100 && ob.targetObserver.Check(updated.GetCollectionID()) { delete(ob.collectionLoadedCount, collection.GetCollectionID()) - ob.targetMgr.UpdateCollectionCurrentTarget(updated.CollectionID) updated.Status = querypb.LoadStatus_Loaded ob.meta.CollectionManager.UpdateCollection(updated) @@ -265,9 +267,8 @@ func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partiti return } ob.partitionLoadedCount[partition.GetPartitionID()] = loadedCount - if updated.LoadPercentage == 100 { + if updated.LoadPercentage == 100 && ob.targetObserver.Check(updated.GetCollectionID()) { delete(ob.partitionLoadedCount, partition.GetPartitionID()) - ob.targetMgr.UpdateCollectionCurrentTarget(partition.GetCollectionID(), partition.GetPartitionID()) updated.Status = querypb.LoadStatus_Loaded ob.meta.CollectionManager.PutPartition(updated) diff --git a/internal/querycoordv2/observers/collection_observer_test.go b/internal/querycoordv2/observers/collection_observer_test.go index 4b30593929..b74ecf45e5 100644 --- a/internal/querycoordv2/observers/collection_observer_test.go +++ b/internal/querycoordv2/observers/collection_observer_test.go @@ -56,9 +56,10 @@ type CollectionObserverSuite struct { broker *meta.MockBroker // Dependencies - dist *meta.DistributionManager - meta *meta.Meta - targetMgr *meta.TargetManager + dist *meta.DistributionManager + meta *meta.Meta + targetMgr *meta.TargetManager + targetObserver *TargetObserver // Test object ob *CollectionObserver @@ -180,18 +181,30 @@ func (suite *CollectionObserverSuite) SetupTest() { suite.meta = meta.NewMeta(suite.idAllocator, suite.store) suite.broker = meta.NewMockBroker(suite.T()) suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta) + suite.targetObserver = NewTargetObserver(suite.meta, + suite.targetMgr, + suite.dist, + suite.broker, + ) // Test object suite.ob = NewCollectionObserver( suite.dist, suite.meta, suite.targetMgr, + suite.targetObserver, ) + for _, collection := range suite.collections { + suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() + } + suite.targetObserver.Start(context.Background()) + suite.loadAll() } func (suite *CollectionObserverSuite) TearDownTest() { + suite.targetObserver.Stop() suite.ob.Stop() suite.kv.Close() } diff --git a/internal/querycoordv2/observers/target_observer.go b/internal/querycoordv2/observers/target_observer.go index 8ab1d2e43c..108450704f 100644 --- a/internal/querycoordv2/observers/target_observer.go +++ b/internal/querycoordv2/observers/target_observer.go @@ -30,6 +30,11 @@ import ( "github.com/milvus-io/milvus/internal/util/typeutil" ) +type checkRequest struct { + CollectionID int64 + Notifier chan bool +} + type targetUpdateRequest struct { CollectionID int64 Notifier chan error @@ -44,6 +49,7 @@ type TargetObserver struct { distMgr *meta.DistributionManager broker meta.Broker + manualCheck chan checkRequest nextTargetLastUpdate map[int64]time.Time updateChan chan targetUpdateRequest mut sync.Mutex // Guard readyNotifiers @@ -59,6 +65,7 @@ func NewTargetObserver(meta *meta.Meta, targetMgr *meta.TargetManager, distMgr * targetMgr: targetMgr, distMgr: distMgr, broker: broker, + manualCheck: make(chan checkRequest, 10), nextTargetLastUpdate: make(map[int64]time.Time), updateChan: make(chan targetUpdateRequest), readyNotifiers: make(map[int64][]chan struct{}), @@ -95,21 +102,48 @@ func (ob *TargetObserver) schedule(ctx context.Context) { ob.clean() ob.tryUpdateTarget() - case request := <-ob.updateChan: - err := ob.updateNextTarget(request.CollectionID) + case req := <-ob.manualCheck: + ob.check(req.CollectionID) + req.Notifier <- ob.targetMgr.IsCurrentTargetExist(req.CollectionID) + + case req := <-ob.updateChan: + err := ob.updateNextTarget(req.CollectionID) if err != nil { - close(request.ReadyNotifier) + close(req.ReadyNotifier) } else { ob.mut.Lock() - ob.readyNotifiers[request.CollectionID] = append(ob.readyNotifiers[request.CollectionID], request.ReadyNotifier) + ob.readyNotifiers[req.CollectionID] = append(ob.readyNotifiers[req.CollectionID], req.ReadyNotifier) ob.mut.Unlock() } - request.Notifier <- err + req.Notifier <- err } } } +// Check checks whether the next target is ready, +// and updates the current target if it is, +// returns true if current target is not nil +func (ob *TargetObserver) Check(collectionID int64) bool { + notifier := make(chan bool) + ob.manualCheck <- checkRequest{ + CollectionID: collectionID, + Notifier: notifier, + } + return <-notifier +} + +func (ob *TargetObserver) check(collectionID int64) { + if ob.shouldUpdateCurrentTarget(collectionID) { + ob.updateCurrentTarget(collectionID) + } + + if ob.shouldUpdateNextTarget(collectionID) { + // update next target in collection level + ob.updateNextTarget(collectionID) + } +} + // UpdateNextTarget updates the next target, // returns a channel which will be closed when the next target is ready, // or returns error if failed to pull target @@ -138,14 +172,7 @@ func (ob *TargetObserver) ReleaseCollection(collectionID int64) { func (ob *TargetObserver) tryUpdateTarget() { collections := ob.meta.GetAll() for _, collectionID := range collections { - if ob.shouldUpdateCurrentTarget(collectionID) { - ob.updateCurrentTarget(collectionID) - } - - if ob.shouldUpdateNextTarget(collectionID) { - // update next target in collection level - ob.updateNextTarget(collectionID) - } + ob.check(collectionID) } collectionSet := typeutil.NewUniqueSet(collections...) @@ -199,12 +226,6 @@ func (ob *TargetObserver) updateNextTargetTimestamp(collectionID int64) { } func (ob *TargetObserver) shouldUpdateCurrentTarget(collectionID int64) bool { - // Collection observer will update the current target as loading done, - // avoid double updating, which will cause update current target to a unfinished next target - if !ob.targetMgr.IsCurrentTargetExist(collectionID) { - return false - } - replicaNum := ob.meta.CollectionManager.GetReplicaNumber(collectionID) // check channel first diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index dec049b425..d8bff40973 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -279,11 +279,6 @@ func (s *Server) initMeta() error { func (s *Server) initObserver() { log.Info("init observers") - s.collectionObserver = observers.NewCollectionObserver( - s.dist, - s.meta, - s.targetMgr, - ) s.leaderObserver = observers.NewLeaderObserver( s.dist, s.meta, @@ -296,6 +291,12 @@ func (s *Server) initObserver() { s.dist, s.broker, ) + s.collectionObserver = observers.NewCollectionObserver( + s.dist, + s.meta, + s.targetMgr, + s.targetObserver, + ) } func (s *Server) afterStart() { diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index 59d6a9c14d..b71a7574c4 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -35,6 +35,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/dist" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/mocks" + "github.com/milvus-io/milvus/internal/querycoordv2/observers" "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" @@ -401,6 +402,18 @@ func (suite *ServerSuite) hackServer() { suite.server.balancer, suite.server.taskScheduler, ) + suite.server.targetObserver = observers.NewTargetObserver( + suite.server.meta, + suite.server.targetMgr, + suite.server.dist, + suite.broker, + ) + suite.server.collectionObserver = observers.NewCollectionObserver( + suite.server.dist, + suite.server.meta, + suite.server.targetMgr, + suite.server.targetObserver, + ) suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(&schemapb.CollectionSchema{}, nil).Maybe() for _, collection := range suite.collections {