mirror of https://github.com/milvus-io/milvus.git
Check the collection num when creating the collection (#22919)
Signed-off-by: SimFG <bang.fu@zilliz.com>pull/22961/head
parent
bc8fca80fa
commit
5784943862
|
@ -4,33 +4,31 @@ import (
|
|||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/util"
|
||||
"github.com/milvus-io/milvus/internal/util/crypto"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// validAuth validates the authentication
|
||||
func validAuth(ctx context.Context, authorization []string) bool {
|
||||
func parseMD(authorization []string) (username, password string) {
|
||||
if len(authorization) < 1 {
|
||||
//log.Warn("key not found in header", zap.String("key", headerAuthorize))
|
||||
return false
|
||||
log.Warn("key not found in header")
|
||||
return
|
||||
}
|
||||
// token format: base64<username:password>
|
||||
//token := strings.TrimPrefix(authorization[0], "Bearer ")
|
||||
token := authorization[0]
|
||||
rawToken, err := crypto.Base64Decode(token)
|
||||
if err != nil {
|
||||
return false
|
||||
log.Warn("fail to decode the token", zap.Error(err))
|
||||
return
|
||||
}
|
||||
secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2)
|
||||
username := secrets[0]
|
||||
password := secrets[1]
|
||||
isSuccess := passwordVerify(ctx, username, password, globalMetaCache)
|
||||
if isSuccess {
|
||||
metrics.UserRPCCounter.WithLabelValues(username).Inc()
|
||||
}
|
||||
return isSuccess
|
||||
username = secrets[0]
|
||||
password = secrets[1]
|
||||
return
|
||||
}
|
||||
|
||||
func validSourceID(ctx context.Context, authorization []string) bool {
|
||||
|
@ -62,9 +60,14 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) {
|
|||
// 1. if rpc call from a member (like index/query/data component)
|
||||
// 2. if rpc call from sdk
|
||||
if Params.CommonCfg.AuthorizationEnabled {
|
||||
if !validSourceID(ctx, md[strings.ToLower(util.HeaderSourceID)]) &&
|
||||
!validAuth(ctx, md[strings.ToLower(util.HeaderAuthorize)]) {
|
||||
return nil, ErrUnauthenticated()
|
||||
var username string
|
||||
var password string
|
||||
if !validSourceID(ctx, md[strings.ToLower(util.HeaderSourceID)]) {
|
||||
username, password = parseMD(md[strings.ToLower(util.HeaderAuthorize)])
|
||||
if !passwordVerify(ctx, username, password, globalMetaCache) {
|
||||
return nil, ErrUnauthenticated()
|
||||
}
|
||||
metrics.UserRPCCounter.WithLabelValues(username).Inc()
|
||||
}
|
||||
}
|
||||
return ctx, nil
|
||||
|
|
|
@ -14,6 +14,14 @@ import (
|
|||
|
||||
// validAuth validates the authentication
|
||||
func TestValidAuth(t *testing.T) {
|
||||
validAuth := func(ctx context.Context, authorization []string) bool {
|
||||
username, password := parseMD(authorization)
|
||||
if username == "" || password == "" {
|
||||
return false
|
||||
}
|
||||
return passwordVerify(ctx, username, password, globalMetaCache)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
Params.InitOnce()
|
||||
// no metadata
|
||||
|
|
|
@ -309,6 +309,17 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
existedCollInfos, err := t.core.meta.ListCollections(ctx, typeutil.MaxTimestamp)
|
||||
if err != nil {
|
||||
log.Warn("fail to list collections for checking the collection count", zap.Error(err))
|
||||
return fmt.Errorf("fail to list collections for checking the collection count")
|
||||
}
|
||||
if len(existedCollInfos) >= Params.QuotaConfig.MaxCollectionNum {
|
||||
errMsg := "unable to create collection because the number of collection has reached the limit"
|
||||
log.Error(errMsg, zap.Int("max_collection_num", Params.QuotaConfig.MaxCollectionNum))
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
undoTask := newBaseUndoTask(t.core.stepExecutor)
|
||||
undoTask.AddStep(&expireCacheStep{
|
||||
baseStep: baseStep{core: t.core},
|
||||
|
|
|
@ -325,6 +325,9 @@ func Test_createCollectionTask_Execute(t *testing.T) {
|
|||
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) {
|
||||
return coll, nil
|
||||
}
|
||||
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||
return []*model.Collection{}, nil
|
||||
}
|
||||
|
||||
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker))
|
||||
|
||||
|
@ -370,6 +373,9 @@ func Test_createCollectionTask_Execute(t *testing.T) {
|
|||
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) {
|
||||
return coll, nil
|
||||
}
|
||||
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||
return []*model.Collection{}, nil
|
||||
}
|
||||
|
||||
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker))
|
||||
|
||||
|
@ -470,6 +476,25 @@ func Test_createCollectionTask_Execute(t *testing.T) {
|
|||
schema: schema,
|
||||
}
|
||||
|
||||
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||
return nil, errors.New("mock error")
|
||||
}
|
||||
err = task.Execute(context.Background())
|
||||
assert.Error(t, err)
|
||||
|
||||
originValue := Params.QuotaConfig.MaxCollectionNum
|
||||
Params.QuotaConfig.MaxCollectionNum = 10
|
||||
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||
maxNum := Params.QuotaConfig.MaxCollectionNum
|
||||
return make([]*model.Collection, maxNum), nil
|
||||
}
|
||||
err = task.Execute(context.Background())
|
||||
assert.Error(t, err)
|
||||
Params.QuotaConfig.MaxCollectionNum = originValue
|
||||
|
||||
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||
return []*model.Collection{}, nil
|
||||
}
|
||||
err = task.Execute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
@ -491,6 +516,9 @@ func Test_createCollectionTask_Execute(t *testing.T) {
|
|||
meta.AddCollectionFunc = func(ctx context.Context, coll *model.Collection) error {
|
||||
return nil
|
||||
}
|
||||
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||
return []*model.Collection{}, nil
|
||||
}
|
||||
// inject error here.
|
||||
meta.ChangeCollectionStateFunc = func(ctx context.Context, collectionID UniqueID, state etcdpb.CollectionState, ts Timestamp) error {
|
||||
return errors.New("error mock ChangeCollectionState")
|
||||
|
|
|
@ -425,7 +425,7 @@ func (p *quotaConfig) initDQLMinQueryRate() {
|
|||
}
|
||||
|
||||
func (p *quotaConfig) initMaxCollectionNum() {
|
||||
p.MaxCollectionNum = p.Base.ParseIntWithDefault("quotaAndLimits.limits.collection.maxNum", 64)
|
||||
p.MaxCollectionNum = p.Base.ParseIntWithDefault("quotaAndLimits.limits.collection.maxNum", 65535)
|
||||
}
|
||||
|
||||
func (p *quotaConfig) initForceDenyWriting() {
|
||||
|
|
|
@ -66,7 +66,7 @@ func TestQuotaParam(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("test limits", func(t *testing.T) {
|
||||
assert.Equal(t, 64, qc.MaxCollectionNum)
|
||||
assert.Equal(t, 65535, qc.MaxCollectionNum)
|
||||
})
|
||||
|
||||
t.Run("test limit writing", func(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue