mirror of https://github.com/milvus-io/milvus.git
Check the collection num when creating the collection (#22946)
Signed-off-by: SimFG <bang.fu@zilliz.com>pull/22962/head
parent
af458017f4
commit
e8f8c1b445
|
@ -4,33 +4,31 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus/internal/log"
|
||||||
"github.com/milvus-io/milvus/internal/metrics"
|
"github.com/milvus-io/milvus/internal/metrics"
|
||||||
"github.com/milvus-io/milvus/internal/util"
|
"github.com/milvus-io/milvus/internal/util"
|
||||||
"github.com/milvus-io/milvus/internal/util/crypto"
|
"github.com/milvus-io/milvus/internal/util/crypto"
|
||||||
|
"go.uber.org/zap"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
// validAuth validates the authentication
|
func parseMD(authorization []string) (username, password string) {
|
||||||
func validAuth(ctx context.Context, authorization []string) bool {
|
|
||||||
if len(authorization) < 1 {
|
if len(authorization) < 1 {
|
||||||
//log.Warn("key not found in header", zap.String("key", headerAuthorize))
|
log.Warn("key not found in header")
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
// token format: base64<username:password>
|
// token format: base64<username:password>
|
||||||
//token := strings.TrimPrefix(authorization[0], "Bearer ")
|
//token := strings.TrimPrefix(authorization[0], "Bearer ")
|
||||||
token := authorization[0]
|
token := authorization[0]
|
||||||
rawToken, err := crypto.Base64Decode(token)
|
rawToken, err := crypto.Base64Decode(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
log.Warn("fail to decode the token", zap.Error(err))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2)
|
secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2)
|
||||||
username := secrets[0]
|
username = secrets[0]
|
||||||
password := secrets[1]
|
password = secrets[1]
|
||||||
isSuccess := passwordVerify(ctx, username, password, globalMetaCache)
|
return
|
||||||
if isSuccess {
|
|
||||||
metrics.UserRPCCounter.WithLabelValues(username).Inc()
|
|
||||||
}
|
|
||||||
return isSuccess
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func validSourceID(ctx context.Context, authorization []string) bool {
|
func validSourceID(ctx context.Context, authorization []string) bool {
|
||||||
|
@ -62,10 +60,13 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) {
|
||||||
// 1. if rpc call from a member (like index/query/data component)
|
// 1. if rpc call from a member (like index/query/data component)
|
||||||
// 2. if rpc call from sdk
|
// 2. if rpc call from sdk
|
||||||
if Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
if Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
||||||
if !validSourceID(ctx, md[strings.ToLower(util.HeaderSourceID)]) &&
|
if !validSourceID(ctx, md[strings.ToLower(util.HeaderSourceID)]) {
|
||||||
!validAuth(ctx, md[strings.ToLower(util.HeaderAuthorize)]) {
|
username, password := parseMD(md[strings.ToLower(util.HeaderAuthorize)])
|
||||||
|
if !passwordVerify(ctx, username, password, globalMetaCache) {
|
||||||
return nil, ErrUnauthenticated()
|
return nil, ErrUnauthenticated()
|
||||||
}
|
}
|
||||||
|
metrics.UserRPCCounter.WithLabelValues(username).Inc()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,18 +4,24 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"google.golang.org/grpc/metadata"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/internal/util"
|
"github.com/milvus-io/milvus/internal/util"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/util/crypto"
|
"github.com/milvus-io/milvus/internal/util/crypto"
|
||||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
// validAuth validates the authentication
|
// validAuth validates the authentication
|
||||||
func TestValidAuth(t *testing.T) {
|
func TestValidAuth(t *testing.T) {
|
||||||
|
validAuth := func(ctx context.Context, authorization []string) bool {
|
||||||
|
username, password := parseMD(authorization)
|
||||||
|
if username == "" || password == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return passwordVerify(ctx, username, password, globalMetaCache)
|
||||||
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
// no metadata
|
// no metadata
|
||||||
res := validAuth(ctx, nil)
|
res := validAuth(ctx, nil)
|
||||||
|
|
|
@ -22,8 +22,6 @@ import (
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"go.uber.org/zap"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||||
|
@ -34,7 +32,9 @@ import (
|
||||||
pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
|
pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
|
||||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/merr"
|
||||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
type collectionChannels struct {
|
type collectionChannels struct {
|
||||||
|
@ -287,6 +287,17 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
existedCollInfos, err := t.core.meta.ListCollections(ctx, typeutil.MaxTimestamp)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("fail to list collections for checking the collection count", zap.Error(err))
|
||||||
|
return fmt.Errorf("fail to list collections for checking the collection count")
|
||||||
|
}
|
||||||
|
maxCollectionNum := Params.QuotaConfig.MaxCollectionNum.GetAsInt()
|
||||||
|
if len(existedCollInfos) >= maxCollectionNum {
|
||||||
|
log.Error("unable to create collection because the number of collection has reached the limit", zap.Int("max_collection_num", maxCollectionNum))
|
||||||
|
return merr.WrapErrCollectionResourceLimitExceeded(fmt.Sprintf("Failed to create collection, limit={%d}", maxCollectionNum))
|
||||||
|
}
|
||||||
|
|
||||||
undoTask := newBaseUndoTask(t.core.stepExecutor)
|
undoTask := newBaseUndoTask(t.core.stepExecutor)
|
||||||
undoTask.AddStep(&expireCacheStep{
|
undoTask.AddStep(&expireCacheStep{
|
||||||
baseStep: baseStep{core: t.core},
|
baseStep: baseStep{core: t.core},
|
||||||
|
|
|
@ -22,11 +22,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
|
|
||||||
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||||
|
@ -34,8 +29,10 @@ import (
|
||||||
"github.com/milvus-io/milvus/internal/metastore/model"
|
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/etcdpb"
|
"github.com/milvus-io/milvus/internal/proto/etcdpb"
|
||||||
|
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
|
||||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_createCollectionTask_validate(t *testing.T) {
|
func Test_createCollectionTask_validate(t *testing.T) {
|
||||||
|
@ -311,6 +308,9 @@ func Test_createCollectionTask_Execute(t *testing.T) {
|
||||||
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) {
|
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) {
|
||||||
return coll, nil
|
return coll, nil
|
||||||
}
|
}
|
||||||
|
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||||
|
return []*model.Collection{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker))
|
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker))
|
||||||
|
|
||||||
|
@ -356,6 +356,9 @@ func Test_createCollectionTask_Execute(t *testing.T) {
|
||||||
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) {
|
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) {
|
||||||
return coll, nil
|
return coll, nil
|
||||||
}
|
}
|
||||||
|
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||||
|
return []*model.Collection{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker))
|
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker))
|
||||||
|
|
||||||
|
@ -456,6 +459,28 @@ func Test_createCollectionTask_Execute(t *testing.T) {
|
||||||
schema: schema,
|
schema: schema,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||||
|
return nil, errors.New("mock error")
|
||||||
|
}
|
||||||
|
err = task.Execute(context.Background())
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
originFormatter := Params.QuotaConfig.MaxCollectionNum.Formatter
|
||||||
|
Params.QuotaConfig.MaxCollectionNum.Formatter = func(originValue string) string {
|
||||||
|
return "10"
|
||||||
|
}
|
||||||
|
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||||
|
maxNum := Params.QuotaConfig.MaxCollectionNum.GetAsInt()
|
||||||
|
return make([]*model.Collection, maxNum), nil
|
||||||
|
}
|
||||||
|
err = task.Execute(context.Background())
|
||||||
|
assert.Error(t, err)
|
||||||
|
Params.QuotaConfig.MaxCollectionNum.Formatter = originFormatter
|
||||||
|
|
||||||
|
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||||
|
return []*model.Collection{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
err = task.Execute(context.Background())
|
err = task.Execute(context.Background())
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
@ -477,6 +502,9 @@ func Test_createCollectionTask_Execute(t *testing.T) {
|
||||||
meta.AddCollectionFunc = func(ctx context.Context, coll *model.Collection) error {
|
meta.AddCollectionFunc = func(ctx context.Context, coll *model.Collection) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
|
||||||
|
return []*model.Collection{}, nil
|
||||||
|
}
|
||||||
// inject error here.
|
// inject error here.
|
||||||
meta.ChangeCollectionStateFunc = func(ctx context.Context, collectionID UniqueID, state etcdpb.CollectionState, ts Timestamp) error {
|
meta.ChangeCollectionStateFunc = func(ctx context.Context, collectionID UniqueID, state etcdpb.CollectionState, ts Timestamp) error {
|
||||||
return errors.New("error mock ChangeCollectionState")
|
return errors.New("error mock ChangeCollectionState")
|
||||||
|
|
|
@ -19,10 +19,9 @@ package merr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
"github.com/samber/lo"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||||
|
"github.com/samber/lo"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -58,6 +57,7 @@ var (
|
||||||
// Collection related
|
// Collection related
|
||||||
ErrCollectionNotFound = newMilvusError("collection not found", 100, false)
|
ErrCollectionNotFound = newMilvusError("collection not found", 100, false)
|
||||||
ErrCollectionNotLoaded = newMilvusError("collection not loaded", 101, false)
|
ErrCollectionNotLoaded = newMilvusError("collection not loaded", 101, false)
|
||||||
|
ErrCollectionNumLimitExceeded = newMilvusError("exceeded the limit number of collections", 102, false)
|
||||||
|
|
||||||
// Partition related
|
// Partition related
|
||||||
ErrPartitionNotFound = newMilvusError("partition not found", 202, false)
|
ErrPartitionNotFound = newMilvusError("partition not found", 202, false)
|
||||||
|
|
|
@ -108,7 +108,6 @@ func (s *ErrSuite) TestWrap() {
|
||||||
|
|
||||||
// Metrics related
|
// Metrics related
|
||||||
s.ErrorIs(WrapErrMetricNotFound("unknown", "failed to get metric"), ErrMetricNotFound)
|
s.ErrorIs(WrapErrMetricNotFound("unknown", "failed to get metric"), ErrMetricNotFound)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ErrSuite) TestCombine() {
|
func (s *ErrSuite) TestCombine() {
|
||||||
|
|
|
@ -165,6 +165,14 @@ func WrapErrCollectionNotLoaded(collection any, msg ...string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WrapErrCollectionResourceLimitExceeded(msg ...string) error {
|
||||||
|
var err error = ErrCollectionNumLimitExceeded
|
||||||
|
if len(msg) > 0 {
|
||||||
|
err = errors.Wrap(err, strings.Join(msg, "; "))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Partition related
|
// Partition related
|
||||||
func WrapErrPartitionNotFound(partition any, msg ...string) error {
|
func WrapErrPartitionNotFound(partition any, msg ...string) error {
|
||||||
err := wrapWithField(ErrPartitionNotFound, "partition", partition)
|
err := wrapWithField(ErrPartitionNotFound, "partition", partition)
|
||||||
|
|
|
@ -497,7 +497,7 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
||||||
p.MaxCollectionNum = ParamItem{
|
p.MaxCollectionNum = ParamItem{
|
||||||
Key: "quotaAndLimits.limits.collection.maxNum",
|
Key: "quotaAndLimits.limits.collection.maxNum",
|
||||||
Version: "2.2.0",
|
Version: "2.2.0",
|
||||||
DefaultValue: "64",
|
DefaultValue: "65535",
|
||||||
}
|
}
|
||||||
p.MaxCollectionNum.Init(base.mgr)
|
p.MaxCollectionNum.Init(base.mgr)
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ func TestQuotaParam(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("test limits", func(t *testing.T) {
|
t.Run("test limits", func(t *testing.T) {
|
||||||
assert.Equal(t, 64, qc.MaxCollectionNum.GetAsInt())
|
assert.Equal(t, 65535, qc.MaxCollectionNum.GetAsInt())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("test limit writing", func(t *testing.T) {
|
t.Run("test limit writing", func(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue