mirror of https://github.com/milvus-io/milvus.git
Make load operation parallel improve performance (#15928)
Signed-off-by: yah01 <yang.cen@zilliz.com>pull/16000/head
parent
2256e94ea8
commit
4f3d5fb6eb
|
@ -21,6 +21,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -89,7 +90,7 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme
|
|||
zap.Any("loadType", segmentType),
|
||||
)
|
||||
// check memory limit
|
||||
err := loader.checkSegmentSize(req.CollectionID, req.Infos)
|
||||
err := loader.checkSegmentSize(req.CollectionID, req.Infos, runtime.GOMAXPROCS(0))
|
||||
if err != nil {
|
||||
log.Error("load failed, OOM if loaded", zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), zap.Error(err))
|
||||
return err
|
||||
|
@ -127,8 +128,8 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme
|
|||
newSegments[segmentID] = segment
|
||||
}
|
||||
|
||||
// start to load
|
||||
for _, loadInfo := range req.Infos {
|
||||
loadSegmentFunc := func(idx int) error {
|
||||
loadInfo := req.Infos[idx]
|
||||
collectionID := loadInfo.CollectionID
|
||||
partitionID := loadInfo.PartitionID
|
||||
segmentID := loadInfo.SegmentID
|
||||
|
@ -142,10 +143,18 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme
|
|||
zap.Int64("segmentID", segmentID),
|
||||
zap.Int32("segment type", int32(segmentType)),
|
||||
zap.Error(err))
|
||||
segmentGC()
|
||||
return err
|
||||
}
|
||||
|
||||
metrics.QueryNodeLoadSegmentLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.QueryNodeID)).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
return nil
|
||||
}
|
||||
// start to load
|
||||
err = funcutil.ProcessFuncParallel(len(req.Infos), runtime.GOMAXPROCS(0), loadSegmentFunc, "loadSegmentFunc")
|
||||
if err != nil {
|
||||
segmentGC()
|
||||
return err
|
||||
}
|
||||
|
||||
// set segment to meta replica
|
||||
|
@ -604,9 +613,12 @@ func JoinIDPath(ids ...UniqueID) string {
|
|||
return path.Join(idStr...)
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoadInfos []*querypb.SegmentLoadInfo) error {
|
||||
func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoadInfos []*querypb.SegmentLoadInfo, concurrency int) error {
|
||||
usedMem := metricsinfo.GetUsedMemoryCount()
|
||||
totalMem := metricsinfo.GetMemoryCount()
|
||||
if len(segmentLoadInfos) < concurrency {
|
||||
concurrency = len(segmentLoadInfos)
|
||||
}
|
||||
|
||||
if usedMem == 0 || totalMem == 0 {
|
||||
return fmt.Errorf("get memory failed when checkSegmentSize, collectionID = %d", collectionID)
|
||||
|
@ -623,7 +635,7 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
|
|||
}
|
||||
|
||||
// when load segment, data will be copied from go memory to c++ memory
|
||||
if uint64(usedMemAfterLoad+maxSegmentSize) > uint64(float64(totalMem)*Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage) {
|
||||
if uint64(usedMemAfterLoad+maxSegmentSize*int64(concurrency)) > uint64(float64(totalMem)*Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage) {
|
||||
return fmt.Errorf("load segment failed, OOM if load, collectionID = %d, maxSegmentSize = %d, usedMemAfterLoad = %d, totalMem = %d, thresholdFactor = %f",
|
||||
collectionID, maxSegmentSize, usedMemAfterLoad, totalMem, Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage)
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package querynode
|
|||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -315,7 +316,7 @@ func TestSegmentLoader_checkSegmentSize(t *testing.T) {
|
|||
loader := node.loader
|
||||
assert.NotNil(t, loader)
|
||||
|
||||
err = loader.checkSegmentSize(defaultSegmentID, []*querypb.SegmentLoadInfo{{SegmentID: defaultSegmentID, SegmentSize: 1024}})
|
||||
err = loader.checkSegmentSize(defaultSegmentID, []*querypb.SegmentLoadInfo{{SegmentID: defaultSegmentID, SegmentSize: 1024}}, runtime.GOMAXPROCS(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
//totalMem, err := getTotalMemory()
|
||||
|
|
Loading…
Reference in New Issue