fix: [10kcp] Fix checkGeneralCapacity slowly (#37981)

Cache the general count to speed up checkGeneralCapacity.

issue: https://github.com/milvus-io/milvus/issues/37630

pr: https://github.com/milvus-io/milvus/pull/37976

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/38058/head
yihao.dai 2024-11-25 14:50:24 +08:00 committed by GitHub
parent fd30034c77
commit e5c16e0676
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 72 additions and 83 deletions

View File

@ -20,7 +20,6 @@ import (
"context"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
const (
@ -35,7 +34,6 @@ func checkGeneralCapacity(ctx context.Context, newColNum int,
newParNum int64,
newShardNum int32,
core *Core,
ts typeutil.Timestamp,
) error {
var addedNum int64 = 0
if newColNum > 0 && newParNum > 0 && newShardNum > 0 {
@ -46,25 +44,10 @@ func checkGeneralCapacity(ctx context.Context, newColNum int,
addedNum += newParNum
}
var generalNum int64 = 0
collectionsMap := core.meta.ListAllAvailCollections(ctx)
for dbId, collectionIDs := range collectionsMap {
db, err := core.meta.GetDatabaseByID(ctx, dbId, ts)
if err == nil {
for _, collectionId := range collectionIDs {
collection, err := core.meta.GetCollectionByID(ctx, db.Name, collectionId, ts, true)
if err == nil {
partNum := int64(collection.GetPartitionNum(false))
shardNum := int64(collection.ShardsNum)
generalNum += partNum * shardNum
}
}
}
}
generalNum += addedNum
if generalNum > Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64() {
return merr.WrapGeneralCapacityExceed(generalNum, Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64(),
generalCount := core.meta.GetGeneralCount(ctx)
generalCount += int(addedNum)
if generalCount > Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt() {
return merr.WrapGeneralCapacityExceed(generalCount, Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64(),
"failed checking constraint: sum_collections(parition*shard) exceeding the max general capacity:")
}
return nil

View File

@ -103,7 +103,7 @@ func (t *createCollectionTask) validate() error {
if t.Req.GetNumPartitions() > 0 {
newPartNum = t.Req.GetNumPartitions()
}
return checkGeneralCapacity(t.ctx, 1, newPartNum, t.Req.GetShardsNum(), t.core, t.ts)
return checkGeneralCapacity(t.ctx, 1, newPartNum, t.Req.GetShardsNum(), t.core)
}
// checkMaxCollectionsPerDB DB properties take precedence over quota configurations for max collections.

View File

@ -246,23 +246,7 @@ func Test_createCollectionTask_validate(t *testing.T) {
meta.EXPECT().ListAllAvailCollections(mock.Anything).Return(map[int64][]int64{1: {1, 2}})
meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).
Return(&model.Database{Name: "db1"}, nil).Once()
meta.On("GetDatabaseByID",
mock.Anything, mock.Anything, mock.Anything,
).Return(&model.Database{
Name: "default",
}, nil)
meta.On("GetCollectionByID",
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(&model.Collection{
Name: "default",
ShardsNum: 2,
Partitions: []*model.Partition{
{
PartitionID: 1,
},
},
}, nil)
meta.EXPECT().GetGeneralCount(mock.Anything).Return(1)
core := newTestCore(withMeta(meta))
@ -295,8 +279,7 @@ func Test_createCollectionTask_validate(t *testing.T) {
},
},
}, nil).Once()
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
core := newTestCore(withMeta(meta))
task := createCollectionTask{
@ -642,6 +625,7 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
).Return(map[int64][]int64{
util.DefaultDBID: {1, 2},
}, nil)
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key)
@ -662,8 +646,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
})
t.Run("invalid schema", func(t *testing.T) {
meta.On("GetDatabaseByID", mock.Anything,
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
core := newTestCore(withMeta(meta))
collectionName := funcutil.GenRandomStr()
task := &createCollectionTask{
@ -692,8 +674,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
}
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
meta.On("GetDatabaseByID", mock.Anything,
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
core := newTestCore(withInvalidIDAllocator(), withMeta(meta))
task := createCollectionTask{
@ -716,8 +696,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
field1 := funcutil.GenRandomStr()
ticker := newRocksMqTtSynchronizer()
meta.On("GetDatabaseByID", mock.Anything,
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta))
@ -1056,8 +1034,7 @@ func Test_createCollectionTask_PartitionKey(t *testing.T) {
).Return(map[int64][]int64{
util.DefaultDBID: {1, 2},
}, nil)
meta.On("GetDatabaseByID", mock.Anything,
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key)

View File

@ -44,7 +44,7 @@ func (t *createPartitionTask) Prepare(ctx context.Context) error {
return err
}
t.collMeta = collMeta
return checkGeneralCapacity(ctx, 0, 1, 0, t.core, t.ts)
return checkGeneralCapacity(ctx, 0, 1, 0, t.core)
}
func (t *createPartitionTask) Execute(ctx context.Context) error {

View File

@ -20,7 +20,6 @@ import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@ -62,14 +61,7 @@ func Test_createPartitionTask_Prepare(t *testing.T) {
mock.Anything,
mock.Anything,
).Return(coll.Clone(), nil)
meta.On("ListAllAvailCollections",
mock.Anything,
).Return(map[int64][]int64{
1: {1, 2},
}, nil)
meta.On("GetDatabaseByID",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil, errors.New("mock"))
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
core := newTestCore(withMeta(meta))
task := &createPartitionTask{

View File

@ -72,6 +72,7 @@ type IMetaTable interface {
ListAliases(ctx context.Context, dbName string, collectionName string, ts Timestamp) ([]string, error)
AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp) error
RenameCollection(ctx context.Context, dbName string, oldName string, newDBName string, newName string, ts Timestamp) error
GetGeneralCount(ctx context.Context) int
// TODO: it'll be a big cost if we handle the time travel logic, since we should always list all aliases in catalog.
IsAlias(db, name string) bool
@ -114,6 +115,8 @@ type MetaTable struct {
dbName2Meta map[string]*model.Database // database name -> db meta
collID2Meta map[typeutil.UniqueID]*model.Collection // collection id -> collection meta
generalCnt int // sum of product of partition number and shard number
// collections *collectionDb
names *nameDb
aliases *nameDb
@ -187,6 +190,7 @@ func (mt *MetaTable) reload() error {
}
for _, collection := range collections {
mt.collID2Meta[collection.CollectionID] = collection
mt.generalCnt += len(collection.Partitions) * int(collection.ShardsNum)
if collection.Available() {
mt.names.insert(dbName, collection.Name, collection.CollectionID)
collectionNum++
@ -409,6 +413,8 @@ func (mt *MetaTable) AddCollection(ctx context.Context, coll *model.Collection)
mt.collID2Meta[coll.CollectionID] = coll.Clone()
mt.names.insert(db.Name, coll.Name, coll.CollectionID)
mt.generalCnt += len(coll.Partitions) * int(coll.ShardsNum)
log.Ctx(ctx).Info("add collection to meta table",
zap.Int64("dbID", coll.DBID),
zap.String("collection", coll.Name),
@ -513,6 +519,8 @@ func (mt *MetaTable) RemoveCollection(ctx context.Context, collectionID UniqueID
mt.removeAllNamesIfMatchedInternal(collectionID, allNames)
mt.removeCollectionByIDInternal(collectionID)
mt.generalCnt -= len(coll.Partitions) * int(coll.ShardsNum)
log.Ctx(ctx).Info("remove collection",
zap.Int64("dbID", coll.DBID),
zap.String("name", coll.Name),
@ -738,6 +746,14 @@ func (mt *MetaTable) ListCollectionPhysicalChannels() map[typeutil.UniqueID][]st
return chanMap
}
// GetGeneralCount gets the general count(sum of product of partition number and shard number).
func (mt *MetaTable) GetGeneralCount(ctx context.Context) int {
mt.ddLock.RLock()
defer mt.ddLock.RUnlock()
return mt.generalCnt
}
func (mt *MetaTable) AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp) error {
mt.ddLock.Lock()
defer mt.ddLock.Unlock()
@ -861,6 +877,8 @@ func (mt *MetaTable) AddPartition(ctx context.Context, partition *model.Partitio
}
mt.collID2Meta[partition.CollectionID].Partitions = append(mt.collID2Meta[partition.CollectionID].Partitions, partition.Clone())
mt.generalCnt += int(coll.ShardsNum) // 1 partition * shardNum
metrics.RootCoordNumOfPartitions.WithLabelValues().Inc()
log.Ctx(ctx).Info("add partition to meta table",
@ -927,6 +945,7 @@ func (mt *MetaTable) RemovePartition(ctx context.Context, dbID int64, collection
}
if loc != -1 {
coll.Partitions = append(coll.Partitions[:loc], coll.Partitions[loc+1:]...)
mt.generalCnt -= int(coll.ShardsNum) // 1 partition * shardNum
}
log.Info("remove partition", zap.Int64("collection", collectionID), zap.Int64("partition", partitionID), zap.Uint64("ts", ts))
return nil

View File

@ -574,10 +574,6 @@ func (_c *IMetaTable_CreateDatabase_Call) RunAndReturn(run func(context.Context,
func (_m *IMetaTable) CreatePrivilegeGroup(groupName string) error {
ret := _m.Called(groupName)
if len(ret) == 0 {
panic("no return value specified for CreatePrivilegeGroup")
}
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(groupName)
@ -892,10 +888,6 @@ func (_c *IMetaTable_DropGrant_Call) RunAndReturn(run func(string, *milvuspb.Rol
func (_m *IMetaTable) DropPrivilegeGroup(groupName string) error {
ret := _m.Called(groupName)
if len(ret) == 0 {
panic("no return value specified for DropPrivilegeGroup")
}
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(groupName)
@ -1357,14 +1349,52 @@ func (_c *IMetaTable_GetDatabaseByName_Call) RunAndReturn(run func(context.Conte
return _c
}
// GetGeneralCount provides a mock function with given fields: ctx
func (_m *IMetaTable) GetGeneralCount(ctx context.Context) int {
ret := _m.Called(ctx)
var r0 int
if rf, ok := ret.Get(0).(func(context.Context) int); ok {
r0 = rf(ctx)
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// IMetaTable_GetGeneralCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGeneralCount'
type IMetaTable_GetGeneralCount_Call struct {
*mock.Call
}
// GetGeneralCount is a helper method to define mock.On call
// - ctx context.Context
func (_e *IMetaTable_Expecter) GetGeneralCount(ctx interface{}) *IMetaTable_GetGeneralCount_Call {
return &IMetaTable_GetGeneralCount_Call{Call: _e.mock.On("GetGeneralCount", ctx)}
}
func (_c *IMetaTable_GetGeneralCount_Call) Run(run func(ctx context.Context)) *IMetaTable_GetGeneralCount_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *IMetaTable_GetGeneralCount_Call) Return(_a0 int) *IMetaTable_GetGeneralCount_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *IMetaTable_GetGeneralCount_Call) RunAndReturn(run func(context.Context) int) *IMetaTable_GetGeneralCount_Call {
_c.Call.Return(run)
return _c
}
// GetPrivilegeGroupRoles provides a mock function with given fields: groupName
func (_m *IMetaTable) GetPrivilegeGroupRoles(groupName string) ([]*milvuspb.RoleEntity, error) {
ret := _m.Called(groupName)
if len(ret) == 0 {
panic("no return value specified for GetPrivilegeGroupRoles")
}
var r0 []*milvuspb.RoleEntity
var r1 error
if rf, ok := ret.Get(0).(func(string) ([]*milvuspb.RoleEntity, error)); ok {
@ -1462,10 +1492,6 @@ func (_c *IMetaTable_IsAlias_Call) RunAndReturn(run func(string, string) bool) *
func (_m *IMetaTable) IsCustomPrivilegeGroup(groupName string) (bool, error) {
ret := _m.Called(groupName)
if len(ret) == 0 {
panic("no return value specified for IsCustomPrivilegeGroup")
}
var r0 bool
var r1 error
if rf, ok := ret.Get(0).(func(string) (bool, error)); ok {
@ -1925,10 +1951,6 @@ func (_c *IMetaTable_ListPolicy_Call) RunAndReturn(run func(string) ([]string, e
func (_m *IMetaTable) ListPrivilegeGroups() ([]*milvuspb.PrivilegeGroupInfo, error) {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for ListPrivilegeGroups")
}
var r0 []*milvuspb.PrivilegeGroupInfo
var r1 error
if rf, ok := ret.Get(0).(func() ([]*milvuspb.PrivilegeGroupInfo, error)); ok {
@ -2080,10 +2102,6 @@ func (_c *IMetaTable_OperatePrivilege_Call) RunAndReturn(run func(string, *milvu
func (_m *IMetaTable) OperatePrivilegeGroup(groupName string, privileges []*milvuspb.PrivilegeEntity, operateType milvuspb.OperatePrivilegeGroupType) error {
ret := _m.Called(groupName, privileges, operateType)
if len(ret) == 0 {
panic("no return value specified for OperatePrivilegeGroup")
}
var r0 error
if rf, ok := ret.Get(0).(func(string, []*milvuspb.PrivilegeEntity, milvuspb.OperatePrivilegeGroupType) error); ok {
r0 = rf(groupName, privileges, operateType)