mirror of https://github.com/milvus-io/milvus.git
modify diskann memory estimation method (#23892)
Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>pull/23970/head
parent
9b47d90c13
commit
f7159189fd
|
@ -31,8 +31,6 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparams"
|
||||
)
|
||||
|
||||
// LoadIndexInfo is a wrapper of the underlying C-structure C.CLoadIndexInfo
|
||||
|
@ -72,12 +70,6 @@ func (li *LoadIndexInfo) appendLoadIndexInfo(bytesIndex [][]byte, indexInfo *que
|
|||
|
||||
// some build params also exist in indexParams, which are useless during loading process
|
||||
indexParams := funcutil.KeyValuePair2Map(indexInfo.IndexParams)
|
||||
if indexParams["index_type"] == indexparamcheck.IndexDISKANN {
|
||||
err = indexparams.SetDiskIndexLoadParams(&Params, indexParams, indexInfo.GetNumRows())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for key, value := range indexParams {
|
||||
err = li.appendIndexParam(key, value)
|
||||
|
|
|
@ -48,6 +48,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/hardware"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparams"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
@ -56,7 +57,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
UsedDiskMemoryRatio = 4
|
||||
DiskANNCacheExpansionFactor = 1.5
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -88,6 +89,25 @@ func (loader *segmentLoader) getFieldType(segment *Segment, fieldID FieldID) (sc
|
|||
return coll.getFieldType(fieldID)
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) appendFieldIndexLoadParams(req *querypb.LoadSegmentsRequest) error {
|
||||
// set diskann load params
|
||||
for _, loadInfo := range req.Infos {
|
||||
for _, fieldIndexInfo := range loadInfo.IndexInfos {
|
||||
if fieldIndexInfo.EnableIndex {
|
||||
indexParams := funcutil.KeyValuePair2Map(fieldIndexInfo.IndexParams)
|
||||
if indexParams["index_type"] == indexparamcheck.IndexDISKANN {
|
||||
err := indexparams.SetDiskIndexLoadParams(&Params, indexParams, fieldIndexInfo.NumRows)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
fieldIndexInfo.IndexParams = funcutil.Map2KeyValuePair(indexParams)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadSegmentsRequest, segmentType segmentType) ([]UniqueID, error) {
|
||||
if req.Base == nil {
|
||||
return nil, fmt.Errorf("nil base message when load segment, collectionID = %d", req.CollectionID)
|
||||
|
@ -114,6 +134,13 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadS
|
|||
}
|
||||
return minValue
|
||||
}
|
||||
|
||||
err := loader.appendFieldIndexLoadParams(req)
|
||||
if err != nil {
|
||||
log.Error("Fail to append load parameters ", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
concurrencyLevel := min(runtime.GOMAXPROCS(0), len(req.Infos))
|
||||
for ; concurrencyLevel > 1; concurrencyLevel /= 2 {
|
||||
err := loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel)
|
||||
|
@ -122,7 +149,7 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadS
|
|||
}
|
||||
}
|
||||
|
||||
err := loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel)
|
||||
err = loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel)
|
||||
if err != nil {
|
||||
log.Error("load failed, OOM if loaded",
|
||||
zap.Int64("loadSegmentRequest msgID", req.Base.MsgID),
|
||||
|
@ -737,8 +764,28 @@ func GetStorageSizeByIndexInfo(indexInfo *querypb.FieldIndexInfo) (uint64, uint6
|
|||
return 0, 0, fmt.Errorf("index type not exist in index params")
|
||||
}
|
||||
if indexType == indexparamcheck.IndexDISKANN {
|
||||
neededMemSize := indexInfo.IndexSize / UsedDiskMemoryRatio
|
||||
return uint64(neededMemSize), uint64(indexInfo.IndexSize), nil
|
||||
chunkManagerBufferSize := uint64(64 * 1024 * 1024)
|
||||
indexParams := funcutil.KeyValuePair2Map(indexInfo.IndexParams)
|
||||
PQCodeProportion, err := strconv.ParseFloat(indexParams[indexparams.PQCodeBudgetRatioKey], 64)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("%s not exist in index params", indexparams.PQCodeBudgetRatioKey)
|
||||
}
|
||||
searchCacheProportion, err := strconv.ParseFloat(indexParams[indexparams.SearchCacheBudgetKey], 64)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("%s not exist in index params", indexparams.SearchCacheBudgetKey)
|
||||
}
|
||||
dim, err := strconv.ParseInt(indexParams["dim"], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("dim not exist in index params")
|
||||
}
|
||||
rawDataSize := indexparams.GetRowDataSizeOfFloatVector(indexInfo.NumRows, dim)
|
||||
// disk index only maintain pq table, pq code and cache in mem;
|
||||
// searchCacheProportion is used with GB, it need to transform to Byte
|
||||
neededMemSize := uint64(float64(rawDataSize)*(PQCodeProportion) + searchCacheProportion*DiskANNCacheExpansionFactor*(1024*1024*1024))
|
||||
if neededMemSize < chunkManagerBufferSize {
|
||||
neededMemSize = chunkManagerBufferSize
|
||||
}
|
||||
return neededMemSize, uint64(indexInfo.IndexSize), nil
|
||||
}
|
||||
|
||||
return uint64(indexInfo.IndexSize), 0, nil
|
||||
|
@ -761,7 +808,7 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
|
|||
}
|
||||
|
||||
usedMemAfterLoad := usedMem
|
||||
maxSegmentSize := uint64(0)
|
||||
maxSegmentMemUsgae := uint64(0)
|
||||
|
||||
localUsedSize, err := GetLocalUsedSize()
|
||||
if err != nil {
|
||||
|
@ -771,6 +818,7 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
|
|||
|
||||
for _, loadInfo := range segmentLoadInfos {
|
||||
oldUsedMem := usedMemAfterLoad
|
||||
diskIndexMemSize := uint64(0)
|
||||
vecFieldID2IndexInfo := make(map[int64]*querypb.FieldIndexInfo)
|
||||
for _, fieldIndexInfo := range loadInfo.IndexInfos {
|
||||
if fieldIndexInfo.EnableIndex {
|
||||
|
@ -789,8 +837,13 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
|
|||
zap.Int64("indexBuildID", fieldIndexInfo.BuildID))
|
||||
return err
|
||||
}
|
||||
usedMemAfterLoad += neededMemSize
|
||||
usedLocalSizeAfterLoad += neededDiskSize
|
||||
// diskann not need to copy data
|
||||
indexType, _ := funcutil.GetAttrByKeyFromRepeatedKV("index_type", fieldIndexInfo.IndexParams)
|
||||
if indexType == indexparamcheck.IndexDISKANN {
|
||||
diskIndexMemSize += neededMemSize
|
||||
}
|
||||
usedMemAfterLoad += neededMemSize
|
||||
} else {
|
||||
usedMemAfterLoad += uint64(funcutil.GetFieldSizeFromFieldBinlog(fieldBinlog))
|
||||
}
|
||||
|
@ -806,8 +859,13 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
|
|||
usedMemAfterLoad += uint64(funcutil.GetFieldSizeFromFieldBinlog(fieldBinlog))
|
||||
}
|
||||
|
||||
if usedMemAfterLoad-oldUsedMem > maxSegmentSize {
|
||||
maxSegmentSize = usedMemAfterLoad - oldUsedMem
|
||||
currentSegmentSize := usedMemAfterLoad - oldUsedMem
|
||||
currentSegmentMemUsage := uint64(float64(currentSegmentSize) * Params.QueryNodeCfg.LoadMemoryUsageFactor)
|
||||
if Params.QueryNodeCfg.LoadMemoryUsageFactor > 1 {
|
||||
currentSegmentMemUsage -= uint64(float64(diskIndexMemSize) * (Params.QueryNodeCfg.LoadMemoryUsageFactor - 1))
|
||||
}
|
||||
if currentSegmentMemUsage > maxSegmentMemUsgae {
|
||||
maxSegmentMemUsgae = currentSegmentMemUsage
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -816,8 +874,8 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
|
|||
}
|
||||
|
||||
// when load segment, data will be copied from go memory to c++ memory
|
||||
memLoadingUsage := usedMemAfterLoad + uint64(
|
||||
float64(maxSegmentSize)*float64(concurrency)*Params.QueryNodeCfg.LoadMemoryUsageFactor)
|
||||
memLoadingUsage := usedMemAfterLoad + uint64(float64(maxSegmentMemUsgae)*float64(concurrency))
|
||||
|
||||
log.Info("predict memory and disk usage while loading (in MiB)",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int("concurrency", concurrency),
|
||||
|
@ -829,10 +887,10 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
|
|||
zap.Uint64("currentTotalMemory", toMB(totalMem)))
|
||||
|
||||
if memLoadingUsage > uint64(float64(totalMem)*Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage) {
|
||||
return fmt.Errorf("%w, load segment failed, OOM if load, collectionID = %d, maxSegmentSize = %v MB, concurrency = %d, usedMemAfterLoad = %v MB, currentAvailableMemCount = %v MB, currentTotalMem = %v MB, thresholdFactor = %f",
|
||||
return fmt.Errorf("%w, load segment failed, OOM if load, collectionID = %d, MaxSegmentMemUsage = %v MB, concurrency = %d, usedMemAfterLoad = %v MB, currentAvailableMemCount = %v MB, currentTotalMem = %v MB, thresholdFactor = %f",
|
||||
ErrInsufficientMemory,
|
||||
collectionID,
|
||||
toMB(maxSegmentSize),
|
||||
toMB(maxSegmentMemUsgae),
|
||||
concurrency,
|
||||
toMB(usedMemAfterLoad),
|
||||
toMB(currentAvailableMemCount),
|
||||
|
|
|
@ -41,6 +41,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/concurrency"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
)
|
||||
|
||||
func TestSegmentLoader_loadSegment(t *testing.T) {
|
||||
|
@ -1122,3 +1123,54 @@ func (s *SegmentLoaderMockSuite) TestSkipBFLoad() {
|
|||
func TestSegmentLoaderWithMock(t *testing.T) {
|
||||
suite.Run(t, new(SegmentLoaderMockSuite))
|
||||
}
|
||||
|
||||
func TestSegmentLoader_check_diskann_mem_usage(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
node, generr := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, generr)
|
||||
indexParams := make(map[string]string)
|
||||
indexParams["index_type"] = indexparamcheck.IndexDISKANN
|
||||
indexParams["metric_type"] = "L2"
|
||||
indexParams["dim"] = "128"
|
||||
indexParams["pq_code_budget_gb_ratio"] = "0.125"
|
||||
schema := genTestCollectionSchema()
|
||||
fieldBinlog, statsLog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema)
|
||||
assert.NoError(t, err)
|
||||
indexInfo := &querypb.FieldIndexInfo{
|
||||
FieldID: simpleFloatVecField.id,
|
||||
EnableIndex: true,
|
||||
IndexName: indexName,
|
||||
IndexID: indexID,
|
||||
BuildID: buildID,
|
||||
IndexParams: funcutil.Map2KeyValuePair(indexParams),
|
||||
IndexFilePaths: []string{"diskindexpathtmp"},
|
||||
}
|
||||
req := &querypb.LoadSegmentsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadSegments,
|
||||
MsgID: rand.Int63(),
|
||||
},
|
||||
DstNodeID: 0,
|
||||
Schema: schema,
|
||||
Infos: []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: defaultSegmentID,
|
||||
PartitionID: defaultPartitionID,
|
||||
CollectionID: defaultCollectionID,
|
||||
BinlogPaths: fieldBinlog,
|
||||
IndexInfos: []*querypb.FieldIndexInfo{indexInfo},
|
||||
Statslogs: statsLog,
|
||||
NumOfRows: defaultMsgLength,
|
||||
},
|
||||
},
|
||||
}
|
||||
loader := node.loader
|
||||
concurrencyLevel := 1
|
||||
err = loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel)
|
||||
assert.Error(t, err)
|
||||
|
||||
loader.appendFieldIndexLoadParams(req)
|
||||
err = loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ const (
|
|||
MaxBeamWidth = 16
|
||||
)
|
||||
|
||||
func getRowDataSizeOfFloatVector(numRows int64, dim int64) int64 {
|
||||
func GetRowDataSizeOfFloatVector(numRows int64, dim int64) int64 {
|
||||
var floatValue float32
|
||||
/* #nosec G103 */
|
||||
return int64(unsafe.Sizeof(floatValue)) * dim * numRows
|
||||
|
@ -136,7 +136,7 @@ func SetDiskIndexLoadParams(params *paramtable.ComponentParam, indexParams map[s
|
|||
}
|
||||
|
||||
indexParams[SearchCacheBudgetKey] = fmt.Sprintf("%f",
|
||||
float32(getRowDataSizeOfFloatVector(numRows, dim))*float32(searchCacheBudgetGBRatio)/(1<<30))
|
||||
float32(GetRowDataSizeOfFloatVector(numRows, dim))*float32(searchCacheBudgetGBRatio)/(1<<30))
|
||||
|
||||
numLoadThread := int(float32(hardware.GetCPUNum()) * float32(loadNumThreadRatio))
|
||||
if numLoadThread > MaxLoadThread {
|
||||
|
|
|
@ -121,7 +121,7 @@ func TestDiskIndexParams(t *testing.T) {
|
|||
|
||||
searchCacheBudget, ok := indexParams[SearchCacheBudgetKey]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, fmt.Sprintf("%f", float32(getRowDataSizeOfFloatVector(100, 128))*float32(extraParams.SearchCacheBudgetGBRatio)/(1<<30)), searchCacheBudget)
|
||||
assert.Equal(t, fmt.Sprintf("%f", float32(GetRowDataSizeOfFloatVector(100, 128))*float32(extraParams.SearchCacheBudgetGBRatio)/(1<<30)), searchCacheBudget)
|
||||
|
||||
numLoadThread, ok := indexParams[NumLoadThreadKey]
|
||||
assert.True(t, ok)
|
||||
|
@ -148,7 +148,7 @@ func TestDiskIndexParams(t *testing.T) {
|
|||
|
||||
searchCacheBudget, ok = indexParams[SearchCacheBudgetKey]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, fmt.Sprintf("%f", float32(getRowDataSizeOfFloatVector(100, 128))*float32(params.CommonCfg.SearchCacheBudgetGBRatio)/(1<<30)), searchCacheBudget)
|
||||
assert.Equal(t, fmt.Sprintf("%f", float32(GetRowDataSizeOfFloatVector(100, 128))*float32(params.CommonCfg.SearchCacheBudgetGBRatio)/(1<<30)), searchCacheBudget)
|
||||
|
||||
numLoadThread, ok = indexParams[NumLoadThreadKey]
|
||||
assert.True(t, ok)
|
||||
|
|
Loading…
Reference in New Issue