package proxy import ( "context" "fmt" "reflect" "strings" "sync" "github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2/model" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/funcutil" ) type PrivilegeFunc func(ctx context.Context, req interface{}) (context.Context, error) const ( // sub -> role name, like admin, public // obj -> contact object with object name, like Global-*, Collection-col1 // act -> privilege, like CreateCollection, DescribeCollection ModelStr = ` [request_definition] r = sub, obj, act [policy_definition] p = sub, obj, act [policy_effect] e = some(where (p.eft == allow)) [matchers] m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.sub == "admin" || (r.sub == p.sub && dbMatch(r.obj, p.obj) && p.act == "PrivilegeAll") ` ) var templateModel = getPolicyModel(ModelStr) var ( enforcer *casbin.SyncedEnforcer initOnce sync.Once ) func getEnforcer() *casbin.SyncedEnforcer { initOnce.Do(func() { e, err := casbin.NewSyncedEnforcer() if err != nil { log.Panic("failed to create casbin enforcer", zap.Error(err)) } casbinModel := getPolicyModel(ModelStr) adapter := NewMetaCacheCasbinAdapter(func() Cache { return globalMetaCache }) e.InitWithModelAndAdapter(casbinModel, adapter) e.AddFunction("dbMatch", DBMatchFunc) enforcer = e }) return enforcer } func getPolicyModel(modelString string) model.Model { m, err := model.NewModelFromString(modelString) if err != nil { log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err)) } return m } // UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access. func UnaryServerInterceptor(privilegeFunc PrivilegeFunc) grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { newCtx, err := privilegeFunc(ctx, req) if err != nil { return nil, err } return handler(newCtx, req) } } func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context, error) { if !Params.CommonCfg.AuthorizationEnabled.GetAsBool() { return ctx, nil } log := log.Ctx(ctx) log.RatedDebug(60, "PrivilegeInterceptor", zap.String("type", reflect.TypeOf(req).String())) privilegeExt, err := funcutil.GetPrivilegeExtObj(req) if err != nil { log.RatedInfo(60, "GetPrivilegeExtObj err", zap.Error(err)) return ctx, nil } username, password, err := contextutil.GetAuthInfoFromContext(ctx) if err != nil { log.Warn("GetCurUserFromContext fail", zap.Error(err)) return ctx, err } if username == util.UserRoot { return ctx, nil } roleNames, err := GetRole(username) if err != nil { log.Warn("GetRole fail", zap.String("username", username), zap.Error(err)) return ctx, err } roleNames = append(roleNames, util.RolePublic) objectType := privilegeExt.ObjectType.String() objectNameIndex := privilegeExt.ObjectNameIndex objectName := funcutil.GetObjectName(req, objectNameIndex) if isCurUserObject(objectType, username, objectName) { return ctx, nil } if isSelectMyRoleGrants(req, roleNames) { return ctx, nil } objectNameIndexs := privilegeExt.ObjectNameIndexs objectNames := funcutil.GetObjectNames(req, objectNameIndexs) objectPrivilege := privilegeExt.ObjectPrivilege.String() dbName := GetCurDBNameFromContextOrDefault(ctx) log = log.With(zap.String("username", username), zap.Strings("role_names", roleNames), zap.String("object_type", objectType), zap.String("object_privilege", objectPrivilege), zap.String("db_name", dbName), zap.Int32("object_index", objectNameIndex), zap.String("object_name", objectName), zap.Int32("object_indexs", objectNameIndexs), zap.Strings("object_names", objectNames)) e := getEnforcer() for _, roleName := range roleNames { permitFunc := func(resName string) (bool, error) { object := funcutil.PolicyForResource(dbName, objectType, resName) isPermit, cached, version := GetPrivilegeCache(roleName, object, objectPrivilege) if cached { return isPermit, nil } isPermit, err := e.Enforce(roleName, object, objectPrivilege) if err != nil { return false, err } SetPrivilegeCache(roleName, object, objectPrivilege, isPermit, version) return isPermit, nil } if objectNameIndex != 0 { // handle the api which refers one resource permitObject, err := permitFunc(objectName) if err != nil { log.Warn("fail to execute permit func", zap.String("name", objectName), zap.Error(err)) return ctx, err } if permitObject { return ctx, nil } } if objectNameIndexs != 0 { // handle the api which refers many resources permitObjects := true for _, name := range objectNames { p, err := permitFunc(name) if err != nil { log.Warn("fail to execute permit func", zap.String("name", name), zap.Error(err)) return ctx, err } if !p { permitObjects = false break } } if permitObjects && len(objectNames) != 0 { return ctx, nil } } } log.Info("permission deny", zap.Strings("roles", roleNames)) if password == util.PasswordHolder { username = "apikey user" } return ctx, status.Error(codes.PermissionDenied, fmt.Sprintf("%s: permission deny to %s in the `%s` database", objectPrivilege, username, dbName)) } // isCurUserObject Determine whether it is an Object of type User that operates on its own user information, // like updating password or viewing your own role information. // make users operate their own user information when the related privileges are not granted. func isCurUserObject(objectType string, curUser string, object string) bool { if objectType != commonpb.ObjectType_User.String() { return false } return curUser == object } func isSelectMyRoleGrants(req interface{}, roleNames []string) bool { selectGrantReq, ok := req.(*milvuspb.SelectGrantRequest) if !ok { return false } filterGrantEntity := selectGrantReq.GetEntity() roleName := filterGrantEntity.GetRole().GetName() return funcutil.SliceContain(roleNames, roleName) } func DBMatchFunc(args ...interface{}) (interface{}, error) { name1 := args[0].(string) name2 := args[1].(string) db1, _ := funcutil.SplitObjectName(name1[strings.Index(name1, "-")+1:]) db2, _ := funcutil.SplitObjectName(name2[strings.Index(name2, "-")+1:]) return db1 == db2, nil }