mirror of https://github.com/milvus-io/milvus.git
Fix the unsafe casbin `Model` (#21129)
Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: SimFG <bang.fu@zilliz.com>pull/21218/head
parent
655db658c8
commit
f31d5facff
|
@ -338,7 +338,6 @@ func (s *Server) init() error {
|
||||||
s.etcdCli = etcdCli
|
s.etcdCli = etcdCli
|
||||||
s.proxy.SetEtcdClient(s.etcdCli)
|
s.proxy.SetEtcdClient(s.etcdCli)
|
||||||
s.proxy.SetAddress(fmt.Sprintf("%s:%d", Params.IP, Params.InternalPort))
|
s.proxy.SetAddress(fmt.Sprintf("%s:%d", Params.IP, Params.InternalPort))
|
||||||
proxy.InitPolicyModel()
|
|
||||||
|
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
{
|
{
|
||||||
|
|
|
@ -625,7 +625,7 @@ func (kc *Catalog) CreateRole(ctx context.Context, tenant string, entity *milvus
|
||||||
k := funcutil.HandleTenantForEtcdKey(RolePrefix, tenant, entity.Name)
|
k := funcutil.HandleTenantForEtcdKey(RolePrefix, tenant, entity.Name)
|
||||||
err := kc.save(k)
|
err := kc.save(k)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("fail to save the role", zap.String("key", k), zap.Error(err))
|
log.Warn("fail to save the role", zap.String("key", k), zap.Error(err))
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,16 +40,14 @@ m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.su
|
||||||
`
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var templateModel = getPolicyModel(ModelStr)
|
||||||
casbinModel model.Model
|
|
||||||
)
|
|
||||||
|
|
||||||
func InitPolicyModel() {
|
func getPolicyModel(modelString string) model.Model {
|
||||||
var err error
|
m, err := model.NewModelFromString(modelString)
|
||||||
casbinModel, err = model.NewModelFromString(ModelStr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err))
|
log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err))
|
||||||
}
|
}
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access.
|
// UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access.
|
||||||
|
@ -107,9 +105,8 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
|
||||||
policy := fmt.Sprintf("[%s]", policyInfo)
|
policy := fmt.Sprintf("[%s]", policyInfo)
|
||||||
b := []byte(policy)
|
b := []byte(policy)
|
||||||
a := jsonadapter.NewAdapter(&b)
|
a := jsonadapter.NewAdapter(&b)
|
||||||
if casbinModel == nil {
|
// the `templateModel` object isn't safe in the concurrent situation
|
||||||
log.Panic("fail to get policy model")
|
casbinModel := templateModel.Copy()
|
||||||
}
|
|
||||||
e, err := casbin.NewEnforcer(casbinModel, a)
|
e, err := casbin.NewEnforcer(casbinModel, a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("NewEnforcer fail", zap.String("policy", policy), zap.Error(err))
|
log.Error("NewEnforcer fail", zap.String("policy", policy), zap.Error(err))
|
||||||
|
|
|
@ -2,14 +2,14 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
|
||||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,7 +22,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
t.Run("Authorization Disabled", func(t *testing.T) {
|
t.Run("Authorization Disabled", func(t *testing.T) {
|
||||||
InitPolicyModel()
|
|
||||||
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "false")
|
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "false")
|
||||||
_, err := PrivilegeInterceptor(ctx, &milvuspb.LoadCollectionRequest{
|
_, err := PrivilegeInterceptor(ctx, &milvuspb.LoadCollectionRequest{
|
||||||
DbName: "db_test",
|
DbName: "db_test",
|
||||||
|
@ -32,7 +31,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Authorization Enabled", func(t *testing.T) {
|
t.Run("Authorization Enabled", func(t *testing.T) {
|
||||||
InitPolicyModel()
|
|
||||||
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
||||||
|
|
||||||
_, err := PrivilegeInterceptor(ctx, &milvuspb.HasCollectionRequest{})
|
_, err := PrivilegeInterceptor(ctx, &milvuspb.HasCollectionRequest{})
|
||||||
|
@ -114,12 +112,23 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
casbinModel = nil
|
g := sync.WaitGroup{}
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
g.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer g.Done()
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
|
||||||
|
DbName: "db_test",
|
||||||
|
CollectionName: "col1",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
g.Wait()
|
||||||
|
|
||||||
assert.Panics(t, func() {
|
assert.Panics(t, func() {
|
||||||
PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
|
getPolicyModel("foo")
|
||||||
DbName: "db_test",
|
|
||||||
CollectionName: "col1",
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -218,6 +218,7 @@ func TestExpireAfterAccess(t *testing.T) {
|
||||||
}
|
}
|
||||||
mockTime := newMockTime()
|
mockTime := newMockTime()
|
||||||
currentTime = mockTime.now
|
currentTime = mockTime.now
|
||||||
|
defer resetCurrentTime()
|
||||||
c := NewCache(WithExpireAfterAccess[uint, uint](1*time.Second), WithRemovalListener(fn),
|
c := NewCache(WithExpireAfterAccess[uint, uint](1*time.Second), WithRemovalListener(fn),
|
||||||
WithInsertionListener(fn)).(*localCache[uint, uint])
|
WithInsertionListener(fn)).(*localCache[uint, uint])
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
@ -258,6 +259,7 @@ func TestExpireAfterWrite(t *testing.T) {
|
||||||
|
|
||||||
mockTime := newMockTime()
|
mockTime := newMockTime()
|
||||||
currentTime = mockTime.now
|
currentTime = mockTime.now
|
||||||
|
defer resetCurrentTime()
|
||||||
c := NewLoadingCache(loader, WithExpireAfterWrite[string, int](1*time.Second))
|
c := NewLoadingCache(loader, WithExpireAfterWrite[string, int](1*time.Second))
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
|
@ -306,6 +308,7 @@ func TestRefreshAterWrite(t *testing.T) {
|
||||||
}
|
}
|
||||||
mockTime := newMockTime()
|
mockTime := newMockTime()
|
||||||
currentTime = mockTime.now
|
currentTime = mockTime.now
|
||||||
|
defer resetCurrentTime()
|
||||||
c := NewLoadingCache(loader,
|
c := NewLoadingCache(loader,
|
||||||
WithExpireAfterAccess[int, int](4*time.Second),
|
WithExpireAfterAccess[int, int](4*time.Second),
|
||||||
WithRefreshAfterWrite[int, int](2*time.Second),
|
WithRefreshAfterWrite[int, int](2*time.Second),
|
||||||
|
@ -347,6 +350,7 @@ func TestGetIfPresentExpired(t *testing.T) {
|
||||||
c := NewCache(WithExpireAfterWrite[int, string](1*time.Second), WithInsertionListener(insFunc))
|
c := NewCache(WithExpireAfterWrite[int, string](1*time.Second), WithInsertionListener(insFunc))
|
||||||
mockTime := newMockTime()
|
mockTime := newMockTime()
|
||||||
currentTime = mockTime.now
|
currentTime = mockTime.now
|
||||||
|
defer resetCurrentTime()
|
||||||
|
|
||||||
v, ok := c.GetIfPresent(0)
|
v, ok := c.GetIfPresent(0)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
|
@ -409,20 +413,20 @@ func TestWithAsyncInitPreLoader(t *testing.T) {
|
||||||
func TestSynchronousReload(t *testing.T) {
|
func TestSynchronousReload(t *testing.T) {
|
||||||
var val string
|
var val string
|
||||||
loader := func(k int) (string, error) {
|
loader := func(k int) (string, error) {
|
||||||
time.Sleep(1 * time.Millisecond)
|
|
||||||
if val == "" {
|
if val == "" {
|
||||||
return "", errors.New("empty")
|
return "", errors.New("empty")
|
||||||
}
|
}
|
||||||
return val, nil
|
return val, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c := NewLoadingCache(loader, WithExpireAfterWrite[int, string](1*time.Second))
|
c := NewLoadingCache(loader, WithExpireAfterWrite[int, string](200*time.Millisecond))
|
||||||
val = "a"
|
val = "a"
|
||||||
v, err := c.Get(1)
|
v, err := c.Get(1)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, val, v)
|
assert.Equal(t, val, v)
|
||||||
|
|
||||||
val = "b"
|
val = "b"
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
v, err = c.Get(1)
|
v, err = c.Get(1)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, val, v)
|
assert.Equal(t, val, v)
|
||||||
|
@ -500,3 +504,7 @@ func (t *mockTime) now() time.Time {
|
||||||
defer t.mu.RUnlock()
|
defer t.mu.RUnlock()
|
||||||
return t.value
|
return t.value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resetCurrentTime() {
|
||||||
|
currentTime = time.Now
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue