mirror of https://github.com/milvus-io/milvus.git
Make load operation parallel improve performance (#15928)
Signed-off-by: yah01 <yang.cen@zilliz.com>pull/16000/head
parent
2256e94ea8
commit
4f3d5fb6eb
|
@ -21,6 +21,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path"
|
"path"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -89,7 +90,7 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme
|
||||||
zap.Any("loadType", segmentType),
|
zap.Any("loadType", segmentType),
|
||||||
)
|
)
|
||||||
// check memory limit
|
// check memory limit
|
||||||
err := loader.checkSegmentSize(req.CollectionID, req.Infos)
|
err := loader.checkSegmentSize(req.CollectionID, req.Infos, runtime.GOMAXPROCS(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("load failed, OOM if loaded", zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), zap.Error(err))
|
log.Error("load failed, OOM if loaded", zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), zap.Error(err))
|
||||||
return err
|
return err
|
||||||
|
@ -127,8 +128,8 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme
|
||||||
newSegments[segmentID] = segment
|
newSegments[segmentID] = segment
|
||||||
}
|
}
|
||||||
|
|
||||||
// start to load
|
loadSegmentFunc := func(idx int) error {
|
||||||
for _, loadInfo := range req.Infos {
|
loadInfo := req.Infos[idx]
|
||||||
collectionID := loadInfo.CollectionID
|
collectionID := loadInfo.CollectionID
|
||||||
partitionID := loadInfo.PartitionID
|
partitionID := loadInfo.PartitionID
|
||||||
segmentID := loadInfo.SegmentID
|
segmentID := loadInfo.SegmentID
|
||||||
|
@ -142,10 +143,18 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme
|
||||||
zap.Int64("segmentID", segmentID),
|
zap.Int64("segmentID", segmentID),
|
||||||
zap.Int32("segment type", int32(segmentType)),
|
zap.Int32("segment type", int32(segmentType)),
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
segmentGC()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics.QueryNodeLoadSegmentLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.QueryNodeID)).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
metrics.QueryNodeLoadSegmentLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.QueryNodeID)).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// start to load
|
||||||
|
err = funcutil.ProcessFuncParallel(len(req.Infos), runtime.GOMAXPROCS(0), loadSegmentFunc, "loadSegmentFunc")
|
||||||
|
if err != nil {
|
||||||
|
segmentGC()
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// set segment to meta replica
|
// set segment to meta replica
|
||||||
|
@ -604,9 +613,12 @@ func JoinIDPath(ids ...UniqueID) string {
|
||||||
return path.Join(idStr...)
|
return path.Join(idStr...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoadInfos []*querypb.SegmentLoadInfo) error {
|
func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoadInfos []*querypb.SegmentLoadInfo, concurrency int) error {
|
||||||
usedMem := metricsinfo.GetUsedMemoryCount()
|
usedMem := metricsinfo.GetUsedMemoryCount()
|
||||||
totalMem := metricsinfo.GetMemoryCount()
|
totalMem := metricsinfo.GetMemoryCount()
|
||||||
|
if len(segmentLoadInfos) < concurrency {
|
||||||
|
concurrency = len(segmentLoadInfos)
|
||||||
|
}
|
||||||
|
|
||||||
if usedMem == 0 || totalMem == 0 {
|
if usedMem == 0 || totalMem == 0 {
|
||||||
return fmt.Errorf("get memory failed when checkSegmentSize, collectionID = %d", collectionID)
|
return fmt.Errorf("get memory failed when checkSegmentSize, collectionID = %d", collectionID)
|
||||||
|
@ -623,7 +635,7 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
|
||||||
}
|
}
|
||||||
|
|
||||||
// when load segment, data will be copied from go memory to c++ memory
|
// when load segment, data will be copied from go memory to c++ memory
|
||||||
if uint64(usedMemAfterLoad+maxSegmentSize) > uint64(float64(totalMem)*Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage) {
|
if uint64(usedMemAfterLoad+maxSegmentSize*int64(concurrency)) > uint64(float64(totalMem)*Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage) {
|
||||||
return fmt.Errorf("load segment failed, OOM if load, collectionID = %d, maxSegmentSize = %d, usedMemAfterLoad = %d, totalMem = %d, thresholdFactor = %f",
|
return fmt.Errorf("load segment failed, OOM if load, collectionID = %d, maxSegmentSize = %d, usedMemAfterLoad = %d, totalMem = %d, thresholdFactor = %f",
|
||||||
collectionID, maxSegmentSize, usedMemAfterLoad, totalMem, Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage)
|
collectionID, maxSegmentSize, usedMemAfterLoad, totalMem, Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage)
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package querynode
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -315,7 +316,7 @@ func TestSegmentLoader_checkSegmentSize(t *testing.T) {
|
||||||
loader := node.loader
|
loader := node.loader
|
||||||
assert.NotNil(t, loader)
|
assert.NotNil(t, loader)
|
||||||
|
|
||||||
err = loader.checkSegmentSize(defaultSegmentID, []*querypb.SegmentLoadInfo{{SegmentID: defaultSegmentID, SegmentSize: 1024}})
|
err = loader.checkSegmentSize(defaultSegmentID, []*querypb.SegmentLoadInfo{{SegmentID: defaultSegmentID, SegmentSize: 1024}}, runtime.GOMAXPROCS(0))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
//totalMem, err := getTotalMemory()
|
//totalMem, err := getTotalMemory()
|
||||||
|
|
Loading…
Reference in New Issue