diff --git a/internal/querycoordv2/observers/collection_observer.go b/internal/querycoordv2/observers/collection_observer.go index d0ccbd7859..cfce7e12b5 100644 --- a/internal/querycoordv2/observers/collection_observer.go +++ b/internal/querycoordv2/observers/collection_observer.go @@ -102,7 +102,7 @@ func (ob *CollectionObserver) Stop() { func (ob *CollectionObserver) Observe(ctx context.Context) { ob.observeTimeout() - ob.observeLoadStatus() + ob.observeLoadStatus(ctx) } func (ob *CollectionObserver) observeTimeout() { @@ -158,7 +158,7 @@ func (ob *CollectionObserver) readyToObserve(collectionID int64) bool { return metaExist && targetExist } -func (ob *CollectionObserver) observeLoadStatus() { +func (ob *CollectionObserver) observeLoadStatus(ctx context.Context) { partitions := ob.meta.CollectionManager.GetAllPartitions() if len(partitions) > 0 { log.Info("observe partitions status", zap.Int("partitionNum", len(partitions))) @@ -170,7 +170,7 @@ func (ob *CollectionObserver) observeLoadStatus() { } if ob.readyToObserve(partition.CollectionID) { replicaNum := ob.meta.GetReplicaNumber(partition.GetCollectionID()) - ob.observePartitionLoadStatus(partition, replicaNum) + ob.observePartitionLoadStatus(ctx, partition, replicaNum) loading = true } } @@ -180,7 +180,7 @@ func (ob *CollectionObserver) observeLoadStatus() { } } -func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partition, replicaNum int32) { +func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, partition *meta.Partition, replicaNum int32) { log := log.With( zap.Int64("collectionID", partition.GetCollectionID()), zap.Int64("partitionID", partition.GetPartitionID()), @@ -230,7 +230,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partiti } ob.partitionLoadedCount[partition.GetPartitionID()] = loadedCount - if loadPercentage == 100 && ob.targetObserver.Check(partition.GetCollectionID()) && ob.leaderObserver.CheckTargetVersion(partition.GetCollectionID()) { + if loadPercentage == 100 && ob.targetObserver.Check(ctx, partition.GetCollectionID()) && ob.leaderObserver.CheckTargetVersion(ctx, partition.GetCollectionID()) { delete(ob.partitionLoadedCount, partition.GetPartitionID()) } collectionPercentage, err := ob.meta.CollectionManager.UpdateLoadPercent(partition.PartitionID, loadPercentage) diff --git a/internal/querycoordv2/observers/leader_observer.go b/internal/querycoordv2/observers/leader_observer.go index 09da5c9a40..0e01477fbd 100644 --- a/internal/querycoordv2/observers/leader_observer.go +++ b/internal/querycoordv2/observers/leader_observer.go @@ -137,13 +137,20 @@ func (o *LeaderObserver) observeCollection(ctx context.Context, collection int64 return result } -func (ob *LeaderObserver) CheckTargetVersion(collectionID int64) bool { +func (ob *LeaderObserver) CheckTargetVersion(ctx context.Context, collectionID int64) bool { notifier := make(chan bool) - ob.manualCheck <- checkRequest{ - CollectionID: collectionID, - Notifier: notifier, + select { + case ob.manualCheck <- checkRequest{CollectionID: collectionID, Notifier: notifier}: + case <-ctx.Done(): + return false + } + + select { + case result := <-notifier: + return result + case <-ctx.Done(): + return false } - return <-notifier } func (o *LeaderObserver) checkNeedUpdateTargetVersion(ctx context.Context, leaderView *meta.LeaderView) *querypb.SyncAction { diff --git a/internal/querycoordv2/observers/leader_observer_test.go b/internal/querycoordv2/observers/leader_observer_test.go index 3a2738f4ff..c2f1f771d2 100644 --- a/internal/querycoordv2/observers/leader_observer_test.go +++ b/internal/querycoordv2/observers/leader_observer_test.go @@ -591,6 +591,44 @@ func (suite *LeaderObserverTestSuite) TestSyncTargetVersion() { suite.Len(action.SealedInTarget, 1) } +func (suite *LeaderObserverTestSuite) TestCheckTargetVersion() { + collectionID := int64(1001) + observer := suite.observer + + suite.Run("check_channel_blocked", func() { + oldCh := observer.manualCheck + defer func() { + observer.manualCheck = oldCh + }() + + // zero-length channel + observer.manualCheck = make(chan checkRequest) + + ctx, cancel := context.WithCancel(context.Background()) + // cancel context, make test return fast + cancel() + + result := observer.CheckTargetVersion(ctx, collectionID) + suite.False(result) + }) + + suite.Run("check_return_ctx_timeout", func() { + oldCh := observer.manualCheck + defer func() { + observer.manualCheck = oldCh + }() + + // make channel length = 1, task received + observer.manualCheck = make(chan checkRequest, 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel() + + result := observer.CheckTargetVersion(ctx, collectionID) + suite.False(result) + }) +} + func TestLeaderObserverSuite(t *testing.T) { suite.Run(t, new(LeaderObserverTestSuite)) } diff --git a/internal/querycoordv2/observers/target_observer.go b/internal/querycoordv2/observers/target_observer.go index 955547a111..9c057a5fee 100644 --- a/internal/querycoordv2/observers/target_observer.go +++ b/internal/querycoordv2/observers/target_observer.go @@ -131,13 +131,20 @@ func (ob *TargetObserver) schedule(ctx context.Context) { // Check checks whether the next target is ready, // and updates the current target if it is, // returns true if current target is not nil -func (ob *TargetObserver) Check(collectionID int64) bool { +func (ob *TargetObserver) Check(ctx context.Context, collectionID int64) bool { notifier := make(chan bool) - ob.manualCheck <- checkRequest{ - CollectionID: collectionID, - Notifier: notifier, + select { + case ob.manualCheck <- checkRequest{CollectionID: collectionID, Notifier: notifier}: + case <-ctx.Done(): + return false + } + + select { + case result := <-notifier: + return result + case <-ctx.Done(): + return false } - return <-notifier } func (ob *TargetObserver) check(collectionID int64) { diff --git a/internal/querycoordv2/observers/target_observer_test.go b/internal/querycoordv2/observers/target_observer_test.go index dde5bf1b3a..c0da2a556a 100644 --- a/internal/querycoordv2/observers/target_observer_test.go +++ b/internal/querycoordv2/observers/target_observer_test.go @@ -17,6 +17,7 @@ package observers import ( + "context" "testing" "time" @@ -215,6 +216,101 @@ func (suite *TargetObserverSuite) TearDownSuite() { suite.observer.Stop() } +type TargetObserverCheckSuite struct { + suite.Suite + + kv kv.MetaKv + // dependency + meta *meta.Meta + targetMgr *meta.TargetManager + distMgr *meta.DistributionManager + broker *meta.MockBroker + + observer *TargetObserver + + collectionID int64 + partitionID int64 +} + +func (suite *TargetObserverCheckSuite) SetupSuite() { + paramtable.Init() +} + +func (suite *TargetObserverCheckSuite) SetupTest() { + var err error + config := GenerateEtcdConfig() + cli, err := etcd.GetEtcdClient( + config.UseEmbedEtcd.GetAsBool(), + config.EtcdUseSSL.GetAsBool(), + config.Endpoints.GetAsStrings(), + config.EtcdTLSCert.GetValue(), + config.EtcdTLSKey.GetValue(), + config.EtcdTLSCACert.GetValue(), + config.EtcdTLSMinVersion.GetValue()) + suite.Require().NoError(err) + suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + + // meta + store := querycoord.NewCatalog(suite.kv) + idAllocator := RandomIncrementIDAllocator() + suite.meta = meta.NewMeta(idAllocator, store, session.NewNodeManager()) + + suite.broker = meta.NewMockBroker(suite.T()) + suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta) + suite.distMgr = meta.NewDistributionManager() + suite.observer = NewTargetObserver(suite.meta, suite.targetMgr, suite.distMgr, suite.broker) + suite.collectionID = int64(1000) + suite.partitionID = int64(100) + + err = suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(suite.collectionID, 1)) + suite.NoError(err) + err = suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collectionID, suite.partitionID)) + suite.NoError(err) + replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, 1, meta.DefaultResourceGroupName) + suite.NoError(err) + replicas[0].AddNode(2) + err = suite.meta.ReplicaManager.Put(replicas...) + suite.NoError(err) +} + +func (suite *TargetObserverCheckSuite) TestCheckCtxDone() { + observer := suite.observer + + suite.Run("check_channel_blocked", func() { + oldCh := observer.manualCheck + defer func() { + observer.manualCheck = oldCh + }() + + // zero-length channel + observer.manualCheck = make(chan checkRequest) + + ctx, cancel := context.WithCancel(context.Background()) + // cancel context, make test return fast + cancel() + + result := observer.Check(ctx, suite.collectionID) + suite.False(result) + }) + + suite.Run("check_return_ctx_timeout", func() { + oldCh := observer.manualCheck + defer func() { + observer.manualCheck = oldCh + }() + + // make channel length = 1, task received + observer.manualCheck = make(chan checkRequest, 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel() + + result := observer.Check(ctx, suite.collectionID) + suite.False(result) + }) +} + func TestTargetObserver(t *testing.T) { suite.Run(t, new(TargetObserverSuite)) + suite.Run(t, new(TargetObserverCheckSuite)) }