Fix the unsafe casbin `Model` (#21129)

Signed-off-by: SimFG <bang.fu@zilliz.com>

Signed-off-by: SimFG <bang.fu@zilliz.com>
pull/21218/head
SimFG 2022-12-14 10:29:22 +08:00 committed by GitHub
parent 655db658c8
commit f31d5facff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 23 deletions

View File

@ -338,7 +338,6 @@ func (s *Server) init() error {
s.etcdCli = etcdCli s.etcdCli = etcdCli
s.proxy.SetEtcdClient(s.etcdCli) s.proxy.SetEtcdClient(s.etcdCli)
s.proxy.SetAddress(fmt.Sprintf("%s:%d", Params.IP, Params.InternalPort)) s.proxy.SetAddress(fmt.Sprintf("%s:%d", Params.IP, Params.InternalPort))
proxy.InitPolicyModel()
errChan := make(chan error, 1) errChan := make(chan error, 1)
{ {

View File

@ -625,7 +625,7 @@ func (kc *Catalog) CreateRole(ctx context.Context, tenant string, entity *milvus
k := funcutil.HandleTenantForEtcdKey(RolePrefix, tenant, entity.Name) k := funcutil.HandleTenantForEtcdKey(RolePrefix, tenant, entity.Name)
err := kc.save(k) err := kc.save(k)
if err != nil { if err != nil {
log.Error("fail to save the role", zap.String("key", k), zap.Error(err)) log.Warn("fail to save the role", zap.String("key", k), zap.Error(err))
} }
return err return err
} }

View File

@ -40,16 +40,14 @@ m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.su
` `
) )
var ( var templateModel = getPolicyModel(ModelStr)
casbinModel model.Model
)
func InitPolicyModel() { func getPolicyModel(modelString string) model.Model {
var err error m, err := model.NewModelFromString(modelString)
casbinModel, err = model.NewModelFromString(ModelStr)
if err != nil { if err != nil {
log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err)) log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err))
} }
return m
} }
// UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access. // UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access.
@ -107,9 +105,8 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
policy := fmt.Sprintf("[%s]", policyInfo) policy := fmt.Sprintf("[%s]", policyInfo)
b := []byte(policy) b := []byte(policy)
a := jsonadapter.NewAdapter(&b) a := jsonadapter.NewAdapter(&b)
if casbinModel == nil { // the `templateModel` object isn't safe in the concurrent situation
log.Panic("fail to get policy model") casbinModel := templateModel.Copy()
}
e, err := casbin.NewEnforcer(casbinModel, a) e, err := casbin.NewEnforcer(casbinModel, a)
if err != nil { if err != nil {
log.Error("NewEnforcer fail", zap.String("policy", policy), zap.Error(err)) log.Error("NewEnforcer fail", zap.String("policy", policy), zap.Error(err))

View File

@ -2,14 +2,14 @@ package proxy
import ( import (
"context" "context"
"sync"
"testing" "testing"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -22,7 +22,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("Authorization Disabled", func(t *testing.T) { t.Run("Authorization Disabled", func(t *testing.T) {
InitPolicyModel()
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "false") paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "false")
_, err := PrivilegeInterceptor(ctx, &milvuspb.LoadCollectionRequest{ _, err := PrivilegeInterceptor(ctx, &milvuspb.LoadCollectionRequest{
DbName: "db_test", DbName: "db_test",
@ -32,7 +31,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
}) })
t.Run("Authorization Enabled", func(t *testing.T) { t.Run("Authorization Enabled", func(t *testing.T) {
InitPolicyModel()
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
_, err := PrivilegeInterceptor(ctx, &milvuspb.HasCollectionRequest{}) _, err := PrivilegeInterceptor(ctx, &milvuspb.HasCollectionRequest{})
@ -114,12 +112,23 @@ func TestPrivilegeInterceptor(t *testing.T) {
}) })
assert.Nil(t, err) assert.Nil(t, err)
casbinModel = nil g := sync.WaitGroup{}
for i := 0; i < 20; i++ {
g.Add(1)
go func() {
defer g.Done()
assert.NotPanics(t, func() {
PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
DbName: "db_test",
CollectionName: "col1",
})
})
}()
}
g.Wait()
assert.Panics(t, func() { assert.Panics(t, func() {
PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{ getPolicyModel("foo")
DbName: "db_test",
CollectionName: "col1",
})
}) })
}) })

