diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index 3e87b3a31f..95293ef67c 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -338,7 +338,6 @@ func (s *Server) init() error { s.etcdCli = etcdCli s.proxy.SetEtcdClient(s.etcdCli) s.proxy.SetAddress(fmt.Sprintf("%s:%d", Params.IP, Params.InternalPort)) - proxy.InitPolicyModel() errChan := make(chan error, 1) { diff --git a/internal/metastore/kv/rootcoord/kv_catalog.go b/internal/metastore/kv/rootcoord/kv_catalog.go index 308b723402..28402bc254 100644 --- a/internal/metastore/kv/rootcoord/kv_catalog.go +++ b/internal/metastore/kv/rootcoord/kv_catalog.go @@ -625,7 +625,7 @@ func (kc *Catalog) CreateRole(ctx context.Context, tenant string, entity *milvus k := funcutil.HandleTenantForEtcdKey(RolePrefix, tenant, entity.Name) err := kc.save(k) if err != nil { - log.Error("fail to save the role", zap.String("key", k), zap.Error(err)) + log.Warn("fail to save the role", zap.String("key", k), zap.Error(err)) } return err } diff --git a/internal/proxy/privilege_interceptor.go b/internal/proxy/privilege_interceptor.go index dc2dc670d2..e78f296ddc 100644 --- a/internal/proxy/privilege_interceptor.go +++ b/internal/proxy/privilege_interceptor.go @@ -40,16 +40,14 @@ m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.su ` ) -var ( - casbinModel model.Model -) +var templateModel = getPolicyModel(ModelStr) -func InitPolicyModel() { - var err error - casbinModel, err = model.NewModelFromString(ModelStr) +func getPolicyModel(modelString string) model.Model { + m, err := model.NewModelFromString(modelString) if err != nil { log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err)) } + return m } // UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access. @@ -107,9 +105,8 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context policy := fmt.Sprintf("[%s]", policyInfo) b := []byte(policy) a := jsonadapter.NewAdapter(&b) - if casbinModel == nil { - log.Panic("fail to get policy model") - } + // the `templateModel` object isn't safe in the concurrent situation + casbinModel := templateModel.Copy() e, err := casbin.NewEnforcer(casbinModel, a) if err != nil { log.Error("NewEnforcer fail", zap.String("policy", policy), zap.Error(err)) diff --git a/internal/proxy/privilege_interceptor_test.go b/internal/proxy/privilege_interceptor_test.go index 95a5f12a73..18be50a92a 100644 --- a/internal/proxy/privilege_interceptor_test.go +++ b/internal/proxy/privilege_interceptor_test.go @@ -2,14 +2,14 @@ package proxy import ( "context" + "sync" "testing" - "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus/internal/util/paramtable" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/stretchr/testify/assert" ) @@ -22,7 +22,6 @@ func TestPrivilegeInterceptor(t *testing.T) { ctx := context.Background() t.Run("Authorization Disabled", func(t *testing.T) { - InitPolicyModel() paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "false") _, err := PrivilegeInterceptor(ctx, &milvuspb.LoadCollectionRequest{ DbName: "db_test", @@ -32,7 +31,6 @@ func TestPrivilegeInterceptor(t *testing.T) { }) t.Run("Authorization Enabled", func(t *testing.T) { - InitPolicyModel() paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") _, err := PrivilegeInterceptor(ctx, &milvuspb.HasCollectionRequest{}) @@ -114,12 +112,23 @@ func TestPrivilegeInterceptor(t *testing.T) { }) assert.Nil(t, err) - casbinModel = nil + g := sync.WaitGroup{} + for i := 0; i < 20; i++ { + g.Add(1) + go func() { + defer g.Done() + assert.NotPanics(t, func() { + PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{ + DbName: "db_test", + CollectionName: "col1", + }) + }) + }() + } + g.Wait() + assert.Panics(t, func() { - PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{ - DbName: "db_test", - CollectionName: "col1", - }) + getPolicyModel("foo") }) }) diff --git a/internal/util/cache/local_cache_test.go b/internal/util/cache/local_cache_test.go index 71464d772f..d49e078103 100644 --- a/internal/util/cache/local_cache_test.go +++ b/internal/util/cache/local_cache_test.go @@ -218,6 +218,7 @@ func TestExpireAfterAccess(t *testing.T) { } mockTime := newMockTime() currentTime = mockTime.now + defer resetCurrentTime() c := NewCache(WithExpireAfterAccess[uint, uint](1*time.Second), WithRemovalListener(fn), WithInsertionListener(fn)).(*localCache[uint, uint]) defer c.Close() @@ -258,6 +259,7 @@ func TestExpireAfterWrite(t *testing.T) { mockTime := newMockTime() currentTime = mockTime.now + defer resetCurrentTime() c := NewLoadingCache(loader, WithExpireAfterWrite[string, int](1*time.Second)) defer c.Close() @@ -306,6 +308,7 @@ func TestRefreshAterWrite(t *testing.T) { } mockTime := newMockTime() currentTime = mockTime.now + defer resetCurrentTime() c := NewLoadingCache(loader, WithExpireAfterAccess[int, int](4*time.Second), WithRefreshAfterWrite[int, int](2*time.Second), @@ -347,6 +350,7 @@ func TestGetIfPresentExpired(t *testing.T) { c := NewCache(WithExpireAfterWrite[int, string](1*time.Second), WithInsertionListener(insFunc)) mockTime := newMockTime() currentTime = mockTime.now + defer resetCurrentTime() v, ok := c.GetIfPresent(0) assert.False(t, ok) @@ -409,20 +413,20 @@ func TestWithAsyncInitPreLoader(t *testing.T) { func TestSynchronousReload(t *testing.T) { var val string loader := func(k int) (string, error) { - time.Sleep(1 * time.Millisecond) if val == "" { return "", errors.New("empty") } return val, nil } - c := NewLoadingCache(loader, WithExpireAfterWrite[int, string](1*time.Second)) + c := NewLoadingCache(loader, WithExpireAfterWrite[int, string](200*time.Millisecond)) val = "a" v, err := c.Get(1) assert.NoError(t, err) assert.Equal(t, val, v) val = "b" + time.Sleep(300 * time.Millisecond) v, err = c.Get(1) assert.NoError(t, err) assert.Equal(t, val, v) @@ -500,3 +504,7 @@ func (t *mockTime) now() time.Time { defer t.mu.RUnlock() return t.value } + +func resetCurrentTime() { + currentTime = time.Now +}