Check the collection num when creating the collection (#22946)

Signed-off-by: SimFG <bang.fu@zilliz.com>
pull/22962/head
SimFG 2023-03-23 16:47:57 +08:00 committed by GitHub
parent af458017f4
commit e8f8c1b445
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 85 additions and 32 deletions

View File

@ -4,33 +4,31 @@ import (
"context" "context"
"strings" "strings"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/crypto" "github.com/milvus-io/milvus/internal/util/crypto"
"go.uber.org/zap"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
// validAuth validates the authentication func parseMD(authorization []string) (username, password string) {
func validAuth(ctx context.Context, authorization []string) bool {
if len(authorization) < 1 { if len(authorization) < 1 {
//log.Warn("key not found in header", zap.String("key", headerAuthorize)) log.Warn("key not found in header")
return false return
} }
// token format: base64<username:password> // token format: base64<username:password>
//token := strings.TrimPrefix(authorization[0], "Bearer ") //token := strings.TrimPrefix(authorization[0], "Bearer ")
token := authorization[0] token := authorization[0]
rawToken, err := crypto.Base64Decode(token) rawToken, err := crypto.Base64Decode(token)
if err != nil { if err != nil {
return false log.Warn("fail to decode the token", zap.Error(err))
return
} }
secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2) secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2)
username := secrets[0] username = secrets[0]
password := secrets[1] password = secrets[1]
isSuccess := passwordVerify(ctx, username, password, globalMetaCache) return
if isSuccess {
metrics.UserRPCCounter.WithLabelValues(username).Inc()
}
return isSuccess
} }
func validSourceID(ctx context.Context, authorization []string) bool { func validSourceID(ctx context.Context, authorization []string) bool {
@ -62,9 +60,12 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) {
// 1. if rpc call from a member (like index/query/data component) // 1. if rpc call from a member (like index/query/data component)
// 2. if rpc call from sdk // 2. if rpc call from sdk
if Params.CommonCfg.AuthorizationEnabled.GetAsBool() { if Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
if !validSourceID(ctx, md[strings.ToLower(util.HeaderSourceID)]) && if !validSourceID(ctx, md[strings.ToLower(util.HeaderSourceID)]) {
!validAuth(ctx, md[strings.ToLower(util.HeaderAuthorize)]) { username, password := parseMD(md[strings.ToLower(util.HeaderAuthorize)])
return nil, ErrUnauthenticated() if !passwordVerify(ctx, username, password, globalMetaCache) {
return nil, ErrUnauthenticated()
}
metrics.UserRPCCounter.WithLabelValues(username).Inc()
} }
} }
return ctx, nil return ctx, nil

View File

@ -4,18 +4,24 @@ import (
"context" "context"
"testing" "testing"
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/crypto" "github.com/milvus-io/milvus/internal/util/crypto"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
) )
// validAuth validates the authentication // validAuth validates the authentication
func TestValidAuth(t *testing.T) { func TestValidAuth(t *testing.T) {
validAuth := func(ctx context.Context, authorization []string) bool {
username, password := parseMD(authorization)
if username == "" || password == "" {
return false
}
return passwordVerify(ctx, username, password, globalMetaCache)
}
ctx := context.Background() ctx := context.Background()
// no metadata // no metadata
res := validAuth(ctx, nil) res := validAuth(ctx, nil)

View File

@ -22,8 +22,6 @@ import (
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/msgpb" "github.com/milvus-io/milvus-proto/go-api/msgpb"
@ -34,7 +32,9 @@ import (
pb "github.com/milvus-io/milvus/internal/proto/etcdpb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/util/commonpbutil" "github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/merr"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
) )
type collectionChannels struct { type collectionChannels struct {
@ -287,6 +287,17 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
return nil return nil
} }
existedCollInfos, err := t.core.meta.ListCollections(ctx, typeutil.MaxTimestamp)
if err != nil {
log.Warn("fail to list collections for checking the collection count", zap.Error(err))
return fmt.Errorf("fail to list collections for checking the collection count")
}
maxCollectionNum := Params.QuotaConfig.MaxCollectionNum.GetAsInt()
if len(existedCollInfos) >= maxCollectionNum {
log.Error("unable to create collection because the number of collection has reached the limit", zap.Int("max_collection_num", maxCollectionNum))
return merr.WrapErrCollectionResourceLimitExceeded(fmt.Sprintf("Failed to create collection, limit={%d}", maxCollectionNum))
}
undoTask := newBaseUndoTask(t.core.stepExecutor) undoTask := newBaseUndoTask(t.core.stepExecutor)
undoTask.AddStep(&expireCacheStep{ undoTask.AddStep(&expireCacheStep{
baseStep: baseStep{core: t.core}, baseStep: baseStep{core: t.core},

View File

@ -22,11 +22,6 @@ import (
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/stretchr/testify/mock"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
@ -34,8 +29,10 @@ import (
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/etcdpb"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
) )
func Test_createCollectionTask_validate(t *testing.T) { func Test_createCollectionTask_validate(t *testing.T) {
@ -311,6 +308,9 @@ func Test_createCollectionTask_Execute(t *testing.T) {
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) {
return coll, nil return coll, nil
} }
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
return []*model.Collection{}, nil
}
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker)) core := newTestCore(withMeta(meta), withTtSynchronizer(ticker))
@ -356,6 +356,9 @@ func Test_createCollectionTask_Execute(t *testing.T) {
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) {
return coll, nil return coll, nil
} }
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
return []*model.Collection{}, nil
}
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker)) core := newTestCore(withMeta(meta), withTtSynchronizer(ticker))
@ -456,6 +459,28 @@ func Test_createCollectionTask_Execute(t *testing.T) {
schema: schema, schema: schema,
} }
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
return nil, errors.New("mock error")
}
err = task.Execute(context.Background())
assert.Error(t, err)
originFormatter := Params.QuotaConfig.MaxCollectionNum.Formatter
Params.QuotaConfig.MaxCollectionNum.Formatter = func(originValue string) string {
return "10"
}
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
maxNum := Params.QuotaConfig.MaxCollectionNum.GetAsInt()
return make([]*model.Collection, maxNum), nil
}
err = task.Execute(context.Background())
assert.Error(t, err)
Params.QuotaConfig.MaxCollectionNum.Formatter = originFormatter
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
return []*model.Collection{}, nil
}
err = task.Execute(context.Background()) err = task.Execute(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
}) })
@ -477,6 +502,9 @@ func Test_createCollectionTask_Execute(t *testing.T) {
meta.AddCollectionFunc = func(ctx context.Context, coll *model.Collection) error { meta.AddCollectionFunc = func(ctx context.Context, coll *model.Collection) error {
return nil return nil
} }
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
return []*model.Collection{}, nil
}
// inject error here. // inject error here.
meta.ChangeCollectionStateFunc = func(ctx context.Context, collectionID UniqueID, state etcdpb.CollectionState, ts Timestamp) error { meta.ChangeCollectionStateFunc = func(ctx context.Context, collectionID UniqueID, state etcdpb.CollectionState, ts Timestamp) error {
return errors.New("error mock ChangeCollectionState") return errors.New("error mock ChangeCollectionState")

View File

@ -19,10 +19,9 @@ package merr
import ( import (
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/samber/lo"
) )
const ( const (
@ -56,8 +55,9 @@ var (
ErrServiceInternal = newMilvusError("service internal error", 5, false) // Never return this error out of Milvus ErrServiceInternal = newMilvusError("service internal error", 5, false) // Never return this error out of Milvus
// Collection related // Collection related
ErrCollectionNotFound = newMilvusError("collection not found", 100, false) ErrCollectionNotFound = newMilvusError("collection not found", 100, false)
ErrCollectionNotLoaded = newMilvusError("collection not loaded", 101, false) ErrCollectionNotLoaded = newMilvusError("collection not loaded", 101, false)
ErrCollectionNumLimitExceeded = newMilvusError("exceeded the limit number of collections", 102, false)
// Partition related // Partition related
ErrPartitionNotFound = newMilvusError("partition not found", 202, false) ErrPartitionNotFound = newMilvusError("partition not found", 202, false)

View File

@ -108,7 +108,6 @@ func (s *ErrSuite) TestWrap() {
// Metrics related // Metrics related
s.ErrorIs(WrapErrMetricNotFound("unknown", "failed to get metric"), ErrMetricNotFound) s.ErrorIs(WrapErrMetricNotFound("unknown", "failed to get metric"), ErrMetricNotFound)
} }
func (s *ErrSuite) TestCombine() { func (s *ErrSuite) TestCombine() {

View File

@ -165,6 +165,14 @@ func WrapErrCollectionNotLoaded(collection any, msg ...string) error {
return err return err
} }
func WrapErrCollectionResourceLimitExceeded(msg ...string) error {
var err error = ErrCollectionNumLimitExceeded
if len(msg) > 0 {
err = errors.Wrap(err, strings.Join(msg, "; "))
}
return err
}
// Partition related // Partition related
func WrapErrPartitionNotFound(partition any, msg ...string) error { func WrapErrPartitionNotFound(partition any, msg ...string) error {
err := wrapWithField(ErrPartitionNotFound, "partition", partition) err := wrapWithField(ErrPartitionNotFound, "partition", partition)

View File

@ -497,7 +497,7 @@ The maximum rate will not be greater than ` + "max" + `.`,
p.MaxCollectionNum = ParamItem{ p.MaxCollectionNum = ParamItem{
Key: "quotaAndLimits.limits.collection.maxNum", Key: "quotaAndLimits.limits.collection.maxNum",
Version: "2.2.0", Version: "2.2.0",
DefaultValue: "64", DefaultValue: "65535",
} }
p.MaxCollectionNum.Init(base.mgr) p.MaxCollectionNum.Init(base.mgr)

View File

@ -66,7 +66,7 @@ func TestQuotaParam(t *testing.T) {
}) })
t.Run("test limits", func(t *testing.T) { t.Run("test limits", func(t *testing.T) {
assert.Equal(t, 64, qc.MaxCollectionNum.GetAsInt()) assert.Equal(t, 65535, qc.MaxCollectionNum.GetAsInt())
}) })
t.Run("test limit writing", func(t *testing.T) { t.Run("test limit writing", func(t *testing.T) {