mirror of https://github.com/milvus-io/milvus.git
fix: Fix checkGeneralCapacity slowly (#37976)
Cache the general count to speed up `checkGeneralCapacity`. issue: https://github.com/milvus-io/milvus/issues/37630 Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/35835/merge
parent
2319018fcb
commit
7fb0c281f2
|
@ -20,7 +20,6 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -35,7 +34,6 @@ func checkGeneralCapacity(ctx context.Context, newColNum int,
|
|||
newParNum int64,
|
||||
newShardNum int32,
|
||||
core *Core,
|
||||
ts typeutil.Timestamp,
|
||||
) error {
|
||||
var addedNum int64 = 0
|
||||
if newColNum > 0 && newParNum > 0 && newShardNum > 0 {
|
||||
|
@ -46,25 +44,10 @@ func checkGeneralCapacity(ctx context.Context, newColNum int,
|
|||
addedNum += newParNum
|
||||
}
|
||||
|
||||
var generalNum int64 = 0
|
||||
collectionsMap := core.meta.ListAllAvailCollections(ctx)
|
||||
for dbId, collectionIDs := range collectionsMap {
|
||||
db, err := core.meta.GetDatabaseByID(ctx, dbId, ts)
|
||||
if err == nil {
|
||||
for _, collectionId := range collectionIDs {
|
||||
collection, err := core.meta.GetCollectionByID(ctx, db.Name, collectionId, ts, true)
|
||||
if err == nil {
|
||||
partNum := int64(collection.GetPartitionNum(false))
|
||||
shardNum := int64(collection.ShardsNum)
|
||||
generalNum += partNum * shardNum
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
generalNum += addedNum
|
||||
if generalNum > Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64() {
|
||||
return merr.WrapGeneralCapacityExceed(generalNum, Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64(),
|
||||
generalCount := core.meta.GetGeneralCount(ctx)
|
||||
generalCount += int(addedNum)
|
||||
if generalCount > Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt() {
|
||||
return merr.WrapGeneralCapacityExceed(generalCount, Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64(),
|
||||
"failed checking constraint: sum_collections(parition*shard) exceeding the max general capacity:")
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -107,7 +107,7 @@ func (t *createCollectionTask) validate() error {
|
|||
if t.Req.GetNumPartitions() > 0 {
|
||||
newPartNum = t.Req.GetNumPartitions()
|
||||
}
|
||||
return checkGeneralCapacity(t.ctx, 1, newPartNum, t.Req.GetShardsNum(), t.core, t.ts)
|
||||
return checkGeneralCapacity(t.ctx, 1, newPartNum, t.Req.GetShardsNum(), t.core)
|
||||
}
|
||||
|
||||
// checkMaxCollectionsPerDB DB properties take precedence over quota configurations for max collections.
|
||||
|
|
|
@ -246,23 +246,7 @@ func Test_createCollectionTask_validate(t *testing.T) {
|
|||
meta.EXPECT().ListAllAvailCollections(mock.Anything).Return(map[int64][]int64{1: {1, 2}})
|
||||
meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(&model.Database{Name: "db1"}, nil).Once()
|
||||
|
||||
meta.On("GetDatabaseByID",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(&model.Database{
|
||||
Name: "default",
|
||||
}, nil)
|
||||
meta.On("GetCollectionByID",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(&model.Collection{
|
||||
Name: "default",
|
||||
ShardsNum: 2,
|
||||
Partitions: []*model.Partition{
|
||||
{
|
||||
PartitionID: 1,
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
meta.EXPECT().GetGeneralCount(mock.Anything).Return(1)
|
||||
|
||||
core := newTestCore(withMeta(meta))
|
||||
|
||||
|
@ -295,8 +279,7 @@ func Test_createCollectionTask_validate(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}, nil).Once()
|
||||
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(nil, errors.New("mock"))
|
||||
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
|
||||
|
||||
core := newTestCore(withMeta(meta))
|
||||
task := createCollectionTask{
|
||||
|
@ -734,6 +717,7 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
|
|||
).Return(map[int64][]int64{
|
||||
util.DefaultDBID: {1, 2},
|
||||
}, nil)
|
||||
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
|
||||
|
||||
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64))
|
||||
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key)
|
||||
|
@ -754,8 +738,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("invalid schema", func(t *testing.T) {
|
||||
meta.On("GetDatabaseByID", mock.Anything,
|
||||
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
|
||||
core := newTestCore(withMeta(meta))
|
||||
collectionName := funcutil.GenRandomStr()
|
||||
task := &createCollectionTask{
|
||||
|
@ -784,8 +766,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
|
|||
}
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
meta.On("GetDatabaseByID", mock.Anything,
|
||||
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
|
||||
core := newTestCore(withInvalidIDAllocator(), withMeta(meta))
|
||||
|
||||
task := createCollectionTask{
|
||||
|
@ -808,8 +788,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
|
|||
field1 := funcutil.GenRandomStr()
|
||||
|
||||
ticker := newRocksMqTtSynchronizer()
|
||||
meta.On("GetDatabaseByID", mock.Anything,
|
||||
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
|
||||
|
||||
core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta))
|
||||
|
||||
|
@ -1160,8 +1138,7 @@ func Test_createCollectionTask_PartitionKey(t *testing.T) {
|
|||
).Return(map[int64][]int64{
|
||||
util.DefaultDBID: {1, 2},
|
||||
}, nil)
|
||||
meta.On("GetDatabaseByID", mock.Anything,
|
||||
mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
|
||||
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
|
||||
|
||||
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64))
|
||||
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key)
|
||||
|
|
|
@ -46,7 +46,7 @@ func (t *createPartitionTask) Prepare(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
t.collMeta = collMeta
|
||||
return checkGeneralCapacity(ctx, 0, 1, 0, t.core, t.ts)
|
||||
return checkGeneralCapacity(ctx, 0, 1, 0, t.core)
|
||||
}
|
||||
|
||||
func (t *createPartitionTask) Execute(ctx context.Context) error {
|
||||
|
|
|
@ -20,7 +20,6 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
|
@ -62,14 +61,7 @@ func Test_createPartitionTask_Prepare(t *testing.T) {
|
|||
mock.Anything,
|
||||
mock.Anything,
|
||||
).Return(coll.Clone(), nil)
|
||||
meta.On("ListAllAvailCollections",
|
||||
mock.Anything,
|
||||
).Return(map[int64][]int64{
|
||||
1: {1, 2},
|
||||
}, nil)
|
||||
meta.On("GetDatabaseByID",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil, errors.New("mock"))
|
||||
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
|
||||
|
||||
core := newTestCore(withMeta(meta))
|
||||
task := &createPartitionTask{
|
||||
|
|
|
@ -74,6 +74,7 @@ type IMetaTable interface {
|
|||
ListAliases(ctx context.Context, dbName string, collectionName string, ts Timestamp) ([]string, error)
|
||||
AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp) error
|
||||
RenameCollection(ctx context.Context, dbName string, oldName string, newDBName string, newName string, ts Timestamp) error
|
||||
GetGeneralCount(ctx context.Context) int
|
||||
|
||||
// TODO: it'll be a big cost if we handle the time travel logic, since we should always list all aliases in catalog.
|
||||
IsAlias(ctx context.Context, db, name string) bool
|
||||
|
@ -115,6 +116,8 @@ type MetaTable struct {
|
|||
dbName2Meta map[string]*model.Database // database name -> db meta
|
||||
collID2Meta map[typeutil.UniqueID]*model.Collection // collection id -> collection meta
|
||||
|
||||
generalCnt int // sum of product of partition number and shard number
|
||||
|
||||
// collections *collectionDb
|
||||
names *nameDb
|
||||
aliases *nameDb
|
||||
|
@ -189,6 +192,7 @@ func (mt *MetaTable) reload() error {
|
|||
}
|
||||
for _, collection := range collections {
|
||||
mt.collID2Meta[collection.CollectionID] = collection
|
||||
mt.generalCnt += len(collection.Partitions) * int(collection.ShardsNum)
|
||||
if collection.Available() {
|
||||
mt.names.insert(dbName, collection.Name, collection.CollectionID)
|
||||
collectionNum++
|
||||
|
@ -417,6 +421,8 @@ func (mt *MetaTable) AddCollection(ctx context.Context, coll *model.Collection)
|
|||
mt.collID2Meta[coll.CollectionID] = coll.Clone()
|
||||
mt.names.insert(db.Name, coll.Name, coll.CollectionID)
|
||||
|
||||
mt.generalCnt += len(coll.Partitions) * int(coll.ShardsNum)
|
||||
|
||||
log.Ctx(ctx).Info("add collection to meta table",
|
||||
zap.Int64("dbID", coll.DBID),
|
||||
zap.String("collection", coll.Name),
|
||||
|
@ -521,6 +527,8 @@ func (mt *MetaTable) RemoveCollection(ctx context.Context, collectionID UniqueID
|
|||
mt.removeAllNamesIfMatchedInternal(collectionID, allNames)
|
||||
mt.removeCollectionByIDInternal(collectionID)
|
||||
|
||||
mt.generalCnt -= len(coll.Partitions) * int(coll.ShardsNum)
|
||||
|
||||
log.Ctx(ctx).Info("remove collection",
|
||||
zap.Int64("dbID", coll.DBID),
|
||||
zap.String("name", coll.Name),
|
||||
|
@ -895,6 +903,8 @@ func (mt *MetaTable) AddPartition(ctx context.Context, partition *model.Partitio
|
|||
}
|
||||
mt.collID2Meta[partition.CollectionID].Partitions = append(mt.collID2Meta[partition.CollectionID].Partitions, partition.Clone())
|
||||
|
||||
mt.generalCnt += int(coll.ShardsNum) // 1 partition * shardNum
|
||||
|
||||
metrics.RootCoordNumOfPartitions.WithLabelValues().Inc()
|
||||
|
||||
log.Ctx(ctx).Info("add partition to meta table",
|
||||
|
@ -961,6 +971,7 @@ func (mt *MetaTable) RemovePartition(ctx context.Context, dbID int64, collection
|
|||
}
|
||||
if loc != -1 {
|
||||
coll.Partitions = append(coll.Partitions[:loc], coll.Partitions[loc+1:]...)
|
||||
mt.generalCnt -= int(coll.ShardsNum) // 1 partition * shardNum
|
||||
}
|
||||
log.Info("remove partition", zap.Int64("collection", collectionID), zap.Int64("partition", partitionID), zap.Uint64("ts", ts))
|
||||
return nil
|
||||
|
@ -1229,6 +1240,14 @@ func (mt *MetaTable) ListAliasesByID(ctx context.Context, collID UniqueID) []str
|
|||
return mt.listAliasesByID(collID)
|
||||
}
|
||||
|
||||
// GetGeneralCount gets the general count(sum of product of partition number and shard number).
|
||||
func (mt *MetaTable) GetGeneralCount(ctx context.Context) int {
|
||||
mt.ddLock.RLock()
|
||||
defer mt.ddLock.RUnlock()
|
||||
|
||||
return mt.generalCnt
|
||||
}
|
||||
|
||||
// AddCredential add credential
|
||||
func (mt *MetaTable) AddCredential(ctx context.Context, credInfo *internalpb.CredentialInfo) error {
|
||||
if credInfo.Username == "" {
|
||||
|
|
|
@ -1473,6 +1473,52 @@ func (_c *IMetaTable_GetDatabaseByName_Call) RunAndReturn(run func(context.Conte
|
|||
return _c
|
||||
}
|
||||
|
||||
// GetGeneralCount provides a mock function with given fields: ctx
|
||||
func (_m *IMetaTable) GetGeneralCount(ctx context.Context) int {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetGeneralCount")
|
||||
}
|
||||
|
||||
var r0 int
|
||||
if rf, ok := ret.Get(0).(func(context.Context) int); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
r0 = ret.Get(0).(int)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// IMetaTable_GetGeneralCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGeneralCount'
|
||||
type IMetaTable_GetGeneralCount_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetGeneralCount is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
func (_e *IMetaTable_Expecter) GetGeneralCount(ctx interface{}) *IMetaTable_GetGeneralCount_Call {
|
||||
return &IMetaTable_GetGeneralCount_Call{Call: _e.mock.On("GetGeneralCount", ctx)}
|
||||
}
|
||||
|
||||
func (_c *IMetaTable_GetGeneralCount_Call) Run(run func(ctx context.Context)) *IMetaTable_GetGeneralCount_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *IMetaTable_GetGeneralCount_Call) Return(_a0 int) *IMetaTable_GetGeneralCount_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *IMetaTable_GetGeneralCount_Call) RunAndReturn(run func(context.Context) int) *IMetaTable_GetGeneralCount_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetPChannelInfo provides a mock function with given fields: ctx, pchannel
|
||||
func (_m *IMetaTable) GetPChannelInfo(ctx context.Context, pchannel string) *rootcoordpb.GetPChannelInfoResponse {
|
||||
ret := _m.Called(ctx, pchannel)
|
||||
|
|
Loading…
Reference in New Issue