Add cacheSize to prevent OOM in query node (#8765)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/8629/head
bigsheeper 2021-09-28 22:24:03 +08:00 committed by GitHub
parent 803d9ae8ca
commit 9d6c95bf25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 18 deletions

View File

@ -71,6 +71,7 @@ queryCoord:
clientMaxSendSize: 104857600 # 100 MB, 100 * 1024 * 1024
queryNode:
cacheSize: 32 # GB, default 32 GB, `cacheSize` is the memory used for caching data for faster query. The `cacheSize` must be less than system memory size.
gracefulTime: 1000 # ms, for search
port: 21123

View File

@ -80,6 +80,7 @@ type ReplicaInterface interface {
addExcludedSegments(collectionID UniqueID, segmentInfos []*datapb.SegmentInfo) error
getExcludedSegments(collectionID UniqueID) ([]*datapb.SegmentInfo, error)
getSegmentsMemSize() int64
freeAll()
printReplica()
}
@ -97,6 +98,17 @@ type collectionReplica struct {
etcdKV *etcdkv.EtcdKV
}
func (colReplica *collectionReplica) getSegmentsMemSize() int64 {
colReplica.mu.RLock()
defer colReplica.mu.RUnlock()
memSize := int64(0)
for _, segment := range colReplica.segments {
memSize += segment.getMemSize()
}
return memSize
}
func (colReplica *collectionReplica) printReplica() {
colReplica.mu.Lock()
defer colReplica.mu.Unlock()

View File

@ -12,10 +12,13 @@
package querynode
import (
"os"
"strconv"
"strings"
"sync"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/paramtable"
)
@ -32,6 +35,7 @@ type ParamTable struct {
QueryNodeIP string
QueryNodePort int64
QueryNodeID UniqueID
CacheSize int64
// channel prefix
ClusterChannelPrefix string
@ -96,6 +100,8 @@ func (p *ParamTable) Init() {
panic(err)
}
p.initCacheSize()
p.initMinioEndPoint()
p.initMinioAccessKeyID()
p.initMinioSecretAccessKey()
@ -130,6 +136,26 @@ func (p *ParamTable) Init() {
p.initLogCfg()
}
func (p *ParamTable) initCacheSize() {
const defaultCacheSize = 32 // GB
p.CacheSize = defaultCacheSize
var err error
cacheSize := os.Getenv("CACHE_SIZE")
if cacheSize == "" {
cacheSize, err = p.Load("queryNode.cacheSize")
if err != nil {
return
}
}
value, err := strconv.ParseInt(cacheSize, 10, 64)
if err != nil {
return
}
p.CacheSize = value
log.Debug("init cacheSize", zap.Any("cacheSize (GB)", p.CacheSize))
}
// ---------------------------------------------------------- minio
func (p *ParamTable) initMinioEndPoint() {
url, err := p.Load("_MinioAddress")

View File

@ -13,6 +13,7 @@ package querynode
import (
"fmt"
"os"
"strings"
"testing"
@ -26,6 +27,19 @@ func TestParamTable_PulsarAddress(t *testing.T) {
assert.Equal(t, "6650", split[len(split)-1])
}
func TestParamTable_cacheSize(t *testing.T) {
cacheSize := Params.CacheSize
assert.Equal(t, int64(32), cacheSize)
err := os.Setenv("CACHE_SIZE", "2")
assert.NoError(t, err)
Params.initCacheSize()
assert.Equal(t, int64(2), Params.CacheSize)
err = os.Setenv("CACHE_SIZE", "32")
assert.NoError(t, err)
Params.initCacheSize()
assert.Equal(t, int64(32), Params.CacheSize)
}
func TestParamTable_minio(t *testing.T) {
t.Run("Test endPoint", func(t *testing.T) {
endPoint := Params.MinioEndPoint

View File

@ -29,7 +29,6 @@ import (
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
@ -194,10 +193,10 @@ func (loader *segmentLoader) loadSegmentInternal(collectionID UniqueID, segment
}
func (loader *segmentLoader) checkSegmentMemory(segmentLoadInfos []*querypb.SegmentLoadInfo) error {
totalRAM := metricsinfo.GetMemoryCount()
usedRAM := metricsinfo.GetUsedMemoryCount()
totalRAMInMB := Params.CacheSize * 1024.0
usedRAMInMB := loader.historicalReplica.getSegmentsMemSize() / 1024.0 / 1024.0
segmentTotalSize := uint64(0)
segmentTotalSize := int64(0)
for _, segInfo := range segmentLoadInfos {
collectionID := segInfo.CollectionID
segmentID := segInfo.SegmentID
@ -212,22 +211,21 @@ func (loader *segmentLoader) checkSegmentMemory(segmentLoadInfos []*querypb.Segm
return err
}
segmentSize := uint64(int64(sizePerRecord) * segInfo.NumOfRows)
segmentTotalSize += segmentSize
// TODO: get 0.9 from param table
thresholdMemSize := float64(totalRAM) * 0.9
segmentSize := int64(sizePerRecord) * segInfo.NumOfRows
segmentTotalSize += segmentSize / 1024.0 / 1024.0
// TODO: get threshold factor from param table
thresholdMemSize := float64(totalRAMInMB) * 0.5
log.Debug("memory size[byte] stats when load segment",
log.Debug("memory stats when load segment",
zap.Any("collectionIDs", collectionID),
zap.Any("segmentID", segmentID),
zap.Any("numOfRows", segInfo.NumOfRows),
zap.Any("totalRAM", totalRAM),
zap.Any("usedRAM", usedRAM),
zap.Any("segmentSize", segmentSize),
zap.Any("segmentTotalSize", segmentTotalSize),
zap.Any("thresholdMemSize", thresholdMemSize),
zap.Any("totalRAM(MB)", totalRAMInMB),
zap.Any("usedRAM(MB)", usedRAMInMB),
zap.Any("segmentTotalSize(MB)", segmentTotalSize),
zap.Any("thresholdMemSize(MB)", thresholdMemSize),
)
if usedRAM+segmentTotalSize > uint64(thresholdMemSize) {
if usedRAMInMB+segmentTotalSize > int64(thresholdMemSize) {
return errors.New("load segment failed, OOM if load, collectionID = " + fmt.Sprintln(collectionID))
}
}

View File

@ -23,7 +23,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
@ -229,7 +228,7 @@ func TestSegmentLoader_CheckSegmentMemory(t *testing.T) {
})
t.Run("test OOM", func(t *testing.T) {
totalRAM := metricsinfo.GetMemoryCount()
totalRAM := Params.CacheSize * 1024 * 1024 * 1024
loader := genSegmentLoader()
col, err := loader.historicalReplica.getCollectionByID(collectionID)
@ -239,7 +238,7 @@ func TestSegmentLoader_CheckSegmentMemory(t *testing.T) {
assert.NoError(t, err)
info := genSegmentLoadInfo()
info.NumOfRows = int64(totalRAM / uint64(sizePerRecord))
info.NumOfRows = totalRAM / int64(sizePerRecord)
err = loader.checkSegmentMemory([]*querypb.SegmentLoadInfo{info})
assert.Error(t, err)
})