mirror of https://github.com/milvus-io/milvus.git
fix: the panic when db isn't existed in the rate limit interceptor (#33244)
issue: #33243 Signed-off-by: SimFG <bang.fu@zilliz.com>pull/32643/head
parent
3c4df81261
commit
dd0c6d6980
|
@ -23,9 +23,7 @@ import (
|
|||
|
||||
"github.com/casbin/casbin/v2/model"
|
||||
jsonadapter "github.com/casbin/json-adapter/v2"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
|
@ -51,7 +49,6 @@ func (a *MetaCacheCasbinAdapter) LoadPolicy(model model.Model) error {
|
|||
policyInfo := strings.Join(cache.GetPrivilegeInfo(context.Background()), ",")
|
||||
|
||||
policy := fmt.Sprintf("[%s]", policyInfo)
|
||||
log.Ctx(context.Background()).Info("LoddPolicy update policyinfo", zap.String("policyInfo", policy))
|
||||
byteSource := []byte(policy)
|
||||
jAdapter := jsonadapter.NewAdapter(&byteSource)
|
||||
return jAdapter.LoadPolicy(model)
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/requestutil"
|
||||
|
@ -119,6 +120,9 @@ func getCollectionAndPartitionIDs(ctx context.Context, r reqPartNames) (int64, m
|
|||
|
||||
func getCollectionID(r reqCollName) (int64, map[int64][]int64) {
|
||||
db, _ := globalMetaCache.GetDatabaseInfo(context.TODO(), r.GetDbName())
|
||||
if db == nil {
|
||||
return util.InvalidDBID, map[int64][]int64{}
|
||||
}
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return db.dbID, map[int64][]int64{collectionID: {}}
|
||||
}
|
||||
|
@ -177,14 +181,14 @@ func getRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]in
|
|||
case *milvuspb.FlushRequest:
|
||||
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
|
||||
if err != nil {
|
||||
return 0, map[int64][]int64{}, 0, 0, err
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
|
||||
collToPartIDs := make(map[int64][]int64, 0)
|
||||
for _, collectionName := range r.GetCollectionNames() {
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
return 0, map[int64][]int64{}, 0, 0, err
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
collToPartIDs[collectionID] = []int64{}
|
||||
}
|
||||
|
@ -193,16 +197,16 @@ func getRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]in
|
|||
dbName := GetCurDBNameFromContextOrDefault(ctx)
|
||||
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName)
|
||||
if err != nil {
|
||||
return 0, map[int64][]int64{}, 0, 0, err
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
return dbInfo.dbID, map[int64][]int64{
|
||||
r.GetCollectionID(): {},
|
||||
}, internalpb.RateType_DDLCompaction, 1, nil
|
||||
default: // TODO: support more request
|
||||
if req == nil {
|
||||
return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
|
||||
}
|
||||
return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name())
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
|
@ -367,7 +368,7 @@ func TestGetInfo(t *testing.T) {
|
|||
}()
|
||||
|
||||
t.Run("fail to get database", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info")).Times(4)
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info")).Times(5)
|
||||
{
|
||||
_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||
DbName: "foo",
|
||||
|
@ -394,6 +395,11 @@ func TestGetInfo(t *testing.T) {
|
|||
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.ManualCompactionRequest{})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
dbID, collectionIDInfos := getCollectionID(&milvuspb.CreateCollectionRequest{})
|
||||
assert.Equal(t, util.InvalidDBID, dbID)
|
||||
assert.Equal(t, 0, len(collectionIDInfos))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fail to get collection", func(t *testing.T) {
|
||||
|
|
|
@ -32,6 +32,7 @@ import (
|
|||
rlinternal "github.com/milvus-io/milvus/internal/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
|
@ -79,7 +80,7 @@ func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int6
|
|||
}
|
||||
|
||||
// 2. check database level rate limits
|
||||
if ret == nil {
|
||||
if ret == nil && dbID != util.InvalidDBID {
|
||||
dbRateLimiters := m.rateLimiter.GetOrCreateDatabaseLimiters(dbID, newDatabaseLimiter)
|
||||
ret = dbRateLimiters.Check(rt, n)
|
||||
if ret != nil {
|
||||
|
@ -92,6 +93,9 @@ func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int6
|
|||
// 3. check collection level rate limits
|
||||
if ret == nil && len(collectionIDToPartIDs) > 0 && !isNotCollectionLevelLimitRequest(rt) {
|
||||
for collectionID := range collectionIDToPartIDs {
|
||||
if collectionID == 0 || dbID == util.InvalidDBID {
|
||||
continue
|
||||
}
|
||||
// only dml and dql have collection level rate limits
|
||||
collectionRateLimiters := m.rateLimiter.GetOrCreateCollectionLimiters(dbID, collectionID,
|
||||
newDatabaseLimiter, newCollectionLimiters)
|
||||
|
@ -108,6 +112,9 @@ func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int6
|
|||
if ret == nil && len(collectionIDToPartIDs) > 0 {
|
||||
for collectionID, partitionIDs := range collectionIDToPartIDs {
|
||||
for _, partID := range partitionIDs {
|
||||
if collectionID == 0 || partID == 0 || dbID == util.InvalidDBID {
|
||||
continue
|
||||
}
|
||||
partitionRateLimiters := m.rateLimiter.GetOrCreatePartitionLimiters(dbID, collectionID, partID,
|
||||
newDatabaseLimiter, newCollectionLimiters, newPartitionLimiters)
|
||||
ret = partitionRateLimiters.Check(rt, n)
|
||||
|
|
|
@ -87,9 +87,11 @@ func TestSimpleRateLimiter(t *testing.T) {
|
|||
clusterRateLimiters := simpleLimiter.rateLimiter.GetRootLimiters()
|
||||
|
||||
collectionIDToPartIDs := map[int64][]int64{
|
||||
0: {},
|
||||
1: {},
|
||||
2: {},
|
||||
3: {},
|
||||
4: {0},
|
||||
}
|
||||
|
||||
for i := 1; i <= 3; i++ {
|
||||
|
|
|
@ -433,9 +433,15 @@ func (q *QuotaCenter) collectMetrics() error {
|
|||
}
|
||||
}
|
||||
|
||||
datacoordQuotaCollections := make([]int64, 0)
|
||||
q.diskMu.Lock()
|
||||
if dataCoordTopology.Cluster.Self.QuotaMetrics != nil {
|
||||
q.dataCoordMetrics = dataCoordTopology.Cluster.Self.QuotaMetrics
|
||||
for _, metricCollections := range q.dataCoordMetrics.PartitionsBinlogSize {
|
||||
for metricCollection := range metricCollections {
|
||||
datacoordQuotaCollections = append(datacoordQuotaCollections, metricCollection)
|
||||
}
|
||||
}
|
||||
}
|
||||
q.diskMu.Unlock()
|
||||
|
||||
|
@ -447,7 +453,6 @@ func (q *QuotaCenter) collectMetrics() error {
|
|||
}
|
||||
var rangeErr error
|
||||
collections.Range(func(collectionID int64) bool {
|
||||
var coll *model.Collection
|
||||
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
|
||||
if getErr != nil {
|
||||
rangeErr = getErr
|
||||
|
@ -482,7 +487,23 @@ func (q *QuotaCenter) collectMetrics() error {
|
|||
}
|
||||
return true
|
||||
})
|
||||
return rangeErr
|
||||
if rangeErr != nil {
|
||||
return rangeErr
|
||||
}
|
||||
for _, collectionID := range datacoordQuotaCollections {
|
||||
_, ok := q.collectionIDToDBID.Get(collectionID)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
|
||||
if getErr != nil {
|
||||
return getErr
|
||||
}
|
||||
q.collectionIDToDBID.Insert(collectionID, coll.DBID)
|
||||
q.collections.Insert(FormatCollectionKey(coll.DBID, coll.Name), collectionID)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
// get Proxies metrics
|
||||
group.Go(func() error {
|
||||
|
|
|
@ -56,6 +56,7 @@ const (
|
|||
DefaultDBName = "default"
|
||||
DefaultDBID = int64(1)
|
||||
NonDBID = int64(0)
|
||||
InvalidDBID = int64(-1)
|
||||
|
||||
PrivilegeWord = "Privilege"
|
||||
AnyWord = "*"
|
||||
|
|
Loading…
Reference in New Issue