fix: Fix checkGeneralCapacity slowly (#37976)

Cache the general count to speed up `checkGeneralCapacity`.

issue: https://github.com/milvus-io/milvus/issues/37630

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/35835/merge
yihao.dai 2024-12-01 18:28:38 +08:00 committed by GitHub
parent 2319018fcb
commit 7fb0c281f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 76 additions and 59 deletions

View File

@ -20,7 +20,6 @@ import (
"context"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
const (
@ -35,7 +34,6 @@ func checkGeneralCapacity(ctx context.Context, newColNum int,
newParNum int64,
newShardNum int32,
core *Core,
ts typeutil.Timestamp,
) error {
var addedNum int64 = 0
if newColNum > 0 && newParNum > 0 && newShardNum > 0 {
@ -46,25 +44,10 @@ func checkGeneralCapacity(ctx context.Context, newColNum int,
addedNum += newParNum
}
var generalNum int64 = 0
collectionsMap := core.meta.ListAllAvailCollections(ctx)
for dbId, collectionIDs := range collectionsMap {
db, err := core.meta.GetDatabaseByID(ctx, dbId, ts)
if err == nil {
for _, collectionId := range collectionIDs {
collection, err := core.meta.GetCollectionByID(ctx, db.Name, collectionId, ts, true)
if err == nil {
partNum := int64(collection.GetPartitionNum(false))
shardNum := int64(collection.ShardsNum)
generalNum += partNum * shardNum
}
}
}
}
generalNum += addedNum
if generalNum > Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64() {
return merr.WrapGeneralCapacityExceed(generalNum, Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64(),
generalCount := core.meta.GetGeneralCount(ctx)
generalCount += int(addedNum)
if generalCount > Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt() {
return merr.WrapGeneralCapacityExceed(generalCount, Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64(),
"failed checking constraint: sum_collections(parition*shard) exceeding the max general capacity:")
}
return nil

View File

@ -107,7 +107,7 @@ func (t *createCollectionTask) validate() error {
if t.Req.GetNumPartitions() > 0 {
newPartNum = t.Req.GetNumPartitions()
}
return checkGeneralCapacity(t.ctx, 1, newPartNum, t.Req.GetShardsNum(), t.core, t.ts)
return checkGeneralCapacity(t.ctx, 1, newPartNum, t.Req.GetShardsNum(), t.core)
}
// checkMaxCollectionsPerDB DB properties take precedence over quota configurations for max collections.

View File

@ -246,23 +246,7 @@ func Test_createCollectionTask_validate(t *testing.T) {
meta.EXPECT().ListAllAvailCollections(mock.Anything).Return(map[int64][]int64{1: {1, 2}})
meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).
Return(&model.Database{Name: "db1"}, nil).Once()
meta.On("GetDatabaseByID",
mock.Anything, mock.Anything, mock.Anything,
).Return(&model.Database{
Name: "default",
}, nil)
meta.On("GetCollectionByID",
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(&model.Collection{
Name: "default",
ShardsNum: 2,
Partitions: []*model.Partition{
{
PartitionID: 1,
},
},
}, nil)
meta.EXPECT().GetGeneralCount(mock.Anything).Return(1)
core := newTestCore(withMeta(meta))
@ -295,8 +279,7 @@ func Test_createCollectionTask_validate(t *testing.T) {
},
},
}, nil).Once()
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
core := newTestCore(withMeta(meta))
task := createCollectionTask{
@ -734,6 +717,7 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
).Return(map[int64][]int64{
util.DefaultDBID: {1, 2},
}, nil)
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key)
@ -754,8 +738,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
})
t.Run("invalid schema", func(t *testing.T) {
meta.On("GetDatabaseByID", mock.Anything,
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
core := newTestCore(withMeta(meta))
collectionName := funcutil.GenRandomStr()
task := &createCollectionTask{
@ -784,8 +766,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
}
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
meta.On("GetDatabaseByID", mock.Anything,
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
core := newTestCore(withInvalidIDAllocator(), withMeta(meta))
task := createCollectionTask{
@ -808,8 +788,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
field1 := funcutil.GenRandomStr()
ticker := newRocksMqTtSynchronizer()
meta.On("GetDatabaseByID", mock.Anything,
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta))
@ -1160,8 +1138,7 @@ func Test_createCollectionTask_PartitionKey(t *testing.T) {
).Return(map[int64][]int64{
util.DefaultDBID: {1, 2},
}, nil)
meta.On("GetDatabaseByID", mock.Anything,
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key)

View File

@ -46,7 +46,7 @@ func (t *createPartitionTask) Prepare(ctx context.Context) error {
return err
}
t.collMeta = collMeta
return checkGeneralCapacity(ctx, 0, 1, 0, t.core, t.ts)
return checkGeneralCapacity(ctx, 0, 1, 0, t.core)
}
func (t *createPartitionTask) Execute(ctx context.Context) error {

View File

@ -20,7 +20,6 @@ import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@ -62,14 +61,7 @@ func Test_createPartitionTask_Prepare(t *testing.T) {
mock.Anything,
mock.Anything,
).Return(coll.Clone(), nil)
meta.On("ListAllAvailCollections",
mock.Anything,
).Return(map[int64][]int64{
1: {1, 2},
}, nil)
meta.On("GetDatabaseByID",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil, errors.New("mock"))
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
core := newTestCore(withMeta(meta))
task := &createPartitionTask{

View File

@ -74,6 +74,7 @@ type IMetaTable interface {
ListAliases(ctx context.Context, dbName string, collectionName string, ts Timestamp) ([]string, error)
AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp) error
RenameCollection(ctx context.Context, dbName string, oldName string, newDBName string, newName string, ts Timestamp) error
GetGeneralCount(ctx context.Context) int
// TODO: it'll be a big cost if we handle the time travel logic, since we should always list all aliases in catalog.
IsAlias(ctx context.Context, db, name string) bool
@ -115,6 +116,8 @@ type MetaTable struct {
dbName2Meta map[string]*model.Database // database name -> db meta
collID2Meta map[typeutil.UniqueID]*model.Collection // collection id -> collection meta
generalCnt int // sum of product of partition number and shard number
// collections *collectionDb
names *nameDb
aliases *nameDb
@ -189,6 +192,7 @@ func (mt *MetaTable) reload() error {
}
for _, collection := range collections {
mt.collID2Meta[collection.CollectionID] = collection
mt.generalCnt += len(collection.Partitions) * int(collection.ShardsNum)
if collection.Available() {
mt.names.insert(dbName, collection.Name, collection.CollectionID)
collectionNum++
@ -417,6 +421,8 @@ func (mt *MetaTable) AddCollection(ctx context.Context, coll *model.Collection)
mt.collID2Meta[coll.CollectionID] = coll.Clone()
mt.names.insert(db.Name, coll.Name, coll.CollectionID)
mt.generalCnt += len(coll.Partitions) * int(coll.ShardsNum)
log.Ctx(ctx).Info("add collection to meta table",
zap.Int64("dbID", coll.DBID),
zap.String("collection", coll.Name),
@ -521,6 +527,8 @@ func (mt *MetaTable) RemoveCollection(ctx context.Context, collectionID UniqueID
mt.removeAllNamesIfMatchedInternal(collectionID, allNames)
mt.removeCollectionByIDInternal(collectionID)
mt.generalCnt -= len(coll.Partitions) * int(coll.ShardsNum)
log.Ctx(ctx).Info("remove collection",
zap.Int64("dbID", coll.DBID),
zap.String("name", coll.Name),
@ -895,6 +903,8 @@ func (mt *MetaTable) AddPartition(ctx context.Context, partition *model.Partitio
}
mt.collID2Meta[partition.CollectionID].Partitions = append(mt.collID2Meta[partition.CollectionID].Partitions, partition.Clone())
mt.generalCnt += int(coll.ShardsNum) // 1 partition * shardNum
metrics.RootCoordNumOfPartitions.WithLabelValues().Inc()
log.Ctx(ctx).Info("add partition to meta table",
@ -961,6 +971,7 @@ func (mt *MetaTable) RemovePartition(ctx context.Context, dbID int64, collection
}
if loc != -1 {
coll.Partitions = append(coll.Partitions[:loc], coll.Partitions[loc+1:]...)
mt.generalCnt -= int(coll.ShardsNum) // 1 partition * shardNum
}
log.Info("remove partition", zap.Int64("collection", collectionID), zap.Int64("partition", partitionID), zap.Uint64("ts", ts))
return nil
@ -1229,6 +1240,14 @@ func (mt *MetaTable) ListAliasesByID(ctx context.Context, collID UniqueID) []str
return mt.listAliasesByID(collID)
}
// GetGeneralCount gets the general count(sum of product of partition number and shard number).
func (mt *MetaTable) GetGeneralCount(ctx context.Context) int {
mt.ddLock.RLock()
defer mt.ddLock.RUnlock()
return mt.generalCnt
}
// AddCredential add credential
func (mt *MetaTable) AddCredential(ctx context.Context, credInfo *internalpb.CredentialInfo) error {
if credInfo.Username == "" {

View File

@ -1473,6 +1473,52 @@ func (_c *IMetaTable_GetDatabaseByName_Call) RunAndReturn(run func(context.Conte
return _c
}
// GetGeneralCount provides a mock function with given fields: ctx
func (_m *IMetaTable) GetGeneralCount(ctx context.Context) int {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetGeneralCount")
}
var r0 int
if rf, ok := ret.Get(0).(func(context.Context) int); ok {
r0 = rf(ctx)
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// IMetaTable_GetGeneralCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGeneralCount'
type IMetaTable_GetGeneralCount_Call struct {
*mock.Call
}
// GetGeneralCount is a helper method to define mock.On call
// - ctx context.Context
func (_e *IMetaTable_Expecter) GetGeneralCount(ctx interface{}) *IMetaTable_GetGeneralCount_Call {
return &IMetaTable_GetGeneralCount_Call{Call: _e.mock.On("GetGeneralCount", ctx)}
}
func (_c *IMetaTable_GetGeneralCount_Call) Run(run func(ctx context.Context)) *IMetaTable_GetGeneralCount_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *IMetaTable_GetGeneralCount_Call) Return(_a0 int) *IMetaTable_GetGeneralCount_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *IMetaTable_GetGeneralCount_Call) RunAndReturn(run func(context.Context) int) *IMetaTable_GetGeneralCount_Call {
_c.Call.Return(run)
return _c
}
// GetPChannelInfo provides a mock function with given fields: ctx, pchannel
func (_m *IMetaTable) GetPChannelInfo(ctx context.Context, pchannel string) *rootcoordpb.GetPChannelInfoResponse {
ret := _m.Called(ctx, pchannel)