fix: the panic when db isn't existed in the rate limit interceptor (#33244)

issue: #33243

Signed-off-by: SimFG <bang.fu@zilliz.com>
pull/32643/head
SimFG 2024-05-22 15:57:39 +08:00 committed by GitHub
parent 3c4df81261
commit dd0c6d6980
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 50 additions and 12 deletions

View File

@ -23,9 +23,7 @@ import (
"github.com/casbin/casbin/v2/model"
jsonadapter "github.com/casbin/json-adapter/v2"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
)
@ -51,7 +49,6 @@ func (a *MetaCacheCasbinAdapter) LoadPolicy(model model.Model) error {
policyInfo := strings.Join(cache.GetPrivilegeInfo(context.Background()), ",")
policy := fmt.Sprintf("[%s]", policyInfo)
log.Ctx(context.Background()).Info("LoddPolicy update policyinfo", zap.String("policyInfo", policy))
byteSource := []byte(policy)
jAdapter := jsonadapter.NewAdapter(&byteSource)
return jAdapter.LoadPolicy(model)

View File

@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/requestutil"
@ -119,6 +120,9 @@ func getCollectionAndPartitionIDs(ctx context.Context, r reqPartNames) (int64, m
func getCollectionID(r reqCollName) (int64, map[int64][]int64) {
db, _ := globalMetaCache.GetDatabaseInfo(context.TODO(), r.GetDbName())
if db == nil {
return util.InvalidDBID, map[int64][]int64{}
}
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return db.dbID, map[int64][]int64{collectionID: {}}
}
@ -177,14 +181,14 @@ func getRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]in
case *milvuspb.FlushRequest:
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
if err != nil {
return 0, map[int64][]int64{}, 0, 0, err
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
collToPartIDs := make(map[int64][]int64, 0)
for _, collectionName := range r.GetCollectionNames() {
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName)
if err != nil {
return 0, map[int64][]int64{}, 0, 0, err
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
collToPartIDs[collectionID] = []int64{}
}
@ -193,16 +197,16 @@ func getRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]in
dbName := GetCurDBNameFromContextOrDefault(ctx)
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName)
if err != nil {
return 0, map[int64][]int64{}, 0, 0, err
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
return dbInfo.dbID, map[int64][]int64{
r.GetCollectionID(): {},
}, internalpb.RateType_DDLCompaction, 1, nil
default: // TODO: support more request
if req == nil {
return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
}
return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name())
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name())
}
}

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/merr"
)
@ -367,7 +368,7 @@ func TestGetInfo(t *testing.T) {
}()
t.Run("fail to get database", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info")).Times(4)
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info")).Times(5)
{
_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
DbName: "foo",
@ -394,6 +395,11 @@ func TestGetInfo(t *testing.T) {
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.ManualCompactionRequest{})
assert.Error(t, err)
}
{
dbID, collectionIDInfos := getCollectionID(&milvuspb.CreateCollectionRequest{})
assert.Equal(t, util.InvalidDBID, dbID)
assert.Equal(t, 0, len(collectionIDInfos))
}
})
t.Run("fail to get collection", func(t *testing.T) {

View File

@ -32,6 +32,7 @@ import (
rlinternal "github.com/milvus-io/milvus/internal/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
@ -79,7 +80,7 @@ func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int6
}
// 2. check database level rate limits
if ret == nil {
if ret == nil && dbID != util.InvalidDBID {
dbRateLimiters := m.rateLimiter.GetOrCreateDatabaseLimiters(dbID, newDatabaseLimiter)
ret = dbRateLimiters.Check(rt, n)
if ret != nil {
@ -92,6 +93,9 @@ func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int6
// 3. check collection level rate limits
if ret == nil && len(collectionIDToPartIDs) > 0 && !isNotCollectionLevelLimitRequest(rt) {
for collectionID := range collectionIDToPartIDs {
if collectionID == 0 || dbID == util.InvalidDBID {
continue
}
// only dml and dql have collection level rate limits
collectionRateLimiters := m.rateLimiter.GetOrCreateCollectionLimiters(dbID, collectionID,
newDatabaseLimiter, newCollectionLimiters)
@ -108,6 +112,9 @@ func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int6
if ret == nil && len(collectionIDToPartIDs) > 0 {
for collectionID, partitionIDs := range collectionIDToPartIDs {
for _, partID := range partitionIDs {
if collectionID == 0 || partID == 0 || dbID == util.InvalidDBID {
continue
}
partitionRateLimiters := m.rateLimiter.GetOrCreatePartitionLimiters(dbID, collectionID, partID,
newDatabaseLimiter, newCollectionLimiters, newPartitionLimiters)
ret = partitionRateLimiters.Check(rt, n)

View File

@ -87,9 +87,11 @@ func TestSimpleRateLimiter(t *testing.T) {
clusterRateLimiters := simpleLimiter.rateLimiter.GetRootLimiters()
collectionIDToPartIDs := map[int64][]int64{
0: {},
1: {},
2: {},
3: {},
4: {0},
}
for i := 1; i <= 3; i++ {

View File

@ -433,9 +433,15 @@ func (q *QuotaCenter) collectMetrics() error {
}
}
datacoordQuotaCollections := make([]int64, 0)
q.diskMu.Lock()
if dataCoordTopology.Cluster.Self.QuotaMetrics != nil {
q.dataCoordMetrics = dataCoordTopology.Cluster.Self.QuotaMetrics
for _, metricCollections := range q.dataCoordMetrics.PartitionsBinlogSize {
for metricCollection := range metricCollections {
datacoordQuotaCollections = append(datacoordQuotaCollections, metricCollection)
}
}
}
q.diskMu.Unlock()
@ -447,7 +453,6 @@ func (q *QuotaCenter) collectMetrics() error {
}
var rangeErr error
collections.Range(func(collectionID int64) bool {
var coll *model.Collection
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
if getErr != nil {
rangeErr = getErr
@ -482,7 +487,23 @@ func (q *QuotaCenter) collectMetrics() error {
}
return true
})
return rangeErr
if rangeErr != nil {
return rangeErr
}
for _, collectionID := range datacoordQuotaCollections {
_, ok := q.collectionIDToDBID.Get(collectionID)
if ok {
continue
}
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
if getErr != nil {
return getErr
}
q.collectionIDToDBID.Insert(collectionID, coll.DBID)
q.collections.Insert(FormatCollectionKey(coll.DBID, coll.Name), collectionID)
}
return nil
})
// get Proxies metrics
group.Go(func() error {

View File

@ -56,6 +56,7 @@ const (
DefaultDBName = "default"
DefaultDBID = int64(1)
NonDBID = int64(0)
InvalidDBID = int64(-1)
PrivilegeWord = "Privilege"
AnyWord = "*"