modify diskann memory estimation method (#23892)

Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
pull/23970/head
cqy123456 2023-05-09 20:29:09 +08:00 committed by GitHub
parent 9b47d90c13
commit f7159189fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 126 additions and 24 deletions

View File

@ -31,8 +31,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/indexparams"
)
// LoadIndexInfo is a wrapper of the underlying C-structure C.CLoadIndexInfo
@ -72,12 +70,6 @@ func (li *LoadIndexInfo) appendLoadIndexInfo(bytesIndex [][]byte, indexInfo *que
// some build params also exist in indexParams, which are useless during loading process
indexParams := funcutil.KeyValuePair2Map(indexInfo.IndexParams)
if indexParams["index_type"] == indexparamcheck.IndexDISKANN {
err = indexparams.SetDiskIndexLoadParams(&Params, indexParams, indexInfo.GetNumRows())
if err != nil {
return err
}
}
for key, value := range indexParams {
err = li.appendIndexParam(key, value)

View File

@ -48,6 +48,7 @@ import (
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/hardware"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/indexparams"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
@ -56,7 +57,7 @@ import (
)
const (
UsedDiskMemoryRatio = 4
DiskANNCacheExpansionFactor = 1.5
)
var (
@ -88,6 +89,25 @@ func (loader *segmentLoader) getFieldType(segment *Segment, fieldID FieldID) (sc
return coll.getFieldType(fieldID)
}
func (loader *segmentLoader) appendFieldIndexLoadParams(req *querypb.LoadSegmentsRequest) error {
// set diskann load params
for _, loadInfo := range req.Infos {
for _, fieldIndexInfo := range loadInfo.IndexInfos {
if fieldIndexInfo.EnableIndex {
indexParams := funcutil.KeyValuePair2Map(fieldIndexInfo.IndexParams)
if indexParams["index_type"] == indexparamcheck.IndexDISKANN {
err := indexparams.SetDiskIndexLoadParams(&Params, indexParams, fieldIndexInfo.NumRows)
if err != nil {
return err
}
}
fieldIndexInfo.IndexParams = funcutil.Map2KeyValuePair(indexParams)
}
}
}
return nil
}
func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadSegmentsRequest, segmentType segmentType) ([]UniqueID, error) {
if req.Base == nil {
return nil, fmt.Errorf("nil base message when load segment, collectionID = %d", req.CollectionID)
@ -114,6 +134,13 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadS
}
return minValue
}
err := loader.appendFieldIndexLoadParams(req)
if err != nil {
log.Error("Fail to append load parameters ", zap.Error(err))
return nil, err
}
concurrencyLevel := min(runtime.GOMAXPROCS(0), len(req.Infos))
for ; concurrencyLevel > 1; concurrencyLevel /= 2 {
err := loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel)
@ -122,7 +149,7 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadS
}
}
err := loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel)
err = loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel)
if err != nil {
log.Error("load failed, OOM if loaded",
zap.Int64("loadSegmentRequest msgID", req.Base.MsgID),
@ -737,8 +764,28 @@ func GetStorageSizeByIndexInfo(indexInfo *querypb.FieldIndexInfo) (uint64, uint6
return 0, 0, fmt.Errorf("index type not exist in index params")
}
if indexType == indexparamcheck.IndexDISKANN {
neededMemSize := indexInfo.IndexSize / UsedDiskMemoryRatio
return uint64(neededMemSize), uint64(indexInfo.IndexSize), nil
chunkManagerBufferSize := uint64(64 * 1024 * 1024)
indexParams := funcutil.KeyValuePair2Map(indexInfo.IndexParams)
PQCodeProportion, err := strconv.ParseFloat(indexParams[indexparams.PQCodeBudgetRatioKey], 64)
if err != nil {
return 0, 0, fmt.Errorf("%s not exist in index params", indexparams.PQCodeBudgetRatioKey)
}
searchCacheProportion, err := strconv.ParseFloat(indexParams[indexparams.SearchCacheBudgetKey], 64)
if err != nil {
return 0, 0, fmt.Errorf("%s not exist in index params", indexparams.SearchCacheBudgetKey)
}
dim, err := strconv.ParseInt(indexParams["dim"], 10, 64)
if err != nil {
return 0, 0, fmt.Errorf("dim not exist in index params")
}
rawDataSize := indexparams.GetRowDataSizeOfFloatVector(indexInfo.NumRows, dim)
// disk index only maintain pq table, pq code and cache in mem;
// searchCacheProportion is used with GB, it need to transform to Byte
neededMemSize := uint64(float64(rawDataSize)*(PQCodeProportion) + searchCacheProportion*DiskANNCacheExpansionFactor*(1024*1024*1024))
if neededMemSize < chunkManagerBufferSize {
neededMemSize = chunkManagerBufferSize
}
return neededMemSize, uint64(indexInfo.IndexSize), nil
}
return uint64(indexInfo.IndexSize), 0, nil
@ -761,7 +808,7 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
}
usedMemAfterLoad := usedMem
maxSegmentSize := uint64(0)
maxSegmentMemUsgae := uint64(0)
localUsedSize, err := GetLocalUsedSize()
if err != nil {
@ -771,6 +818,7 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
for _, loadInfo := range segmentLoadInfos {
oldUsedMem := usedMemAfterLoad
diskIndexMemSize := uint64(0)
vecFieldID2IndexInfo := make(map[int64]*querypb.FieldIndexInfo)
for _, fieldIndexInfo := range loadInfo.IndexInfos {
if fieldIndexInfo.EnableIndex {
@ -789,8 +837,13 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
zap.Int64("indexBuildID", fieldIndexInfo.BuildID))
return err
}
usedMemAfterLoad += neededMemSize
usedLocalSizeAfterLoad += neededDiskSize
// diskann not need to copy data
indexType, _ := funcutil.GetAttrByKeyFromRepeatedKV("index_type", fieldIndexInfo.IndexParams)
if indexType == indexparamcheck.IndexDISKANN {
diskIndexMemSize += neededMemSize
}
usedMemAfterLoad += neededMemSize
} else {
usedMemAfterLoad += uint64(funcutil.GetFieldSizeFromFieldBinlog(fieldBinlog))
}
@ -806,8 +859,13 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
usedMemAfterLoad += uint64(funcutil.GetFieldSizeFromFieldBinlog(fieldBinlog))
}
if usedMemAfterLoad-oldUsedMem > maxSegmentSize {
maxSegmentSize = usedMemAfterLoad - oldUsedMem
currentSegmentSize := usedMemAfterLoad - oldUsedMem
currentSegmentMemUsage := uint64(float64(currentSegmentSize) * Params.QueryNodeCfg.LoadMemoryUsageFactor)
if Params.QueryNodeCfg.LoadMemoryUsageFactor > 1 {
currentSegmentMemUsage -= uint64(float64(diskIndexMemSize) * (Params.QueryNodeCfg.LoadMemoryUsageFactor - 1))
}
if currentSegmentMemUsage > maxSegmentMemUsgae {
maxSegmentMemUsgae = currentSegmentMemUsage
}
}
@ -816,8 +874,8 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
}
// when load segment, data will be copied from go memory to c++ memory
memLoadingUsage := usedMemAfterLoad + uint64(
float64(maxSegmentSize)*float64(concurrency)*Params.QueryNodeCfg.LoadMemoryUsageFactor)
memLoadingUsage := usedMemAfterLoad + uint64(float64(maxSegmentMemUsgae)*float64(concurrency))
log.Info("predict memory and disk usage while loading (in MiB)",
zap.Int64("collectionID", collectionID),
zap.Int("concurrency", concurrency),
@ -829,10 +887,10 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
zap.Uint64("currentTotalMemory", toMB(totalMem)))
if memLoadingUsage > uint64(float64(totalMem)*Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage) {
return fmt.Errorf("%w, load segment failed, OOM if load, collectionID = %d, maxSegmentSize = %v MB, concurrency = %d, usedMemAfterLoad = %v MB, currentAvailableMemCount = %v MB, currentTotalMem = %v MB, thresholdFactor = %f",
return fmt.Errorf("%w, load segment failed, OOM if load, collectionID = %d, MaxSegmentMemUsage = %v MB, concurrency = %d, usedMemAfterLoad = %v MB, currentAvailableMemCount = %v MB, currentTotalMem = %v MB, thresholdFactor = %f",
ErrInsufficientMemory,
collectionID,
toMB(maxSegmentSize),
toMB(maxSegmentMemUsgae),
concurrency,
toMB(usedMemAfterLoad),
toMB(currentAvailableMemCount),

View File

@ -41,6 +41,7 @@ import (
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/concurrency"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
)
func TestSegmentLoader_loadSegment(t *testing.T) {
@ -1122,3 +1123,54 @@ func (s *SegmentLoaderMockSuite) TestSkipBFLoad() {
func TestSegmentLoaderWithMock(t *testing.T) {
suite.Run(t, new(SegmentLoaderMockSuite))
}
func TestSegmentLoader_check_diskann_mem_usage(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, generr := genSimpleQueryNode(ctx)
assert.NoError(t, generr)
indexParams := make(map[string]string)
indexParams["index_type"] = indexparamcheck.IndexDISKANN
indexParams["metric_type"] = "L2"
indexParams["dim"] = "128"
indexParams["pq_code_budget_gb_ratio"] = "0.125"
schema := genTestCollectionSchema()
fieldBinlog, statsLog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema)
assert.NoError(t, err)
indexInfo := &querypb.FieldIndexInfo{
FieldID: simpleFloatVecField.id,
EnableIndex: true,
IndexName: indexName,
IndexID: indexID,
BuildID: buildID,
IndexParams: funcutil.Map2KeyValuePair(indexParams),
IndexFilePaths: []string{"diskindexpathtmp"},
}
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadSegments,
MsgID: rand.Int63(),
},
DstNodeID: 0,
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: defaultSegmentID,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
BinlogPaths: fieldBinlog,
IndexInfos: []*querypb.FieldIndexInfo{indexInfo},
Statslogs: statsLog,
NumOfRows: defaultMsgLength,
},
},
}
loader := node.loader
concurrencyLevel := 1
err = loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel)
assert.Error(t, err)
loader.appendFieldIndexLoadParams(req)
err = loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel)
assert.NoError(t, err)
}

