diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 7f2851a31e..efb0ae1427 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -336,39 +336,8 @@ func (lct *loadCollectionTask) timestamp() Timestamp { } func (lct *loadCollectionTask) updateTaskProcess() { - collectionID := lct.CollectionID - childTasks := lct.getChildTask() - allDone := true - for _, t := range childTasks { - if t.getState() != taskDone { - allDone = false - break - } - - // wait watchDeltaChannel task done after loading segment - nodeID := getDstNodeIDByTask(t) - if t.msgType() == commonpb.MsgType_LoadSegments { - if !lct.cluster.HasWatchedDeltaChannel(lct.ctx, nodeID, collectionID) { - allDone = false - break - } - } - - } - if allDone { - err := lct.meta.setLoadPercentage(collectionID, 0, 100, querypb.LoadType_LoadCollection) - if err != nil { - log.Error("loadCollectionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID)) - lct.setResultInfo(err) - return - } - - lct.once.Do(func() { - metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc() - metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(lct.elapseSpan().Milliseconds())) - metrics.QueryCoordNumChildTasks.WithLabelValues().Sub(float64(len(lct.getChildTask()))) - }) - } + //TODO move check all child task done to globalPostExecute + //this function shall just calculate intermediate progress } func (lct *loadCollectionTask) preExecute(ctx context.Context) error { @@ -622,6 +591,18 @@ func (lct *loadCollectionTask) globalPostExecute(ctx context.Context) error { } } + err = lct.meta.setLoadPercentage(lct.CollectionID, 0, 100, querypb.LoadType_LoadCollection) + if err != nil { + log.Error("loadCollectionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", lct.CollectionID)) + lct.setResultInfo(err) + return err + } + + lct.once.Do(func() { + metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc() + metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(lct.elapseSpan().Milliseconds())) + metrics.QueryCoordNumChildTasks.WithLabelValues().Sub(float64(len(lct.getChildTask()))) + }) return nil } @@ -801,39 +782,8 @@ func (lpt *loadPartitionTask) timestamp() Timestamp { } func (lpt *loadPartitionTask) updateTaskProcess() { - collectionID := lpt.CollectionID - partitionIDs := lpt.PartitionIDs - childTasks := lpt.getChildTask() - allDone := true - for _, t := range childTasks { - if t.getState() != taskDone { - allDone = false - } - - // wait watchDeltaChannel task done after loading segment - nodeID := getDstNodeIDByTask(t) - if t.msgType() == commonpb.MsgType_LoadSegments { - if !lpt.cluster.HasWatchedDeltaChannel(lpt.ctx, nodeID, collectionID) { - allDone = false - break - } - } - } - if allDone { - for _, id := range partitionIDs { - err := lpt.meta.setLoadPercentage(collectionID, id, 100, querypb.LoadType_LoadPartition) - if err != nil { - log.Error("loadPartitionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", id)) - lpt.setResultInfo(err) - return - } - } - lpt.once.Do(func() { - metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc() - metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(lpt.elapseSpan().Milliseconds())) - metrics.QueryCoordNumChildTasks.WithLabelValues().Sub(float64(len(lpt.getChildTask()))) - }) - } + //TODO move check all child task done to globalPostExecute + //this function shall just calculate intermediate progress } func (lpt *loadPartitionTask) preExecute(context.Context) error { @@ -1057,6 +1007,7 @@ func (lpt *loadPartitionTask) postExecute(ctx context.Context) error { func (lpt *loadPartitionTask) globalPostExecute(ctx context.Context) error { collectionID := lpt.CollectionID + partitionIDs := lpt.PartitionIDs collection, err := lpt.meta.getCollectionInfoByID(collectionID) if err != nil { @@ -1080,6 +1031,21 @@ func (lpt *loadPartitionTask) globalPostExecute(ctx context.Context) error { } } + for _, id := range partitionIDs { + err := lpt.meta.setLoadPercentage(collectionID, id, 100, querypb.LoadType_LoadPartition) + if err != nil { + log.Error("loadPartitionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", id)) + lpt.setResultInfo(err) + return err + } + } + + lpt.once.Do(func() { + metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc() + metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(lpt.elapseSpan().Milliseconds())) + metrics.QueryCoordNumChildTasks.WithLabelValues().Sub(float64(len(lpt.getChildTask()))) + }) + return nil } diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index ed26fd7f01..bdfb8a5cd8 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -1347,6 +1347,12 @@ func TestUpdateTaskProcessWhenLoadSegment(t *testing.T) { queryCoord.scheduler.processTask(watchDeltaChannel) collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID) assert.Nil(t, err) + assert.Equal(t, int64(0), collectionInfo.InMemoryPercentage) + + err = loadCollectionTask.globalPostExecute(ctx) + assert.NoError(t, err) + collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID) + assert.NoError(t, err) assert.Equal(t, int64(100), collectionInfo.InMemoryPercentage) err = removeAllSession() @@ -1365,6 +1371,7 @@ func TestUpdateTaskProcessWhenWatchDmChannel(t *testing.T) { queryCoord.meta.addCollection(defaultCollectionID, querypb.LoadType_LoadCollection, genDefaultCollectionSchema(false)) watchDmChannel := genWatchDmChannelTask(ctx, queryCoord, node1.queryNodeID) + loadCollectionTask := watchDmChannel.getParentTask() collectionInfo, err := queryCoord.meta.getCollectionInfoByID(defaultCollectionID) assert.Nil(t, err) @@ -1373,6 +1380,12 @@ func TestUpdateTaskProcessWhenWatchDmChannel(t *testing.T) { collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID) assert.Nil(t, err) + assert.Equal(t, int64(0), collectionInfo.InMemoryPercentage) + + err = loadCollectionTask.globalPostExecute(ctx) + assert.NoError(t, err) + collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID) + assert.NoError(t, err) assert.Equal(t, int64(100), collectionInfo.InMemoryPercentage) err = removeAllSession()