Fix load progress complete before syncReplica (#17828)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/17818/head
congqixia 2022-06-27 20:52:16 +08:00 committed by GitHub
parent f55ca5eb56
commit 6840486efc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 66 deletions

View File

@ -336,39 +336,8 @@ func (lct *loadCollectionTask) timestamp() Timestamp {
}
func (lct *loadCollectionTask) updateTaskProcess() {
collectionID := lct.CollectionID
childTasks := lct.getChildTask()
allDone := true
for _, t := range childTasks {
if t.getState() != taskDone {
allDone = false
break
}
// wait watchDeltaChannel task done after loading segment
nodeID := getDstNodeIDByTask(t)
if t.msgType() == commonpb.MsgType_LoadSegments {
if !lct.cluster.HasWatchedDeltaChannel(lct.ctx, nodeID, collectionID) {
allDone = false
break
}
}
}
if allDone {
err := lct.meta.setLoadPercentage(collectionID, 0, 100, querypb.LoadType_LoadCollection)
if err != nil {
log.Error("loadCollectionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID))
lct.setResultInfo(err)
return
}
lct.once.Do(func() {
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(lct.elapseSpan().Milliseconds()))
metrics.QueryCoordNumChildTasks.WithLabelValues().Sub(float64(len(lct.getChildTask())))
})
}
//TODO move check all child task done to globalPostExecute
//this function shall just calculate intermediate progress
}
func (lct *loadCollectionTask) preExecute(ctx context.Context) error {
@ -622,6 +591,18 @@ func (lct *loadCollectionTask) globalPostExecute(ctx context.Context) error {
}
}
err = lct.meta.setLoadPercentage(lct.CollectionID, 0, 100, querypb.LoadType_LoadCollection)
if err != nil {
log.Error("loadCollectionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", lct.CollectionID))
lct.setResultInfo(err)
return err
}
lct.once.Do(func() {
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(lct.elapseSpan().Milliseconds()))
metrics.QueryCoordNumChildTasks.WithLabelValues().Sub(float64(len(lct.getChildTask())))
})
return nil
}
@ -801,39 +782,8 @@ func (lpt *loadPartitionTask) timestamp() Timestamp {
}
func (lpt *loadPartitionTask) updateTaskProcess() {
collectionID := lpt.CollectionID
partitionIDs := lpt.PartitionIDs
childTasks := lpt.getChildTask()
allDone := true
for _, t := range childTasks {
if t.getState() != taskDone {
allDone = false
}
// wait watchDeltaChannel task done after loading segment
nodeID := getDstNodeIDByTask(t)
if t.msgType() == commonpb.MsgType_LoadSegments {
if !lpt.cluster.HasWatchedDeltaChannel(lpt.ctx, nodeID, collectionID) {
allDone = false
break
}
}
}
if allDone {
for _, id := range partitionIDs {
err := lpt.meta.setLoadPercentage(collectionID, id, 100, querypb.LoadType_LoadPartition)
if err != nil {
log.Error("loadPartitionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", id))
lpt.setResultInfo(err)
return
}
}
lpt.once.Do(func() {
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(lpt.elapseSpan().Milliseconds()))
metrics.QueryCoordNumChildTasks.WithLabelValues().Sub(float64(len(lpt.getChildTask())))
})
}
//TODO move check all child task done to globalPostExecute
//this function shall just calculate intermediate progress
}
func (lpt *loadPartitionTask) preExecute(context.Context) error {
@ -1057,6 +1007,7 @@ func (lpt *loadPartitionTask) postExecute(ctx context.Context) error {
func (lpt *loadPartitionTask) globalPostExecute(ctx context.Context) error {
collectionID := lpt.CollectionID
partitionIDs := lpt.PartitionIDs
collection, err := lpt.meta.getCollectionInfoByID(collectionID)
if err != nil {
@ -1080,6 +1031,21 @@ func (lpt *loadPartitionTask) globalPostExecute(ctx context.Context) error {
}
}
for _, id := range partitionIDs {
err := lpt.meta.setLoadPercentage(collectionID, id, 100, querypb.LoadType_LoadPartition)
if err != nil {
log.Error("loadPartitionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", id))
lpt.setResultInfo(err)
return err
}
}
lpt.once.Do(func() {
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(lpt.elapseSpan().Milliseconds()))
metrics.QueryCoordNumChildTasks.WithLabelValues().Sub(float64(len(lpt.getChildTask())))
})
return nil
}

View File

@ -1347,6 +1347,12 @@ func TestUpdateTaskProcessWhenLoadSegment(t *testing.T) {
queryCoord.scheduler.processTask(watchDeltaChannel)
collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, int64(0), collectionInfo.InMemoryPercentage)
err = loadCollectionTask.globalPostExecute(ctx)
assert.NoError(t, err)
collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID)
assert.NoError(t, err)
assert.Equal(t, int64(100), collectionInfo.InMemoryPercentage)
err = removeAllSession()
@ -1365,6 +1371,7 @@ func TestUpdateTaskProcessWhenWatchDmChannel(t *testing.T) {
queryCoord.meta.addCollection(defaultCollectionID, querypb.LoadType_LoadCollection, genDefaultCollectionSchema(false))
watchDmChannel := genWatchDmChannelTask(ctx, queryCoord, node1.queryNodeID)
loadCollectionTask := watchDmChannel.getParentTask()
collectionInfo, err := queryCoord.meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, err)
@ -1373,6 +1380,12 @@ func TestUpdateTaskProcessWhenWatchDmChannel(t *testing.T) {
collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, int64(0), collectionInfo.InMemoryPercentage)
err = loadCollectionTask.globalPostExecute(ctx)
assert.NoError(t, err)
collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID)
assert.NoError(t, err)
assert.Equal(t, int64(100), collectionInfo.InMemoryPercentage)
err = removeAllSession()