View File

@ -45,7 +45,7 @@ const (
MaxBeamWidth = 16
)
func getRowDataSizeOfFloatVector(numRows int64, dim int64) int64 {
func GetRowDataSizeOfFloatVector(numRows int64, dim int64) int64 {
var floatValue float32
/* #nosec G103 */
return int64(unsafe.Sizeof(floatValue)) * dim * numRows
@ -136,7 +136,7 @@ func SetDiskIndexLoadParams(params *paramtable.ComponentParam, indexParams map[s
}
indexParams[SearchCacheBudgetKey] = fmt.Sprintf("%f",
float32(getRowDataSizeOfFloatVector(numRows, dim))*float32(searchCacheBudgetGBRatio)/(1<<30))
float32(GetRowDataSizeOfFloatVector(numRows, dim))*float32(searchCacheBudgetGBRatio)/(1<<30))
numLoadThread := int(float32(hardware.GetCPUNum()) * float32(loadNumThreadRatio))
if numLoadThread > MaxLoadThread {

View File

@ -121,7 +121,7 @@ func TestDiskIndexParams(t *testing.T) {
searchCacheBudget, ok := indexParams[SearchCacheBudgetKey]
assert.True(t, ok)
assert.Equal(t, fmt.Sprintf("%f", float32(getRowDataSizeOfFloatVector(100, 128))*float32(extraParams.SearchCacheBudgetGBRatio)/(1<<30)), searchCacheBudget)
assert.Equal(t, fmt.Sprintf("%f", float32(GetRowDataSizeOfFloatVector(100, 128))*float32(extraParams.SearchCacheBudgetGBRatio)/(1<<30)), searchCacheBudget)
numLoadThread, ok := indexParams[NumLoadThreadKey]
assert.True(t, ok)
@ -148,7 +148,7 @@ func TestDiskIndexParams(t *testing.T) {
searchCacheBudget, ok = indexParams[SearchCacheBudgetKey]
assert.True(t, ok)
assert.Equal(t, fmt.Sprintf("%f", float32(getRowDataSizeOfFloatVector(100, 128))*float32(params.CommonCfg.SearchCacheBudgetGBRatio)/(1<<30)), searchCacheBudget)
assert.Equal(t, fmt.Sprintf("%f", float32(GetRowDataSizeOfFloatVector(100, 128))*float32(params.CommonCfg.SearchCacheBudgetGBRatio)/(1<<30)), searchCacheBudget)
numLoadThread, ok = indexParams[NumLoadThreadKey]
assert.True(t, ok)