mirror of https://github.com/milvus-io/milvus.git
enhance: the return result of list db api (#31544)
issue: #31543 Signed-off-by: SimFG <bang.fu@zilliz.com>pull/31680/head
parent
b1a1cca10b
commit
8f3e0b6b41
|
@ -1041,7 +1041,7 @@ func (kc *Catalog) ListGrant(ctx context.Context, tenant string, entity *milvusp
|
|||
appendGrantEntity := func(v string, object string, objectName string) error {
|
||||
dbName := ""
|
||||
dbName, objectName = funcutil.SplitObjectName(objectName)
|
||||
if dbName != entity.DbName && dbName != util.AnyWord {
|
||||
if dbName != entity.DbName && dbName != util.AnyWord && entity.DbName != util.AnyWord {
|
||||
return nil
|
||||
}
|
||||
granteeIDKey := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, v)
|
||||
|
|
|
@ -2408,7 +2408,7 @@ func TestRBAC_Grant(t *testing.T) {
|
|||
{true, &milvuspb.GrantEntity{
|
||||
DbName: "*",
|
||||
Role: &milvuspb.RoleEntity{Name: "role1"},
|
||||
}, "valid role and any dbName without object", 2},
|
||||
}, "valid role and any dbName without object", 6},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
|
|
@ -2,12 +2,18 @@ package proxy
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
|
@ -205,7 +211,14 @@ func (ldt *listDatabaseTask) PreExecute(ctx context.Context) error {
|
|||
|
||||
func (ldt *listDatabaseTask) Execute(ctx context.Context) error {
|
||||
var err error
|
||||
ldt.result, err = ldt.rootCoord.ListDatabases(ctx, ldt.ListDatabasesRequest)
|
||||
curUser, _ := GetCurUserFromContext(ldt.ctx)
|
||||
if curUser != "" {
|
||||
originValue := fmt.Sprintf("%s%s%s", curUser, util.CredentialSeperator, curUser)
|
||||
authKey := strings.ToLower(util.HeaderAuthorize)
|
||||
authValue := crypto.Base64Encode(originValue)
|
||||
ldt.ctx = metadata.AppendToOutgoingContext(ldt.ctx, authKey, authValue)
|
||||
}
|
||||
ldt.result, err = ldt.rootCoord.ListDatabases(ldt.ctx, ldt.ListDatabasesRequest)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -2,13 +2,17 @@ package proxy
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
|
@ -118,7 +122,7 @@ func TestListDatabaseTask(t *testing.T) {
|
|||
rc := NewRootCoordMock()
|
||||
defer rc.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := GetContext(context.Background(), "root:123456")
|
||||
task := &listDatabaseTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ListDatabasesRequest: &milvuspb.ListDatabasesRequest{
|
||||
|
@ -149,5 +153,12 @@ func TestListDatabaseTask(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID())
|
||||
assert.Equal(t, UniqueID(0), task.ID())
|
||||
|
||||
md, ok := metadata.FromOutgoingContext(task.ctx)
|
||||
assert.True(t, ok)
|
||||
authorization, ok := md[strings.ToLower(util.HeaderAuthorize)]
|
||||
assert.True(t, ok)
|
||||
expectAuth := crypto.Base64Encode("root:root")
|
||||
assert.Equal(t, expectAuth, authorization[0])
|
||||
})
|
||||
}
|
||||
|
|
|
@ -882,25 +882,7 @@ func ValidatePrivilege(entity string) error {
|
|||
}
|
||||
|
||||
func GetCurUserFromContext(ctx context.Context) (string, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("fail to get md from the context")
|
||||
}
|
||||
authorization, ok := md[strings.ToLower(util.HeaderAuthorize)]
|
||||
if !ok || len(authorization) < 1 {
|
||||
return "", fmt.Errorf("fail to get authorization from the md, %s:[token]", strings.ToLower(util.HeaderAuthorize))
|
||||
}
|
||||
token := authorization[0]
|
||||
rawToken, err := crypto.Base64Decode(token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fail to decode the token, token: %s", token)
|
||||
}
|
||||
secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2)
|
||||
if len(secrets) < 2 {
|
||||
return "", fmt.Errorf("fail to get user info from the raw token, raw token: %s", rawToken)
|
||||
}
|
||||
username := secrets[0]
|
||||
return username, nil
|
||||
return contextutil.GetCurUserFromContext(ctx)
|
||||
}
|
||||
|
||||
func GetCurUserFromContextOrDefault(ctx context.Context) string {
|
||||
|
|
|
@ -20,7 +20,10 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/contextutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type listDatabaseTask struct {
|
||||
|
@ -35,6 +38,67 @@ func (t *listDatabaseTask) Prepare(ctx context.Context) error {
|
|||
|
||||
func (t *listDatabaseTask) Execute(ctx context.Context) error {
|
||||
t.Resp.Status = merr.Success()
|
||||
|
||||
getVisibleDBs := func() (typeutil.Set[string], error) {
|
||||
enableAuth := Params.CommonCfg.AuthorizationEnabled.GetAsBool()
|
||||
privilegeDBs := typeutil.NewSet[string]()
|
||||
if !enableAuth {
|
||||
privilegeDBs.Insert(util.AnyWord)
|
||||
return privilegeDBs, nil
|
||||
}
|
||||
curUser, err := contextutil.GetCurUserFromContext(ctx)
|
||||
// it will fail if the inner node server use the list database API
|
||||
if err != nil || curUser == util.UserRoot {
|
||||
privilegeDBs.Insert(util.AnyWord)
|
||||
return privilegeDBs, nil
|
||||
}
|
||||
userRoles, err := t.core.meta.SelectUser("", &milvuspb.UserEntity{
|
||||
Name: curUser,
|
||||
}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(userRoles) == 0 {
|
||||
return privilegeDBs, nil
|
||||
}
|
||||
for _, role := range userRoles[0].Roles {
|
||||
if role.GetName() == util.RoleAdmin {
|
||||
privilegeDBs.Insert(util.AnyWord)
|
||||
return privilegeDBs, nil
|
||||
}
|
||||
entities, err := t.core.meta.SelectGrant("", &milvuspb.GrantEntity{
|
||||
Role: role,
|
||||
DbName: util.AnyWord,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, entity := range entities {
|
||||
privilegeDBs.Insert(entity.GetDbName())
|
||||
if entity.GetDbName() == util.AnyWord {
|
||||
return privilegeDBs, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return privilegeDBs, nil
|
||||
}
|
||||
|
||||
isVisibleDBForCurUser := func(dbName string, visibleDBs typeutil.Set[string]) bool {
|
||||
if visibleDBs.Contain(util.AnyWord) {
|
||||
return true
|
||||
}
|
||||
return visibleDBs.Contain(dbName)
|
||||
}
|
||||
|
||||
visibleDBs, err := getVisibleDBs()
|
||||
if err != nil {
|
||||
t.Resp.Status = merr.Status(err)
|
||||
return err
|
||||
}
|
||||
if len(visibleDBs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ret, err := t.core.meta.ListDatabases(ctx, t.GetTs())
|
||||
if err != nil {
|
||||
t.Resp.Status = merr.Status(err)
|
||||
|
@ -44,6 +108,9 @@ func (t *listDatabaseTask) Execute(ctx context.Context) error {
|
|||
dbNames := make([]string, 0, len(ret))
|
||||
createdTimes := make([]uint64, 0, len(ret))
|
||||
for _, db := range ret {
|
||||
if !isVisibleDBForCurUser(db.Name, visibleDBs) {
|
||||
continue
|
||||
}
|
||||
dbNames = append(dbNames, db.Name)
|
||||
createdTimes = append(createdTimes, db.CreatedTime)
|
||||
}
|
||||
|
|
|
@ -18,18 +18,25 @@ package rootcoord
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
func Test_ListDBTask(t *testing.T) {
|
||||
paramtable.Init()
|
||||
t.Run("list db fails", func(t *testing.T) {
|
||||
core := newTestCore(withInvalidMeta())
|
||||
task := &listDatabaseTask{
|
||||
|
@ -78,4 +85,199 @@ func Test_ListDBTask(t *testing.T) {
|
|||
assert.Equal(t, ret[0].Name, task.Resp.GetDbNames()[0])
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, task.Resp.GetStatus().GetErrorCode())
|
||||
})
|
||||
|
||||
t.Run("list db with auth", func(t *testing.T) {
|
||||
Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
||||
defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key)
|
||||
ret := []*model.Database{model.NewDefaultDatabase()}
|
||||
meta := mockrootcoord.NewIMetaTable(t)
|
||||
|
||||
core := newTestCore(withMeta(meta))
|
||||
getTask := func() *listDatabaseTask {
|
||||
return &listDatabaseTask{
|
||||
baseTask: newBaseTask(context.TODO(), core),
|
||||
Req: &milvuspb.ListDatabasesRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ListDatabases,
|
||||
},
|
||||
},
|
||||
Resp: &milvuspb.ListDatabasesResponse{},
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// inner node
|
||||
meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(ret, nil).Once()
|
||||
|
||||
task := getTask()
|
||||
err := task.Execute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(task.Resp.GetDbNames()))
|
||||
assert.Equal(t, ret[0].Name, task.Resp.GetDbNames()[0])
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, task.Resp.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
{
|
||||
// proxy node with root user
|
||||
meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(ret, nil).Once()
|
||||
|
||||
ctx := GetContext(context.Background(), "root:root")
|
||||
task := getTask()
|
||||
err := task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(task.Resp.GetDbNames()))
|
||||
assert.Equal(t, ret[0].Name, task.Resp.GetDbNames()[0])
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, task.Resp.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
{
|
||||
// select role fail
|
||||
meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(nil, errors.New("mock select user error")).Once()
|
||||
ctx := GetContext(context.Background(), "foo:root")
|
||||
task := getTask()
|
||||
err := task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
// select role, empty result
|
||||
meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return([]*milvuspb.UserResult{}, nil).Once()
|
||||
ctx := GetContext(context.Background(), "foo:root")
|
||||
task := getTask()
|
||||
err := task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(task.Resp.GetDbNames()))
|
||||
}
|
||||
|
||||
{
|
||||
// select role, the user is added to admin role
|
||||
meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return([]*milvuspb.UserResult{
|
||||
{
|
||||
User: &milvuspb.UserEntity{
|
||||
Name: "foo",
|
||||
},
|
||||
Roles: []*milvuspb.RoleEntity{
|
||||
{
|
||||
Name: "admin",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil).Once()
|
||||
meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(ret, nil).Once()
|
||||
ctx := GetContext(context.Background(), "foo:root")
|
||||
task := getTask()
|
||||
err := task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(task.Resp.GetDbNames()))
|
||||
}
|
||||
|
||||
{
|
||||
// select grant fail
|
||||
meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return([]*milvuspb.UserResult{
|
||||
{
|
||||
User: &milvuspb.UserEntity{
|
||||
Name: "foo",
|
||||
},
|
||||
Roles: []*milvuspb.RoleEntity{
|
||||
{
|
||||
Name: "hoo",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil).Once()
|
||||
meta.EXPECT().SelectGrant(mock.Anything, mock.Anything).
|
||||
Return(nil, errors.New("mock select grant error")).Once()
|
||||
ctx := GetContext(context.Background(), "foo:root")
|
||||
task := getTask()
|
||||
err := task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
// normal user
|
||||
meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return([]*milvuspb.UserResult{
|
||||
{
|
||||
User: &milvuspb.UserEntity{
|
||||
Name: "foo",
|
||||
},
|
||||
Roles: []*milvuspb.RoleEntity{
|
||||
{
|
||||
Name: "hoo",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil).Once()
|
||||
meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return([]*model.Database{
|
||||
{
|
||||
Name: "fooDB",
|
||||
},
|
||||
{
|
||||
Name: "default",
|
||||
},
|
||||
}, nil).Once()
|
||||
meta.EXPECT().SelectGrant(mock.Anything, mock.Anything).
|
||||
Return([]*milvuspb.GrantEntity{
|
||||
{
|
||||
DbName: "fooDB",
|
||||
},
|
||||
}, nil).Once()
|
||||
ctx := GetContext(context.Background(), "foo:root")
|
||||
task := getTask()
|
||||
err := task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(task.Resp.GetDbNames()))
|
||||
assert.Equal(t, "fooDB", task.Resp.GetDbNames()[0])
|
||||
}
|
||||
|
||||
{
|
||||
// normal user with any db privilege
|
||||
meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return([]*milvuspb.UserResult{
|
||||
{
|
||||
User: &milvuspb.UserEntity{
|
||||
Name: "foo",
|
||||
},
|
||||
Roles: []*milvuspb.RoleEntity{
|
||||
{
|
||||
Name: "hoo",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil).Once()
|
||||
meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return([]*model.Database{
|
||||
{
|
||||
Name: "fooDB",
|
||||
},
|
||||
{
|
||||
Name: "default",
|
||||
},
|
||||
}, nil).Once()
|
||||
meta.EXPECT().SelectGrant(mock.Anything, mock.Anything).
|
||||
Return([]*milvuspb.GrantEntity{
|
||||
{
|
||||
DbName: "*",
|
||||
},
|
||||
}, nil).Once()
|
||||
ctx := GetContext(context.Background(), "foo:root")
|
||||
task := getTask()
|
||||
err := task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(task.Resp.GetDbNames()))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func GetContext(ctx context.Context, originValue string) context.Context {
|
||||
authKey := strings.ToLower(util.HeaderAuthorize)
|
||||
authValue := crypto.Base64Encode(originValue)
|
||||
contextMap := map[string]string{
|
||||
authKey: authValue,
|
||||
}
|
||||
md := metadata.New(contextMap)
|
||||
return metadata.NewIncomingContext(ctx, md)
|
||||
}
|
||||
|
|
|
@ -19,8 +19,12 @@ package contextutil
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||
)
|
||||
|
||||
type ctxTenantKey struct{}
|
||||
|
@ -58,3 +62,25 @@ func AppendToIncomingContext(ctx context.Context, kv ...string) context.Context
|
|||
}
|
||||
return metadata.NewIncomingContext(ctx, md)
|
||||
}
|
||||
|
||||
func GetCurUserFromContext(ctx context.Context) (string, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("fail to get md from the context")
|
||||
}
|
||||
authorization, ok := md[strings.ToLower(util.HeaderAuthorize)]
|
||||
if !ok || len(authorization) < 1 {
|
||||
return "", fmt.Errorf("fail to get authorization from the md, %s:[token]", strings.ToLower(util.HeaderAuthorize))
|
||||
}
|
||||
token := authorization[0]
|
||||
rawToken, err := crypto.Base64Decode(token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fail to decode the token, token: %s", token)
|
||||
}
|
||||
secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2)
|
||||
if len(secrets) < 2 {
|
||||
return "", fmt.Errorf("fail to get user info from the raw token, raw token: %s", rawToken)
|
||||
}
|
||||
username := secrets[0]
|
||||
return username, nil
|
||||
}
|
||||
|
|
|
@ -20,10 +20,15 @@ package contextutil
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||
)
|
||||
|
||||
func TestAppendToIncomingContext(t *testing.T) {
|
||||
|
@ -42,3 +47,30 @@ func TestAppendToIncomingContext(t *testing.T) {
|
|||
assert.Equal(t, "bar", md.Get("foo")[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetCurUserFromContext(t *testing.T) {
|
||||
_, err := GetCurUserFromContext(context.Background())
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = GetCurUserFromContext(metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{})))
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = GetCurUserFromContext(GetContext(context.Background(), "123456"))
|
||||
assert.Error(t, err)
|
||||
|
||||
root := "root"
|
||||
password := "123456"
|
||||
username, err := GetCurUserFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password)))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "root", username)
|
||||
}
|
||||
|
||||
func GetContext(ctx context.Context, originValue string) context.Context {
|
||||
authKey := strings.ToLower(util.HeaderAuthorize)
|
||||
authValue := crypto.Base64Encode(originValue)
|
||||
contextMap := map[string]string{
|
||||
authKey: authValue,
|
||||
}
|
||||
md := metadata.New(contextMap)
|
||||
return metadata.NewIncomingContext(ctx, md)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue