diff --git a/configs/milvus.yaml b/configs/milvus.yaml index eb9847e737..d01e2a1ff9 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -180,6 +180,7 @@ queryCoord: checkHandoffInterval: 5000 taskMergeCap: 16 enableActiveStandby: false # Enable active-standby + refreshTargetsIntervalSeconds: 300 # Related configuration of queryNode, used to run hybrid search between vector and scalar data. queryNode: diff --git a/internal/querycoordv2/job/job.go b/internal/querycoordv2/job/job.go index 60f7975bc9..3392ab458b 100644 --- a/internal/querycoordv2/job/job.go +++ b/internal/querycoordv2/job/job.go @@ -218,7 +218,6 @@ func (job *LoadCollectionJob) Execute() error { log.Error(msg, zap.Error(err)) return utils.WrapError(msg, err) } - job.handoffObserver.StartHandoff(job.CollectionID()) err = job.meta.CollectionManager.PutCollection(&meta.Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ @@ -242,7 +241,7 @@ func (job *LoadCollectionJob) Execute() error { func (job *LoadCollectionJob) PostExecute() { if job.Error() != nil && !job.meta.Exist(job.CollectionID()) { job.meta.ReplicaManager.RemoveCollection(job.CollectionID()) - job.handoffObserver.Unregister(job.ctx) + job.handoffObserver.Unregister(job.ctx, job.CollectionID()) job.targetMgr.RemoveCollection(job.req.GetCollectionID()) } } diff --git a/internal/querycoordv2/observers/collection_observer.go b/internal/querycoordv2/observers/collection_observer.go index 06e2451d4c..d0a71fa8c6 100644 --- a/internal/querycoordv2/observers/collection_observer.go +++ b/internal/querycoordv2/observers/collection_observer.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/samber/lo" ) type CollectionObserver struct { @@ -37,6 +38,10 @@ type CollectionObserver struct { dist *meta.DistributionManager meta *meta.Meta targetMgr *meta.TargetManager + broker meta.Broker + handoffOb *HandoffObserver + + refreshed map[int64]time.Time stopOnce sync.Once } @@ -45,12 +50,18 @@ func NewCollectionObserver( dist *meta.DistributionManager, meta *meta.Meta, targetMgr *meta.TargetManager, + broker meta.Broker, + handoffObserver *HandoffObserver, ) *CollectionObserver { return &CollectionObserver{ stopCh: make(chan struct{}), dist: dist, meta: meta, targetMgr: targetMgr, + broker: broker, + handoffOb: handoffObserver, + + refreshed: make(map[int64]time.Time), } } @@ -90,17 +101,27 @@ func (ob *CollectionObserver) Observe() { func (ob *CollectionObserver) observeTimeout() { collections := ob.meta.CollectionManager.GetAllCollections() for _, collection := range collections { - if collection.GetStatus() != querypb.LoadStatus_Loading || - time.Now().Before(collection.UpdatedAt.Add(Params.QueryCoordCfg.LoadTimeoutSeconds)) { + if collection.GetStatus() != querypb.LoadStatus_Loading { continue } - log.Info("load collection timeout, cancel it", - zap.Int64("collectionID", collection.GetCollectionID()), - zap.Duration("loadTime", time.Since(collection.CreatedAt))) - ob.meta.CollectionManager.RemoveCollection(collection.GetCollectionID()) - ob.meta.ReplicaManager.RemoveCollection(collection.GetCollectionID()) - ob.targetMgr.RemoveCollection(collection.GetCollectionID()) + refreshTime := collection.UpdatedAt.Add(Params.QueryCoordCfg.RefreshTargetsIntervalSeconds) + timeoutTime := collection.UpdatedAt.Add(Params.QueryCoordCfg.LoadTimeoutSeconds) + + now := time.Now() + if now.After(timeoutTime) { + log.Info("load collection timeout, cancel it", + zap.Int64("collectionID", collection.GetCollectionID()), + zap.Duration("loadTime", time.Since(collection.CreatedAt))) + ob.meta.CollectionManager.RemoveCollection(collection.GetCollectionID()) + ob.meta.ReplicaManager.RemoveCollection(collection.GetCollectionID()) + ob.targetMgr.RemoveCollection(collection.GetCollectionID()) + } else if now.After(refreshTime) { + ob.refreshTargets(collection.UpdatedAt, collection.GetCollectionID()) + log.Info("load for long time, refresh targets of collection", + zap.Duration("loadTime", time.Since(collection.CreatedAt)), + ) + } } partitions := utils.GroupPartitionsByCollection( @@ -113,23 +134,69 @@ func (ob *CollectionObserver) observeTimeout() { zap.Int64("collectionID", collection), ) for _, partition := range partitions { - if partition.GetStatus() != querypb.LoadStatus_Loading || - time.Now().Before(partition.CreatedAt.Add(Params.QueryCoordCfg.LoadTimeoutSeconds)) { + if partition.GetStatus() != querypb.LoadStatus_Loading { continue } - log.Info("load partition timeout, cancel all partitions", - zap.Int64("partitionID", partition.GetPartitionID()), - zap.Duration("loadTime", time.Since(partition.CreatedAt))) - // TODO(yah01): Now, releasing part of partitions is not allowed - ob.meta.CollectionManager.RemoveCollection(partition.GetCollectionID()) - ob.meta.ReplicaManager.RemoveCollection(partition.GetCollectionID()) - ob.targetMgr.RemoveCollection(partition.GetCollectionID()) - break + refreshTime := partition.UpdatedAt.Add(Params.QueryCoordCfg.RefreshTargetsIntervalSeconds) + timeoutTime := partition.UpdatedAt.Add(Params.QueryCoordCfg.LoadTimeoutSeconds) + + now := time.Now() + if now.After(timeoutTime) { + log.Info("load partition timeout, cancel all partitions", + zap.Int64("partitionID", partition.GetPartitionID()), + zap.Duration("loadTime", time.Since(partition.CreatedAt))) + // TODO(yah01): Now, releasing part of partitions is not allowed + ob.meta.CollectionManager.RemoveCollection(partition.GetCollectionID()) + ob.meta.ReplicaManager.RemoveCollection(partition.GetCollectionID()) + ob.targetMgr.RemoveCollection(partition.GetCollectionID()) + break + } else if now.After(refreshTime) { + partitionIDs := lo.Map(partitions, func(partition *meta.Partition, _ int) int64 { + return partition.GetPartitionID() + }) + ob.refreshTargets(partition.UpdatedAt, partition.GetCollectionID(), partitionIDs...) + log.Info("load for long time, refresh targets of partitions", + zap.Duration("loadTime", time.Since(partition.CreatedAt)), + ) + break + } } } } +func (ob *CollectionObserver) refreshTargets(updatedAt time.Time, collectionID int64, partitions ...int64) { + refreshedTime, ok := ob.refreshed[collectionID] + if ok && refreshedTime.Equal(updatedAt) { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ob.targetMgr.RemoveCollection(collectionID) + ob.handoffOb.Unregister(ctx, collectionID) + + if len(partitions) == 0 { + var err error + partitions, err = ob.broker.GetPartitions(ctx, collectionID) + if err != nil { + log.Error("failed to get partitions from RootCoord", zap.Error(err)) + return + } + } + + ob.handoffOb.Register(collectionID) + utils.RegisterTargets(ctx, + ob.targetMgr, + ob.broker, + collectionID, + partitions, + ) + + ob.refreshed[collectionID] = updatedAt +} + func (ob *CollectionObserver) observeLoadStatus() { collections := ob.meta.CollectionManager.GetAllCollections() for _, collection := range collections { @@ -196,9 +263,11 @@ func (ob *CollectionObserver) observeCollectionLoadStatus(collection *meta.Colle if updated.LoadPercentage <= collection.LoadPercentage { return } + if loadedCount >= len(segmentTargets)+len(channelTargets) { updated.Status = querypb.LoadStatus_Loaded ob.meta.CollectionManager.UpdateCollection(updated) + ob.handoffOb.StartHandoff(updated.GetCollectionID()) elapsed := time.Since(updated.CreatedAt) metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(elapsed.Milliseconds())) @@ -257,9 +326,11 @@ func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partiti if updated.LoadPercentage <= partition.LoadPercentage { return } + if loadedCount >= len(segmentTargets)+len(channelTargets) { updated.Status = querypb.LoadStatus_Loaded ob.meta.CollectionManager.UpdatePartition(updated) + ob.handoffOb.StartHandoff(updated.GetCollectionID()) elapsed := time.Since(updated.CreatedAt) metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(elapsed.Milliseconds())) diff --git a/internal/querycoordv2/observers/collection_observer_test.go b/internal/querycoordv2/observers/collection_observer_test.go index b32af880cd..45475de42b 100644 --- a/internal/querycoordv2/observers/collection_observer_test.go +++ b/internal/querycoordv2/observers/collection_observer_test.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/util/etcd" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" clientv3 "go.etcd.io/etcd/client/v3" ) @@ -56,6 +57,8 @@ type CollectionObserverSuite struct { dist *meta.DistributionManager meta *meta.Meta targetMgr *meta.TargetManager + broker *meta.MockBroker + handoffOb *HandoffObserver // Test object ob *CollectionObserver @@ -153,14 +156,27 @@ func (suite *CollectionObserverSuite) SetupTest() { suite.dist = meta.NewDistributionManager() suite.meta = meta.NewMeta(suite.idAllocator, suite.store) suite.targetMgr = meta.NewTargetManager() + suite.broker = meta.NewMockBroker(suite.T()) + suite.handoffOb = NewHandoffObserver( + suite.store, + suite.meta, + suite.dist, + suite.targetMgr, + suite.broker, + ) // Test object suite.ob = NewCollectionObserver( suite.dist, suite.meta, suite.targetMgr, + suite.broker, + suite.handoffOb, ) + Params.QueryCoordCfg.LoadTimeoutSeconds = 600 * time.Second + Params.QueryCoordCfg.RefreshTargetsIntervalSeconds = 600 * time.Second + suite.loadAll() } @@ -169,7 +185,35 @@ func (suite *CollectionObserverSuite) TearDownTest() { suite.kv.Close() } -func (suite *CollectionObserverSuite) TestObserve() { +func (suite *CollectionObserverSuite) TestObserveCollectionTimeout() { + const ( + timeout = 2 * time.Second + ) + // Not timeout + Params.QueryCoordCfg.LoadTimeoutSeconds = timeout + suite.ob.Start(context.Background()) + + // Collection 100 timeout, + // collection 101 loaded timeout + suite.dist.LeaderViewManager.Update(1, &meta.LeaderView{ + ID: 1, + CollectionID: 101, + Channel: "101-dmc0", + Segments: map[int64]*querypb.SegmentDist{3: {NodeID: 1, Version: 0}}, + }) + suite.dist.LeaderViewManager.Update(2, &meta.LeaderView{ + ID: 2, + CollectionID: 101, + Channel: "101-dmc1", + Segments: map[int64]*querypb.SegmentDist{4: {NodeID: 2, Version: 0}}, + }) + suite.Eventually(func() bool { + return suite.isCollectionTimeout(suite.collections[0]) && + suite.isCollectionLoaded(suite.collections[1]) + }, timeout*2, timeout/10) +} + +func (suite *CollectionObserverSuite) TestObservePartitionsTimeout() { const ( timeout = 2 * time.Second ) @@ -197,6 +241,63 @@ func (suite *CollectionObserverSuite) TestObserve() { }, timeout*2, timeout/10) } +func (suite *CollectionObserverSuite) TestObserveCollectionRefresh() { + const ( + timeout = 2 * time.Second + ) + // Not timeout + Params.QueryCoordCfg.RefreshTargetsIntervalSeconds = timeout + suite.broker.EXPECT().GetPartitions(mock.Anything, int64(100)).Return(suite.partitions[100], nil) + for _, partition := range suite.partitions[100] { + suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(100), partition).Return(nil, nil, nil) + } + suite.ob.Start(context.Background()) + + // Collection 100 refreshed, + // collection 101 loaded + suite.dist.LeaderViewManager.Update(1, &meta.LeaderView{ + ID: 1, + CollectionID: 101, + Channel: "101-dmc0", + Segments: map[int64]*querypb.SegmentDist{3: {NodeID: 1, Version: 0}}, + }) + suite.dist.LeaderViewManager.Update(2, &meta.LeaderView{ + ID: 2, + CollectionID: 101, + Channel: "101-dmc1", + Segments: map[int64]*querypb.SegmentDist{4: {NodeID: 2, Version: 0}}, + }) + time.Sleep(timeout * 2) +} + +func (suite *CollectionObserverSuite) TestObservePartitionsRefresh() { + const ( + timeout = 2 * time.Second + ) + // Not timeout + Params.QueryCoordCfg.RefreshTargetsIntervalSeconds = timeout + for _, partition := range suite.partitions[101] { + suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(101), partition).Return(nil, nil, nil) + } + suite.ob.Start(context.Background()) + + // Collection 100 loaded, + // collection 101 refreshed + suite.dist.LeaderViewManager.Update(1, &meta.LeaderView{ + ID: 1, + CollectionID: 100, + Channel: "100-dmc0", + Segments: map[int64]*querypb.SegmentDist{1: {NodeID: 1, Version: 0}}, + }) + suite.dist.LeaderViewManager.Update(2, &meta.LeaderView{ + ID: 2, + CollectionID: 100, + Channel: "100-dmc1", + Segments: map[int64]*querypb.SegmentDist{2: {NodeID: 2, Version: 0}}, + }) + time.Sleep(timeout * 2) +} + func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool { exist := suite.meta.Exist(collection) percentage := suite.meta.GetLoadPercentage(collection) diff --git a/internal/querycoordv2/observers/handoff_observer.go b/internal/querycoordv2/observers/handoff_observer.go index 8e94549795..3e6778f858 100644 --- a/internal/querycoordv2/observers/handoff_observer.go +++ b/internal/querycoordv2/observers/handoff_observer.go @@ -120,6 +120,7 @@ func (ob *HandoffObserver) Unregister(ctx context.Context, collectionIDs ...int6 for segmentID, event := range ob.handoffEvents { if collectionSet.Contain(event.Segment.GetCollectionID()) { delete(ob.handoffEvents, segmentID) + ob.cleanEvent(ctx, event.Segment) } } } diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index d075ab50bb..2f6f7d21c8 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -273,17 +273,6 @@ func (s *Server) initMeta() error { func (s *Server) initObserver() { log.Info("init observers") - s.collectionObserver = observers.NewCollectionObserver( - s.dist, - s.meta, - s.targetMgr, - ) - s.leaderObserver = observers.NewLeaderObserver( - s.dist, - s.meta, - s.targetMgr, - s.cluster, - ) s.handoffObserver = observers.NewHandoffObserver( s.store, s.meta, @@ -291,6 +280,19 @@ func (s *Server) initObserver() { s.targetMgr, s.broker, ) + s.collectionObserver = observers.NewCollectionObserver( + s.dist, + s.meta, + s.targetMgr, + s.broker, + s.handoffObserver, + ) + s.leaderObserver = observers.NewLeaderObserver( + s.dist, + s.meta, + s.targetMgr, + s.cluster, + ) } func (s *Server) Start() error { @@ -518,18 +520,13 @@ func (s *Server) recoverCollectionTargets(ctx context.Context, collection int64) } s.handoffObserver.Register(collection) - err = utils.RegisterTargets( + return utils.RegisterTargets( ctx, s.targetMgr, s.broker, collection, partitions, ) - if err != nil { - return err - } - s.handoffObserver.StartHandoff(collection) - return nil } func (s *Server) watchNodes(revision int64) { diff --git a/internal/util/paramtable/component_param.go b/internal/util/paramtable/component_param.go index 0a9c311014..cc3adc4b97 100644 --- a/internal/util/paramtable/component_param.go +++ b/internal/util/paramtable/component_param.go @@ -720,6 +720,7 @@ type queryCoordConfig struct { LoadTimeoutSeconds time.Duration CheckHandoffInterval time.Duration EnableActiveStandby bool + RefreshTargetsIntervalSeconds time.Duration } func (p *queryCoordConfig) init(base *BaseTable) { @@ -746,6 +747,7 @@ func (p *queryCoordConfig) init(base *BaseTable) { p.initLoadTimeoutSeconds() p.initCheckHandoffInterval() p.initEnableActiveStandby() + p.initRefreshTargetsIntervalSeconds() } func (p *queryCoordConfig) initTaskRetryNum() { @@ -877,6 +879,15 @@ func (p *queryCoordConfig) GetNodeID() UniqueID { return 0 } +func (p *queryCoordConfig) initRefreshTargetsIntervalSeconds() { + interval := p.Base.LoadWithDefault("queryCoord.refreshTargetsIntervalSeconds", "300") + refreshInterval, err := strconv.ParseInt(interval, 10, 64) + if err != nil { + panic(err) + } + p.RefreshTargetsIntervalSeconds = time.Duration(refreshInterval) * time.Second +} + // ///////////////////////////////////////////////////////////////////////////// // --- querynode --- type queryNodeConfig struct {