View File

@ -218,6 +218,7 @@ func TestExpireAfterAccess(t *testing.T) {
} }
mockTime := newMockTime() mockTime := newMockTime()
currentTime = mockTime.now currentTime = mockTime.now
defer resetCurrentTime()
c := NewCache(WithExpireAfterAccess[uint, uint](1*time.Second), WithRemovalListener(fn), c := NewCache(WithExpireAfterAccess[uint, uint](1*time.Second), WithRemovalListener(fn),
WithInsertionListener(fn)).(*localCache[uint, uint]) WithInsertionListener(fn)).(*localCache[uint, uint])
defer c.Close() defer c.Close()
@ -258,6 +259,7 @@ func TestExpireAfterWrite(t *testing.T) {
mockTime := newMockTime() mockTime := newMockTime()
currentTime = mockTime.now currentTime = mockTime.now
defer resetCurrentTime()
c := NewLoadingCache(loader, WithExpireAfterWrite[string, int](1*time.Second)) c := NewLoadingCache(loader, WithExpireAfterWrite[string, int](1*time.Second))
defer c.Close() defer c.Close()
@ -306,6 +308,7 @@ func TestRefreshAterWrite(t *testing.T) {
} }
mockTime := newMockTime() mockTime := newMockTime()
currentTime = mockTime.now currentTime = mockTime.now
defer resetCurrentTime()
c := NewLoadingCache(loader, c := NewLoadingCache(loader,
WithExpireAfterAccess[int, int](4*time.Second), WithExpireAfterAccess[int, int](4*time.Second),
WithRefreshAfterWrite[int, int](2*time.Second), WithRefreshAfterWrite[int, int](2*time.Second),
@ -347,6 +350,7 @@ func TestGetIfPresentExpired(t *testing.T) {
c := NewCache(WithExpireAfterWrite[int, string](1*time.Second), WithInsertionListener(insFunc)) c := NewCache(WithExpireAfterWrite[int, string](1*time.Second), WithInsertionListener(insFunc))
mockTime := newMockTime() mockTime := newMockTime()
currentTime = mockTime.now currentTime = mockTime.now
defer resetCurrentTime()
v, ok := c.GetIfPresent(0) v, ok := c.GetIfPresent(0)
assert.False(t, ok) assert.False(t, ok)
@ -409,20 +413,20 @@ func TestWithAsyncInitPreLoader(t *testing.T) {
func TestSynchronousReload(t *testing.T) { func TestSynchronousReload(t *testing.T) {
var val string var val string
loader := func(k int) (string, error) { loader := func(k int) (string, error) {
time.Sleep(1 * time.Millisecond)
if val == "" { if val == "" {
return "", errors.New("empty") return "", errors.New("empty")
} }
return val, nil return val, nil
} }
c := NewLoadingCache(loader, WithExpireAfterWrite[int, string](1*time.Second)) c := NewLoadingCache(loader, WithExpireAfterWrite[int, string](200*time.Millisecond))
val = "a" val = "a"
v, err := c.Get(1) v, err := c.Get(1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, val, v) assert.Equal(t, val, v)
val = "b" val = "b"
time.Sleep(300 * time.Millisecond)
v, err = c.Get(1) v, err = c.Get(1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, val, v) assert.Equal(t, val, v)
@ -500,3 +504,7 @@ func (t *mockTime) now() time.Time {
defer t.mu.RUnlock() defer t.mu.RUnlock()
return t.value return t.value
} }
func resetCurrentTime() {
currentTime = time.Now
}