Fix the unsafe casbin `Model` (#21129)

Signed-off-by: SimFG <bang.fu@zilliz.com>

Signed-off-by: SimFG <bang.fu@zilliz.com>
pull/21218/head
SimFG 2022-12-14 10:29:22 +08:00 committed by GitHub
parent 655db658c8
commit f31d5facff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 23 deletions

View File

@ -338,7 +338,6 @@ func (s *Server) init() error {
s.etcdCli = etcdCli
s.proxy.SetEtcdClient(s.etcdCli)
s.proxy.SetAddress(fmt.Sprintf("%s:%d", Params.IP, Params.InternalPort))
proxy.InitPolicyModel()
errChan := make(chan error, 1)
{

View File

@ -625,7 +625,7 @@ func (kc *Catalog) CreateRole(ctx context.Context, tenant string, entity *milvus
k := funcutil.HandleTenantForEtcdKey(RolePrefix, tenant, entity.Name)
err := kc.save(k)
if err != nil {
log.Error("fail to save the role", zap.String("key", k), zap.Error(err))
log.Warn("fail to save the role", zap.String("key", k), zap.Error(err))
}
return err
}

View File

@ -40,16 +40,14 @@ m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.su
`
)
var (
casbinModel model.Model
)
var templateModel = getPolicyModel(ModelStr)
func InitPolicyModel() {
var err error
casbinModel, err = model.NewModelFromString(ModelStr)
func getPolicyModel(modelString string) model.Model {
m, err := model.NewModelFromString(modelString)
if err != nil {
log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err))
}
return m
}
// UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access.
@ -107,9 +105,8 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
policy := fmt.Sprintf("[%s]", policyInfo)
b := []byte(policy)
a := jsonadapter.NewAdapter(&b)
if casbinModel == nil {
log.Panic("fail to get policy model")
}
// the `templateModel` object isn't safe in the concurrent situation
casbinModel := templateModel.Copy()
e, err := casbin.NewEnforcer(casbinModel, a)
if err != nil {
log.Error("NewEnforcer fail", zap.String("policy", policy), zap.Error(err))

View File

@ -2,14 +2,14 @@ package proxy
import (
"context"
"sync"
"testing"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/stretchr/testify/assert"
)
@ -22,7 +22,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
ctx := context.Background()
t.Run("Authorization Disabled", func(t *testing.T) {
InitPolicyModel()
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "false")
_, err := PrivilegeInterceptor(ctx, &milvuspb.LoadCollectionRequest{
DbName: "db_test",
@ -32,7 +31,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
})
t.Run("Authorization Enabled", func(t *testing.T) {
InitPolicyModel()
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
_, err := PrivilegeInterceptor(ctx, &milvuspb.HasCollectionRequest{})
@ -114,12 +112,23 @@ func TestPrivilegeInterceptor(t *testing.T) {
})
assert.Nil(t, err)
casbinModel = nil
g := sync.WaitGroup{}
for i := 0; i < 20; i++ {
g.Add(1)
go func() {
defer g.Done()
assert.NotPanics(t, func() {
PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
DbName: "db_test",
CollectionName: "col1",
})
})
}()
}
g.Wait()
assert.Panics(t, func() {
PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
DbName: "db_test",
CollectionName: "col1",
})
getPolicyModel("foo")
})
})

View File

@ -218,6 +218,7 @@ func TestExpireAfterAccess(t *testing.T) {
}
mockTime := newMockTime()
currentTime = mockTime.now
defer resetCurrentTime()
c := NewCache(WithExpireAfterAccess[uint, uint](1*time.Second), WithRemovalListener(fn),
WithInsertionListener(fn)).(*localCache[uint, uint])
defer c.Close()
@ -258,6 +259,7 @@ func TestExpireAfterWrite(t *testing.T) {
mockTime := newMockTime()
currentTime = mockTime.now
defer resetCurrentTime()
c := NewLoadingCache(loader, WithExpireAfterWrite[string, int](1*time.Second))
defer c.Close()
@ -306,6 +308,7 @@ func TestRefreshAterWrite(t *testing.T) {
}
mockTime := newMockTime()
currentTime = mockTime.now
defer resetCurrentTime()
c := NewLoadingCache(loader,
WithExpireAfterAccess[int, int](4*time.Second),
WithRefreshAfterWrite[int, int](2*time.Second),
@ -347,6 +350,7 @@ func TestGetIfPresentExpired(t *testing.T) {
c := NewCache(WithExpireAfterWrite[int, string](1*time.Second), WithInsertionListener(insFunc))
mockTime := newMockTime()
currentTime = mockTime.now
defer resetCurrentTime()
v, ok := c.GetIfPresent(0)
assert.False(t, ok)
@ -409,20 +413,20 @@ func TestWithAsyncInitPreLoader(t *testing.T) {
func TestSynchronousReload(t *testing.T) {
var val string
loader := func(k int) (string, error) {
time.Sleep(1 * time.Millisecond)
if val == "" {
return "", errors.New("empty")
}
return val, nil
}
c := NewLoadingCache(loader, WithExpireAfterWrite[int, string](1*time.Second))
c := NewLoadingCache(loader, WithExpireAfterWrite[int, string](200*time.Millisecond))
val = "a"
v, err := c.Get(1)
assert.NoError(t, err)
assert.Equal(t, val, v)
val = "b"
time.Sleep(300 * time.Millisecond)
v, err = c.Get(1)
assert.NoError(t, err)
assert.Equal(t, val, v)
@ -500,3 +504,7 @@ func (t *mockTime) now() time.Time {
defer t.mu.RUnlock()
return t.value
}
func resetCurrentTime() {
currentTime = time.Now
}