From c8f89907b678237fab8e2557f988e157eec08859 Mon Sep 17 00:00:00 2001
From: yah01 <yang.cen@zilliz.com>
Date: Tue, 17 Jan 2023 11:41:51 +0800
Subject: [PATCH] Fix current target may be updated to an invalid target
 (#21742)

Signed-off-by: yah01 <yang.cen@zilliz.com>
---
 .../observers/collection_observer.go          |  9 +--
 .../observers/collection_observer_test.go     | 19 +++++-
 .../querycoordv2/observers/target_observer.go | 59 +++++++++++++------
 internal/querycoordv2/server.go               | 11 ++--
 internal/querycoordv2/server_test.go          | 13 ++++
 5 files changed, 80 insertions(+), 31 deletions(-)

diff --git a/internal/querycoordv2/observers/collection_observer.go b/internal/querycoordv2/observers/collection_observer.go
index cfb0764e7e..20a17d7ef4 100644
--- a/internal/querycoordv2/observers/collection_observer.go
+++ b/internal/querycoordv2/observers/collection_observer.go
@@ -37,6 +37,7 @@ type CollectionObserver struct {
 	dist                  *meta.DistributionManager
 	meta                  *meta.Meta
 	targetMgr             *meta.TargetManager
+	targetObserver        *TargetObserver
 	collectionLoadedCount map[int64]int
 	partitionLoadedCount  map[int64]int
 
@@ -47,12 +48,14 @@ func NewCollectionObserver(
 	dist *meta.DistributionManager,
 	meta *meta.Meta,
 	targetMgr *meta.TargetManager,
+	targetObserver *TargetObserver,
 ) *CollectionObserver {
 	return &CollectionObserver{
 		stopCh:                make(chan struct{}),
 		dist:                  dist,
 		meta:                  meta,
 		targetMgr:             targetMgr,
+		targetObserver:        targetObserver,
 		collectionLoadedCount: make(map[int64]int),
 		partitionLoadedCount:  make(map[int64]int),
 	}
@@ -201,9 +204,8 @@ func (ob *CollectionObserver) observeCollectionLoadStatus(collection *meta.Colle
 		return
 	}
 	ob.collectionLoadedCount[collection.GetCollectionID()] = loadedCount
-	if updated.LoadPercentage == 100 {
+	if updated.LoadPercentage == 100 && ob.targetObserver.Check(updated.GetCollectionID()) {
 		delete(ob.collectionLoadedCount, collection.GetCollectionID())
-		ob.targetMgr.UpdateCollectionCurrentTarget(updated.CollectionID)
 		updated.Status = querypb.LoadStatus_Loaded
 		ob.meta.CollectionManager.UpdateCollection(updated)
 
@@ -265,9 +267,8 @@ func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partiti
 		return
 	}
 	ob.partitionLoadedCount[partition.GetPartitionID()] = loadedCount
-	if updated.LoadPercentage == 100 {
+	if updated.LoadPercentage == 100 && ob.targetObserver.Check(updated.GetCollectionID()) {
 		delete(ob.partitionLoadedCount, partition.GetPartitionID())
-		ob.targetMgr.UpdateCollectionCurrentTarget(partition.GetCollectionID(), partition.GetPartitionID())
 		updated.Status = querypb.LoadStatus_Loaded
 		ob.meta.CollectionManager.PutPartition(updated)
 
diff --git a/internal/querycoordv2/observers/collection_observer_test.go b/internal/querycoordv2/observers/collection_observer_test.go
index 4b30593929..b74ecf45e5 100644
--- a/internal/querycoordv2/observers/collection_observer_test.go
+++ b/internal/querycoordv2/observers/collection_observer_test.go
@@ -56,9 +56,10 @@ type CollectionObserverSuite struct {
 	broker      *meta.MockBroker
 
 	// Dependencies
-	dist      *meta.DistributionManager
-	meta      *meta.Meta
-	targetMgr *meta.TargetManager
+	dist           *meta.DistributionManager
+	meta           *meta.Meta
+	targetMgr      *meta.TargetManager
+	targetObserver *TargetObserver
 
 	// Test object
 	ob *CollectionObserver
@@ -180,18 +181,30 @@ func (suite *CollectionObserverSuite) SetupTest() {
 	suite.meta = meta.NewMeta(suite.idAllocator, suite.store)
 	suite.broker = meta.NewMockBroker(suite.T())
 	suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta)
+	suite.targetObserver = NewTargetObserver(suite.meta,
+		suite.targetMgr,
+		suite.dist,
+		suite.broker,
+	)
 
 	// Test object
 	suite.ob = NewCollectionObserver(
 		suite.dist,
 		suite.meta,
 		suite.targetMgr,
+		suite.targetObserver,
 	)
 
+	for _, collection := range suite.collections {
+		suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
+	}
+	suite.targetObserver.Start(context.Background())
+
 	suite.loadAll()
 }
 
 func (suite *CollectionObserverSuite) TearDownTest() {
+	suite.targetObserver.Stop()
 	suite.ob.Stop()
 	suite.kv.Close()
 }
diff --git a/internal/querycoordv2/observers/target_observer.go b/internal/querycoordv2/observers/target_observer.go
index 8ab1d2e43c..108450704f 100644
--- a/internal/querycoordv2/observers/target_observer.go
+++ b/internal/querycoordv2/observers/target_observer.go
@@ -30,6 +30,11 @@ import (
 	"github.com/milvus-io/milvus/internal/util/typeutil"
 )
 
+type checkRequest struct {
+	CollectionID int64
+	Notifier     chan bool
+}
+
 type targetUpdateRequest struct {
 	CollectionID  int64
 	Notifier      chan error
@@ -44,6 +49,7 @@ type TargetObserver struct {
 	distMgr   *meta.DistributionManager
 	broker    meta.Broker
 
+	manualCheck          chan checkRequest
 	nextTargetLastUpdate map[int64]time.Time
 	updateChan           chan targetUpdateRequest
 	mut                  sync.Mutex                // Guard readyNotifiers
@@ -59,6 +65,7 @@ func NewTargetObserver(meta *meta.Meta, targetMgr *meta.TargetManager, distMgr *
 		targetMgr:            targetMgr,
 		distMgr:              distMgr,
 		broker:               broker,
+		manualCheck:          make(chan checkRequest, 10),
 		nextTargetLastUpdate: make(map[int64]time.Time),
 		updateChan:           make(chan targetUpdateRequest),
 		readyNotifiers:       make(map[int64][]chan struct{}),
@@ -95,21 +102,48 @@ func (ob *TargetObserver) schedule(ctx context.Context) {
 			ob.clean()
 			ob.tryUpdateTarget()
 
-		case request := <-ob.updateChan:
-			err := ob.updateNextTarget(request.CollectionID)
+		case req := <-ob.manualCheck:
+			ob.check(req.CollectionID)
+			req.Notifier <- ob.targetMgr.IsCurrentTargetExist(req.CollectionID)
+
+		case req := <-ob.updateChan:
+			err := ob.updateNextTarget(req.CollectionID)
 			if err != nil {
-				close(request.ReadyNotifier)
+				close(req.ReadyNotifier)
 			} else {
 				ob.mut.Lock()
-				ob.readyNotifiers[request.CollectionID] = append(ob.readyNotifiers[request.CollectionID], request.ReadyNotifier)
+				ob.readyNotifiers[req.CollectionID] = append(ob.readyNotifiers[req.CollectionID], req.ReadyNotifier)
 				ob.mut.Unlock()
 			}
 
-			request.Notifier <- err
+			req.Notifier <- err
 		}
 	}
 }
 
+// Check checks whether the next target is ready,
+// and updates the current target if it is,
+// returns true if current target is not nil
+func (ob *TargetObserver) Check(collectionID int64) bool {
+	notifier := make(chan bool)
+	ob.manualCheck <- checkRequest{
+		CollectionID: collectionID,
+		Notifier:     notifier,
+	}
+	return <-notifier
+}
+
+func (ob *TargetObserver) check(collectionID int64) {
+	if ob.shouldUpdateCurrentTarget(collectionID) {
+		ob.updateCurrentTarget(collectionID)
+	}
+
+	if ob.shouldUpdateNextTarget(collectionID) {
+		// update next target in collection level
+		ob.updateNextTarget(collectionID)
+	}
+}
+
 // UpdateNextTarget updates the next target,
 // returns a channel which will be closed when the next target is ready,
 // or returns error if failed to pull target
@@ -138,14 +172,7 @@ func (ob *TargetObserver) ReleaseCollection(collectionID int64) {
 func (ob *TargetObserver) tryUpdateTarget() {
 	collections := ob.meta.GetAll()
 	for _, collectionID := range collections {
-		if ob.shouldUpdateCurrentTarget(collectionID) {
-			ob.updateCurrentTarget(collectionID)
-		}
-
-		if ob.shouldUpdateNextTarget(collectionID) {
-			// update next target in collection level
-			ob.updateNextTarget(collectionID)
-		}
+		ob.check(collectionID)
 	}
 
 	collectionSet := typeutil.NewUniqueSet(collections...)
@@ -199,12 +226,6 @@ func (ob *TargetObserver) updateNextTargetTimestamp(collectionID int64) {
 }
 
 func (ob *TargetObserver) shouldUpdateCurrentTarget(collectionID int64) bool {
-	// Collection observer will update the current target as loading done,
-	// avoid double updating, which will cause update current target to a unfinished next target
-	if !ob.targetMgr.IsCurrentTargetExist(collectionID) {
-		return false
-	}
-
 	replicaNum := ob.meta.CollectionManager.GetReplicaNumber(collectionID)
 
 	// check channel first
diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go
index dec049b425..d8bff40973 100644
--- a/internal/querycoordv2/server.go
+++ b/internal/querycoordv2/server.go
@@ -279,11 +279,6 @@ func (s *Server) initMeta() error {
 
 func (s *Server) initObserver() {
 	log.Info("init observers")
-	s.collectionObserver = observers.NewCollectionObserver(
-		s.dist,
-		s.meta,
-		s.targetMgr,
-	)
 	s.leaderObserver = observers.NewLeaderObserver(
 		s.dist,
 		s.meta,
@@ -296,6 +291,12 @@ func (s *Server) initObserver() {
 		s.dist,
 		s.broker,
 	)
+	s.collectionObserver = observers.NewCollectionObserver(
+		s.dist,
+		s.meta,
+		s.targetMgr,
+		s.targetObserver,
+	)
 }
 
 func (s *Server) afterStart() {
diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go
index 59d6a9c14d..b71a7574c4 100644
--- a/internal/querycoordv2/server_test.go
+++ b/internal/querycoordv2/server_test.go
@@ -35,6 +35,7 @@ import (
 	"github.com/milvus-io/milvus/internal/querycoordv2/dist"
 	"github.com/milvus-io/milvus/internal/querycoordv2/meta"
 	"github.com/milvus-io/milvus/internal/querycoordv2/mocks"
+	"github.com/milvus-io/milvus/internal/querycoordv2/observers"
 	"github.com/milvus-io/milvus/internal/querycoordv2/params"
 	"github.com/milvus-io/milvus/internal/querycoordv2/session"
 	"github.com/milvus-io/milvus/internal/querycoordv2/task"
@@ -401,6 +402,18 @@ func (suite *ServerSuite) hackServer() {
 		suite.server.balancer,
 		suite.server.taskScheduler,
 	)
+	suite.server.targetObserver = observers.NewTargetObserver(
+		suite.server.meta,
+		suite.server.targetMgr,
+		suite.server.dist,
+		suite.broker,
+	)
+	suite.server.collectionObserver = observers.NewCollectionObserver(
+		suite.server.dist,
+		suite.server.meta,
+		suite.server.targetMgr,
+		suite.server.targetObserver,
+	)
 
 	suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(&schemapb.CollectionSchema{}, nil).Maybe()
 	for _, collection := range suite.collections {