mirror of https://github.com/milvus-io/milvus.git
Add cacheSize to prevent OOM in query node (#8765)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/8629/head
parent
803d9ae8ca
commit
9d6c95bf25
|
@ -71,6 +71,7 @@ queryCoord:
|
|||
clientMaxSendSize: 104857600 # 100 MB, 100 * 1024 * 1024
|
||||
|
||||
queryNode:
|
||||
cacheSize: 32 # GB, default 32 GB, `cacheSize` is the memory used for caching data for faster query. The `cacheSize` must be less than system memory size.
|
||||
gracefulTime: 1000 # ms, for search
|
||||
port: 21123
|
||||
|
||||
|
|
|
@ -80,6 +80,7 @@ type ReplicaInterface interface {
|
|||
addExcludedSegments(collectionID UniqueID, segmentInfos []*datapb.SegmentInfo) error
|
||||
getExcludedSegments(collectionID UniqueID) ([]*datapb.SegmentInfo, error)
|
||||
|
||||
getSegmentsMemSize() int64
|
||||
freeAll()
|
||||
printReplica()
|
||||
}
|
||||
|
@ -97,6 +98,17 @@ type collectionReplica struct {
|
|||
etcdKV *etcdkv.EtcdKV
|
||||
}
|
||||
|
||||
func (colReplica *collectionReplica) getSegmentsMemSize() int64 {
|
||||
colReplica.mu.RLock()
|
||||
defer colReplica.mu.RUnlock()
|
||||
|
||||
memSize := int64(0)
|
||||
for _, segment := range colReplica.segments {
|
||||
memSize += segment.getMemSize()
|
||||
}
|
||||
return memSize
|
||||
}
|
||||
|
||||
func (colReplica *collectionReplica) printReplica() {
|
||||
colReplica.mu.Lock()
|
||||
defer colReplica.mu.Unlock()
|
||||
|
|
|
@ -12,10 +12,13 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
)
|
||||
|
@ -32,6 +35,7 @@ type ParamTable struct {
|
|||
QueryNodeIP string
|
||||
QueryNodePort int64
|
||||
QueryNodeID UniqueID
|
||||
CacheSize int64
|
||||
|
||||
// channel prefix
|
||||
ClusterChannelPrefix string
|
||||
|
@ -96,6 +100,8 @@ func (p *ParamTable) Init() {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
p.initCacheSize()
|
||||
|
||||
p.initMinioEndPoint()
|
||||
p.initMinioAccessKeyID()
|
||||
p.initMinioSecretAccessKey()
|
||||
|
@ -130,6 +136,26 @@ func (p *ParamTable) Init() {
|
|||
p.initLogCfg()
|
||||
}
|
||||
|
||||
func (p *ParamTable) initCacheSize() {
|
||||
const defaultCacheSize = 32 // GB
|
||||
p.CacheSize = defaultCacheSize
|
||||
|
||||
var err error
|
||||
cacheSize := os.Getenv("CACHE_SIZE")
|
||||
if cacheSize == "" {
|
||||
cacheSize, err = p.Load("queryNode.cacheSize")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
value, err := strconv.ParseInt(cacheSize, 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p.CacheSize = value
|
||||
log.Debug("init cacheSize", zap.Any("cacheSize (GB)", p.CacheSize))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------- minio
|
||||
func (p *ParamTable) initMinioEndPoint() {
|
||||
url, err := p.Load("_MinioAddress")
|
||||
|
|
|
@ -13,6 +13,7 @@ package querynode
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -26,6 +27,19 @@ func TestParamTable_PulsarAddress(t *testing.T) {
|
|||
assert.Equal(t, "6650", split[len(split)-1])
|
||||
}
|
||||
|
||||
func TestParamTable_cacheSize(t *testing.T) {
|
||||
cacheSize := Params.CacheSize
|
||||
assert.Equal(t, int64(32), cacheSize)
|
||||
err := os.Setenv("CACHE_SIZE", "2")
|
||||
assert.NoError(t, err)
|
||||
Params.initCacheSize()
|
||||
assert.Equal(t, int64(2), Params.CacheSize)
|
||||
err = os.Setenv("CACHE_SIZE", "32")
|
||||
assert.NoError(t, err)
|
||||
Params.initCacheSize()
|
||||
assert.Equal(t, int64(32), Params.CacheSize)
|
||||
}
|
||||
|
||||
func TestParamTable_minio(t *testing.T) {
|
||||
t.Run("Test endPoint", func(t *testing.T) {
|
||||
endPoint := Params.MinioEndPoint
|
||||
|
|
|
@ -29,7 +29,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
|
@ -194,10 +193,10 @@ func (loader *segmentLoader) loadSegmentInternal(collectionID UniqueID, segment
|
|||
}
|
||||
|
||||
func (loader *segmentLoader) checkSegmentMemory(segmentLoadInfos []*querypb.SegmentLoadInfo) error {
|
||||
totalRAM := metricsinfo.GetMemoryCount()
|
||||
usedRAM := metricsinfo.GetUsedMemoryCount()
|
||||
totalRAMInMB := Params.CacheSize * 1024.0
|
||||
usedRAMInMB := loader.historicalReplica.getSegmentsMemSize() / 1024.0 / 1024.0
|
||||
|
||||
segmentTotalSize := uint64(0)
|
||||
segmentTotalSize := int64(0)
|
||||
for _, segInfo := range segmentLoadInfos {
|
||||
collectionID := segInfo.CollectionID
|
||||
segmentID := segInfo.SegmentID
|
||||
|
@ -212,22 +211,21 @@ func (loader *segmentLoader) checkSegmentMemory(segmentLoadInfos []*querypb.Segm
|
|||
return err
|
||||
}
|
||||
|
||||
segmentSize := uint64(int64(sizePerRecord) * segInfo.NumOfRows)
|
||||
segmentTotalSize += segmentSize
|
||||
// TODO: get 0.9 from param table
|
||||
thresholdMemSize := float64(totalRAM) * 0.9
|
||||
segmentSize := int64(sizePerRecord) * segInfo.NumOfRows
|
||||
segmentTotalSize += segmentSize / 1024.0 / 1024.0
|
||||
// TODO: get threshold factor from param table
|
||||
thresholdMemSize := float64(totalRAMInMB) * 0.5
|
||||
|
||||
log.Debug("memory size[byte] stats when load segment",
|
||||
log.Debug("memory stats when load segment",
|
||||
zap.Any("collectionIDs", collectionID),
|
||||
zap.Any("segmentID", segmentID),
|
||||
zap.Any("numOfRows", segInfo.NumOfRows),
|
||||
zap.Any("totalRAM", totalRAM),
|
||||
zap.Any("usedRAM", usedRAM),
|
||||
zap.Any("segmentSize", segmentSize),
|
||||
zap.Any("segmentTotalSize", segmentTotalSize),
|
||||
zap.Any("thresholdMemSize", thresholdMemSize),
|
||||
zap.Any("totalRAM(MB)", totalRAMInMB),
|
||||
zap.Any("usedRAM(MB)", usedRAMInMB),
|
||||
zap.Any("segmentTotalSize(MB)", segmentTotalSize),
|
||||
zap.Any("thresholdMemSize(MB)", thresholdMemSize),
|
||||
)
|
||||
if usedRAM+segmentTotalSize > uint64(thresholdMemSize) {
|
||||
if usedRAMInMB+segmentTotalSize > int64(thresholdMemSize) {
|
||||
return errors.New("load segment failed, OOM if load, collectionID = " + fmt.Sprintln(collectionID))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
|
@ -229,7 +228,7 @@ func TestSegmentLoader_CheckSegmentMemory(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("test OOM", func(t *testing.T) {
|
||||
totalRAM := metricsinfo.GetMemoryCount()
|
||||
totalRAM := Params.CacheSize * 1024 * 1024 * 1024
|
||||
|
||||
loader := genSegmentLoader()
|
||||
col, err := loader.historicalReplica.getCollectionByID(collectionID)
|
||||
|
@ -239,7 +238,7 @@ func TestSegmentLoader_CheckSegmentMemory(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
info := genSegmentLoadInfo()
|
||||
info.NumOfRows = int64(totalRAM / uint64(sizePerRecord))
|
||||
info.NumOfRows = totalRAM / int64(sizePerRecord)
|
||||
err = loader.checkSegmentMemory([]*querypb.SegmentLoadInfo{info})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue